knn增强数据训练

简介: 【7月更文挑战第27天】

在 load_data 和 predict_image 函数中,将图像转换为灰度图像 (color.rgb2gray)。
将图像数据类型转换为 uint8 (img_as_ubyte)。
增强数据集包括彩色图像和其灰度版本。
这样做可以确保模型能够识别黑白打印的图像。你可以进一步调整数据增强策略,以更好地适应黑白图像的特性。

import os
from skimage import io, color, transform, img_as_ubyte
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
import numpy as np
import imgaug.augmenters as iaa
import pickle

# 导入UGOT机器人库
from ugot import ugot
import cv2

# 函数:从文件夹加载图像数据
def load_data(data_folder):
    data = []
    labels = []
    for folder in os.listdir(data_folder):
        folder_path = os.path.join(data_folder, folder)
        print("正在:", folder_path)  # 新增的打印语句
        if os.path.isdir(folder_path):
            label = folder
            for filename in os.listdir(folder_path):
                print("正在读取文件:", filename)  # 新增的打印语句
                img_path = os.path.join(folder_path, filename)
                img = io.imread(img_path)
                img = transform.resize(img, (50, 50))  # 调整图像大小为50x50像素
                img_gray = color.rgb2gray(img)  # 转换为灰度图
                img = img_as_ubyte(img_gray)  # 转换图像数据类型为 uint8
                data.append(img.flatten())  # 将图像数据展平
                labels.append(label)
    return data, labels

# 数据增强
def augment_data(images, labels, n_augmentations=5):
    aug = iaa.Sequential([
        iaa.Fliplr(0.5), # 水平翻转
        iaa.Affine(rotate=(-20, 20)), # 随机旋转
        iaa.Multiply((0.8, 1.2)), # 随机亮度
        iaa.Affine(scale=(0.9, 1.1)) # 随机缩放
    ])
    augmented_images = []
    augmented_labels = []
    for img, label in zip(images, labels):
        img = img.reshape((50, 50, -1))  # 恢复图像形状
        for _ in range(n_augmentations):
            augmented_img = aug(image=img)
            augmented_images.append(augmented_img.flatten())  # 展平图像
            augmented_labels.append(label)
    return augmented_images, augmented_labels

# 加载数据
data_folder = './data'
print(data_folder)
X, y = load_data(data_folder)

# 增强数据
X_augmented, y_augmented = augment_data(X, y)

# 将原始数据和增强数据合并
X_combined = X + X_augmented
y_combined = y + y_augmented

# 将数据拆分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X_combined, y_combined, test_size=0.2, random_state=42)

# 初始化KNN分类器
knn_classifier = KNeighborsClassifier(n_neighbors=3)

# 训练模型
knn_classifier.fit(X_train, y_train)

# 评估模型
accuracy = knn_classifier.score(X_test, y_test)
print("Accuracy:", accuracy)

# 保存模型
model_filename = 'knn_model.pkl'
with open(model_filename, 'wb') as model_file:
    pickle.dump(knn_classifier, model_file)

# 加载模型
with open(model_filename, 'rb') as model_file:
    loaded_model = pickle.load(model_file)

# 设置置信度阈值
confidence_threshold = 0.5

# 函数:预测图像内容
def predict_image(img):
    img = transform.resize(img, (50, 50))
    img_gray = color.rgb2gray(img)  # 转换为灰度图
    img = img_as_ubyte(img_gray)  # 转换图像数据类型为 uint8
    flattened_img = img.flatten()
    probs = loaded_model.predict_proba([flattened_img])
    print("Prediction Probabilities:", probs)  # 打印预测概率分布
    max_prob = np.max(probs)
    if max_prob < confidence_threshold:
        return "no"
    prediction = loaded_model.predict([flattened_img])
    return prediction[0]

def predict_and_display_image(img):
    img_resized = transform.resize(img, (50, 50))
    img_gray = color.rgb2gray(img_resized)  # 转换为灰度图
    img_resized = img_as_ubyte(img_gray)  # 转换图像数据类型为 uint8
    flattened_img = img_resized.flatten()
    probs = loaded_model.predict_proba([flattened_img])
    max_prob = np.max(probs)
    if max_prob < confidence_threshold:
        prediction = "no"
    else:
        prediction = loaded_model.predict([flattened_img])[0]
    print(prediction)  # 调试信息
    # 在图像上显示预测结果
    cv2.putText(img, f'Prediction: {prediction}', (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)

    cv2.imshow('Image', img)
    cv2.waitKey(1)
    # cv2.destroyAllWindows()


try:
    while True:
        frame = got.read_camera_data()
        if frame is not None:
            nparr = np.frombuffer(frame, np.uint8)
            img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)

            # 使用加载的模型进行预测并在图像上显示结果
            predict_and_display_image(img)

            # 在此处添加物体跟随和闭环控制逻辑
            # print('-------:', got.get_knn_result('knn_model'))
