在 Android 上使用自訂 TensorFlow Lite 模型

如果應用程式使用自訂 TensorFlow Lite 模型,可以透過 Firebase ML 部署模型。使用 Firebase 部署模型,可減少應用程式的初始下載大小,並更新應用程式的 ML 模型,不必發布新版應用程式。此外,您也可以使用 Remote ConfigA/B Testing,動態為不同使用者群組提供不同模型。

TensorFlow Lite 模型

TensorFlow Lite 模型是經過最佳化的機器學習模型,可在行動裝置上執行。如要取得 TensorFlow Lite 模型,請按照下列步驟操作:

事前準備

  1. 如果您尚未將 Firebase 新增至 Android 專案,請先新增。
  2. 模組 (應用程式層級) Gradle 檔案 (通常為 <project>/<app-module>/build.gradle.kts<project>/<app-module>/build.gradle) 中,新增 Android 適用的 Firebase ML 模型下載器程式庫依附元件。建議使用 Firebase Android BoM 控制程式庫版本。

    此外,設定 Firebase ML 模型下載工具時,您也需要在應用程式中新增 TensorFlow Lite SDK。

    dependencies {
        // Import the BoM for the Firebase platform
        implementation(platform("com.google.firebase:firebase-bom:34.0.0"))
    
        // Add the dependency for the Firebase ML model downloader library
        // When using the BoM, you don't specify versions in Firebase library dependencies
        implementation("com.google.firebase:firebase-ml-modeldownloader")
    // Also add the dependency for the TensorFlow Lite library and specify its version implementation("org.tensorflow:tensorflow-lite:2.3.0")
    }

    只要使用 Firebase Android BoM,應用程式就會一律使用相容的 Firebase Android 程式庫版本。

    如果選擇不使用 Firebase BoM,則必須在依附元件行中指定每個 Firebase 程式庫版本。

    請注意,如果應用程式使用多個 Firebase 程式庫,強烈建議使用 BoM 管理程式庫版本,確保所有版本都相容。

    dependencies {
        // Add the dependency for the Firebase ML model downloader library
        // When NOT using the BoM, you must specify versions in Firebase library dependencies
        implementation("com.google.firebase:firebase-ml-modeldownloader:26.0.0")
    // Also add the dependency for the TensorFlow Lite library and specify its version implementation("org.tensorflow:tensorflow-lite:2.3.0")
    }
  3. 在應用程式的資訊清單中,宣告需要 INTERNET 權限:
    <uses-permission android:name="android.permission.INTERNET" />

1. 部署模型

使用 Firebase 控制台或 Firebase Admin Python 和 Node.js SDK,部署自訂 TensorFlow 模型。請參閱「部署及管理自訂模型」。

將自訂模型新增至 Firebase 專案後,您可以使用指定的名稱在應用程式中參照模型。您隨時可以部署新的 TensorFlow Lite 模型,並呼叫 getModel() (如下所示),將新模型下載到使用者裝置上。

2. 將模型下載至裝置,並初始化 TensorFlow Lite 解譯器

如要在應用程式中使用 TensorFlow Lite 模型,請先使用 Firebase ML SDK 將最新版模型下載至裝置。接著,使用模型例項化 TensorFlow Lite 解譯器。

如要開始下載模型,請呼叫模型下載器的 getModel() 方法,並指定您上傳模型時指派的名稱、是否要一律下載最新模型,以及允許下載的條件。

您可以選擇三種下載行為:

下載類型 說明
LOCAL_MODEL 從裝置取得本機模型。 如果沒有可用的本機模型,這項功能的行為與 LATEST_MODEL 類似。如果您不想檢查模型更新,請使用這個下載類型。舉例來說,您使用遠端設定擷取模型名稱,並一律以新名稱上傳模型 (建議做法)。
LOCAL_MODEL_UPDATE_IN_BACKGROUND 從裝置取得本機模型,並在背景開始更新模型。如果沒有可用的本機模型,這項功能的行為與 LATEST_MODEL 相同。
LATEST_MODEL 取得最新型號。如果本機模型是最新版本,則會傳回本機模型。否則,請下載最新模型。這個行為會封鎖,直到下載最新版本為止 (不建議)。只有在明確需要最新版本時,才使用這項行為。

