基于TensorFlow.js的JavaScript机器学习-Hello World

简介: 我们有一些基于TensorFlow.js的JavaScript机器学习尝试,可以分享一点心得。

image.png

前言

2017年的双十一为了解决运营图片审核任务繁重的问题,我们发起了素材智能审核项目。在这个项目中,我们基于深度学习拿到了很好的项目结果,至今已经审核数千万张图片。

后续我们也尝试做一个 JavaScript 版本 TensorFlow - Tens.js(github.com/tensjs/tens),不过发现很多问题比较难解决。好在后续 TensorFlow 官方发布了 JavaScript 版本,并在近期发布 2.0 版本。

最近几年深度学习在人工智能领域取得了非凡进展,在很多任务中远远超过人类的表现,我相信深度学习后续会在更多的工作场景中被广泛使用。随着 TensorFlow.js 的发布,我们的学习成本进一步降低,我认为深度学习的工程化、大众化是接下来的必然趋势。

介绍

本系列不需要深度学习基础,会避免使用数学符号,会通过代码示例来介绍概念。 代码示例使用 JavaScript,使用 TensorFlow.js(浏览器/Node.js)框架。

Hello World

几乎所有深度学习相关的教程都会以 Minist 项目为例,Mnist 是机器学习一个经典的数据集,包含 60000 张训练图片和 10000 张测试图片,在本项目中我们会基于这些图片数据通过深度学习训练出一个图像识别的模型。 图像目前也是深度学习最有优势的场景,通过该项目,我们也可以理解TensorFlow.js 的基础用法和深度学习的一些核心概念。 简单来看,通过深度学习解决一个问题一般抽象为下面的流程

3.1. 数据预处理

对于浏览器环境来说,数据处理与其他环境相比还是比较麻烦。这次 Mnist 我们使用单张雪碧图来存储,对应 label 使用二进制进行存储。我们可以直接从 Google 的 url 中获取。

3.1.1. 图片数据预处理

图片数据的处理在浏览器场景非常有用,无论是图片数据集处理、模型预测时的上传图片处理都经常用到

const MNIST_IMAGES_SPRITE_PATH =
    'https://storage.googleapis.com/learnjs-data/model-builder/mnist_images.png';
const MNIST_LABELS_PATH =
    'https://storage.googleapis.com/learnjs-data/model-builder/mnist_labels_uint8';

在浏览器将图片转换成二进制数据,需要

const img = new Image();
const canvas = document.createElement('canvas');
const ctx = canvas.getContext('2d');
img.crossOrigin = '';
// 图片加载后,使用 canvas drawImage,然后 getImageData 获取二进制数据
img.onload = () => {
  img.width = img.naturalWidth;
  img.height = img.naturalHeight;
  ctx.drawImage(img, 0, i * chunkSize, img.width, chunkSize, 0, 0, img.width,chunkSize);
  const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);
}
img.src = MNIST_IMAGES_SPRITE_PATH;

3.1.2. ArrayBuffer & DataView

在 Tensorflow.js 的使用中,会经常使用 ArrayBuffer 类型,比如Canvas、Fetch API、File API。

// Canvas Uint8ClampedArray
const canvas = document.getElementById('myCanvas');
const ctx = canvas.getContext('2d');
const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);
const uint8ClampedArray = imageData.data; 
// Fetch ArrayBuffer
fetch(url)
.then(function(response){
  return response.arrayBuffer()
})
.then(function(arrayBuffer){
  // ...
});
// File ArrayBuffer
const fileInput = document.getElementById('fileInput');
const file = fileInput.files<a href="https://storage.googleapis.com/tfjs-examples/mnist/dist/index.html">0];
const reader = new FileReader();
reader.readAsArrayBuffer(file);
reader.onload = function () {
  const arrayBuffer = reader.result;
  // ···
};

3.1.3. 测试数据与验证数据

如果用全量数据来训练模型,模型很容易和数据集过度拟合,并不一定能很好的应对未来产生的新数据。为了解决这个问题,一般我们会将数据分为多份,分别用来做训练和验证使用。

image.png

this.datasetLabels = new Uint8Array(await labelsResponse.arrayBuffer());
  // Slice the the images and labels into train and test sets.
  this.trainImages =
      this.datasetImages.slice(0, IMAGE_SIZE * NUM_TRAIN_ELEMENTS);
  this.testImages = this.datasetImages.slice(IMAGE_SIZE * NUM_TRAIN_ELEMENTS);
  this.trainLabels =
      this.datasetLabels.slice(0, NUM_CLASSES * NUM_TRAIN_ELEMENTS);
  this.testLabels =
      this.datasetLabels.slice(NUM_CLASSES * NUM_TRAIN_ELEMENTS);

3.2. 构建模型

深度学习本质上是构建了一个多层网络模型,不能层会来实现 特征提取、关联学习的工作,通过 TensorFlow 我们可以很简单的构建一个自己的深度网络

function createDenseModel() {
  const model = tf.sequential();
  model.add(tf.layers.flatten({inputShape: [IMAGE_H, IMAGE_W, 1]}));
  model.add(tf.layers.dense({units: 42, activation: 'relu'}));
  model.add(tf.layers.dense({units: 10, activation: 'softmax'}));
  return model;
}
// 下一章会来介绍更多细节

3.3. 训练模型

模型编译后,便可进行训练,目前 TensorFlow.js 也提供 tfjs-vis 方便将训练过程可视化

