Join us in person and online for Firebase Summit on October 18, 2022. Learn how Firebase can help you accelerate app development, release your app with confidence, and scale with ease. Register now

在 Android 上使用自定義 TensorFlow Lite 模型

透過集合功能整理內容 你可以依據偏好儲存及分類內容。

如果您的應用使用自定義TensorFlow Lite模型,您可以使用 Firebase ML 來部署您的模型。通過使用 Firebase 部署模型,您可以減少應用的初始下載大小並更新應用的 ML 模型,而無需發布應用的新版本。而且,通過遠程配置和 A/B 測試,您可以為不同的用戶組動態地提供不同的模型。

TensorFlow Lite 模型

TensorFlow Lite 模型是經過優化以在移動設備上運行的 ML 模型。要獲取 TensorFlow Lite 模型:

在你開始之前

  1. 如果您還沒有,請將 Firebase 添加到您的 Android 項目中。
  2. 在您的模塊(應用級)Gradle 文件(通常是<project>/<app-module>/build.gradle )中,添加 Firebase ML 模型下載器 Android 庫的依賴項。我們建議使用Firebase Android BoM來控制庫版本控制。

    此外,作為設置 Firebase ML 模型下載器的一部分,您需要將 TensorFlow Lite SDK 添加到您的應用中。

    Java

    dependencies {
        // Import the BoM for the Firebase platform
        implementation platform('com.google.firebase:firebase-bom:30.5.0')
    
        // Add the dependency for the Firebase ML model downloader library
        // When using the BoM, you don't specify versions in Firebase library dependencies
        implementation 'com.google.firebase:firebase-ml-modeldownloader'
    // Also add the dependency for the TensorFlow Lite library and specify its version implementation 'org.tensorflow:tensorflow-lite:2.3.0'
    }

    通過使用Firebase Android BoM ,您的應用將始終使用兼容版本的 Firebase Android 庫。

    (替代)使用 BoM 的情況下添加 Firebase 庫依賴項

    如果您選擇不使用 Firebase BoM,則必須在其依賴行中指定每個 Firebase 庫版本。

    請注意,如果您在應用中使用多個Firebase 庫,我們強烈建議您使用 BoM 來管理庫版本,以確保所有版本都兼容。

    dependencies {
        // Add the dependency for the Firebase ML model downloader library
        // When NOT using the BoM, you must specify versions in Firebase library dependencies
        implementation 'com.google.firebase:firebase-ml-modeldownloader:24.0.5'
    // Also add the dependency for the TensorFlow Lite library and specify its version implementation 'org.tensorflow:tensorflow-lite:2.3.0'
    }

    Kotlin+KTX

    dependencies {
        // Import the BoM for the Firebase platform
        implementation platform('com.google.firebase:firebase-bom:30.5.0')
    
        // Add the dependency for the Firebase ML model downloader library
        // When using the BoM, you don't specify versions in Firebase library dependencies
        implementation 'com.google.firebase:firebase-ml-modeldownloader-ktx'
    // Also add the dependency for the TensorFlow Lite library and specify its version implementation 'org.tensorflow:tensorflow-lite:2.3.0'
    }

    通過使用Firebase Android BoM ,您的應用將始終使用兼容版本的 Firebase Android 庫。

    (替代)使用 BoM 的情況下添加 Firebase 庫依賴項

    如果您選擇不使用 Firebase BoM,則必須在其依賴行中指定每個 Firebase 庫版本。

    請注意,如果您在應用中使用多個Firebase 庫,我們強烈建議您使用 BoM 來管理庫版本,以確保所有版本都兼容。

    dependencies {
        // Add the dependency for the Firebase ML model downloader library
        // When NOT using the BoM, you must specify versions in Firebase library dependencies
        implementation 'com.google.firebase:firebase-ml-modeldownloader-ktx:24.0.5'
    // Also add the dependency for the TensorFlow Lite library and specify its version implementation 'org.tensorflow:tensorflow-lite:2.3.0'
    }
  3. 在您應用的清單中,聲明需要 INTERNET 權限:
    <uses-permission android:name="android.permission.INTERNET" />

1. 部署你的模型

使用 Firebase 控制台或 Firebase Admin Python 和 Node.js SDK 部署您的自定義 TensorFlow 模型。請參閱部署和管理自定義模型

將自定義模型添加到 Firebase 項目後,您可以使用您指定的名稱在應用中引用該模型。您可以隨時部署新的 TensorFlow Lite 模型,並通過調用getModel()將新模型下載到用戶的設備上(見下文)。

