Verwenden Sie ein benutzerdefiniertes TensorFlow Lite-Modell auf Android

Wenn Ihre App benutzerdefinierte TensorFlow Lite- Modelle verwendet, können Sie Firebase ML zum Bereitstellen Ihrer Modelle verwenden. Durch die Bereitstellung von Modellen mit Firebase können Sie die anfängliche Downloadgröße Ihrer App reduzieren und die ML-Modelle Ihrer App aktualisieren, ohne eine neue Version Ihrer App zu veröffentlichen. Und mit Remote Config und A/B-Testing können Sie unterschiedliche Modelle dynamisch für verschiedene Benutzergruppen bereitstellen.

TensorFlow Lite-Modelle

TensorFlow Lite-Modelle sind ML-Modelle, die für die Ausführung auf Mobilgeräten optimiert sind. So erhalten Sie ein TensorFlow Lite-Modell:

Bevor Sie beginnen

  1. Falls noch nicht geschehen, fügen Sie Firebase zu Ihrem Android-Projekt hinzu .
  2. Fügen Sie in Ihrer Modul-Gradle-Datei (auf App-Ebene) (normalerweise <project>/<app-module>/build.gradle.kts oder <project>/<app-module>/build.gradle ) die Abhängigkeit für den Firebase ML hinzu Modell-Downloader-Bibliothek für Android. Wir empfehlen die Verwendung der Firebase Android BoM zur Steuerung der Bibliotheksversionierung.

    Außerdem müssen Sie im Rahmen der Einrichtung des Firebase ML-Modell-Downloaders das TensorFlow Lite SDK zu Ihrer App hinzufügen.

    dependencies {
        // Import the BoM for the Firebase platform
        implementation(platform("com.google.firebase:firebase-bom:32.8.0"))
    
        // Add the dependency for the Firebase ML model downloader library
        // When using the BoM, you don't specify versions in Firebase library dependencies
        implementation("com.google.firebase:firebase-ml-modeldownloader")
    // Also add the dependency for the TensorFlow Lite library and specify its version implementation("org.tensorflow:tensorflow-lite:2.3.0")
    }

    Durch die Verwendung der Firebase Android BoM verwendet Ihre App immer kompatible Versionen der Firebase Android-Bibliotheken.

    (Alternative) Fügen Sie Firebase-Bibliotheksabhängigkeiten hinzu , ohne die Stückliste zu verwenden

    Wenn Sie die Firebase-Stückliste nicht verwenden möchten, müssen Sie jede Firebase-Bibliotheksversion in ihrer Abhängigkeitszeile angeben.

    Beachten Sie: Wenn Sie mehrere Firebase-Bibliotheken in Ihrer App verwenden, empfehlen wir dringend, die BoM zum Verwalten der Bibliotheksversionen zu verwenden, um sicherzustellen, dass alle Versionen kompatibel sind.

    dependencies {
        // Add the dependency for the Firebase ML model downloader library
        // When NOT using the BoM, you must specify versions in Firebase library dependencies
        implementation("com.google.firebase:firebase-ml-modeldownloader:24.2.3")
    // Also add the dependency for the TensorFlow Lite library and specify its version implementation("org.tensorflow:tensorflow-lite:2.3.0")
    }
    Suchen Sie nach einem Kotlin-spezifischen Bibliotheksmodul? Ab Oktober 2023 (Firebase BoM 32.5.0) können sich sowohl Kotlin- als auch Java-Entwickler auf das Hauptbibliotheksmodul verlassen (Einzelheiten finden Sie in den FAQ zu dieser Initiative ).
  3. Erklären Sie im Manifest Ihrer App, dass eine INTERNET-Berechtigung erforderlich ist:
    <uses-permission android:name="android.permission.INTERNET" />

1. Stellen Sie Ihr Modell bereit

