获取我们在 Firebase 峰会上发布的所有信息,了解 Firebase 可如何帮助您加快应用开发速度并满怀信心地运行应用。了解详情

在 Android 上使用 TensorFlow Lite 模型与机器学习套件进行推理

使用集合让一切井井有条 根据您的偏好保存内容并对其进行分类。

您可以使用 ML Kit 通过TensorFlow Lite模型执行设备上的推理。

此 API 需要 Android SDK 级别 16 (Jelly Bean) 或更高版本。

在你开始之前

  1. 如果您还没有,请将 Firebase 添加到您的 Android 项目中。
  2. 将 ML Kit Android 库的依赖项添加到您的模块(应用程序级)Gradle 文件(通常是app/build.gradle ):
    apply plugin: 'com.android.application'
    apply plugin: 'com.google.gms.google-services'
    
    dependencies {
      // ...
    
      implementation 'com.google.firebase:firebase-ml-model-interpreter:22.0.3'
    }
    
  3. 将您要使用的 TensorFlow 模型转换为 TensorFlow Lite 格式。请参阅TOCO:TensorFlow Lite 优化转换器

托管或捆绑您的模型

在您可以在您的应用程序中使用 TensorFlow Lite 模型进行推理之前,您必须使该模型可用于 ML Kit。 ML Kit 可以使用使用 Firebase 远程托管的 TensorFlow Lite 模型,与应用程序二进制文件捆绑在一起,或两者兼而有之。

通过在 Firebase 上托管模型,您可以在不发布新应用程序版本的情况下更新模型,并且可以使用远程配置和 A/B 测试为不同的用户组动态提供不同的模型。

如果您选择仅通过使用 Firebase 托管模型来提供模型,而不是将其与您的应用程序捆绑在一起,则可以减少应用程序的初始下载大小。但请记住,如果模型未与您的应用程序捆绑,则在您的应用程序首次下载该模型之前,任何与模型相关的功能都将不可用。

通过将您的模型与您的应用程序捆绑在一起,您可以确保您的应用程序的 ML 功能在 Firebase 托管的模型不可用时仍然有效。

在 Firebase 上托管模型

在 Firebase 上托管您的 TensorFlow Lite 模型:

  1. Firebase 控制台ML Kit部分,单击Custom选项卡。
  2. 单击添加自定义模型(或添加另一个模型)。
  3. 指定一个名称,用于在您的 Firebase 项目中识别您的模型,然后上传 TensorFlow Lite 模型文件(通常以.tflite.lite )。
  4. 在您应用的清单中,声明需要 INTERNET 权限:
    <uses-permission android:name="android.permission.INTERNET" />
    

将自定义模型添加到 Firebase 项目后,您可以使用指定的名称在应用中引用该模型。您可以随时上传新的 TensorFlow Lite 模型,您的应用程序将下载新模型并在应用程序下次重启时开始使用它。您可以定义应用程序尝试更新模型所需的设备条件(见下文)。

将模型与应用程序捆绑在一起

要将您的 TensorFlow Lite 模型与您的应用程序捆绑在一起,请将模型文件(通常以.tflite.lite结尾)复制到您应用程序的assets/文件夹中。 (您可能需要先创建文件夹,方法是右键单击app/文件夹,然后单击New > Folder > Assets Folder 。)

然后,将以下内容添加到应用程序的build.gradle文件中,以确保 Gradle 在构建应用程序时不会压缩模型:

android {

    // ...

    aaptOptions {
        noCompress "tflite"  // Your model's file extension: "tflite", "lite", etc.
    }
}

模型文件将包含在应用程序包中,并作为原始资产提供给 ML Kit。

加载模型

要在您的应用程序中使用您的 TensorFlow Lite 模型,首先使用您的模型可用的位置配置 ML Kit:远程使用 Firebase,在本地存储中,或两者兼而有之。如果同时指定本地和远程模型,则可以使用远程模型(如果可用),如果远程模型不可用则回退到本地存储的模型。

配置 Firebase 托管的模型

如果您使用 Firebase 托管您的模型,请创建一个FirebaseCustomRemoteModel对象,指定您在上传时分配给模型的名称:

Java

FirebaseCustomRemoteModel remoteModel =
        new FirebaseCustomRemoteModel.Builder("your_model").build();

Kotlin+KTX

val remoteModel = FirebaseCustomRemoteModel.Builder("your_model").build()

然后,启动模型下载任务,指定允许下载的条件。如果模型不在设备上,或者有更新版本的模型可用,任务将从 Firebase 异步下载模型:

Java

FirebaseModelDownloadConditions conditions = new FirebaseModelDownloadConditions.Builder()
        .requireWifi()
        .build();
FirebaseModelManager.getInstance().download(remoteModel, conditions)
        .addOnCompleteListener(new OnCompleteListener<Void>() {
            @Override
            public void onComplete(@NonNull Task<Void> task) {
                // Success.
            }
        });

Kotlin+KTX

val conditions = FirebaseModelDownloadConditions.Builder()
    .requireWifi()
    .build()
FirebaseModelManager.getInstance().download(remoteModel, conditions)
    .addOnCompleteListener {
        // Success.
    }

