Migracja ze starszej wersji interfejsu API modelu niestandardowego

W wersji 22.0.2 biblioteki firebase-ml-model-interpreter wprowadzamy getLatestModelFile(), która pobiera lokalizację urządzenia niestandardowego modeli ML. Za pomocą tej metody możesz bezpośrednio utworzyć instancję TensorFlow Lite Interpreter, którego możesz użyć zamiast obiektu Kod FirebaseModelInterpreter.

Od tej pory jest to preferowane podejście. Ponieważ TensorFlow Lite wersja interpretera nie jest już połączona z wersją biblioteki Firebase, będziesz mieć większą elastyczność przy przejściu na nowe wersje TensorFlow Lite, jeśli lub z łatwością wykorzystać niestandardowe kompilacje TensorFlow Lite.

Na tej stronie dowiesz się, jak przejść z FirebaseModelInterpreter na TensorFlow Lite Interpreter.

1. Zaktualizuj zależności projektu

Zaktualizuj zależności projektu, tak aby zawierały wersję 22.0.2 biblioteka firebase-ml-model-interpreter (lub nowsza) oraz tensorflow-lite, biblioteka:

Przed

implementation("com.google.firebase:firebase-ml-model-interpreter:22.0.1")

Po

implementation("com.google.firebase:firebase-ml-model-interpreter:22.0.2")
implementation("org.tensorflow:tensorflow-lite:2.0.0")

2. Tworzenie interpretera TensorFlow Lite zamiast FirebaseModelInterpreter

Zamiast tworzyć obiekt FirebaseModelInterpreter, włącz lokalizację modelu na urządzeniu z getLatestModelFile() i użyj go do utworzenia TensorFlow Lite Interpreter

Przed

Kotlin+KTX

val remoteModel = FirebaseCustomRemoteModel.Builder("your_model").build()
val options = FirebaseModelInterpreterOptions.Builder(remoteModel).build()
val interpreter = FirebaseModelInterpreter.getInstance(options)

Java

FirebaseCustomRemoteModel remoteModel =
        new FirebaseCustomRemoteModel.Builder("your_model").build();
FirebaseModelInterpreterOptions options =
        new FirebaseModelInterpreterOptions.Builder(remoteModel).build();
FirebaseModelInterpreter interpreter = FirebaseModelInterpreter.getInstance(options);

Po

Kotlin+KTX

val remoteModel = FirebaseCustomRemoteModel.Builder("your_model").build()
FirebaseModelManager.getInstance().getLatestModelFile(remoteModel)
    .addOnCompleteListener { task ->
        val modelFile = task.getResult()
        if (modelFile != null) {
            // Instantiate an org.tensorflow.lite.Interpreter object.
            interpreter = Interpreter(modelFile)
        }
    }

Java

FirebaseCustomRemoteModel remoteModel =
        new FirebaseCustomRemoteModel.Builder("your_model").build();
FirebaseModelManager.getInstance().getLatestModelFile(remoteModel)
        .addOnCompleteListener(new OnCompleteListener<File>() {
            @Override
            public void onComplete(@NonNull Task<File> task) {
                File modelFile = task.getResult();
                if (modelFile != null) {
                    // Instantiate an org.tensorflow.lite.Interpreter object.
                    Interpreter interpreter = new Interpreter(modelFile);
                }
            }
        });

3. Zaktualizuj kod przygotowania danych wejściowych i wyjściowych

Funkcja FirebaseModelInterpreter pozwala określić kształty wejściowe i wyjściowe modelu przez przekazanie obiektu FirebaseModelInputOutputOptions do interpretera, gdy musisz go uruchomić.

W przypadku interpretera TensorFlow Lite przydzielono obiekty ByteBuffer z odpowiednim rozmiarem danych wejściowych i wyjściowych modelu.

Jeśli na przykład model ma dane wejściowe o kształcie: [1 224 224 3] float i wyjściowym kształtem [11000] wartości float, wprowadź te zmiany:

Przed

Kotlin+KTX

val inputOutputOptions = FirebaseModelInputOutputOptions.Builder()
    .setInputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1, 224, 224, 3))
    .setOutputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1, 1000))
    .build()

val input = ByteBuffer.allocateDirect(224*224*3*4).order(ByteOrder.nativeOrder())
// Then populate with input data.