2. 將模型下載到設備並初始化 TensorFlow Lite 解釋器

要在您的應用中使用您的 TensorFlow Lite 模型,請首先使用 Firebase ML SDK 將最新版本的模型下載到設備中。然後,使用模型實例化一個 TensorFlow Lite 解釋器。

要開始模型下載,請調用模型下載器的getModel()方法,指定上傳模型時指定的名稱、是否要始終下載最新模型以及允許下載的條件。

您可以從三種下載行為中進行選擇:

下載類型描述
本地模型從設備獲取本地模型。如果沒有可用的本地模型,則其行為類似於LATEST_MODEL 。如果您對檢查模型更新不感興趣,請使用此下載類型。例如,您正在使用遠程配置來檢索模型名稱,並且您總是以新名稱上傳模型(推薦)。
LOCAL_MODEL_UPDATE_IN_BACKGROUND從設備獲取本地模型並開始在後台更新模型。如果沒有可用的本地模型,則其行為類似於LATEST_MODEL
最新款獲取最新型號。如果本地模型是最新版本,則返回本地模型。否則,請下載最新型號。在下載最新版本之前,此行為將被阻止(不推薦)。僅在您明確需要最新版本的情況下使用此行為。

您應該禁用與模型相關的功能(例如,灰顯或隱藏部分 UI),直到您確認模型已下載。

Java

CustomModelDownloadConditions conditions = new CustomModelDownloadConditions.Builder()
    .requireWifi()  // Also possible: .requireCharging() and .requireDeviceIdle()
    .build();
FirebaseModelDownloader.getInstance()
    .getModel("your_model", DownloadType.LOCAL_MODEL_UPDATE_IN_BACKGROUND, conditions)
    .addOnSuccessListener(new OnSuccessListener<CustomModel>() {
      @Override
      public void onSuccess(CustomModel model) {
        // Download complete. Depending on your app, you could enable the ML
        // feature, or switch from the local model to the remote model, etc.

        // The CustomModel object contains the local path of the model file,
        // which you can use to instantiate a TensorFlow Lite interpreter.
        File modelFile = model.getFile();
        if (modelFile != null) {
            interpreter = new Interpreter(modelFile);
        }
      }
    });

Kotlin+KTX

val conditions = CustomModelDownloadConditions.Builder()
        .requireWifi()  // Also possible: .requireCharging() and .requireDeviceIdle()
        .build()
FirebaseModelDownloader.getInstance()
        .getModel("your_model", DownloadType.LOCAL_MODEL_UPDATE_IN_BACKGROUND,
            conditions)
        .addOnSuccessListener { model: CustomModel? ->
            // Download complete. Depending on your app, you could enable the ML
            // feature, or switch from the local model to the remote model, etc.

            // The CustomModel object contains the local path of the model file,
            // which you can use to instantiate a TensorFlow Lite interpreter.
            val modelFile = model?.file
            if (modelFile != null) {
                interpreter = Interpreter(modelFile)
            }
        }

許多應用程序在其初始化代碼中啟動下載任務,但您可以在需要使用模型之前的任何時候這樣做。

3. 對輸入數據進行推理

獲取模型的輸入和輸出形狀

TensorFlow Lite 模型解釋器將一個或多個多維數組作為輸入並生成輸出。這些數組包含byteintlongfloat值。在將數據傳遞給模型或使用其結果之前,您必須知道模型使用的數組的數量和維度(“形狀”)。

如果您自己構建了模型,或者模型的輸入和輸出格式已記錄在案,您可能已經擁有此信息。如果您不知道模型輸入和輸出的形狀和數據類型,您可以使用 TensorFlow Lite 解釋器來檢查您的模型。例如:

Python

import tensorflow as tf

interpreter = tf.lite.Interpreter(model_path="your_model.tflite")
interpreter.allocate_tensors()

# Print input shape and type
inputs = interpreter.get_input_details()
print('{} input(s):'.format(len(inputs)))
for i in range(0, len(inputs)):
    print('{} {}'.format(inputs[i]['shape'], inputs[i]['dtype']))

# Print output shape and type
outputs = interpreter.get_output_details()
print('\n{} output(s):'.format(len(outputs)))
for i in range(0, len(outputs)):
    print('{} {}'.format(outputs[i]['shape'], outputs[i]['dtype']))

示例輸出:

1 input(s):
[  1 224 224   3] <class 'numpy.float32'>

