Check out what’s new from Firebase at Google I/O 2022. Learn more

레거시 사용자 지정 모델 API에서 마이그레이션

firebase firebase-ml-model-interpreter 라이브러리 버전 22.0.2에는 맞춤 모델의 기기에서 위치를 가져오는 새로운 getLatestModelFile() 메서드가 도입되었습니다. 이 메서드를 사용하여 FirebaseModelInterpreter 래퍼 대신 사용할 수 있는 TensorFlow Lite Interpreter 객체를 직접 인스턴스화할 수 있습니다.

앞으로는 이것이 선호되는 접근 방식입니다. TensorFlow Lite 인터프리터 버전은 더 이상 Firebase 라이브러리 버전과 연결되어 있지 않으므로 원할 때 TensorFlow Lite의 새 버전으로 업그레이드하거나 맞춤형 TensorFlow Lite 빌드를 더 쉽게 사용할 수 있습니다.

이 페이지에서는 FirebaseModelInterpreter 를 사용하여 TensorFlow Lite Interpreter 로 마이그레이션하는 방법을 보여줍니다.

1. 프로젝트 종속성 업데이트

firebase firebase-ml-model-interpreter 라이브러리(또는 그 이상) 및 tensorflow-lite 라이브러리 버전 22.0.2를 포함하도록 프로젝트의 종속성을 업데이트합니다.

전에

implementation 'com.google.firebase:firebase-ml-model-interpreter:22.0.1'

후에

implementation 'com.google.firebase:firebase-ml-model-interpreter:22.0.2'
implementation 'org.tensorflow:tensorflow-lite:2.0.0'

2. FirebaseModelInterpreter 대신 TensorFlow Lite 인터프리터 만들기

FirebaseModelInterpreter 를 만드는 대신 getLatestModelFile() 을 사용하여 기기에서 모델의 위치를 ​​가져와 TensorFlow Lite Interpreter 를 만드는 데 사용합니다.

전에

Java

FirebaseCustomRemoteModel remoteModel =
        new FirebaseCustomRemoteModel.Builder("your_model").build();
FirebaseModelInterpreterOptions options =
        new FirebaseModelInterpreterOptions.Builder(remoteModel).build();
FirebaseModelInterpreter interpreter = FirebaseModelInterpreter.getInstance(options);

Kotlin+KTX

val remoteModel = FirebaseCustomRemoteModel.Builder("your_model").build()
val options = FirebaseModelInterpreterOptions.Builder(remoteModel).build()
val interpreter = FirebaseModelInterpreter.getInstance(options)

후에

Java

FirebaseCustomRemoteModel remoteModel =
        new FirebaseCustomRemoteModel.Builder("your_model").build();
FirebaseModelManager.getInstance().getLatestModelFile(remoteModel)
        .addOnCompleteListener(new OnCompleteListener<File>() {
            @Override
            public void onComplete(@NonNull Task<File> task) {
                File modelFile = task.getResult();
                if (modelFile != null) {
                    // Instantiate an org.tensorflow.lite.Interpreter object.
                    Interpreter interpreter = new Interpreter(modelFile);
                }
            }
        });

Kotlin+KTX

val remoteModel = FirebaseCustomRemoteModel.Builder("your_model").build()
FirebaseModelManager.getInstance().getLatestModelFile(remoteModel)
    .addOnCompleteListener { task ->
        val modelFile = task.getResult()
        if (modelFile != null) {
            // Instantiate an org.tensorflow.lite.Interpreter object.
            interpreter = Interpreter(modelFile)
        }
    }

3. 입력 및 출력 준비 코드 업데이트

FirebaseModelInterpreter 를 사용하면 실행할 때 FirebaseModelInputOutputOptions 객체를 인터프리터에 전달하여 모델의 입력 및 출력 형태를 지정합니다.

TensorFlow Lite 인터프리터의 경우 대신 모델의 입력 및 출력에 적합한 크기로 ByteBuffer 객체를 할당합니다.

예를 들어, 모델의 입력 형태가 [1 224 224 3] float 값이고 출력 형태가 [1 1000] float 값인 경우 다음과 같이 변경합니다.

전에

Java

FirebaseModelInputOutputOptions inputOutputOptions =
        new FirebaseModelInputOutputOptions.Builder()
                .setInputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 224, 224, 3})
                .setOutputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 1000})
                .build();

float[][][][] input = new float[1][224][224][3];
// Then populate with input data.

FirebaseModelInputs inputs = new FirebaseModelInputs.Builder()
        .add(input)
        .build();

interpreter.run(inputs, inputOutputOptions)
        .addOnSuccessListener(
                new OnSuccessListener<FirebaseModelOutputs>() {
                    @Override
                    public void onSuccess(FirebaseModelOutputs result) {
                        // ...
                    }
                })
        .addOnFailureListener(
                new OnFailureListener() {
                    @Override
                    public void onFailure(@NonNull Exception e) {
                        // Task failed with an exception
                        // ...
                    }
                });

Kotlin+KTX

val inputOutputOptions = FirebaseModelInputOutputOptions.Builder()
    .setInputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1, 224, 224, 3))
    .setOutputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1, 1000))
    .build()

val input = ByteBuffer.allocateDirect(224*224*3*4).order(ByteOrder.nativeOrder())
// Then populate with input data.

val inputs = FirebaseModelInputs.Builder()
    .add(input)
    .build()

interpreter.run(inputs, inputOutputOptions)
    .addOnSuccessListener { outputs ->
        // ...
    }
    .addOnFailureListener {
        // Task failed with an exception.
        // ...
    }

후에

Java

int inBufferSize = 1 * 224 * 224 * 3 * java.lang.Float.SIZE / java.lang.Byte.SIZE;
ByteBuffer inputBuffer =
        ByteBuffer.allocateDirect(inBufferSize).order(ByteOrder.nativeOrder());
// Then populate with input data.

int outBufferSize = 1 * 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE;
ByteBuffer outputBuffer =
        ByteBuffer.allocateDirect(outBufferSize).order(ByteOrder.nativeOrder());

interpreter.run(inputBuffer, outputBuffer);

Kotlin+KTX

val inBufferSize = 1 * 224 * 224 * 3 * java.lang.Float.SIZE / java.lang.Byte.SIZE
val inputBuffer = ByteBuffer.allocateDirect(inBufferSize).order(ByteOrder.nativeOrder())
// Then populate with input data.

val outBufferSize = 1 * 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE
val outputBuffer = ByteBuffer.allocateDirect(outBufferSize).order(ByteOrder.nativeOrder())

interpreter.run(inputBuffer, outputBuffer)

4. 출력 처리 코드 업데이트

마지막으로 FirebaseModelOutputs 객체의 getOutput() 메서드를 사용하여 모델의 출력을 가져오는 대신 ByteBuffer 출력을 사용 사례에 편리한 구조로 변환합니다.

예를 들어 분류를 수행하는 경우 다음과 같이 변경할 수 있습니다.

전에

Java

float[][] output = result.getOutput(0);
float[] probabilities = output[0];
try {
    BufferedReader reader = new BufferedReader(
          new InputStreamReader(getAssets().open("custom_labels.txt")));
    for (float probability : probabilities) {
        String label = reader.readLine();
        Log.i(TAG, String.format("%s: %1.4f", label, probability));
    }
} catch (IOException e) {
    // File not found?
}

Kotlin+KTX

val output = result.getOutput(0)
val probabilities = output[0]
try {
    val reader = BufferedReader(InputStreamReader(assets.open("custom_labels.txt")))
    for (probability in probabilities) {
        val label: String = reader.readLine()
        println("$label: $probability")
    }
} catch (e: IOException) {
    // File not found?
}

후에

Java

modelOutput.rewind();
FloatBuffer probabilities = modelOutput.asFloatBuffer();
try {
    BufferedReader reader = new BufferedReader(
            new InputStreamReader(getAssets().open("custom_labels.txt")));
    for (int i = 0; i < probabilities.capacity(); i++) {
        String label = reader.readLine();
        float probability = probabilities.get(i);
        Log.i(TAG, String.format("%s: %1.4f", label, probability));
    }
} catch (IOException e) {
    // File not found?
}

Kotlin+KTX

modelOutput.rewind()
val probabilities = modelOutput.asFloatBuffer()
try {
    val reader = BufferedReader(
            InputStreamReader(assets.open("custom_labels.txt")))
    for (i in probabilities.capacity()) {
        val label: String = reader.readLine()
        val probability = probabilities.get(i)
        println("$label: $probability")
    }
} catch (e: IOException) {
    // File not found?
}