在 Java 项目中集成和使用 dl4j 实现通过扫描图片识别快递单信息
使用DL4J(DeepLearning4J)搭建一个简单的图像识别模型,并将其集成到Spring Boot后端中。我们将使用MNIST数据集来训练一个简单的卷积神经网络(CNN),然后将其部署到Spring Boot应用中。
1. 设置Spring Boot项目
首先,创建一个新的Spring Boot项目。你可以使用Spring Initializr(https://start.spring.io/)来快速生成项目结构。选择以下依赖:
- Spring Web
- Spring Boot DevTools
- Lombok(可选,用于简化代码)
2. 添加DL4J依赖
在你的pom.xml
文件中添加DL4J和相关依赖:
xml
<dependencies>
<!-- Spring Boot Web -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<!-- DL4J -->
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-core</artifactId>
<version>1.0.0-beta7</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-nn</artifactId>
<version>1.0.0-beta7</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-ui</artifactId>
<version>1.0.0-beta7</version>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native-platform</artifactId>
<version>1.0.0-beta7</version>
</dependency>
<!-- File Upload -->
<dependency>
<groupId>commons-fileupload</groupId>
<artifactId>commons-fileupload</artifactId>
<version>1.4</version>
</dependency>
<dependency>
<groupId>commons-io</groupId>
<artifactId>commons-io</artifactId>
<version>2.11.0</version>
</dependency>
<!-- Lombok (optional) -->
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<optional>true</optional>
</dependency>
</dependencies>
3. 训练DL4J模型
我们将使用MNIST数据集来训练一个简单的卷积神经网络(CNN)。创建一个新的Java类MnistModelTrainer.java
来训练模型:
java
package com.example.scanapp;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.conf.layers.objdetect.YoloOutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.learning.config.Adam;
import java.io.File;
import java.io.IOException;
public class MnistModelTrainer {
public static void main(String[] args) throws IOException {
int numEpochs = 10;
int batchSize = 64;
int numLabels = 10;
int numRows = 28;
int numColumns = 28;
int numChannels = 1;
// Load MNIST data
DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize, true, 12345);
DataSetIterator mnistTest = new MnistDataSetIterator(batchSize, false, 12345);
// Preprocess data
ImagePreProcessingScaler scaler = new ImagePreProcessingScaler(0, 1);
mnistTrain.setPreProcessor(scaler);
mnistTest.setPreProcessor(scaler);
// Define the network architecture
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(12345)
.updater(new Adam(0.001))
.list()
.layer(0, new ConvolutionLayer.Builder(5, 5)
.nIn(numChannels)
.nOut(20)
.stride(1, 1)
.activation(Activation.RELU)
.build())
.layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
.kernelSize(2, 2)
.stride(2, 2)
.build())
.layer(2, new ConvolutionLayer.Builder(5, 5)
.nOut(50)
.stride(1, 1)
.activation(Activation.RELU)
.build())
.layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
.kernelSize(2, 2)
.stride(2, 2)
.build())
.layer(4, new DenseLayer.Builder().activation(Activation.RELU)
.nOut(500).build())
.layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.nOut(numLabels)
.activation(Activation.SOFTMAX)
.build())
.build();
// Initialize the network
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
model.setListeners(new ScoreIterationListener(10));
// Train the network
for (int i = 0; i < numEpochs; i++) {
model.fit(mnistTrain);
}
// Save the model
File locationToSave = new File("mnist-model.zip");
boolean saveUpdater = true; // Save the updater
ModelSerializer.writeModel(model, locationToSave, saveUpdater);
}
}
运行MnistModelTrainer
类来训练模型并保存到mnist-model.zip
文件中。
4. 创建Spring Boot Controller
创建一个新的Controller来处理图片上传和图像识别:
java
package com.example.scanapp.controller;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.nd4j.linalg.factory.Nd4j;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.multipart.MultipartFile;
import javax.imageio.ImageIO;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
@RestController
public class ImageController {
private static final String MODEL_PATH = "mnist-model.zip"; // 替换为你的模型路径
private MultiLayerNetwork model;
private ImagePreProcessingScaler scaler;
public ImageController() throws IOException {
this.model = ModelSerializer.restoreMultiLayerNetwork(new File(MODEL_PATH));
this.scaler = new ImagePreProcessingScaler(0, 1);
}
@PostMapping("/recognize")
public String recognize(@RequestParam("image") MultipartFile file) {
try {
BufferedImage image = ImageIO.read(file.getInputStream());
INDArray imageArray = Nd4j.create(new int[]{1, 1, 28, 28});
for (int i = 0; i < 28; i++) {
for (int j = 0; j < 28; j++) {
int rgb = image.getRGB(j, i);
int gray = (rgb >> 16) & 0xFF; // Convert to grayscale
imageArray.putScalar(0, 0, i, j, gray / 255.0);
}
}
scaler.transform(imageArray);
INDArray output = model.output(imageArray);
int predictedClass = output.argMax(1).getInt(0);
return "Predicted class: " + predictedClass;
} catch (IOException e) {
e.printStackTrace();
return "Error processing image";
}
}
}
5. 测试API
你可以使用Postman或其他工具来测试你的API。发送一个POST请求到/recognize
端点,并附带一个MNIST格式的图片文件(28x28像素的灰度图像)。
6. 运行Spring Boot应用
确保你的Spring Boot应用能够正常启动。你可以通过以下命令运行应用:
bash
mvn spring-boot:run
7. 前端集成(可选)
如果你有一个前端应用(例如Vue.js),你可以创建一个简单的表单来上传图片并调用后端API。以下是一个简单的Vue.js组件示例:
vue
<template>
<div>
<h1>Image Recognition</h1>
<input type="file" @change="onFileChange" accept="image/*" />
<button @click="uploadImage">Upload</button>
<p v-if="result">{{ result }}</p>
</div>
</template>
<script>
export default {
data() {
return {
file: null,
result: ''
};
},
methods: {
onFileChange(e) {
this.file = e.target.files[0];
},
async uploadImage() {
const formData = new FormData();
formData.append('image', this.file);
try {
const response = await fetch('http://localhost:8080/recognize', {
method: 'POST',
body: formData
});
const data = await response.text();
this.result = data;
} catch (error) {
console.error('Error uploading image:', error);
}
}
}
};
</script>
将上述Vue.js组件添加到你的Vue项目中,然后运行前端应用来测试整个流程。
通过以上步骤,你应该能够成功搭建一个使用DL4J模型的Spring Boot后端服务,并通过前端应用进行图像识别。
原文地址:https://blog.csdn.net/xiaozukun/article/details/144709937
免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!