在浏览器上也能训练神经网络?TensorFlow.js带你玩游戏~

在线体验各类最新模型,更有模型 免费Token 额度领取!
立即体验
简介: 一直以来训练神经网络给我们的印象都是复杂、耗时、对硬件要求高。你有没有想过有一天在浏览器上也能训练神经网络~ 本文通过一篇详细的TensorFlow.js教程,带你玩一个用浏览器训练神经网络的游戏!

How to train neural network on browser

无论你是刚开始深度学习,亦或是个老练的老手,建立一个神经网络的训练环境有时都会很痛苦。让神经网络的训练像加载一个网页,然后点击几下,然后你就准备好马上进行推理,会不会是件很棒的事呢?(那必须棒)

在本教程中,我将向你展示如何使用浏览器上的框架 TensorFlow.js 构建一个模型,其中包含从你的网络摄像头收集到的数据,并在你的浏览器上进行训练。为了使模型有用,我们将把一个摄像头变成一个游戏 - Pong。

来玩个游戏先!

准备工作:

  1. 下载dist.zip [1] 并将其解压缩到你的本地机器上。
  2. 安装一个HTTP服务器,我的建议是通过 npm 在全球范围内安装http-server。
npm install -g http-server

你会问什么是 npm?它是Node.js的包安装程序,就像 python 的 pip 一样,可以在 [2] 获得。

在dist文件夹所在的命令行中运行以下命令,以便在端口上为本地计算机上的Web应用程序提供服务,例如1234。

http-server dist --cors -p 1234 -s

将浏览器窗口指向http://localhost:1234 ,我已经在Chrome和Firefox上进行了测试。

当页面完成加载后,开始收集三个动作的训练图像,左、中、右。在这里有一个提示,平衡训练样本,每个case可能大约有20个样本。

点击“TRAIN”,开始训练,并显示loss。

如果loss没有变化了,那么训练结束,现在点击“PLAY”开始游戏。

如果想重新开始,点击“RESET”。

让我们来看看游戏是如何构建的。本教程中使用了两种模型,第一种是一个预先训练过的卷积网络,它是从Keras导出的,它负责从网络摄像头图像中提取图像特征。第二个模型在你的浏览器上建立和训练,用图像特征对游戏控制进行预测。它是一个回归模型,预测值在-1~1之间,以控制玩家的paddle速度。它本质上是一个迁移学习任务。更多关于迁移学习的主题,请参考 [3]。这里不做进一步的讨论,可以从我Github [4]上下载源代码。

将预训练模型到处到tfjs


如果你只想学习web应用程序部分,可以跳过本节。

让我们首先将一个预先训练过的卷积网络导出到 tensorflow.js(tfjs) 格式。我选择使用本教程中的 ImageNet 数据集训练的 DenseNet,但是你可以使用其他模型,如MobileNet。尽量避免像 ResNet 和 VGGs 这样的大型深度卷积网络,尽管它们可能提供更高的精度,但不适合像我们这样运行在浏览器上的边缘设备。

第一步是在python脚本中将经过预先训练的 DenNet 的 keras 模型保存到一个.h5文件中。

from keras.applications.densenet import DenseNet121

model = DenseNet121(input_shape=(224, 224, 3), 
                    weights='imagenet')
model.save('./tfjs-densenet/model.h5')

然后运行转换脚本将.h5文件转换为浏览器缓存优化的tfjs文件。在继续之前,通过pip3安装tensorflowjs转换脚本python包。

pip3 install tensorflowjs

我们现在可以通过运行生成tfjs文件:

cd ./tfjs-densenet
tensorflowjs_converter --input_format keras 
./model.h5 ./model

你会看到一个名为 model 的文件夹,里面有几个文件。model.json文件定义了模型结构和权重文件的路径。经过预先训练的模型可以为 web 应用程序提供服务。例如,你可以将模型文件夹重命名为 serveDenseNet 并复制到你的 web app served文件夹,然后可以像这样加载模型:

