在 load_data 和 predict_image 函数中,将图像转换为灰度图像 (color.rgb2gray)。
将图像数据类型转换为 uint8 (img_as_ubyte)。
增强数据集包括彩色图像和其灰度版本。
这样做可以确保模型能够识别黑白打印的图像。你可以进一步调整数据增强策略,以更好地适应黑白图像的特性。
import os
from skimage import io, color, transform, img_as_ubyte
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
import numpy as np
import imgaug.augmenters as iaa
import pickle
# 导入UGOT机器人库
from ugot import ugot
import cv2
# 函数:从文件夹加载图像数据
def load_data(data_folder):
data = []
labels = []
for folder in os.listdir(data_folder):
folder_path = os.path.join(data_folder, folder)
print("正在:", folder_path) # 新增的打印语句
if os.path.isdir(folder_path):
label = folder
for filename in os.listdir(folder_path):
print("正在读取文件:", filename) # 新增的打印语句
img_path = os.path.join(folder_path, filename)
img = io.imread(img_path)
img = transform.resize(img, (50, 50)) # 调整图像大小为50x50像素
img_gray = color.rgb2gray(img) # 转换为灰度图
img = img_as_ubyte(img_gray) # 转换图像数据类型为 uint8
data.append(img.flatten()) # 将图像数据展平
labels.append(label)
return data, labels
# 数据增强
def augment_data(images, labels, n_augmentations=5):
aug = iaa.Sequential([
iaa.Fliplr(0.5), # 水平翻转
iaa.Affine(rotate=(-20, 20)), # 随机旋转
iaa.Multiply((0.8, 1.2)), # 随机亮度
iaa.Affine(scale=(0.9, 1.1)) # 随机缩放
])
augmented_images = []
augmented_labels = []
for img, label in zip(images, labels):
img = img.reshape((50, 50, -1)) # 恢复图像形状
for _ in range(n_augmentations):
augmented_img = aug(image=img)
augmented_images.append(augmented_img.flatten()) # 展平图像
augmented_labels.append(label)
return augmented_images, augmented_labels
# 加载数据
data_folder = './data'
print(data_folder)
X, y = load_data(data_folder)
# 增强数据
X_augmented, y_augmented = augment_data(X, y)
# 将原始数据和增强数据合并
X_combined = X + X_augmented
y_combined = y + y_augmented
# 将数据拆分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X_combined, y_combined, test_size=0.2, random_state=42)
# 初始化KNN分类器
knn_classifier = KNeighborsClassifier(n_neighbors=3)
# 训练模型
knn_classifier.fit(X_train, y_train)
# 评估模型
accuracy = knn_classifier.score(X_test, y_test)
print("Accuracy:", accuracy)
# 保存模型
model_filename = 'knn_model.pkl'
with open(model_filename, 'wb') as model_file:
pickle.dump(knn_classifier, model_file)
# 加载模型
with open(model_filename, 'rb') as model_file:
loaded_model = pickle.load(model_file)
# 设置置信度阈值
confidence_threshold = 0.5
# 函数:预测图像内容
def predict_image(img):
img = transform.resize(img, (50, 50))
img_gray = color.rgb2gray(img) # 转换为灰度图
img = img_as_ubyte(img_gray) # 转换图像数据类型为 uint8
flattened_img = img.flatten()
probs = loaded_model.predict_proba([flattened_img])
print("Prediction Probabilities:", probs) # 打印预测概率分布
max_prob = np.max(probs)
if max_prob < confidence_threshold:
return "no"
prediction = loaded_model.predict([flattened_img])
return prediction[0]
def predict_and_display_image(img):
img_resized = transform.resize(img, (50, 50))
img_gray = color.rgb2gray(img_resized) # 转换为灰度图
img_resized = img_as_ubyte(img_gray) # 转换图像数据类型为 uint8
flattened_img = img_resized.flatten()
probs = loaded_model.predict_proba([flattened_img])
max_prob = np.max(probs)
if max_prob < confidence_threshold:
prediction = "no"
else:
prediction = loaded_model.predict([flattened_img])[0]
print(prediction) # 调试信息
# 在图像上显示预测结果
cv2.putText(img, f'Prediction: {prediction}', (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)
cv2.imshow('Image', img)
cv2.waitKey(1)
# cv2.destroyAllWindows()
try:
while True:
frame = got.read_camera_data()
if frame is not None:
nparr = np.frombuffer(frame, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
# 使用加载的模型进行预测并在图像上显示结果
predict_and_display_image(img)
# 在此处添加物体跟随和闭环控制逻辑
# print('-------:', got.get_knn_result('knn_model'))
except KeyboardInterrupt:
print('-----KeyboardInterrupt')