Files
ai-titration/Picture_Train/main.py
flt6 b0670fa2ca enhance train and adapt main to train only
Former-commit-id: dd0f5b88b7a401eb41006cf357eff993ca661dd7
2025-05-17 23:46:34 +08:00

70 lines
2.0 KiB
Python

import cv2
import numpy as np
from pathlib import Path
from train_2 import train_model,prepare_data,predict
import joblib
inp = Path("data/train_new")
val = Path("data/val_new")
proc = Path("data/proc")
proc.mkdir(exist_ok=True)
def preproc(dir:Path):
d={}
for file in dir.glob("*/*.jpg"):
im = cv2.imread(file)
if im is None:
print(f"Error reading image: {file}")
continue
cl = file.parents[0].name
if cl not in d:
d[cl] = []
hsv = cv2.cvtColor(im,cv2.COLOR_BGR2HSV)
mask = hsv[:,:,1] > 150
# cor = np.argwhere(mask)
# y_min,x_min = cor.min(axis=0)
# y_max,x_max = cor.max(axis=0)
mask = mask[:,:,np.newaxis]
# cv2.findCoun
# im = im[y_min:y_max,x_min:x_max]
cnt = np.count_nonzero(mask)
hsv*=mask
h = round(np.sum(hsv[:,:,0])/cnt)
s = round(np.sum(hsv[:,:,1])/cnt)
v = round(np.sum(hsv[:,:,2])/cnt)
name = f"{cl}_{h}_{s}_{v}.jpg"
d[cl].append((h,s,v))
# (proc/cl).mkdir(exist_ok=True)
# cv2.imwrite(proc/name,cv2.cvtColor(hsv,cv2.COLOR_HSV2BGR))
return d
# d:dict[str,list[tuple[int,int,int]]] = preproc(inp)
val:dict[str,list[tuple[int,int,int]]] = preproc(val)
print("数据预处理完成")
model, label_map = train_model(val)
# print("训练完成")
# joblib.dump(model, "model.pkl")
# model = joblib.load("model.pkl")
X_train, y_train = prepare_data(val)
rate,failed = predict(model, label_map, val)
print(rate)
# print(predict(model, label_map, val))
# print(f"\n训练集准确率: {model.score(X_train, y_train):.5f}")
# X_train, y_train = prepare_data(val)
# print(f"\n训练集准确率: {model.score(X_train, y_train):.5f}")
# model.predict()
# from src_predict import predictor
# suc = cnt = 0
# for file in Path("data/val").glob("*/*.jpg"):
# cnt+=1
# pcl,_=predictor(file)
# acl = file.parents[0].name
# if acl == pcl:
# suc+=1
# print(f"预测准确率: {suc/cnt:.4f}")