Files
ai-titration/Picture_Train/src_predict.py
flt6 3605b974ab add source single test for train
Former-commit-id: bbbc19eaa23049a6edabf93731f799f391e9623a
2025-05-16 15:48:59 +08:00

40 lines
1.3 KiB
Python

from PIL import Image
from torchvision import transforms
import torch
import os
import json
from model import resnet34
print("init model")
model = resnet34(num_classes=2).to("cpu")
assert os.path.exists("resnet34-1Net.pth"), "file: '{}' dose not exist.".format("resnet34-1Net.pth")
model.load_state_dict(torch.load("resnet34-1Net.pth", map_location="cpu"))
print("load model done")
def predictor(im_file): # 预测分类
image = Image.open(im_file)
data_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
img = data_transform(image)
img = torch.unsqueeze(img, dim=0)
with open("class_indices.json", "r") as f:
class_indict = json.load(f)
model.eval()
with torch.no_grad():
output = torch.squeeze(model(img)).cpu()
predict = torch.softmax(output, dim=0)
predict_cla = torch.argmax(predict).numpy()
class_a = "{}".format(class_indict[str(predict_cla)])
prob_a = "{:.3}".format(predict[predict_cla].numpy())
prob_b = float(prob_a)
# print('class_:',class_a)
# print('prob_:',prob_b)
return class_a, prob_b