Save the date - Google I/O returns May 18-20. Register to get the most out of the digital experience: Build your schedule, reserve space, participate in Q&As, earn Google Developer profile badges, and more. Register now
本頁面由 Cloud Translation API 翻譯而成。
Switch to English

在Android上使用自定義TensorFlow Lite模型

如果您的應用程序使用自定義TensorFlow Lite模型,則可以使用Firebase ML部署模型。通過使用Firebase部署模型,您可以減少應用程序的初始下載大小並更新應用程序的ML模型,而無需發布新版本的應用程序。而且,通過遠程配置和A / B測試,您可以為不同的用戶組動態提供不同的模型。

TensorFlow Lite模型

TensorFlow Lite模型是經過優化可在移動設備上運行的ML模型。要獲取TensorFlow Lite模型:

在你開始之前

  1. 如果尚未將Firebase添加到您的Android項目中
  2. 使用Firebase Android BoM ,在模塊(應用程序級)Gradle文件(通常為app/build.gradle )中聲明Firebase ML模型下載器Android庫的依賴app/build.gradle

    另外,作為設置Firebase ML模型下載器的一部分,您需要將TensorFlow Lite SDK添加到您的應用中。

    爪哇

    dependencies {
        // Import the BoM for the Firebase platform
        implementation platform('com.google.firebase:firebase-bom:27.1.0')
    
        // Declare the dependency for the Firebase ML model downloader library
        // When using the BoM, you don't specify versions in Firebase library dependencies
        implementation 'com.google.firebase:firebase-ml-modeldownloader'
    // Also declare the dependency for the TensorFlow Lite library and specify its version implementation 'org.tensorflow:tensorflow-lite:2.3.0'
    }

    通過使用Firebase Android BoM ,您的應用將始終使用Firebase Android庫的兼容版本。

    (可選)使用BoM聲明Firebase庫依賴關係

    如果選擇不使用Firebase BoM,則必須在其依賴關係行中指定每個Firebase庫版本。

    請注意,如果您在應用中使用多個Firebase庫,我們強烈建議您使用BoM來管理庫版本,以確保所有版本都兼容。

    dependencies {
        // Declare the dependency for the Firebase ML model downloader library
        // When NOT using the BoM, you must specify versions in Firebase library dependencies
        implementation 'com.google.firebase:firebase-ml-modeldownloader:23.0.1'
    // Also declare the dependency for the TensorFlow Lite library and specify its version implementation 'org.tensorflow:tensorflow-lite:2.3.0'
    }

    Kotlin + KTX

    dependencies {
        // Import the BoM for the Firebase platform
        implementation platform('com.google.firebase:firebase-bom:27.1.0')
    
        // Declare the dependency for the Firebase ML model downloader library
        // When using the BoM, you don't specify versions in Firebase library dependencies
        implementation 'com.google.firebase:firebase-ml-modeldownloader-ktx'
    // Also declare the dependency for the TensorFlow Lite library and specify its version implementation 'org.tensorflow:tensorflow-lite:2.3.0'
    }

    通過使用Firebase Android BoM ,您的應用將始終使用Firebase Android庫的兼容版本。

    (可選)使用BoM聲明Firebase庫依賴關係

    如果選擇不使用Firebase BoM,則必須在其依賴關係行中指定每個Firebase庫版本。

    請注意,如果您在應用中使用多個Firebase庫,我們強烈建議您使用BoM來管理庫版本,以確保所有版本都兼容。

    dependencies {
        // Declare the dependency for the Firebase ML model downloader library
        // When NOT using the BoM, you must specify versions in Firebase library dependencies
        implementation 'com.google.firebase:firebase-ml-modeldownloader-ktx:23.0.1'
    // Also declare the dependency for the TensorFlow Lite library and specify its version implementation 'org.tensorflow:tensorflow-lite:2.3.0'
    }
  3. 在您的應用清單中,聲明需要INTERNET權限:
    <uses-permission android:name="android.permission.INTERNET" />

1.部署模型

使用Firebase控制台或Firebase Admin Python和Node.js SDK部署自定義TensorFlow模型。請參閱部署和管理定制模型

將自定義模型添加到Firebase項目後,可以使用您指定的名稱在應用程序中引用該模型。您隨時可以部署新的TensorFlow Lite模型,並通過調用getModel()將新模型下載到用戶設備上(請參見下文)。

2.將模型下載到設備並初始化TensorFlow Lite解釋器

