add source single test for train

Former-commit-id: bbbc19eaa23049a6edabf93731f799f391e9623a
This commit is contained in:
2025-05-16 15:48:59 +08:00
parent a6b9531df5
commit 3605b974ab
4 changed files with 244 additions and 3 deletions

View File

@ -6,7 +6,7 @@ from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
from model import vgg
from model_train import vgg
# 单张图片预测程序
def main():
@ -18,7 +18,7 @@ def main():
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
# load image
img_path = "C:/Picture_Train/test-data/pic_0032.jpg"
img_path = "data/val/yellow/ticket_0_3.jpg"
assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
img = Image.open(img_path)
plt.imshow(img)
@ -37,7 +37,7 @@ def main():
# create model
model = vgg(model_name="vgg16", num_classes=2).to(device)
# load model weights
weights_path = "C:/Picture_Train/resnet34Net.pth"
weights_path = "resnet34-1Net.pth"
assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
model.load_state_dict(torch.load(weights_path, map_location=device))