Google is committed to advancing racial equity for Black communities. See how.
Cette page a été traduite par l'API Cloud Translation.
Switch to English

Utilisez un modèle TensorFlow Lite personnalisé sur Android

Si votre application utilise des modèles TensorFlow Lite personnalisés, vous pouvez utiliser Firebase ML pour déployer vos modèles. En déployant des modèles avec Firebase, vous pouvez réduire la taille de téléchargement initiale de votre application et mettre à jour les modèles ML de votre application sans publier une nouvelle version de votre application. Et, avec la configuration à distance et les tests A / B, vous pouvez servir dynamiquement différents modèles à différents groupes d'utilisateurs.

Modèles TensorFlow Lite

Les modèles TensorFlow Lite sont des modèles ML optimisés pour s'exécuter sur des appareils mobiles. Pour obtenir un modèle TensorFlow Lite:

Avant que tu commences

  1. Si vous ne l'avez pas déjà fait, ajoutez Firebase à votre projet Android .
  2. À l'aide de Firebase Android BoM , déclarez la dépendance de la bibliothèque Android de modèles personnalisés Firebase ML dans le fichier Gradle de votre module (au niveau de l'application) (généralement app/build.gradle ).

    De plus, dans le cadre de la configuration des modèles personnalisés Firebase ML, vous devez ajouter le SDK TensorFlow Lite à votre application.

    dependencies {
        // Import the BoM for the Firebase platform
        implementation platform('com.google.firebase:firebase-bom:26.6.0')
    
        // Declare the dependency for the Firebase ML Custom Models library
        // When using the BoM, you don't specify versions in Firebase library dependencies
        implementation 'com.google.firebase:firebase-ml-model-interpreter'
    // Also declare the dependency for the TensorFlow Lite library and specify its version implementation 'org.tensorflow:tensorflow-lite:2.3.0'
    }

    En utilisant Firebase Android BoM , votre application utilisera toujours des versions compatibles des bibliothèques Firebase Android.

    (Alternative) Déclarez les dépendances de la bibliothèque Firebase sans utiliser le BoM

    Si vous choisissez de ne pas utiliser Firebase BoM, vous devez spécifier chaque version de la bibliothèque Firebase dans sa ligne de dépendance.

    Notez que si vous utilisez plusieurs bibliothèques Firebase dans votre application, nous vous recommandons vivement d'utiliser BoM pour gérer les versions de bibliothèque, ce qui garantit que toutes les versions sont compatibles.

    dependencies {
        // Declare the dependency for the Firebase ML Custom Models library
        // When NOT using the BoM, you must specify versions in Firebase library dependencies
        implementation 'com.google.firebase:firebase-ml-model-interpreter:22.0.4'
    // Also declare the dependency for the TensorFlow Lite library and specify its version implementation 'org.tensorflow:tensorflow-lite:2.3.0'
    }
  3. Dans le manifeste de votre application, déclarez que l'autorisation INTERNET est requise:
    <uses-permission android:name="android.permission.INTERNET" />

1. Déployez votre modèle

Déployez vos modèles TensorFlow personnalisés à l'aide de la console Firebase ou des SDK Firebase Admin Python et Node.js. Voir Déployer et gérer des modèles personnalisés .

Après avoir ajouté un modèle personnalisé à votre projet Firebase, vous pouvez référencer le modèle dans vos applications en utilisant le nom que vous avez spécifié. À tout moment, vous pouvez télécharger un nouveau modèle TensorFlow Lite, et votre application téléchargera le nouveau modèle et commencera à l'utiliser au prochain redémarrage de l'application. Vous pouvez définir les conditions d'appareil requises pour que votre application tente de mettre à jour le modèle (voir ci-dessous).

2. Téléchargez le modèle sur l'appareil

Pour utiliser votre modèle TensorFlow Lite dans votre application, utilisez d'abord le SDK Firebase ML pour télécharger la dernière version du modèle sur l'appareil.

