Benutzerdefiniertes TensorFlow Lite-Modell unter Android verwenden

Wenn Ihre App benutzerdefinierte TensorFlow Lite-Modellen verwenden, können Sie sie mit Firebase ML bereitstellen. Von bei der Bereitstellung von Modellen mit Firebase die anfängliche Downloadgröße Ihre App aktualisieren und die ML-Modelle Ihrer App aktualisieren, ohne eine neue Version des für Ihre App. Mit Remote Config und A/B Testing können Sie außerdem verschiedene Modelle für verschiedene Gruppen von Nutzenden bereitzustellen.

TensorFlow Lite-Modelle

TensorFlow Lite-Modelle sind ML-Modelle, die für die Ausführung auf Mobilgeräten optimiert sind Geräte. So rufen Sie ein TensorFlow Lite-Modell ab:

Hinweis

  1. Falls noch nicht geschehen, Fügen Sie Firebase zu Ihrem Android-Projekt hinzu.
  2. In der Gradle-Datei des Moduls (auf App-Ebene) (normalerweise <project>/<app-module>/build.gradle.kts oder <project>/<app-module>/build.gradle) Fügen Sie die Abhängigkeit für die Firebase ML-Modelldownloader-Bibliothek für Android hinzu. Wir empfehlen die Verwendung des Firebase Android BoM um die Versionsverwaltung der Bibliothek zu steuern.

    Beim Einrichten des Firebase ML-Modelldownloads müssen Sie außerdem den Parameter TensorFlow Lite SDK zu Ihrer App hinzufügen.

    dependencies {
        // Import the BoM for the Firebase platform
        implementation(platform("com.google.firebase:firebase-bom:33.2.0"))
    
        // Add the dependency for the Firebase ML model downloader library
        // When using the BoM, you don't specify versions in Firebase library dependencies
        implementation("com.google.firebase:firebase-ml-modeldownloader")
    // Also add the dependency for the TensorFlow Lite library and specify its version implementation("org.tensorflow:tensorflow-lite:2.3.0")
    }

    Mit dem Firebase Android BoM Ihre App verwendet immer kompatible Versionen der Firebase Android Libraries.

    Alternative: Firebase-Bibliotheksabhängigkeiten ohne BoM hinzufügen

    Wenn Sie Firebase BoM nicht verwenden, müssen Sie jede Firebase-Bibliotheksversion angeben in der Abhängigkeitszeile ein.

    Wenn Sie in Ihrer App mehrere Firebase-Bibliotheken verwenden, empfehlen, Bibliotheksversionen mit der BoM zu verwalten. Dadurch wird sichergestellt, dass alle Versionen kompatibel.

    dependencies {
        // Add the dependency for the Firebase ML model downloader library
        // When NOT using the BoM, you must specify versions in Firebase library dependencies
        implementation("com.google.firebase:firebase-ml-modeldownloader:25.0.0")
    // Also add the dependency for the TensorFlow Lite library and specify its version implementation("org.tensorflow:tensorflow-lite:2.3.0")
    }
    Suchen Sie nach einem Kotlin-spezifischen Bibliotheksmodul? Beginnt in Oktober 2023 (Firebase BoM 32.5.0) können sowohl Kotlin- als auch Java-Entwickler sind vom Modul der Hauptbibliothek abhängig (Details finden Sie in der FAQs zu dieser Initiative).
  3. Deklarieren Sie im Manifest Ihrer App, dass die Berechtigung INTERNET erforderlich ist:
    <uses-permission android:name="android.permission.INTERNET" />

1. Modell bereitstellen

Stellen Sie Ihre benutzerdefinierten TensorFlow-Modelle über die Firebase-Konsole oder Firebase Admin Python und Node.js SDKs. Weitere Informationen finden Sie unter Benutzerdefinierte Modelle bereitstellen und verwalten

Nachdem Sie Ihrem Firebase-Projekt ein benutzerdefiniertes Modell hinzugefügt haben, können Sie auf das Modell in Ihren Apps unter Verwendung des von Ihnen angegebenen Namens. Sie können jederzeit ein neues TensorFlow Lite-Modell bereitstellen und es auf die Geräte der Nutzer herunterladen, indem Sie getModel() aufrufen (siehe unten).

2. Modell auf das Gerät herunterladen und einen TensorFlow Lite-Interpreter initialisieren

Zur Verwendung des TensorFlow Lite-Modells in Ihrer App müssen Sie zuerst das Firebase ML SDK verwenden um die neueste Version des Modells auf das Gerät herunterzuladen. Instanziieren Sie dann eine TensorFlow Lite-Interpreter mit dem Modell.

Um den Modelldownload zu starten, rufen Sie die Methode getModel() des Modelldownloads auf. Sie geben den Namen an, den Sie dem Modell beim Hochladen zugewiesen haben, unabhängig davon, ob Sie immer das neueste Modell herunterladen möchten und unter welchen Bedingungen den Download erlauben möchten.

Sie können zwischen drei Downloadverhalten wählen:

Downloadtyp Beschreibung
LOCAL_MODEL Rufen Sie das lokale Modell vom Gerät ab. Wenn kein lokales Modell verfügbar ist, verhält sich wie LATEST_MODEL. Verwenden Downloadtyp, wenn Sie kein Interesse an Modellaktualisierungen suchen. Beispiel: rufen Sie mit Remote Config und Sie laden Modelle immer hoch, unter neuen Namen (empfohlen).
LOCAL_MODEL_UPDATE_IN_BACKGROUND Lokales Modell vom Gerät abrufen und mit der Aktualisierung des Modells im Hintergrund beginnen. Wenn kein lokales Modell verfügbar ist, verhält sich wie LATEST_MODEL.
NEUES_MODELL Holen Sie sich das neueste Modell. Wenn das lokale Modell die neueste Version, gibt den lokalen modellieren. Laden Sie andernfalls die aktuelle modellieren. Dieses Verhalten wird blockiert, bis das die neueste Version heruntergeladen wurde (nicht empfohlen). Verwenden Sie dieses Verhalten nur in in denen Sie ausdrücklich die aktuellsten Version.

Sie sollten modellbezogene Funktionen deaktivieren, z. B. einen Teil der Benutzeroberfläche ausblenden, bis Sie bestätigen, dass das Modell heruntergeladen wurde.

Kotlin+KTX

val conditions = CustomModelDownloadConditions.Builder()
        .requireWifi()  // Also possible: .requireCharging() and .requireDeviceIdle()
        .build()
FirebaseModelDownloader.getInstance()
        .getModel("your_model", DownloadType.LOCAL_MODEL_UPDATE_IN_BACKGROUND,
            conditions)
        .addOnSuccessListener { model: CustomModel? ->
            // Download complete. Depending on your app, you could enable the ML
            // feature, or switch from the local model to the remote model, etc.

            // The CustomModel object contains the local path of the model file,
            // which you can use to instantiate a TensorFlow Lite interpreter.
            val modelFile = model?.file
            if (modelFile != null) {
                interpreter = Interpreter(modelFile)
            }
        }

Java

CustomModelDownloadConditions conditions = new CustomModelDownloadConditions.Builder()
    .requireWifi()  // Also possible: .requireCharging() and .requireDeviceIdle()
    .build();
FirebaseModelDownloader.getInstance()
    .getModel("your_model", DownloadType.LOCAL_MODEL_UPDATE_IN_BACKGROUND, conditions)
    .addOnSuccessListener(new OnSuccessListener<CustomModel>() {
      @Override
      public void onSuccess(CustomModel model) {
        // Download complete. Depending on your app, you could enable the ML
        // feature, or switch from the local model to the remote model, etc.

        // The CustomModel object contains the local path of the model file,
        // which you can use to instantiate a TensorFlow Lite interpreter.
        File modelFile = model.getFile();
        if (modelFile != null) {
            interpreter = new Interpreter(modelFile);
        }
      }
    });

Viele Apps starten die Download-Aufgabe im Initialisierungscode, aber Sie können bevor Sie das Modell verwenden.

3. Eingabedaten Inferenz durchführen

Ein- und Ausgabeformen des Modells abrufen

Der TensorFlow Lite-Modellinterpreter nimmt als Eingabe an und erstellt als Ausgabe ein oder mehrere multidimensionale Arrays. Diese Arrays enthalten entweder byte, int, long oder float Werte. Bevor Sie Daten an ein Modell übergeben oder dessen Ergebnisse verwenden können, müssen Sie wissen, die Anzahl und Abmessungen ("Form") der Arrays, die Ihr Modell verwendet

Wenn Sie das Modell selbst erstellt haben oder das Eingabe- und Ausgabeformat des Modells wie folgt lautet: dokumentiert ist, haben Sie diese Informationen vielleicht schon. Wenn Sie nicht wissen, die Form und den Datentyp der Eingabe und Ausgabe Ihres Modells ein, können Sie die Methode TensorFlow Lite-Interpreter zur Prüfung Ihres Modells. Beispiel:

Python

import tensorflow as tf

interpreter = tf.lite.Interpreter(model_path="your_model.tflite")
interpreter.allocate_tensors()

# Print input shape and type
inputs = interpreter.get_input_details()
print('{} input(s):'.format(len(inputs)))
for i in range(0, len(inputs)):
    print('{} {}'.format(inputs[i]['shape'], inputs[i]['dtype']))

# Print output shape and type
outputs = interpreter.get_output_details()
print('\n{} output(s):'.format(len(outputs)))
for i in range(0, len(outputs)):
    print('{} {}'.format(outputs[i]['shape'], outputs[i]['dtype']))

Beispielausgabe:

1 input(s):
[  1 224 224   3] <class 'numpy.float32'>

1 output(s):
[1 1000] <class 'numpy.float32'>

Dolmetscher ausführen

