src
Former-commit-id: 43bc977a970a5bba09d0afa6f2a85169fe1ed253
This commit is contained in:
170
Auto_Ctrl/predictor_burette.py
Normal file
170
Auto_Ctrl/predictor_burette.py
Normal file
@ -0,0 +1,170 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
import torchvision.transforms as transforms
|
||||
from torch.autograd import Variable
|
||||
from torch import nn
|
||||
import matplotlib.pyplot as plt
|
||||
import cv2
|
||||
import time
|
||||
import os
|
||||
from model import resnet34
|
||||
import json
|
||||
import serial
|
||||
import Find_COM
|
||||
|
||||
|
||||
def get_picture(cap): # 获取照片
|
||||
# 捕获一帧的数据
|
||||
ret, frame = cap.read()
|
||||
if frame is None:
|
||||
print(frame)
|
||||
if ret:
|
||||
# 默认不阻塞
|
||||
cv2.imshow("picture", frame)
|
||||
cv2.waitKey(1)
|
||||
# 数据帧写入图片中
|
||||
label = "1"
|
||||
timeStamp = 1381419600
|
||||
image_name = str(int(time.time())) + ".jpg"
|
||||
# 照片存储位置
|
||||
filepath = "Input/" + image_name # 改成跟上面一样的位置
|
||||
str_name = filepath.replace('%s', label)
|
||||
cv2.imwrite(str_name, frame) # 将照片保存起来
|
||||
return image_name
|
||||
|
||||
|
||||
def start_move_1(port, baudrate): # 快速加酸程序
|
||||
ser = serial.Serial(port, baudrate)
|
||||
data = b"q1h15d" # 每分钟转15圈,一个比慢滴略高的旋转速度
|
||||
ser.write(data)
|
||||
time.sleep(0.01)
|
||||
data = b"q5h1d" # 转1秒
|
||||
ser.write(data)
|
||||
time.sleep(0.01)
|
||||
data = b"q6h3d" # 逆时针
|
||||
ser.write(data)
|
||||
time.sleep(20) # 等待20秒,为快滴时间,注意,这里等待时间不能少于1秒(阀门开启的时间)
|
||||
data = b"q6h2d" # 顺时针
|
||||
ser.write(data)
|
||||
time.sleep(1) # 转回去的时间
|
||||
ser.close()
|
||||
|
||||
|
||||
def start_move_2(port, baudrate): # 缓慢加酸程序
|
||||
ser = serial.Serial(port, baudrate)
|
||||
data = b"q1h14d" # 每分钟转14圈,每秒转14/60=0.233圈,可以根据实际情况调整
|
||||
ser.write(data)
|
||||
time.sleep(0.01)
|
||||
# 注意:实际上我们也可以修改速度的小数部分,如下列注释所示
|
||||
# data = b"q2h50d" # 结合q1指令,每分钟转14.5圈,可以根据实际情况调整
|
||||
# ser.write(data)
|
||||
# time.sleep(0.01)
|
||||
data = b"q5h1d" # 转1秒
|
||||
ser.write(data)
|
||||
time.sleep(0.01)
|
||||
data = b"q6h3d" # 逆时针
|
||||
ser.write(data)
|
||||
time.sleep(1) # 转阀门的时间
|
||||
# 注意,这里没有将阀门转回去,而是持续几秒钟滴加一次的状态
|
||||
ser.close()
|
||||
|
||||
|
||||
def start_move_3(port, baudrate): # 停止加酸程序
|
||||
ser = serial.Serial(port, baudrate)
|
||||
data = b"q1h14d" # 每分钟转14圈,需要与move2的速度保持一致
|
||||
ser.write(data)
|
||||
time.sleep(0.01)
|
||||
data = b"q5h1d" # 转1秒
|
||||
ser.write(data)
|
||||
time.sleep(0.01)
|
||||
data = b"q6h2d" # 顺时针
|
||||
ser.write(data)
|
||||
time.sleep(1) # 转阀门的时间
|
||||
# 将阀门转回去
|
||||
ser.close()
|
||||
|
||||
|
||||
def main():
|
||||
# port = "COM6" # 串口名,根据实际情况修改
|
||||
port = Find_COM.list_ch340_ports()[0] # 串口名,根据实际情况修改
|
||||
baudrate = 9600 # 波特率,根据实际情况修改
|
||||
# # 快速滴加过程,这里请自己根据滴加量优化
|
||||
# start_move_1(port, baudrate)
|
||||
# time.sleep(15)
|
||||
videoSourceIndex = 0 # 摄像机编号,请根据自己的情况调整
|
||||
cap = cv2.VideoCapture(videoSourceIndex, cv2.CAP_DSHOW) # 打开摄像头
|
||||
# 是否用GPU
|
||||
device = torch.device( "cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
start_move_2(port, baudrate) # 开启慢滴状态
|
||||
while True:
|
||||
# 读取图片
|
||||
name = get_picture(cap)
|
||||
# 图片完整路径
|
||||
im_file = 'Input/' + name
|
||||
# 使用PIL库打开图片
|
||||
image = Image.open(im_file)
|
||||
# print(type(image)) # 打印图片的类型
|
||||
# 定义图片预处理流程
|
||||
data_transform = transforms.Compose(
|
||||
[
|
||||
# 调整图片大小为256x256
|
||||
transforms.Resize(256),
|
||||
# 从中心裁剪出224x224大小的图片
|
||||
transforms.CenterCrop(224),
|
||||
# 将图片转换为PyTorch的Tensor格式
|
||||
transforms.ToTensor(),
|
||||
# 对图片进行归一化,使用ImageNet的均值和标准差
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||
])
|
||||
# [N, C, H, W]
|
||||
img = data_transform(image)
|
||||
# expand batch dimension
|
||||
img = torch.unsqueeze(img, dim=0)
|
||||
# 定义模型权重文件的路径
|
||||
json_path = './class_indices.json'
|
||||
|
||||
with open(json_path, "r") as f:
|
||||
class_indict = json.load(f)
|
||||
# create model
|
||||
model = resnet34(num_classes=2).to(device)
|
||||
|
||||
# load model weights
|
||||
weights_path = "./resnet34-1Net.pth"
|
||||
assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
|
||||
model.load_state_dict(torch.load(weights_path, map_location=device, weights_only=True))
|
||||
|
||||
# prediction
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
# predict class
|
||||
output = torch.squeeze(model(img.to(device))).cpu()
|
||||
# 对预测结果进行softmax,得到每个类别的概率
|
||||
predict = torch.softmax(output, dim=0)
|
||||
# 找到概率最大的类别的索引
|
||||
predict_cla = torch.argmax(predict).numpy()
|
||||
# 根据索引从类别字典中获取类别名称
|
||||
class_a = "{}".format(class_indict[str(predict_cla)])
|
||||
# 格式化概率值,保留三位小数
|
||||
prob_a = "{:.3}".format(predict[predict_cla].numpy())
|
||||
# 将概率值转换为浮点数
|
||||
prob_b = float(prob_a)
|
||||
# 打印预测的类别和概率
|
||||
print(class_a)
|
||||
print(prob_b)
|
||||
if class_a == "orange" and prob_b >= 0.5: # 到达滴定终点
|
||||
# 关闭阀门
|
||||
start_move_3(port, baudrate)
|
||||
print('----->>End<<-----')
|
||||
print(im_file)
|
||||
time.sleep(1)
|
||||
# 释放摄像头
|
||||
cap.release()
|
||||
# 关闭所有OpenCV窗口
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
break
|
||||
time.sleep(1) # 拍照间隔
|
||||
|
||||
|
||||
if True:
|
||||
main()
|
Reference in New Issue
Block a user