Pour démarrer le téléchargement du modèle, appelez la méthode download() du gestionnaire de modèles, en spécifiant le nom que vous avez attribué au modèle lorsque vous l'avez téléchargé et les conditions dans lesquelles vous souhaitez autoriser le téléchargement. Si le modèle ne se trouve pas sur l'appareil ou si une version plus récente du modèle est disponible, la tâche téléchargera le modèle de manière asynchrone depuis Firebase.

Vous devez désactiver les fonctionnalités liées au modèle (par exemple, griser ou masquer une partie de votre interface utilisateur) jusqu'à ce que vous confirmiez que le modèle a été téléchargé.

Java

FirebaseCustomRemoteModel remoteModel =
      new FirebaseCustomRemoteModel.Builder("your_model").build();
FirebaseModelDownloadConditions conditions = new FirebaseModelDownloadConditions.Builder()
        .requireWifi()
        .build();
FirebaseModelManager.getInstance().download(remoteModel, conditions)
        .addOnSuccessListener(new OnSuccessListener<Void>() {
            @Override
            public void onSuccess(Void v) {
              // Download complete. Depending on your app, you could enable
              // the ML feature, or switch from the local model to the remote
              // model, etc.
            }
        });

Kotlin + KTX

val remoteModel = FirebaseCustomRemoteModel.Builder("your_model").build()
val conditions = FirebaseModelDownloadConditions.Builder()
    .requireWifi()
    .build()
FirebaseModelManager.getInstance().download(remoteModel, conditions)
    .addOnCompleteListener {
        // Download complete. Depending on your app, you could enable the ML
        // feature, or switch from the local model to the remote model, etc.
    }

De nombreuses applications démarrent la tâche de téléchargement dans leur code d'initialisation, mais vous pouvez le faire à tout moment avant d'utiliser le modèle.

3. Initialisez un interpréteur TensorFlow Lite

Après avoir téléchargé le modèle sur l'appareil, vous pouvez obtenir l'emplacement du fichier de modèle en appelant la méthode getLatestModelFile() du gestionnaire de getLatestModelFile() . Utilisez cette valeur pour instancier un interpréteur TensorFlow Lite:

Java

FirebaseCustomRemoteModel remoteModel = new FirebaseCustomRemoteModel.Builder("your_model").build();
FirebaseModelManager.getInstance().getLatestModelFile(remoteModel)
        .addOnCompleteListener(new OnCompleteListener<File>() {
            @Override
            public void onComplete(@NonNull Task<File> task) {
                File modelFile = task.getResult();
                if (modelFile != null) {
                    interpreter = new Interpreter(modelFile);
                }
            }
        });

Kotlin + KTX

val remoteModel = FirebaseCustomRemoteModel.Builder("your_model").build()
FirebaseModelManager.getInstance().getLatestModelFile(remoteModel)
    .addOnCompleteListener { task ->
        val modelFile = task.result
        if (modelFile != null) {
            interpreter = Interpreter(modelFile)
        }
    }

4. Effectuer une inférence sur les données d'entrée

Obtenez les formes d'entrée et de sortie de votre modèle

L'interpréteur de modèle TensorFlow Lite prend en entrée et produit en sortie un ou plusieurs tableaux multidimensionnels. Ces tableaux contiennent des valeurs byte , int , long ou float . Avant de pouvoir transmettre des données à un modèle ou utiliser son résultat, vous devez connaître le nombre et les dimensions («forme») des tableaux utilisés par votre modèle.

Si vous avez créé le modèle vous-même ou si le format d'entrée et de sortie du modèle est documenté, vous disposez peut-être déjà de ces informations. Si vous ne connaissez pas la forme et le type de données de l'entrée et de la sortie de votre modèle, vous pouvez utiliser l'interpréteur TensorFlow Lite pour inspecter votre modèle. Par exemple:

Python

import tensorflow as tf

interpreter = tf.lite.Interpreter(model_path="your_model.tflite")
interpreter.allocate_tensors()

# Print input shape and type
inputs = interpreter.get_input_details()
print('{} input(s):'.format(len(inputs)))
for i in range(0, len(inputs)):
    print('{} {}'.format(inputs[i]['shape'], inputs[i]['dtype']))