Stellen Sie Ihre benutzerdefinierten TensorFlow-Modelle entweder mit der Firebase-Konsole oder den Firebase Admin Python- und Node.js-SDKs bereit. Siehe Benutzerdefinierte Modelle bereitstellen und verwalten .

Nachdem Sie Ihrem Firebase-Projekt ein benutzerdefiniertes Modell hinzugefügt haben, können Sie in Ihren Apps unter dem von Ihnen angegebenen Namen auf das Modell verweisen. Sie können jederzeit ein neues TensorFlow Lite-Modell bereitstellen und das neue Modell auf die Geräte der Benutzer herunterladen, indem Sie getModel() aufrufen (siehe unten).

2. Laden Sie das Modell auf das Gerät herunter und initialisieren Sie einen TensorFlow Lite-Interpreter

Um Ihr TensorFlow Lite-Modell in Ihrer App zu verwenden, laden Sie zunächst mit dem Firebase ML SDK die neueste Version des Modells auf das Gerät herunter. Instanziieren Sie dann einen TensorFlow Lite-Interpreter mit dem Modell.

Um den Modell-Download zu starten, rufen Sie die getModel() -Methode des Modell-Downloaders auf. Geben Sie dabei den Namen an, den Sie dem Modell beim Hochladen zugewiesen haben, ob Sie immer das neueste Modell herunterladen möchten und die Bedingungen, unter denen Sie den Download zulassen möchten.

Sie können zwischen drei Download-Verhalten wählen:

Download-Typ Beschreibung
LOCAL_MODEL Rufen Sie das lokale Modell vom Gerät ab. Wenn kein lokales Modell verfügbar ist, verhält sich dies wie LATEST_MODEL . Verwenden Sie diesen Download-Typ, wenn Sie nicht daran interessiert sind, nach Modellaktualisierungen zu suchen. Sie verwenden beispielsweise Remote Config zum Abrufen von Modellnamen und laden Modelle immer unter neuen Namen hoch (empfohlen).
LOCAL_MODEL_UPDATE_IN_BACKGROUND Holen Sie sich das lokale Modell vom Gerät und beginnen Sie mit der Aktualisierung des Modells im Hintergrund. Wenn kein lokales Modell verfügbar ist, verhält sich dies wie LATEST_MODEL .
NEUSTE MODELL Holen Sie sich das neueste Modell. Wenn das lokale Modell die neueste Version ist, wird das lokale Modell zurückgegeben. Andernfalls laden Sie das neueste Modell herunter. Dieses Verhalten wird blockiert, bis die neueste Version heruntergeladen wird (nicht empfohlen). Verwenden Sie dieses Verhalten nur in Fällen, in denen Sie ausdrücklich die neueste Version benötigen.

Sie sollten modellbezogene Funktionen deaktivieren – zum Beispiel einen Teil Ihrer Benutzeroberfläche ausgrauen oder ausblenden –, bis Sie bestätigen, dass das Modell heruntergeladen wurde.

Kotlin+KTX

val conditions = CustomModelDownloadConditions.Builder()
        .requireWifi()  // Also possible: .requireCharging() and .requireDeviceIdle()
        .build()
FirebaseModelDownloader.getInstance()
        .getModel("your_model", DownloadType.LOCAL_MODEL_UPDATE_IN_BACKGROUND,
            conditions)
        .addOnSuccessListener { model: CustomModel? ->
            // Download complete. Depending on your app, you could enable the ML
            // feature, or switch from the local model to the remote model, etc.

            // The CustomModel object contains the local path of the model file,
            // which you can use to instantiate a TensorFlow Lite interpreter.
            val modelFile = model?.file
            if (modelFile != null) {
                interpreter = Interpreter(modelFile)
            }
        }

Java

CustomModelDownloadConditions conditions = new CustomModelDownloadConditions.Builder()
    .requireWifi()  // Also possible: .requireCharging() and .requireDeviceIdle()
    .build();
