knn增强数据训练

简介: 【7月更文挑战第28天】

在 load_data 和 predict_image 函数中,将图像转换为灰度图像 (color.rgb2gray)。
将图像数据类型转换为 uint8 (img_as_ubyte)。
增强数据集包括彩色图像和其灰度版本。
这样做可以确保模型能够识别黑白打印的图像。你可以进一步调整数据增强策略,以更好地适应黑白图像的特性。

import os
from skimage import io, color, transform, img_as_ubyte
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
import numpy as np
import imgaug.augmenters as iaa
import pickle

# 导入UGOT机器人库
from ugot import ugot
import cv2

# 函数:从文件夹加载图像数据
def load_data(data_folder):
    data = []
    labels = []
    for folder in os.listdir(data_folder):
        folder_path = os.path.join(data_folder, folder)
        print("正在:", folder_path)  # 新增的打印语句
        if os.path.isdir(folder_path):
            label = folder
            for filename in os.listdir(folder_path):
                print("正在读取文件:", filename)  # 新增的打印语句
                img_path = os.path.join(folder_path, filename)
                img = io.imread(img_path)
                img = transform.resize(img, (50, 50))  # 调整图像大小为50x50像素
                img_gray = color.rgb2gray(img)  # 转换为灰度图
                img = img_as_ubyte(img_gray)  # 转换图像数据类型为 uint8
                data.append(img.flatten())  # 将图像数据展平
                labels.append(label)
    return data, labels

# 数据增强
def augment_data(images, labels, n_augmentations=5):
    aug = iaa.Sequential([
        iaa.Fliplr(0.5), # 水平翻转
        iaa.Affine(rotate=(-20, 20)), # 随机旋转
        iaa.Multiply((0.8, 1.2)), # 随机亮度
        iaa.Affine(scale=(0.9, 1.1)) # 随机缩放
    ])
    augmented_images = []
    augmented_labels = []
    for img, label in zip(images, labels):
        img = img.reshape((50, 50, -1))  # 恢复图像形状
        for _ in range(n_augmentations):
            augmented_img = aug(image=img)
            augmented_images.append(augmented_img.flatten())  # 展平图像
            augmented_labels.append(label)
    return augmented_images, augmented_labels

# 加载数据
data_folder = './data'
print(data_folder)
X, y = load_data(data_folder)

# 增强数据
X_augmented, y_augmented = augment_data(X, y)

# 将原始数据和增强数据合并
X_combined = X + X_augmented
y_combined = y + y_augmented

# 将数据拆分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X_combined, y_combined, test_size=0.2, random_state=42)

# 初始化KNN分类器
knn_classifier = KNeighborsClassifier(n_neighbors=3)

# 训练模型
knn_classifier.fit(X_train, y_train)

# 评估模型
accuracy = knn_classifier.score(X_test, y_test)
print("Accuracy:", accuracy)

# 保存模型
model_filename = 'knn_model.pkl'
with open(model_filename, 'wb') as model_file:
    pickle.dump(knn_classifier, model_file)

# 加载模型
with open(model_filename, 'rb') as model_file:
    loaded_model = pickle.load(model_file)

# 设置置信度阈值
confidence_threshold = 0.5

# 函数:预测图像内容
def predict_image(img):
    img = transform.resize(img, (50, 50))
    img_gray = color.rgb2gray(img)  # 转换为灰度图
    img = img_as_ubyte(img_gray)  # 转换图像数据类型为 uint8
    flattened_img = img.flatten()
    probs = loaded_model.predict_proba([flattened_img])
    print("Prediction Probabilities:", probs)  # 打印预测概率分布
    max_prob = np.max(probs)
    if max_prob < confidence_threshold:
        return "no"
    prediction = loaded_model.predict([flattened_img])
    return prediction[0]

