获取我们在 Firebase 峰会上发布的所有信息,了解 Firebase 可如何帮助您加快应用开发速度并满怀信心地运行应用。了解详情

在 Android 上使用自定义 TensorFlow Lite 模型

使用集合让一切井井有条 根据您的偏好保存内容并对其进行分类。

如果您的应用使用自定义TensorFlow Lite模型,您可以使用 Firebase ML 来部署您的模型。通过使用 Firebase 部署模型,您可以减少应用的初始下载大小并更新应用的 ML 模型,而无需发布应用的新版本。而且,通过远程配置和 A/B 测试,您可以为不同的用户组动态地提供不同的模型。

TensorFlow Lite 模型

TensorFlow Lite 模型是经过优化以在移动设备上运行的 ML 模型。要获取 TensorFlow Lite 模型:

在你开始之前

  1. 如果您还没有,请将 Firebase 添加到您的 Android 项目中。
  2. 在您的模块(应用级)Gradle 文件(通常是<project>/<app-module>/build.gradle )中,添加 Firebase ML 模型下载器 Android 库的依赖项。我们建议使用Firebase Android BoM来控制库版本控制。

    此外,作为设置 Firebase ML 模型下载器的一部分,您需要将 TensorFlow Lite SDK 添加到您的应用中。

    Java

    dependencies {
        // Import the BoM for the Firebase platform
        implementation platform('com.google.firebase:firebase-bom:31.1.0')
    
        // Add the dependency for the Firebase ML model downloader library
        // When using the BoM, you don't specify versions in Firebase library dependencies
        implementation 'com.google.firebase:firebase-ml-modeldownloader'
    // Also add the dependency for the TensorFlow Lite library and specify its version implementation 'org.tensorflow:tensorflow-lite:2.3.0'
    }

    通过使用Firebase Android BoM ,您的应用将始终使用兼容版本的 Firebase Android 库。

    (替代)使用 BoM 的情况下添加 Firebase 库依赖项

    如果您选择不使用 Firebase BoM,则必须在其依赖行中指定每个 Firebase 库版本。

    请注意,如果您在应用中使用多个Firebase 库,我们强烈建议您使用 BoM 来管理库版本,以确保所有版本都兼容。

    dependencies {
        // Add the dependency for the Firebase ML model downloader library
        // When NOT using the BoM, you must specify versions in Firebase library dependencies
        implementation 'com.google.firebase:firebase-ml-modeldownloader:24.1.1'
    // Also add the dependency for the TensorFlow Lite library and specify its version implementation 'org.tensorflow:tensorflow-lite:2.3.0'
    }

    Kotlin+KTX

    dependencies {
        // Import the BoM for the Firebase platform
        implementation platform('com.google.firebase:firebase-bom:31.1.0')
    
        // Add the dependency for the Firebase ML model downloader library
        // When using the BoM, you don't specify versions in Firebase library dependencies
        implementation 'com.google.firebase:firebase-ml-modeldownloader-ktx'
    // Also add the dependency for the TensorFlow Lite library and specify its version implementation 'org.tensorflow:tensorflow-lite:2.3.0'
    }

    通过使用Firebase Android BoM ,您的应用将始终使用兼容版本的 Firebase Android 库。

    (替代)使用 BoM 的情况下添加 Firebase 库依赖项

    如果您选择不使用 Firebase BoM,则必须在其依赖行中指定每个 Firebase 库版本。

    请注意,如果您在应用中使用多个Firebase 库,我们强烈建议您使用 BoM 来管理库版本,以确保所有版本都兼容。

    dependencies {
        // Add the dependency for the Firebase ML model downloader library
        // When NOT using the BoM, you must specify versions in Firebase library dependencies
        implementation 'com.google.firebase:firebase-ml-modeldownloader-ktx:24.1.1'
    // Also add the dependency for the TensorFlow Lite library and specify its version implementation 'org.tensorflow:tensorflow-lite:2.3.0'
    }
  3. 在您应用的清单中,声明需要 INTERNET 权限:
    <uses-permission android:name="android.permission.INTERNET" />

1. 部署你的模型

使用 Firebase 控制台或 Firebase Admin Python 和 Node.js SDK 部署您的自定义 TensorFlow 模型。请参阅部署和管理自定义模型

将自定义模型添加到 Firebase 项目后,您可以使用您指定的名称在应用中引用该模型。您可以随时部署新的 TensorFlow Lite 模型,并通过调用getModel()将新模型下载到用户的设备上(见下文)。

