tensorflow2.0回归模型---如何用好keras对sklearn的api

简介: tensorflow2.0回归模型---如何用好keras对sklearn的api

之前写过如何用tf.keras搭建模型,那个时候埋下了一个伏笔,就是超参数搜索的问题。如何得到最好的模型,我们用sklearn的时候就是GridSearchCV或者RandomizedSearchCV,所以我今天想讲讲怎么通过tf.keras的api来实现超参数搜索。


1.看看官方文档的介绍


官方文档


image.png


发现调用这个api只需要写一个build_fn,也就是写一个搭建网络的函数,知道这个之后就来实战看看。


2.实现超参数搜索


  1. 首先导入数据集,我使用的是California Housing dataset
  2. 切分好训练集,验证集和测试集,并且对数据标准化
  3. 写好我们需要的网络模型,调用keras.wrappers.scikit_learn.KerasRegressor实现model可以使用sklearn的方法
  4. 训练得到最好的参数以及模型


代码:


import pandas as pd
import numpy as np
import tensorflow as tf
from tensorflow import keras
import matplotlib
import matplotlib.pyplot as plt
import sklearn
import os,sys
from sklearn.datasets import fetch_california_housing
housing=fetch_california_housing()
house=pd.DataFrame(housing.data)
house.columns=housing.feature_names
house['price']=housing.target
house.info()
house.head(10)
from sklearn.model_selection import train_test_split
x_train_all,x_test,y_train_all,y_test=train_test_split(housing.data,housing.target,test_size=0.25,random_state=2)
x_train,x_valid,y_train,y_valid=train_test_split(x_train_all,y_train_all,test_size=0.25,random_state=2)
#标准化
from sklearn.preprocessing import StandardScaler
scaler=StandardScaler()
x_train_scaled=scaler.fit_transform(x_train)
x_valid_scaled=scaler.fit_transform(x_valid)
x_test_scaled=scaler.fit_transform(x_test)
#构建自己的模型
def build_model(hidden_layers=1,layer_size=30,learning_rate=3e-3):
    model=keras.models.Sequential()
    model.add(keras.layers.Dense(layer_size,activation="relu",input_shape=x_train.shape[1:]))
    for _ in range(hidden_layers-1):
        model.add(keras.layers.Dense(layer_size,activation="relu"))
    model.add(keras.layers.Dense(1))
    optimizer=keras.optimizers.SGD(learning_rate)
    model.compile(loss="mse",optimizer=optimizer)
    return model
import datetime
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
logdir = os.path.join('logs', current_time)
callbacks=[
    keras.callbacks.EarlyStopping(patience=3,min_delta=1e-3),
    keras.callbacks.TensorBoard(log_dir=logdir)
]
model=keras.wrappers.scikit_learn.KerasRegressor(build_model)
callback=[keras.callbacks.EarlyStopping(patience=3,min_delta=1e-3)]
#使用默认参数的模型
history=model.fit(x_train_scaled,y_train,validation_data=(x_valid_scaled,y_valid),epochs=100,callbacks=callback)
def plot_learning_curves(history):
    pd.DataFrame(history.history).plot(figsize=(8,5))
    plt.grid()
    plt.gca().set_ylim(0,1)
    plt.show()
plot_learning_curves(history)
#实现超参数搜索
from scipy.stats import reciprocal
param_distribution={
    "hidden_layers":[3,4],
    "layer_size":np.arange(24,27),
    "learning_rate":reciprocal(0.001,0.005)
}
#随机搜索
from sklearn.model_selection import RandomizedSearchCV
random_searchcv=RandomizedSearchCV(model,param_distribution,n_iter=10,verbose=0)
random_searchcv.fit(x_train_scaled,y_train,validation_data=(x_valid_scaled,y_valid),epochs=100)
print("得到的最好参数为:",random_searchcv.best_params_)
print("最好的得分为:",random_searchcv.best_score_)
model=random_searchcv.best_estimator_.model
history1=model.fit(x_train_scaled,y_train,validation_data=(x_valid_scaled,y_valid),epochs=100,callbacks=callbacks)
plot_learning_curves(history1)
print(model.evaluate(x_test_scaled,y_test,verbose=0))


image.png

image.png

默认参数的模型效果


image.png

使用随机搜索的结果

image.png


打开tensorboard看看模型


image.png

3.总结:


虽然使用超参数搜索很方便,但是也有一些需要注意的地方。比如需要知道哪些参数比较重要,就尽量精确;相反的就可以稍微放宽要求。毕竟不论是使用网格搜索还是随机搜索如果参数太多会导致计算时间很久。BTW,在使用超参数搜索的时候,我想把n_jobs设置成大于1的,运行就会报错,还没有找到很好的办法解决,这也就导致了对数据的计算时间会更久,终究调参还是一个需要经验的事情鸭!

