DL之RBM:(sklearn自带数据集为1797个样本*64个特征+5倍数据集)深度学习之BRBM模型学习+LR进行分类实现手写数字图识别

简介: DL之RBM:(sklearn自带数据集为1797个样本*64个特征+5倍数据集)深度学习之BRBM模型学习+LR进行分类实现手写数字图识别

输出结果

image.png



image.png





实现代码


from __future__ import print_function

print(__doc__)

import numpy as np              

import matplotlib.pyplot as plt  

from scipy.ndimage import convolve

from sklearn import linear_model, datasets, metrics  

from sklearn.cross_validation import train_test_split

from sklearn.neural_network import BernoulliRBM  

from sklearn.pipeline import Pipeline            

def nudge_dataset(X, Y):  

direction_vectors = [

   [[0, 1, 0],[0, 0, 0],[0, 0, 0]],

   [[0, 0, 0],[1, 0, 0],[0, 0, 0]],

   [[0, 0, 0],[0, 0, 1],[0, 0, 0]],

   [[0, 0, 0],[0, 0, 0],[0, 1, 0]]

   ]

shift = lambda x, w: convolve(x.reshape((8, 8)), mode='constant',weights=w).ravel()

X = np.concatenate([X] +

[np.apply_along_axis(shift, 1, X, vector)

for vector in direction_vectors])

Y = np.concatenate([Y for _ in range(5)], axis=0)

return X, Y

digits = datasets.load_digits()

X = np.asarray(digits.data, 'float32')

X, Y = nudge_dataset(X, digits.target)  

X = (X - np.min(X, 0)) / (np.max(X, 0) + 0.0001)

X_train, X_test, Y_train, Y_test = train_test_split(X, Y,test_size=0.2,random_state=0)

logistic = linear_model.LogisticRegression()

rbm = BernoulliRBM(random_state=0, verbose=True)

classifier = Pipeline(steps=[('rbm', rbm), ('logistic', logistic)])

rbm.learning_rate = 0.06  

rbm.n_iter = 20

# More components tend to give better prediction performance, but larger fitting time

rbm.n_components = 100

logistic.C = 6000.0

classifier.fit(X_train, Y_train)  

logistic_classifier = linear_model.LogisticRegression(C=100.0)

logistic_classifier.fit(X_train, Y_train)

print()

print("Logistic regression using RBM features:\n%s\n" % (

   metrics.classification_report(

       Y_test,classifier.predict(X_test)  

       )

   ))

print("Logistic regression using raw pixel features:\n%s\n" % (

metrics.classification_report(

Y_test,

logistic_classifier.predict(X_test))))

plt.figure(figsize=(4.2, 4))

for i, comp in enumerate(rbm.components_):

plt.subplot(10, 10, i + 1)

plt.imshow(comp.reshape((8, 8)), cmap=plt.cm.gray_r,

interpolation='nearest')

plt.xticks(())

plt.yticks(())

plt.suptitle('100 components extracted by RBM', fontsize=16)

plt.subplots_adjust(0.08, 0.02, 0.92, 0.85, 0.08, 0.23)

plt.show()


相关文章
|
1月前
|
机器学习/深度学习 数据采集 TensorFlow
使用Python实现智能食品消费模式分析的深度学习模型
使用Python实现智能食品消费模式分析的深度学习模型
126 70
|
2月前
|
机器学习/深度学习 数据采集 供应链
使用Python实现智能食品库存管理的深度学习模型
使用Python实现智能食品库存管理的深度学习模型
190 63
|
22天前
|
机器学习/深度学习 数据可视化 TensorFlow
使用Python实现深度学习模型的分布式训练
使用Python实现深度学习模型的分布式训练
165 73
|
6天前
|
机器学习/深度学习 存储 人工智能
MNN:阿里开源的轻量级深度学习推理框架,支持在移动端等多种终端上运行,兼容主流的模型格式
MNN 是阿里巴巴开源的轻量级深度学习推理框架,支持多种设备和主流模型格式,具备高性能和易用性,适用于移动端、服务器和嵌入式设备。
49 18
MNN:阿里开源的轻量级深度学习推理框架,支持在移动端等多种终端上运行,兼容主流的模型格式
|
1月前
|
机器学习/深度学习 数据采集 TensorFlow
使用Python实现智能食品消费习惯分析的深度学习模型
使用Python实现智能食品消费习惯分析的深度学习模型
144 68
|
1月前
|
机器学习/深度学习 数据采集 数据挖掘
使用Python实现智能食品消费市场分析的深度学习模型
使用Python实现智能食品消费市场分析的深度学习模型
117 36
|
25天前
|
机器学习/深度学习 数据采集 供应链
使用Python实现智能食品消费需求分析的深度学习模型
使用Python实现智能食品消费需求分析的深度学习模型
79 21
|
27天前
|
机器学习/深度学习 数据采集 搜索推荐
使用Python实现智能食品消费偏好预测的深度学习模型
使用Python实现智能食品消费偏好预测的深度学习模型
75 23
|
28天前
|
机器学习/深度学习 数据采集 数据挖掘
使用Python实现智能食品消费习惯预测的深度学习模型
使用Python实现智能食品消费习惯预测的深度学习模型
106 19
|
29天前
|
机器学习/深度学习 数据采集 数据挖掘
使用Python实现智能食品消费趋势分析的深度学习模型
使用Python实现智能食品消费趋势分析的深度学习模型
114 18