40 lines
1.3 KiB
Python
40 lines
1.3 KiB
Python
from PIL import Image
|
|
from torchvision import transforms
|
|
import torch
|
|
import os
|
|
import json
|
|
from model import resnet34
|
|
|
|
print("init model")
|
|
model = resnet34(num_classes=2).to("cpu")
|
|
|
|
assert os.path.exists("resnet34-1Net.pth"), "file: '{}' dose not exist.".format("resnet34-1Net.pth")
|
|
model.load_state_dict(torch.load("resnet34-1Net.pth", map_location="cpu"))
|
|
|
|
print("load model done")
|
|
|
|
def predictor(im_file): # 预测分类
|
|
image = Image.open(im_file)
|
|
data_transform = transforms.Compose([
|
|
transforms.Resize(256),
|
|
transforms.CenterCrop(224),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
|
])
|
|
img = data_transform(image)
|
|
img = torch.unsqueeze(img, dim=0)
|
|
with open("class_indices.json", "r") as f:
|
|
class_indict = json.load(f)
|
|
|
|
model.eval()
|
|
with torch.no_grad():
|
|
output = torch.squeeze(model(img)).cpu()
|
|
predict = torch.softmax(output, dim=0)
|
|
predict_cla = torch.argmax(predict).numpy()
|
|
|
|
class_a = "{}".format(class_indict[str(predict_cla)])
|
|
prob_a = "{:.3}".format(predict[predict_cla].numpy())
|
|
prob_b = float(prob_a)
|
|
# print('class_:',class_a)
|
|
# print('prob_:',prob_b)
|
|
return class_a, prob_b |