目录
相关文章
|
16天前
|
机器学习/深度学习 人工智能 算法
鸟类识别系统Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+ResNet50算法模型+图像识别
鸟类识别系统。本系统采用Python作为主要开发语言,通过使用加利福利亚大学开源的200种鸟类图像作为数据集。使用TensorFlow搭建ResNet50卷积神经网络算法模型,然后进行模型的迭代训练,得到一个识别精度较高的模型,然后在保存为本地的H5格式文件。在使用Django开发Web网页端操作界面,实现用户上传一张鸟类图像,识别其名称。
60 12
鸟类识别系统Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+ResNet50算法模型+图像识别
|
1月前
|
人工智能 Serverless API
一键服务化:从魔搭开源模型到OpenAI API服务
在多样化大模型的背后,OpenAI得益于在领域的先发优势,其API接口今天也成为了业界的一个事实标准。
一键服务化:从魔搭开源模型到OpenAI API服务
|
2月前
|
机器学习/深度学习 存储 前端开发
实战揭秘:如何借助TensorFlow.js的强大力量,轻松将高效能的机器学习模型无缝集成到Web浏览器中,从而打造智能化的前端应用并优化用户体验
【8月更文挑战第31天】将机器学习模型集成到Web应用中,可让用户在浏览器内体验智能化功能。TensorFlow.js作为在客户端浏览器中运行的库,提供了强大支持。本文通过问答形式详细介绍如何使用TensorFlow.js将机器学习模型带入Web浏览器,并通过具体示例代码展示最佳实践。首先,需在HTML文件中引入TensorFlow.js库;接着,可通过加载预训练模型如MobileNet实现图像分类;然后,编写代码处理图像识别并显示结果;此外,还介绍了如何训练自定义模型及优化模型性能的方法,包括模型量化、剪枝和压缩等。
34 1
|
2月前
|
机器学习/深度学习 人工智能 PyTorch
AI智能体研发之路-模型篇(五):pytorch vs tensorflow框架DNN网络结构源码级对比
AI智能体研发之路-模型篇(五):pytorch vs tensorflow框架DNN网络结构源码级对比
62 1
|
2月前
|
API UED 开发者
如何在Uno Platform中轻松实现流畅动画效果——从基础到优化,全方位打造用户友好的动态交互体验!
【8月更文挑战第31天】在开发跨平台应用时,确保用户界面流畅且具吸引力至关重要。Uno Platform 作为多端统一的开发框架,不仅支持跨系统应用开发,还能通过优化实现流畅动画,增强用户体验。本文探讨了Uno Platform中实现流畅动画的多个方面,包括动画基础、性能优化、实践技巧及问题排查,帮助开发者掌握具体优化策略,提升应用质量与用户满意度。通过合理利用故事板、减少布局复杂性、使用硬件加速等技术,结合异步方法与预设缓存技巧,开发者能够创建美观且流畅的动画效果。
57 0
|
2月前
|
C# 开发者 前端开发
揭秘混合开发新趋势:Uno Platform携手Blazor,教你一步到位实现跨平台应用,代码复用不再是梦!
【8月更文挑战第31天】随着前端技术的发展,混合开发日益受到开发者青睐。本文详述了如何结合.NET生态下的两大框架——Uno Platform与Blazor,进行高效混合开发。Uno Platform基于WebAssembly和WebGL技术,支持跨平台应用构建;Blazor则让C#成为可能的前端开发语言,实现了客户端与服务器端逻辑共享。二者结合不仅提升了代码复用率与跨平台能力,还简化了项目维护并增强了Web应用性能。文中提供了从环境搭建到示例代码的具体步骤,并展示了如何创建一个简单的计数器应用,帮助读者快速上手混合开发。
45 0
|
2月前
|
UED 开发工具 iOS开发
Uno Platform大揭秘:如何在你的跨平台应用中,巧妙融入第三方库与服务,一键解锁无限可能,让应用功能飙升,用户体验爆棚!
【8月更文挑战第31天】Uno Platform 让开发者能用同一代码库打造 Windows、iOS、Android、macOS 甚至 Web 的多彩应用。本文介绍如何在 Uno Platform 中集成第三方库和服务,如 Mapbox 或 Google Maps 的 .NET SDK,以增强应用功能并提升用户体验。通过 NuGet 安装所需库,并在 XAML 页面中添加相应控件,即可实现地图等功能。尽管 Uno 平台减少了平台差异,但仍需关注版本兼容性和性能问题,确保应用在多平台上表现一致。掌握正确方法,让跨平台应用更出色。
34 0
|
2月前
|
开发者 算法 虚拟化
惊爆!Uno Platform 调试与性能分析终极攻略,从工具运用到代码优化,带你攻克开发难题成就完美应用
【8月更文挑战第31天】在 Uno Platform 中,调试可通过 Visual Studio 设置断点和逐步执行代码实现,同时浏览器开发者工具有助于 Web 版本调试。性能分析则利用 Visual Studio 的性能分析器检查 CPU 和内存使用情况,还可通过记录时间戳进行简单分析。优化性能涉及代码逻辑优化、资源管理和用户界面简化,综合利用平台提供的工具和技术,确保应用高效稳定运行。
40 0
|
2月前
|
前端开发 开发者 设计模式
揭秘Uno Platform状态管理之道:INotifyPropertyChanged、依赖注入、MVVM大对决,帮你找到最佳策略!
【8月更文挑战第31天】本文对比分析了 Uno Platform 中的关键状态管理策略,包括内置的 INotifyPropertyChanged、依赖注入及 MVVM 框架。INotifyPropertyChanged 方案简单易用,适合小型项目;依赖注入则更灵活,支持状态共享与持久化,适用于复杂场景;MVVM 框架通过分离视图、视图模型和模型,使状态管理更清晰,适合大型项目。开发者可根据项目需求和技术栈选择合适的状态管理方案,以实现高效管理。
32 0
|
2月前
|
Apache 开发者 Java
Apache Wicket揭秘:如何巧妙利用模型与表单机制,实现Web应用高效开发?
【8月更文挑战第31天】本文深入探讨了Apache Wicket的模型与表单处理机制。Wicket作为一个组件化的Java Web框架,提供了多种模型实现,如CompoundPropertyModel等,充当组件与数据间的桥梁。文章通过示例介绍了模型创建及使用方法,并详细讲解了表单组件、提交处理及验证机制,帮助开发者更好地理解如何利用Wicket构建高效、易维护的Web应用程序。
28 0

热门文章

最新文章

下一篇
无影云桌面