Thư viện firebase-ml-model-interpreter
phiên bản 22.0.2 giới thiệu một phương thức getLatestModelFile()
mới, giúp lấy vị trí của các mô hình tuỳ chỉnh trên thiết bị. Bạn có thể dùng phương thức này để trực tiếp tạo thực thể cho một đối tượng Interpreter
TensorFlow Lite. Bạn có thể dùng đối tượng này thay cho trình bao bọc FirebaseModelInterpreter
.
Từ nay về sau, đây là phương pháp được ưu tiên. Vì phiên bản trình thông dịch TensorFlow Lite không còn liên kết với phiên bản thư viện Firebase nữa, nên bạn có thể linh hoạt hơn khi nâng cấp lên các phiên bản mới của TensorFlow Lite bất cứ khi nào bạn muốn hoặc dễ dàng sử dụng các bản dựng TensorFlow Lite tuỳ chỉnh.
Trang này cho biết cách bạn có thể di chuyển từ việc sử dụng FirebaseModelInterpreter
sang Interpreter
TensorFlow Lite.
1. Cập nhật các phần phụ thuộc của dự án
Cập nhật các phần phụ thuộc của dự án để thêm thư viện firebase-ml-model-interpreter
phiên bản 22.0.2 (hoặc mới hơn) và thư viện tensorflow-lite
:
Trước
implementation("com.google.firebase:firebase-ml-model-interpreter:22.0.1")
Sau
implementation("com.google.firebase:firebase-ml-model-interpreter:22.0.2")
implementation("org.tensorflow:tensorflow-lite:2.0.0")
2. Tạo một trình thông dịch TensorFlow Lite thay vì FirebaseModelInterpreter
Thay vì tạo một FirebaseModelInterpreter
, hãy lấy vị trí của mô hình trên thiết bị bằng getLatestModelFile()
và dùng vị trí đó để tạo một Interpreter
TensorFlow Lite.
Trước
val remoteModel = FirebaseCustomRemoteModel.Builder("your_model").build()
val options = FirebaseModelInterpreterOptions.Builder(remoteModel).build()
val interpreter = FirebaseModelInterpreter.getInstance(options)
FirebaseCustomRemoteModel remoteModel =
new FirebaseCustomRemoteModel.Builder("your_model").build();
FirebaseModelInterpreterOptions options =
new FirebaseModelInterpreterOptions.Builder(remoteModel).build();
FirebaseModelInterpreter interpreter = FirebaseModelInterpreter.getInstance(options);
Sau
val remoteModel = FirebaseCustomRemoteModel.Builder("your_model").build()
FirebaseModelManager.getInstance().getLatestModelFile(remoteModel)
.addOnCompleteListener { task ->
val modelFile = task.getResult()
if (modelFile != null) {
// Instantiate an org.tensorflow.lite.Interpreter object.
interpreter = Interpreter(modelFile)
}
}
FirebaseCustomRemoteModel remoteModel =
new FirebaseCustomRemoteModel.Builder("your_model").build();
FirebaseModelManager.getInstance().getLatestModelFile(remoteModel)
.addOnCompleteListener(new OnCompleteListener<File>() {
@Override
public void onComplete(@NonNull Task<File> task) {
File modelFile = task.getResult();
if (modelFile != null) {
// Instantiate an org.tensorflow.lite.Interpreter object.
Interpreter interpreter = new Interpreter(modelFile);
}
}
});
3. Cập nhật mã chuẩn bị đầu vào và đầu ra
Với FirebaseModelInterpreter
, bạn chỉ định các hình dạng đầu vào và đầu ra của mô hình bằng cách truyền một đối tượng FirebaseModelInputOutputOptions
đến trình thông dịch khi chạy.
Đối với trình thông dịch TensorFlow Lite, thay vào đó, bạn sẽ phân bổ các đối tượng ByteBuffer
có kích thước phù hợp cho đầu vào và đầu ra của mô hình.
Ví dụ: nếu mô hình của bạn có hình dạng đầu vào là [1 224 224 3] giá trị float
và hình dạng đầu ra là [1 1000] giá trị float
, hãy thực hiện những thay đổi sau:
Trước
val inputOutputOptions = FirebaseModelInputOutputOptions.Builder()
.setInputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1, 224, 224, 3))
.setOutputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1, 1000))
.build()
val input = ByteBuffer.allocateDirect(224*224*3*4).order(ByteOrder.nativeOrder())
// Then populate with input data.
val inputs = FirebaseModelInputs.Builder()
.add(input)
.build()
interpreter.run(inputs, inputOutputOptions)
.addOnSuccessListener { outputs ->
// ...
}
.addOnFailureListener {
// Task failed with an exception.
// ...
}
FirebaseModelInputOutputOptions inputOutputOptions =
new FirebaseModelInputOutputOptions.Builder()
.setInputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 224, 224, 3})
.setOutputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 1000})
.build();
float[][][][] input = new float[1][224][224][3];
// Then populate with input data.
FirebaseModelInputs inputs = new FirebaseModelInputs.Builder()
.add(input)
.build();
interpreter.run(inputs, inputOutputOptions)
.addOnSuccessListener(
new OnSuccessListener<FirebaseModelOutputs>() {
@Override
public void onSuccess(FirebaseModelOutputs result) {
// ...
}
})
.addOnFailureListener(
new OnFailureListener() {
@Override
public void onFailure(@NonNull Exception e) {
// Task failed with an exception
// ...
}
});
Sau
val inBufferSize = 1 * 224 * 224 * 3 * java.lang.Float.SIZE / java.lang.Byte.SIZE
val inputBuffer = ByteBuffer.allocateDirect(inBufferSize).order(ByteOrder.nativeOrder())
// Then populate with input data.
val outBufferSize = 1 * 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE
val outputBuffer = ByteBuffer.allocateDirect(outBufferSize).order(ByteOrder.nativeOrder())
interpreter.run(inputBuffer, outputBuffer)
int inBufferSize = 1 * 224 * 224 * 3 * java.lang.Float.SIZE / java.lang.Byte.SIZE;
ByteBuffer inputBuffer =
ByteBuffer.allocateDirect(inBufferSize).order(ByteOrder.nativeOrder());
// Then populate with input data.
int outBufferSize = 1 * 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE;
ByteBuffer outputBuffer =
ByteBuffer.allocateDirect(outBufferSize).order(ByteOrder.nativeOrder());
interpreter.run(inputBuffer, outputBuffer);
4. Cập nhật mã xử lý đầu ra
Cuối cùng, thay vì nhận đầu ra của mô hình bằng phương thức getOutput()
của đối tượng FirebaseModelOutputs
, hãy chuyển đổi đầu ra ByteBuffer
thành bất kỳ cấu trúc nào thuận tiện cho trường hợp sử dụng của bạn.
Ví dụ: nếu đang phân loại, bạn có thể thực hiện những thay đổi như sau:
Trước
val output = result.getOutput(0)
val probabilities = output[0]
try {
val reader = BufferedReader(InputStreamReader(assets.open("custom_labels.txt")))
for (probability in probabilities) {
val label: String = reader.readLine()
println("$label: $probability")
}
} catch (e: IOException) {
// File not found?
}
float[][] output = result.getOutput(0);
float[] probabilities = output[0];
try {
BufferedReader reader = new BufferedReader(
new InputStreamReader(getAssets().open("custom_labels.txt")));
for (float probability : probabilities) {
String label = reader.readLine();
Log.i(TAG, String.format("%s: %1.4f", label, probability));
}
} catch (IOException e) {
// File not found?
}
Sau
modelOutput.rewind()
val probabilities = modelOutput.asFloatBuffer()
try {
val reader = BufferedReader(
InputStreamReader(assets.open("custom_labels.txt")))
for (i in probabilities.capacity()) {
val label: String = reader.readLine()
val probability = probabilities.get(i)
println("$label: $probability")
}
} catch (e: IOException) {
// File not found?
}
modelOutput.rewind();
FloatBuffer probabilities = modelOutput.asFloatBuffer();
try {
BufferedReader reader = new BufferedReader(
new InputStreamReader(getAssets().open("custom_labels.txt")));
for (int i = 0; i < probabilities.capacity(); i++) {
String label = reader.readLine();
float probability = probabilities.get(i);
Log.i(TAG, String.format("%s: %1.4f", label, probability));
}
} catch (IOException e) {
// File not found?
}