enhance train and adapt main to train only

Former-commit-id: dd0f5b88b7a401eb41006cf357eff993ca661dd7
This commit is contained in:
2025-05-17 23:46:34 +08:00
parent 8849c28c14
commit b0670fa2ca
2 changed files with 27 additions and 25 deletions

View File

@ -71,7 +71,7 @@ def predict(model, label_map: Dict[int, str], val_data: Dict[str, List[np.ndarra
返回:
预测结果字典,格式为 dict[str, list[list[str]]],表示每个输入数组中样本的预测类别
"""
results = {}
failed = []
suc = 0
cnt = 0
@ -87,14 +87,15 @@ def predict(model, label_map: Dict[int, str], val_data: Dict[str, List[np.ndarra
# 进行预测并转换为类别名
pred_labels = model.predict(arr)
pred_classes = [label_map[label] for label in pred_labels]
if len(pred_classes) > 1:continue
if class_name==pred_classes[0]:
if len(pred_classes) == 1 and class_name==pred_classes[0]:
suc+=1
else:
failed.append(arrays_list)
# class_predictions.append(pred_classes)
results[class_name] = class_predictions
# results[class_name] = class_predictions
return suc/cnt
return suc/cnt,failed
if __name__ == "__main__":
exit()