在確認模型已下載完畢前,您應停用模型相關功能,例如將部分 UI 設為灰色或隱藏。

KotlinJava
val conditions = CustomModelDownloadConditions.Builder()
        .requireWifi()  // Also possible: .requireCharging() and .requireDeviceIdle()
        .build()
FirebaseModelDownloader.getInstance()
        .getModel("your_model", DownloadType.LOCAL_MODEL_UPDATE_IN_BACKGROUND,
            conditions)
        .addOnSuccessListener { model: CustomModel? ->
            // Download complete. Depending on your app, you could enable the ML
            // feature, or switch from the local model to the remote model, etc.

            // The CustomModel object contains the local path of the model file,
            // which you can use to instantiate a TensorFlow Lite interpreter.
            val modelFile = model?.file
            if (modelFile != null) {
                interpreter = Interpreter(modelFile)
            }
        }
CustomModelDownloadConditions conditions = new CustomModelDownloadConditions.Builder()
    .requireWifi()  // Also possible: .requireCharging() and .requireDeviceIdle()
    .build();
FirebaseModelDownloader.getInstance()
    .getModel("your_model", DownloadType.LOCAL_MODEL_UPDATE_IN_BACKGROUND, conditions)
    .addOnSuccessListener(new OnSuccessListener<CustomModel>() {
      @Override
      public void onSuccess(CustomModel model) {
        // Download complete. Depending on your app, you could enable the ML
        // feature, or switch from the local model to the remote model, etc.

        // The CustomModel object contains the local path of the model file,
        // which you can use to instantiate a TensorFlow Lite interpreter.
        File modelFile = model.getFile();
        if (modelFile != null) {
            interpreter = new Interpreter(modelFile);
        }
      }
    });

許多應用程式會在初始化程式碼中啟動下載工作,但您可以在需要使用模型前的任何時間點執行這項操作。

3. 對輸入資料執行推論

取得模型的輸入和輸出形狀

TensorFlow Lite 模型解譯器會將一或多個多維度陣列做為輸入,並產生一或多個多維度陣列做為輸出。這些陣列包含 byteintlongfloat 值。如要將資料傳遞至模型或使用其結果,您必須瞭解模型使用的陣列數量和維度 (「形狀」)。

如果您自行建構模型,或模型輸入和輸出格式已記錄在文件中,您可能已經有這項資訊。如果您不知道模型輸入和輸出的形狀和資料類型,可以使用 TensorFlow Lite 解譯器檢查模型。例如:

Python
import tensorflow as tf

interpreter = tf.lite.Interpreter(model_path="your_model.tflite")
interpreter.allocate_tensors()

# Print input shape and type
inputs = interpreter.get_input_details()
print('{} input(s):'.format(len(inputs)))
for i in range(0, len(inputs)):
    print('{} {}'.format(inputs[i]['shape'], inputs[i]['dtype']))

# Print output shape and type
outputs = interpreter.get_output_details()
print('\n{} output(s):'.format(len(outputs)))
for i in range(0, len(outputs)):
    print('{} {}'.format(outputs[i]['shape'], outputs[i]['dtype']))

輸出內容範例:

1 input(s):
[  1 224 224   3] <class 'numpy.float32'>

1 output(s):
[1 1000] <class 'numpy.float32'>

執行解譯器

決定模型輸入和輸出的格式後,請取得輸入資料,並對資料執行任何必要的轉換,以取得適合模型的輸入內容。

舉例來說,如果您有輸入形狀為 [1 224 224 3] 浮點值的圖片分類模型,可以從 Bitmap 物件產生輸入 ByteBuffer,如下列範例所示:

