Di chuyển từ API mô hình tuỳ chỉnh cũ

Thư viện firebase-ml-model-interpreter phiên bản 22.0.2 giới thiệu một phương thức getLatestModelFile() mới. Phương thức này nhận thông tin vị trí trên thiết bị của các mô hình tuỳ chỉnh. Bạn có thể sử dụng phương thức này để tạo bản sao trực tiếp đối tượng Interpreter của TensorFlow Lite. Bạn có thể sử dụng đối tượng này thay cho trình bao bọc FirebaseModelInterpreter.

Từ giờ trở đi, đây là phương pháp ưu tiên. Vì phiên bản trình thông dịch TensorFlow Lite không còn được ghép nối với phiên bản thư viện Firebase, nên bạn có thể linh hoạt nâng cấp lên các phiên bản mới của TensorFlow Lite khi muốn hoặc dễ dàng sử dụng các bản dựng TensorFlow Lite tuỳ chỉnh hơn.

Trang này cho biết cách bạn có thể di chuyển từ việc sử dụng FirebaseModelInterpreter sang Interpreter TensorFlow Lite.

1. Cập nhật phần phụ thuộc của dự án

Cập nhật các phần phụ thuộc của dự án để bao gồm phiên bản 22.0.2 của thư viện firebase-ml-model-interpreter (hoặc mới hơn) và thư viện tensorflow-lite:

Trước

implementation("com.google.firebase:firebase-ml-model-interpreter:22.0.1")

Sau

implementation("com.google.firebase:firebase-ml-model-interpreter:22.0.2")
implementation("org.tensorflow:tensorflow-lite:2.0.0")

2. Tạo trình diễn giải TensorFlow Lite thay vì FirebaseModelInterpreter

Thay vì tạo FirebaseModelInterpreter, hãy lấy vị trí của mô hình trên thiết bị bằng getLatestModelFile() và sử dụng vị trí đó để tạo Interpreter TensorFlow Lite.

Trước

Kotlin+KTX

val remoteModel = FirebaseCustomRemoteModel.Builder("your_model").build()
val options = FirebaseModelInterpreterOptions.Builder(remoteModel).build()
val interpreter = FirebaseModelInterpreter.getInstance(options)

Java

FirebaseCustomRemoteModel remoteModel =
        new FirebaseCustomRemoteModel.Builder("your_model").build();
FirebaseModelInterpreterOptions options =
        new FirebaseModelInterpreterOptions.Builder(remoteModel).build();
FirebaseModelInterpreter interpreter = FirebaseModelInterpreter.getInstance(options);

Sau

Kotlin+KTX

val remoteModel = FirebaseCustomRemoteModel.Builder("your_model").build()
FirebaseModelManager.getInstance().getLatestModelFile(remoteModel)
    .addOnCompleteListener { task ->
        val modelFile = task.getResult()
        if (modelFile != null) {
            // Instantiate an org.tensorflow.lite.Interpreter object.
            interpreter = Interpreter(modelFile)
        }
    }

Java

FirebaseCustomRemoteModel remoteModel =
        new FirebaseCustomRemoteModel.Builder("your_model").build();
FirebaseModelManager.getInstance().getLatestModelFile(remoteModel)
        .addOnCompleteListener(new OnCompleteListener<File>() {
            @Override
            public void onComplete(@NonNull Task<File> task) {
                File modelFile = task.getResult();
                if (modelFile != null) {
                    // Instantiate an org.tensorflow.lite.Interpreter object.
                    Interpreter interpreter = new Interpreter(modelFile);
                }
            }
        });

3. Cập nhật mã chuẩn bị đầu vào và đầu ra

Với FirebaseModelInterpreter, bạn chỉ định các hình dạng đầu vào và đầu ra của mô hình bằng cách truyền một đối tượng FirebaseModelInputOutputOptions đến trình thông dịch khi chạy đối tượng đó.

Đối với trình diễn giải TensorFlow Lite, bạn sẽ phân bổ các đối tượng ByteBuffer có kích thước phù hợp cho dữ liệu đầu vào và đầu ra của mô hình.

Ví dụ: nếu mô hình của bạn có hình dạng đầu vào là các giá trị float [1 224 224 3] và hình dạng đầu ra là các giá trị float [1 1000], hãy thực hiện những thay đổi sau:

Trước

Kotlin+KTX

val inputOutputOptions = FirebaseModelInputOutputOptions.Builder()
    .setInputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1, 224, 224, 3))
    .setOutputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1, 1000))
    .build()