1 output(s):
[1 1000] <class 'numpy.float32'>

運行解釋器

確定模型輸入和輸出的格式後,獲取輸入數據並對數據執行任何必要的轉換,以獲得模型正確形狀的輸入。

例如,如果您有一個輸入形狀為[1 224 224 3]浮點值的圖像分類模型,您可以從Bitmap對像生成輸入ByteBuffer ,如下例所示:

Java

Bitmap bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true);
ByteBuffer input = ByteBuffer.allocateDirect(224 * 224 * 3 * 4).order(ByteOrder.nativeOrder());
for (int y = 0; y < 224; y++) {
    for (int x = 0; x < 224; x++) {
        int px = bitmap.getPixel(x, y);

        // Get channel values from the pixel value.
        int r = Color.red(px);
        int g = Color.green(px);
        int b = Color.blue(px);

        // Normalize channel values to [-1.0, 1.0]. This requirement depends
        // on the model. For example, some models might require values to be
        // normalized to the range [0.0, 1.0] instead.
        float rf = (r - 127) / 255.0f;
        float gf = (g - 127) / 255.0f;
        float bf = (b - 127) / 255.0f;

        input.putFloat(rf);
        input.putFloat(gf);
        input.putFloat(bf);
    }
}

Kotlin+KTX

val bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true)
val input = ByteBuffer.allocateDirect(224*224*3*4).order(ByteOrder.nativeOrder())
for (y in 0 until 224) {
    for (x in 0 until 224) {
        val px = bitmap.getPixel(x, y)

        // Get channel values from the pixel value.
        val r = Color.red(px)
        val g = Color.green(px)
        val b = Color.blue(px)

        // Normalize channel values to [-1.0, 1.0]. This requirement depends on the model.
        // For example, some models might require values to be normalized to the range
        // [0.0, 1.0] instead.
        val rf = (r - 127) / 255f
        val gf = (g - 127) / 255f
        val bf = (b - 127) / 255f

        input.putFloat(rf)
        input.putFloat(gf)
        input.putFloat(bf)
    }
}

然後,分配一個足夠大的ByteBuffer來包含模型的輸出,並將輸入緩衝區和輸出緩衝區傳遞給 TensorFlow Lite 解釋器的run()方法。例如,對於[1 1000]浮點值的輸出形狀:

Java

int bufferSize = 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE;
ByteBuffer modelOutput = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder());
interpreter.run(input, modelOutput);

Kotlin+KTX

val bufferSize = 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE
val modelOutput = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder())
interpreter?.run(input, modelOutput)

您如何使用輸出取決於您使用的模型。

例如,如果您正在執行分類,作為下一步,您可能會將結果的索引映射到它們所代表的標籤:

Java

modelOutput.rewind();
FloatBuffer probabilities = modelOutput.asFloatBuffer();
try {
    BufferedReader reader = new BufferedReader(
            new InputStreamReader(getAssets().open("custom_labels.txt")));
    for (int i = 0; i < probabilities.capacity(); i++) {
        String label = reader.readLine();
        float probability = probabilities.get(i);
        Log.i(TAG, String.format("%s: %1.4f", label, probability));
    }
} catch (IOException e) {
    // File not found?
}

Kotlin+KTX

modelOutput.rewind()
val probabilities = modelOutput.asFloatBuffer()
try {
    val reader = BufferedReader(
            InputStreamReader(assets.open("custom_labels.txt")))
    for (i in probabilities.capacity()) {
        val label: String = reader.readLine()
        val probability = probabilities.get(i)
        println("$label: $probability")
    }
} catch (e: IOException) {
    // File not found?
}

附錄:模型安全

無論您如何使 TensorFlow Lite 模型可用於 Firebase ML,Firebase ML 都會以標準序列化 protobuf 格式將它們存儲在本地存儲中。

理論上,這意味著任何人都可以復制您的模型。然而,在實踐中,大多數模型都是特定於應用程序的,並且被優化混淆了,其風險類似於競爭對手反彙編和重用代碼的風險。不過,在您的應用程序中使用自定義模型之前,您應該意識到這種風險。

在 Android API 級別 21 (Lollipop) 和更高版本上,模型會下載到自動備份中排除的目錄。

在 Android API 級別 20 和更早版本上,模型會下載到應用專用內部存儲中名為com.google.firebase.ml.custom.models的目錄中。如果您使用BackupAgent啟用了文件備份,則可以選擇排除此目錄。