如果您的应用使用自定义TensorFlow Lite模型,您可以使用 Firebase ML 来部署您的模型。通过使用 Firebase 部署模型,您可以减少应用的初始下载大小并更新应用的 ML 模型,而无需发布应用的新版本。而且,通过远程配置和 A/B 测试,您可以动态地为不同的用户组提供不同的模型。
TensorFlow Lite 模型
TensorFlow Lite 模型是经过优化以在移动设备上运行的 ML 模型。要获取 TensorFlow Lite 模型:
在你开始之前
- 如果您还没有,请将 Firebase 添加到您的 Android 项目中。
- 在您的模块(应用级)Gradle 文件(通常为
<project>/<app-module>/build.gradle
)中,添加 Firebase ML 模型下载器 Android 库的依赖项。我们建议使用Firebase Android BoM来控制库版本。此外,作为设置 Firebase ML 模型下载器的一部分,您需要将 TensorFlow Lite SDK 添加到您的应用中。
Kotlin+KTX
dependencies { // Import the BoM for the Firebase platform implementation platform('com.google.firebase:firebase-bom:31.2.0') // Add the dependency for the Firebase ML model downloader library // When using the BoM, you don't specify versions in Firebase library dependencies implementation 'com.google.firebase:firebase-ml-modeldownloader-ktx'
// Also add the dependency for the TensorFlow Lite library and specify its version implementation 'org.tensorflow:tensorflow-lite:2.3.0' }通过使用Firebase Android BoM ,您的应用将始终使用兼容版本的 Firebase Android 库。
(备选)在不使用 BoM 的情况下添加 Firebase 库依赖项
如果您选择不使用 Firebase BoM,则必须在其依赖项行中指定每个 Firebase 库版本。
请注意,如果您在应用中使用多个Firebase 库,我们强烈建议您使用 BoM 来管理库版本,以确保所有版本都兼容。
dependencies { // Add the dependency for the Firebase ML model downloader library // When NOT using the BoM, you must specify versions in Firebase library dependencies implementation 'com.google.firebase:firebase-ml-modeldownloader-ktx:24.1.2'
// Also add the dependency for the TensorFlow Lite library and specify its version implementation 'org.tensorflow:tensorflow-lite:2.3.0' }Java
dependencies { // Import the BoM for the Firebase platform implementation platform('com.google.firebase:firebase-bom:31.2.0') // Add the dependency for the Firebase ML model downloader library // When using the BoM, you don't specify versions in Firebase library dependencies implementation 'com.google.firebase:firebase-ml-modeldownloader'
// Also add the dependency for the TensorFlow Lite library and specify its version implementation 'org.tensorflow:tensorflow-lite:2.3.0' }通过使用Firebase Android BoM ,您的应用将始终使用兼容版本的 Firebase Android 库。
(备选)在不使用 BoM 的情况下添加 Firebase 库依赖项
如果您选择不使用 Firebase BoM,则必须在其依赖项行中指定每个 Firebase 库版本。
请注意,如果您在应用中使用多个Firebase 库,我们强烈建议您使用 BoM 来管理库版本,以确保所有版本都兼容。
dependencies { // Add the dependency for the Firebase ML model downloader library // When NOT using the BoM, you must specify versions in Firebase library dependencies implementation 'com.google.firebase:firebase-ml-modeldownloader:24.1.2'
// Also add the dependency for the TensorFlow Lite library and specify its version implementation 'org.tensorflow:tensorflow-lite:2.3.0' }- 在您应用的清单中,声明需要 INTERNET 权限:
<uses-permission android:name="android.permission.INTERNET" />
1. 部署你的模型
使用 Firebase 控制台或 Firebase Admin Python 和 Node.js SDK 部署自定义 TensorFlow 模型。请参阅部署和管理自定义模型。
将自定义模型添加到 Firebase 项目后,您可以使用指定的名称在应用中引用该模型。您可以随时部署新的 TensorFlow Lite 模型并通过调用
getModel()
(见下文)将新模型下载到用户的设备上。2. 将模型下载到设备并初始化一个 TensorFlow Lite 解释器
要在应用中使用 TensorFlow Lite 模型,请先使用 Firebase ML SDK 将最新版本的模型下载到设备。然后,使用模型实例化 TensorFlow Lite 解释器。要开始模型下载,请调用模型下载器的
getModel()
方法,指定上传时为模型分配的名称、是否要始终下载最新模型以及允许下载的条件。您可以从三种下载行为中进行选择:
下载类型 描述 本地型号 从设备获取本地模型。如果没有可用的本地模型,则其行为类似于 LATEST_MODEL
。如果您对检查模型更新不感兴趣,请使用此下载类型。例如,您正在使用 Remote Config 检索模型名称,并且始终以新名称上传模型(推荐)。LOCAL_MODEL_UPDATE_IN_BACKGROUND 从设备获取本地模型并开始在后台更新模型。如果没有可用的本地模型,则其行为类似于 LATEST_MODEL
。最新款 获取最新型号。如果本地模型是最新版本,则返回本地模型。否则,下载最新模型。在下载最新版本之前,此行为将被阻止(不推荐)。仅在您明确需要最新版本的情况下使用此行为。 您应该禁用与模型相关的功能——例如,灰显或隐藏部分 UI——直到您确认模型已下载。
Kotlin+KTX
val conditions = CustomModelDownloadConditions.Builder() .requireWifi() // Also possible: .requireCharging() and .requireDeviceIdle() .build() FirebaseModelDownloader.getInstance() .getModel("your_model", DownloadType.LOCAL_MODEL_UPDATE_IN_BACKGROUND, conditions) .addOnSuccessListener { model: CustomModel? -> // Download complete. Depending on your app, you could enable the ML // feature, or switch from the local model to the remote model, etc. // The CustomModel object contains the local path of the model file, // which you can use to instantiate a TensorFlow Lite interpreter. val modelFile = model?.file if (modelFile != null) { interpreter = Interpreter(modelFile) } }
Java
CustomModelDownloadConditions conditions = new CustomModelDownloadConditions.Builder() .requireWifi() // Also possible: .requireCharging() and .requireDeviceIdle() .build(); FirebaseModelDownloader.getInstance() .getModel("your_model", DownloadType.LOCAL_MODEL_UPDATE_IN_BACKGROUND, conditions) .addOnSuccessListener(new OnSuccessListener<CustomModel>() { @Override public void onSuccess(CustomModel model) { // Download complete. Depending on your app, you could enable the ML // feature, or switch from the local model to the remote model, etc. // The CustomModel object contains the local path of the model file, // which you can use to instantiate a TensorFlow Lite interpreter. File modelFile = model.getFile(); if (modelFile != null) { interpreter = new Interpreter(modelFile); } } });
许多应用程序在其初始化代码中开始下载任务,但您可以在需要使用该模型之前的任何时候执行此操作。
3. 对输入数据进行推理
获取模型的输入和输出形状
TensorFlow Lite 模型解释器将一个或多个多维数组作为输入并生成输出。这些数组包含
byte
、int
、long
或float
值。在将数据传递给模型或使用其结果之前,您必须知道模型使用的数组的数量和维度(“形状”)。如果您自己构建了模型,或者如果模型的输入和输出格式已记录在案,您可能已经有了这些信息。如果您不知道模型输入和输出的形状和数据类型,可以使用 TensorFlow Lite 解释器检查您的模型。例如:
Python
import tensorflow as tf interpreter = tf.lite.Interpreter(model_path="your_model.tflite") interpreter.allocate_tensors() # Print input shape and type inputs = interpreter.get_input_details() print('{} input(s):'.format(len(inputs))) for i in range(0, len(inputs)): print('{} {}'.format(inputs[i]['shape'], inputs[i]['dtype'])) # Print output shape and type outputs = interpreter.get_output_details() print('\n{} output(s):'.format(len(outputs))) for i in range(0, len(outputs)): print('{} {}'.format(outputs[i]['shape'], outputs[i]['dtype']))
示例输出:
1 input(s): [ 1 224 224 3] <class 'numpy.float32'> 1 output(s): [1 1000] <class 'numpy.float32'>
运行解释器
确定模型输入和输出的格式后,获取输入数据并对数据执行任何必要的转换,以获得模型正确形状的输入。例如,如果您的图像分类模型的输入形状为
[1 224 224 3]
浮点值,则可以从Bitmap
对象生成输入ByteBuffer
,如以下示例所示:Kotlin+KTX
val bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true) val input = ByteBuffer.allocateDirect(224*224*3*4).order(ByteOrder.nativeOrder()) for (y in 0 until 224) { for (x in 0 until 224) { val px = bitmap.getPixel(x, y) // Get channel values from the pixel value. val r = Color.red(px) val g = Color.green(px) val b = Color.blue(px) // Normalize channel values to [-1.0, 1.0]. This requirement depends on the model. // For example, some models might require values to be normalized to the range // [0.0, 1.0] instead. val rf = (r - 127) / 255f val gf = (g - 127) / 255f val bf = (b - 127) / 255f input.putFloat(rf) input.putFloat(gf) input.putFloat(bf) } }
Java
Bitmap bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true); ByteBuffer input = ByteBuffer.allocateDirect(224 * 224 * 3 * 4).order(ByteOrder.nativeOrder()); for (int y = 0; y < 224; y++) { for (int x = 0; x < 224; x++) { int px = bitmap.getPixel(x, y); // Get channel values from the pixel value. int r = Color.red(px); int g = Color.green(px); int b = Color.blue(px); // Normalize channel values to [-1.0, 1.0]. This requirement depends // on the model. For example, some models might require values to be // normalized to the range [0.0, 1.0] instead. float rf = (r - 127) / 255.0f; float gf = (g - 127) / 255.0f; float bf = (b - 127) / 255.0f; input.putFloat(rf); input.putFloat(gf); input.putFloat(bf); } }
然后,分配一个足够大的
ByteBuffer
来包含模型的输出,并将输入缓冲区和输出缓冲区传递给 TensorFlow Lite 解释器的run()
方法。例如,对于[1 1000]
浮点值的输出形状:Kotlin+KTX
val bufferSize = 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE val modelOutput = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder()) interpreter?.run(input, modelOutput)
Java
int bufferSize = 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE; ByteBuffer modelOutput = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder()); interpreter.run(input, modelOutput);
如何使用输出取决于您使用的模型。
例如,如果您正在执行分类,作为下一步,您可以将结果的索引映射到它们代表的标签:
Kotlin+KTX
modelOutput.rewind() val probabilities = modelOutput.asFloatBuffer() try { val reader = BufferedReader( InputStreamReader(assets.open("custom_labels.txt"))) for (i in probabilities.capacity()) { val label: String = reader.readLine() val probability = probabilities.get(i) println("$label: $probability") } } catch (e: IOException) { // File not found? }
Java
modelOutput.rewind(); FloatBuffer probabilities = modelOutput.asFloatBuffer(); try { BufferedReader reader = new BufferedReader( new InputStreamReader(getAssets().open("custom_labels.txt"))); for (int i = 0; i < probabilities.capacity(); i++) { String label = reader.readLine(); float probability = probabilities.get(i); Log.i(TAG, String.format("%s: %1.4f", label, probability)); } } catch (IOException e) { // File not found? }
附录:模型安全
无论您如何使 TensorFlow Lite 模型可用于 Firebase ML,Firebase ML 都会将它们以标准序列化 protobuf 格式存储在本地存储中。
从理论上讲,这意味着任何人都可以复制您的模型。然而,在实践中,大多数模型都是特定于应用程序的,并且被优化所混淆,以至于风险类似于竞争对手反汇编和重用您的代码的风险。然而,在您的应用程序中使用自定义模型之前,您应该意识到这种风险。
在 Android API 级别 21 (Lollipop) 和更新版本中,模型被下载到一个从自动备份中排除的目录。
在 Android API 级别 20 及更早版本上,模型将下载到应用私有内部存储中名为
com.google.firebase.ml.custom.models
的目录中。如果您使用BackupAgent
启用文件备份,您可以选择排除此目录。Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2023-02-03 UTC.
[{ "type": "thumb-down", "id": "missingTheInformationINeed", "label":"没有我需要的信息" },{ "type": "thumb-down", "id": "tooComplicatedTooManySteps", "label":"太复杂/步骤太多" },{ "type": "thumb-down", "id": "outOfDate", "label":"内容需要更新" },{ "type": "thumb-down", "id": "translationIssue", "label":"翻译问题" },{ "type": "thumb-down", "id": "samplesCodeIssue", "label":"示例/代码问题" },{ "type": "thumb-down", "id": "otherDown", "label":"其他" }] [{ "type": "thumb-up", "id": "easyToUnderstand", "label":"易于理解" },{ "type": "thumb-up", "id": "solvedMyProblem", "label":"解决了我的问题" },{ "type": "thumb-up", "id": "otherUp", "label":"其他" }] - 在您应用的清单中,声明需要 INTERNET 权限: