自学内容网 自学内容网

在 Java 项目中集成和使用 dl4j 实现通过扫描图片识别快递单信息

使用DL4J(DeepLearning4J)搭建一个简单的图像识别模型,并将其集成到Spring Boot后端中。我们将使用MNIST数据集来训练一个简单的卷积神经网络(CNN),然后将其部署到Spring Boot应用中。

1. 设置Spring Boot项目

首先,创建一个新的Spring Boot项目。你可以使用Spring Initializr(https://start.spring.io/)来快速生成项目结构。选择以下依赖:

  • Spring Web
  • Spring Boot DevTools
  • Lombok(可选,用于简化代码)

2. 添加DL4J依赖

在你的pom.xml文件中添加DL4J和相关依赖:

xml

<dependencies>
    <!-- Spring Boot Web -->
    <dependency>
        <groupId>org.springframework.boot</groupId>
        <artifactId>spring-boot-starter-web</artifactId>
    </dependency>

    <!-- DL4J -->
    <dependency>
        <groupId>org.deeplearning4j</groupId>
        <artifactId>deeplearning4j-core</artifactId>
        <version>1.0.0-beta7</version>
    </dependency>
    <dependency>
        <groupId>org.deeplearning4j</groupId>
        <artifactId>deeplearning4j-nn</artifactId>
        <version>1.0.0-beta7</version>
    </dependency>
    <dependency>
        <groupId>org.deeplearning4j</groupId>
        <artifactId>deeplearning4j-ui</artifactId>
        <version>1.0.0-beta7</version>
    </dependency>
    <dependency>
        <groupId>org.nd4j</groupId>
        <artifactId>nd4j-native-platform</artifactId>
        <version>1.0.0-beta7</version>
    </dependency>

    <!-- File Upload -->
    <dependency>
        <groupId>commons-fileupload</groupId>
        <artifactId>commons-fileupload</artifactId>
        <version>1.4</version>
    </dependency>
    <dependency>
        <groupId>commons-io</groupId>
        <artifactId>commons-io</artifactId>
        <version>2.11.0</version>
    </dependency>

    <!-- Lombok (optional) -->
    <dependency>
        <groupId>org.projectlombok</groupId>
        <artifactId>lombok</artifactId>
        <optional>true</optional>
    </dependency>
</dependencies>

3. 训练DL4J模型

我们将使用MNIST数据集来训练一个简单的卷积神经网络(CNN)。创建一个新的Java类MnistModelTrainer.java来训练模型:

java

package com.example.scanapp;

import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.conf.layers.objdetect.YoloOutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.learning.config.Adam;

import java.io.File;
import java.io.IOException;

public class MnistModelTrainer {

    public static void main(String[] args) throws IOException {
        int numEpochs = 10;
        int batchSize = 64;
        int numLabels = 10;
        int numRows = 28;
        int numColumns = 28;
        int numChannels = 1;

        // Load MNIST data
        DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize, true, 12345);
        DataSetIterator mnistTest = new MnistDataSetIterator(batchSize, false, 12345);

        // Preprocess data
        ImagePreProcessingScaler scaler = new ImagePreProcessingScaler(0, 1);
        mnistTrain.setPreProcessor(scaler);
        mnistTest.setPreProcessor(scaler);

        // Define the network architecture
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                .seed(12345)
                .updater(new Adam(0.001))
                .list()
                .layer(0, new ConvolutionLayer.Builder(5, 5)
                        .nIn(numChannels)
                        .nOut(20)
                        .stride(1, 1)
                        .activation(Activation.RELU)
                        .build())
                .layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
                        .kernelSize(2, 2)
                        .stride(2, 2)
                        .build())
                .layer(2, new ConvolutionLayer.Builder(5, 5)
                        .nOut(50)
                        .stride(1, 1)
                        .activation(Activation.RELU)
                        .build())
                .layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
                        .kernelSize(2, 2)
                        .stride(2, 2)
                        .build())
                .layer(4, new DenseLayer.Builder().activation(Activation.RELU)
                        .nOut(500).build())
                .layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
                        .nOut(numLabels)
                        .activation(Activation.SOFTMAX)
                        .build())
                .build();

        // Initialize the network
        MultiLayerNetwork model = new MultiLayerNetwork(conf);
        model.init();
        model.setListeners(new ScoreIterationListener(10));

        // Train the network
        for (int i = 0; i < numEpochs; i++) {
            model.fit(mnistTrain);
        }

        // Save the model
        File locationToSave = new File("mnist-model.zip");
        boolean saveUpdater = true; // Save the updater
        ModelSerializer.writeModel(model, locationToSave, saveUpdater);
    }
}

