Migrer à partir de l'ancienne API de modèle personnalisé

Version 22.0.2 du firebase-ml-model-interpreter bibliothèque présente une nouvelle getLatestModelFile() méthode, qui obtient l'emplacement sur le périphérique de modèles personnalisés. Vous pouvez utiliser cette méthode directement instancier un tensorflow Lite Interpreter objet, que vous pouvez utiliser au lieu de l' FirebaseModelInterpreter emballage.

À l'avenir, c'est l'approche privilégiée. Étant donné que la version de l'interpréteur TensorFlow Lite n'est plus couplée à la version de la bibliothèque Firebase, vous avez plus de flexibilité pour mettre à niveau vers de nouvelles versions de TensorFlow Lite quand vous le souhaitez, ou utilisez plus facilement des builds TensorFlow Lite personnalisés.

Cette page montre comment vous pouvez migrer d'utiliser FirebaseModelInterpreter au tensorflow Lite Interpreter .

1. Mettre à jour les dépendances du projet

Mettez à jour votre dépendances de projet pour inclure la version 22.0.2 du firebase-ml-model-interpreter bibliothèque (ou plus récent) et la tensorflow-lite bibliothèque:

Avant que

implementation 'com.google.firebase:firebase-ml-model-interpreter:22.0.1'

Après

implementation 'com.google.firebase:firebase-ml-model-interpreter:22.0.2'
implementation 'org.tensorflow:tensorflow-lite:2.0.0'

2. Créez un interpréteur TensorFlow Lite au lieu d'un FirebaseModelInterpreter

Au lieu de créer un FirebaseModelInterpreter , obtenir l'emplacement du modèle sur l' appareil avec getLatestModelFile() et l' utiliser pour créer un tensorflow Lite Interpreter .

Avant que

Java

FirebaseCustomRemoteModel remoteModel =
        new FirebaseCustomRemoteModel.Builder("your_model").build();
FirebaseModelInterpreterOptions options =
        new FirebaseModelInterpreterOptions.Builder(remoteModel).build();
FirebaseModelInterpreter interpreter = FirebaseModelInterpreter.getInstance(options);

Kotlin+KTX

val remoteModel = FirebaseCustomRemoteModel.Builder("your_model").build()
val options = FirebaseModelInterpreterOptions.Builder(remoteModel).build()
val interpreter = FirebaseModelInterpreter.getInstance(options)

Après

Java

FirebaseCustomRemoteModel remoteModel =
        new FirebaseCustomRemoteModel.Builder("your_model").build();
FirebaseModelManager.getInstance().getLatestModelFile(remoteModel)
        .addOnCompleteListener(new OnCompleteListener<File>() {
            @Override
            public void onComplete(@NonNull Task<File> task) {
                File modelFile = task.getResult();
                if (modelFile != null) {
                    // Instantiate an org.tensorflow.lite.Interpreter object.
                    Interpreter interpreter = new Interpreter(modelFile);
                }
            }
        });

Kotlin+KTX

val remoteModel = FirebaseCustomRemoteModel.Builder("your_model").build()
FirebaseModelManager.getInstance().getLatestModelFile(remoteModel)
    .addOnCompleteListener { task ->
        val modelFile = task.getResult()
        if (modelFile != null) {
            // Instantiate an org.tensorflow.lite.Interpreter object.
            interpreter = Interpreter(modelFile)
        }
    }

3. Mettre à jour le code de préparation d'entrée et de sortie

Avec FirebaseModelInterpreter , vous spécifiez l'entrée du modèle et les formes de sortie en passant un FirebaseModelInputOutputOptions objet à l'interprète lorsque vous l' exécutez.

Pour l'interprète tensorflow Lite, vous à la place allouez ByteBuffer objets à la bonne taille pour l'entrée et la sortie de votre modèle.

Par exemple, si votre modèle a une forme d'entrée [1 224 224 3] float valeurs et une forme de sortie [1] 1000 float valeurs, effectuer ces modifications:

Avant que

Java

FirebaseModelInputOutputOptions inputOutputOptions =
        new FirebaseModelInputOutputOptions.Builder()
                .setInputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 224, 224, 3})
                .setOutputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 1000})
                .build();

float[][][][] input = new float[1][224][224][3];
// Then populate with input data.