许多应用程序在其初始化代码中开始下载任务,但您可以在需要使用该模型之前的任何时候执行此操作。

配置本地模型

如果您将模型与您的应用捆绑在一起,请创建一个FirebaseCustomLocalModel对象,并指定 TensorFlow Lite 模型的文件名:

Java

FirebaseCustomLocalModel localModel = new FirebaseCustomLocalModel.Builder()
        .setAssetFilePath("your_model.tflite")
        .build();

Kotlin+KTX

val localModel = FirebaseCustomLocalModel.Builder()
    .setAssetFilePath("your_model.tflite")
    .build()

从您的模型创建解释器

配置模型源后,从其中之一创建一个FirebaseModelInterpreter对象。

如果您只有一个本地绑定的模型,只需从您的FirebaseCustomLocalModel对象创建一个解释器:

Java

FirebaseModelInterpreter interpreter;
try {
    FirebaseModelInterpreterOptions options =
            new FirebaseModelInterpreterOptions.Builder(localModel).build();
    interpreter = FirebaseModelInterpreter.getInstance(options);
} catch (FirebaseMLException e) {
    // ...
}

Kotlin+KTX

val options = FirebaseModelInterpreterOptions.Builder(localModel).build()
val interpreter = FirebaseModelInterpreter.getInstance(options)

如果您有一个远程托管的模型,则必须在运行之前检查它是否已下载。您可以使用模型管理器的isModelDownloaded()方法检查模型下载任务的状态。

虽然您只需要在运行解释器之前确认这一点,但如果您同时拥有远程托管模型和本地绑定模型,则在实例化模型解释器时执行此检查可能有意义:如果满足以下条件,则从远程模型创建解释器它已被下载,否则来自本地模型。

Java

FirebaseModelManager.getInstance().isModelDownloaded(remoteModel)
        .addOnSuccessListener(new OnSuccessListener<Boolean>() {
            @Override
            public void onSuccess(Boolean isDownloaded) {
                FirebaseModelInterpreterOptions options;
                if (isDownloaded) {
                    options = new FirebaseModelInterpreterOptions.Builder(remoteModel).build();
                } else {
                    options = new FirebaseModelInterpreterOptions.Builder(localModel).build();
                }
                FirebaseModelInterpreter interpreter = FirebaseModelInterpreter.getInstance(options);
                // ...
            }
        });

Kotlin+KTX

FirebaseModelManager.getInstance().isModelDownloaded(remoteModel)
    .addOnSuccessListener { isDownloaded -> 
    val options =
        if (isDownloaded) {
            FirebaseModelInterpreterOptions.Builder(remoteModel).build()
        } else {
            FirebaseModelInterpreterOptions.Builder(localModel).build()
        }
    val interpreter = FirebaseModelInterpreter.getInstance(options)
}

如果您只有一个远程托管的模型,您应该禁用与模型相关的功能——例如,灰显或隐藏部分 UI——直到您确认模型已下载。您可以通过将侦听器附加到模型管理器的download()方法来实现:

Java

FirebaseModelManager.getInstance().download(remoteModel, conditions)
        .addOnSuccessListener(new OnSuccessListener<Void>() {
            @Override
            public void onSuccess(Void v) {
              // Download complete. Depending on your app, you could enable
              // the ML feature, or switch from the local model to the remote
              // model, etc.
            }
        });

Kotlin+KTX

FirebaseModelManager.getInstance().download(remoteModel, conditions)
    .addOnCompleteListener {
        // Download complete. Depending on your app, you could enable the ML
        // feature, or switch from the local model to the remote model, etc.
    }

指定模型的输入和输出

接下来,配置模型解释器的输入和输出格式。

TensorFlow Lite 模型将一个或多个多维数组作为输入并生成输出。这些数组包含byteintlongfloat值。您必须使用模型使用的数组的数量和维度(“形状”)配置机器学习套件。

如果您不知道模型输入和输出的形状和数据类型,您可以使用 TensorFlow Lite Python 解释器来检查您的模型。例如:

import tensorflow as tf

interpreter = tf.lite.Interpreter(model_path="my_model.tflite")
interpreter.allocate_tensors()

# Print input shape and type
print(interpreter.get_input_details()[0]['shape'])  # Example: [1 224 224 3]
print(interpreter.get_input_details()[0]['dtype'])  # Example: <class 'numpy.float32'>

# Print output shape and type
print(interpreter.get_output_details()[0]['shape'])  # Example: [1 1000]
print(interpreter.get_output_details()[0]['dtype'])  # Example: <class 'numpy.float32'>

确定模型输入和输出的格式后,您可以通过创建FirebaseModelInputOutputOptions对象来配置应用程序的模型解释器。

例如,浮点图像分类模型可能将N x224x224x3 float值数组作为输入,表示一批N 224x224 三通道 (RGB) 图像,并生成 1000 个float值列表作为输出,每个浮点值表示图像是模型预测的 1000 个类别之一的概率。

对于这样的模型,您将配置模型解释器的输入和输出,如下所示:

Java