def predict_and_display_image(img):
    img_resized = transform.resize(img, (50, 50))
    img_gray = color.rgb2gray(img_resized)  # 转换为灰度图
    img_resized = img_as_ubyte(img_gray)  # 转换图像数据类型为 uint8
    flattened_img = img_resized.flatten()
    probs = loaded_model.predict_proba([flattened_img])
    max_prob = np.max(probs)
    if max_prob < confidence_threshold:
        prediction = "no"
    else:
        prediction = loaded_model.predict([flattened_img])[0]
    print(prediction)  # 调试信息
    # 在图像上显示预测结果
    cv2.putText(img, f'Prediction: {prediction}', (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)

    cv2.imshow('Image', img)
    cv2.waitKey(1)
    # cv2.destroyAllWindows()


try:
    while True:
        frame = got.read_camera_data()
        if frame is not None:
            nparr = np.frombuffer(frame, np.uint8)
            img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)

            # 使用加载的模型进行预测并在图像上显示结果
            predict_and_display_image(img)

            # 在此处添加物体跟随和闭环控制逻辑
            # print('-------:', got.get_knn_result('knn_model'))
except KeyboardInterrupt:
    print('-----KeyboardInterrupt')

image.png

目录
相关文章
|
29天前
|
机器学习/深度学习 算法 前端开发
别再用均值填充了!MICE算法教你正确处理缺失数据
MICE是一种基于迭代链式方程的缺失值插补方法,通过构建后验分布并生成多个完整数据集,有效量化不确定性。相比简单填补,MICE利用变量间复杂关系,提升插补准确性,适用于多变量关联、缺失率高的场景。本文结合PMM与线性回归,详解其机制并对比效果,验证其在统计推断中的优势。
629 11
别再用均值填充了!MICE算法教你正确处理缺失数据
|
2月前
|
传感器 机器学习/深度学习 算法
【使用 DSP 滤波器加速速度和位移】使用信号处理算法过滤加速度数据并将其转换为速度和位移研究(Matlab代码实现)
【使用 DSP 滤波器加速速度和位移】使用信号处理算法过滤加速度数据并将其转换为速度和位移研究(Matlab代码实现)
182 1
|
2月前
|
机器学习/深度学习 算法 调度
14种智能算法优化BP神经网络(14种方法)实现数据预测分类研究(Matlab代码实现)
14种智能算法优化BP神经网络(14种方法)实现数据预测分类研究(Matlab代码实现)
284 0
|
22天前
|
机器学习/深度学习 人工智能 算法
【基于TTNRBO优化DBN回归预测】基于瞬态三角牛顿-拉夫逊优化算法(TTNRBO)优化深度信念网络(DBN)数据回归预测研究(Matlab代码实现)
【基于TTNRBO优化DBN回归预测】基于瞬态三角牛顿-拉夫逊优化算法(TTNRBO)优化深度信念网络(DBN)数据回归预测研究(Matlab代码实现)
|
2月前
|
存储 监控 算法
企业电脑监控系统中基于 Go 语言的跳表结构设备数据索引算法研究
本文介绍基于Go语言的跳表算法在企业电脑监控系统中的应用,通过多层索引结构将数据查询、插入、删除操作优化至O(log n),显著提升海量设备数据管理效率,解决传统链表查询延迟问题,实现高效设备状态定位与异常筛选。
98 3
|
1月前
|
数据采集 分布式计算 并行计算
mRMR算法实现特征选择-MATLAB
mRMR算法实现特征选择-MATLAB
109 2
|
2月前
|
传感器 机器学习/深度学习 编解码
MATLAB|主动噪声和振动控制算法——对较大的次级路径变化具有鲁棒性
MATLAB|主动噪声和振动控制算法——对较大的次级路径变化具有鲁棒性
173 3
|
22天前
|
机器学习/深度学习 算法 机器人
【水下图像增强融合算法】基于融合的水下图像与视频增强研究(Matlab代码实现)
【水下图像增强融合算法】基于融合的水下图像与视频增强研究(Matlab代码实现)
121 0
|
2月前
|
存储 编解码 算法
【多光谱滤波器阵列设计的最优球体填充】使用MSFA设计方法进行各种重建算法时,图像质量可以提高至多2 dB,并在光谱相似性方面实现了显著提升(Matlab代码实现)
【多光谱滤波器阵列设计的最优球体填充】使用MSFA设计方法进行各种重建算法时,图像质量可以提高至多2 dB,并在光谱相似性方面实现了显著提升(Matlab代码实现)
|
22天前
|
机器学习/深度学习 算法 机器人
使用哈里斯角Harris和SIFT算法来实现局部特征匹配(Matlab代码实现)
使用哈里斯角Harris和SIFT算法来实现局部特征匹配(Matlab代码实现)
116 8

热门文章

最新文章