Utiliser un modèle TensorFlow Lite personnalisé sur Android

Si votre application utilise des TensorFlow Lite, vous pouvez utiliser Firebase ML pour déployer vos modèles. Par déployer des modèles avec Firebase, vous pouvez réduire la taille de téléchargement initiale votre application et mettre à jour ses modèles de ML sans publier de nouvelle version votre application. De plus, avec Remote Config et A/B Testing, vous pouvez dynamiquement peuvent servir des modèles différents pour différents ensembles d'utilisateurs.

Modèles TensorFlow Lite

Les modèles TensorFlow Lite sont des modèles de ML optimisés pour s'exécuter sur des appareils mobiles appareils. Pour obtenir un modèle TensorFlow Lite, procédez comme suit:

Avant de commencer

  1. Si ce n'est pas déjà fait, Ajoutez Firebase à votre projet Android.
  2. Dans le fichier Gradle de votre module (au niveau de l'application) (généralement <project>/<app-module>/build.gradle.kts ou <project>/<app-module>/build.gradle), Ajoutez la dépendance pour la bibliothèque de téléchargement de modèles Firebase ML pour Android. Nous vous recommandons d'utiliser Firebase Android BoM pour contrôler le contrôle des versions de la bibliothèque.

    De plus, pour configurer l'outil de téléchargement de modèles Firebase ML, vous devez ajouter SDK TensorFlow Lite à votre application.

    dependencies {
        // Import the BoM for the Firebase platform
        implementation(platform("com.google.firebase:firebase-bom:33.2.0"))
    
        // Add the dependency for the Firebase ML model downloader library
        // When using the BoM, you don't specify versions in Firebase library dependencies
        implementation("com.google.firebase:firebase-ml-modeldownloader")
    // Also add the dependency for the TensorFlow Lite library and specify its version implementation("org.tensorflow:tensorflow-lite:2.3.0")
    }

    En utilisant Firebase Android BoM, votre application utilisera toujours des versions compatibles des bibliothèques Firebase Android.

    (Alternative) Ajoutez des dépendances de bibliothèque Firebase sans utiliser BoM.

    Si vous choisissez de ne pas utiliser Firebase BoM, vous devez spécifier chaque version de la bibliothèque Firebase dans sa ligne de dépendance.

    Notez que si vous utilisez plusieurs bibliothèques Firebase dans votre application, recommandent d'utiliser BoM pour gérer les versions de la bibliothèque, ce qui garantit que toutes les versions sont compatibles.

    dependencies {
        // Add the dependency for the Firebase ML model downloader library
        // When NOT using the BoM, you must specify versions in Firebase library dependencies
        implementation("com.google.firebase:firebase-ml-modeldownloader:25.0.0")
    // Also add the dependency for the TensorFlow Lite library and specify its version implementation("org.tensorflow:tensorflow-lite:2.3.0")
    }
    Vous recherchez un module de bibliothèque spécifique à Kotlin ? Début dans Octobre 2023 (Firebase BoM 32.5.0), les développeurs Kotlin et Java peuvent dépendent du module de bibliothèque principal (pour en savoir plus, consultez Questions fréquentes sur cette initiative).
  3. Dans le fichier manifeste de votre application, déclarez que l'autorisation INTERNET est requise :
    <uses-permission android:name="android.permission.INTERNET" />

1. Déployer le modèle

Déployez vos modèles TensorFlow personnalisés à l'aide de la console Firebase ou les SDK Firebase Admin Python et Node.js. Voir Déployer et gérer des modèles personnalisés

Après avoir ajouté un modèle personnalisé à votre projet Firebase, vous pouvez référencer le dans vos applications en utilisant le nom spécifié. Vous pouvez déployer à tout moment un nouveau modèle TensorFlow Lite et le télécharger appareils par appelez getModel() (voir ci-dessous).

2. Télécharger le modèle sur l'appareil et initialiser un interpréteur TensorFlow Lite

Pour utiliser votre modèle TensorFlow Lite dans votre application, commencez par utiliser le SDK Firebase ML pour télécharger la dernière version du modèle sur l'appareil. Instanciez ensuite Interpréteur TensorFlow Lite avec le modèle.

Pour lancer le téléchargement du modèle, appelez la méthode getModel() de l'outil de téléchargement de modèles. en spécifiant le nom que vous avez attribué au modèle lors de son importation, télécharger systématiquement le modèle le plus récent, ainsi que les conditions dans lesquelles vous si vous souhaitez autoriser le téléchargement.

Vous avez le choix entre trois comportements de téléchargement :

Type de téléchargement Description
MODÈLE_LOCAL Récupérez le modèle local de l'appareil. Si aucun modèle local n'est disponible, le comportement est identique à celui de LATEST_MODEL. Utiliser ceci de téléchargement si vous n'êtes pas intéressé par des mises à jour du modèle. Par exemple, vous utilisez Remote Config pour récupérer les noms de modèles et vous importez toujours des modèles sous de nouveaux noms (recommandé).
LOCAL_MODEL_UPDATE_IN_BACKGROUND Récupérez le modèle local de l'appareil et commencer à mettre à jour le modèle en arrière-plan. Si aucun modèle local n'est disponible, se comporte comme LATEST_MODEL.
LATEST_MODEL Obtenez le dernier modèle. Si le modèle local est la dernière version, renvoie la valeur du modèle de ML. Sinon, téléchargez la dernière version du modèle de ML. Ce comportement bloquera l'application jusqu'à ce que la dernière version soit téléchargée (non recommandé). Utilisez ce comportement uniquement dans dans les cas où vous avez explicitement besoin version.

Vous devez désactiver les fonctionnalités liées aux modèles, par exemple, les options grisées ou masquer une partie de l'interface utilisateur jusqu'à ce que vous confirmiez que le modèle a été téléchargé.

Kotlin+KTX

val conditions = CustomModelDownloadConditions.Builder()
        .requireWifi()  // Also possible: .requireCharging() and .requireDeviceIdle()
        .build()
FirebaseModelDownloader.getInstance()
        .getModel("your_model", DownloadType.LOCAL_MODEL_UPDATE_IN_BACKGROUND,
            conditions)
        .addOnSuccessListener { model: CustomModel? ->
            // Download complete. Depending on your app, you could enable the ML
            // feature, or switch from the local model to the remote model, etc.

            // The CustomModel object contains the local path of the model file,
            // which you can use to instantiate a TensorFlow Lite interpreter.
            val modelFile = model?.file
            if (modelFile != null) {
                interpreter = Interpreter(modelFile)
            }
        }

Java

CustomModelDownloadConditions conditions = new CustomModelDownloadConditions.Builder()
    .requireWifi()  // Also possible: .requireCharging() and .requireDeviceIdle()
    .build();
FirebaseModelDownloader.getInstance()
    .getModel("your_model", DownloadType.LOCAL_MODEL_UPDATE_IN_BACKGROUND, conditions)
    .addOnSuccessListener(new OnSuccessListener<CustomModel>() {
      @Override
      public void onSuccess(CustomModel model) {
        // Download complete. Depending on your app, you could enable the ML
        // feature, or switch from the local model to the remote model, etc.

        // The CustomModel object contains the local path of the model file,
        // which you can use to instantiate a TensorFlow Lite interpreter.
        File modelFile = model.getFile();
        if (modelFile != null) {
            interpreter = new Interpreter(modelFile);
        }
      }
    });

De nombreuses applications lancent la tâche de téléchargement dans leur code d'initialisation, mais vous pouvez avant de devoir utiliser le modèle.

3. Effectuer une inférence sur des données d'entrée

Obtenir les formes d'entrée et de sortie de votre modèle

L'interpréteur de modèle TensorFlow Lite prend en entrée et produit en sortie un ou plusieurs tableaux multidimensionnels. Ces tableaux contiennent des valeurs byte, int, long ou float. Avant de pouvoir transmettre des données à un modèle ou utiliser son résultat, vous devez savoir le nombre et les dimensions ("forme") des tableaux utilisés par le modèle ;

Si vous avez créé le modèle vous-même, ou si le format d'entrée et de sortie du modèle est documentées, vous disposez peut-être déjà de ces informations. Si vous ne connaissez pas le et le type de données des entrées et sorties de votre modèle, vous pouvez utiliser Interpréteur TensorFlow Lite pour inspecter votre modèle. Exemple :

Python

import tensorflow as tf

interpreter = tf.lite.Interpreter(model_path="your_model.tflite")
interpreter.allocate_tensors()

# Print input shape and type
inputs = interpreter.get_input_details()
print('{} input(s):'.format(len(inputs)))
for i in range(0, len(inputs)):
    print('{} {}'.format(inputs[i]['shape'], inputs[i]['dtype']))

# Print output shape and type
outputs = interpreter.get_output_details()
print('\n{} output(s):'.format(len(outputs)))
for i in range(0, len(outputs)):
    print('{} {}'.format(outputs[i]['shape'], outputs[i]['dtype']))

Exemple de résultat :

1 input(s):
[  1 224 224   3] <class 'numpy.float32'>

1 output(s):
[1 1000] <class 'numpy.float32'>

Exécuter l'interpréteur

Une fois que vous avez déterminé le format des entrées et des sorties de votre modèle, obtenez les données d'entrée et effectuer toutes les transformations nécessaires sur les données une entrée dont la forme est adaptée à votre modèle.

Par exemple, si vous disposez d'un modèle de classification d'images avec une forme d'entrée de valeurs à virgule flottante [1 224 224 3], vous pouvez générer une ByteBuffer d'entrée à partir d'un objet Bitmap, comme illustré dans l'exemple suivant :

Kotlin+KTX

val bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true)
val input = ByteBuffer.allocateDirect(224*224*3*4).order(ByteOrder.nativeOrder())
for (y in 0 until 224) {
    for (x in 0 until 224) {
        val px = bitmap.getPixel(x, y)

        // Get channel values from the pixel value.
        val r = Color.red(px)
        val g = Color.green(px)
        val b = Color.blue(px)

        // Normalize channel values to [-1.0, 1.0]. This requirement depends on the model.
        // For example, some models might require values to be normalized to the range
        // [0.0, 1.0] instead.
        val rf = (r - 127) / 255f
        val gf = (g - 127) / 255f
        val bf = (b - 127) / 255f

        input.putFloat(rf)
        input.putFloat(gf)
        input.putFloat(bf)
    }
}