FirebaseModelDownloader.getInstance()
    .getModel("your_model", DownloadType.LOCAL_MODEL_UPDATE_IN_BACKGROUND, conditions)
    .addOnSuccessListener(new OnSuccessListener<CustomModel>() {
      @Override
      public void onSuccess(CustomModel model) {
        // Download complete. Depending on your app, you could enable the ML
        // feature, or switch from the local model to the remote model, etc.

        // The CustomModel object contains the local path of the model file,
        // which you can use to instantiate a TensorFlow Lite interpreter.
        File modelFile = model.getFile();
        if (modelFile != null) {
            interpreter = new Interpreter(modelFile);
        }
      }
    });

Viele Apps starten die Download-Aufgabe in ihrem Initialisierungscode, Sie können dies jedoch jederzeit tun, bevor Sie das Modell verwenden müssen.

3. Führen Sie eine Inferenz auf Eingabedaten durch

Rufen Sie die Eingabe- und Ausgabeformen Ihres Modells ab

Der TensorFlow Lite-Modellinterpreter verwendet als Eingabe ein oder mehrere mehrdimensionale Arrays und erzeugt als Ausgabe. Diese Arrays enthalten entweder byte , int , long oder float -Werte. Bevor Sie Daten an ein Modell übergeben oder dessen Ergebnis verwenden können, müssen Sie die Anzahl und Abmessungen („Form“) der von Ihrem Modell verwendeten Arrays kennen.

Wenn Sie das Modell selbst erstellt haben oder das Eingabe- und Ausgabeformat des Modells dokumentiert ist, verfügen Sie möglicherweise bereits über diese Informationen. Wenn Sie die Form und den Datentyp der Eingabe und Ausgabe Ihres Modells nicht kennen, können Sie Ihr Modell mit dem TensorFlow Lite-Interpreter überprüfen. Zum Beispiel:

Python

import tensorflow as tf

interpreter = tf.lite.Interpreter(model_path="your_model.tflite")
interpreter.allocate_tensors()

# Print input shape and type
inputs = interpreter.get_input_details()
print('{} input(s):'.format(len(inputs)))
for i in range(0, len(inputs)):
    print('{} {}'.format(inputs[i]['shape'], inputs[i]['dtype']))

# Print output shape and type
outputs = interpreter.get_output_details()
print('\n{} output(s):'.format(len(outputs)))
for i in range(0, len(outputs)):
    print('{} {}'.format(outputs[i]['shape'], outputs[i]['dtype']))

Beispielausgabe:

1 input(s):
[  1 224 224   3] <class 'numpy.float32'>

1 output(s):
[1 1000] <class 'numpy.float32'>

Führen Sie den Interpreter aus

Nachdem Sie das Format der Eingabe und Ausgabe Ihres Modells bestimmt haben, rufen Sie Ihre Eingabedaten ab und führen Sie alle erforderlichen Transformationen an den Daten durch, um eine Eingabe mit der richtigen Form für Ihr Modell zu erhalten.

Wenn Sie beispielsweise über ein Bildklassifizierungsmodell mit einer Eingabeform von [1 224 224 3] Gleitkommawerten verfügen, könnten Sie einen Eingabe- ByteBuffer aus einem Bitmap Objekt generieren, wie im folgenden Beispiel gezeigt:

Kotlin+KTX

val bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true)
val input = ByteBuffer.allocateDirect(224*224*3*4).order(ByteOrder.nativeOrder())
for (y in 0 until 224) {
    for (x in 0 until 224) {
        val px = bitmap.getPixel(x, y)

        // Get channel values from the pixel value.
        val r = Color.red(px)
        val g = Color.green(px)
        val b = Color.blue(px)

        // Normalize channel values to [-1.0, 1.0]. This requirement depends on the model.
        // For example, some models might require values to be normalized to the range
        // [0.0, 1.0] instead.
        val rf = (r - 127) / 255f
        val gf = (g - 127) / 255f
        val bf = (b - 127) / 255f

        input.putFloat(rf)
        input.putFloat(gf)
        input.putFloat(bf)
    }
}