要在您的應用中使用TensorFlow Lite模型,請首先使用Firebase ML SDK將模型的最新版本下載到設備上。然後,使用模型實例化TensorFlow Lite解釋器。

要開始下載模型,請調用模型下載器的getModel()方法,指定在上載模型時為模型分配的名稱,是否要始終下載最新模型以及允許下載的條件。

您可以從三種下載行為中進行選擇:

下載類型描述
LOCAL_MODEL從設備獲取本地模型。如果沒有可用的本地模型,則其行為類似於LATEST_MODEL 。如果您對檢查模型更新不感興趣,請使用此下載類型。例如,您正在使用“遠程配置”來檢索模型名稱,並且始終以新名稱上載模型(推薦)。
LOCAL_MODEL_UPDATE_IN_BACKGROUND從設備獲取本地模型,然後在後台開始更新模型。如果沒有可用的本地模型,則其行為類似於LATEST_MODEL
最新款獲取最新型號。如果本地模型是最新版本,則返回本地模型。否則,請下載最新型號。在下載最新版本之前,此行為將一直阻止(不推薦)。僅在明確需要最新版本的情況下,才使用此行為。

您應禁用與模型相關的功能,例如灰色或隱藏UI的一部分,直到您確認模型已下載。

爪哇

CustomModelDownloadConditions conditions = new CustomModelDownloadConditions.Builder()
    .requireWifi()  // Also possible: .requireCharging() and .requireDeviceIdle()
    .build();
FirebaseModelDownloader.getInstance()
    .getModel("your_model", DownloadType.LOCAL_MODEL_UPDATE_IN_BACKGROUND, conditions)
    .addOnSuccessListener(new OnSuccessListener<CustomModel>() {
      @Override
      public void onSuccess(CustomModel model) {
        // Download complete. Depending on your app, you could enable the ML
        // feature, or switch from the local model to the remote model, etc.

        // The CustomModel object contains the local path of the model file,
        // which you can use to instantiate a TensorFlow Lite interpreter.
        File modelFile = model.getFile();
        if (modelFile != null) {
            interpreter = new Interpreter(modelFile);
        }
      }
    });

Kotlin + KTX

val conditions = CustomModelDownloadConditions.Builder()
        .requireWifi()  // Also possible: .requireCharging() and .requireDeviceIdle()
        .build()
FirebaseModelDownloader.getInstance()
        .getModel("your_model", DownloadType.LOCAL_MODEL_UPDATE_IN_BACKGROUND,
            conditions)
        .addOnSuccessListener { customModel ->
            // Download complete. Depending on your app, you could enable the ML
            // feature, or switch from the local model to the remote model, etc.

            // The CustomModel object contains the local path of the model file,
            // which you can use to instantiate a TensorFlow Lite interpreter.
            .addOnSuccessListener { model: CustomModel? ->
                val modelFile = model?.file
                if (modelFile != null) {
                    interpreter = Interpreter(modelFile)
                }
        }

許多應用程序都使用其初始化代碼啟動下載任務,但是您可以在需要使用模型之前隨時進行下載。

3.對輸入數據進行推斷

獲取模型的輸入和輸出形狀

TensorFlow Lite模型解釋器將輸入作為輸入,並產生一個或多個多維數組作為輸出。這些數組包含byteintlongfloat值。在將數據傳遞到模型或使用其結果之前,必須知道模型使用的數組的數量和尺寸(“形狀”)。

如果您自己構建模型,或者記錄了模型的輸入和輸出格式,則可能已經有了此信息。如果您不知道模型輸入和輸出的形狀和數據類型,則可以使用TensorFlow Lite解釋器檢查模型。例如:

Python

import tensorflow as tf

interpreter = tf.lite.Interpreter(model_path="your_model.tflite")
interpreter.allocate_tensors()

# Print input shape and type
inputs = interpreter.get_input_details()
print('{} input(s):'.format(len(inputs)))
for i in range(0, len(inputs)):
    print('{} {}'.format(inputs[i]['shape'], inputs[i]['dtype']))

# Print output shape and type
outputs = interpreter.get_output_details()
print('\n{} output(s):'.format(len(outputs)))
for i in range(0, len(outputs)):
    print('{} {}'.format(outputs[i]['shape'], outputs[i]['dtype']))

輸出示例:

1 input(s):
[  1 224 224   3] <class 'numpy.float32'>

1 output(s):
[1 1000] <class 'numpy.float32'>

運行口譯員

確定模型輸入和輸出的格式後,獲取輸入數據並對數據進行任何必要的轉換,以獲取適合模型的正確形狀的輸入。

例如,如果您的圖像分類模型的輸入形狀為[1 224 224 3]浮點值,則可以從Bitmap像生成輸入ByteBuffer ,如以下示例所示:

爪哇

Bitmap bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true);
ByteBuffer input = ByteBuffer.allocateDirect(224 * 224 * 3 * 4).order(ByteOrder.nativeOrder());
for (int y = 0; y < 224; y++) {
    for (int x = 0; x < 224; x++) {
        int px = bitmap.getPixel(x, y);

        // Get channel values from the pixel value.
        int r = Color.red(px);
        int g = Color.green(px);
        int b = Color.blue(px);

        // Normalize channel values to [-1.0, 1.0]. This requirement depends
        // on the model. For example, some models might require values to be
        // normalized to the range [0.0, 1.0] instead.
        float rf = (r - 127) / 255.0f;
        float gf = (g - 127) / 255.0f;
        float bf = (b - 127) / 255.0f;

        input.putFloat(rf);
        input.putFloat(gf);
        input.putFloat(bf);
    }
}