Java

Bitmap bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true);
ByteBuffer input = ByteBuffer.allocateDirect(224 * 224 * 3 * 4).order(ByteOrder.nativeOrder());
for (int y = 0; y < 224; y++) {
    for (int x = 0; x < 224; x++) {
        int px = bitmap.getPixel(x, y);

        // Get channel values from the pixel value.
        int r = Color.red(px);
        int g = Color.green(px);
        int b = Color.blue(px);

        // Normalize channel values to [-1.0, 1.0]. This requirement depends
        // on the model. For example, some models might require values to be
        // normalized to the range [0.0, 1.0] instead.
        float rf = (r - 127) / 255.0f;
        float gf = (g - 127) / 255.0f;
        float bf = (b - 127) / 255.0f;

        input.putFloat(rf);
        input.putFloat(gf);
        input.putFloat(bf);
    }
}

Ensuite, allouez un ByteBuffer suffisamment grand pour contenir la sortie du modèle et transmettre les tampons d'entrée et de sortie au run(). Par exemple, pour une forme de sortie en virgule flottante [1 1000], :

Kotlin+KTX

val bufferSize = 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE
val modelOutput = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder())
interpreter?.run(input, modelOutput)

Java

int bufferSize = 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE;
ByteBuffer modelOutput = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder());
interpreter.run(input, modelOutput);

