Join us for Firebase Summit on November 10, 2021. Tune in to learn how Firebase can help you accelerate app development, release with confidence, and scale with ease. Register

Utiliser un modèle TensorFlow Lite personnalisé sur Android

Si votre application utilise sur mesure tensorflow Lite modèles, vous pouvez utiliser Firebase ML pour déployer vos modèles. En déployant des modèles avec Firebase, vous pouvez réduire la taille de téléchargement initiale de votre application et mettre à jour les modèles de ML de votre application sans publier une nouvelle version de votre application. Et, avec la configuration à distance et les tests A/B, vous pouvez proposer dynamiquement différents modèles à différents groupes d'utilisateurs.

Modèles TensorFlow Lite

Les modèles TensorFlow Lite sont des modèles ML optimisés pour s'exécuter sur des appareils mobiles. Pour obtenir un modèle TensorFlow Lite :

Avant que tu commences

  1. Si vous avez pas déjà, ajoutez Firebase à votre projet Android .
  2. Utilisation de la Firebase Android BoM , déclarer la dépendance du modèle ML Firebase téléchargeur bibliothèque Android dans votre module (app-niveau) de fichier Gradle (généralement app/build.gradle ).

    De plus, dans le cadre de la configuration du téléchargeur de modèles Firebase ML, vous devez ajouter le SDK TensorFlow Lite à votre application.

    Java

    dependencies {
        // Import the BoM for the Firebase platform
        implementation platform('com.google.firebase:firebase-bom:28.4.2')
    
        // Declare the dependency for the Firebase ML model downloader library
        // When using the BoM, you don't specify versions in Firebase library dependencies
        implementation 'com.google.firebase:firebase-ml-modeldownloader'
    // Also declare the dependency for the TensorFlow Lite library and specify its version implementation 'org.tensorflow:tensorflow-lite:2.3.0'
    }

    En utilisant le Firebase Android BoM , votre application utilise toujours des versions compatibles des bibliothèques Firebase Android.

    (Alternative) déclarer des dépendances de bibliothèque firebase sans utiliser la nomenclature

    Si vous choisissez de ne pas utiliser la nomenclature de Firebase, vous devez spécifier chaque version de la bibliothèque Firebase dans sa ligne de dépendance.

    Notez que si vous utilisez plusieurs bibliothèques Firebase dans votre application, nous vous recommandons fortement d' utiliser la BoM pour gérer les versions bibliothèque, ce qui garantit que toutes les versions sont compatibles.

    dependencies {
        // Declare the dependency for the Firebase ML model downloader library
        // When NOT using the BoM, you must specify versions in Firebase library dependencies
        implementation 'com.google.firebase:firebase-ml-modeldownloader:24.0.0'
    // Also declare the dependency for the TensorFlow Lite library and specify its version implementation 'org.tensorflow:tensorflow-lite:2.3.0'
    }

    Kotlin+KTX

    dependencies {
        // Import the BoM for the Firebase platform
        implementation platform('com.google.firebase:firebase-bom:28.4.2')
    
        // Declare the dependency for the Firebase ML model downloader library
        // When using the BoM, you don't specify versions in Firebase library dependencies
        implementation 'com.google.firebase:firebase-ml-modeldownloader-ktx'
    // Also declare the dependency for the TensorFlow Lite library and specify its version implementation 'org.tensorflow:tensorflow-lite:2.3.0'
    }

    En utilisant le Firebase Android BoM , votre application utilise toujours des versions compatibles des bibliothèques Firebase Android.

    (Alternative) déclarer des dépendances de bibliothèque firebase sans utiliser la nomenclature

    Si vous choisissez de ne pas utiliser la nomenclature de Firebase, vous devez spécifier chaque version de la bibliothèque Firebase dans sa ligne de dépendance.

    Notez que si vous utilisez plusieurs bibliothèques Firebase dans votre application, nous vous recommandons fortement d' utiliser la BoM pour gérer les versions bibliothèque, ce qui garantit que toutes les versions sont compatibles.

    dependencies {
        // Declare the dependency for the Firebase ML model downloader library
        // When NOT using the BoM, you must specify versions in Firebase library dependencies
        implementation 'com.google.firebase:firebase-ml-modeldownloader-ktx:24.0.0'
    // Also declare the dependency for the TensorFlow Lite library and specify its version implementation 'org.tensorflow:tensorflow-lite:2.3.0'
    }
  3. En manifeste, DÉCLARE de votre application que l' autorisation Internet est nécessaire:
    <uses-permission android:name="android.permission.INTERNET" />

1. Déployez votre modèle

Déployez vos modèles TensorFlow personnalisés à l'aide de la console Firebase ou des SDK Firebase Admin Python et Node.js. Voir déployer et gérer des modèles personnalisés .

Après avoir ajouté un modèle personnalisé à votre projet Firebase, vous pouvez référencer le modèle dans vos applications en utilisant le nom que vous avez spécifié. A tout moment, vous pouvez déployer un nouveau modèle tensorflow Lite et télécharger le nouveau modèle sur les appareils des utilisateurs en appelant getModel() (voir ci - dessous).

2. Téléchargez le modèle sur l'appareil et initialisez un interpréteur TensorFlow Lite

Pour utiliser votre modèle TensorFlow Lite dans votre application, utilisez d'abord le SDK Firebase ML pour télécharger la dernière version du modèle sur l'appareil. Ensuite, instanciez un interpréteur TensorFlow Lite avec le modèle.

Pour commencer le téléchargement du modèle, appelez le de downloader modèle getModel() méthode, spécifiant le nom que vous avez attribué le modèle lorsque vous l' avez téléchargé, si vous voulez toujours télécharger le dernier modèle, et les conditions dans lesquelles vous souhaitez autoriser le téléchargement.

Vous pouvez choisir parmi trois comportements de téléchargement :

Type de téléchargement La description
MODÈLE_LOCAL Obtenez le modèle local de l'appareil. S'il n'y a pas de modèle locale disponible, ce se comporte comme LATEST_MODEL . Utilisez ce type de téléchargement si vous n'êtes pas intéressé par la recherche de mises à jour de modèle. Par exemple, vous utilisez Remote Config pour récupérer les noms de modèles et vous téléchargez toujours des modèles sous de nouveaux noms (recommandé).
LOCAL_MODEL_UPDATE_IN_BACKGROUND Obtenez le modèle local de l'appareil et commencez à mettre à jour le modèle en arrière-plan. S'il n'y a pas de modèle locale disponible, ce se comporte comme LATEST_MODEL .
DERNIER MODÈLE Obtenez le dernier modèle. Si le modèle local est la dernière version, renvoie le modèle local. Sinon, téléchargez le dernier modèle. Ce comportement se bloquera jusqu'à ce que la dernière version soit téléchargée (non recommandé). Utilisez ce comportement uniquement dans les cas où vous avez explicitement besoin de la dernière version.

Vous devez désactiver les fonctionnalités liées au modèle, par exemple griser ou masquer une partie de votre interface utilisateur, jusqu'à ce que vous confirmiez que le modèle a été téléchargé.

Java

CustomModelDownloadConditions conditions = new CustomModelDownloadConditions.Builder()
    .requireWifi()  // Also possible: .requireCharging() and .requireDeviceIdle()
    .build();
FirebaseModelDownloader.getInstance()
    .getModel("your_model", DownloadType.LOCAL_MODEL_UPDATE_IN_BACKGROUND, conditions)
    .addOnSuccessListener(new OnSuccessListener<CustomModel>() {
      @Override
      public void onSuccess(CustomModel model) {
        // Download complete. Depending on your app, you could enable the ML
        // feature, or switch from the local model to the remote model, etc.

        // The CustomModel object contains the local path of the model file,
        // which you can use to instantiate a TensorFlow Lite interpreter.
        File modelFile = model.getFile();
        if (modelFile != null) {
            interpreter = new Interpreter(modelFile);
        }
      }
    });

