如果您的應用程序使用自定義TensorFlow Lite模型,則可以使用Firebase ML部署模型。通過使用Firebase部署模型,您可以減少應用程序的初始下載大小並更新應用程序的ML模型,而無需發布新版本的應用程序。並且,通過遠程配置和A / B測試,您可以為不同的用戶組動態提供不同的模型。
TensorFlow Lite模型
TensorFlow Lite模型是經過優化可在移動設備上運行的ML模型。要獲取TensorFlow Lite模型:
在你開始之前
- 如果尚未將Firebase添加到您的Android項目中。
- 使用Firebase Android BoM ,在模塊(應用程序級)Gradle文件(通常為
app/build.gradle
)中聲明Firebase ML自定義模型Android庫的依賴app/build.gradle
。另外,作為設置Firebase ML自定義模型的一部分,您需要將TensorFlow Lite SDK添加到您的應用中。
dependencies { // Import the BoM for the Firebase platform implementation platform('com.google.firebase:firebase-bom:26.3.0') // Declare the dependency for the Firebase ML Custom Models library // When using the BoM, you don't specify versions in Firebase library dependencies implementation 'com.google.firebase:firebase-ml-model-interpreter'
// Also declare the dependency for the TensorFlow Lite library and specify its version implementation 'org.tensorflow:tensorflow-lite:2.3.0' }通過使用Firebase Android BoM ,您的應用將始終使用Firebase Android庫的兼容版本。
(可選)不使用BoM聲明Firebase庫依賴關係
如果選擇不使用Firebase BoM,則必須在其依賴關係行中指定每個Firebase庫版本。
請注意,如果您在應用中使用多個Firebase庫,我們強烈建議您使用BoM來管理庫版本,以確保所有版本兼容。
dependencies { // Declare the dependency for the Firebase ML Custom Models library // When NOT using the BoM, you must specify versions in Firebase library dependencies implementation 'com.google.firebase:firebase-ml-model-interpreter:22.0.4'
// Also declare the dependency for the TensorFlow Lite library and specify its version implementation 'org.tensorflow:tensorflow-lite:2.3.0' } - 在您的應用清單中,聲明需要INTERNET權限:
<uses-permission android:name="android.permission.INTERNET" />
1.部署模型
使用Firebase控制台或Firebase Admin Python和Node.js SDK部署自定義TensorFlow模型。請參閱部署和管理定制模型。
將自定義模型添加到Firebase項目後,您可以使用指定的名稱在應用程序中引用該模型。您隨時可以上傳新的TensorFlow Lite模型,您的應用將下載新模型並在下次重啟時開始使用它。您可以定義應用程序嘗試更新模型所需的設備條件(請參見下文)。
2.將模型下載到設備
要在您的應用中使用TensorFlow Lite模型,請首先使用Firebase ML SDK將模型的最新版本下載到設備上。要開始下載模型,請調用模型管理器的download()
方法,並指定在上載時為模型分配的名稱以及允許下載的條件。如果模型不在設備上,或者有較新版本的模型可用,則該任務將從異步從Firebase下載模型。
您應禁用與模型相關的功能(例如,變灰或隱藏UI的一部分),直到確認已下載模型為止。
爪哇
FirebaseCustomRemoteModel remoteModel =
new FirebaseCustomRemoteModel.Builder("your_model").build();
FirebaseModelDownloadConditions conditions = new FirebaseModelDownloadConditions.Builder()
.requireWifi()
.build();
FirebaseModelManager.getInstance().download(remoteModel, conditions)
.addOnSuccessListener(new OnSuccessListener<Void>() {
@Override
public void onSuccess(Void v) {
// Download complete. Depending on your app, you could enable
// the ML feature, or switch from the local model to the remote
// model, etc.
}
});
Kotlin + KTX
val remoteModel = FirebaseCustomRemoteModel.Builder("your_model").build()
val conditions = FirebaseModelDownloadConditions.Builder()
.requireWifi()
.build()
FirebaseModelManager.getInstance().download(remoteModel, conditions)
.addOnCompleteListener {
// Download complete. Depending on your app, you could enable the ML
// feature, or switch from the local model to the remote model, etc.
}
許多應用程序都在其初始化代碼中啟動下載任務,但是您可以在需要使用模型之前隨時進行下載。
3.初始化TensorFlow Lite解釋器
將模型下載到設備後,可以通過調用模型管理器的getLatestModelFile()
方法來獲取模型文件的位置。使用此值實例化TensorFlow Lite解釋器:
爪哇
FirebaseCustomRemoteModel remoteModel = new FirebaseCustomRemoteModel.Builder("your_model").build();
FirebaseModelManager.getInstance().getLatestModelFile(remoteModel)
.addOnCompleteListener(new OnCompleteListener<File>() {
@Override
public void onComplete(@NonNull Task<File> task) {
File modelFile = task.getResult();
if (modelFile != null) {
interpreter = new Interpreter(modelFile);
}
}
});
Kotlin + KTX
val remoteModel = FirebaseCustomRemoteModel.Builder("your_model").build()
FirebaseModelManager.getInstance().getLatestModelFile(remoteModel)
.addOnCompleteListener { task ->
val modelFile = task.result
if (modelFile != null) {
interpreter = Interpreter(modelFile)
}
}
4.對輸入數據進行推斷
獲取模型的輸入和輸出形狀
TensorFlow Lite模型解釋器將輸入或產生一個或多個多維數組作為輸出。這些數組包含byte
, int
, long
或float
值。在將數據傳遞到模型或使用其結果之前,必須知道模型使用的數組的數量和尺寸(“形狀”)。
如果您自己構建模型,或者記錄了模型的輸入和輸出格式,則可能已經具有此信息。如果您不知道模型輸入和輸出的形狀和數據類型,則可以使用TensorFlow Lite解釋器檢查模型。例如:
蟒蛇
import tensorflow as tf interpreter = tf.lite.Interpreter(model_path="your_model.tflite") interpreter.allocate_tensors() # Print input shape and type inputs = interpreter.get_input_details() print('{} input(s):'.format(len(inputs))) for i in range(0, len(inputs)): print('{} {}'.format(inputs[i]['shape'], inputs[i]['dtype'])) # Print output shape and type outputs = interpreter.get_output_details() print('\n{} output(s):'.format(len(outputs))) for i in range(0, len(outputs)): print('{} {}'.format(outputs[i]['shape'], outputs[i]['dtype']))
輸出示例:
1 input(s): [ 1 224 224 3] <class 'numpy.float32'> 1 output(s): [1 1000] <class 'numpy.float32'>
運行口譯員
確定模型輸入和輸出的格式之後,獲取輸入數據,並對數據進行任何必要的轉換,以獲取適合模型的正確形狀的輸入。例如,如果您的圖像分類模型的輸入形狀為[1 224 224 3]
浮點值,則可以從Bitmap
像生成輸入ByteBuffer
,如以下示例所示:
爪哇
Bitmap bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true);
ByteBuffer input = ByteBuffer.allocateDirect(224 * 224 * 3 * 4).order(ByteOrder.nativeOrder());
for (int y = 0; y < 224; y++) {
for (int x = 0; x < 224; x++) {
int px = bitmap.getPixel(x, y);
// Get channel values from the pixel value.
int r = Color.red(px);
int g = Color.green(px);
int b = Color.blue(px);
// Normalize channel values to [-1.0, 1.0]. This requirement depends
// on the model. For example, some models might require values to be
// normalized to the range [0.0, 1.0] instead.
float rf = (r - 127) / 255.0f;
float gf = (g - 127) / 255.0f;
float bf = (b - 127) / 255.0f;
input.putFloat(rf);
input.putFloat(gf);
input.putFloat(bf);
}
}
Kotlin + KTX
val bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true)
val input = ByteBuffer.allocateDirect(224*224*3*4).order(ByteOrder.nativeOrder())
for (y in 0 until 224) {
for (x in 0 until 224) {
val px = bitmap.getPixel(x, y)
// Get channel values from the pixel value.
val r = Color.red(px)
val g = Color.green(px)
val b = Color.blue(px)
// Normalize channel values to [-1.0, 1.0]. This requirement depends on the model.
// For example, some models might require values to be normalized to the range
// [0.0, 1.0] instead.
val rf = (r - 127) / 255f
val gf = (g - 127) / 255f
val bf = (b - 127) / 255f
input.putFloat(rf)
input.putFloat(gf)
input.putFloat(bf)
}
}
然後,分配一個足夠大的ByteBuffer
來容納模型的輸出,並將輸入緩衝區和輸出緩衝區傳遞給TensorFlow Lite解釋器的run()
方法。例如,對於輸出形狀為[1 1000]
浮點值:
爪哇
int bufferSize = 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE;
ByteBuffer modelOutput = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder());
interpreter.run(input, modelOutput);
Kotlin + KTX
val bufferSize = 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE
val modelOutput = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder())
interpreter?.run(input, modelOutput)
您如何使用輸出取決於您使用的模型。
例如,如果要執行分類,那麼下一步,您可以將結果的索引映射到它們表示的標籤:
爪哇
modelOutput.rewind();
FloatBuffer probabilities = modelOutput.asFloatBuffer();
try {
BufferedReader reader = new BufferedReader(
new InputStreamReader(getAssets().open("custom_labels.txt")));
for (int i = 0; i < probabilities.capacity(); i++) {
String label = reader.readLine();
float probability = probabilities.get(i);
Log.i(TAG, String.format("%s: %1.4f", label, probability));
}
} catch (IOException e) {
// File not found?
}
Kotlin + KTX
modelOutput.rewind()
val probabilities = modelOutput.asFloatBuffer()
try {
val reader = BufferedReader(
InputStreamReader(assets.open("custom_labels.txt")))
for (i in probabilities.capacity()) {
val label: String = reader.readLine()
val probability = probabilities.get(i)
println("$label: $probability")
}
} catch (e: IOException) {
// File not found?
}
附錄:回退到本地捆綁的模型
當您使用Firebase託管模型時,與模型相關的任何功能將不可用,直到您的應用程序首次下載模型為止。對於某些應用程序,這可能很好,但是如果您的模型啟用了核心功能,則您可能希望將模型版本與應用程序捆綁在一起,並使用最佳版本。這樣,當Firebase託管模型不可用時,您可以確保應用的ML功能正常工作。
要將TensorFlow Lite模型與您的應用捆綁在一起:
將模型文件(通常以
.tflite
或.lite
).lite
到應用程序的assets/
文件夾中。 (您可能需要先右鍵單擊app/
文件夾,然後單擊新建>文件夾>資產文件夾來創建文件夾。)將以下內容添加到應用程序的
build.gradle
文件中,以確保Gradle在構建應用程序時不會壓縮模型:android { // ... aaptOptions { noCompress "tflite", "lite" } }
然後,在託管模型不可用時使用本地捆綁的模型:
爪哇
FirebaseCustomRemoteModel remoteModel =
new FirebaseCustomRemoteModel.Builder("your_model").build();
FirebaseModelManager.getInstance().getLatestModelFile(remoteModel)
.addOnCompleteListener(new OnCompleteListener<File>() {
@Override
public void onComplete(@NonNull Task<File> task) {
File modelFile = task.getResult();
if (modelFile != null) {
interpreter = new Interpreter(modelFile);
} else {
try {
InputStream inputStream = getAssets().open("your_fallback_model.tflite");
byte[] model = new byte[inputStream.available()];
inputStream.read(model);
ByteBuffer buffer = ByteBuffer.allocateDirect(model.length)
.order(ByteOrder.nativeOrder());
buffer.put(model);
interpreter = new Interpreter(buffer);
} catch (IOException e) {
// File not found?
}
}
}
});
Kotlin + KTX
val remoteModel = FirebaseCustomRemoteModel.Builder("your_model").build()
FirebaseModelManager.getInstance().getLatestModelFile(remoteModel)
.addOnCompleteListener { task ->
val modelFile = task.result
if (modelFile != null) {
interpreter = Interpreter(modelFile)
} else {
val model = assets.open("your_fallback_model.tflite").readBytes()
val buffer = ByteBuffer.allocateDirect(model.size).order(ByteOrder.nativeOrder())
buffer.put(model)
interpreter = Interpreter(buffer)
}
}
附錄:模型安全
無論您如何使TensorFlow Lite模型可用於Firebase ML,Firebase ML都將它們以標準序列化protobuf格式存儲在本地存儲中。
從理論上講,這意味著任何人都可以復制您的模型。但是,實際上,大多數模型都是特定於應用程序的,並且由於優化而模糊不清,以至於風險與競爭對手拆卸和重用您的代碼相似。但是,在應用程序中使用自定義模型之前,您應該意識到這種風險。
在Android API級別21(Lollipop)和更高版本上,模型被下載到自動備份所排除的目錄中。
在Android API級別20和更高版本上,模型被下載到應用程序專用內部存儲中的名為com.google.firebase.ml.custom.models
的目錄中。如果使用BackupAgent
啟用了文件備份,則可以選擇排除此目錄。