281 lines
9.9 KiB
Python
281 lines
9.9 KiB
Python
import torch
|
||
from PIL import Image
|
||
import torchvision.transforms as transforms
|
||
from torch.autograd import Variable
|
||
from torch import nn
|
||
import matplotlib.pyplot as plt
|
||
import cv2
|
||
import time
|
||
import os
|
||
from model import resnet34
|
||
import serial
|
||
from datetime import datetime
|
||
from scipy.optimize import curve_fit
|
||
import numpy as np
|
||
import re
|
||
import json
|
||
import Find_COM
|
||
import builtins
|
||
|
||
|
||
class MAT:
|
||
def __init__(self, videoSourceIndex=0, weights_path = "resnet34-1Net.pth", json_path = 'class_indices.json', classes = 2):
|
||
print('实验初始化中')
|
||
self.data_root = os.getcwd()
|
||
self.videoSourceIndex = videoSourceIndex # 摄像机编号
|
||
self.cap = cv2.VideoCapture(videoSourceIndex, cv2.CAP_DSHOW) # 打开摄像头
|
||
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||
self.port = Find_COM.list_ch340_ports()[0] # 串口名
|
||
self.pump_ser = serial.Serial(self.port, 9600) # 初始化串口
|
||
self.usb_port = Find_COM.list_USB_ports() # 串口名
|
||
if self.usb_port:
|
||
self.usb_ser = serial.Serial(self.usb_port, 115200) # 初始化串口
|
||
self.classes = classes
|
||
self.total_volume = 0 # 记录总体积
|
||
self.now_volume = 0 # 记录当前注射泵内体积
|
||
self.volume_list = [] # 记录体积变化
|
||
self.voltage_list = [] # 记录电位变化(如有需要)
|
||
self.color_list = [] # 记录颜色变化
|
||
self.start_time = time.time() # 记录实验开始时间
|
||
self.weights_path = os.path.join(self.data_root, weights_path) # 权重文件路径
|
||
self.json_path = os.path.join(self.data_root, json_path) # 类别文件路径
|
||
# 将开始时间转化为年月日时分秒的格式,后续文件命名都已此命名
|
||
self.formatted_time = datetime.fromtimestamp(self.start_time).strftime('%Y%m%d_%H%M%S')
|
||
|
||
print("实验开始于", self.formatted_time)
|
||
|
||
def get_picture(self, frame, typ=0, date=''): # 拍摄照片并保存
|
||
if frame is None:
|
||
print(frame)
|
||
image_name = f'{date}_{self.total_volume}.jpg' # 照片保存在Input文件夹下,以开始时间+体积数的方式命名
|
||
filepath = os.path.join(self.data_root, "Input", image_name)
|
||
str_name = filepath.replace('%s', '1')
|
||
cv2.imwrite(str_name, frame)
|
||
return image_name
|
||
|
||
def start_move_1(self): # 抽料程序
|
||
data = b"q1h24d" # *2
|
||
self.pump_ser.write(data)
|
||
time.sleep(0.01)
|
||
data = b"q2h0d"
|
||
self.pump_ser.write(data)
|
||
time.sleep(0.01)
|
||
data = b"q4h0d"
|
||
self.pump_ser.write(data)
|
||
time.sleep(0.01)
|
||
data = b"q5h15d"
|
||
self.pump_ser.write(data)
|
||
time.sleep(0.01)
|
||
data = b"q6h3d"
|
||
self.pump_ser.write(data)
|
||
time.sleep(15)
|
||
print('完成抽取')
|
||
|
||
def start_move_2(self, speed=0.1): # 进料程序
|
||
# 计算单次滴定体积并传输至控制器
|
||
speed_min = speed * 30
|
||
speed_min_int = int(speed_min)
|
||
speed_min_float = int((speed_min - speed_min_int) * 100)
|
||
# print(speed_min_int, speed_min_float)
|
||
data = f"q1h{speed_min_int}d"
|
||
self.pump_ser.write(data.encode('ascii'))
|
||
time.sleep(0.01)
|
||
data = f"q2h{speed_min_float}d"
|
||
self.pump_ser.write(data.encode('ascii'))
|
||
time.sleep(0.01)
|
||
data = b"q4h0d"
|
||
self.pump_ser.write(data)
|
||
time.sleep(0.01)
|
||
data = b"q5h1d"
|
||
self.pump_ser.write(data)
|
||
time.sleep(0.01)
|
||
# 进料
|
||
data = b"q6h2d"
|
||
self.pump_ser.write(data)
|
||
time.sleep(1)
|
||
|
||
def start_move_3(self): # 进料急停
|
||
data = b"q6h6d"
|
||
self.pump_ser.write(data)
|
||
|
||
def voltage(self): # 测量电位
|
||
self.usb_ser.write("VOL|\n".encode())
|
||
time.sleep(0.1)
|
||
while True:
|
||
response = self.usb_ser.readline().decode().strip()
|
||
if response:
|
||
try:
|
||
return float(response)
|
||
except:
|
||
return 0
|
||
|
||
@staticmethod
|
||
def poly_func(x, a, b, c, d):
|
||
return a * np.tanh(d * x + b) + c
|
||
|
||
def line_chart(self):
|
||
x = self.volume_list
|
||
y = self.voltage_list
|
||
z = self.color_list
|
||
|
||
fig, ax1 = plt.subplots()
|
||
plt.title("titration curve")
|
||
color = 'tab:red'
|
||
ax1.set_xlabel('value')
|
||
ax1.set_ylabel('voltage', color=color)
|
||
ax1.plot(x, y, color=color, antialiased=True)
|
||
ax1.tick_params(axis='y', labelcolor=color)
|
||
|
||
ax2 = ax1.twinx()
|
||
color = 'tab:blue'
|
||
ax2.set_ylabel('color', color=color)
|
||
ax2.plot(x, z, color=color)
|
||
ax2.tick_params(axis='y', labelcolor=color)
|
||
ax2.set_yticks([0, 1])
|
||
ax2.set_yticklabels(['yellow', 'orange'])
|
||
ax2.spines['right'].set_position(('outward', 60))
|
||
|
||
try:
|
||
popt, pcov = curve_fit(self.poly_func, x, y, p0=[max(y) * 3 / 4, -max(x), max(y), 1.5])
|
||
print("最优参数:", popt)
|
||
print(f'电位突跃点:{-popt[1] / popt[3]:.3f}')
|
||
x_d = np.arange(0, max(x), 0.05)
|
||
y_fit = self.poly_func(x_d, *popt)
|
||
dE_dV = np.gradient(y_fit)
|
||
d2E_dV2 = np.gradient(dE_dV)
|
||
y2 = d2E_dV2.tolist()
|
||
|
||
ax3 = ax1.twinx()
|
||
color = 'tab:green'
|
||
ax3.set_ylabel('2nd Derivative', color=color)
|
||
ax3.plot(x_d, y2, color=color)
|
||
ax3.tick_params(axis='y', labelcolor=color)
|
||
ax3.grid(True, linestyle='--', linewidth=0.5, color='gray', axis='both')
|
||
|
||
x_d, y_d = -popt[1] / popt[3], 0.0
|
||
ax3.plot(x_d, y_d, 'ro')
|
||
ax3.annotate(f'({x_d:.2f})', xy=(x_d, y_d), color='red', xytext=(x_d - 1, y_d + max(y2) / 10))
|
||
|
||
except Exception as e:
|
||
print(e)
|
||
pass
|
||
|
||
fig.tight_layout()
|
||
plt.savefig(f'Output/{self.formatted_time}.png')
|
||
plt.show()
|
||
plt.pause(1)
|
||
plt.close()
|
||
|
||
def predictor(self, im_file): # 预测分类
|
||
image = Image.open(im_file)
|
||
data_transform = transforms.Compose([
|
||
transforms.Resize(256),
|
||
transforms.CenterCrop(224),
|
||
transforms.ToTensor(),
|
||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||
])
|
||
img = data_transform(image)
|
||
img = torch.unsqueeze(img, dim=0)
|
||
with open(self.json_path, "r") as f:
|
||
class_indict = json.load(f)
|
||
|
||
model = resnet34(num_classes=self.classes).to(self.device)
|
||
|
||
assert os.path.exists(self.weights_path), "file: '{}' dose not exist.".format(self.weights_path)
|
||
model.load_state_dict(torch.load(self.weights_path, map_location=self.device))
|
||
|
||
model.eval()
|
||
with torch.no_grad():
|
||
output = torch.squeeze(model(img.to(self.device))).cpu()
|
||
predict = torch.softmax(output, dim=0)
|
||
predict_cla = torch.argmax(predict).numpy()
|
||
|
||
class_a = "{}".format(class_indict[str(predict_cla)])
|
||
prob_a = "{:.3}".format(predict[predict_cla].numpy())
|
||
prob_b = float(prob_a)
|
||
print('class_:',class_a)
|
||
print('prob_:',prob_b)
|
||
return class_a, prob_b
|
||
|
||
def __del__(self):
|
||
# 绘制滴定曲线
|
||
# self.line_chart()
|
||
|
||
# 关闭串口和摄像头
|
||
self.pump_ser.close()
|
||
self.cap.release()
|
||
cv2.destroyAllWindows()
|
||
print("Experiment finished.")
|
||
|
||
|
||
def run(self,quick_speed = 0.2, slow_speed = 0.05,switching_point = 5, end_kind = 'orange', end_prob =0.5):
|
||
n = 1
|
||
total_n = n
|
||
while True:
|
||
if self.now_volume <= 0:
|
||
self.start_move_1() # 抽取12ml
|
||
self.now_volume += 12
|
||
|
||
if self.total_volume < switching_point: # 每次加0.2ml
|
||
speed = quick_speed
|
||
self.start_move_2(speed)
|
||
self.total_volume += speed
|
||
self.now_volume -= speed
|
||
else:
|
||
speed = slow_speed
|
||
self.start_move_2(speed) # 每次加0.05ml
|
||
self.total_volume += speed
|
||
self.now_volume -= speed
|
||
|
||
self.total_volume = round(self.total_volume, 3)
|
||
|
||
# 读取图片
|
||
ret, frame = self.cap.read()
|
||
if not ret:
|
||
print("Failed to capture frame from camera.")
|
||
break
|
||
|
||
name = self.get_picture(frame, 0, self.formatted_time)
|
||
im_file = 'Input/' + name
|
||
|
||
cv2.imshow('Color', frame)
|
||
cv2.waitKey(1)
|
||
|
||
class_a, prob_b = self.predictor(im_file)
|
||
self.volume_list.append(self.total_volume)
|
||
|
||
# 如果有电压测量设备,可以在这里读取电压
|
||
# self.voltage_list.append(self.voltage())
|
||
|
||
if class_a == end_kind and prob_b > end_prob: # 判断终点
|
||
print('----->>Visual Endpoint<<-----')
|
||
print(f"Total Volume: {self.total_volume} ml")
|
||
print(f"Image File: {im_file}")
|
||
self.color_list.append(1)
|
||
break
|
||
else:
|
||
self.color_list.append(0)
|
||
|
||
print(f"Current Total Volume: {self.total_volume} ml")
|
||
print("Volume List:", self.volume_list)
|
||
print("Voltage List:", self.voltage_list)
|
||
print("Color List:", self.color_list)
|
||
|
||
# 保存实验数据到JSON文件
|
||
with builtins.open(f'Output/{self.formatted_time}.json', 'w') as f:
|
||
json.dump(
|
||
{"volume_list": self.volume_list, 'voltage_list': self.voltage_list, 'color_list': self.color_list},
|
||
f)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
import warnings
|
||
# 忽略所有警告
|
||
warnings.filterwarnings('ignore')
|
||
|
||
# 创建MAT类的实例并运行
|
||
mat = MAT(videoSourceIndex = 0, weights_path = "resnet34-1Net.pth", json_path = 'class_indices.json', classes = 2)
|
||
mat.run(quick_speed = 0.2, slow_speed = 0.05, switching_point = 5, end_kind = 'orange', end_prob = 0.5)
|
||
|
||
|