init trans

Former-commit-id: 32f405e2fbc6096ba6f7a8b3f55e743bf4550d78
This commit is contained in:
2025-05-18 11:33:07 +08:00
parent 2e1734b811
commit e27e5d6129

View File

@ -13,7 +13,7 @@ import json
import Find_COM import Find_COM
from threading import Thread from threading import Thread
LOCAL_DEBUG = True LOCAL_DEBUG = False
if LOCAL_DEBUG: if LOCAL_DEBUG:
print("WARNING: Local debug mode is enabled. Serial communication will be skipped.") print("WARNING: Local debug mode is enabled. Serial communication will be skipped.")
@ -116,9 +116,10 @@ class MAT:
return None return None
ret = self.my_predictor(im) ret = self.my_predictor(im)
# print(ret)
if ret is None: if ret is None:
cv2.imwrite("tmp.jpg",im) print("Failed")
return self.predictor("tmp.jpg") self.thr = Thread(target=self._pred).start()
else: else:
if ret == self.end_kind: if ret == self.end_kind:
print("Stop at ",self.total_volume) print("Stop at ",self.total_volume)
@ -129,22 +130,18 @@ class MAT:
return ret,0.9 return ret,0.9
def my_predictor(self,im): def my_predictor(self,im):
model = self.model # im = cv2.imread(file)
ret = self.preproc(im) hsv = cv2.cvtColor(im,cv2.COLOR_BGR2HSV)
if ret is None: s = hsv[:,:,1]
return None mask = s>100
arr = np.array(ret) # print(mask)
if len(arr.shape) == 1: tot = mask.shape[0]*mask.shape[1]
arr = arr.reshape(1, -1) val = np.sum(mask)
# print(val/tot)
# 进行预测并转换为类别名 if val<tot*0.3:
pred_labels = model.predict(arr) return "transport"
# pred_classes = [label_map[label] for label in pred_labels]
mp = ["orange", "yellow"]
if len(pred_labels) == 1:
return mp[pred_labels[0]]
else: else:
return None return "colored"
def predictor(self, im_file): # 预测分类 def predictor(self, im_file): # 预测分类
image = Image.open(im_file) image = Image.open(im_file)
@ -229,7 +226,6 @@ class MAT:
print(f"Total Volume: {self.total_volume} ml") print(f"Total Volume: {self.total_volume} ml")
# print(f"Image File: {im_file}") # print(f"Image File: {im_file}")
print("Volume List:", self.volume_list) print("Volume List:", self.volume_list)
print("Voltage List:", self.voltage_list)
print("Color List:", self.color_list) print("Color List:", self.color_list)
@ -239,12 +235,12 @@ if __name__ == "__main__":
warnings.filterwarnings('ignore') warnings.filterwarnings('ignore')
# 创建MAT类的实例并运行 # 创建MAT类的实例并运行
mat = MAT(videoSourceIndex = 0, weights_path = "resnet34-1Net.pth", json_path = 'class_indices.json', classes = 2) mat = MAT(videoSourceIndex = 1, weights_path = "resnet34-1Net.pth", json_path = 'class_indices.json', classes = 2)
mat.run( mat.run(
quick_speed = 0.3, quick_speed = 0.3,
slow_speed = 0.2, slow_speed = 0.2,
expect = 11.2, expect = 10,
end_kind = 'orange', end_kind = 'colored',
) )