val inputs = FirebaseModelInputs.Builder()
    .add(input)
    .build()

interpreter.run(inputs, inputOutputOptions)
    .addOnSuccessListener { outputs ->
        // ...
    }
    .addOnFailureListener {
        // Task failed with an exception.
        // ...
    }

Java

FirebaseModelInputOutputOptions inputOutputOptions =
        new FirebaseModelInputOutputOptions.Builder()
                .setInputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 224, 224, 3})
                .setOutputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 1000})
                .build();

float[][][][] input = new float[1][224][224][3];
// Then populate with input data.

FirebaseModelInputs inputs = new FirebaseModelInputs.Builder()
        .add(input)
        .build();

interpreter.run(inputs, inputOutputOptions)
        .addOnSuccessListener(
                new OnSuccessListener<FirebaseModelOutputs>() {
                    @Override
                    public void onSuccess(FirebaseModelOutputs result) {
                        // ...
                    }
                })
        .addOnFailureListener(
                new OnFailureListener() {
                    @Override
                    public void onFailure(@NonNull Exception e) {
                        // Task failed with an exception
                        // ...
                    }
                });

Po

Kotlin+KTX

val inBufferSize = 1 * 224 * 224 * 3 * java.lang.Float.SIZE / java.lang.Byte.SIZE
val inputBuffer = ByteBuffer.allocateDirect(inBufferSize).order(ByteOrder.nativeOrder())
// Then populate with input data.

val outBufferSize = 1 * 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE
val outputBuffer = ByteBuffer.allocateDirect(outBufferSize).order(ByteOrder.nativeOrder())

interpreter.run(inputBuffer, outputBuffer)

Java

int inBufferSize = 1 * 224 * 224 * 3 * java.lang.Float.SIZE / java.lang.Byte.SIZE;
ByteBuffer inputBuffer =
        ByteBuffer.allocateDirect(inBufferSize).order(ByteOrder.nativeOrder());
// Then populate with input data.

int outBufferSize = 1 * 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE;
ByteBuffer outputBuffer =
        ByteBuffer.allocateDirect(outBufferSize).order(ByteOrder.nativeOrder());

interpreter.run(inputBuffer, outputBuffer);

4. Zaktualizuj kod obsługi danych wyjściowych

Zamiast pobierania danych wyjściowych modelu za pomocą funkcji FirebaseModelOutputs metody getOutput() obiektu, przekonwertuj dane wyjściowe ByteBuffer na dowolną jest wygodna w Twoim przypadku.

Na przykład, jeśli przeprowadzasz klasyfikację, możesz wprowadzić zmiany takie jak :

Przed

Kotlin+KTX

val output = result.getOutput(0)
val probabilities = output[0]
try {
    val reader = BufferedReader(InputStreamReader(assets.open("custom_labels.txt")))
    for (probability in probabilities) {
        val label: String = reader.readLine()
        println("$label: $probability")
    }
} catch (e: IOException) {
    // File not found?
}

Java

float[][] output = result.getOutput(0);
float[] probabilities = output[0];
try {
    BufferedReader reader = new BufferedReader(
          new InputStreamReader(getAssets().open("custom_labels.txt")));
    for (float probability : probabilities) {
        String label = reader.readLine();
        Log.i(TAG, String.format("%s: %1.4f", label, probability));
    }
} catch (IOException e) {
    // File not found?
}

Po

Kotlin+KTX

modelOutput.rewind()
val probabilities = modelOutput.asFloatBuffer()
try {
    val reader = BufferedReader(
            InputStreamReader(assets.open("custom_labels.txt")))
    for (i in probabilities.capacity()) {
        val label: String = reader.readLine()
        val probability = probabilities.get(i)
        println("$label: $probability")
    }
} catch (e: IOException) {
    // File not found?
}

Java

modelOutput.rewind();
FloatBuffer probabilities = modelOutput.asFloatBuffer();
try {
    BufferedReader reader = new BufferedReader(
            new InputStreamReader(getAssets().open("custom_labels.txt")));
    for (int i = 0; i < probabilities.capacity(); i++) {
        String label = reader.readLine();
        float probability = probabilities.get(i);
        Log.i(TAG, String.format("%s: %1.4f", label, probability));
    }
} catch (IOException e) {
    // File not found?
}