enhance train and adapt main to train only
Former-commit-id: dd0f5b88b7a401eb41006cf357eff993ca661dd7
This commit is contained in:
@ -71,7 +71,7 @@ def predict(model, label_map: Dict[int, str], val_data: Dict[str, List[np.ndarra
|
||||
返回:
|
||||
预测结果字典,格式为 dict[str, list[list[str]]],表示每个输入数组中样本的预测类别
|
||||
"""
|
||||
results = {}
|
||||
failed = []
|
||||
|
||||
suc = 0
|
||||
cnt = 0
|
||||
@ -87,14 +87,15 @@ def predict(model, label_map: Dict[int, str], val_data: Dict[str, List[np.ndarra
|
||||
# 进行预测并转换为类别名
|
||||
pred_labels = model.predict(arr)
|
||||
pred_classes = [label_map[label] for label in pred_labels]
|
||||
if len(pred_classes) > 1:continue
|
||||
if class_name==pred_classes[0]:
|
||||
if len(pred_classes) == 1 and class_name==pred_classes[0]:
|
||||
suc+=1
|
||||
else:
|
||||
failed.append(arrays_list)
|
||||
# class_predictions.append(pred_classes)
|
||||
|
||||
results[class_name] = class_predictions
|
||||
# results[class_name] = class_predictions
|
||||
|
||||
return suc/cnt
|
||||
return suc/cnt,failed
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit()
|
||||
|
Reference in New Issue
Block a user