knn增强数据训练

简介: 【7月更文挑战第28天】

在 load_data 和 predict_image 函数中,将图像转换为灰度图像 (color.rgb2gray)。
将图像数据类型转换为 uint8 (img_as_ubyte)。
增强数据集包括彩色图像和其灰度版本。
这样做可以确保模型能够识别黑白打印的图像。你可以进一步调整数据增强策略,以更好地适应黑白图像的特性。

import os
from skimage import io, color, transform, img_as_ubyte
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
import numpy as np
import imgaug.augmenters as iaa
import pickle

# 导入UGOT机器人库
from ugot import ugot
import cv2

# 函数:从文件夹加载图像数据
def load_data(data_folder):
    data = []
    labels = []
    for folder in os.listdir(data_folder):
        folder_path = os.path.join(data_folder, folder)
        print("正在:", folder_path)  # 新增的打印语句
        if os.path.isdir(folder_path):
            label = folder
            for filename in os.listdir(folder_path):
                print("正在读取文件:", filename)  # 新增的打印语句
                img_path = os.path.join(folder_path, filename)
                img = io.imread(img_path)
                img = transform.resize(img, (50, 50))  # 调整图像大小为50x50像素
                img_gray = color.rgb2gray(img)  # 转换为灰度图
                img = img_as_ubyte(img_gray)  # 转换图像数据类型为 uint8
                data.append(img.flatten())  # 将图像数据展平
                labels.append(label)
    return data, labels

# 数据增强
def augment_data(images, labels, n_augmentations=5):
    aug = iaa.Sequential([
        iaa.Fliplr(0.5), # 水平翻转
        iaa.Affine(rotate=(-20, 20)), # 随机旋转
        iaa.Multiply((0.8, 1.2)), # 随机亮度
        iaa.Affine(scale=(0.9, 1.1)) # 随机缩放
    ])
    augmented_images = []
    augmented_labels = []
    for img, label in zip(images, labels):
        img = img.reshape((50, 50, -1))  # 恢复图像形状
        for _ in range(n_augmentations):
            augmented_img = aug(image=img)
            augmented_images.append(augmented_img.flatten())  # 展平图像
            augmented_labels.append(label)
    return augmented_images, augmented_labels

# 加载数据
data_folder = './data'
print(data_folder)
X, y = load_data(data_folder)

# 增强数据
X_augmented, y_augmented = augment_data(X, y)

# 将原始数据和增强数据合并
X_combined = X + X_augmented
y_combined = y + y_augmented

# 将数据拆分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X_combined, y_combined, test_size=0.2, random_state=42)

# 初始化KNN分类器
knn_classifier = KNeighborsClassifier(n_neighbors=3)

# 训练模型
knn_classifier.fit(X_train, y_train)

# 评估模型
accuracy = knn_classifier.score(X_test, y_test)
print("Accuracy:", accuracy)

# 保存模型
model_filename = 'knn_model.pkl'
with open(model_filename, 'wb') as model_file:
    pickle.dump(knn_classifier, model_file)

# 加载模型
with open(model_filename, 'rb') as model_file:
    loaded_model = pickle.load(model_file)

# 设置置信度阈值
confidence_threshold = 0.5

# 函数:预测图像内容
def predict_image(img):
    img = transform.resize(img, (50, 50))
    img_gray = color.rgb2gray(img)  # 转换为灰度图
    img = img_as_ubyte(img_gray)  # 转换图像数据类型为 uint8
    flattened_img = img.flatten()
    probs = loaded_model.predict_proba([flattened_img])
    print("Prediction Probabilities:", probs)  # 打印预测概率分布
    max_prob = np.max(probs)
    if max_prob < confidence_threshold:
        return "no"
    prediction = loaded_model.predict([flattened_img])
    return prediction[0]

def predict_and_display_image(img):
    img_resized = transform.resize(img, (50, 50))
    img_gray = color.rgb2gray(img_resized)  # 转换为灰度图
    img_resized = img_as_ubyte(img_gray)  # 转换图像数据类型为 uint8
    flattened_img = img_resized.flatten()
    probs = loaded_model.predict_proba([flattened_img])
    max_prob = np.max(probs)
    if max_prob < confidence_threshold:
        prediction = "no"
    else:
        prediction = loaded_model.predict([flattened_img])[0]
    print(prediction)  # 调试信息
    # 在图像上显示预测结果
    cv2.putText(img, f'Prediction: {prediction}', (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)

    cv2.imshow('Image', img)
    cv2.waitKey(1)
    # cv2.destroyAllWindows()


try:
    while True:
        frame = got.read_camera_data()
        if frame is not None:
            nparr = np.frombuffer(frame, np.uint8)
            img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)

            # 使用加载的模型进行预测并在图像上显示结果
            predict_and_display_image(img)

            # 在此处添加物体跟随和闭环控制逻辑
            # print('-------:', got.get_knn_result('knn_model'))
except KeyboardInterrupt:
    print('-----KeyboardInterrupt')

image.png

