Files
ai-titration/utils.py
flt6 e0e1c649eb clean and format codebase
Former-commit-id: 5d0497ac67199a7ea475849a6ec3f28df46371cb
2025-07-07 18:52:50 +08:00

311 lines
9.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import base64
import json
import logging
import os
import time
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from typing import List, Literal, Optional
import numpy as np
import requests
@dataclass
class HistoryRecord:
"""历史记录的单个条目"""
timestamp: float
state: Literal["transport", "middle", "about", "colored"]
rate: float
volume: float
image: np.ndarray
class History:
"""滑动窗口历史记录管理类"""
def __init__(self, max_window_size: float = 5.0, base_time=5.0):
"""
初始化历史记录管理器
Args:
max_window_size: 最大窗口时间长度(秒)
base_time: 基准时间长度(秒)
display: 是否开启实时可视化显示
"""
if base_time > max_window_size:
max_window_size = base_time + 0.2
# raise ValueError("Base time must be less than or equal to max window size.")
self.records: List[HistoryRecord] = []
self.max_window_size = max_window_size
self.fulled = False
self.about_history = [] # 用于存储最近的about状态记录
self.base = None
self._base_time = base_time
self._base_cnt = 0
self.end_history: list[bool] = []
self.last_end = 0
def add_record(
self,
timestamp: float,
state: str,
rate: float,
volume: float,
image: np.ndarray,
):
"""添加新的历史记录"""
record = HistoryRecord(timestamp, state, rate, volume, image)
self.records.append(record)
self._cleanup_old_records(timestamp)
if (
self.base is None
and (timestamp - self._base_time) >= self.records[0].timestamp
):
if self._base_cnt < 30:
self._base_cnt += 1
return
base_records = self.get_recent_records(self._base_time, timestamp)
self.base = sum([rec.rate for rec in base_records]) / len(base_records)
get_endpoint_logger().info("Base rate calculated: %.2f", self.base)
def _cleanup_old_records(self, current_time: float):
"""清理过期的历史记录"""
cutoff_time = current_time - self.max_window_size
if not self.records and self.records[0].timestamp < cutoff_time:
return False
while self.records and self.records[0].timestamp < cutoff_time:
self.records.pop(0)
return True
def get_records_in_timespan(
self, start_time: float, end_time: Optional[float] = None
) -> List[HistoryRecord]:
"""获取指定时间段内的记录"""
if end_time is None:
return [record for record in self.records if record.timestamp >= start_time]
else:
return [
record
for record in self.records
if start_time <= record.timestamp <= end_time
]
def get_recent_records(
self, duration: float, current_time: float
) -> List[HistoryRecord]:
"""获取最近指定时间长度内的记录"""
start_time = current_time - duration
return self.get_records_in_timespan(start_time, current_time)
def get_state_ratio(
self, target_state: str, records: Optional[List[HistoryRecord]] = None
) -> float:
"""计算指定状态在记录中的比例"""
if records is None:
records = self.records
if not records:
return 0.0
target_count = sum(1 for record in records if record.state == target_state)
return target_count / len(records)
def get_states_by_type(self, target_state: str) -> List[float]:
"""获取所有指定状态的时间戳"""
return [
record.timestamp for record in self.records if record.state == target_state
]
def find_record_by_timestamp(
self, target_timestamp: float
) -> Optional[HistoryRecord]:
"""根据时间戳查找记录"""
for record in self.records:
if record.timestamp == target_timestamp:
return record
return None
def is_empty(self) -> bool:
"""检查历史记录是否为空"""
return len(self.records) == 0
class State:
"""滴定状态管理类"""
class Mode(Enum):
FAST = 0 # 快速模式
SLOW = 1 # 慢速模式 (middle)
ABOUT = 2 # 接近终点模式
def __init__(self, bounce_time=1):
self.mode = self.Mode.FAST
self.bounce_time = bounce_time
# 状态检查标志
self.in_middle_check = False
self.in_end_check = False
self.about_check = False
self.about_first_flag = False
# 时间记录
self.middle_detected_time: Optional[float] = None
def is_fast_mode(self):
return self.mode == self.Mode.FAST
def is_slow_mode(self):
return self.mode == self.Mode.SLOW
def is_about_mode(self):
return self.mode == self.Mode.ABOUT
def enter_middle_state(self, current_time):
"""进入middle状态 - 立即切换到slow模式并开始检查"""
self.mode = self.Mode.SLOW
self.in_middle_check = True
self.middle_detected_time = current_time
def enter_about_state(self, current_time):
"""从middle状态进入about状态"""
if self.mode == self.Mode.SLOW:
self.mode = self.Mode.ABOUT
def exit_middle_check(self):
"""退出middle检查状态返回fast模式"""
self.in_middle_check = False
self.middle_detected_time = None
self.mode = self.Mode.FAST
def exit_about(self):
"""about状态退出"""
self.about_check = False
self.about_first_flag = True
if self.mode == self.Mode.ABOUT:
self.mode = self.Mode.SLOW
def should_check_middle_exit(self, current_time):
"""检查是否应该进行middle退出检查"""
return (
self.in_middle_check
and self.middle_detected_time is not None
and current_time - self.middle_detected_time > self.bounce_time
and (self.mode == self.Mode.SLOW)
)
def get_status_text(self):
"""获取状态显示文本"""
status = []
current_time = time.time()
if (
self.in_middle_check
and current_time - self.middle_detected_time > self.bounce_time
):
status.append("MIDCHK")
return ", " + ", ".join(status) if status else ""
def login_to_platform(username, password):
"""登录到平台获取token"""
try:
wwwF = {"userName": username, "password": password}
url = "https://jingsai.mools.net/api/login"
response = requests.post(url, wwwF, timeout=2)
if response is None:
print("错误", "网络连接失败 ")
return None
request = json.loads(response.text)
if request["code"] == 1:
# 登陆成功
# print("登陆成功",'登陆成功')
# 从服务器获取到的数据中找到token
token = request["token"]
print("成功", "登录成功!")
return token
elif request["code"] == 2:
print("错误", "用户名或密码错误!")
return None
else:
print("错误", "登陆失败 ")
return None
except Exception as e:
print(e)
print("错误", f"发送数据时出错:{e}")
return None
def send_data_to_platform(token, data, picture):
"""将数据发送到平台"""
try:
# if 1:
# 打开图片文件并转换为 Base64 编码
with open(picture, "rb") as picture_file:
picture_data = picture_file.read()
base64_encoded_picture = base64.b64encode(picture_data).decode("utf-8")
# print(base64_encoded_picture)
# 更新数据字典,添加 Base64 编码的图片
data["final_image"] = base64_encoded_picture
# print(data)
# 设置请求头
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
}
url = "https://jingsai.mools.net/api/upload-record"
# 准备 JSON 数据
json_data = json.dumps(data)
# 发送 POST 请求
response = requests.post(url, headers=headers, data=json_data)
request = json.loads(response.text)
# print(request['code'])
# 检查响应
if request["code"] == 1:
print("提交成功", "提交成功")
else:
print(
"错误",
f"网络请求失败,状态码:{response.status_code}\n错误信息:{response.text}",
)
except Exception as e:
raise e
print("错误", f"发送数据时出错:{e}")
def setup_logging(log_level=logging.INFO, log_dir="logs"):
"""
设置logging配置创建不同模块的logger
Args:
log_level: 日志级别默认INFO
log_dir: 日志文件存储目录,默认"logs"
"""
# 创建日志目录
if not os.path.exists(log_dir):
os.makedirs(log_dir)
# 获取当前时间作为日志文件名
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
log_file = os.path.join(log_dir, f"titration_{timestamp}.log")
# 配置根logger
logging.basicConfig(
level=log_level,
format="%(asctime)s - %(name)8s - %(levelname)7s - %(message)s",
handlers=[
logging.FileHandler(log_file, encoding="utf-8"),
logging.StreamHandler(), # 同时输出到控制台
],
)
return log_file