Catch up on everything announced at Firebase Summit, and learn how Firebase can help you accelerate app development and run your app with confidence. Learn More

Przeprowadź migrację ze starszego interfejsu API modelu niestandardowego

Zadbaj o dobrą organizację dzięki kolekcji Zapisuj i kategoryzuj treści zgodnie ze swoimi preferencjami.

Wersja 22.0.2 biblioteki firebase-ml-model-interpreter wprowadza nową getLatestModelFile() , która pobiera lokalizację modeli niestandardowych na urządzeniu. Możesz użyć tej metody do bezpośredniego utworzenia instancji obiektu TensorFlow Lite Interpreter , którego można użyć zamiast opakowania FirebaseModelInterpreter .

W przyszłości jest to preferowane podejście. Ponieważ wersja interpretera TensorFlow Lite nie jest już połączona z wersją biblioteki Firebase, masz większą elastyczność w zakresie uaktualniania do nowych wersji TensorFlow Lite, kiedy chcesz, lub łatwiejszego korzystania z niestandardowych kompilacji TensorFlow Lite.

Ta strona pokazuje, jak przejść z FirebaseModelInterpreter do TensorFlow Lite Interpreter .

1. Zaktualizuj zależności projektu

Zaktualizuj zależności projektu, aby uwzględnić wersję 22.0.2 biblioteki firebase-ml-model-interpreter (lub nowszą) oraz bibliotekę tensorflow-lite :

Zanim

implementation 'com.google.firebase:firebase-ml-model-interpreter:22.0.1'

Później

implementation 'com.google.firebase:firebase-ml-model-interpreter:22.0.2'
implementation 'org.tensorflow:tensorflow-lite:2.0.0'

2. Utwórz interpreter TensorFlow Lite zamiast interpretera FirebaseModel

Zamiast tworzyć FirebaseModelInterpreter , pobierz lokalizację modelu na urządzeniu za pomocą getLatestModelFile() i użyj go do utworzenia Interpreter TensorFlow Lite .

Zanim

Java

FirebaseCustomRemoteModel remoteModel =
        new FirebaseCustomRemoteModel.Builder("your_model").build();
FirebaseModelInterpreterOptions options =
        new FirebaseModelInterpreterOptions.Builder(remoteModel).build();
FirebaseModelInterpreter interpreter = FirebaseModelInterpreter.getInstance(options);

Kotlin+KTX

val remoteModel = FirebaseCustomRemoteModel.Builder("your_model").build()
val options = FirebaseModelInterpreterOptions.Builder(remoteModel).build()
val interpreter = FirebaseModelInterpreter.getInstance(options)

Później

Java

FirebaseCustomRemoteModel remoteModel =
        new FirebaseCustomRemoteModel.Builder("your_model").build();
FirebaseModelManager.getInstance().getLatestModelFile(remoteModel)
        .addOnCompleteListener(new OnCompleteListener<File>() {
            @Override
            public void onComplete(@NonNull Task<File> task) {
                File modelFile = task.getResult();
                if (modelFile != null) {
                    // Instantiate an org.tensorflow.lite.Interpreter object.
                    Interpreter interpreter = new Interpreter(modelFile);
                }
            }
        });

Kotlin+KTX

val remoteModel = FirebaseCustomRemoteModel.Builder("your_model").build()
FirebaseModelManager.getInstance().getLatestModelFile(remoteModel)
    .addOnCompleteListener { task ->
        val modelFile = task.getResult()
        if (modelFile != null) {
            // Instantiate an org.tensorflow.lite.Interpreter object.
            interpreter = Interpreter(modelFile)
        }
    }

3. Zaktualizuj kod przygotowania wejścia i wyjścia

Za pomocą FirebaseModelInterpreter określasz kształty wejściowe i wyjściowe modelu, przekazując obiekt FirebaseModelInputOutputOptions do interpretera podczas jego uruchamiania.

W przypadku interpretera TensorFlow Lite zamiast tego przydzielasz obiekty ByteBuffer o odpowiednim rozmiarze dla danych wejściowych i wyjściowych modelu.

Na przykład, jeśli twój model ma kształt wejściowy równy [1 224 224 3] wartości float i kształt wyjściowy równy [1 1000] wartości float , wprowadź następujące zmiany:

Zanim

Java

FirebaseModelInputOutputOptions inputOutputOptions =
        new FirebaseModelInputOutputOptions.Builder()
                .setInputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 224, 224, 3})
                .setOutputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 1000})
                .build();