目录
相关文章
|
25天前
|
机器学习/深度学习 算法 搜索推荐
联邦学习的未来:深入剖析FedAvg算法与数据不均衡的解决之道
随着数据隐私和数据安全法规的不断加强,传统的集中式机器学习方法受到越来越多的限制。为了在分布式数据场景中高效训练模型,同时保护用户数据隐私,联邦学习(Federated Learning, FL)应运而生。它允许多个参与方在本地数据上训练模型,并通过共享模型参数而非原始数据,实现协同建模。
|
5天前
|
资源调度 算法 数据可视化
基于IEKF迭代扩展卡尔曼滤波算法的数据跟踪matlab仿真,对比EKF和UKF
本项目基于MATLAB2022A实现IEKF迭代扩展卡尔曼滤波算法的数据跟踪仿真,对比EKF和UKF的性能。通过仿真输出误差收敛曲线和误差协方差收敛曲线,展示三种滤波器的精度差异。核心程序包括数据处理、误差计算及可视化展示。IEKF通过多次迭代线性化过程,增强非线性处理能力;UKF避免线性化,使用sigma点直接处理非线性问题;EKF则通过一次线性化简化处理。
|
3天前
|
机器学习/深度学习 资源调度 算法
基于入侵野草算法的KNN分类优化matlab仿真
本程序基于入侵野草算法(IWO)优化KNN分类器,通过模拟自然界中野草的扩散与竞争过程,寻找最优特征组合和超参数。核心步骤包括初始化、繁殖、变异和选择,以提升KNN分类效果。程序在MATLAB2022A上运行,展示了优化后的分类性能。该方法适用于高维数据和复杂分类任务,显著提高了分类准确性。
|
18天前
|
算法 图形学 数据安全/隐私保护
基于NURBS曲线的数据拟合算法matlab仿真
本程序基于NURBS曲线实现数据拟合,适用于计算机图形学、CAD/CAM等领域。通过控制顶点和权重,精确表示复杂形状,特别适合真实对象建模和数据点光滑拟合。程序在MATLAB2022A上运行,展示了T1至T7的测试结果,无水印输出。核心算法采用梯度下降等优化技术调整参数,最小化误差函数E,确保迭代收敛,提供高质量的拟合效果。
|
17天前
|
存储 监控 算法
公司监控上网软件架构:基于 C++ 链表算法的数据关联机制探讨
在数字化办公时代,公司监控上网软件成为企业管理网络资源和保障信息安全的关键工具。本文深入剖析C++中的链表数据结构及其在该软件中的应用。链表通过节点存储网络访问记录,具备高效插入、删除操作及节省内存的优势,助力企业实时追踪员工上网行为,提升运营效率并降低安全风险。示例代码展示了如何用C++实现链表记录上网行为,并模拟发送至服务器。链表为公司监控上网软件提供了灵活高效的数据管理方式,但实际开发还需考虑安全性、隐私保护等多方面因素。
21 0
公司监控上网软件架构:基于 C++ 链表算法的数据关联机制探讨
|
23天前
|
存储 移动开发 算法
【狂热算法篇】解锁数据潜能:探秘前沿 LIS 算法
【狂热算法篇】解锁数据潜能:探秘前沿 LIS 算法
|
19天前
|
机器学习/深度学习 算法 数据安全/隐私保护
基于GRU网络的MQAM调制信号检测算法matlab仿真,对比LSTM
本研究基于MATLAB 2022a,使用GRU网络对QAM调制信号进行检测。QAM是一种高效调制技术,广泛应用于现代通信系统。传统方法在复杂环境下性能下降,而GRU通过门控机制有效提取时间序列特征,实现16QAM、32QAM、64QAM、128QAM的准确检测。仿真结果显示,GRU在低SNR下表现优异,且训练速度快,参数少。核心程序包括模型预测、误检率和漏检率计算,并绘制准确率图。
85 65
基于GRU网络的MQAM调制信号检测算法matlab仿真,对比LSTM
|
6天前
|
机器学习/深度学习 算法 数据安全/隐私保护
基于生物地理算法的MLP多层感知机优化matlab仿真
本程序基于生物地理算法(BBO)优化MLP多层感知机,通过MATLAB2022A实现随机数据点的趋势预测,并输出优化收敛曲线。BBO模拟物种在地理空间上的迁移、竞争与适应过程,以优化MLP的权重和偏置参数,提升预测性能。完整程序无水印,适用于机器学习和数据预测任务。
|
7天前
|
算法 数据安全/隐私保护
基于二次规划优化的OFDM系统PAPR抑制算法的matlab仿真
本程序基于二次规划优化的OFDM系统PAPR抑制算法,旨在降低OFDM信号的高峰均功率比(PAPR),以减少射频放大器的非线性失真并提高电源效率。通过MATLAB2022A仿真验证,核心算法通过对原始OFDM信号进行预编码,最小化最大瞬时功率,同时约束信号重构误差,确保数据完整性。完整程序运行后无水印,展示优化后的PAPR性能提升效果。
|
4天前
|
算法 数据安全/隐私保护 计算机视觉
基于sift变换的农田杂草匹配定位算法matlab仿真
本项目基于SIFT算法实现农田杂草精准识别与定位,运行环境为Matlab2022a。完整程序无水印,提供详细中文注释及操作视频。核心步骤包括尺度空间极值检测、关键点定位、方向分配和特征描述符生成。该算法通过特征匹配实现杂草定位,适用于现代农业中的自动化防控。

热门文章

最新文章