enhance train and adapt main to train only

Former-commit-id: dd0f5b88b7a401eb41006cf357eff993ca661dd7
This commit is contained in:
2025-05-17 23:46:34 +08:00
parent 8849c28c14
commit b0670fa2ca
2 changed files with 27 additions and 25 deletions

View File

@ -4,8 +4,8 @@ from pathlib import Path
from train_2 import train_model,prepare_data,predict from train_2 import train_model,prepare_data,predict
import joblib import joblib
inp = Path("data/train") inp = Path("data/train_new")
val = Path("data/val") val = Path("data/val_new")
proc = Path("data/proc") proc = Path("data/proc")
proc.mkdir(exist_ok=True) proc.mkdir(exist_ok=True)
@ -33,36 +33,37 @@ def preproc(dir:Path):
h = round(np.sum(hsv[:,:,0])/cnt) h = round(np.sum(hsv[:,:,0])/cnt)
s = round(np.sum(hsv[:,:,1])/cnt) s = round(np.sum(hsv[:,:,1])/cnt)
v = round(np.sum(hsv[:,:,2])/cnt) v = round(np.sum(hsv[:,:,2])/cnt)
name = f"{h}_{s}_{v}.jpg" name = f"{cl}_{h}_{s}_{v}.jpg"
d[cl].append((h,s,v)) d[cl].append((h,s,v))
# (proc/cl).mkdir(exist_ok=True) # (proc/cl).mkdir(exist_ok=True)
# cv2.imwrite(proc/cl/name,cv2.cvtColor(hsv,cv2.COLOR_HSV2BGR)) # cv2.imwrite(proc/name,cv2.cvtColor(hsv,cv2.COLOR_HSV2BGR))
return d return d
d:dict[str,list[tuple[int,int,int]]] = preproc(inp) # d:dict[str,list[tuple[int,int,int]]] = preproc(inp)
val:dict[str,list[tuple[int,int,int]]] = preproc(val) val:dict[str,list[tuple[int,int,int]]] = preproc(val)
print("数据预处理完成") print("数据预处理完成")
model, label_map = train_model(d) model, label_map = train_model(val)
print("训练完成") # print("训练完成")
joblib.dump(model, "model.pkl") # joblib.dump(model, "model.pkl")
# model = joblib.load("model.pkl") # model = joblib.load("model.pkl")
X_train, y_train = prepare_data(d) X_train, y_train = prepare_data(val)
print(predict(model, label_map, d)) rate,failed = predict(model, label_map, val)
print(predict(model, label_map, val)) print(rate)
# print(predict(model, label_map, val))
# print(f"\n训练集准确率: {model.score(X_train, y_train):.5f}") # print(f"\n训练集准确率: {model.score(X_train, y_train):.5f}")
# X_train, y_train = prepare_data(val) # X_train, y_train = prepare_data(val)
# print(f"\n训练集准确率: {model.score(X_train, y_train):.5f}") # print(f"\n训练集准确率: {model.score(X_train, y_train):.5f}")
# model.predict() # model.predict()
from src_predict import predictor # from src_predict import predictor
suc = cnt = 0 # suc = cnt = 0
for file in Path("data/train").glob("*/*.jpg"): # for file in Path("data/val").glob("*/*.jpg"):
cnt+=1 # cnt+=1
pcl,_=predictor(file) # pcl,_=predictor(file)
acl = file.parents[0].name # acl = file.parents[0].name
if acl == pcl: # if acl == pcl:
suc+=1 # suc+=1
print(f"预测准确率: {suc/cnt:.4f}") # print(f"预测准确率: {suc/cnt:.4f}")

View File

@ -71,7 +71,7 @@ def predict(model, label_map: Dict[int, str], val_data: Dict[str, List[np.ndarra
返回: 返回:
预测结果字典,格式为 dict[str, list[list[str]]],表示每个输入数组中样本的预测类别 预测结果字典,格式为 dict[str, list[list[str]]],表示每个输入数组中样本的预测类别
""" """
results = {} failed = []
suc = 0 suc = 0
cnt = 0 cnt = 0
@ -87,14 +87,15 @@ def predict(model, label_map: Dict[int, str], val_data: Dict[str, List[np.ndarra
# 进行预测并转换为类别名 # 进行预测并转换为类别名
pred_labels = model.predict(arr) pred_labels = model.predict(arr)
pred_classes = [label_map[label] for label in pred_labels] pred_classes = [label_map[label] for label in pred_labels]
if len(pred_classes) > 1:continue if len(pred_classes) == 1 and class_name==pred_classes[0]:
if class_name==pred_classes[0]:
suc+=1 suc+=1
else:
failed.append(arrays_list)
# class_predictions.append(pred_classes) # class_predictions.append(pred_classes)
results[class_name] = class_predictions # results[class_name] = class_predictions
return suc/cnt return suc/cnt,failed
if __name__ == "__main__": if __name__ == "__main__":
exit() exit()