float[][][][] input = new float[1][224][224][3];
// Then populate with input data.

FirebaseModelInputs inputs = new FirebaseModelInputs.Builder()
        .add(input)
        .build();

interpreter.run(inputs, inputOutputOptions)
        .addOnSuccessListener(
                new OnSuccessListener<FirebaseModelOutputs>() {
                    @Override
                    public void onSuccess(FirebaseModelOutputs result) {
                        // ...
                    }
                })
        .addOnFailureListener(
                new OnFailureListener() {
                    @Override
                    public void onFailure(@NonNull Exception e) {
                        // Task failed with an exception
                        // ...
                    }
                });

Kotlin+KTX

val inputOutputOptions = FirebaseModelInputOutputOptions.Builder()
    .setInputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1, 224, 224, 3))
    .setOutputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1, 1000))
    .build()

val input = ByteBuffer.allocateDirect(224*224*3*4).order(ByteOrder.nativeOrder())
// Then populate with input data.

val inputs = FirebaseModelInputs.Builder()
    .add(input)
    .build()

interpreter.run(inputs, inputOutputOptions)
    .addOnSuccessListener { outputs ->
        // ...
    }
    .addOnFailureListener {
        // Task failed with an exception.
        // ...
    }

Później

Java

int inBufferSize = 1 * 224 * 224 * 3 * java.lang.Float.SIZE / java.lang.Byte.SIZE;
ByteBuffer inputBuffer =
        ByteBuffer.allocateDirect(inBufferSize).order(ByteOrder.nativeOrder());
// Then populate with input data.

int outBufferSize = 1 * 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE;
ByteBuffer outputBuffer =
        ByteBuffer.allocateDirect(outBufferSize).order(ByteOrder.nativeOrder());

interpreter.run(inputBuffer, outputBuffer);

Kotlin+KTX

val inBufferSize = 1 * 224 * 224 * 3 * java.lang.Float.SIZE / java.lang.Byte.SIZE
val inputBuffer = ByteBuffer.allocateDirect(inBufferSize).order(ByteOrder.nativeOrder())
// Then populate with input data.

val outBufferSize = 1 * 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE
val outputBuffer = ByteBuffer.allocateDirect(outBufferSize).order(ByteOrder.nativeOrder())

interpreter.run(inputBuffer, outputBuffer)

4. Zaktualizuj kod obsługi wyjścia

Na koniec, zamiast pobierać dane wyjściowe modelu za pomocą metody FirebaseModelOutputs getOutput() obiektu FirebaseModelOutputs, przekonwertuj dane wyjściowe ByteBuffer na dowolną strukturę, która jest dogodna dla danego przypadku użycia.

Na przykład, jeśli robisz klasyfikację, możesz wprowadzić następujące zmiany:

Zanim

Java

float[][] output = result.getOutput(0);
float[] probabilities = output[0];
try {
    BufferedReader reader = new BufferedReader(
          new InputStreamReader(getAssets().open("custom_labels.txt")));
    for (float probability : probabilities) {
        String label = reader.readLine();
        Log.i(TAG, String.format("%s: %1.4f", label, probability));
    }
} catch (IOException e) {
    // File not found?
}

Kotlin+KTX

val output = result.getOutput(0)
val probabilities = output[0]
try {
    val reader = BufferedReader(InputStreamReader(assets.open("custom_labels.txt")))
    for (probability in probabilities) {
        val label: String = reader.readLine()
        println("$label: $probability")
    }
} catch (e: IOException) {
    // File not found?
}

Później

Java

modelOutput.rewind();
FloatBuffer probabilities = modelOutput.asFloatBuffer();
try {
    BufferedReader reader = new BufferedReader(
            new InputStreamReader(getAssets().open("custom_labels.txt")));
    for (int i = 0; i < probabilities.capacity(); i++) {
        String label = reader.readLine();
        float probability = probabilities.get(i);
        Log.i(TAG, String.format("%s: %1.4f", label, probability));
    }
} catch (IOException e) {
    // File not found?
}

Kotlin+KTX

modelOutput.rewind()
val probabilities = modelOutput.asFloatBuffer()
try {
    val reader = BufferedReader(
            InputStreamReader(assets.open("custom_labels.txt")))
    for (i in probabilities.capacity()) {
        val label: String = reader.readLine()
        val probability = probabilities.get(i)
        println("$label: $probability")
    }
} catch (e: IOException) {
    // File not found?
}