Nachdem Sie das Format der Ein- und Ausgabe Ihres Modells ermittelt haben, Eingabedaten und Transformationen der Daten, die für den Erhalt Eingabe der richtigen Form für Ihr Modell.

Wenn Sie z. B. ein Bildklassifizierungsmodell mit der Eingabeform [1 224 224 3] Gleitkommawerten ist, könnten Sie eine Eingabe generieren ByteBuffer aus einem Bitmap-Objekt, wie im folgenden Beispiel gezeigt:

Kotlin+KTX

val bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true)
val input = ByteBuffer.allocateDirect(224*224*3*4).order(ByteOrder.nativeOrder())
for (y in 0 until 224) {
    for (x in 0 until 224) {
        val px = bitmap.getPixel(x, y)

        // Get channel values from the pixel value.
        val r = Color.red(px)
        val g = Color.green(px)
        val b = Color.blue(px)

        // Normalize channel values to [-1.0, 1.0]. This requirement depends on the model.
        // For example, some models might require values to be normalized to the range
        // [0.0, 1.0] instead.
        val rf = (r - 127) / 255f
        val gf = (g - 127) / 255f
        val bf = (b - 127) / 255f

        input.putFloat(rf)
        input.putFloat(gf)
        input.putFloat(bf)
    }
}

Java

Bitmap bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true);
ByteBuffer input = ByteBuffer.allocateDirect(224 * 224 * 3 * 4).order(ByteOrder.nativeOrder());
for (int y = 0; y < 224; y++) {
    for (int x = 0; x < 224; x++) {
        int px = bitmap.getPixel(x, y);

        // Get channel values from the pixel value.
        int r = Color.red(px);
        int g = Color.green(px);
        int b = Color.blue(px);

        // Normalize channel values to [-1.0, 1.0]. This requirement depends
        // on the model. For example, some models might require values to be
        // normalized to the range [0.0, 1.0] instead.
        float rf = (r - 127) / 255.0f;
        float gf = (g - 127) / 255.0f;
        float bf = (b - 127) / 255.0f;

        input.putFloat(rf);
        input.putFloat(gf);
        input.putFloat(bf);
    }
}

Weisen Sie dann einen ByteBuffer zu, der groß genug ist, um die Modellausgabe und Eingabe- und Ausgabepuffer an den TensorFlow Lite-Interpreter übergeben. run()-Methode. Beispiel: Für die Ausgabeform [1 1000] als Gleitkommazahl Werte:

Kotlin+KTX

val bufferSize = 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE
val modelOutput = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder())
interpreter?.run(input, modelOutput)

Java

int bufferSize = 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE;
ByteBuffer modelOutput = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder());
interpreter.run(input, modelOutput);

Wie Sie die Ausgabe verwenden, hängt vom verwendeten Modell ab.

Bei der Klassifizierung könnten Sie als nächsten Schritt ordnen Sie die Indexe des Ergebnisses den von ihnen dargestellten Labels zu:

Kotlin+KTX

modelOutput.rewind()
val probabilities = modelOutput.asFloatBuffer()
try {
    val reader = BufferedReader(
            InputStreamReader(assets.open("custom_labels.txt")))
    for (i in probabilities.capacity()) {
        val label: String = reader.readLine()
        val probability = probabilities.get(i)
        println("$label: $probability")
    }
} catch (e: IOException) {
    // File not found?
}

Java

modelOutput.rewind();
FloatBuffer probabilities = modelOutput.asFloatBuffer();
try {
    BufferedReader reader = new BufferedReader(
            new InputStreamReader(getAssets().open("custom_labels.txt")));
    for (int i = 0; i < probabilities.capacity(); i++) {
        String label = reader.readLine();
        float probability = probabilities.get(i);
        Log.i(TAG, String.format("%s: %1.4f", label, probability));
    }
} catch (IOException e) {
    // File not found?
}

Anhang: Modellsicherheit

Unabhängig davon, wie Sie Ihre TensorFlow Lite-Modelle für Firebase ML, speichert Firebase ML sie im standardisierten serialisierten protobuf-Format unter lokalen Speicher.

Theoretisch bedeutet dies, dass jeder Ihr Modell kopieren kann. Sie können jedoch in der Praxis sind die meisten Modelle so anwendungsspezifisch Optimierungen vorzunehmen, bei denen das Risiko dem der Konkurrenz beim Auseinanderbauen und Ihren Code wiederverwenden. Sie sollten sich jedoch über dieses Risiko im Klaren sein, bevor Sie ein benutzerdefiniertes Modell in Ihrer App erstellen.

Unter Android API-Level 21 (Lollipop) und höher wird das Modell Verzeichnis, das ist aus der automatischen Sicherung ausgeschlossen werden.

Bei der Android API-Ebene 20 und niedriger wird das Modell in den privaten internen Speicher der App in ein Verzeichnis mit dem Namen com.google.firebase.ml.custom.models heruntergeladen. Wenn Sie die Dateisicherung mit BackupAgent aktiviert haben, können Sie dieses Verzeichnis ausschließen.