Kotlin+KTX

val conditions = CustomModelDownloadConditions.Builder()
        .requireWifi()  // Also possible: .requireCharging() and .requireDeviceIdle()
        .build()
FirebaseModelDownloader.getInstance()
        .getModel("your_model", DownloadType.LOCAL_MODEL_UPDATE_IN_BACKGROUND,
            conditions)
        .addOnSuccessListener { model: CustomModel? ->
            // Download complete. Depending on your app, you could enable the ML
            // feature, or switch from the local model to the remote model, etc.

            // The CustomModel object contains the local path of the model file,
            // which you can use to instantiate a TensorFlow Lite interpreter.
            val modelFile = model?.file
            if (modelFile != null) {
                interpreter = Interpreter(modelFile)
            }
        }

De nombreuses applications démarrent la tâche de téléchargement dans leur code d'initialisation, mais vous pouvez le faire à tout moment avant de devoir utiliser le modèle.

3. Effectuer une inférence sur les données d'entrée

Obtenez les formes d'entrée et de sortie de votre modèle

L'interpréteur de modèle TensorFlow Lite prend en entrée et produit en sortie un ou plusieurs tableaux multidimensionnels. Ces tableaux contiennent soit des byte , int , long ou float valeurs. Avant de pouvoir transmettre des données à un modèle ou utiliser son résultat, vous devez connaître le nombre et les dimensions ("forme") des tableaux utilisés par votre modèle.

Si vous avez créé le modèle vous-même, ou si le format d'entrée et de sortie du modèle est documenté, vous disposez peut-être déjà de ces informations. Si vous ne connaissez pas la forme et le type de données des entrées et sorties de votre modèle, vous pouvez utiliser l'interpréteur TensorFlow Lite pour inspecter votre modèle. Par exemple:

Python

import tensorflow as tf

interpreter = tf.lite.Interpreter(model_path="your_model.tflite")
interpreter.allocate_tensors()

# Print input shape and type
inputs = interpreter.get_input_details()
print('{} input(s):'.format(len(inputs)))
for i in range(0, len(inputs)):
    print('{} {}'.format(inputs[i]['shape'], inputs[i]['dtype']))

# Print output shape and type
outputs = interpreter.get_output_details()
print('\n{} output(s):'.format(len(outputs)))
for i in range(0, len(outputs)):
    print('{} {}'.format(outputs[i]['shape'], outputs[i]['dtype']))

Exemple de sortie :

1 input(s):
[  1 224 224   3] <class 'numpy.float32'>

1 output(s):
[1 1000] <class 'numpy.float32'>

Exécuter l'interpréteur

Après avoir déterminé le format de l'entrée et de la sortie de votre modèle, récupérez vos données d'entrée et effectuez toutes les transformations sur les données qui sont nécessaires pour obtenir une entrée de la bonne forme pour votre modèle.

Par exemple, si vous avez un modèle de classification de l' image avec une forme d'entrée [1 224 224 3] valeurs à virgule flottante, on peut générer une entrée ByteBuffer à partir d' une Bitmap objet comme indiqué dans l'exemple suivant:

Java

Bitmap bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true);
ByteBuffer input = ByteBuffer.allocateDirect(224 * 224 * 3 * 4).order(ByteOrder.nativeOrder());
for (int y = 0; y < 224; y++) {
    for (int x = 0; x < 224; x++) {
        int px = bitmap.getPixel(x, y);

        // Get channel values from the pixel value.
        int r = Color.red(px);
        int g = Color.green(px);
        int b = Color.blue(px);

        // Normalize channel values to [-1.0, 1.0]. This requirement depends
        // on the model. For example, some models might require values to be
        // normalized to the range [0.0, 1.0] instead.
        float rf = (r - 127) / 255.0f;
        float gf = (g - 127) / 255.0f;
        float bf = (b - 127) / 255.0f;

        input.putFloat(rf);
        input.putFloat(gf);
        input.putFloat(bf);
    }
}

