Former-commit-id: 43bc977a970a5bba09d0afa6f2a85169fe1ed253
This commit is contained in:
2025-05-15 23:49:13 +08:00
commit 2c2822ff11
720 changed files with 2735 additions and 0 deletions

View File

@ -0,0 +1,31 @@
"""
该程序使用的是resnet34网络用到其他网络可自行更改
保存的权重字典目录如下所示。
ckpt = {
'weight': model.state_dict(),
'epoch': epoch,
'cfg': opt.model,
'index': name
}
"""
from model import resnet34 # 确保引用你的正确模型架构
import torch
import torch.nn as nn
# 假设你的ResNet定义在resnet.py文件中
model = resnet34()
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2) # 修改这里为你的类别数
# 加载权重
checkpoint = torch.load('resnet34-1Net.pth', map_location=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"), weights_only=True)
# print(checkpoint.keys())
# 根据实际情况修改键名
model.load_state_dict(checkpoint, strict=False) # 使用strict=False可以忽略不匹配的键
model.eval()
# 将模型转换为TorchScript
example_input = torch.rand(1, 3, 32, 32) # 修改这里以匹配你的模型输入尺寸
traced_script_module = torch.jit.trace(model, example_input)
traced_script_module.save("resnet34-1Net.pt")
print('Finished Model Convertion')