add source single test for train
Former-commit-id: bbbc19eaa23049a6edabf93731f799f391e9623a
This commit is contained in:
@ -6,7 +6,7 @@ from PIL import Image
|
||||
from torchvision import transforms
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from model import vgg
|
||||
from model_train import vgg
|
||||
|
||||
# 单张图片预测程序
|
||||
def main():
|
||||
@ -18,7 +18,7 @@ def main():
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
|
||||
|
||||
# load image
|
||||
img_path = "C:/Picture_Train/test-data/pic_0032.jpg"
|
||||
img_path = "data/val/yellow/ticket_0_3.jpg"
|
||||
assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
|
||||
img = Image.open(img_path)
|
||||
plt.imshow(img)
|
||||
@ -37,7 +37,7 @@ def main():
|
||||
# create model
|
||||
model = vgg(model_name="vgg16", num_classes=2).to(device)
|
||||
# load model weights
|
||||
weights_path = "C:/Picture_Train/resnet34Net.pth"
|
||||
weights_path = "resnet34-1Net.pth"
|
||||
assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
|
||||
model.load_state_dict(torch.load(weights_path, map_location=device))
|
||||
|
||||
|
Reference in New Issue
Block a user