Files
ai-titration/Picture_Train/pth_to_pt.py
flt6 2c2822ff11 src
Former-commit-id: 43bc977a970a5bba09d0afa6f2a85169fe1ed253
2025-05-15 23:49:13 +08:00

32 lines
1.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
该程序使用的是resnet34网络用到其他网络可自行更改
保存的权重字典目录如下所示。
ckpt = {
'weight': model.state_dict(),
'epoch': epoch,
'cfg': opt.model,
'index': name
}
"""
from model import resnet34 # 确保引用你的正确模型架构
import torch
import torch.nn as nn
# 假设你的ResNet定义在resnet.py文件中
model = resnet34()
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2) # 修改这里为你的类别数
# 加载权重
checkpoint = torch.load('resnet34-1Net.pth', map_location=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"), weights_only=True)
# print(checkpoint.keys())
# 根据实际情况修改键名
model.load_state_dict(checkpoint, strict=False) # 使用strict=False可以忽略不匹配的键
model.eval()
# 将模型转换为TorchScript
example_input = torch.rand(1, 3, 32, 32) # 修改这里以匹配你的模型输入尺寸
traced_script_module = torch.jit.trace(model, example_input)
traced_script_module.save("resnet34-1Net.pt")
print('Finished Model Convertion')