model.compile({
  optimizer,
  loss: 'categoricalCrossentropy',
  metrics: ['accuracy'],
});
await model.fit(trainData.xs, trainData.labels, {
  batchSize,
  validationSplit,
  epochs: trainEpochs
});

线上示例

后续

下一周会来介绍 TensorFlow.js 中的核心概念

参考

相关文章
|
2月前
|
JavaScript 前端开发 Java
springboot解决js前端跨域问题,javascript跨域问题解决
本文介绍了如何在Spring Boot项目中编写Filter过滤器以处理跨域问题,并通过一个示例展示了使用JavaScript进行跨域请求的方法。首先,在Spring Boot应用中添加一个实现了`Filter`接口的类,设置响应头允许所有来源的跨域请求。接着,通过一个简单的HTML页面和jQuery发送AJAX请求到指定URL,验证跨域请求是否成功。文中还提供了请求成功的响应数据样例及请求效果截图。
springboot解决js前端跨域问题,javascript跨域问题解决
|
2月前
|
JavaScript 前端开发
Moment.js与其他处理时间戳格式差异的JavaScript库相比有什么优势?
Moment.js与其他处理时间戳格式差异的JavaScript库相比有什么优势?
|
2月前
|
机器学习/深度学习 人工智能 算法
【手写数字识别】Python+深度学习+机器学习+人工智能+TensorFlow+算法模型
手写数字识别系统,使用Python作为主要开发语言,基于深度学习TensorFlow框架,搭建卷积神经网络算法。并通过对数据集进行训练,最后得到一个识别精度较高的模型。并基于Flask框架,开发网页端操作平台,实现用户上传一张图片识别其名称。
109 0
【手写数字识别】Python+深度学习+机器学习+人工智能+TensorFlow+算法模型
|
2月前
|
机器学习/深度学习 TensorFlow API
机器学习实战:TensorFlow在图像识别中的应用探索
【10月更文挑战第28天】随着深度学习技术的发展,图像识别取得了显著进步。TensorFlow作为Google开源的机器学习框架,凭借其强大的功能和灵活的API,在图像识别任务中广泛应用。本文通过实战案例,探讨TensorFlow在图像识别中的优势与挑战,展示如何使用TensorFlow构建和训练卷积神经网络(CNN),并评估模型的性能。尽管面临学习曲线和资源消耗等挑战,TensorFlow仍展现出广阔的应用前景。
81 5
|
3月前
|
机器学习/深度学习 自然语言处理 JavaScript
信息论、机器学习的核心概念:熵、KL散度、JS散度和Renyi散度的深度解析及应用
在信息论、机器学习和统计学领域中,KL散度(Kullback-Leibler散度)是量化概率分布差异的关键概念。本文深入探讨了KL散度及其相关概念,包括Jensen-Shannon散度和Renyi散度。KL散度用于衡量两个概率分布之间的差异,而Jensen-Shannon散度则提供了一种对称的度量方式。Renyi散度通过可调参数α,提供了更灵活的散度度量。这些概念不仅在理论研究中至关重要,在实际应用中也广泛用于数据压缩、变分自编码器、强化学习等领域。通过分析电子商务中的数据漂移实例,展示了这些散度指标在捕捉数据分布变化方面的独特优势,为企业提供了数据驱动的决策支持。
193 2
信息论、机器学习的核心概念:熵、KL散度、JS散度和Renyi散度的深度解析及应用
|
2月前
|
机器学习/深度学习 人工智能 TensorFlow
基于TensorFlow的深度学习模型训练与优化实战
基于TensorFlow的深度学习模型训练与优化实战
110 0
|
4月前
|
机器学习/深度学习 算法 TensorFlow
交通标志识别系统Python+卷积神经网络算法+深度学习人工智能+TensorFlow模型训练+计算机课设项目+Django网页界面
交通标志识别系统。本系统使用Python作为主要编程语言,在交通标志图像识别功能实现中,基于TensorFlow搭建卷积神经网络算法模型,通过对收集到的58种常见的交通标志图像作为数据集,进行迭代训练最后得到一个识别精度较高的模型文件,然后保存为本地的h5格式文件。再使用Django开发Web网页端操作界面,实现用户上传一张交通标志图片,识别其名称。
163 6
交通标志识别系统Python+卷积神经网络算法+深度学习人工智能+TensorFlow模型训练+计算机课设项目+Django网页界面
|
3月前
|
机器学习/深度学习 人工智能 算法
【玉米病害识别】Python+卷积神经网络算法+人工智能+深度学习+计算机课设项目+TensorFlow+模型训练
玉米病害识别系统,本系统使用Python作为主要开发语言,通过收集了8种常见的玉米叶部病害图片数据集('矮花叶病', '健康', '灰斑病一般', '灰斑病严重', '锈病一般', '锈病严重', '叶斑病一般', '叶斑病严重'),然后基于TensorFlow搭建卷积神经网络算法模型,通过对数据集进行多轮迭代训练,最后得到一个识别精度较高的模型文件。再使用Django搭建Web网页操作平台,实现用户上传一张玉米病害图片识别其名称。
84 0
【玉米病害识别】Python+卷积神经网络算法+人工智能+深度学习+计算机课设项目+TensorFlow+模型训练
|
3月前
|
人工智能 JavaScript 前端开发
使用Node.js模拟执行JavaScript
使用Node.js模拟执行JavaScript
35 2
|
3月前
|
JavaScript 前端开发
电话号码正则表达式 代码 javascript+html,JS正则表达式判断11位手机号码
电话号码正则表达式 代码 javascript+html,JS正则表达式判断11位手机号码
141 1