2. 将模型下载到设备并初始化 TensorFlow Lite 解释器

要在您的应用中使用您的 TensorFlow Lite 模型,请首先使用 Firebase ML SDK 将最新版本的模型下载到设备中。然后,使用模型实例化一个 TensorFlow Lite 解释器。

要开始模型下载,请调用模型下载器的getModel()方法,指定上传模型时指定的名称、是否要始终下载最新模型以及允许下载的条件。

您可以从三种下载行为中进行选择:

下载类型描述
本地模型从设备获取本地模型。如果没有可用的本地模型,则其行为类似于LATEST_MODEL 。如果您对检查模型更新不感兴趣,请使用此下载类型。例如,您正在使用远程配置来检索模型名称,并且您总是以新名称上传模型(推荐)。
LOCAL_MODEL_UPDATE_IN_BACKGROUND从设备获取本地模型并开始在后台更新模型。如果没有可用的本地模型,则其行为类似于LATEST_MODEL
最新款获取最新型号。如果本地模型是最新版本,则返回本地模型。否则,请下载最新型号。在下载最新版本之前,此行为将被阻止(不推荐)。仅在您明确需要最新版本的情况下使用此行为。

您应该禁用与模型相关的功能(例如,灰显或隐藏部分 UI),直到您确认模型已下载。

Java

CustomModelDownloadConditions conditions = new CustomModelDownloadConditions.Builder()
    .requireWifi()  // Also possible: .requireCharging() and .requireDeviceIdle()
    .build();
FirebaseModelDownloader.getInstance()
    .getModel("your_model", DownloadType.LOCAL_MODEL_UPDATE_IN_BACKGROUND, conditions)
    .addOnSuccessListener(new OnSuccessListener<CustomModel>() {
      @Override
      public void onSuccess(CustomModel model) {
        // Download complete. Depending on your app, you could enable the ML
        // feature, or switch from the local model to the remote model, etc.

        // The CustomModel object contains the local path of the model file,
        // which you can use to instantiate a TensorFlow Lite interpreter.
        File modelFile = model.getFile();
        if (modelFile != null) {
            interpreter = new Interpreter(modelFile);
        }
      }
    });

Kotlin+KTX

val conditions = CustomModelDownloadConditions.Builder()
        .requireWifi()  // Also possible: .requireCharging() and .requireDeviceIdle()
        .build()
FirebaseModelDownloader.getInstance()
        .getModel("your_model", DownloadType.LOCAL_MODEL_UPDATE_IN_BACKGROUND,
            conditions)
        .addOnSuccessListener { model: CustomModel? ->
            // Download complete. Depending on your app, you could enable the ML
            // feature, or switch from the local model to the remote model, etc.

            // The CustomModel object contains the local path of the model file,
            // which you can use to instantiate a TensorFlow Lite interpreter.
            val modelFile = model?.file
            if (modelFile != null) {
                interpreter = Interpreter(modelFile)
            }
        }

许多应用程序在其初始化代码中启动下载任务,但您可以在需要使用模型之前的任何时候这样做。

3. 对输入数据进行推理

获取模型的输入和输出形状

TensorFlow Lite 模型解释器将一个或多个多维数组作为输入并生成输出。这些数组包含byteintlongfloat值。在将数据传递给模型或使用其结果之前,您必须知道模型使用的数组的数量和维度(“形状”)。

如果您自己构建了模型,或者模型的输入和输出格式已记录在案,您可能已经拥有此信息。如果您不知道模型输入和输出的形状和数据类型,您可以使用 TensorFlow Lite 解释器来检查您的模型。例如:

Python

import tensorflow as tf

interpreter = tf.lite.Interpreter(model_path="your_model.tflite")
interpreter.allocate_tensors()

# Print input shape and type
inputs = interpreter.get_input_details()
print('{} input(s):'.format(len(inputs)))
for i in range(0, len(inputs)):
    print('{} {}'.format(inputs[i]['shape'], inputs[i]['dtype']))

# Print output shape and type
outputs = interpreter.get_output_details()
print('\n{} output(s):'.format(len(outputs)))
for i in range(0, len(outputs)):
    print('{} {}'.format(outputs[i]['shape'], outputs[i]['dtype']))

示例输出:

1 input(s):
[  1 224 224   3] <class 'numpy.float32'>

