欢迎参加我们将于 2022 年 10 月 18 日举办的 Firebase 峰会(线上线下同时进行),了解 Firebase 如何帮助您加快应用开发速度、满怀信心地发布应用并在之后需要时轻松地扩大应用规模。立即报名

在 Android 上使用 AutoML 训练的模型检测图像中的对象

使用集合让一切井井有条 根据您的偏好保存内容并对其进行分类。

使用 AutoML Vision Edge 训练您自己的模型后,您可以在您的应用程序中使用它来检测图像中的对象。

有两种方法可以集成从 AutoML Vision Edge 训练的模型:您可以通过将模型放在应用的资产文件夹中来捆绑模型,也可以从 Firebase 动态下载它。

模型捆绑选项
捆绑在您的应用程序中
  • 该模型是您应用的 APK 的一部分
  • 该模型立即可用,即使 Android 设备处于离线状态
  • 无需 Firebase 项目
由 Firebase 托管
  • 通过将模型上传到Firebase 机器学习来托管模型
  • 减小 APK 大小
  • 模型按需下载
  • 无需重新发布应用即可推送模型更新
  • 使用Firebase 远程配置轻松进行 A/B 测试
  • 需要 Firebase 项目

在你开始之前

  1. 如果您想下载模型,请确保将 Firebase 添加到您的 Android 项目(如果您尚未这样做)。捆绑模型时不需要这样做。

  2. 将 TensorFlow Lite Task 库的依赖项添加到模块的应用级 gradle 文件中,该文件通常为app/build.gradle

    要将模型与您的应用程序捆绑在一起:

    dependencies {
      // ...
      // Object detection with a bundled Auto ML model
      implementation 'org.tensorflow:tensorflow-lite-task-vision:0.0.0-nightly-SNAPSHOT'
    }
    

    要从 Firebase 动态下载模型,还需要添加 Firebase ML 依赖项:

    dependencies {
      // ...
      // Object detection with an Auto ML model deployed to Firebase
      implementation platform('com.google.firebase:firebase-bom:26.1.1')
      implementation 'com.google.firebase:firebase-ml-model-interpreter'
    
      implementation 'org.tensorflow:tensorflow-lite-task-vision:0.0.0-nightly'
    }
    

1.加载模型

配置本地模型源

要将模型与您的应用程序捆绑在一起:

  1. 从您从 Google Cloud Console 下载的 zip 存档中提取模型。
  2. 将您的模型包含在您的应用程序包中:
    1. 如果您的项目中没有 assets 文件夹,请右键单击app/文件夹,然后单击New > Folder > Assets Folder创建一个。
    2. 将带有嵌入元数据的tflite模型文件复制到 assets 文件夹。
  3. 将以下内容添加到应用的build.gradle文件中,以确保 Gradle 在构建应用时不会压缩模型文件:

    android {
        // ...
        aaptOptions {
            noCompress "tflite"
        }
    }
    

    模型文件将包含在应用程序包中并作为原始资产提供。

配置 Firebase 托管的模型源

要使用远程托管模型,请创建一个RemoteModel对象,指定您在发布模型时为其分配的名称:

爪哇

// Specify the name you assigned when you deployed the model.
FirebaseCustomRemoteModel remoteModel =
        new FirebaseCustomRemoteModel.Builder("your_model").build();

科特林

// Specify the name you assigned when you deployed the model.
val remoteModel =
    FirebaseCustomRemoteModel.Builder("your_model_name").build()

然后,启动模型下载任务,指定您希望允许下载的条件。如果模型不在设备上,或者有更新版本的模型可用,任务将从 Firebase 异步下载模型:

爪哇

DownloadConditions downloadConditions = new DownloadConditions.Builder()
        .requireWifi()
        .build();
RemoteModelManager.getInstance().download(remoteModel, downloadConditions)
        .addOnSuccessListener(new OnSuccessListener<Void>() {
            @Override
            public void onSuccess(@NonNull Task<Void> task) {
                // Success.
            }
        });

科特林

val downloadConditions = DownloadConditions.Builder()
    .requireWifi()
    .build()
RemoteModelManager.getInstance().download(remoteModel, downloadConditions)
    .addOnSuccessListener {
        // Success.
    }

许多应用程序在其初始化代码中启动下载任务,但您可以在需要使用模型之前的任何时候这样做。

从您的模型创建对象检测器

配置模型源后,从其中之一创建ObjectDetector对象。

如果您只有一个本地捆绑的模型,只需从您的模型文件创建一个对象检测器并配置您想要的置信度分数阈值(请参阅评估您的模型):