FirebaseModelInputs inputs = new FirebaseModelInputs.Builder()
        .add(input)
        .build();

interpreter.run(inputs, inputOutputOptions)
        .addOnSuccessListener(
                new OnSuccessListener<FirebaseModelOutputs>() {
                    @Override
                    public void onSuccess(FirebaseModelOutputs result) {
                        // ...
                    }
                })
        .addOnFailureListener(
                new OnFailureListener() {
                    @Override
                    public void onFailure(@NonNull Exception e) {
                        // Task failed with an exception
                        // ...
                    }
                });

Kotlin+KTX

val inputOutputOptions = FirebaseModelInputOutputOptions.Builder()
    .setInputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1, 224, 224, 3))
    .setOutputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1, 1000))
    .build()

val input = ByteBuffer.allocateDirect(224*224*3*4).order(ByteOrder.nativeOrder())
// Then populate with input data.

val inputs = FirebaseModelInputs.Builder()
    .add(input)
    .build()

interpreter.run(inputs, inputOutputOptions)
    .addOnSuccessListener { outputs ->
        // ...
    }
    .addOnFailureListener {
        // Task failed with an exception.
        // ...
    }

Après

Java

int inBufferSize = 1 * 224 * 224 * 3 * java.lang.Float.SIZE / java.lang.Byte.SIZE;
ByteBuffer inputBuffer =
        ByteBuffer.allocateDirect(inBufferSize).order(ByteOrder.nativeOrder());
// Then populate with input data.

int outBufferSize = 1 * 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE;
ByteBuffer outputBuffer =
        ByteBuffer.allocateDirect(outBufferSize).order(ByteOrder.nativeOrder());

interpreter.run(inputBuffer, outputBuffer);

Kotlin+KTX

val inBufferSize = 1 * 224 * 224 * 3 * java.lang.Float.SIZE / java.lang.Byte.SIZE
val inputBuffer = ByteBuffer.allocateDirect(inBufferSize).order(ByteOrder.nativeOrder())
// Then populate with input data.

val outBufferSize = 1 * 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE
val outputBuffer = ByteBuffer.allocateDirect(outBufferSize).order(ByteOrder.nativeOrder())

interpreter.run(inputBuffer, outputBuffer)

4. Mettre à jour le code de gestion des sorties

Enfin, au lieu d'obtenir la sortie du modèle avec le FirebaseModelOutputs de l' objet getOutput() méthode, convertir la ByteBuffer sortie quelle que soit la structure est pratique pour votre cas d'utilisation.

Par exemple, si vous effectuez une classification, vous pouvez apporter les modifications suivantes :

Avant que

Java

float[][] output = result.getOutput(0);
float[] probabilities = output[0];
try {
    BufferedReader reader = new BufferedReader(
          new InputStreamReader(getAssets().open("custom_labels.txt")));
    for (float probability : probabilities) {
        String label = reader.readLine();
        Log.i(TAG, String.format("%s: %1.4f", label, probability));
    }
} catch (IOException e) {
    // File not found?
}

Kotlin+KTX

val output = result.getOutput(0)
val probabilities = output[0]
try {
    val reader = BufferedReader(InputStreamReader(assets.open("custom_labels.txt")))
    for (probability in probabilities) {
        val label: String = reader.readLine()
        println("$label: $probability")
    }
} catch (e: IOException) {
    // File not found?
}

Après

Java

modelOutput.rewind();
FloatBuffer probabilities = modelOutput.asFloatBuffer();
try {
    BufferedReader reader = new BufferedReader(
            new InputStreamReader(getAssets().open("custom_labels.txt")));
    for (int i = 0; i < probabilities.capacity(); i++) {
        String label = reader.readLine();
        float probability = probabilities.get(i);
        Log.i(TAG, String.format("%s: %1.4f", label, probability));
    }
} catch (IOException e) {
    // File not found?
}

Kotlin+KTX

modelOutput.rewind()
val probabilities = modelOutput.asFloatBuffer()
try {
    val reader = BufferedReader(
            InputStreamReader(assets.open("custom_labels.txt")))
    for (i in probabilities.capacity()) {
        val label: String = reader.readLine()
        val probability = probabilities.get(i)
        println("$label: $probability")
    }
} catch (e: IOException) {
    // File not found?
}