TensorFlow系列:第五讲:移动端部署模型
项目地址:https://github.com/LionJackson/imageClassification
Flutter项目地址:https://github.com/LionJackson/flutter_image
一. 模型转换
编写tflite模型工具类:
import os
import PIL
import tensorflow as tf
import keras
import numpy as np
from PIL.Image import Image
from matplotlib import pyplot as plt
from utils.dataset_loader import DatasetLoader
from utils.utils import Utils
"""
tflite模型工具类
"""
class TFLiteUtil:
def __init__(self, saved_model_dir, path_url):
self.save_model_dir = saved_model_dir
self.path_url = path_url
# 训练的模型生成标签列表
def get_folder_names(self):
folder_names = []
for root, dirs, files in os.walk(self.path_url + '/train'):
for dir_name in dirs:
folder_names.append(dir_name)
with open(self.save_model_dir + '.label', 'w') as file:
for name in folder_names:
file.write(name + '\n')
return folder_names
# 模型转成tflite格式
def convert_tflite(self):
self.get_folder_names()
converter = tf.lite.TFLiteConverter.from_saved_model(self.save_model_dir)
tflite_model = converter.convert()
# 将转换后的 TFLite 模型保存为文件
with open(self.save_model_dir + '.tflite', 'wb') as f:
f.write(tflite_model)
print("转换成功,已保存为 tflite")
# 加载keras并转成tflite
def convert_model_tflite(self):
self.get_folder_names()
model = keras.models.load_model(self.save_model_dir + ".keras")
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_types = [tf.float16]
tflite_model = converter.convert()
# 将转换后的 TFLite 模型保存为文件
with open(self.save_model_dir + '.tflite', 'wb') as f:
f.write(tflite_model)
print("转换成功(model),已保存为 tflite")
# 批量识别 进行可视化显示
def batch_evaluation(self, class_mode='categorical', image_size=(224, 224), num_images=25):
dataset_loader = DatasetLoader(self.path_url, image_size=image_size, class_mode=class_mode)
train_ds, val_ds, test_ds, class_names = dataset_loader.load_data()
interpreter = tf.lite.Interpreter(self.save_model_dir + '.tflite')
interpreter.allocate_tensors()
# 获取输入和输出张量的信息
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
plt.figure(figsize=(10, 10))
for images, labels in test_ds.take(1):
outputs = []
for img in images:
img_expanded = np.expand_dims(img, axis=0)
interpreter.set_tensor(input_details[0]['index'], img_expanded)
interpreter.invoke()
output = interpreter.get_tensor(output_details[0]['index'])
outputs.append(output)
for i in range(num_images):
plt.subplot(5, 5, i + 1)
image = np.array(images[i]).astype("uint8")
plt.imshow(image)
index = int(np.argmax(outputs[i]))
prediction = outputs[i][0][index]
percentage_str = "{:.2f}%".format(prediction * 100)
plt.title(f"{class_names[index]}: {percentage_str}")
plt.axis("off")
plt.subplots_adjust(hspace=0.5, wspace=0.5)
plt.show()
# 查看tflite模型信息
def tflite_analyzer(self):
# 加载 TFLite 模型
interpreter = tf.lite.Interpreter(model_path=self.save_model_dir + '.tflite')
interpreter.allocate_tensors()
# 获取输入和输出的详细信息
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
# 打印输入和输出的详细信息
print("Input Details:")
for detail in input_details:
print(detail)
print("\nOutput Details:")
for detail in output_details:
print(detail)
# 列出所有使用的算子
tensor_details = interpreter.get_tensor_details()
print("\nTensor Details:")
for tensor_detail in tensor_details:
print("Index:", tensor_detail['index'])
print("Name:", tensor_detail['name'])
print("Shape:", tensor_detail['shape'])
print("Shape Signature:", tensor_detail['shape_signature'])
print("dtype:", tensor_detail['dtype'])
print("Quantization:", tensor_detail['quantization'])
print("Quantization Parameters:", tensor_detail['quantization_parameters'])
print("Sparsity Parameters:", tensor_detail['sparsity_parameters'])
print()
引用工具类:
if __name__ == '__main__':
# train()
# model_util = ModelUtil(SAVED_MODEL_DIR, PATH_URL)
# model_util.batch_evaluation()
tflite_util = TFLiteUtil(SAVED_MODEL_DIR, PATH_URL)
tflite_util.convert_tflite()
tflite_util.tflite_analyzer()
tflite_util.batch_evaluation()
此时会生成tflite模型文件:
二. 使用模型
创建flutter项目,引入以下库:
image: ^4.0.17
path: ^1.8.3
path_provider: ^2.0.15
image_picker: ^0.8.8
tflite_flutter: ^0.10.4
camera: ^0.10.5+2
把模型文件拷贝到项目中:
核心代码:
import 'dart:developer';
import 'dart:io';
import 'dart:isolate';
import 'package:camera/camera.dart';
import 'package:flutter/services.dart';
import 'package:image/image.dart';
import 'package:tflite_flutter/tflite_flutter.dart';
import 'isolate_inference.dart';
class ImageClassificationHelper {
static const modelPath = 'assets/models/fruits.tflite';
static const labelsPath = 'assets/models/fruits.label';
late final Interpreter interpreter;
late final List<String> labels;
late final IsolateInference isolateInference;
late Tensor inputTensor;
late Tensor outputTensor;
// Load model
Future<void> _loadModel() async {
final options = InterpreterOptions();
// Use XNNPACK Delegate
if (Platform.isAndroid) {
options.addDelegate(XNNPackDelegate());
}
// Use GPU Delegate
// doesn't work on emulator
// if (Platform.isAndroid) {
// options.addDelegate(GpuDelegateV2());
// }
// Use Metal Delegate
if (Platform.isIOS) {
options.addDelegate(GpuDelegate());
}
// Load model from assets
interpreter = await Interpreter.fromAsset(modelPath, options: options);
// Get tensor input shape [1, 224, 224, 3]
inputTensor = interpreter.getInputTensors().first;
// Get tensor output shape [1, 1001]
outputTensor = interpreter.getOutputTensors().first;
log('Interpreter loaded successfully');
}
// Load labels from assets
Future<void> _loadLabels() async {
final labelTxt = await rootBundle.loadString(labelsPath);
labels = labelTxt.split('\n');
}
Future<void> initHelper() async {
_loadLabels();
_loadModel();
isolateInference = IsolateInference();
await isolateInference.start();
}
Future<Map<String, double>> _inference(InferenceModel inferenceModel) async {
ReceivePort responsePort = ReceivePort();
isolateInference.sendPort
.send(inferenceModel..responsePort = responsePort.sendPort);
// get inference result.
var results = await responsePort.first;
return results;
}
// inference camera frame
Future<Map<String, double>> inferenceCameraFrame(
CameraImage cameraImage) async {
var isolateModel = InferenceModel(cameraImage, null, interpreter.address,
labels, inputTensor.shape, outputTensor.shape);
return _inference(isolateModel);
}
// inference still image
Future<Map<String, double>> inferenceImage(Image image) async {
var isolateModel = InferenceModel(null, image, interpreter.address, labels,
inputTensor.shape, outputTensor.shape);
return _inference(isolateModel);
}
Future<void> close() async {
isolateInference.close();
}
}
页面部分:
原文地址:https://blog.csdn.net/wang_yong_hui_1234/article/details/140355816
免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!