diff --git a/Picture_Train/main.py b/Picture_Train/main.py index 82f7da2..43d4c77 100644 --- a/Picture_Train/main.py +++ b/Picture_Train/main.py @@ -4,8 +4,8 @@ from pathlib import Path from train_2 import train_model,prepare_data,predict import joblib -inp = Path("data/train") -val = Path("data/val") +inp = Path("data/train_new") +val = Path("data/val_new") proc = Path("data/proc") proc.mkdir(exist_ok=True) @@ -33,36 +33,37 @@ def preproc(dir:Path): h = round(np.sum(hsv[:,:,0])/cnt) s = round(np.sum(hsv[:,:,1])/cnt) v = round(np.sum(hsv[:,:,2])/cnt) - name = f"{h}_{s}_{v}.jpg" + name = f"{cl}_{h}_{s}_{v}.jpg" d[cl].append((h,s,v)) # (proc/cl).mkdir(exist_ok=True) - # cv2.imwrite(proc/cl/name,cv2.cvtColor(hsv,cv2.COLOR_HSV2BGR)) + # cv2.imwrite(proc/name,cv2.cvtColor(hsv,cv2.COLOR_HSV2BGR)) return d -d:dict[str,list[tuple[int,int,int]]] = preproc(inp) +# d:dict[str,list[tuple[int,int,int]]] = preproc(inp) val:dict[str,list[tuple[int,int,int]]] = preproc(val) print("数据预处理完成") -model, label_map = train_model(d) -print("训练完成") -joblib.dump(model, "model.pkl") +model, label_map = train_model(val) +# print("训练完成") +# joblib.dump(model, "model.pkl") # model = joblib.load("model.pkl") -X_train, y_train = prepare_data(d) -print(predict(model, label_map, d)) -print(predict(model, label_map, val)) +X_train, y_train = prepare_data(val) +rate,failed = predict(model, label_map, val) +print(rate) +# print(predict(model, label_map, val)) # print(f"\n训练集准确率: {model.score(X_train, y_train):.5f}") # X_train, y_train = prepare_data(val) # print(f"\n训练集准确率: {model.score(X_train, y_train):.5f}") # model.predict() -from src_predict import predictor -suc = cnt = 0 -for file in Path("data/train").glob("*/*.jpg"): - cnt+=1 - pcl,_=predictor(file) - acl = file.parents[0].name - if acl == pcl: - suc+=1 +# from src_predict import predictor +# suc = cnt = 0 +# for file in Path("data/val").glob("*/*.jpg"): +# cnt+=1 +# pcl,_=predictor(file) +# acl = file.parents[0].name +# if acl == pcl: +# suc+=1 -print(f"预测准确率: {suc/cnt:.4f}") +# print(f"预测准确率: {suc/cnt:.4f}") diff --git a/Picture_Train/train_2.py b/Picture_Train/train_2.py index 8b841ed..3620c7f 100644 --- a/Picture_Train/train_2.py +++ b/Picture_Train/train_2.py @@ -71,7 +71,7 @@ def predict(model, label_map: Dict[int, str], val_data: Dict[str, List[np.ndarra 返回: 预测结果字典,格式为 dict[str, list[list[str]]],表示每个输入数组中样本的预测类别 """ - results = {} + failed = [] suc = 0 cnt = 0 @@ -87,14 +87,15 @@ def predict(model, label_map: Dict[int, str], val_data: Dict[str, List[np.ndarra # 进行预测并转换为类别名 pred_labels = model.predict(arr) pred_classes = [label_map[label] for label in pred_labels] - if len(pred_classes) > 1:continue - if class_name==pred_classes[0]: + if len(pred_classes) == 1 and class_name==pred_classes[0]: suc+=1 + else: + failed.append(arrays_list) # class_predictions.append(pred_classes) - results[class_name] = class_predictions + # results[class_name] = class_predictions - return suc/cnt + return suc/cnt,failed if __name__ == "__main__": exit()