1 output(s):
[1 1000] <class 'numpy.float32'>

运行解释器

确定模型输入和输出的格式后,获取输入数据并对数据执行任何必要的转换,以获得模型正确形状的输入。

例如,如果您有一个输入形状为[1 224 224 3]浮点值的图像分类模型,您可以从Bitmap对象生成输入ByteBuffer ,如下例所示:

Java

Bitmap bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true);
ByteBuffer input = ByteBuffer.allocateDirect(224 * 224 * 3 * 4).order(ByteOrder.nativeOrder());
for (int y = 0; y < 224; y++) {
    for (int x = 0; x < 224; x++) {
        int px = bitmap.getPixel(x, y);

        // Get channel values from the pixel value.
        int r = Color.red(px);
        int g = Color.green(px);
        int b = Color.blue(px);

        // Normalize channel values to [-1.0, 1.0]. This requirement depends
        // on the model. For example, some models might require values to be
        // normalized to the range [0.0, 1.0] instead.
        float rf = (r - 127) / 255.0f;
        float gf = (g - 127) / 255.0f;
        float bf = (b - 127) / 255.0f;

        input.putFloat(rf);
        input.putFloat(gf);
        input.putFloat(bf);
    }
}

Kotlin+KTX

val bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true)
val input = ByteBuffer.allocateDirect(224*224*3*4).order(ByteOrder.nativeOrder())
for (y in 0 until 224) {
    for (x in 0 until 224) {
        val px = bitmap.getPixel(x, y)

        // Get channel values from the pixel value.
        val r = Color.red(px)
        val g = Color.green(px)
        val b = Color.blue(px)

        // Normalize channel values to [-1.0, 1.0]. This requirement depends on the model.
        // For example, some models might require values to be normalized to the range
        // [0.0, 1.0] instead.
        val rf = (r - 127) / 255f
        val gf = (g - 127) / 255f
        val bf = (b - 127) / 255f

        input.putFloat(rf)
        input.putFloat(gf)
        input.putFloat(bf)
    }
}

然后,分配一个足够大的ByteBuffer来包含模型的输出,并将输入缓冲区和输出缓冲区传递给 TensorFlow Lite 解释器的run()方法。例如,对于[1 1000]浮点值的输出形状:

Java

int bufferSize = 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE;
ByteBuffer modelOutput = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder());
interpreter.run(input, modelOutput);

Kotlin+KTX

val bufferSize = 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE
val modelOutput = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder())
interpreter?.run(input, modelOutput)

您如何使用输出取决于您使用的模型。

例如,如果您正在执行分类,作为下一步,您可能会将结果的索引映射到它们所代表的标签:

Java

modelOutput.rewind();
FloatBuffer probabilities = modelOutput.asFloatBuffer();
try {
    BufferedReader reader = new BufferedReader(
            new InputStreamReader(getAssets().open("custom_labels.txt")));
    for (int i = 0; i < probabilities.capacity(); i++) {
        String label = reader.readLine();
        float probability = probabilities.get(i);
        Log.i(TAG, String.format("%s: %1.4f", label, probability));
    }
} catch (IOException e) {
    // File not found?
}

Kotlin+KTX

modelOutput.rewind()
val probabilities = modelOutput.asFloatBuffer()
try {
    val reader = BufferedReader(
            InputStreamReader(assets.open("custom_labels.txt")))
    for (i in probabilities.capacity()) {
        val label: String = reader.readLine()
        val probability = probabilities.get(i)
        println("$label: $probability")
    }
} catch (e: IOException) {
    // File not found?
}

附录:模型安全

无论您如何使 TensorFlow Lite 模型可用于 Firebase ML,Firebase ML 都会以标准序列化 protobuf 格式将它们存储在本地存储中。

理论上,这意味着任何人都可以复制您的模型。然而,在实践中,大多数模型都是特定于应用程序的,并且被优化混淆了,其风险类似于竞争对手反汇编和重用代码的风险。不过,在您的应用程序中使用自定义模型之前,您应该意识到这种风险。

在 Android API 级别 21 (Lollipop) 和更高版本上,模型会下载到自动备份中排除的目录。

在 Android API 级别 20 和更早版本上,模型会下载到应用专用内部存储中名为com.google.firebase.ml.custom.models的目录中。如果您使用BackupAgent启用了文件备份,则可以选择排除此目录。