在 Android 上使用自訂 TensorFlow Lite 模型

如果您的應用程式使用自訂 TensorFlow Lite 模型,則可以使用 Firebase ML 來部署模型。變更者: 可以縮減初始下載大小 應用程式及更新應用程式的機器學習模型,不必發布新版本 透過 Remote ConfigA/B Testing,您可以透過 向不同的使用者群組提供不同的模型。

TensorFlow Lite 模型

TensorFlow Lite 模型是最佳化的機器學習模型,適合在行動裝置上執行 裝置。如要取得 TensorFlow Lite 模型,請按照下列步驟操作:

事前準備

  1. 如果還沒試過 將 Firebase 新增至您的 Android 專案
  2. 模組 (應用程式層級) Gradle 檔案中 (通常為 <project>/<app-module>/build.gradle.kts<project>/<app-module>/build.gradle)、 新增 Android 適用的 Firebase ML 模型下載工具程式庫的依附元件。建議您使用 Firebase Android BoM敬上 管理程式庫版本管理

    此外,在設定 Firebase ML 模型下載工具時,您需要 導入 TensorFlow Lite SDK。

    dependencies {
        // Import the BoM for the Firebase platform
        implementation(platform("com.google.firebase:firebase-bom:33.5.1"))
    
        // Add the dependency for the Firebase ML model downloader library
        // When using the BoM, you don't specify versions in Firebase library dependencies
        implementation("com.google.firebase:firebase-ml-modeldownloader")
    // Also add the dependency for the TensorFlow Lite library and specify its version implementation("org.tensorflow:tensorflow-lite:2.3.0")
    }

    只要使用 Firebase Android BoM, 應用程式一律會使用相容的 Firebase Android 程式庫版本。

    (替代做法) 新增 Firebase 程式庫依附元件,「不使用」 BoM

    如果選擇不使用 Firebase BoM,請指定各個 Firebase 程式庫版本 都屬於依附元件行

    請注意,如果您在應用程式中使用多個 Firebase 程式庫,強烈建議您 建議使用 BoM 管理程式庫版本,確保所有版本 相容。

    dependencies {
        // Add the dependency for the Firebase ML model downloader library
        // When NOT using the BoM, you must specify versions in Firebase library dependencies
        implementation("com.google.firebase:firebase-ml-modeldownloader:25.0.1")
    // Also add the dependency for the TensorFlow Lite library and specify its version implementation("org.tensorflow:tensorflow-lite:2.3.0")
    }
    敬上
    在尋找 Kotlin 專用的程式庫模組嗎?距離開始還有 2023 年 10 月 (Firebase BoM 32.5.0),Kotlin 和 Java 開發人員皆可 依附於主要程式庫模組 (詳情請參閱 這項計畫的常見問題)。
  3. 在應用程式的資訊清單中,宣告需要 INTERNET 權限:
    <uses-permission android:name="android.permission.INTERNET" />

1. 部署模型

使用 Firebase 控制台或 Firebase Admin Python 和 Node.js SDK詳情請見 部署及管理自訂模型

在 Firebase 專案中加入自訂模型後,您就能參照 在應用程式中使用您指定的名稱您隨時可以將 新的 TensorFlow Lite 模型,並將新模型下載至使用者的裝置劃分依據: 呼叫 getModel() (請見下方)。

2. 將模型下載至裝置,然後初始化 TensorFlow Lite 解譯器

如要在應用程式中使用 TensorFlow Lite 模型,請先使用 Firebase ML SDK 下載最新版本的模型到裝置。接著,將 TensorFlow Lite 解譯器搭配模型。

如要開始下載模型,請呼叫模型下載工具的 getModel() 方法。 指定您在上傳模型時為模型指派的名稱 因此建議您一律下載最新模型 需要允許下載

有三種下載行為可供選擇:

