基于DJL的机器学习

本文涉及的产品
RDS DuckDB + QuickBI 企业套餐,8核32GB + QuickBI 专业版
简介: 本文介绍了基于Java的深度学习框架DJL,涵盖机器学习与深度学习的核心概念、神经网络结构及生命周期,并通过MNIST数据集展示了从模型构建、训练到推理的完整流程。内容深入浅出,适合初学者入门。

Hi~各位读者朋友们,感谢您阅读本文,我是笠泱,本期分享基于Java语言入门机器学习,本期内容灵感来源于B站@雷丰阳(https://www.bilibili.com/video/BV15vqcYbEM9

1. 核心概念

1.1. 机器学习

机器学习是一个通过利用统计学知识,将数据输入到计算机中进行训练并完成特定目标任务的过程。这种归纳学习的方法可以让计算机学习一些特征并进行一系列复杂的任务,比如识别照片中的物体。由于需要写复杂的逻辑以及测量标准,这些任务在传统计算科学领域中很难实现。

参阅文档:https://docs.djl.ai/master/docs/demos/jupyter/tutorial/02_train_your_first_model.html

1.2. 深度学习

深度学习是机器学习的一个分支,主要侧重于对于人工神经网络的开发。人工神经网络是通过研究人脑如何学习和实现目标的过程中归纳而得出一套计算逻辑。它通过模拟部分人脑神经间信息传递的过程,从而实现各类复杂的任务。深度学习中的“深度”来源于我们会在人工神经网络中编织构建出许多层(多层感知机)从而进一步对数据信息进行更深层的传导。深度学习技术应用范围十分广泛,现在被用来做目标检测、动作识别、机器翻译、语意分析等各类现实应用中。

1.3. Neural Network - 神经网络

神经网络是一个黑盒程序。不用你自己编写这个函数,你需要为这个函数提供很多“输入/输出”对样本。然后,我们尝试训练网络,让他最大近似于我们给定的输入/输出对。拥有更多数据的更好的模型可以更准确地近似函数。

1.4. 生命周期

机器学习的生命周期与传统的软件开发生命周期有所不同,它包含七个具体的步骤:

  1. 获取数据,清洗并准备数据
  2. 构建神经网络
  3. 构建模型(这个模型应用上面的神经网络)
  4. 进行训练配置(如何训练、训练集、验证集、测试集),多轮训练,生成模型
  5. 保存模型
  6. 加载模型
  7. 从模型中获得预测(或推理)

生命周期的最终结果是一个可以查询并返回答案(或预测)的机器学习模型。

1.5. MLP

多层感知机(Multilayer Perceptron,简称 MLP)是一种前向型人工神经网络(Feedforward Artificial Neural Network),是最简单的前馈神经网络之一。MLP 由多个层组成,每层包含多个神经元(也称为节点),并且每一层的神经元与下一层的所有神经元相连。MLP 主要用于解决分类和回归问题。

1.5.1. MLP 结构

MLP 通常由以下几个部分组成:

  • 输入层(Input Layer)
  • 输入层接收外部输入数据。每个神经元对应一个输入特征。
  • 例如,对于一个图像分类任务,输入层的神经元数量等于图像的像素数量。
  • 隐藏层(Hidden Layers)
  • 隐藏层位于输入层和输出层之间。MLP 可以有一个或多个隐藏层。
  • 每个隐藏层包含多个神经元,每个神经元通过权重和偏置与前一层的所有神经元相连。
  • 隐藏层引入了非线性变换,使得 MLP 能够学习复杂的模式。
  • 输出层(Output Layer)
  • 输出层产生最终的预测结果。

对于分类任务,输出层的神经元数量通常等于类别数量,使用 softmax 激活函数。

对于回归任务,输出层的神经元数量通常为 1,使用线性激活函数。

1.5.2. 工作原理

参考:https://zhuanlan.zhihu.com/p/642537175

MLP 的工作原理可以分为前向传播(Forward Propagation)反向传播(Backward Propagation)

两个主要步骤:

前向传播

  • 输入数据从输入层传递到输出层。
  • 每个神经元计算加权和(线性组合),然后通过激活函数引入非线性。

反向传播

  • 反向传播用于计算损失函数相对于每个权重和偏置的梯度。
  • 使用梯度下降法更新权重和偏置,以最小化损失函数。

具体步骤包括:

  • 计算输出层的误差。
  • 逐层向后传播误差,计算每个隐藏层的误差。
  • 更新权重和偏置。

当我们准备好数据集和神经网络之后,就可以开始训练模型了。在深度学习中,一般会由下面几步来完成一个训练过程:

  • 初始化:我们会对每一个 Block 的参数进行初始化,初始化每个参数的函数都是由 设定的 Initializer 决定的。
  • 前向传播:这一步将输入数据在神经网络中逐层传递,然后产生输出数据。
  • 计算损失:我们会根据特定的损失函数 Loss 来计算输出和标记结果的偏差。
  • 反向传播:在这一步中,你可以利用损失反向求导算出每一个参数的梯度。
  • 更新权重:我们会根据选择的优化器(Optimizer)更新每一个在 Block 上参数的值。

2. DJL - Deep Java Library

Deep Java Library(DJL)是一个开源的、高级的、与引擎无关的深度学习Java框架。DJL旨在易于上手,易于Java开发人员使用。DJL提供原生Java开发体验和功能,就像任何其他常规Java库一样。

您不必成为机器学习/深度学习专家才能开始。您可以使用现有的Java专业知识作为学习和使用机器学习深度学习的入口。您可以使用您最喜欢的IDE来构建、训练和部署您的模型。DJL可以轻松地将这些模型与您的Java应用程序集成。

因为DJL与深度学习引擎无关,所以您在创建项目时不必在引擎之间做出选择。您可以在任何时候交换机引擎。为了确保最佳性能,DJL还提供了基于硬件配置的自动CPU/GPU选择。

2.1. 架构

2.2. 依赖

<properties>
        <java.version>17</java.version>
        <djl.version>0.26.0</djl.version>
    </properties>
    <dependencies>
        <dependency>
            <groupId>org.projectlombok</groupId>
            <artifactId>lombok</artifactId>
            <optional>true</optional>
        </dependency>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-test</artifactId>
            <scope>test</scope>
        </dependency>
        <!--        引入 djl 依赖  -->
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>api</artifactId>
        </dependency>
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>basicdataset</artifactId>
        </dependency>
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>model-zoo</artifactId>
        </dependency>
        <dependency>
            <groupId>ai.djl.opencv</groupId>
            <artifactId>opencv</artifactId>
        </dependency>
        <dependency>
            <groupId>ai.djl.mxnet</groupId>
            <artifactId>mxnet-engine</artifactId>
        </dependency>
        <dependency>
            <groupId>ai.djl.mxnet</groupId>
            <artifactId>mxnet-native-auto</artifactId>
            <version>1.8.0</version>
        </dependency>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-web</artifactId>
        </dependency>
    </dependencies>
    <dependencyManagement>
        <dependencies>
            <dependency>
                <groupId>ai.djl</groupId>
                <artifactId>bom</artifactId>
                <version>${djl.version}</version>
                <type>pom</type>
                <scope>import</scope>
            </dependency>
        </dependencies>
    </dependencyManagement>

2.3. 流程

2.3.1. NDArray - N维向量

创建、运算、操作

@Test
    void testNDArray(){
        // NDManager 创建和管理深度学习期间的临时数据。销毁后自动释放所有资源
        try(NDManager manager = NDManager.newBaseManager()){
            //创建 2x2 矩阵
            /**
             * ND: (2, 2) cpu() float32
             * [[0., 1.],
             *  [2., 3.],
             * ]
             */
            NDArray ndArray = manager.create(new float[]{0f, 1f, 2f, 3f}, new Shape(2, 2));
            /**
             * ND: (2, 3) cpu() float32
             * [[1., 1., 1.],
             *  [1., 1., 1.],
             * ]
             */
            NDArray ones = manager.ones(new Shape(2, 3));
            /**
             * ND: (2, 3, 4) cpu() float32
             * [[[0.5488, 0.5928, 0.7152, 0.8443],
             *   [0.6028, 0.8579, 0.5449, 0.8473],
             *   [0.4237, 0.6236, 0.6459, 0.3844],
             *  ],
             *  [[0.4376, 0.2975, 0.8918, 0.0567],
             *   [0.9637, 0.2727, 0.3834, 0.4777],
             *   [0.7917, 0.8122, 0.5289, 0.48  ],
             *  ],
             * ]
             */
            NDArray uniform = manager.randomUniform(0, 1, new Shape(2,3, 4));
            System.out.println(uniform);
        }
    }
//===================
NDArray uniform = manager.randomUniform(0, 1, new Shape(2,3, 4));
// 加
NDArray add = uniform.add(5);
System.out.println(add);
//===================
    @Test
    void operation(){
        try(NDManager manager = NDManager.newBaseManager()){
            var array1 = manager. create(new float[] {1f, 2f, 3f, 4f}, new Shape(2, 2));
            var array2 = manager. create(new float[] {1f, 2f});
            /**
             * ND: (2) cpu() float32
             * [ 5., 11.]
             */
            var array3 = array1.matMul(array2);
            System.out.println(array3);
            /**
             * ND: (1, 4) cpu() float32
             * [[1., 2., 3., 4.],
             * ]
             */
            var array4 = array1.reshape(1, 4);
            System.out.println(array4);
            /**
             * ND: (2, 2) cpu() float32
             * [[1., 3.],
             *  [2., 4.],
             * ]
             */
            var array5 = array1.transpose();
            System.out.println(array5);
        }
    }

2.3.2. Dataset - 数据集

2.3.2.1. 作用

数据集是用于训练机器学习模型的数据集合。

机器学习通常使用三个数据集https://machinelearningmastery.com/difference-test-validation-datasets/

  • 训练数据集
  • 我们用来训练模型的实际数据集。模型从这些数据中学习权重和参数。
  • 验证数据集
  • 验证集用于在训练过程中评估给定模型。它帮助机器学习工程师在模型开发阶段微调超参数。模型不从验证数据集学习;验证数据集是可选的。
  • 测试数据集
  • 测试数据集提供了用于评估模型的黄金标准,它只在模型完全训练后使用,测试数据集应该更准确地评估模型将如何在新数据上执行。

2.3.2.2. 内置数据集

DJL提供了许多内置的基本和标准数据集,这些数据集用于训练深度学习模型,本模块包含以下数据集:

详细参阅:https://docs.djl.ai/master/docs/dataset.html

CV(计算机视觉)

  • Image Classification
  • MNIST - A small and fast handwritten digits dataset
  • Fashion MNIST - A small and fast clothing type detection dataset
  • CIFAR10 - A dataset consisting of 60,000 32x32 color images in 10 classes
  • ImageNet - An image database organized according to the WordNet hierarchy

Note: You have to manually download the ImageNet dataset due to licensing requirements.

  • Object Detection
  • Pikachu - 1000 Pikachu images of different angles and sizes created using an open source 3D Pikachu model
  • Banana Detection - A testing single object detection dataset
  • Other CV
  • Captcha - A dataset for a grayscale 6-digit CAPTCHA task
  • Coco - A large-scale object detection, segmentation, and captioning dataset that contains 1.5 million object instances

You have to manually add com.twelvemonkeys.imageio:imageio-jpeg:3.11.0 dependency to your project

NLP(自然语言处理)

Text Classification and Sentiment Analysis

  • AmazonReview - A sentiment analysis dataset of Amazon Reviews with their ratings
  • Stanford Movie Review - A sentiment analysis dataset of movie reviews and sentiments sourced from IMDB
  • GoEmotions - A dataset classifying 50k curated reddit comments into either 27 emotion categories or neutral

Unlabeled Text

  • Penn Treebank Text - The text (not POS tags) from the Penn Treebank, a collection of Wall Street Journal stories
  • WikiText2 - A collection of over 100 million tokens extracted from good and featured articles on wikipedia

Other NLP

Tabular(表格化)

Time Series(时序)

2.3.3. Block - 块函数

顺序块用于创建块函数链,其中一个块的输出作为下一个块的输入进行传递

@Test
    void testBlock(){
        long inputSize = 28*28;
        long outputSize = 10;
        SequentialBlock block = new SequentialBlock();
        //一个批量扁平块,将二维图像输入转化为一维特征向量
        block.add(Blocks.batchFlattenBlock(inputSize));
        //添加一个和隐藏层,其线性变换大小为128
        block.add(Linear.builder().setUnits(128).build());
        //添加相应的激活函数;
        block.add(Activation::relu);
        //作为第二个隐藏层的激活函数,有大小为64的变换
        block.add(Linear.builder().setUnits(64).build());
        block.add(Activation::relu);
        //最后输出10大小的特征向量
        block.add(Linear.builder().setUnits(outputSize).build());
        //这些大小是在实验过程中选择的
        //围绕块,可以构建我们的模型,它添加了一些重要的元数据,例如:可以在训练和推理期间使用的名称
        Model model = Model.newInstance("mlp");
        model.setBlock(block);
        
        //现在已经拥有了块和模型,接下来就是如何进行训练
    }

2.3.4. Model - 模型

2.3.5. Inference - 推理

2.3.5.1. 流程

2.3.5.2. Translator - 转换器

ImageClassificationTranslator classificationTranslator = ImageClassificationTranslator.builder()
        .addTransform(new CenterCrop()) //中心裁剪
        .addTransform(new Resize(224, 224)) //调整大小
        .addTransform(new ToTensor())//将图像 NDArray 从预处理格式转换为神经网络格式的变换
        .build();

2.3.6. Model Loading - 模型加载

  • 创建 Filters 而查找模型
  • Local Disk
  • URL:file、http、jar
  • DJL 内置 modelzoo
  • Builder Pattern
  • setTypes 必须
  • optXxx 可选
@Test
    void testModelLoading() throws ModelNotFoundException, MalformedModelException, IOException {
        ImageClassificationTranslator translator = ImageClassificationTranslator.builder()
                .addTransform(new CenterCrop())
                .addTransform(new Resize(224, 224))
                .addTransform(new ToTensor())
                .build();
        //模型检索条件
        Criteria<Image, Classifications> criteria = Criteria.builder()
                .setTypes(Image.class, Classifications.class)
                .optApplication(Application.CV.IMAGE_CLASSIFICATION)
                .optFilter("layers", "50")
                .optTranslator(translator)
                .optProgress(new ProgressBar())
                .build();
        //加载模型
        ZooModel<Image, Classifications> model = ModelZoo.loadModel(criteria);
        
        System.out.println(model.getName());
    }

2.3.7. Predictor - 预测

//预测
Predictor<Image, Classifications> predictor = model.newPredictor();
Classifications predict = predictor.predict(image);

3. Mnist数据训练 - 完整过程

基于多层感知机训练手写数字识别模型

3.1. 准备数据集

//0、准备数据集
RandomAccessDataset trainingSet = getDataset(Dataset.Usage.TRAIN);
RandomAccessDataset validateSet = getDataset(Dataset.Usage.TEST);
private RandomAccessDataset getDataset(Dataset.Usage usage) {
    Mnist mnist =
            Mnist.builder()
                    .optUsage(usage)
                    .setSampling(64, true)
                    .build();
    mnist.prepare(new ProgressBar());
    return mnist;
}

3.2. 构建神经网络

//1、构建神经网络
Block block =
        new Mlp(
                Mnist.IMAGE_HEIGHT * Mnist.IMAGE_WIDTH,
                Mnist.NUM_CLASSES,
                new int[]{128, 64});

3.3. 创建模型

//2、创建模型
try (Model model = Model.newInstance("mlp")) {
    model.setBlock(block);
    // 3.4及以后的代码
    
}

3.4. 训练配置

//3、训练配置
   DefaultTrainingConfig config = setupTrainingConfig();
    private DefaultTrainingConfig setupTrainingConfig() {
        String outputDir = "build/model";
        SaveModelTrainingListener listener = new SaveModelTrainingListener(outputDir);
        listener.setSaveModelCallback(
                trainer -> {
                    TrainingResult result = trainer.getTrainingResult();
                    Model model = trainer.getModel();
                    float accuracy = result.getValidateEvaluation("Accuracy");
                    model.setProperty("Accuracy", String.format("%.5f", accuracy));
                    model.setProperty("Loss", String.format("%.5f", result.getValidateLoss()));
                });
        return new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
                .addEvaluator(new Accuracy())
                .addTrainingListeners(TrainingListener.Defaults.logging(outputDir))
                .addTrainingListeners(listener);
    }

3.5. 开始训练

//4、拿到训练器
try (Trainer trainer = model.newTrainer(config)) {
    trainer.setMetrics(new Metrics());
    /*
     * MNIST is 28x28 grayscale image and pre processed into 28 * 28 NDArray.
     * 1st axis is batch axis, we can use 1 for initialization.
     */
    Shape inputShape = new Shape(1, Mnist.IMAGE_HEIGHT * Mnist.IMAGE_WIDTH);
    // initialize trainer with proper input shape
    trainer.initialize(inputShape);
    //5、开始训练
    EasyTrain.fit(trainer, 5, trainingSet, validateSet);
    //6、训练结果
    TrainingResult result = trainer.getTrainingResult();
    System.out.println("result = " + result);
}

3.6. 保存模型

// 前面  SaveModelTrainingListener listener = new SaveModelTrainingListener(outputDir); 
// 已经可以保存模型
//也可以直接调用如下API
Path modelDir = Paths.get("build/mlpx");
Files.createDirectories(modelDir);
// Save the model
model.save(modelDir, "mlpx");

3.7. 加载模型

//加载模型
Path modelDir = Paths.get("build/model");
Model model = Model.newInstance("mlp");
model.setBlock(block);
model.load(modelDir);

3.8. 创建推理器

//如果没有 Translator,非标图像则无法处理
Translator<Image, Classifications> translator =
        ImageClassificationTranslator.builder()
                .addTransform(new Resize(Mnist.IMAGE_WIDTH, Mnist.IMAGE_HEIGHT))
                .addTransform(new ToTensor())
                .optApplySoftmax(true)
                .build();
Predictor<Image, Classifications> predictor = model.newPredictor(translator);

3.9. 分类推理

Classifications predict = predictor.predict(img);
System.out.println("predict = " + predict);

4. 进阶学习

我们继续分享机器学习,进阶训练一个鞋分类模型;

代码(带训练好的模型):📎foot-classify-demo.zip

由于鞋分类模型训练时间会比较长,建议使用显卡,需要安装 CUDA、cuDNN环境

nvcc -V确认cuda版本

安装pytorch环境

pip3 install torch torchvision torchaudio --user --index-url https://download.pytorch.org/whl/cu121

1、准备数据集(训练数据【训练集、验证集、测试集】)

2、定义神经网络(不同网络模型来训练不同的效果)

3、定义训练配置

4、开始训练

5、生产模型

6、保存模型

7、加载模型

8、预测(推理),预测结果

5. 鞋分类

数据集下载:https://vision.cs.utexas.edu/projects/finegrained/utzap50k/

DJL官方案例:https://docs.djl.ai/master/docs/demos/footwear_classification/index.html

作者在训练鞋分类模型过程中似乎没用上电脑的英伟达显卡加速,CPU跑满了都救不了这训练速度,希望得到读者朋友的帮助与指导。

预测验证

6. 扩展:

6.1. CUDA

https://developer.nvidia.com/cuda-toolkit

https://developer.nvidia.com/rdp/cudnn-download

CUDA:下载安装即可

cuDNN:下载后,解压把 bin、lib等目录复制到cuda的安装目录下

6.2. 机器学习算法

算法一览表,根据不同问题选择不同算法

本期总结

本期分享了基于DJL的机器学习相关基础内容,浅尝辄止,希望有兴趣的读者朋友们一起探讨交流。

最后,感谢您的阅读!系列文章会同步在微信公众号@云上的喵酱、阿里云开发者社区@云上的喵酱、CSDN@笠泱 更新,您的点赞+关注+转发是我后续更新的动力!

相关文章
|
6月前
|
机器学习/深度学习 设计模式 人工智能
TinyAI :全栈式轻量级 AI 框架
一个完全用Java实现的全栈式轻量级AI框架,TinyAI IS ALL YOU NEED。
TinyAI :全栈式轻量级 AI 框架
|
druid Java 数据库连接
Spring Boot3整合MyBatis Plus
Spring Boot3整合MyBatis Plus
2133 1
|
人工智能 自然语言处理 前端开发
从理论到实践:使用JAVA实现RAG、Agent、微调等六种常见大模型定制策略
大语言模型(LLM)在过去几年中彻底改变了自然语言处理领域,展现了在理解和生成类人文本方面的卓越能力。然而,通用LLM的开箱即用性能并不总能满足特定的业务需求或领域要求。为了将LLM更好地应用于实际场景,开发出了多种LLM定制策略。本文将深入探讨RAG(Retrieval Augmented Generation)、Agent、微调(Fine-Tuning)等六种常见的大模型定制策略,并使用JAVA进行demo处理,以期为AI资深架构师提供实践指导。
2116 73
|
机器学习/深度学习 人工智能 Java
Java机器学习实战:基于DJL框架的手写数字识别全解析
在人工智能蓬勃发展的今天,Python凭借丰富的生态库(如TensorFlow、PyTorch)成为AI开发的首选语言。但Java作为企业级应用的基石,其在生产环境部署、性能优化和工程化方面的优势不容忽视。DJL(Deep Java Library)的出现完美填补了Java在深度学习领域的空白,它提供了一套统一的API,允许开发者无缝对接主流深度学习框架,将AI模型高效部署到Java生态中。本文将通过手写数字识别的完整流程,深入解析DJL框架的核心机制与应用实践。
887 3
|
8月前
|
存储 缓存 监控
LangChain4j 详细教程
LangChain4j 详细教程
3519 7
|
8月前
|
人工智能 Java API
Java AI智能体实战:使用LangChain4j构建能使用工具的AI助手
随着AI技术的发展,AI智能体(Agent)能够通过使用工具来执行复杂任务,从而大幅扩展其能力边界。本文介绍如何在Java中使用LangChain4j框架构建一个能够使用外部工具的AI智能体。我们将通过一个具体示例——一个能获取天气信息和执行数学计算的AI助手,详细讲解如何定义工具、创建智能体并处理执行流程。本文包含完整的代码示例和架构说明,帮助Java开发者快速上手AI智能体的开发。
3184 8
|
人工智能 开发框架 Java
重磅发布!AI 驱动的 Java 开发框架:Spring AI Alibaba
随着生成式 AI 的快速发展,基于 AI 开发框架构建 AI 应用的诉求迅速增长,涌现出了包括 LangChain、LlamaIndex 等开发框架,但大部分框架只提供了 Python 语言的实现。但这些开发框架对于国内习惯了 Spring 开发范式的 Java 开发者而言,并非十分友好和丝滑。因此,我们基于 Spring AI 发布并快速演进 Spring AI Alibaba,通过提供一种方便的 API 抽象,帮助 Java 开发者简化 AI 应用的开发。同时,提供了完整的开源配套,包括可观测、网关、消息队列、配置中心等。
11344 126
|
运维 NoSQL 应用服务中间件
云服务器规格与带宽选型
本文主要分享了云服务器规格与带宽选型的经验,包括PV、UV、IP等概念的解释及其简化换算关系。文章详细介绍了根据业务访问规律计算合适的服务器资源配置,并提供了CPU与内存不同配比适用的业务场景。同时,针对带宽配置选择,提出了基于总请求量和单次请求大小的估算模型,以及按量付费和固定带宽的选择标准。最后简述了云上运维从人工到智能化(AIOps)的发展阶段,为读者提供实用参考。
765 57
|
缓存 Unix 应用服务中间件
Nginx,最强单体之一
Nginx是一款高性能的HTTP Web服务器、反向代理、内容缓存及负载均衡器,由伊戈尔·赛索耶夫开发并开源。它采用多进程和I/O多路复用技术,支持高并发和高效处理请求,广泛应用于各大互联网公司。Nginx不仅具备基本的HTTP服务功能,如静态文件处理、反向代理和负载均衡,还支持高级特性如SSL、HTTP/2、动静分离等。其架构设计使其在性能、可靠性、扩展性等方面表现出色,成为Web技术学习和应用的首选工具之一。本文将分两部分介绍Nginx的架构及其原生常用功能。
910 25
Nginx,最强单体之一