hsv_1
Former-commit-id: 57a8fb7799ea7543b4af1ea49626346d50546f82
This commit is contained in:
3
.gitattributes
vendored
3
.gitattributes
vendored
@ -1,2 +1,5 @@
|
||||
# Auto detect text files and perform LF normalization
|
||||
* text=auto
|
||||
*.exe filter=lfs diff=lfs merge=lfs -text
|
||||
*.pth filter=lfs diff=lfs merge=lfs -text
|
||||
*.pkl filter=lfs diff=lfs merge=lfs -text
|
||||
|
3
Auto_Ctrl/.gitattributes
vendored
3
Auto_Ctrl/.gitattributes
vendored
@ -1,3 +0,0 @@
|
||||
*.exe filter=lfs diff=lfs merge=lfs -text
|
||||
*.pth filter=lfs diff=lfs merge=lfs -text
|
||||
*.pkl filter=lfs diff=lfs merge=lfs -text
|
@ -1,16 +1,68 @@
|
||||
# 这是一个示例 Python 脚本。
|
||||
import cv2
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from train_2 import train_model,prepare_data,predict
|
||||
import joblib
|
||||
|
||||
# 按 Shift+F10 执行或将其替换为您的代码。
|
||||
# 按 双击 Shift 在所有地方搜索类、文件、工具窗口、操作和设置。
|
||||
inp = Path("data/train")
|
||||
val = Path("data/val")
|
||||
proc = Path("data/proc")
|
||||
proc.mkdir(exist_ok=True)
|
||||
|
||||
|
||||
def print_hi(name):
|
||||
# 在下面的代码行中使用断点来调试脚本。
|
||||
print(f'Hi, {name}') # 按 Ctrl+F8 切换断点。
|
||||
def preproc(dir:Path):
|
||||
d={}
|
||||
for file in dir.glob("*/*.jpg"):
|
||||
im = cv2.imread(file)
|
||||
if im is None:
|
||||
print(f"Error reading image: {file}")
|
||||
continue
|
||||
cl = file.parents[0].name
|
||||
if cl not in d:
|
||||
d[cl] = []
|
||||
hsv = cv2.cvtColor(im,cv2.COLOR_BGR2HSV)
|
||||
mask = hsv[:,:,1] > 150
|
||||
# cor = np.argwhere(mask)
|
||||
# y_min,x_min = cor.min(axis=0)
|
||||
# y_max,x_max = cor.max(axis=0)
|
||||
mask = mask[:,:,np.newaxis]
|
||||
# cv2.findCoun
|
||||
# im = im[y_min:y_max,x_min:x_max]
|
||||
cnt = np.count_nonzero(mask)
|
||||
hsv*=mask
|
||||
h = round(np.sum(hsv[:,:,0])/cnt)
|
||||
s = round(np.sum(hsv[:,:,1])/cnt)
|
||||
v = round(np.sum(hsv[:,:,2])/cnt)
|
||||
name = f"{h}_{s}_{v}.jpg"
|
||||
d[cl].append((h,s,v))
|
||||
# (proc/cl).mkdir(exist_ok=True)
|
||||
# cv2.imwrite(proc/cl/name,cv2.cvtColor(hsv,cv2.COLOR_HSV2BGR))
|
||||
return d
|
||||
|
||||
d:dict[str,list[tuple[int,int,int]]] = preproc(inp)
|
||||
val:dict[str,list[tuple[int,int,int]]] = preproc(val)
|
||||
|
||||
# 按装订区域中的绿色按钮以运行脚本。
|
||||
if __name__ == '__main__':
|
||||
print_hi('PyCharm')
|
||||
print("数据预处理完成")
|
||||
|
||||
# 访问 https://www.jetbrains.com/help/pycharm/ 获取 PyCharm 帮助
|
||||
model, label_map = train_model(d)
|
||||
print("训练完成")
|
||||
joblib.dump(model, "model.pkl")
|
||||
|
||||
# model = joblib.load("model.pkl")
|
||||
X_train, y_train = prepare_data(d)
|
||||
print(predict(model, label_map, d))
|
||||
print(predict(model, label_map, val))
|
||||
# print(f"\n训练集准确率: {model.score(X_train, y_train):.5f}")
|
||||
# X_train, y_train = prepare_data(val)
|
||||
# print(f"\n训练集准确率: {model.score(X_train, y_train):.5f}")
|
||||
# model.predict()
|
||||
from src_predict import predictor
|
||||
suc = cnt = 0
|
||||
for file in Path("data/train").glob("*/*.jpg"):
|
||||
cnt+=1
|
||||
pcl,_=predictor(file)
|
||||
acl = file.parents[0].name
|
||||
if acl == pcl:
|
||||
suc+=1
|
||||
|
||||
print(f"预测准确率: {suc/cnt:.4f}")
|
||||
|
3
Picture_Train/model.pkl
Normal file
3
Picture_Train/model.pkl
Normal file
@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:409653310924dc977ea2ae5dd46042ef144f4c8500c460ba5ca5b0f5ce68bed8
|
||||
size 535673
|
29
Picture_Train/show.py
Normal file
29
Picture_Train/show.py
Normal file
@ -0,0 +1,29 @@
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import random
|
||||
|
||||
|
||||
# 创建一个 3D 图形
|
||||
fig = plt.figure()
|
||||
ax = fig.add_subplot(111, projection='3d')
|
||||
|
||||
# 为每个 key 分配一个随机颜色
|
||||
colors = {"orange":(1,0,0),"yellow":(0,1,0)}
|
||||
# for key in d.keys():
|
||||
# colors[key] = (random.random(), random.random(), random.random()) # 随机 RGB 颜色
|
||||
|
||||
# 绘制每个 key 的点
|
||||
for key, points in d.items():
|
||||
x_vals = [point[0] for point in points]
|
||||
y_vals = [point[1] for point in points]
|
||||
z_vals = [point[2] for point in points]
|
||||
ax.scatter(x_vals, y_vals, z_vals, label=key, color=colors[key])
|
||||
|
||||
# 添加图例和标签
|
||||
ax.set_xlabel('X轴')
|
||||
ax.set_ylabel('Y轴')
|
||||
ax.set_zlabel('Z轴')
|
||||
ax.legend()
|
||||
|
||||
# 显示图形
|
||||
plt.show()
|
118
Picture_Train/train_2.py
Normal file
118
Picture_Train/train_2.py
Normal file
@ -0,0 +1,118 @@
|
||||
|
||||
import numpy as np
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
from typing import Dict, List, Tuple
|
||||
import joblib
|
||||
|
||||
def prepare_data(data: Dict[str, List[np.ndarray]]) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
将dict[str, list[ndarray]]格式的数据转换为模型可用的特征矩阵和标签向量
|
||||
|
||||
参数:
|
||||
data: 格式为 dict[str, list[ndarray]] 的数据,其中键为类别名,值为对应类别的特征数组列表
|
||||
|
||||
返回:
|
||||
X: 特征矩阵
|
||||
y: 标签向量
|
||||
"""
|
||||
features = []
|
||||
labels = []
|
||||
|
||||
# 为每个类别分配一个数字标签
|
||||
label_map = {class_name: i for i, class_name in enumerate(data.keys())}
|
||||
|
||||
for class_name, arrays_list in data.items():
|
||||
label = label_map[class_name]
|
||||
for arr in arrays_list:
|
||||
# 处理每个数组中的每个样本
|
||||
features.append(np.array(arr))
|
||||
labels.append(label)
|
||||
# if len(arr.shape) > 1:
|
||||
# for sample in arr:
|
||||
# features.append(sample)
|
||||
# labels.append(label)
|
||||
# else:
|
||||
# # 处理单个样本的情况
|
||||
# features.append(arr)
|
||||
# labels.append(label)
|
||||
|
||||
return np.array(features), np.array(labels)
|
||||
|
||||
def train_model(data: Dict[str, List[np.ndarray]]):
|
||||
"""
|
||||
训练分类模型
|
||||
|
||||
参数:
|
||||
data: 训练数据,格式为 dict[str, list[ndarray]]
|
||||
|
||||
返回:
|
||||
训练好的模型和标签映射字典
|
||||
"""
|
||||
X, y = prepare_data(data)
|
||||
|
||||
# 创建并训练模型
|
||||
model = RandomForestClassifier(n_estimators=100, random_state=42)
|
||||
model.fit(X, y)
|
||||
|
||||
# 创建逆向映射,用于将数字标签转回类别名
|
||||
label_map = {i: class_name for i, class_name in enumerate(data.keys())}
|
||||
|
||||
return model, label_map
|
||||
|
||||
def predict(model, label_map: Dict[int, str], val_data: Dict[str, List[np.ndarray]]) -> Dict[str, List[List[str]]]:
|
||||
"""
|
||||
使用训练好的模型对验证数据进行预测
|
||||
|
||||
参数:
|
||||
model: 训练好的模型
|
||||
label_map: 标签映射字典,用于将数字标签转换回类别名
|
||||
val_data: 验证数据,格式为 dict[str, list[ndarray]]
|
||||
|
||||
返回:
|
||||
预测结果字典,格式为 dict[str, list[list[str]]],表示每个输入数组中样本的预测类别
|
||||
"""
|
||||
results = {}
|
||||
|
||||
suc = 0
|
||||
cnt = 0
|
||||
for class_name, arrays_list in val_data.items():
|
||||
class_predictions = []
|
||||
for arr in arrays_list:
|
||||
# 确保数据格式正确
|
||||
arr = np.array(arr)
|
||||
cnt+=1
|
||||
if len(arr.shape) == 1:
|
||||
arr = arr.reshape(1, -1)
|
||||
|
||||
# 进行预测并转换为类别名
|
||||
pred_labels = model.predict(arr)
|
||||
pred_classes = [label_map[label] for label in pred_labels]
|
||||
if len(pred_classes) > 1:continue
|
||||
if class_name==pred_classes[0]:
|
||||
suc+=1
|
||||
# class_predictions.append(pred_classes)
|
||||
|
||||
results[class_name] = class_predictions
|
||||
|
||||
return suc/cnt
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit()
|
||||
# 训练模型
|
||||
model, label_map = train_model(d)
|
||||
print("训练完成")
|
||||
joblib.dump(model, "model.pkl")
|
||||
|
||||
# 在验证数据上进行预测
|
||||
# predictions = predict(model, label_map, val)
|
||||
|
||||
# 输出预测结果
|
||||
# print("预测结果:")
|
||||
# for class_name, class_preds in predictions.items():
|
||||
# print(f"{class_name}:")
|
||||
# for i, arr_preds in enumerate(class_preds):
|
||||
# print(f" 数组 {i}: {arr_preds}")
|
||||
|
||||
# 输出模型性能评估
|
||||
# X_train, y_train = prepare_data(val)
|
||||
# print(f"\n训练集准确率: {model.score(X_train, y_train):.4f}")
|
Reference in New Issue
Block a user