运行MnistModelTrainer类来训练模型并保存到mnist-model.zip文件中。

4. 创建Spring Boot Controller

创建一个新的Controller来处理图片上传和图像识别:

java

package com.example.scanapp.controller;

import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.nd4j.linalg.factory.Nd4j;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.multipart.MultipartFile;

import javax.imageio.ImageIO;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;

@RestController
public class ImageController {

    private static final String MODEL_PATH = "mnist-model.zip"; // 替换为你的模型路径
    private MultiLayerNetwork model;
    private ImagePreProcessingScaler scaler;

    public ImageController() throws IOException {
        this.model = ModelSerializer.restoreMultiLayerNetwork(new File(MODEL_PATH));
        this.scaler = new ImagePreProcessingScaler(0, 1);
    }

    @PostMapping("/recognize")
    public String recognize(@RequestParam("image") MultipartFile file) {
        try {
            BufferedImage image = ImageIO.read(file.getInputStream());
            INDArray imageArray = Nd4j.create(new int[]{1, 1, 28, 28});
            for (int i = 0; i < 28; i++) {
                for (int j = 0; j < 28; j++) {
                    int rgb = image.getRGB(j, i);
                    int gray = (rgb >> 16) & 0xFF; // Convert to grayscale
                    imageArray.putScalar(0, 0, i, j, gray / 255.0);
                }
            }

            scaler.transform(imageArray);
            INDArray output = model.output(imageArray);
            int predictedClass = output.argMax(1).getInt(0);

            return "Predicted class: " + predictedClass;
        } catch (IOException e) {
            e.printStackTrace();
            return "Error processing image";
        }
    }
}

5. 测试API

你可以使用Postman或其他工具来测试你的API。发送一个POST请求到/recognize端点,并附带一个MNIST格式的图片文件(28x28像素的灰度图像)。

6. 运行Spring Boot应用

确保你的Spring Boot应用能够正常启动。你可以通过以下命令运行应用:

bash

mvn spring-boot:run

7. 前端集成(可选)

如果你有一个前端应用(例如Vue.js),你可以创建一个简单的表单来上传图片并调用后端API。以下是一个简单的Vue.js组件示例:

vue

<template>
  <div>
    <h1>Image Recognition</h1>
    <input type="file" @change="onFileChange" accept="image/*" />
    <button @click="uploadImage">Upload</button>
    <p v-if="result">{{ result }}</p>
  </div>
</template>

<script>
export default {
  data() {
    return {
      file: null,
      result: ''
    };
  },
  methods: {
    onFileChange(e) {
      this.file = e.target.files[0];
    },
    async uploadImage() {
      const formData = new FormData();
      formData.append('image', this.file);

      try {
        const response = await fetch('http://localhost:8080/recognize', {
          method: 'POST',
          body: formData
        });
        const data = await response.text();
        this.result = data;
      } catch (error) {
        console.error('Error uploading image:', error);
      }
    }
  }
};
</script>

将上述Vue.js组件添加到你的Vue项目中,然后运行前端应用来测试整个流程。

通过以上步骤,你应该能够成功搭建一个使用DL4J模型的Spring Boot后端服务,并通过前端应用进行图像识别。


原文地址:https://blog.csdn.net/xiaozukun/article/details/144709937

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!