Files
ai-titration/Auto_Ctrl/predictor_Syringe_Pump.py
flt6 2c2822ff11 src
Former-commit-id: 43bc977a970a5bba09d0afa6f2a85169fe1ed253
2025-05-15 23:49:13 +08:00

281 lines
9.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
from PIL import Image
import torchvision.transforms as transforms
from torch.autograd import Variable
from torch import nn
import matplotlib.pyplot as plt
import cv2
import time
import os
from model import resnet34
import serial
from datetime import datetime
from scipy.optimize import curve_fit
import numpy as np
import re
import json
import Find_COM
import builtins
class MAT:
def __init__(self, videoSourceIndex=0, weights_path = "resnet34-1Net.pth", json_path = 'class_indices.json', classes = 2):
print('实验初始化中')
self.data_root = os.getcwd()
self.videoSourceIndex = videoSourceIndex # 摄像机编号
self.cap = cv2.VideoCapture(videoSourceIndex, cv2.CAP_DSHOW) # 打开摄像头
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.port = Find_COM.list_ch340_ports()[0] # 串口名
self.pump_ser = serial.Serial(self.port, 9600) # 初始化串口
self.usb_port = Find_COM.list_USB_ports() # 串口名
if self.usb_port:
self.usb_ser = serial.Serial(self.usb_port, 115200) # 初始化串口
self.classes = classes
self.total_volume = 0 # 记录总体积
self.now_volume = 0 # 记录当前注射泵内体积
self.volume_list = [] # 记录体积变化
self.voltage_list = [] # 记录电位变化(如有需要)
self.color_list = [] # 记录颜色变化
self.start_time = time.time() # 记录实验开始时间
self.weights_path = os.path.join(self.data_root, weights_path) # 权重文件路径
self.json_path = os.path.join(self.data_root, json_path) # 类别文件路径
# 将开始时间转化为年月日时分秒的格式,后续文件命名都已此命名
self.formatted_time = datetime.fromtimestamp(self.start_time).strftime('%Y%m%d_%H%M%S')
print("实验开始于", self.formatted_time)
def get_picture(self, frame, typ=0, date=''): # 拍摄照片并保存
if frame is None:
print(frame)
image_name = f'{date}_{self.total_volume}.jpg' # 照片保存在Input文件夹下以开始时间+体积数的方式命名
filepath = os.path.join(self.data_root, "Input", image_name)
str_name = filepath.replace('%s', '1')
cv2.imwrite(str_name, frame)
return image_name
def start_move_1(self): # 抽料程序
data = b"q1h24d" # *2
self.pump_ser.write(data)
time.sleep(0.01)
data = b"q2h0d"
self.pump_ser.write(data)
time.sleep(0.01)
data = b"q4h0d"
self.pump_ser.write(data)
time.sleep(0.01)
data = b"q5h15d"
self.pump_ser.write(data)
time.sleep(0.01)
data = b"q6h3d"
self.pump_ser.write(data)
time.sleep(15)
print('完成抽取')
def start_move_2(self, speed=0.1): # 进料程序
# 计算单次滴定体积并传输至控制器
speed_min = speed * 30
speed_min_int = int(speed_min)
speed_min_float = int((speed_min - speed_min_int) * 100)
# print(speed_min_int, speed_min_float)
data = f"q1h{speed_min_int}d"
self.pump_ser.write(data.encode('ascii'))
time.sleep(0.01)
data = f"q2h{speed_min_float}d"
self.pump_ser.write(data.encode('ascii'))
time.sleep(0.01)
data = b"q4h0d"
self.pump_ser.write(data)
time.sleep(0.01)
data = b"q5h1d"
self.pump_ser.write(data)
time.sleep(0.01)
# 进料
data = b"q6h2d"
self.pump_ser.write(data)
time.sleep(1)
def start_move_3(self): # 进料急停
data = b"q6h6d"
self.pump_ser.write(data)
def voltage(self): # 测量电位
self.usb_ser.write("VOL|\n".encode())
time.sleep(0.1)
while True:
response = self.usb_ser.readline().decode().strip()
if response:
try:
return float(response)
except:
return 0
@staticmethod
def poly_func(x, a, b, c, d):
return a * np.tanh(d * x + b) + c
def line_chart(self):
x = self.volume_list
y = self.voltage_list
z = self.color_list
fig, ax1 = plt.subplots()
plt.title("titration curve")
color = 'tab:red'
ax1.set_xlabel('value')
ax1.set_ylabel('voltage', color=color)
ax1.plot(x, y, color=color, antialiased=True)
ax1.tick_params(axis='y', labelcolor=color)
ax2 = ax1.twinx()
color = 'tab:blue'
ax2.set_ylabel('color', color=color)
ax2.plot(x, z, color=color)
ax2.tick_params(axis='y', labelcolor=color)
ax2.set_yticks([0, 1])
ax2.set_yticklabels(['yellow', 'orange'])
ax2.spines['right'].set_position(('outward', 60))
try:
popt, pcov = curve_fit(self.poly_func, x, y, p0=[max(y) * 3 / 4, -max(x), max(y), 1.5])
print("最优参数:", popt)
print(f'电位突跃点:{-popt[1] / popt[3]:.3f}')
x_d = np.arange(0, max(x), 0.05)
y_fit = self.poly_func(x_d, *popt)
dE_dV = np.gradient(y_fit)
d2E_dV2 = np.gradient(dE_dV)
y2 = d2E_dV2.tolist()
ax3 = ax1.twinx()
color = 'tab:green'
ax3.set_ylabel('2nd Derivative', color=color)
ax3.plot(x_d, y2, color=color)
ax3.tick_params(axis='y', labelcolor=color)
ax3.grid(True, linestyle='--', linewidth=0.5, color='gray', axis='both')
x_d, y_d = -popt[1] / popt[3], 0.0
ax3.plot(x_d, y_d, 'ro')
ax3.annotate(f'({x_d:.2f})', xy=(x_d, y_d), color='red', xytext=(x_d - 1, y_d + max(y2) / 10))
except Exception as e:
print(e)
pass
fig.tight_layout()
plt.savefig(f'Output/{self.formatted_time}.png')
plt.show()
plt.pause(1)
plt.close()
def predictor(self, im_file): # 预测分类
image = Image.open(im_file)
data_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
img = data_transform(image)
img = torch.unsqueeze(img, dim=0)
with open(self.json_path, "r") as f:
class_indict = json.load(f)
model = resnet34(num_classes=self.classes).to(self.device)
assert os.path.exists(self.weights_path), "file: '{}' dose not exist.".format(self.weights_path)
model.load_state_dict(torch.load(self.weights_path, map_location=self.device))
model.eval()
with torch.no_grad():
output = torch.squeeze(model(img.to(self.device))).cpu()
predict = torch.softmax(output, dim=0)
predict_cla = torch.argmax(predict).numpy()
class_a = "{}".format(class_indict[str(predict_cla)])
prob_a = "{:.3}".format(predict[predict_cla].numpy())
prob_b = float(prob_a)
print('class_:',class_a)
print('prob_:',prob_b)
return class_a, prob_b
def __del__(self):
# 绘制滴定曲线
# self.line_chart()
# 关闭串口和摄像头
self.pump_ser.close()
self.cap.release()
cv2.destroyAllWindows()
print("Experiment finished.")
def run(self,quick_speed = 0.2, slow_speed = 0.05,switching_point = 5, end_kind = 'orange', end_prob =0.5):
n = 1
total_n = n
while True:
if self.now_volume <= 0:
self.start_move_1() # 抽取12ml
self.now_volume += 12
if self.total_volume < switching_point: # 每次加0.2ml
speed = quick_speed
self.start_move_2(speed)
self.total_volume += speed
self.now_volume -= speed
else:
speed = slow_speed
self.start_move_2(speed) # 每次加0.05ml
self.total_volume += speed
self.now_volume -= speed
self.total_volume = round(self.total_volume, 3)
# 读取图片
ret, frame = self.cap.read()
if not ret:
print("Failed to capture frame from camera.")
break
name = self.get_picture(frame, 0, self.formatted_time)
im_file = 'Input/' + name
cv2.imshow('Color', frame)
cv2.waitKey(1)
class_a, prob_b = self.predictor(im_file)
self.volume_list.append(self.total_volume)
# 如果有电压测量设备,可以在这里读取电压
# self.voltage_list.append(self.voltage())
if class_a == end_kind and prob_b > end_prob: # 判断终点
print('----->>Visual Endpoint<<-----')
print(f"Total Volume: {self.total_volume} ml")
print(f"Image File: {im_file}")
self.color_list.append(1)
break
else:
self.color_list.append(0)
print(f"Current Total Volume: {self.total_volume} ml")
print("Volume List:", self.volume_list)
print("Voltage List:", self.voltage_list)
print("Color List:", self.color_list)
# 保存实验数据到JSON文件
with builtins.open(f'Output/{self.formatted_time}.json', 'w') as f:
json.dump(
{"volume_list": self.volume_list, 'voltage_list': self.voltage_list, 'color_list': self.color_list},
f)
if __name__ == "__main__":
import warnings
# 忽略所有警告
warnings.filterwarnings('ignore')
# 创建MAT类的实例并运行
mat = MAT(videoSourceIndex = 0, weights_path = "resnet34-1Net.pth", json_path = 'class_indices.json', classes = 2)
mat.run(quick_speed = 0.2, slow_speed = 0.05, switching_point = 5, end_kind = 'orange', end_prob = 0.5)