Java

Bitmap bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true);
ByteBuffer input = ByteBuffer.allocateDirect(224 * 224 * 3 * 4).order(ByteOrder.nativeOrder());
for (int y = 0; y < 224; y++) {
    for (int x = 0; x < 224; x++) {
        int px = bitmap.getPixel(x, y);

        // Get channel values from the pixel value.
        int r = Color.red(px);
        int g = Color.green(px);
        int b = Color.blue(px);

        // Normalize channel values to [-1.0, 1.0]. This requirement depends
        // on the model. For example, some models might require values to be
        // normalized to the range [0.0, 1.0] instead.
        float rf = (r - 127) / 255.0f;
        float gf = (g - 127) / 255.0f;
        float bf = (b - 127) / 255.0f;

        input.putFloat(rf);
        input.putFloat(gf);
        input.putFloat(bf);
    }
}

Weisen Sie dann einen ByteBuffer zu, der groß genug ist, um die Ausgabe des Modells aufzunehmen, und übergeben Sie den Eingabepuffer und den Ausgabepuffer an die run() Methode des TensorFlow Lite-Interpreters. Beispiel für eine Ausgabeform von [1 1000] Gleitkommawerten:

Kotlin+KTX

val bufferSize = 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE
val modelOutput = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder())
interpreter?.run(input, modelOutput)

Java

int bufferSize = 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE;
ByteBuffer modelOutput = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder());
interpreter.run(input, modelOutput);

Wie Sie die Ausgabe verwenden, hängt vom verwendeten Modell ab.

Wenn Sie beispielsweise eine Klassifizierung durchführen, können Sie im nächsten Schritt die Indizes des Ergebnisses den Beschriftungen zuordnen, die sie darstellen:

Kotlin+KTX

modelOutput.rewind()
val probabilities = modelOutput.asFloatBuffer()
try {
    val reader = BufferedReader(
            InputStreamReader(assets.open("custom_labels.txt")))
    for (i in probabilities.capacity()) {
        val label: String = reader.readLine()
        val probability = probabilities.get(i)
        println("$label: $probability")
    }
} catch (e: IOException) {
    // File not found?
}

Java

modelOutput.rewind();
FloatBuffer probabilities = modelOutput.asFloatBuffer();
try {
    BufferedReader reader = new BufferedReader(
            new InputStreamReader(getAssets().open("custom_labels.txt")));
    for (int i = 0; i < probabilities.capacity(); i++) {
        String label = reader.readLine();
        float probability = probabilities.get(i);
        Log.i(TAG, String.format("%s: %1.4f", label, probability));
    }
} catch (IOException e) {
    // File not found?
}

Anhang: Modellsicherheit

Unabhängig davon, wie Sie Ihre TensorFlow Lite-Modelle für Firebase ML verfügbar machen, speichert Firebase ML sie im standardmäßigen serialisierten Protobuf-Format im lokalen Speicher.

Theoretisch bedeutet das, dass jeder Ihr Modell kopieren kann. In der Praxis sind die meisten Modelle jedoch so anwendungsspezifisch und durch Optimierungen verschleiert, dass das Risiko mit dem der Konkurrenz vergleichbar ist, die Ihren Code zerlegt und wiederverwendet. Dennoch sollten Sie sich dieses Risikos bewusst sein, bevor Sie ein benutzerdefiniertes Modell in Ihrer App verwenden.

Auf Android API Level 21 (Lollipop) und höher wird das Modell in ein Verzeichnis heruntergeladen, das von der automatischen Sicherung ausgeschlossen ist.

Auf Android API Level 20 und älter wird das Modell in ein Verzeichnis namens com.google.firebase.ml.custom.models im privaten App-internen Speicher heruntergeladen. Wenn Sie die Dateisicherung mit BackupAgent aktiviert haben, können Sie dieses Verzeichnis ausschließen.