La manière dont vous utilisez la sortie dépend du modèle que vous utilisez.

Par exemple, si vous effectuez une classification, vous pourriez mappez les index du résultat aux étiquettes qu'ils représentent:

Kotlin+KTX

modelOutput.rewind()
val probabilities = modelOutput.asFloatBuffer()
try {
    val reader = BufferedReader(
            InputStreamReader(assets.open("custom_labels.txt")))
    for (i in probabilities.capacity()) {
        val label: String = reader.readLine()
        val probability = probabilities.get(i)
        println("$label: $probability")
    }
} catch (e: IOException) {
    // File not found?
}

Java

modelOutput.rewind();
FloatBuffer probabilities = modelOutput.asFloatBuffer();
try {
    BufferedReader reader = new BufferedReader(
            new InputStreamReader(getAssets().open("custom_labels.txt")));
    for (int i = 0; i < probabilities.capacity(); i++) {
        String label = reader.readLine();
        float probability = probabilities.get(i);
        Log.i(TAG, String.format("%s: %1.4f", label, probability));
    }
} catch (IOException e) {
    // File not found?
}

Annexe: Sécurité du modèle

Quelle que soit la manière dont vous mettez vos modèles TensorFlow Lite à la disposition de Firebase ML, Firebase ML les stocke au format protobuf sérialisé standard dans l'espace de stockage local.

En théorie, cela signifie que n'importe qui peut copier votre modèle. Toutefois, en pratique, la plupart des modèles sont tellement spécifiques à l'application et masqués par des optimisations que le risque est semblable à celui de vos concurrents qui désassemblent et réutilisent votre code. Néanmoins, vous devez être conscient de ce risque avant d'utiliser un modèle personnalisé dans votre application.

Sur le niveau d'API Android 21 (Lollipop) ou version ultérieure, le modèle est téléchargé répertoire contenant exclus de la sauvegarde automatique.

À partir du niveau d'API Android 20, le modèle est téléchargé dans un répertoire nommée com.google.firebase.ml.custom.models dans le dossier privé de l'application mémoire de stockage interne. Si vous avez activé la sauvegarde de fichiers avec BackupAgent, vous pouvez choisir d'exclure ce répertoire.