373 lines
12 KiB
Python
373 lines
12 KiB
Python
import torch
|
||
from PIL import Image
|
||
import torchvision.transforms as transforms
|
||
from torch.autograd import Variable
|
||
from torch import nn
|
||
import matplotlib.pyplot as plt
|
||
import cv2
|
||
import time
|
||
import os
|
||
from model import resnet34
|
||
import json
|
||
import serial
|
||
from datetime import datetime
|
||
from scipy.optimize import curve_fit
|
||
import numpy as np
|
||
import re
|
||
import json
|
||
import Find_COM
|
||
|
||
|
||
def get_picture(frame, typ=0, date=''): # 获取照片
|
||
|
||
# 捕获一帧的数据
|
||
# ret, frame = cap.read()
|
||
if frame is None:
|
||
print(frame)
|
||
# if ret:
|
||
# # 默认不阻塞
|
||
# cv2.imshow("picture", frame)
|
||
# 数据帧写入图片中
|
||
label = "1"
|
||
timeStamp = 1381419600
|
||
if typ:
|
||
image_name = f'{date}{int(time.time())}.jpg'
|
||
else:
|
||
image_name = f'{date}PH{int(time.time())}.jpg'
|
||
# 照片存储位置
|
||
filepath = "Input/" + image_name # 改成跟上面一样的位置
|
||
str_name = filepath.replace('%s', label)
|
||
cv2.imwrite(str_name, frame) # 将照片保存起来
|
||
|
||
return image_name
|
||
|
||
|
||
def start_move_1(ser): # 抽取原料
|
||
# 注意:这里我们将控制器的模式切换成了20ml注射泵模式,由于丝杠的区别,这里设定的速度为实际速度(ml/min)的一半
|
||
data = b"q1h12d" # 每分钟加样24ml
|
||
ser.write(data)
|
||
time.sleep(0.01)
|
||
data = b"q4h0d" # 转0分钟
|
||
ser.write(data)
|
||
time.sleep(0.01)
|
||
data = b"q5h30d" # 转30秒
|
||
ser.write(data) # 合计抽取12ml
|
||
time.sleep(0.01)
|
||
data = b"q6h3d" # 抽取
|
||
ser.write(data)
|
||
time.sleep(30) # 等待抽取
|
||
print('完成抽取')
|
||
# ser.close()
|
||
|
||
|
||
def start_move_2(ser): # 缓慢加样程序
|
||
data = b"q1h1d" # 每分钟加样3ml,每秒0.1ml
|
||
ser.write(data)
|
||
time.sleep(0.01)
|
||
data = b"q2h50d" # 每分钟加样6ml,每秒0.1ml
|
||
ser.write(data)
|
||
time.sleep(0.01)
|
||
data = b"q4h30d" # 转0分钟
|
||
ser.write(data)
|
||
time.sleep(0.01)
|
||
data = b"q5h0d" # 转1秒
|
||
ser.write(data) # 合计进样12ml
|
||
time.sleep(0.01)
|
||
data = b"q6h2d" # 进样
|
||
ser.write(data)
|
||
time.sleep(1)
|
||
# 注意,这里没有将阀门转回去,而是持续几秒钟滴加一次的状态
|
||
# ser.close()
|
||
|
||
|
||
def start_move_4(ser): # 缓慢加样程序0.2
|
||
data = b"q1h6d" # 每分钟加样12ml,每秒0.2ml
|
||
ser.write(data)
|
||
time.sleep(0.01)
|
||
data = b"q2h0d" # 每分钟加样6ml,每秒0.1ml
|
||
ser.write(data)
|
||
time.sleep(0.01)
|
||
data = b"q4h30d" # 转30分钟
|
||
ser.write(data)
|
||
time.sleep(0.01)
|
||
data = b"q5h0d" # 转0秒
|
||
ser.write(data) # 持续进样
|
||
time.sleep(0.01)
|
||
data = b"q6h2d" # 进样
|
||
ser.write(data)
|
||
time.sleep(1)
|
||
# 注意,这里没有将阀门转回去,而是持续几秒钟滴加一次的状态
|
||
# ser.close()
|
||
|
||
|
||
def start_move_3(ser): # 停止加酸程序
|
||
data = b"q6h6d" # 停止指令
|
||
ser.write(data)
|
||
# 将阀门转回去
|
||
ser.close()
|
||
|
||
|
||
def read_number_new(filepath):
|
||
from paddleocr import PaddleOCR, draw_ocr
|
||
# 创建一个OCR实例,配置语言为中文
|
||
ocr = PaddleOCR(use_angle_cls=True, lang="ch")
|
||
# 对图片进行OCR识别
|
||
img_path = filepath
|
||
result = ocr.ocr(img_path, cls=True)
|
||
print('-----------------------------------------')
|
||
print(result)
|
||
ans = []
|
||
for line in result:
|
||
if line:
|
||
for line1 in line:
|
||
# print(line1[-1])
|
||
# print(line1[-1][0])
|
||
try:
|
||
ans.append(float(line1[-1][0]))
|
||
except:
|
||
continue
|
||
print(ans)
|
||
if not ans:
|
||
ans.append(10)
|
||
return ans
|
||
|
||
|
||
# 定义反正切函数
|
||
def poly_func(x, a, b, c, d):
|
||
return a * np.tanh(d * x + b) + c
|
||
|
||
|
||
def line_chart(date="1", volume_list=[], voltage_list=[], color_list=[]):
|
||
x = volume_list
|
||
y = voltage_list
|
||
z = color_list
|
||
|
||
# '''
|
||
fig, ax1 = plt.subplots()
|
||
plt.title("titration curve")
|
||
# 绘制第一个Y轴的数据,绘制电位曲线
|
||
color = 'tab:red'
|
||
ax1.set_xlabel('value')
|
||
ax1.set_ylabel('voltage', color=color)
|
||
ax1.plot(x, y, color=color, antialiased=True)
|
||
ax1.tick_params(axis='y', labelcolor=color)
|
||
|
||
# 创建一个共享X轴的第二个Y轴,绘制颜色曲线
|
||
ax2 = ax1.twinx()
|
||
color = 'tab:blue'
|
||
ax2.set_ylabel('color', color=color)
|
||
print(x,z)
|
||
ax2.plot(x, z, color=color)
|
||
ax2.tick_params(axis='y', labelcolor=color)
|
||
ax2.set_yticks([0, 1]) # 设置Y轴的刻度位置
|
||
ax2.set_yticklabels(['yellow', 'orange']) # 设置Y轴的刻度标签
|
||
ax2.spines['right'].set_position(('outward', 60)) # 将第三个Y轴向右移动
|
||
# ax2.tick_params(axis='y', labelcolor='none')
|
||
try:
|
||
# 初始参数估计
|
||
popt, pcov = curve_fit(poly_func, x, y, p0=[max(y)*3/4, -max(x), max(y), 1.5])
|
||
|
||
# 打印最优参数
|
||
print("最优参数:", popt)
|
||
print(f'电位突跃点:{-popt[1]/popt[3]:.3f}')
|
||
# print(max(x))
|
||
x_d = np.arange(0, max(x), 0.05)
|
||
# 使用拟合得到的参数计算二阶导数
|
||
y_fit = poly_func(x_d, *popt)
|
||
# 计算一阶微商(即电位对体积的导数)
|
||
dE_dV = np.gradient(y_fit)
|
||
# 计算二阶微商
|
||
d2E_dV2 = np.gradient(dE_dV)
|
||
y2 = d2E_dV2.tolist()
|
||
# y2 = dE_dV.tolist()
|
||
# y2 = y_fit
|
||
# 创建一个共享X轴的第3个Y轴,绘制二阶导
|
||
ax3 = ax1.twinx()
|
||
color = 'tab:green'
|
||
|
||
ax3.set_ylabel('2nd Derivative', color=color)
|
||
ax3.plot(x_d, y2, color=color)
|
||
# ax3.plot(x_d, y_fit, color=color)
|
||
ax3.tick_params(axis='y', labelcolor=color)
|
||
ax3.grid(True, linestyle='--', linewidth=0.5, color='gray', axis='both')
|
||
|
||
# 画出电位突变点
|
||
x_d, y_d = -popt[1]/popt[3], 0.0
|
||
ax3.plot(x_d, y_d, 'ro') # 'ro' 表示红色圆圈,'r' 表示红色,'o' 表示圆圈4\4
|
||
|
||
# 标注坐标
|
||
ax3.annotate(f'({x_d:.2f})', # 标注的文本,使用格式化字符串显示坐标
|
||
xy=(x_d, y_d), # 标注指向的点
|
||
color='red', # 标注文本的颜色
|
||
xytext=(x_d-1, y_d + max(y2)/10) # 标注文本的位置,这里相对于点的位置稍微偏移
|
||
)
|
||
# 画出视觉突变点
|
||
x_c, y_c = x[xz] - 0.025, 0.0
|
||
ax3.plot(x_c, y_c, 'bo') # 'bo' 表示蓝色圆圈,'b' 表示蓝色,'o' 表示圆圈4\4
|
||
# 标注坐标
|
||
ax3.annotate(f'({x_c:.2f})', # 标注的文本,使用格式化字符串显示坐标
|
||
xy=(x_c, y_c), # 标注指向的点
|
||
color='blue', # 标注文本的颜色
|
||
xytext=(x_c - 1, y_c - max(y2) / 10) # 标注文本的位置,这里相对于点的位置稍微偏移
|
||
)
|
||
print(f"视觉突跃点:{x_c:.3f}")
|
||
# '''
|
||
except Exception as e:
|
||
print(e)
|
||
pass
|
||
|
||
fig.tight_layout() # 自动调整子图参数, 使之填充整个图像区域
|
||
plt.savefig(f'Output/{date}.png')
|
||
# plt.savefig('O1.png')
|
||
plt.show()
|
||
plt.pause(1)
|
||
plt.close()
|
||
|
||
|
||
def predictor(im_file, device):
|
||
# 使用PIL库打开图片
|
||
image = Image.open(im_file)
|
||
# 定义图片预处理流程
|
||
data_transform = transforms.Compose(
|
||
[
|
||
# 调整图片大小为256x256
|
||
transforms.Resize(256),
|
||
# 从中心裁剪出224x224大小的图片
|
||
transforms.CenterCrop(224),
|
||
# 将图片转换为PyTorch的Tensor格式
|
||
transforms.ToTensor(),
|
||
# 对图片进行归一化,使用ImageNet的均值和标准差
|
||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||
])
|
||
# [N, C, H, W]
|
||
img = data_transform(image)
|
||
# expand batch dimension
|
||
img = torch.unsqueeze(img, dim=0)
|
||
# 定义模型分类文件的路径
|
||
json_path = './class_indices.json'
|
||
|
||
with open(json_path, "r") as f:
|
||
class_indict = json.load(f)
|
||
# create model
|
||
model = resnet34(num_classes=2).to(device) # 根据分类数量修改
|
||
|
||
# load model weights
|
||
weights_path = "./resnet34-1Net.pth" # 根据实际需要使用的模型名称修改
|
||
assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
|
||
model.load_state_dict(torch.load(weights_path, map_location=device))
|
||
|
||
# prediction
|
||
model.eval()
|
||
with torch.no_grad():
|
||
# predict class
|
||
output = torch.squeeze(model(img.to(device))).cpu()
|
||
# 对预测结果进行softmax,得到每个类别的概率
|
||
predict = torch.softmax(output, dim=0)
|
||
# 找到概率最大的类别的索引
|
||
predict_cla = torch.argmax(predict).numpy()
|
||
# 根据索引从类别字典中获取类别名称
|
||
class_a = "{}".format(class_indict[str(predict_cla)])
|
||
# 格式化概率值,保留三位小数
|
||
prob_a = "{:.3}".format(predict[predict_cla].numpy())
|
||
# 将概率值转换为浮点数
|
||
prob_b = float(prob_a)
|
||
# 打印预测的类别和概率
|
||
print(class_a)
|
||
print(prob_b)
|
||
return class_a, prob_b
|
||
|
||
|
||
def voltage(ser):
|
||
# data = "VOL|" # 每分钟加样12ml,每秒0.2ml
|
||
ser.write("VOL|\n".encode())
|
||
time.sleep(0.1) # 等待设备响应
|
||
while True:
|
||
# 读取响应
|
||
response = ser.readline().decode().strip()
|
||
if response:
|
||
# print(f"设备响应: {response}")
|
||
# if "END" in response:
|
||
# break
|
||
try:
|
||
return float(response)
|
||
except:
|
||
pass
|
||
|
||
|
||
def main():
|
||
# port = "COM11" # 串口名,根据实际情况修改
|
||
port = Find_COM.list_ch340_ports()[0] # 串口名,根据实际情况修改
|
||
baudrate = 9600 # 波特率,根据实际情况修改
|
||
pump_ser = serial.Serial(port, baudrate)
|
||
# port_USB = []
|
||
port_USB = Find_COM.list_USB_ports() # 串口名,根据实际情况修改
|
||
if port_USB:
|
||
USB_ser = serial.Serial(port_USB[0], baudrate=115200, timeout=1)
|
||
# print(voltage(USB_ser))
|
||
videoSourceIndex = 0 # 摄像机编号,请根据自己的情况调整
|
||
cap = cv2.VideoCapture(videoSourceIndex, cv2.CAP_DSHOW) # 打开摄像头
|
||
# 是否用GPU
|
||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||
# 循环开始之前需要一个变量来记录初始状态 比如说就叫color_type
|
||
total_volume = 0
|
||
now_volume = 0
|
||
volume_list = []
|
||
voltage_list = []
|
||
color_list = []
|
||
start_time = time.time()
|
||
# 将时间戳转换为datetime对象
|
||
dt_object = datetime.fromtimestamp(start_time)
|
||
# 格式化datetime对象为字符串,该时间用于保存图像名称
|
||
formatted_time = dt_object.strftime('%Y%m%d_%H%M%S')
|
||
print("实验开始于", formatted_time)
|
||
n = 10
|
||
total_n = n
|
||
start_move_2(pump_ser)
|
||
while True:
|
||
total_volume += 1
|
||
volume_list.append(total_volume)
|
||
# 读取图片
|
||
ret, frame = cap.read()
|
||
name = get_picture(frame, 0, formatted_time)
|
||
|
||
# 图片完整路径
|
||
im_file = 'Input/' + name
|
||
cv2.imshow('Color', frame)
|
||
cv2.waitKey(1)
|
||
class_a ,prob_b =predictor(im_file,device)
|
||
volume_list.append(total_volume)
|
||
if port_USB:
|
||
voltage_list.append(voltage(USB_ser))
|
||
if class_a == "orange" and prob_b > 0.5: # 判断终点
|
||
# 如果判断为终点
|
||
# 使用两个空列表用来记录后续五次的判断结果
|
||
start_move_3(pump_ser)
|
||
print('----->>Visual Endpoint<<-----')
|
||
print(volume_list[-1])
|
||
print(im_file)
|
||
color_list.append(1)
|
||
break
|
||
color_list.append(0)
|
||
print(total_volume)
|
||
print(volume_list)
|
||
print(voltage_list)
|
||
print(color_list)
|
||
|
||
with open(f'Output/{formatted_time}.json', 'w') as f:
|
||
# 使用json.dump()将列表保存到文件
|
||
json.dump({"volume_list": volume_list, 'voltage_list': voltage_list, 'color_list': color_list}, f)
|
||
# 关闭串口
|
||
pump_ser.close()
|
||
if port_USB:
|
||
USB_ser.close()
|
||
line_chart(formatted_time, volume_list = volume_list, voltage_list = voltage_list, color_list = color_list)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
import warnings
|
||
# 忽略所有警告
|
||
warnings.filterwarnings('ignore')
|
||
main()
|
||
|
||
|