src
Former-commit-id: 43bc977a970a5bba09d0afa6f2a85169fe1ed253
This commit is contained in:
31
Picture_Train/pth_to_pt.py
Normal file
31
Picture_Train/pth_to_pt.py
Normal file
@ -0,0 +1,31 @@
|
||||
"""
|
||||
该程序使用的是resnet34网络,用到其他网络可自行更改
|
||||
保存的权重字典目录如下所示。
|
||||
ckpt = {
|
||||
'weight': model.state_dict(),
|
||||
'epoch': epoch,
|
||||
'cfg': opt.model,
|
||||
'index': name
|
||||
}
|
||||
"""
|
||||
from model import resnet34 # 确保引用你的正确模型架构
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# 假设你的ResNet定义在resnet.py文件中
|
||||
model = resnet34()
|
||||
num_ftrs = model.fc.in_features
|
||||
model.fc = nn.Linear(num_ftrs, 2) # 修改这里为你的类别数
|
||||
|
||||
# 加载权重
|
||||
checkpoint = torch.load('resnet34-1Net.pth', map_location=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"), weights_only=True)
|
||||
# print(checkpoint.keys())
|
||||
# 根据实际情况修改键名
|
||||
model.load_state_dict(checkpoint, strict=False) # 使用strict=False可以忽略不匹配的键
|
||||
|
||||
model.eval()
|
||||
# 将模型转换为TorchScript
|
||||
example_input = torch.rand(1, 3, 32, 32) # 修改这里以匹配你的模型输入尺寸
|
||||
traced_script_module = torch.jit.trace(model, example_input)
|
||||
traced_script_module.save("resnet34-1Net.pt")
|
||||
print('Finished Model Convertion')
|
Reference in New Issue
Block a user