val input = ByteBuffer.allocateDirect(224*224*3*4).order(ByteOrder.nativeOrder())
// Then populate with input data.

val inputs = FirebaseModelInputs.Builder()
    .add(input)
    .build()

interpreter.run(inputs, inputOutputOptions)
    .addOnSuccessListener { outputs ->
        // ...
    }
    .addOnFailureListener {
        // Task failed with an exception.
        // ...
    }

Java

FirebaseModelInputOutputOptions inputOutputOptions =
        new FirebaseModelInputOutputOptions.Builder()
                .setInputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 224, 224, 3})
                .setOutputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 1000})
                .build();

float[][][][] input = new float[1][224][224][3];
// Then populate with input data.

FirebaseModelInputs inputs = new FirebaseModelInputs.Builder()
        .add(input)
        .build();

interpreter.run(inputs, inputOutputOptions)
        .addOnSuccessListener(
                new OnSuccessListener<FirebaseModelOutputs>() {
                    @Override
                    public void onSuccess(FirebaseModelOutputs result) {
                        // ...
                    }
                })
        .addOnFailureListener(
                new OnFailureListener() {
                    @Override
                    public void onFailure(@NonNull Exception e) {
                        // Task failed with an exception
                        // ...
                    }
                });

Sau

Kotlin+KTX

val inBufferSize = 1 * 224 * 224 * 3 * java.lang.Float.SIZE / java.lang.Byte.SIZE
val inputBuffer = ByteBuffer.allocateDirect(inBufferSize).order(ByteOrder.nativeOrder())
// Then populate with input data.

val outBufferSize = 1 * 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE
val outputBuffer = ByteBuffer.allocateDirect(outBufferSize).order(ByteOrder.nativeOrder())

interpreter.run(inputBuffer, outputBuffer)

Java

int inBufferSize = 1 * 224 * 224 * 3 * java.lang.Float.SIZE / java.lang.Byte.SIZE;
ByteBuffer inputBuffer =
        ByteBuffer.allocateDirect(inBufferSize).order(ByteOrder.nativeOrder());
// Then populate with input data.

int outBufferSize = 1 * 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE;
ByteBuffer outputBuffer =
        ByteBuffer.allocateDirect(outBufferSize).order(ByteOrder.nativeOrder());

interpreter.run(inputBuffer, outputBuffer);

4. Cập nhật mã xử lý đầu ra

Cuối cùng, thay vì nhận kết quả của mô hình bằng phương thức getOutput() của đối tượng FirebaseModelOutputs, hãy chuyển đổi đầu ra ByteBuffer sang bất kỳ cấu trúc nào thuận tiện cho trường hợp sử dụng của bạn.

Ví dụ: nếu đang phân loại, bạn có thể thực hiện các thay đổi như sau:

Trước

Kotlin+KTX

val output = result.getOutput(0)
val probabilities = output[0]
try {
    val reader = BufferedReader(InputStreamReader(assets.open("custom_labels.txt")))
    for (probability in probabilities) {
        val label: String = reader.readLine()
        println("$label: $probability")
    }
} catch (e: IOException) {
    // File not found?
}

Java

float[][] output = result.getOutput(0);
float[] probabilities = output[0];
try {
    BufferedReader reader = new BufferedReader(
          new InputStreamReader(getAssets().open("custom_labels.txt")));
    for (float probability : probabilities) {
        String label = reader.readLine();
        Log.i(TAG, String.format("%s: %1.4f", label, probability));
    }
} catch (IOException e) {
    // File not found?
}

Sau

Kotlin+KTX

modelOutput.rewind()
val probabilities = modelOutput.asFloatBuffer()
try {
    val reader = BufferedReader(
            InputStreamReader(assets.open("custom_labels.txt")))
    for (i in probabilities.capacity()) {
        val label: String = reader.readLine()
        val probability = probabilities.get(i)
        println("$label: $probability")
    }
} catch (e: IOException) {
    // File not found?
}

Java

modelOutput.rewind();
FloatBuffer probabilities = modelOutput.asFloatBuffer();
try {
    BufferedReader reader = new BufferedReader(
            new InputStreamReader(getAssets().open("custom_labels.txt")));
    for (int i = 0; i < probabilities.capacity(); i++) {
        String label = reader.readLine();
        float probability = probabilities.get(i);
        Log.i(TAG, String.format("%s: %1.4f", label, probability));
    }
} catch (IOException e) {
    // File not found?
}