firebase firebase-ml-model-interpreter
getLatestModelFile()
firebase-ml-model-interpreter
庫的getLatestModelFile()
引入了一個新的getLatestModelFile()
方法,該方法獲取自定義模型在設備上的位置。您可以使用此方法直接實例化TensorFlow Lite Interpreter
對象,可以代替FirebaseModelInterpreter
包裝器使用該對象。
展望未來,這是首選方法。由於TensorFlow Lite解釋器版本不再與Firebase庫版本配合使用,因此您可以在需要時擁有更大的靈活性來升級到TensorFlow Lite的新版本,或者更輕鬆地使用自定義TensorFlow Lite構建。
此頁面顯示瞭如何從使用FirebaseModelInterpreter
遷移到TensorFlow Lite Interpreter
。
1.更新項目依賴項
更新項目的依賴項,以包括firebase firebase-ml-model-interpreter
庫(或更高版本)的tensorflow-lite
和tensorflow-lite
庫:
前
implementation 'com.google.firebase:firebase-ml-model-interpreter:22.0.1'
後
implementation 'com.google.firebase:firebase-ml-model-interpreter:22.0.2'
implementation 'org.tensorflow:tensorflow-lite:2.0.0'
2.創建一個TensorFlow Lite解釋器,而不是FirebaseModelInterpreter
不用創建FirebaseModelInterpreter
,而是使用getLatestModelFile()
獲取模型在設備上的位置,然後使用它來創建TensorFlow Lite Interpreter
。
前
爪哇
FirebaseCustomRemoteModel remoteModel =
new FirebaseCustomRemoteModel.Builder("your_model").build();
FirebaseModelInterpreterOptions options =
new FirebaseModelInterpreterOptions.Builder(remoteModel).build();
FirebaseModelInterpreter interpreter = FirebaseModelInterpreter.getInstance(options);
Kotlin + KTX
val remoteModel = FirebaseCustomRemoteModel.Builder("your_model").build()
val options = FirebaseModelInterpreterOptions.Builder(remoteModel).build()
val interpreter = FirebaseModelInterpreter.getInstance(options)
後
爪哇
FirebaseCustomRemoteModel remoteModel =
new FirebaseCustomRemoteModel.Builder("your_model").build();
FirebaseModelManager.getInstance().getLatestModelFile(remoteModel)
.addOnCompleteListener(new OnCompleteListener<File>() {
@Override
public void onComplete(@NonNull Task<File> task) {
File modelFile = task.getResult();
if (modelFile != null) {
// Instantiate an org.tensorflow.lite.Interpreter object.
Interpreter interpreter = new Interpreter(modelFile);
}
}
});
Kotlin + KTX
val remoteModel = FirebaseCustomRemoteModel.Builder("your_model").build()
FirebaseModelManager.getInstance().getLatestModelFile(remoteModel)
.addOnCompleteListener { task ->
val modelFile = task.getResult()
if (modelFile != null) {
// Instantiate an org.tensorflow.lite.Interpreter object.
interpreter = Interpreter(modelFile)
}
}
3.更新輸入和輸出準備代碼
使用FirebaseModelInterpreter
,您可以通過在運行模型時將FirebaseModelInputOutputOptions
對像傳遞給解釋器來指定模型的輸入和輸出形狀。
對於TensorFlow Lite解釋器,您可以為模型的輸入和輸出分配大小合適的ByteBuffer
對象。
例如,如果模型的輸入形狀為[1 224 224 3] float
值,而輸出形狀為[1 1000] float
值,請進行以下更改:
前
爪哇
FirebaseModelInputOutputOptions inputOutputOptions =
new FirebaseModelInputOutputOptions.Builder()
.setInputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 224, 224, 3})
.setOutputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 1000})
.build();
float[][][][] input = new float[1][224][224][3];
// Then populate with input data.
FirebaseModelInputs inputs = new FirebaseModelInputs.Builder()
.add(input)
.build();
interpreter.run(inputs, inputOutputOptions)
.addOnSuccessListener(
new OnSuccessListener<FirebaseModelOutputs>() {
@Override
public void onSuccess(FirebaseModelOutputs result) {
// ...
}
})
.addOnFailureListener(
new OnFailureListener() {
@Override
public void onFailure(@NonNull Exception e) {
// Task failed with an exception
// ...
}
});
Kotlin + KTX
val inputOutputOptions = FirebaseModelInputOutputOptions.Builder()
.setInputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1, 224, 224, 3))
.setOutputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1, 1000))
.build()
val input = ByteBuffer.allocateDirect(224*224*3*4).order(ByteOrder.nativeOrder())
// Then populate with input data.
val inputs = FirebaseModelInputs.Builder()
.add(input)
.build()
interpreter.run(inputs, inputOutputOptions)
.addOnSuccessListener { outputs ->
// ...
}
.addOnFailureListener {
// Task failed with an exception.
// ...
}
後
爪哇
int inBufferSize = 1 * 224 * 224 * 3 * java.lang.Float.SIZE / java.lang.Byte.SIZE;
ByteBuffer inputBuffer =
ByteBuffer.allocateDirect(inBufferSize).order(ByteOrder.nativeOrder());
// Then populate with input data.
int outBufferSize = 1 * 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE;
ByteBuffer outputBuffer =
ByteBuffer.allocateDirect(outBufferSize).order(ByteOrder.nativeOrder());
interpreter.run(inputBuffer, outputBuffer);
Kotlin + KTX
val inBufferSize = 1 * 224 * 224 * 3 * java.lang.Float.SIZE / java.lang.Byte.SIZE
val inputBuffer = ByteBuffer.allocateDirect(inBufferSize).order(ByteOrder.nativeOrder())
// Then populate with input data.
val outBufferSize = 1 * 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE
val outputBuffer = ByteBuffer.allocateDirect(outBufferSize).order(ByteOrder.nativeOrder())
interpreter.run(inputBuffer, outputBuffer)
4.更新輸出處理代碼
最後,不要使用FirebaseModelOutputs
對象的getOutput()
方法獲取模型的輸出,而是將ByteBuffer
輸出轉換為適合您的使用情況的任何結構。
例如,如果您要進行分類,則可以進行如下更改:
前
爪哇
float[][] output = result.getOutput(0);
float[] probabilities = output[0];
try {
BufferedReader reader = new BufferedReader(
new InputStreamReader(getAssets().open("custom_labels.txt")));
for (float probability : probabilities) {
String label = reader.readLine();
Log.i(TAG, String.format("%s: %1.4f", label, probability));
}
} catch (IOException e) {
// File not found?
}
Kotlin + KTX
val output = result.getOutput(0)
val probabilities = output[0]
try {
val reader = BufferedReader(InputStreamReader(assets.open("custom_labels.txt")))
for (probability in probabilities) {
val label: String = reader.readLine()
println("$label: $probability")
}
} catch (e: IOException) {
// File not found?
}
後
爪哇
modelOutput.rewind();
FloatBuffer probabilities = modelOutput.asFloatBuffer();
try {
BufferedReader reader = new BufferedReader(
new InputStreamReader(getAssets().open("custom_labels.txt")));
for (int i = 0; i < probabilities.capacity(); i++) {
String label = reader.readLine();
float probability = probabilities.get(i);
Log.i(TAG, String.format("%s: %1.4f", label, probability));
}
} catch (IOException e) {
// File not found?
}
Kotlin + KTX
modelOutput.rewind()
val probabilities = modelOutput.asFloatBuffer()
try {
val reader = BufferedReader(
InputStreamReader(assets.open("custom_labels.txt")))
for (i in probabilities.capacity()) {
val label: String = reader.readLine()
val probability = probabilities.get(i)
println("$label: $probability")
}
} catch (e: IOException) {
// File not found?
}