在 Android 上使用自訂 TensorFlow Lite 模型

如果您的應用程式使用自訂 TensorFlow Lite 模型,您可以使用 Firebase ML 部署模型。透過 Firebase 部署模型,您可以縮減應用程式的初始下載大小,並更新應用程式的 ML 模型,而無需發布新版應用程式。此外,您還可以使用 Remote ConfigA/B Testing,為不同的使用者群組動態提供不同的模型。

TensorFlow Lite 模型

TensorFlow Lite 模型是經過最佳化,可在行動裝置上執行的機器學習模型。如要取得 TensorFlow Lite 模型,請按照下列步驟操作:

事前準備

  1. 如果您尚未將 Firebase 新增至 Android 專案,請新增 Firebase
  2. 模組 (應用程式層級) Gradle 檔案 (通常為 <project>/<app-module>/build.gradle.kts<project>/<app-module>/build.gradle) 中,加入 Android 專用的 Firebase ML 模型下載器程式庫依附元件。建議您使用 Firebase Android BoM 來控制程式庫版本。

    此外,您還需要在設定 Firebase ML 模型下載工具時,將 TensorFlow Lite SDK 新增至應用程式。

    dependencies {
        // Import the BoM for the Firebase platform
        implementation(platform("com.google.firebase:firebase-bom:33.7.0"))
    
        // Add the dependency for the Firebase ML model downloader library
        // When using the BoM, you don't specify versions in Firebase library dependencies
        implementation("com.google.firebase:firebase-ml-modeldownloader")
    // Also add the dependency for the TensorFlow Lite library and specify its version implementation("org.tensorflow:tensorflow-lite:2.3.0")
    }

    只要使用 Firebase Android BoM,應用程式就會一律使用相容的 Firebase Android 程式庫版本。

    (替代做法)  使用 BoM 新增 Firebase 程式庫依附元件

    如果您選擇不使用 Firebase BoM,則必須在依附元件行中指定每個 Firebase 程式庫版本。

    請注意,如果您在應用程式中使用多個 Firebase 程式庫,強烈建議您使用 BoM 來管理程式庫版本,確保所有版本皆相容。

    dependencies {
        // Add the dependency for the Firebase ML model downloader library
        // When NOT using the BoM, you must specify versions in Firebase library dependencies
        implementation("com.google.firebase:firebase-ml-modeldownloader:25.0.1")
    // Also add the dependency for the TensorFlow Lite library and specify its version implementation("org.tensorflow:tensorflow-lite:2.3.0")
    }
    想尋找 Kotlin 專屬的程式庫模組嗎?2023 年 10 月 (Firebase BoM 32.5.0)起,Kotlin 和 Java 開發人員都可以依附主要程式庫模組 (詳情請參閱這項計畫的常見問題)。
  3. 在應用程式的資訊清單中,宣告需要 INTERNET 權限:
    <uses-permission android:name="android.permission.INTERNET" />

1. 部署模型

您可以使用 Firebase 主控台或 Firebase Admin Python 和 Node.js SDK 部署自訂 TensorFlow 模型。請參閱「部署及管理自訂模型」。

將自訂模型新增至 Firebase 專案後,您就可以使用指定的名稱在應用程式中參照模型。您隨時可以呼叫 getModel() (請參閱下文),部署新的 TensorFlow Lite 模型,並將新模型下載到使用者的裝置上。

2. 將模型下載至裝置,並初始化 TensorFlow Lite 解譯器

如要在應用程式中使用 TensorFlow Lite 模型,請先使用 Firebase ML SDK 將最新版本的模型下載到裝置。接著,請使用模型將 TensorFlow Lite 解譯器例項化。

如要開始下載模型,請呼叫模型下載器的 getModel() 方法,指定您在上傳模型時指派的名稱、是否要一律下載最新模型,以及要允許下載的條件。

您可以選擇下列三種下載行為:

下載類型 說明
LOCAL_MODEL 從裝置取得本機模型。如果沒有可用的本機模型,這項操作會像 LATEST_MODEL 一樣運作。如果您不想檢查模型更新,請使用這個下載類型。舉例來說,您使用遠端設定擷取模型名稱,且一律以新名稱上傳模型 (建議做法)。
LOCAL_MODEL_UPDATE_IN_BACKGROUND 從裝置取得本機模型,並開始在背景更新模型。如果沒有可用的本機模型,則會像 LATEST_MODEL 一樣運作。
LATEST_MODEL 取得最新模型。如果本機模型是最新版本,則會傳回本機模型。否則,請下載最新模型。這個行為會阻斷,直到最新版本下載完成為止 (不建議使用)。只有在您明確需要最新版本時,才使用這項行為。

請在確認模型已下載完成前,停用模型相關功能,例如將 UI 部分設為灰色或隱藏。

Kotlin

val conditions = CustomModelDownloadConditions.Builder()
        .requireWifi()  // Also possible: .requireCharging() and .requireDeviceIdle()
        .build()
FirebaseModelDownloader.getInstance()
        .getModel("your_model", DownloadType.LOCAL_MODEL_UPDATE_IN_BACKGROUND,
            conditions)
        .addOnSuccessListener { model: CustomModel? ->
            // Download complete. Depending on your app, you could enable the ML
            // feature, or switch from the local model to the remote model, etc.

            // The CustomModel object contains the local path of the model file,
            // which you can use to instantiate a TensorFlow Lite interpreter.
            val modelFile = model?.file
            if (modelFile != null) {
                interpreter = Interpreter(modelFile)
            }
        }

Java

CustomModelDownloadConditions conditions = new CustomModelDownloadConditions.Builder()
    .requireWifi()  // Also possible: .requireCharging() and .requireDeviceIdle()
    .build();
FirebaseModelDownloader.getInstance()
    .getModel("your_model", DownloadType.LOCAL_MODEL_UPDATE_IN_BACKGROUND, conditions)
    .addOnSuccessListener(new OnSuccessListener<CustomModel>() {
      @Override
      public void onSuccess(CustomModel model) {
        // Download complete. Depending on your app, you could enable the ML
        // feature, or switch from the local model to the remote model, etc.

        // The CustomModel object contains the local path of the model file,
        // which you can use to instantiate a TensorFlow Lite interpreter.
        File modelFile = model.getFile();
        if (modelFile != null) {
            interpreter = new Interpreter(modelFile);
        }
      }
    });

許多應用程式會在初始化程式碼中啟動下載工作,但您可以在需要使用模型之前的任何時間點啟動下載工作。

3. 對輸入資料執行推論

取得模型的輸入和輸出形狀

TensorFlow Lite 模型解譯器會將一或多個多維陣列做為輸入,並產生一或多個多維陣列做為輸出。這些陣列包含 byteintlongfloat 值。您必須先瞭解模型使用的陣列數量和維度 (「形狀」),才能將資料傳遞至模型或使用模型的結果。

如果您自行建構模型,或是已記錄模型的輸入和輸出格式,可能就已經有這項資訊。如果您不知道模型輸入和輸出的形狀和資料類型,可以使用 TensorFlow Lite 解譯器檢查模型。例如:

Python

import tensorflow as tf

interpreter = tf.lite.Interpreter(model_path="your_model.tflite")
interpreter.allocate_tensors()

# Print input shape and type
inputs = interpreter.get_input_details()
print('{} input(s):'.format(len(inputs)))
for i in range(0, len(inputs)):
    print('{} {}'.format(inputs[i]['shape'], inputs[i]['dtype']))

# Print output shape and type
outputs = interpreter.get_output_details()
print('\n{} output(s):'.format(len(outputs)))
for i in range(0, len(outputs)):
    print('{} {}'.format(outputs[i]['shape'], outputs[i]['dtype']))

輸出內容範例:

1 input(s):
[  1 224 224   3] <class 'numpy.float32'>

1 output(s):
[1 1000] <class 'numpy.float32'>

執行解譯器

確定模型輸入和輸出格式後,請取得輸入資料,並對資料執行必要的轉換,以便取得模型適當形狀的輸入資料。

舉例來說,如果您有圖片分類模式,且輸入形狀為 [1 224 224 3] 浮點值,您可以從 Bitmap 物件產生輸入 ByteBuffer,如以下範例所示:

Kotlin

val bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true)
val input = ByteBuffer.allocateDirect(224*224*3*4).order(ByteOrder.nativeOrder())
for (y in 0 until 224) {
    for (x in 0 until 224) {
        val px = bitmap.getPixel(x, y)

        // Get channel values from the pixel value.
        val r = Color.red(px)
        val g = Color.green(px)
        val b = Color.blue(px)

        // Normalize channel values to [-1.0, 1.0]. This requirement depends on the model.
        // For example, some models might require values to be normalized to the range
        // [0.0, 1.0] instead.
        val rf = (r - 127) / 255f
        val gf = (g - 127) / 255f
        val bf = (b - 127) / 255f

        input.putFloat(rf)
        input.putFloat(gf)
        input.putFloat(bf)
    }
}

Java

Bitmap bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true);
ByteBuffer input = ByteBuffer.allocateDirect(224 * 224 * 3 * 4).order(ByteOrder.nativeOrder());
for (int y = 0; y < 224; y++) {
    for (int x = 0; x < 224; x++) {
        int px = bitmap.getPixel(x, y);

        // Get channel values from the pixel value.
        int r = Color.red(px);
        int g = Color.green(px);
        int b = Color.blue(px);

        // Normalize channel values to [-1.0, 1.0]. This requirement depends
        // on the model. For example, some models might require values to be
        // normalized to the range [0.0, 1.0] instead.
        float rf = (r - 127) / 255.0f;
        float gf = (g - 127) / 255.0f;
        float bf = (b - 127) / 255.0f;

        input.putFloat(rf);
        input.putFloat(gf);
        input.putFloat(bf);
    }
}

接著,請配置足以容納模型輸出的 ByteBuffer,並將輸入緩衝區和輸出緩衝區傳遞至 TensorFlow Lite 解譯器的 run() 方法。例如,如果輸出形狀為 [1 1000] 浮點值:

Kotlin

val bufferSize = 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE
val modelOutput = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder())
interpreter?.run(input, modelOutput)

Java

int bufferSize = 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE;
ByteBuffer modelOutput = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder());
interpreter.run(input, modelOutput);

輸出內容的使用方式取決於您使用的模型。

舉例來說,如果您要執行分類作業,下一步可以將結果的索引對應至所代表的標籤:

Kotlin

modelOutput.rewind()
val probabilities = modelOutput.asFloatBuffer()
try {
    val reader = BufferedReader(
            InputStreamReader(assets.open("custom_labels.txt")))
    for (i in probabilities.capacity()) {
        val label: String = reader.readLine()
        val probability = probabilities.get(i)
        println("$label: $probability")
    }
} catch (e: IOException) {
    // File not found?
}

Java

modelOutput.rewind();
FloatBuffer probabilities = modelOutput.asFloatBuffer();
try {
    BufferedReader reader = new BufferedReader(
            new InputStreamReader(getAssets().open("custom_labels.txt")));
    for (int i = 0; i < probabilities.capacity(); i++) {
        String label = reader.readLine();
        float probability = probabilities.get(i);
        Log.i(TAG, String.format("%s: %1.4f", label, probability));
    }
} catch (IOException e) {
    // File not found?
}

附錄:模型安全性

無論您如何將 TensorFlow Lite 模型提供給 Firebase MLFirebase ML 都會以標準序列化 protobuf 格式儲存在本機儲存空間中。

理論上,這表示任何人都可以複製您的模型。不過,實際上,大多數模型都是應用程式專屬,且經過最佳化處理而變得難以解讀,因此風險與競爭對手將您的程式碼反組合及重複使用相似。不過,在應用程式中使用自訂模型前,請先瞭解這項風險。

在 Android API 級別 21 (Lollipop) 以上版本中,系統會將模型下載至 從自動備份中排除的目錄。

在 Android API 級別 20 以下版本中,模型會下載至應用程式私人內部儲存空間中的 com.google.firebase.ml.custom.models 目錄。如果您使用 BackupAgent 啟用檔案備份功能,可以選擇排除這個目錄。