Kotlin+KTX

val bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true)
val input = ByteBuffer.allocateDirect(224*224*3*4).order(ByteOrder.nativeOrder())
for (y in 0 until 224) {
    for (x in 0 until 224) {
        val px = bitmap.getPixel(x, y)

        // Get channel values from the pixel value.
        val r = Color.red(px)
        val g = Color.green(px)
        val b = Color.blue(px)

        // Normalize channel values to [-1.0, 1.0]. This requirement depends on the model.
        // For example, some models might require values to be normalized to the range
        // [0.0, 1.0] instead.
        val rf = (r - 127) / 255f
        val gf = (g - 127) / 255f
        val bf = (b - 127) / 255f

        input.putFloat(rf)
        input.putFloat(gf)
        input.putFloat(bf)
    }
}

Ensuite, allouer un ByteBuffer assez grande pour contenir la sortie du modèle et passer le tampon d'entrée et un tampon de sortie de l'interpréteur tensorflow Lite run() méthode. Par exemple, pour une forme de sortie [1 1000] valeurs à virgule flottante:

Java

int bufferSize = 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE;
ByteBuffer modelOutput = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder());
interpreter.run(input, modelOutput);

Kotlin+KTX

val bufferSize = 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE
val modelOutput = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder())
interpreter?.run(input, modelOutput)

La façon dont vous utilisez la sortie dépend du modèle que vous utilisez.

Par exemple, si vous effectuez une classification, à l'étape suivante, vous pouvez mapper les index du résultat aux étiquettes qu'ils représentent :

Java

modelOutput.rewind();
FloatBuffer probabilities = modelOutput.asFloatBuffer();
try {
    BufferedReader reader = new BufferedReader(
            new InputStreamReader(getAssets().open("custom_labels.txt")));
    for (int i = 0; i < probabilities.capacity(); i++) {
        String label = reader.readLine();
        float probability = probabilities.get(i);
        Log.i(TAG, String.format("%s: %1.4f", label, probability));
    }
} catch (IOException e) {
    // File not found?
}

Kotlin+KTX

modelOutput.rewind()
val probabilities = modelOutput.asFloatBuffer()
try {
    val reader = BufferedReader(
            InputStreamReader(assets.open("custom_labels.txt")))
    for (i in probabilities.capacity()) {
        val label: String = reader.readLine()
        val probability = probabilities.get(i)
        println("$label: $probability")
    }
} catch (e: IOException) {
    // File not found?
}

Annexe : Modèle de sécurité

Quelle que soit la manière dont vous mettez vos modèles TensorFlow Lite à la disposition de Firebase ML, Firebase ML les stocke au format protobuf sérialisé standard dans le stockage local.

En théorie, cela signifie que n'importe qui peut copier votre modèle. Cependant, dans la pratique, la plupart des modèles sont tellement spécifiques à l'application et obscurcis par les optimisations que le risque est similaire à celui des concurrents qui désassemblent et réutilisent votre code. Néanmoins, vous devez être conscient de ce risque avant d'utiliser un modèle personnalisé dans votre application.

Sur Android API niveau 21 (Lollipop) et plus récent, le modèle est téléchargé dans un répertoire qui est exclu de la sauvegarde automatique .

Le niveau de l' API Android 20 ans et plus, le modèle est téléchargé dans un répertoire nommé com.google.firebase.ml.custom.models dans le stockage interne app-privé. Si vous avez activé la sauvegarde de fichiers en utilisant BackupAgent , vous pouvez choisir d'exclure ce répertoire.