KotlinJava
val bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true)
val input = ByteBuffer.allocateDirect(224*224*3*4).order(ByteOrder.nativeOrder())
for (y in 0 until 224) {
    for (x in 0 until 224) {
        val px = bitmap.getPixel(x, y)

        // Get channel values from the pixel value.
        val r = Color.red(px)
        val g = Color.green(px)
        val b = Color.blue(px)

        // Normalize channel values to [-1.0, 1.0]. This requirement depends on the model.
        // For example, some models might require values to be normalized to the range
        // [0.0, 1.0] instead.
        val rf = (r - 127) / 255f
        val gf = (g - 127) / 255f
        val bf = (b - 127) / 255f

        input.putFloat(rf)
        input.putFloat(gf)
        input.putFloat(bf)
    }
}
Bitmap bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true);
ByteBuffer input = ByteBuffer.allocateDirect(224 * 224 * 3 * 4).order(ByteOrder.nativeOrder());
for (int y = 0; y < 224; y++) {
    for (int x = 0; x < 224; x++) {
        int px = bitmap.getPixel(x, y);

        // Get channel values from the pixel value.
        int r = Color.red(px);
        int g = Color.green(px);
        int b = Color.blue(px);

        // Normalize channel values to [-1.0, 1.0]. This requirement depends
        // on the model. For example, some models might require values to be
        // normalized to the range [0.0, 1.0] instead.
        float rf = (r - 127) / 255.0f;
        float gf = (g - 127) / 255.0f;
        float bf = (b - 127) / 255.0f;

        input.putFloat(rf);
        input.putFloat(gf);
        input.putFloat(bf);
    }
}

接著,配置足夠大的 ByteBuffer 來包含模型的輸出內容,並將輸入緩衝區和輸出緩衝區傳遞至 TensorFlow Lite 解譯器的 run() 方法。舉例來說,如果輸出形狀為 [1 1000] 浮點值:

KotlinJava
val bufferSize = 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE
val modelOutput = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder())
interpreter?.run(input, modelOutput)
int bufferSize = 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE;
ByteBuffer modelOutput = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder());
interpreter.run(input, modelOutput);

輸出內容的使用方式取決於使用的模型。

舉例來說,如果您要執行分類作業,下一步可能就是將結果的索引對應至代表的標籤:

KotlinJava
modelOutput.rewind()
val probabilities = modelOutput.asFloatBuffer()
try {
    val reader = BufferedReader(
            InputStreamReader(assets.open("custom_labels.txt")))
    for (i in probabilities.capacity()) {
        val label: String = reader.readLine()
        val probability = probabilities.get(i)
        println("$label: $probability")
    }
} catch (e: IOException) {
    // File not found?
}
modelOutput.rewind();
FloatBuffer probabilities = modelOutput.asFloatBuffer();
try {
    BufferedReader reader = new BufferedReader(
            new InputStreamReader(getAssets().open("custom_labels.txt")));
    for (int i = 0; i < probabilities.capacity(); i++) {
        String label = reader.readLine();
        float probability = probabilities.get(i);
        Log.i(TAG, String.format("%s: %1.4f", label, probability));
    }
} catch (IOException e) {
    // File not found?
}

附錄:模型安全性

無論您如何提供 TensorFlow Lite 模型給 Firebase MLFirebase ML 都會以標準序列化 protobuf 格式將模型儲存在本機儲存空間。

從理論上來說,這表示任何人都能複製你的模型。不過,實際上大多數模型都經過最佳化,因此會針對特定應用程式進行混淆處理,風險與競爭對手拆解及重複使用您程式碼的風險類似。不過,在應用程式中使用自訂模型前,請務必瞭解這項風險。

在 Android API 級別 21 (Lollipop) 以上版本中,模型會下載至 自動備份排除的目錄

在 Android API 級別 20 以下版本中,模型會下載至應用程式私有內部儲存空間中名為 com.google.firebase.ml.custom.models 的目錄。如果您使用 BackupAgent 啟用檔案備份功能,可以選擇排除這個目錄。