在 Android 上使用 AutoML 訓練的模型檢測圖像中的對象

使用 AutoML Vision Edge 訓練您自己的模型後,您可以在您的應用程序中使用它來檢測圖像中的對象。

有兩種方法可以集成從 AutoML Vision Edge 訓練的模型:您可以通過將模型放在應用的資產文件夾中來捆綁模型,也可以從 Firebase 動態下載它。

模型捆綁選項
捆綁在您的應用程序中
  • 該模型是您應用的 APK 的一部分
  • 該模型立即可用,即使 Android 設備處於離線狀態
  • 無需 Firebase 項目
由 Firebase 託管
  • 通過將模型上傳到Firebase 機器學習來託管模型
  • 減小 APK 大小
  • 模型按需下載
  • 無需重新發​​布應用即可推送模型更新
  • 使用Firebase 遠程配置輕鬆進行 A/B 測試
  • 需要 Firebase 項目

在你開始之前

  1. 如果您想下載模型,請確保將 Firebase 添加到您的 Android 項目(如果您尚未這樣做)。捆綁模型時不需要這樣做。

  2. 將 TensorFlow Lite Task 庫的依賴項添加到模塊的應用級 gradle 文件中,該文件通常為app/build.gradle

    要將模型與您的應用程序捆綁在一起:

    dependencies {
      // ...
      // Object detection with a bundled Auto ML model
      implementation 'org.tensorflow:tensorflow-lite-task-vision:0.0.0-nightly-SNAPSHOT'
    }
    

    要從 Firebase 動態下載模型,還需要添加 Firebase ML 依賴項:

    dependencies {
      // ...
      // Object detection with an Auto ML model deployed to Firebase
      implementation platform('com.google.firebase:firebase-bom:26.1.1')
      implementation 'com.google.firebase:firebase-ml-model-interpreter'
    
      implementation 'org.tensorflow:tensorflow-lite-task-vision:0.0.0-nightly'
    }
    

1.加載模型

配置本地模型源

要將模型與您的應用程序捆綁在一起:

  1. 從您從 Google Cloud Console 下載的 zip 存檔中提取模型。
  2. 將您的模型包含在您的應用程序包中:
    1. 如果您的項目中沒有 assets 文件夾,請右鍵單擊app/文件夾,然後單擊New > Folder > Assets Folder創建一個。
    2. 將帶有嵌入元數據的tflite模型文件複製到 assets 文件夾。
  3. 將以下內容添加到應用的build.gradle文件中,以確保 Gradle 在構建應用時不會壓縮模型文件:

    android {
        // ...
        aaptOptions {
            noCompress "tflite"
        }
    }
    

    模型文件將包含在應用程序包中並作為原始資產提供。

配置 Firebase 託管的模型源

要使用遠程託管模型,請創建一個RemoteModel對象,指定您在發布模型時為其分配的名稱:

爪哇

// Specify the name you assigned when you deployed the model.
FirebaseCustomRemoteModel remoteModel =
        new FirebaseCustomRemoteModel.Builder("your_model").build();

科特林

// Specify the name you assigned when you deployed the model.
val remoteModel =
    FirebaseCustomRemoteModel.Builder("your_model_name").build()

然後,啟動模型下載任務,指定您希望允許下載的條件。如果模型不在設備上,或者有更新版本的模型可用,任務將從 Firebase 異步下載模型:

爪哇

DownloadConditions downloadConditions = new DownloadConditions.Builder()
        .requireWifi()
        .build();
RemoteModelManager.getInstance().download(remoteModel, downloadConditions)
        .addOnSuccessListener(new OnSuccessListener<Void>() {
            @Override
            public void onSuccess(@NonNull Task<Void> task) {
                // Success.
            }
        });

科特林

val downloadConditions = DownloadConditions.Builder()
    .requireWifi()
    .build()
RemoteModelManager.getInstance().download(remoteModel, downloadConditions)
    .addOnSuccessListener {
        // Success.
    }

許多應用程序在其初始化代碼中啟動下載任務,但您可以在需要使用模型之前的任何時候這樣做。

從您的模型創建對象檢測器

配置模型源後,從其中之一創建ObjectDetector對象。

如果您只有一個本地捆綁的模型,只需從您的模型文件創建一個對象檢測器並配置您想要的置信度分數閾值(請參閱評估您的模型):

爪哇

// Initialization
ObjectDetectorOptions options = ObjectDetectorOptions.builder()
    .setScoreThreshold(0)  // Evaluate your model in the Google Cloud Console
                           // to determine an appropriate value.
    .build();