const modelPath = window.location.origin + 
                    '/serveDenseNet/model.json'
const pretrainedNet = await tf.loadModel(modelPath);
const layer = pretrainedNet.getLayer(
                'conv5_block16_concat');
// Feature extractor model
cnnNet = tf.model({inputs: pretrainedNet.inputs,
                 outputs: layer.output});

window.location.origin 是web应用程序url,或者如果你在1234端口本地为其提供服务,它将是 localhost:1234。await 语句只允许 Web 应用程序在后台加载模型,而不冻结主用户界面。

另外,由于我们加载的模型是一个图像分类模型,顶层我们不需要,我们只需要模型的特征提取部分,解决方案是定位最顶层的卷积层,并截断前面代码片段中显示的模型。

从网络摄像头生成训练数据

为了准备回归模型的训练数据,我们将从网络摄像头抓取一些图像,并在Web应用程序中用预先训练的模型提取它们的特征。为了简化用于获取训练数据的用户界面,我们仅用三个值中的一个标记[-1, 0, 1 ]。

对于通过网络摄像头获取的每一幅图像,它都会被输入预先训练的 DenseNet 中提取特征并保存为训练样本。在通过特征提取器模型传递图像后,224×224彩色图像的维数将降为图像特征张量 [7,7,1024],大小取决于你所选择的预训练模型,并且可以通过在前面一节中选择的图层调用outputShape来获得,如下所示。

modelLayerShape = layer.outputShape.slice(1)

将提取的图像特征作为训练数据而不是原始图像的原因有两方面:一是节省了存储训练数据的内存,二是不运行特征提取模型,减少了训练时间。

下面的片段显示了一个图像是如何被网络摄像头捕获、提取和聚合的。请注意,所有图像特征都是以张量的形式保存的,这意味着如果你的模型运行在浏览器的WebGL后端,那么它一次可以在GPU内存中安全地包含多少个训练样本是有限制的。因此,不要期望使用数千甚至数百个图像样本来训练你的模型,这取决于你的硬件。

const img = webcam.capture();
controllerDataset.addExample(cnnNet.predict(img), 
                            CONTROLS_VALUES[label]);

神经网络的建立与训练


在不上传到任何云服务的情况下,建立和训练你的神经网络保护了你的隐私,因为数据永远不会离开你的设备,在你的浏览器上观察它的发生,让它变得更酷。

回归模型以图像特征作为输入,将其压平到一个向量,然后接着两个全连接层,生成一个浮点数来控制游戏。最后一个全连接层不需要激活函数,因为我们希望它产生实数在-1到1之间。我们选择的损失函数是训练过程中的均方误差,以最小化损失。更多选择可以阅读我的帖子,比如如何选择最后一层激活和损失函数[5]。

下面的代码将构建、编译和匹配模型。看起来非常类似于keras的工作流,对吗?

model = tf.sequential({
   layers: [
     tf.layers.flatten({inputShape: modelLayerShape}),
     // Layer 1
     tf.layers.dense({
       units: 100,
       activation: 'relu',
       kernelInitializer: 'varianceScaling',
       useBias: true
     }),
     // Layer 2.
     tf.layers.dense({
       units: 1,
       kernelInitializer: 'varianceScaling',
       useBias: false,
     })
   ]
});

// Creates the optimizers which drives training of 
//the model.const optimizer = tf.train.adam(
                               ui.getLearningRate());
model.compile({optimizer: optimizer, 
             loss: 'meanSquaredError'});

let batchSize = 32
// Train the model! Model.fit() will shuffle xs & ys
//so we don't have to.
model.fit(controllerDataset.xs, controllerDataset.ys,
 {
   batchSize,
   epochs: 10
});

将摄像头变成Pong控制器


你可能期望使用类似于Keras语法的图像进行预测。该图像首先被转换成图像特征,然后传递到经过训练的回归神经网络,该神经网络输出控制器值在-1到1之间。
// Capture the frame from the webcam.
const img = webcam.capture();