FirebaseModelInputOutputOptions inputOutputOptions =
        new FirebaseModelInputOutputOptions.Builder()
                .setInputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 224, 224, 3})
                .setOutputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 5})
                .build();

Kotlin+KTX

val inputOutputOptions = FirebaseModelInputOutputOptions.Builder()
        .setInputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1, 224, 224, 3))
        .setOutputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1, 5))
        .build()

对输入数据执行推理

最后,要使用模型执行推理,请获取您的输入数据并对数据执行任何必要的转换,以获得适合您的模型的正确形状的输入数组。

例如,如果您有一个输入形状为 [1 224 224 3] 浮点值的图像分类模型,您可以从Bitmap对象生成一个输入数组,如以下示例所示:

Java

Bitmap bitmap = getYourInputImage();
bitmap = Bitmap.createScaledBitmap(bitmap, 224, 224, true);

int batchNum = 0;
float[][][][] input = new float[1][224][224][3];
for (int x = 0; x < 224; x++) {
    for (int y = 0; y < 224; y++) {
        int pixel = bitmap.getPixel(x, y);
        // Normalize channel values to [-1.0, 1.0]. This requirement varies by
        // model. For example, some models might require values to be normalized
        // to the range [0.0, 1.0] instead.
        input[batchNum][x][y][0] = (Color.red(pixel) - 127) / 128.0f;
        input[batchNum][x][y][1] = (Color.green(pixel) - 127) / 128.0f;
        input[batchNum][x][y][2] = (Color.blue(pixel) - 127) / 128.0f;
    }
}

Kotlin+KTX

val bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true)

val batchNum = 0
val input = Array(1) { Array(224) { Array(224) { FloatArray(3) } } }
for (x in 0..223) {
    for (y in 0..223) {
        val pixel = bitmap.getPixel(x, y)
        // Normalize channel values to [-1.0, 1.0]. This requirement varies by
        // model. For example, some models might require values to be normalized
        // to the range [0.0, 1.0] instead.
        input[batchNum][x][y][0] = (Color.red(pixel) - 127) / 255.0f
        input[batchNum][x][y][1] = (Color.green(pixel) - 127) / 255.0f
        input[batchNum][x][y][2] = (Color.blue(pixel) - 127) / 255.0f
    }
}

然后,使用您的输入数据创建一个FirebaseModelInputs对象,并将它和模型的输入和输出规范传递给模型解释器run方法:

Java

FirebaseModelInputs inputs = new FirebaseModelInputs.Builder()
        .add(input)  // add() as many input arrays as your model requires
        .build();
firebaseInterpreter.run(inputs, inputOutputOptions)
        .addOnSuccessListener(
                new OnSuccessListener<FirebaseModelOutputs>() {
                    @Override
                    public void onSuccess(FirebaseModelOutputs result) {
                        // ...
                    }
                })
        .addOnFailureListener(
                new OnFailureListener() {
                    @Override
                    public void onFailure(@NonNull Exception e) {
                        // Task failed with an exception
                        // ...
                    }
                });

Kotlin+KTX

val inputs = FirebaseModelInputs.Builder()
        .add(input) // add() as many input arrays as your model requires
        .build()
firebaseInterpreter.run(inputs, inputOutputOptions)
        .addOnSuccessListener { result ->
            // ...
        }
        .addOnFailureListener { e ->
            // Task failed with an exception
            // ...
        }

如果调用成功,您可以通过调用传递给成功侦听器的对象的getOutput()方法来获取输出。例如:

Java

float[][] output = result.getOutput(0);
float[] probabilities = output[0];

Kotlin+KTX

val output = result.getOutput<Array<FloatArray>>(0)
val probabilities = output[0]

如何使用输出取决于您使用的模型。

例如,如果您正在执行分类,作为下一步,您可以将结果的索引映射到它们代表的标签:

Java

BufferedReader reader = new BufferedReader(
        new InputStreamReader(getAssets().open("retrained_labels.txt")));
for (int i = 0; i < probabilities.length; i++) {
    String label = reader.readLine();
    Log.i("MLKit", String.format("%s: %1.4f", label, probabilities[i]));
}

Kotlin+KTX

val reader = BufferedReader(
        InputStreamReader(assets.open("retrained_labels.txt")))
for (i in probabilities.indices) {
    val label = reader.readLine()
    Log.i("MLKit", String.format("%s: %1.4f", label, probabilities[i]))
}

附录:模型安全

无论您如何使您的 TensorFlow Lite 模型可供 ML Kit 使用,ML Kit 都会将它们以标准序列化 protobuf 格式存储在本地存储中。

从理论上讲,这意味着任何人都可以复制您的模型。然而,在实践中,大多数模型都是特定于应用程序的,并且被优化所混淆,以至于风险类似于竞争对手反汇编和重用您的代码的风险。然而,在您的应用程序中使用自定义模型之前,您应该意识到这种风险。

在 Android API 级别 21 (Lollipop) 和更新版本中,模型被下载到一个从自动备份中排除的目录。

在 Android API 级别 20 及更早版本上,模型将下载到应用私有内部存储中名为com.google.firebase.ml.custom.models的目录中。如果您使用BackupAgent启用文件备份,您可以选择排除此目录。