下載類型 說明
LOCAL_MODEL 取得裝置的本機模型。 如果沒有本機模型可用,這個 的運作方式與 LATEST_MODEL 類似。使用這份草稿 然後下載 檢查模型更新狀態例如: 就是使用遠端設定 而且每次都會將模型上傳到 的新名稱 (建議)。
LOCAL_MODEL_UPDATE_IN_BACKGROUND 取得裝置的本機模型,並 開始在背景更新模型 如果沒有本機模型可用,這個 的運作方式與 LATEST_MODEL 類似。
最新機型 取得最新模型。如果本機模型 最新版本,會傳回本機 模型否則,請下載最新版本 模型這項行為會封鎖,直到 就會下載最新版本 建議)。這個行為僅限用於 明確需要 版本。

您應該停用模型相關功能,例如 隱藏部分使用者介面,直到您確認下載模型為止。

Kotlin+KTX

val conditions = CustomModelDownloadConditions.Builder()
        .requireWifi()  // Also possible: .requireCharging() and .requireDeviceIdle()
        .build()
FirebaseModelDownloader.getInstance()
        .getModel("your_model", DownloadType.LOCAL_MODEL_UPDATE_IN_BACKGROUND,
            conditions)
        .addOnSuccessListener { model: CustomModel? ->
            // Download complete. Depending on your app, you could enable the ML
            // feature, or switch from the local model to the remote model, etc.

            // The CustomModel object contains the local path of the model file,
            // which you can use to instantiate a TensorFlow Lite interpreter.
            val modelFile = model?.file
            if (modelFile != null) {
                interpreter = Interpreter(modelFile)
            }
        }

Java

CustomModelDownloadConditions conditions = new CustomModelDownloadConditions.Builder()
    .requireWifi()  // Also possible: .requireCharging() and .requireDeviceIdle()
    .build();
FirebaseModelDownloader.getInstance()
    .getModel("your_model", DownloadType.LOCAL_MODEL_UPDATE_IN_BACKGROUND, conditions)
    .addOnSuccessListener(new OnSuccessListener<CustomModel>() {
      @Override
      public void onSuccess(CustomModel model) {
        // Download complete. Depending on your app, you could enable the ML
        // feature, or switch from the local model to the remote model, etc.

        // The CustomModel object contains the local path of the model file,
        // which you can use to instantiate a TensorFlow Lite interpreter.
        File modelFile = model.getFile();
        if (modelFile != null) {
            interpreter = new Interpreter(modelFile);
        }
      }
    });

許多應用程式會在初始化程式碼中啟動下載工作,但您可以這麼做 因此在您需要使用模型前

3. 對輸入資料執行推論

取得模型的輸入和輸出形狀

TensorFlow Lite 模型解譯器會做為輸入內容並產生輸出內容 一或多個多維陣列這些陣列包含 byteintlongfloat 輕鬆分配獎金如要將資料傳遞至模型或使用其結果,請務必瞭解 模型使用的陣列數量和維度 (「形狀」)。

如果您是自行建構模型,或模型的輸入和輸出格式 您可能已有這項資訊如果不知道 模型輸入和輸出的形狀和資料類型 TensorFlow Lite 解譯器檢查模型。例如:

Python

import tensorflow as tf

interpreter = tf.lite.Interpreter(model_path="your_model.tflite")
interpreter.allocate_tensors()

# Print input shape and type
inputs = interpreter.get_input_details()
print('{} input(s):'.format(len(inputs)))
for i in range(0, len(inputs)):
    print('{} {}'.format(inputs[i]['shape'], inputs[i]['dtype']))

# Print output shape and type
outputs = interpreter.get_output_details()
print('\n{} output(s):'.format(len(outputs)))
for i in range(0, len(outputs)):
    print('{} {}'.format(outputs[i]['shape'], outputs[i]['dtype']))

輸出內容範例:

1 input(s):
[  1 224 224   3] <class 'numpy.float32'>

1 output(s):
[1 1000] <class 'numpy.float32'>

執行解譯器

確定模型的輸入和輸出內容格式後,請取得 先取得所需資料,再對必要的資料執行任何轉換 為模型建立正確形狀的輸入內容

舉例來說,假設圖片分類模型的輸入形狀 [1 224 224 3] 浮點值,可以產生輸入 ByteBufferBitmap 物件載入資料,如以下範例所示:

Kotlin+KTX

val bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true)
val input = ByteBuffer.allocateDirect(224*224*3*4).order(ByteOrder.nativeOrder())
for (y in 0 until 224) {
    for (x in 0 until 224) {
        val px = bitmap.getPixel(x, y)

        // Get channel values from the pixel value.
        val r = Color.red(px)
        val g = Color.green(px)
        val b = Color.blue(px)

        // Normalize channel values to [-1.0, 1.0]. This requirement depends on the model.
        // For example, some models might require values to be normalized to the range
        // [0.0, 1.0] instead.
        val rf = (r - 127) / 255f
        val gf = (g - 127) / 255f
        val bf = (b - 127) / 255f

        input.putFloat(rf)
        input.putFloat(gf)
        input.putFloat(bf)
    }
}

Java

Bitmap bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true);
ByteBuffer input = ByteBuffer.allocateDirect(224 * 224 * 3 * 4).order(ByteOrder.nativeOrder());
for (int y = 0; y < 224; y++) {
    for (int x = 0; x < 224; x++) {
        int px = bitmap.getPixel(x, y);

        // Get channel values from the pixel value.
        int r = Color.red(px);
        int g = Color.green(px);
        int b = Color.blue(px);

        // Normalize channel values to [-1.0, 1.0]. This requirement depends
        // on the model. For example, some models might require values to be
        // normalized to the range [0.0, 1.0] instead.
        float rf = (r - 127) / 255.0f;
        float gf = (g - 127) / 255.0f;
        float bf = (b - 127) / 255.0f;

        input.putFloat(rf);
        input.putFloat(gf);
        input.putFloat(bf);
    }
}

接著,分配夠大的 ByteBuffer 來容納模型輸出內容。 將輸入緩衝區和輸出緩衝區傳遞至 TensorFlow Lite 解譯器 run() 方法。例如 [1 1000] 浮點的輸出形狀 值:

Kotlin+KTX

val bufferSize = 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE
val modelOutput = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder())
interpreter?.run(input, modelOutput)

Java

int bufferSize = 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE;
ByteBuffer modelOutput = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder());
interpreter.run(input, modelOutput);

您使用輸出內容的方式取決於您使用的模型。

舉例來說,如果您接下來要進行分類 將結果的索引對應至它們代表的標籤:

Kotlin+KTX

modelOutput.rewind()
val probabilities = modelOutput.asFloatBuffer()
try {
    val reader = BufferedReader(
            InputStreamReader(assets.open("custom_labels.txt")))
    for (i in probabilities.capacity()) {
        val label: String = reader.readLine()
        val probability = probabilities.get(i)
        println("$label: $probability")
    }
} catch (e: IOException) {
    // File not found?
}

Java

modelOutput.rewind();
FloatBuffer probabilities = modelOutput.asFloatBuffer();
try {
    BufferedReader reader = new BufferedReader(
            new InputStreamReader(getAssets().open("custom_labels.txt")));
    for (int i = 0; i < probabilities.capacity(); i++) {
        String label = reader.readLine();
        float probability = probabilities.get(i);
        Log.i(TAG, String.format("%s: %1.4f", label, probability));
    }
} catch (IOException e) {
    // File not found?
}

附錄:模型安全性

無論您是以何種方式將 TensorFlow Lite 模型提供給 Firebase MLFirebase ML 會以標準的序列化 protobuf 格式儲存檔案,位置在 本機儲存空間

理論上,這代表任何人都可以複製您的模型不過 實務上,大多數模型 都特別適合應用程式,並經過模糊處理 並預測其他可能與競爭對手拆解及 重複使用程式碼然而,在使用 之前,請務必先瞭解此風險 在應用程式中加入自訂模型

在 Android API 級別 21 (Lollipop) 以上版本中,模型會下載到 目錄, 從自動備份中排除

在 Android API 級別 20 及以下級別中,模型會下載至目錄 名為「com.google.firebase.ml.custom.models」在「應用程式私人」之下 內部儲存空間如果使用 BackupAgent 啟用檔案備份功能, 您可以選擇排除這個目錄