// Make a prediction through mobilenet, 
//getting the internal activation of
// the mobilenet model.
const activation = cnnNet.predict(img);

// Make a prediction through our newly-trained model
//using the activation
// from mobilenet as input.
const predictions = model.predict(activation);

// The predicted value between -1~1.
predictions.as1D();

一旦你对模型进行了训练,游戏开始运行,预测值就会通过这个调用 pong.updatePlayerSpeed(value) 来控制玩家paddle向左或向右移动的速度。你可以通过调用一下函数来启动和停止游戏:

pong.startGameplay():按下Play按钮该函数将被调用

pong.stopGameplay():按下Reset按钮该函数将被调用

可以通过调用 pong.updateMultiplier(multiplier) 来调整 paddle 运动的侵略性,在Pong类构造函数中,当前的multiplier值设置为12。

结论与探讨

在本教程中,你已经学习了如何在带有TensorFlow.js的浏览器上训练神经网络,并将你的网络摄像头转换为识别你的动作的Pong控制器。可以自由地查看我的源代码并对其进行实验、修改,比如激活函数、损失函数和切换另一个预训练模型等等,看看结果如何。用即时反馈在浏览器上训练神经网络的美妙之处,使我们能够更快地尝试新的想法,并为我们的原型获得更快的结果。

原文发布时间为:2018-07-30
本文作者:Chengwei Zhang
本文来自云栖社区合作伙伴“专知”,了解相关信息可以关注“专知