爪哇

// Initialization
ObjectDetectorOptions options = ObjectDetectorOptions.builder()
    .setScoreThreshold(0)  // Evaluate your model in the Google Cloud Console
                           // to determine an appropriate value.
    .build();
ObjectDetector objectDetector = ObjectDetector.createFromFileAndOptions(context, modelFile, options);

科特林

// Initialization
val options = ObjectDetectorOptions.builder()
    .setScoreThreshold(0)  // Evaluate your model in the Google Cloud Console
                           // to determine an appropriate value.
    .build()
val objectDetector = ObjectDetector.createFromFileAndOptions(context, modelFile, options)

如果您有远程托管模型,则必须在运行之前检查它是否已下载。您可以使用模型管理器的isModelDownloaded()方法检查模型下载任务的状态。

尽管您只需要在运行对象检测器之前确认这一点,但如果您同时拥有远程托管模型和本地捆绑模型,则在实例化对象检测器时执行此检查可能是有意义的:从远程创建对象检测器如果已下载模型,则从本地模型下载。

爪哇

FirebaseModelManager.getInstance().isModelDownloaded(remoteModel)
        .addOnSuccessListener(new OnSuccessListener<Boolean>() {
            @Override
            public void onSuccess(Boolean isDownloaded) {
            }
        });

科特林

FirebaseModelManager.getInstance().isModelDownloaded(remoteModel)
        .addOnSuccessListener { success ->

        }

如果您只有远程托管的模型,则应禁用与模型相关的功能(例如,灰显或隐藏部分 UI),直到您确认模型已下载。您可以通过将侦听器附加到模型管理器的download()方法来实现。

一旦您知道您的模型已下载,请从模型文件创建对象检测器:

爪哇

FirebaseModelManager.getInstance().getLatestModelFile(remoteModel)
        .addOnCompleteListener(new OnCompleteListener<File>() {
            @Override
            public void onComplete(@NonNull Task<File> task) {
                File modelFile = task.getResult();
                if (modelFile != null) {
                    ObjectDetectorOptions options = ObjectDetectorOptions.builder()
                            .setScoreThreshold(0)
                            .build();
                    objectDetector = ObjectDetector.createFromFileAndOptions(
                            getApplicationContext(), modelFile.getPath(), options);
                }
            }
        });

科特林

FirebaseModelManager.getInstance().getLatestModelFile(remoteModel)
        .addOnSuccessListener { modelFile ->
            val options = ObjectDetectorOptions.builder()
                    .setScoreThreshold(0f)
                    .build()
            objectDetector = ObjectDetector.createFromFileAndOptions(
                    applicationContext, modelFile.path, options)
        }

2.准备输入图像

然后,对于您要标记的每个图像,从您的图像创建一个TensorImage对象。您可以使用fromBitmap方法从Bitmap创建TensorImage对象:

爪哇

TensorImage image = TensorImage.fromBitmap(bitmap);

科特林

val image = TensorImage.fromBitmap(bitmap)

如果您的图像数据不在Bitmap中,您可以加载像素数组,如TensorFlow Lite 文档中所示。

3. 运行物体检测器

要检测图像中的对象,请将TensorImage对象传递给ObjectDetectordetect()方法。

爪哇

List<Detection> results = objectDetector.detect(image);

科特林

val results = objectDetector.detect(image)

4. 获取有关标记对象的信息

如果对象检测操作成功,则返回Detection对象列表。每个Detection对象代表在图像中检测到的东西。您可以获得每个对象的边界框及其标签。

例如:

爪哇

for (Detection result : results) {
    RectF bounds = result.getBoundingBox();
    List<Category> labels = result.getCategories();
}

科特林

for (result in results) {
    val bounds = result.getBoundingBox()
    val labels = result.getCategories()
}

提高实时性能的技巧

如果您想在实时应用程序中标记图像,请遵循以下指南以获得最佳帧率:

  • 限制对图像标注器的调用。如果在图像标注器运行时有新的视频帧可用,则丢弃该帧。有关示例,请参阅快速入门示例应用程序中的VisionProcessorBase类。
  • 如果您使用图像标注器的输出在输入图像上叠加图形,首先获取结果,然后在一个步骤中渲染图像并叠加。通过这样做,您只为每个输入帧渲染到显示表面一次。有关示例,请参阅快速入门示例应用程序中的CameraSourcePreviewGraphicOverlay类。
  • 如果您使用 Camera2 API,请以ImageFormat.YUV_420_888格式捕获图像。

    如果您使用较旧的 Camera API,请以ImageFormat.NV21格式捕获图像。