add source single test for train
Former-commit-id: bbbc19eaa23049a6edabf93731f799f391e9623a
This commit is contained in:
40
Picture_Train/src_predict.py
Normal file
40
Picture_Train/src_predict.py
Normal file
@ -0,0 +1,40 @@
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
import torch
|
||||
import os
|
||||
import json
|
||||
from model import resnet34
|
||||
|
||||
print("init model")
|
||||
model = resnet34(num_classes=2).to("cpu")
|
||||
|
||||
assert os.path.exists("resnet34-1Net.pth"), "file: '{}' dose not exist.".format("resnet34-1Net.pth")
|
||||
model.load_state_dict(torch.load("resnet34-1Net.pth", map_location="cpu"))
|
||||
|
||||
print("load model done")
|
||||
|
||||
def predictor(im_file): # 预测分类
|
||||
image = Image.open(im_file)
|
||||
data_transform = transforms.Compose([
|
||||
transforms.Resize(256),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||
])
|
||||
img = data_transform(image)
|
||||
img = torch.unsqueeze(img, dim=0)
|
||||
with open("class_indices.json", "r") as f:
|
||||
class_indict = json.load(f)
|
||||
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
output = torch.squeeze(model(img)).cpu()
|
||||
predict = torch.softmax(output, dim=0)
|
||||
predict_cla = torch.argmax(predict).numpy()
|
||||
|
||||
class_a = "{}".format(class_indict[str(predict_cla)])
|
||||
prob_a = "{:.3}".format(predict[predict_cla].numpy())
|
||||
prob_b = float(prob_a)
|
||||
# print('class_:',class_a)
|
||||
# print('prob_:',prob_b)
|
||||
return class_a, prob_b
|
Reference in New Issue
Block a user