相关文章
|
11月前
|
机器学习/深度学习 人工智能 算法
AI 基础知识从 0.6 到 0.7—— 彻底拆解深度神经网络训练的五大核心步骤
本文以一个经典的PyTorch手写数字识别代码示例为引子,深入剖析了简洁代码背后隐藏的深度神经网络(DNN)训练全过程。
1559 56
|
SQL 分布式计算 Serverless
鹰角网络:EMR Serverless Spark 在《明日方舟》游戏业务的应用
鹰角网络为应对游戏业务高频活动带来的数据潮汐、资源弹性及稳定性需求,采用阿里云 EMR Serverless Spark 替代原有架构。迁移后实现研发效率提升,支持业务快速发展、计算效率提升,增强SLA保障,稳定性提升,降低运维成本,并支撑全球化数据架构部署。
1454 56
鹰角网络:EMR Serverless Spark 在《明日方舟》游戏业务的应用
|
机器学习/深度学习 存储 算法
NoProp:无需反向传播,基于去噪原理的非全局梯度传播神经网络训练,可大幅降低内存消耗
反向传播算法虽是深度学习基石,但面临内存消耗大和并行扩展受限的问题。近期,牛津大学等机构提出NoProp方法,通过扩散模型概念,将训练重塑为分层去噪任务,无需全局前向或反向传播。NoProp包含三种变体(DT、CT、FM),具备低内存占用与高效训练优势,在CIFAR-10等数据集上达到与传统方法相当的性能。其层间解耦特性支持分布式并行训练,为无梯度深度学习提供了新方向。
860 1
NoProp:无需反向传播,基于去噪原理的非全局梯度传播神经网络训练,可大幅降低内存消耗
|
9月前
|
机器学习/深度学习 数据可视化 网络架构
PINN训练新思路:把初始条件和边界约束嵌入网络架构,解决多目标优化难题
PINNs训练难因多目标优化易失衡。通过设计硬约束网络架构,将初始与边界条件内嵌于模型输出,可自动满足约束,仅需优化方程残差,简化训练过程,提升稳定性与精度,适用于气候、生物医学等高要求仿真场景。
976 4
PINN训练新思路:把初始条件和边界约束嵌入网络架构,解决多目标优化难题
|
编解码 JavaScript 前端开发
【Java进阶】详解JavaScript的BOM(浏览器对象模型)
总的来说,BOM提供了一种方式来与浏览器进行交互。通过BOM,你可以操作窗口、获取URL、操作历史、访问HTML文档、获取浏览器信息和屏幕信息等。虽然BOM并没有正式的标准,但大多数现代浏览器都实现了相似的功能,因此,你可以放心地在你的JavaScript代码中使用BOM。
408 23
|
存储 监控 算法
公司内部网络监控中的二叉搜索树算法:基于 Node.js 的实时设备状态管理
在数字化办公生态系统中,公司内部网络监控已成为企业信息安全管理体系的核心构成要素。随着局域网内终端设备数量呈指数级增长,实现设备状态的实时追踪与异常节点的快速定位,已成为亟待解决的关键技术难题。传统线性数据结构在处理动态更新的设备信息时,存在检索效率低下的固有缺陷;而树形数据结构因其天然的分层特性与高效的检索机制,逐渐成为网络监控领域的研究热点。本文以二叉搜索树(Binary Search Tree, BST)作为研究对象,系统探讨其在公司内部网络监控场景中的应用机制,并基于 Node.js 平台构建一套具备实时更新与快速查询功能的设备状态管理算法框架。
413 3
|
机器学习/深度学习 文件存储 异构计算
YOLOv11改进策略【模型轻量化】| 替换骨干网络为EfficientNet v2,加速训练,快速收敛
YOLOv11改进策略【模型轻量化】| 替换骨干网络为EfficientNet v2,加速训练,快速收敛
1518 18
YOLOv11改进策略【模型轻量化】| 替换骨干网络为EfficientNet v2,加速训练,快速收敛
|
机器学习/深度学习 人工智能 算法
基于Python深度学习的【害虫识别】系统~卷积神经网络+TensorFlow+图像识别+人工智能
害虫识别系统,本系统使用Python作为主要开发语言,基于TensorFlow搭建卷积神经网络算法,并收集了12种常见的害虫种类数据集【"蚂蚁(ants)", "蜜蜂(bees)", "甲虫(beetle)", "毛虫(catterpillar)", "蚯蚓(earthworms)", "蜚蠊(earwig)", "蚱蜢(grasshopper)", "飞蛾(moth)", "鼻涕虫(slug)", "蜗牛(snail)", "黄蜂(wasp)", "象鼻虫(weevil)"】 再使用通过搭建的算法模型对数据集进行训练得到一个识别精度较高的模型,然后保存为为本地h5格式文件。最后使用Djan
851 1
基于Python深度学习的【害虫识别】系统~卷积神经网络+TensorFlow+图像识别+人工智能
|
机器学习/深度学习 人工智能 算法
基于Python深度学习的【蘑菇识别】系统~卷积神经网络+TensorFlow+图像识别+人工智能
蘑菇识别系统,本系统使用Python作为主要开发语言,基于TensorFlow搭建卷积神经网络算法,并收集了9种常见的蘑菇种类数据集【"香菇(Agaricus)", "毒鹅膏菌(Amanita)", "牛肝菌(Boletus)", "网状菌(Cortinarius)", "毒镰孢(Entoloma)", "湿孢菌(Hygrocybe)", "乳菇(Lactarius)", "红菇(Russula)", "松茸(Suillus)"】 再使用通过搭建的算法模型对数据集进行训练得到一个识别精度较高的模型,然后保存为为本地h5格式文件。最后使用Django框架搭建了一个Web网页平台可视化操作界面,
1429 11
基于Python深度学习的【蘑菇识别】系统~卷积神经网络+TensorFlow+图像识别+人工智能
|
Web App开发 前端开发 JavaScript
折腾之王:JavaScript之父Brave浏览器与BAT的诞生
2015年,JavaScript之父Brendan Eich再次创业,推出Brave浏览器和加密货币Basic Attention Token(BAT),旨在颠覆传统广告行业。Brave屏蔽广告、保护隐私,加载速度快;BAT则通过奖励机制让用户、内容创作者和广告主三方受益。尽管面临用户习惯和巨头竞争的挑战,Brave已拥有超4000万月活跃用户,成为全球增长最快的隐私浏览器,引领Web3生态发展。
592 22
折腾之王:JavaScript之父Brave浏览器与BAT的诞生

热门文章

最新文章