基于DJL的机器学习

简介: 本文介绍了基于Java的深度学习框架DJL,涵盖机器学习与深度学习的核心概念、神经网络结构及生命周期,并通过MNIST数据集展示了从模型构建、训练到推理的完整流程。内容深入浅出,适合初学者入门。

Hi~各位读者朋友们,感谢您阅读本文,我是笠泱,本期分享基于Java语言入门机器学习,本期内容灵感来源于B站@雷丰阳(https://www.bilibili.com/video/BV15vqcYbEM9

1. 核心概念

1.1. 机器学习

机器学习是一个通过利用统计学知识,将数据输入到计算机中进行训练并完成特定目标任务的过程。这种归纳学习的方法可以让计算机学习一些特征并进行一系列复杂的任务,比如识别照片中的物体。由于需要写复杂的逻辑以及测量标准,这些任务在传统计算科学领域中很难实现。

参阅文档:https://docs.djl.ai/master/docs/demos/jupyter/tutorial/02_train_your_first_model.html

1.2. 深度学习

深度学习是机器学习的一个分支,主要侧重于对于人工神经网络的开发。人工神经网络是通过研究人脑如何学习和实现目标的过程中归纳而得出一套计算逻辑。它通过模拟部分人脑神经间信息传递的过程,从而实现各类复杂的任务。深度学习中的“深度”来源于我们会在人工神经网络中编织构建出许多层(多层感知机)从而进一步对数据信息进行更深层的传导。深度学习技术应用范围十分广泛,现在被用来做目标检测、动作识别、机器翻译、语意分析等各类现实应用中。

1.3. Neural Network - 神经网络

神经网络是一个黑盒程序。不用你自己编写这个函数,你需要为这个函数提供很多“输入/输出”对样本。然后,我们尝试训练网络,让他最大近似于我们给定的输入/输出对。拥有更多数据的更好的模型可以更准确地近似函数。

1.4. 生命周期

机器学习的生命周期与传统的软件开发生命周期有所不同,它包含七个具体的步骤:

  1. 获取数据,清洗并准备数据
  2. 构建神经网络
  3. 构建模型(这个模型应用上面的神经网络)
  4. 进行训练配置(如何训练、训练集、验证集、测试集),多轮训练,生成模型
  5. 保存模型
  6. 加载模型
  7. 从模型中获得预测(或推理)

生命周期的最终结果是一个可以查询并返回答案(或预测)的机器学习模型。

1.5. MLP

多层感知机(Multilayer Perceptron,简称 MLP)是一种前向型人工神经网络(Feedforward Artificial Neural Network),是最简单的前馈神经网络之一。MLP 由多个层组成,每层包含多个神经元(也称为节点),并且每一层的神经元与下一层的所有神经元相连。MLP 主要用于解决分类和回归问题。

1.5.1. MLP 结构

MLP 通常由以下几个部分组成:

  • 输入层(Input Layer)
  • 输入层接收外部输入数据。每个神经元对应一个输入特征。
  • 例如,对于一个图像分类任务,输入层的神经元数量等于图像的像素数量。
  • 隐藏层(Hidden Layers)
  • 隐藏层位于输入层和输出层之间。MLP 可以有一个或多个隐藏层。
  • 每个隐藏层包含多个神经元,每个神经元通过权重和偏置与前一层的所有神经元相连。
  • 隐藏层引入了非线性变换,使得 MLP 能够学习复杂的模式。
  • 输出层(Output Layer)
  • 输出层产生最终的预测结果。

对于分类任务,输出层的神经元数量通常等于类别数量,使用 softmax 激活函数。

对于回归任务,输出层的神经元数量通常为 1,使用线性激活函数。

1.5.2. 工作原理

参考:https://zhuanlan.zhihu.com/p/642537175

MLP 的工作原理可以分为前向传播(Forward Propagation)反向传播(Backward Propagation)

两个主要步骤:

前向传播

  • 输入数据从输入层传递到输出层。
  • 每个神经元计算加权和(线性组合),然后通过激活函数引入非线性。

反向传播

  • 反向传播用于计算损失函数相对于每个权重和偏置的梯度。
  • 使用梯度下降法更新权重和偏置,以最小化损失函数。

具体步骤包括:

  • 计算输出层的误差。
  • 逐层向后传播误差,计算每个隐藏层的误差。
  • 更新权重和偏置。

当我们准备好数据集和神经网络之后,就可以开始训练模型了。在深度学习中,一般会由下面几步来完成一个训练过程:

  • 初始化:我们会对每一个 Block 的参数进行初始化,初始化每个参数的函数都是由 设定的 Initializer 决定的。
  • 前向传播:这一步将输入数据在神经网络中逐层传递,然后产生输出数据。
  • 计算损失:我们会根据特定的损失函数 Loss 来计算输出和标记结果的偏差。
  • 反向传播:在这一步中,你可以利用损失反向求导算出每一个参数的梯度。
  • 更新权重:我们会根据选择的优化器(Optimizer)更新每一个在 Block 上参数的值。

2. DJL - Deep Java Library

Deep Java Library(DJL)是一个开源的、高级的、与引擎无关的深度学习Java框架。DJL旨在易于上手,易于Java开发人员使用。DJL提供原生Java开发体验和功能,就像任何其他常规Java库一样。

您不必成为机器学习/深度学习专家才能开始。您可以使用现有的Java专业知识作为学习和使用机器学习深度学习的入口。您可以使用您最喜欢的IDE来构建、训练和部署您的模型。DJL可以轻松地将这些模型与您的Java应用程序集成。

因为DJL与深度学习引擎无关,所以您在创建项目时不必在引擎之间做出选择。您可以在任何时候交换机引擎。为了确保最佳性能,DJL还提供了基于硬件配置的自动CPU/GPU选择。

2.1. 架构

2.2. 依赖

<properties>
        <java.version>17</java.version>
        <djl.version>0.26.0</djl.version>
    </properties>
    <dependencies>
        <dependency>
            <groupId>org.projectlombok</groupId>
            <artifactId>lombok</artifactId>
            <optional>true</optional>
        </dependency>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-test</artifactId>
            <scope>test</scope>
        </dependency>
        <!--        引入 djl 依赖  -->
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>api</artifactId>
        </dependency>
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>basicdataset</artifactId>
        </dependency>
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>model-zoo</artifactId>
        </dependency>
        <dependency>
            <groupId>ai.djl.opencv</groupId>
            <artifactId>opencv</artifactId>
        </dependency>
        <dependency>
            <groupId>ai.djl.mxnet</groupId>
            <artifactId>mxnet-engine</artifactId>
        </dependency>
        <dependency>
            <groupId>ai.djl.mxnet</groupId>
            <artifactId>mxnet-native-auto</artifactId>
            <version>1.8.0</version>
        </dependency>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-web</artifactId>
        </dependency>
    </dependencies>
    <dependencyManagement>
        <dependencies>
            <dependency>
                <groupId>ai.djl</groupId>
                <artifactId>bom</artifactId>
                <version>${djl.version}</version>
                <type>pom</type>
                <scope>import</scope>
            </dependency>
        </dependencies>
    </dependencyManagement>

2.3. 流程

2.3.1. NDArray - N维向量

创建、运算、操作

@Test
    void testNDArray(){
        // NDManager 创建和管理深度学习期间的临时数据。销毁后自动释放所有资源
        try(NDManager manager = NDManager.newBaseManager()){
            //创建 2x2 矩阵
            /**
             * ND: (2, 2) cpu() float32
             * [[0., 1.],
             *  [2., 3.],
             * ]
             */
            NDArray ndArray = manager.create(new float[]{0f, 1f, 2f, 3f}, new Shape(2, 2));
            /**
             * ND: (2, 3) cpu() float32
             * [[1., 1., 1.],
             *  [1., 1., 1.],
             * ]
             */
            NDArray ones = manager.ones(new Shape(2, 3));
            /**
             * ND: (2, 3, 4) cpu() float32
             * [[[0.5488, 0.5928, 0.7152, 0.8443],
             *   [0.6028, 0.8579, 0.5449, 0.8473],
             *   [0.4237, 0.6236, 0.6459, 0.3844],
             *  ],
             *  [[0.4376, 0.2975, 0.8918, 0.0567],
             *   [0.9637, 0.2727, 0.3834, 0.4777],
             *   [0.7917, 0.8122, 0.5289, 0.48  ],
             *  ],
             * ]
             */
            NDArray uniform = manager.randomUniform(0, 1, new Shape(2,3, 4));
            System.out.println(uniform);
        }
    }
//===================
NDArray uniform = manager.randomUniform(0, 1, new Shape(2,3, 4));
// 加
NDArray add = uniform.add(5);
System.out.println(add);
//===================
    @Test
    void operation(){
        try(NDManager manager = NDManager.newBaseManager()){
            var array1 = manager. create(new float[] {1f, 2f, 3f, 4f}, new Shape(2, 2));
            var array2 = manager. create(new float[] {1f, 2f});
            /**
             * ND: (2) cpu() float32
             * [ 5., 11.]
             */
            var array3 = array1.matMul(array2);
            System.out.println(array3);
            /**
             * ND: (1, 4) cpu() float32
             * [[1., 2., 3., 4.],
             * ]
             */
            var array4 = array1.reshape(1, 4);
            System.out.println(array4);
            /**
             * ND: (2, 2) cpu() float32
             * [[1., 3.],
             *  [2., 4.],
             * ]
             */
            var array5 = array1.transpose();
            System.out.println(array5);
        }
    }

2.3.2. Dataset - 数据集

2.3.2.1. 作用

数据集是用于训练机器学习模型的数据集合。

机器学习通常使用三个数据集https://machinelearningmastery.com/difference-test-validation-datasets/

  • 训练数据集
  • 我们用来训练模型的实际数据集。模型从这些数据中学习权重和参数。
  • 验证数据集
  • 验证集用于在训练过程中评估给定模型。它帮助机器学习工程师在模型开发阶段微调超参数。模型不从验证数据集学习;验证数据集是可选的。
  • 测试数据集
  • 测试数据集提供了用于评估模型的黄金标准,它只在模型完全训练后使用,测试数据集应该更准确地评估模型将如何在新数据上执行。

2.3.2.2. 内置数据集

DJL提供了许多内置的基本和标准数据集,这些数据集用于训练深度学习模型,本模块包含以下数据集:

详细参阅:https://docs.djl.ai/master/docs/dataset.html

CV(计算机视觉)

  • Image Classification
  • MNIST - A small and fast handwritten digits dataset
  • Fashion MNIST - A small and fast clothing type detection dataset
  • CIFAR10 - A dataset consisting of 60,000 32x32 color images in 10 classes
  • ImageNet - An image database organized according to the WordNet hierarchy

Note: You have to manually download the ImageNet dataset due to licensing requirements.

  • Object Detection
  • Pikachu - 1000 Pikachu images of different angles and sizes created using an open source 3D Pikachu model
  • Banana Detection - A testing single object detection dataset
  • Other CV
  • Captcha - A dataset for a grayscale 6-digit CAPTCHA task
  • Coco - A large-scale object detection, segmentation, and captioning dataset that contains 1.5 million object instances

You have to manually add com.twelvemonkeys.imageio:imageio-jpeg:3.11.0 dependency to your project

NLP(自然语言处理)

Text Classification and Sentiment Analysis

  • AmazonReview - A sentiment analysis dataset of Amazon Reviews with their ratings
  • Stanford Movie Review - A sentiment analysis dataset of movie reviews and sentiments sourced from IMDB
  • GoEmotions - A dataset classifying 50k curated reddit comments into either 27 emotion categories or neutral

Unlabeled Text

  • Penn Treebank Text - The text (not POS tags) from the Penn Treebank, a collection of Wall Street Journal stories
  • WikiText2 - A collection of over 100 million tokens extracted from good and featured articles on wikipedia

Other NLP

Tabular(表格化)

Time Series(时序)

2.3.3. Block - 块函数

顺序块用于创建块函数链,其中一个块的输出作为下一个块的输入进行传递

@Test
    void testBlock(){
        long inputSize = 28*28;
        long outputSize = 10;
        SequentialBlock block = new SequentialBlock();
        //一个批量扁平块,将二维图像输入转化为一维特征向量
        block.add(Blocks.batchFlattenBlock(inputSize));
        //添加一个和隐藏层,其线性变换大小为128
        block.add(Linear.builder().setUnits(128).build());
        //添加相应的激活函数;
        block.add(Activation::relu);
        //作为第二个隐藏层的激活函数,有大小为64的变换
        block.add(Linear.builder().setUnits(64).build());
        block.add(Activation::relu);
        //最后输出10大小的特征向量
        block.add(Linear.builder().setUnits(outputSize).build());
        //这些大小是在实验过程中选择的
        //围绕块,可以构建我们的模型,它添加了一些重要的元数据,例如:可以在训练和推理期间使用的名称
        Model model = Model.newInstance("mlp");
        model.setBlock(block);
        
        //现在已经拥有了块和模型,接下来就是如何进行训练
    }

2.3.4. Model - 模型

2.3.5. Inference - 推理

2.3.5.1. 流程

2.3.5.2. Translator - 转换器

ImageClassificationTranslator classificationTranslator = ImageClassificationTranslator.builder()
        .addTransform(new CenterCrop()) //中心裁剪
        .addTransform(new Resize(224, 224)) //调整大小
        .addTransform(new ToTensor())//将图像 NDArray 从预处理格式转换为神经网络格式的变换
        .build();

2.3.6. Model Loading - 模型加载

  • 创建 Filters 而查找模型
  • Local Disk
  • URL:file、http、jar
  • DJL 内置 modelzoo
  • Builder Pattern
  • setTypes 必须
  • optXxx 可选
@Test
    void testModelLoading() throws ModelNotFoundException, MalformedModelException, IOException {
        ImageClassificationTranslator translator = ImageClassificationTranslator.builder()
                .addTransform(new CenterCrop())
                .addTransform(new Resize(224, 224))
                .addTransform(new ToTensor())
                .build();
        //模型检索条件
        Criteria<Image, Classifications> criteria = Criteria.builder()
                .setTypes(Image.class, Classifications.class)
                .optApplication(Application.CV.IMAGE_CLASSIFICATION)
                .optFilter("layers", "50")
                .optTranslator(translator)
                .optProgress(new ProgressBar())
                .build();
        //加载模型
        ZooModel<Image, Classifications> model = ModelZoo.loadModel(criteria);
        
        System.out.println(model.getName());
    }

2.3.7. Predictor - 预测

//预测
Predictor<Image, Classifications> predictor = model.newPredictor();
Classifications predict = predictor.predict(image);

3. Mnist数据训练 - 完整过程

基于多层感知机训练手写数字识别模型

3.1. 准备数据集

//0、准备数据集
RandomAccessDataset trainingSet = getDataset(Dataset.Usage.TRAIN);
RandomAccessDataset validateSet = getDataset(Dataset.Usage.TEST);
private RandomAccessDataset getDataset(Dataset.Usage usage) {
    Mnist mnist =
            Mnist.builder()
                    .optUsage(usage)
                    .setSampling(64, true)
                    .build();
    mnist.prepare(new ProgressBar());
    return mnist;
}

3.2. 构建神经网络

//1、构建神经网络
Block block =
        new Mlp(
                Mnist.IMAGE_HEIGHT * Mnist.IMAGE_WIDTH,
                Mnist.NUM_CLASSES,
                new int[]{128, 64});

3.3. 创建模型

//2、创建模型
try (Model model = Model.newInstance("mlp")) {
    model.setBlock(block);
    // 3.4及以后的代码
    
}

3.4. 训练配置

//3、训练配置
   DefaultTrainingConfig config = setupTrainingConfig();
    private DefaultTrainingConfig setupTrainingConfig() {
        String outputDir = "build/model";
        SaveModelTrainingListener listener = new SaveModelTrainingListener(outputDir);
        listener.setSaveModelCallback(
                trainer -> {
                    TrainingResult result = trainer.getTrainingResult();
                    Model model = trainer.getModel();
                    float accuracy = result.getValidateEvaluation("Accuracy");
                    model.setProperty("Accuracy", String.format("%.5f", accuracy));
                    model.setProperty("Loss", String.format("%.5f", result.getValidateLoss()));
                });
        return new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
                .addEvaluator(new Accuracy())
                .addTrainingListeners(TrainingListener.Defaults.logging(outputDir))
                .addTrainingListeners(listener);
    }

3.5. 开始训练

//4、拿到训练器
try (Trainer trainer = model.newTrainer(config)) {
    trainer.setMetrics(new Metrics());
    /*
     * MNIST is 28x28 grayscale image and pre processed into 28 * 28 NDArray.
     * 1st axis is batch axis, we can use 1 for initialization.
     */
    Shape inputShape = new Shape(1, Mnist.IMAGE_HEIGHT * Mnist.IMAGE_WIDTH);
    // initialize trainer with proper input shape
    trainer.initialize(inputShape);
    //5、开始训练
    EasyTrain.fit(trainer, 5, trainingSet, validateSet);
    //6、训练结果
    TrainingResult result = trainer.getTrainingResult();
    System.out.println("result = " + result);
}

3.6. 保存模型

// 前面  SaveModelTrainingListener listener = new SaveModelTrainingListener(outputDir); 
// 已经可以保存模型
//也可以直接调用如下API
Path modelDir = Paths.get("build/mlpx");
Files.createDirectories(modelDir);
// Save the model
model.save(modelDir, "mlpx");

3.7. 加载模型

//加载模型
Path modelDir = Paths.get("build/model");
Model model = Model.newInstance("mlp");
model.setBlock(block);
model.load(modelDir);

3.8. 创建推理器

//如果没有 Translator,非标图像则无法处理
Translator<Image, Classifications> translator =
        ImageClassificationTranslator.builder()
                .addTransform(new Resize(Mnist.IMAGE_WIDTH, Mnist.IMAGE_HEIGHT))
                .addTransform(new ToTensor())
                .optApplySoftmax(true)
                .build();
Predictor<Image, Classifications> predictor = model.newPredictor(translator);

3.9. 分类推理

Classifications predict = predictor.predict(img);
System.out.println("predict = " + predict);

4. 进阶学习

我们继续分享机器学习,进阶训练一个鞋分类模型;

代码(带训练好的模型):📎foot-classify-demo.zip

由于鞋分类模型训练时间会比较长,建议使用显卡,需要安装 CUDA、cuDNN环境

nvcc -V确认cuda版本

安装pytorch环境

pip3 install torch torchvision torchaudio --user --index-url https://download.pytorch.org/whl/cu121

1、准备数据集(训练数据【训练集、验证集、测试集】)

2、定义神经网络(不同网络模型来训练不同的效果)

3、定义训练配置

4、开始训练

5、生产模型

6、保存模型

7、加载模型

8、预测(推理),预测结果

5. 鞋分类

数据集下载:https://vision.cs.utexas.edu/projects/finegrained/utzap50k/

DJL官方案例:https://docs.djl.ai/master/docs/demos/footwear_classification/index.html

作者在训练鞋分类模型过程中似乎没用上电脑的英伟达显卡加速,CPU跑满了都救不了这训练速度,希望得到读者朋友的帮助与指导。

预测验证

6. 扩展:

6.1. CUDA

https://developer.nvidia.com/cuda-toolkit

https://developer.nvidia.com/rdp/cudnn-download

CUDA:下载安装即可

cuDNN:下载后,解压把 bin、lib等目录复制到cuda的安装目录下

6.2. 机器学习算法

算法一览表,根据不同问题选择不同算法

本期总结

本期分享了基于DJL的机器学习相关基础内容,浅尝辄止,希望有兴趣的读者朋友们一起探讨交流。

最后,感谢您的阅读!系列文章会同步在微信公众号@云上的喵酱、阿里云开发者社区@云上的喵酱、CSDN@笠泱 更新,您的点赞+关注+转发是我后续更新的动力!

相关文章
|
6天前
|
人工智能 数据可视化 安全
王炸组合!阿里云 OpenClaw X 飞书 CLI,开启 Agent 基建狂潮!(附带免费使用6个月服务器)
本文详解如何用阿里云Lighthouse一键部署OpenClaw,结合飞书CLI等工具,让AI真正“动手”——自动群发、生成科研日报、整理知识库。核心理念:未来软件应为AI而生,CLI即AI的“手脚”,实现高效、安全、可控的智能自动化。
18721 12
王炸组合!阿里云 OpenClaw X 飞书 CLI,开启 Agent 基建狂潮!(附带免费使用6个月服务器)
|
18天前
|
人工智能 JSON 机器人
让龙虾成为你的“公众号分身” | 阿里云服务器玩Openclaw
本文带你零成本玩转OpenClaw:学生认证白嫖6个月阿里云服务器,手把手配置飞书机器人、接入免费/高性价比AI模型(NVIDIA/通义),并打造微信公众号“全自动分身”——实时抓热榜、AI选题拆解、一键发布草稿,5分钟完成热点→文章全流程!
30219 141
让龙虾成为你的“公众号分身” | 阿里云服务器玩Openclaw
|
7天前
|
人工智能 JSON 监控
Claude Code 源码泄露:一份价值亿元的 AI 工程公开课
我以为顶级 AI 产品的护城河是模型。读完这 51.2 万行泄露的源码,我发现自己错了。
4630 20
|
6天前
|
人工智能 API 开发者
阿里云百炼 Coding Plan 售罄、Lite 停售、Pro 抢不到?最新解决方案
阿里云百炼Coding Plan Lite已停售,Pro版每日9:30限量抢购难度大。本文解析原因,并提供两大方案:①掌握技巧抢购Pro版;②直接使用百炼平台按量付费——新用户赠100万Tokens,支持Qwen3.5-Max等满血模型,灵活低成本。
1474 3
阿里云百炼 Coding Plan 售罄、Lite 停售、Pro 抢不到?最新解决方案