32 lines
1.1 KiB
Python
32 lines
1.1 KiB
Python
"""
|
||
该程序使用的是resnet34网络,用到其他网络可自行更改
|
||
保存的权重字典目录如下所示。
|
||
ckpt = {
|
||
'weight': model.state_dict(),
|
||
'epoch': epoch,
|
||
'cfg': opt.model,
|
||
'index': name
|
||
}
|
||
"""
|
||
from model import resnet34 # 确保引用你的正确模型架构
|
||
import torch
|
||
import torch.nn as nn
|
||
|
||
# 假设你的ResNet定义在resnet.py文件中
|
||
model = resnet34()
|
||
num_ftrs = model.fc.in_features
|
||
model.fc = nn.Linear(num_ftrs, 2) # 修改这里为你的类别数
|
||
|
||
# 加载权重
|
||
checkpoint = torch.load('resnet34-1Net.pth', map_location=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"), weights_only=True)
|
||
# print(checkpoint.keys())
|
||
# 根据实际情况修改键名
|
||
model.load_state_dict(checkpoint, strict=False) # 使用strict=False可以忽略不匹配的键
|
||
|
||
model.eval()
|
||
# 将模型转换为TorchScript
|
||
example_input = torch.rand(1, 3, 32, 32) # 修改这里以匹配你的模型输入尺寸
|
||
traced_script_module = torch.jit.trace(model, example_input)
|
||
traced_script_module.save("resnet34-1Net.pt")
|
||
print('Finished Model Convertion')
|