# Print output shape and type
outputs = interpreter.get_output_details()
print('\n{} output(s):'.format(len(outputs)))
for i in range(0, len(outputs)):
    print('{} {}'.format(outputs[i]['shape'], outputs[i]['dtype']))

Exemple de sortie:

1 input(s):
[  1 224 224   3] <class 'numpy.float32'>

1 output(s):
[1 1000] <class 'numpy.float32'>

Exécutez l'interpréteur

Une fois que vous avez déterminé le format de l'entrée et de la sortie de votre modèle, récupérez vos données d'entrée et effectuez toutes les transformations sur les données nécessaires pour obtenir une entrée de la bonne forme pour votre modèle.

Par exemple, si vous avez un modèle de classification d'image avec une forme d'entrée de [1 224 224 3] valeurs à virgule flottante, vous pouvez générer un ByteBuffer entrée à partir d'un objet Bitmap , comme illustré dans l'exemple suivant:

Java

Bitmap bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true);
ByteBuffer input = ByteBuffer.allocateDirect(224 * 224 * 3 * 4).order(ByteOrder.nativeOrder());
for (int y = 0; y < 224; y++) {
    for (int x = 0; x < 224; x++) {
        int px = bitmap.getPixel(x, y);

        // Get channel values from the pixel value.
        int r = Color.red(px);
        int g = Color.green(px);
        int b = Color.blue(px);

        // Normalize channel values to [-1.0, 1.0]. This requirement depends
        // on the model. For example, some models might require values to be
        // normalized to the range [0.0, 1.0] instead.
        float rf = (r - 127) / 255.0f;
        float gf = (g - 127) / 255.0f;
        float bf = (b - 127) / 255.0f;

        input.putFloat(rf);
        input.putFloat(gf);
        input.putFloat(bf);
    }
}

Kotlin + KTX

val bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true)
val input = ByteBuffer.allocateDirect(224*224*3*4).order(ByteOrder.nativeOrder())
for (y in 0 until 224) {
    for (x in 0 until 224) {
        val px = bitmap.getPixel(x, y)

        // Get channel values from the pixel value.
        val r = Color.red(px)
        val g = Color.green(px)
        val b = Color.blue(px)

        // Normalize channel values to [-1.0, 1.0]. This requirement depends on the model.
        // For example, some models might require values to be normalized to the range
        // [0.0, 1.0] instead.
        val rf = (r - 127) / 255f
        val gf = (g - 127) / 255f
        val bf = (b - 127) / 255f

        input.putFloat(rf)
        input.putFloat(gf)
        input.putFloat(bf)
    }
}

Ensuite, ByteBuffer un ByteBuffer suffisamment grand pour contenir la sortie du modèle et passez le tampon d'entrée et le tampon de sortie à la méthode run() l'interpréteur TensorFlow Lite. Par exemple, pour une forme de sortie de [1 1000] valeurs à virgule flottante:

Java

int bufferSize = 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE;
ByteBuffer modelOutput = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder());
interpreter.run(input, modelOutput);

Kotlin + KTX

val bufferSize = 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE
val modelOutput = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder())
interpreter?.run(input, modelOutput)

La façon dont vous utilisez la sortie dépend du modèle que vous utilisez.

Par exemple, si vous effectuez une classification, à l'étape suivante, vous pouvez mapper les index du résultat aux étiquettes qu'ils représentent:

Java

modelOutput.rewind();
FloatBuffer probabilities = modelOutput.asFloatBuffer();
try {
    BufferedReader reader = new BufferedReader(
            new InputStreamReader(getAssets().open("custom_labels.txt")));
    for (int i = 0; i < probabilities.capacity(); i++) {
        String label = reader.readLine();
        float probability = probabilities.get(i);
        Log.i(TAG, String.format("%s: %1.4f", label, probability));
    }
} catch (IOException e) {
    // File not found?
}

Kotlin + KTX

modelOutput.rewind()
val probabilities = modelOutput.asFloatBuffer()
try {
    val reader = BufferedReader(
            InputStreamReader(assets.open("custom_labels.txt")))
    for (i in probabilities.capacity()) {
        val label: String = reader.readLine()
        val probability = probabilities.get(i)
        println("$label: $probability")
    }
} catch (e: IOException) {
    // File not found?
}