ObjectDetector objectDetector = ObjectDetector.createFromFileAndOptions(context, modelFile, options);

科特林

// Initialization
val options = ObjectDetectorOptions.builder()
    .setScoreThreshold(0)  // Evaluate your model in the Google Cloud Console
                           // to determine an appropriate value.
    .build()
val objectDetector = ObjectDetector.createFromFileAndOptions(context, modelFile, options)

如果您有遠程託管模型,則必須在運行之前檢查它是否已下載。您可以使用模型管理器的isModelDownloaded()方法檢查模型下載任務的狀態。

儘管您只需要在運行對象檢測器之前確認這一點,但如果您同時擁有遠程託管模型和本地捆綁模型,則在實例化對象檢測器時執行此檢查可能是有意義的:從遠程創建對象檢測器如果已下載模型,則從本地模型下載。

爪哇

FirebaseModelManager.getInstance().isModelDownloaded(remoteModel)
        .addOnSuccessListener(new OnSuccessListener<Boolean>() {
            @Override
            public void onSuccess(Boolean isDownloaded) {
            }
        });

科特林

FirebaseModelManager.getInstance().isModelDownloaded(remoteModel)
        .addOnSuccessListener { success ->

        }

如果您只有遠程託管的模型,則應禁用與模型相關的功能(例如,灰顯或隱藏部分 UI),直到您確認模型已下載。您可以通過將偵聽器附加到模型管理器的download()方法來實現。

一旦您知道您的模型已下載,請從模型文件創建對象檢測器:

爪哇

FirebaseModelManager.getInstance().getLatestModelFile(remoteModel)
        .addOnCompleteListener(new OnCompleteListener<File>() {
            @Override
            public void onComplete(@NonNull Task<File> task) {
                File modelFile = task.getResult();
                if (modelFile != null) {
                    ObjectDetectorOptions options = ObjectDetectorOptions.builder()
                            .setScoreThreshold(0)
                            .build();
                    objectDetector = ObjectDetector.createFromFileAndOptions(
                            getApplicationContext(), modelFile.getPath(), options);
                }
            }
        });

科特林

FirebaseModelManager.getInstance().getLatestModelFile(remoteModel)
        .addOnSuccessListener { modelFile ->
            val options = ObjectDetectorOptions.builder()
                    .setScoreThreshold(0f)
                    .build()
            objectDetector = ObjectDetector.createFromFileAndOptions(
                    applicationContext, modelFile.path, options)
        }

2.準備輸入圖像

然後,對於您要標記的每個圖像,從您的圖像創建一個TensorImage對象。您可以使用fromBitmap方法從Bitmap創建TensorImage對象:

爪哇

TensorImage image = TensorImage.fromBitmap(bitmap);

科特林

val image = TensorImage.fromBitmap(bitmap)

如果您的圖像數據不在Bitmap中,您可以加載像素數組,如TensorFlow Lite 文檔中所示。

3. 運行物體檢測器

要檢測圖像中的對象,請將TensorImage對像傳遞給ObjectDetectordetect()方法。

爪哇

List<Detection> results = objectDetector.detect(image);

科特林

val results = objectDetector.detect(image)

4. 獲取有關標記對象的信息

如果對象檢測操作成功,則返回Detection對象列表。每個Detection對象代表在圖像中檢測到的東西。您可以獲得每個對象的邊界框及其標籤。

例如:

爪哇

for (Detection result : results) {
    RectF bounds = result.getBoundingBox();
    List<Category> labels = result.getCategories();
}

科特林

for (result in results) {
    val bounds = result.getBoundingBox()
    val labels = result.getCategories()
}

提高實時性能的技巧

如果您想在實時應用程序中標記圖像,請遵循以下指南以獲得最佳幀率:
  • 限制對圖像標註器的調用。如果在圖像標註器運行時有新的視頻幀可用,則丟棄該幀。有關示例,請參閱快速入門示例應用程序中的VisionProcessorBase類。
  • 如果您使用圖像標註器的輸出在輸入圖像上疊加圖形,首先獲取結果,然後在一個步驟中渲染圖像並疊加。通過這樣做,您只為每個輸入幀渲染到顯示表面一次。有關示例,請參閱快速入門示例應用程序中的CameraSourcePreviewGraphicOverlay類。
  • 如果您使用 Camera2 API,請以ImageFormat.YUV_420_888格式捕獲圖像。

    如果您使用較舊的 Camera API,請以ImageFormat.NV21格式捕獲圖像。