from PIL import Image from torchvision import transforms import torch import os import json from model import resnet34 print("init model") model = resnet34(num_classes=2).to("cpu") assert os.path.exists("resnet34-1Net.pth"), "file: '{}' dose not exist.".format("resnet34-1Net.pth") model.load_state_dict(torch.load("resnet34-1Net.pth", map_location="cpu")) print("load model done") def predictor(im_file): # 预测分类 image = Image.open(im_file) data_transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) img = data_transform(image) img = torch.unsqueeze(img, dim=0) with open("class_indices.json", "r") as f: class_indict = json.load(f) model.eval() with torch.no_grad(): output = torch.squeeze(model(img)).cpu() predict = torch.softmax(output, dim=0) predict_cla = torch.argmax(predict).numpy() class_a = "{}".format(class_indict[str(predict_cla)]) prob_a = "{:.3}".format(predict[predict_cla].numpy()) prob_b = float(prob_a) # print('class_:',class_a) # print('prob_:',prob_b) return class_a, prob_b