Annexe: Revenir à un modèle groupé localement

Lorsque vous hébergez votre modèle avec Firebase, aucune fonctionnalité liée au modèle ne sera disponible tant que votre application n'aura pas téléchargé le modèle pour la première fois. Pour certaines applications, cela peut convenir, mais si votre modèle active les fonctionnalités de base, vous souhaiterez peut-être regrouper une version de votre modèle avec votre application et utiliser la meilleure version disponible. Ce faisant, vous pouvez vous assurer que les fonctionnalités ML de votre application fonctionnent lorsque le modèle hébergé par Firebase n'est pas disponible.

Pour regrouper votre modèle TensorFlow Lite avec votre application:

  1. Copiez le fichier de modèle (se terminant généralement par .tflite ou .lite ) dans les assets/ dossier de votre application. (Vous devrez peut-être d'abord créer le dossier en cliquant avec le bouton droit sur l' app/ dossier, puis en cliquant sur Nouveau> Dossier> Dossier d'actifs .)

  2. Ajoutez ce qui suit au fichier build.gradle votre application pour vous assurer que Gradle ne compresse pas les modèles lors de la création de l'application:

    android {
    
        // ...
    
        aaptOptions {
            noCompress "tflite", "lite"
        }
    }
    

Ensuite, utilisez le modèle fourni localement lorsque le modèle hébergé n'est pas disponible:

Java

FirebaseCustomRemoteModel remoteModel =
        new FirebaseCustomRemoteModel.Builder("your_model").build();
FirebaseModelManager.getInstance().getLatestModelFile(remoteModel)
        .addOnCompleteListener(new OnCompleteListener<File>() {
            @Override
            public void onComplete(@NonNull Task<File> task) {
                File modelFile = task.getResult();
                if (modelFile != null) {
                    interpreter = new Interpreter(modelFile);
                } else {
                    try {
                        InputStream inputStream = getAssets().open("your_fallback_model.tflite");
                        byte[] model = new byte[inputStream.available()];
                        inputStream.read(model);
                        ByteBuffer buffer = ByteBuffer.allocateDirect(model.length)
                                .order(ByteOrder.nativeOrder());
                        buffer.put(model);
                        interpreter = new Interpreter(buffer);
                    } catch (IOException e) {
                        // File not found?
                    }
                }
            }
        });

Kotlin + KTX

val remoteModel = FirebaseCustomRemoteModel.Builder("your_model").build()
FirebaseModelManager.getInstance().getLatestModelFile(remoteModel)
    .addOnCompleteListener { task ->
        val modelFile = task.result
        if (modelFile != null) {
            interpreter = Interpreter(modelFile)
        } else {
            val model = assets.open("your_fallback_model.tflite").readBytes()
            val buffer = ByteBuffer.allocateDirect(model.size).order(ByteOrder.nativeOrder())
            buffer.put(model)
            interpreter = Interpreter(buffer)
        }
    }

Annexe: sécurité du modèle

Quelle que soit la façon dont vous mettez vos modèles TensorFlow Lite à disposition de Firebase ML, Firebase ML les stocke au format protobuf sérialisé standard dans le stockage local.

En théorie, cela signifie que n'importe qui peut copier votre modèle. Cependant, dans la pratique, la plupart des modèles sont tellement spécifiques à l'application et obscurcis par les optimisations que le risque est similaire à celui des concurrents désassemblant et réutilisant votre code. Néanmoins, vous devez être conscient de ce risque avant d'utiliser un modèle personnalisé dans votre application.

Sur l'API Android de niveau 21 (Lollipop) et plus récent, le modèle est téléchargé dans un répertoire exclu de la sauvegarde automatique .

Sur l'API Android de niveau 20 et plus ancien, le modèle est téléchargé dans un répertoire nommé com.google.firebase.ml.custom.models dans le stockage interne privé de l'application. Si vous avez activé la sauvegarde de fichiers à l'aide de BackupAgent , vous pouvez choisir d'exclure ce répertoire.