Kotlin + KTX

val bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true)
val input = ByteBuffer.allocateDirect(224*224*3*4).order(ByteOrder.nativeOrder())
for (y in 0 until 224) {
    for (x in 0 until 224) {
        val px = bitmap.getPixel(x, y)

        // Get channel values from the pixel value.
        val r = Color.red(px)
        val g = Color.green(px)
        val b = Color.blue(px)

        // Normalize channel values to [-1.0, 1.0]. This requirement depends on the model.
        // For example, some models might require values to be normalized to the range
        // [0.0, 1.0] instead.
        val rf = (r - 127) / 255f
        val gf = (g - 127) / 255f
        val bf = (b - 127) / 255f

        input.putFloat(rf)
        input.putFloat(gf)
        input.putFloat(bf)
    }
}

然後,分配一個足夠大的ByteBuffer來容納模型的輸出,並將輸入緩衝區和輸出緩衝區傳遞給TensorFlow Lite解釋器的run()方法。例如,對於輸出形狀為[1 1000]浮點值:

爪哇

int bufferSize = 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE;
ByteBuffer modelOutput = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder());
interpreter.run(input, modelOutput);

Kotlin + KTX

val bufferSize = 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE
val modelOutput = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder())
interpreter?.run(input, modelOutput)

您如何使用輸出取決於您使用的模型。

例如,如果要執行分類,那麼下一步,您可以將結果的索引映射到它們代表的標籤上:

爪哇

modelOutput.rewind();
FloatBuffer probabilities = modelOutput.asFloatBuffer();
try {
    BufferedReader reader = new BufferedReader(
            new InputStreamReader(getAssets().open("custom_labels.txt")));
    for (int i = 0; i < probabilities.capacity(); i++) {
        String label = reader.readLine();
        float probability = probabilities.get(i);
        Log.i(TAG, String.format("%s: %1.4f", label, probability));
    }
} catch (IOException e) {
    // File not found?
}

Kotlin + KTX

modelOutput.rewind()
val probabilities = modelOutput.asFloatBuffer()
try {
    val reader = BufferedReader(
            InputStreamReader(assets.open("custom_labels.txt")))
    for (i in probabilities.capacity()) {
        val label: String = reader.readLine()
        val probability = probabilities.get(i)
        println("$label: $probability")
    }
} catch (e: IOException) {
    // File not found?
}

附錄:模型安全

無論您如何使TensorFlow Lite模型可用於Firebase ML,Firebase ML都將它們以標準的序列化protobuf格式存儲在本地存儲中。

從理論上講,這意味著任何人都可以復制您的模型。但是,實際上,大多數模型都是特定於應用程序的,並且由於優化而模糊不清,以至於風險與競爭者拆卸和重用您的代碼相似。但是,在應用程序中使用自定義模型之前,您應該意識到這種風險。

在Android API級別21(Lollipop)和更高版本上,該模型下載到自動備份排除的目錄中。

在Android API級別20和更高版本上,模型被下載到應用程序專用內部存儲中的名為com.google.firebase.ml.custom.models的目錄中。如果使用BackupAgent啟用了文件備份,則可以選擇排除此目錄。