except KeyboardInterrupt:
    print('-----KeyboardInterrupt')

image.png

目录
相关文章
|
5天前
|
算法 Python
KNN
【9月更文挑战第11天】
24 13
|
5天前
|
存储 算法 测试技术
预见未来?Python线性回归算法:数据中的秘密预言家
【9月更文挑战第11天】在数据的海洋中,线性回归算法犹如智慧的预言家,助我们揭示未知。本案例通过收集房屋面积、距市中心距离等数据,利用Python的pandas和scikit-learn库构建房价预测模型。经过训练与测试,模型展现出较好的预测能力,均方根误差(RMSE)低,帮助房地产投资者做出更明智决策。尽管现实关系复杂多变,线性回归仍提供了有效工具,引领我们在数据世界中自信前行。
18 5
|
3天前
|
算法 大数据
K-最近邻(KNN)
K-最近邻(KNN)
|
14天前
|
机器学习/深度学习 算法 数据挖掘
R语言中的支持向量机(SVM)与K最近邻(KNN)算法实现与应用
【9月更文挑战第2天】无论是支持向量机还是K最近邻算法,都是机器学习中非常重要的分类算法。它们在R语言中的实现相对简单,但各有其优缺点和适用场景。在实际应用中,应根据数据的特性、任务的需求以及计算资源的限制来选择合适的算法。通过不断地实践和探索,我们可以更好地掌握这些算法并应用到实际的数据分析和机器学习任务中。
|
14天前
|
编解码 算法 图形学
同一路RTSP|RTMP流如何同时回调YUV和RGB数据实现渲染和算法分析
我们播放RTSP|RTMP流,如果需要同时做渲染和算法分析的话,特别是渲染在上层实现(比如Unity),算法是python这种情况,拉两路流,更耗费带宽和性能,拉一路流,同时回调YUV和RGB数据也可以,但是更灵活的是本文提到的按需转算法期望的RGB数据,然后做算法处理
|
1月前
|
存储 算法 大数据
小米教你:2GB内存搞定20亿数据的高效算法
你好,我是小米。本文介绍如何在2GB内存中找出20亿个整数里出现次数最多的数。通过将数据用哈希函数分至16个小文件,每份独立计数后选出频次最高的数,最终比对得出结果。这种方法有效解决大数据下的内存限制问题,并可应用于更广泛的场景。欢迎关注我的公众号“软件求生”,获取更多技术分享!
145 12
|
1月前
|
编解码 算法 Linux
Linux平台下RTSP|RTMP播放器如何跟python交互投递RGB数据供视觉算法分析
在对接Linux平台的RTSP播放模块时,需将播放数据同时提供给Python进行视觉算法分析。技术实现上,可在播放时通过回调函数获取视频帧数据,并以RGB32格式输出。利用`SetVideoFrameCallBackV2`接口设定缩放后的视频帧回调,以满足算法所需的分辨率。回调函数中,每收到一帧数据即保存为bitmap文件。Python端只需读取指定文件夹中的bitmap文件,即可进行视频数据的分析处理。此方案简单有效,但应注意控制输出的bitmap文件数量以避免内存占用过高。
|
1月前
|
机器学习/深度学习 人工智能 自然语言处理
深度学习的伦理困境:数据隐私与算法偏见
【8月更文挑战第9天】随着深度学习技术的飞速发展,其对个人隐私和数据安全的威胁日益凸显。本文探讨了深度学习在处理敏感信息时可能导致的数据泄露风险,以及训练数据中固有偏见如何影响算法公正性的问题。文章分析了当前隐私保护措施的局限性,并提出了减少算法偏见的方法。最后,本文讨论了如何在保障技术进步的同时,确保技术应用不侵犯个人权益,呼吁建立更为全面的伦理框架以指导深度学习的发展。
|
28天前
|
算法 搜索推荐
支付宝商业化广告算法问题之基于pretrain—>finetune范式的知识迁移中,finetune阶段全参数训练与部分参数训练的效果如何比较
支付宝商业化广告算法问题之基于pretrain—>finetune范式的知识迁移中,finetune阶段全参数训练与部分参数训练的效果如何比较
|
1月前
|
存储 算法
【C算法】编程初学者入门训练140道(1~20)
【C算法】编程初学者入门训练140道(1~20)