src
Former-commit-id: 43bc977a970a5bba09d0afa6f2a85169fe1ed253
This commit is contained in:
61
Picture_Train/predict.py
Normal file
61
Picture_Train/predict.py
Normal file
@ -0,0 +1,61 @@
|
||||
import os
|
||||
import json
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from model import vgg
|
||||
|
||||
# 单张图片预测程序
|
||||
def main():
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
data_transform = transforms.Compose(
|
||||
[transforms.Resize((224, 224)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
|
||||
|
||||
# load image
|
||||
img_path = "C:/Picture_Train/test-data/pic_0032.jpg"
|
||||
assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
|
||||
img = Image.open(img_path)
|
||||
plt.imshow(img)
|
||||
# [N, C, H, W]
|
||||
img = data_transform(img)
|
||||
# expand batch dimension
|
||||
img = torch.unsqueeze(img, dim=0)
|
||||
|
||||
# read class_indict
|
||||
json_path = './class_indices.json'
|
||||
assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
|
||||
|
||||
with open(json_path, "r") as f:
|
||||
class_indict = json.load(f)
|
||||
|
||||
# create model
|
||||
model = vgg(model_name="vgg16", num_classes=2).to(device)
|
||||
# load model weights
|
||||
weights_path = "C:/Picture_Train/resnet34Net.pth"
|
||||
assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
|
||||
model.load_state_dict(torch.load(weights_path, map_location=device))
|
||||
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
# predict class
|
||||
output = torch.squeeze(model(img.to(device))).cpu()
|
||||
predict = torch.softmax(output, dim=0)
|
||||
predict_cla = torch.argmax(predict).numpy()
|
||||
|
||||
print_res = "class: {} prob: {:.3}".format(class_indict[str(predict_cla)],
|
||||
predict[predict_cla].numpy())
|
||||
plt.title(print_res)
|
||||
for i in range(len(predict)):
|
||||
print("class: {:10} prob: {:.3}".format(class_indict[str(i)],
|
||||
predict[i].numpy()))
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Reference in New Issue
Block a user