如果您的應用使用自定義TensorFlow Lite模型,您可以使用 Firebase ML 來部署您的模型。通過使用 Firebase 部署模型,您可以減少應用的初始下載大小並更新應用的 ML 模型,而無需發布應用的新版本。而且,通過遠程配置和 A/B 測試,您可以為不同的用戶組動態地提供不同的模型。
先決條件
MLModelDownloader
庫僅適用於 Swift。- TensorFlow Lite 只能在使用 iOS 9 及更新版本的設備上運行。
TensorFlow Lite 模型
TensorFlow Lite 模型是經過優化以在移動設備上運行的 ML 模型。要獲取 TensorFlow Lite 模型:
在你開始之前
要將 TensorFlowLite 與 Firebase 一起使用,您必須使用 CocoaPods,因為 TensorFlowLite 目前不支持使用 Swift 包管理器進行安裝。有關如何安裝MLModelDownloader
的說明,請參閱CocoaPods 安裝指南。
安裝後,導入 Firebase 和 TensorFlowLite 以使用它們。
迅速
import FirebaseMLModelDownloader
import TensorFlowLite
1. 部署你的模型
使用 Firebase 控制台或 Firebase Admin Python 和 Node.js SDK 部署您的自定義 TensorFlow 模型。請參閱部署和管理自定義模型。
將自定義模型添加到 Firebase 項目後,您可以使用您指定的名稱在應用中引用該模型。您可以隨時部署新的 TensorFlow Lite 模型,並通過調用getModel()
將新模型下載到用戶的設備上(見下文)。
2. 將模型下載到設備並初始化 TensorFlow Lite 解釋器
要在您的應用中使用您的 TensorFlow Lite 模型,請首先使用 Firebase ML SDK 將最新版本的模型下載到設備中。要開始模型下載,請調用模型下載器的getModel()
方法,指定上傳模型時指定的名稱、是否要始終下載最新模型以及允許下載的條件。
您可以從三種下載行為中進行選擇:
下載類型 | 描述 |
---|---|
localModel | 從設備獲取本地模型。如果沒有可用的本地模型,則其行為類似於latestModel 。如果您對檢查模型更新不感興趣,請使用此下載類型。例如,您正在使用遠程配置來檢索模型名稱,並且您總是以新名稱上傳模型(推薦)。 |
localModelUpdateInBackground | 從設備獲取本地模型並開始在後台更新模型。如果沒有可用的本地模型,則其行為類似於latestModel 。 |
latestModel | 獲取最新型號。如果本地模型是最新版本,則返回本地模型。否則,請下載最新型號。在下載最新版本之前,此行為將被阻止(不推薦)。僅在您明確需要最新版本的情況下使用此行為。 |
您應該禁用與模型相關的功能(例如,灰顯或隱藏部分 UI),直到您確認模型已下載。
迅速
let conditions = ModelDownloadConditions(allowsCellularAccess: false)
ModelDownloader.modelDownloader()
.getModel(name: "your_model",
downloadType: .localModelUpdateInBackground,
conditions: conditions) { result in
switch (result) {
case .success(let customModel):
do {
// Download complete. Depending on your app, you could enable the ML
// feature, or switch from the local model to the remote model, etc.
// The CustomModel object contains the local path of the model file,
// which you can use to instantiate a TensorFlow Lite interpreter.
let interpreter = try Interpreter(modelPath: customModel.path)
} catch {
// Error. Bad model file?
}
case .failure(let error):
// Download was unsuccessful. Don't enable ML features.
print(error)
}
}
許多應用程序在其初始化代碼中啟動下載任務,但您可以在需要使用模型之前的任何時候這樣做。
3. 對輸入數據進行推理
獲取模型的輸入和輸出形狀
TensorFlow Lite 模型解釋器將一個或多個多維數組作為輸入並生成輸出。這些數組包含byte
、 int
、 long
或float
值。在將數據傳遞給模型或使用其結果之前,您必須知道模型使用的數組的數量和維度(“形狀”)。
如果您自己構建了模型,或者模型的輸入和輸出格式已記錄在案,您可能已經擁有此信息。如果您不知道模型輸入和輸出的形狀和數據類型,您可以使用 TensorFlow Lite 解釋器來檢查您的模型。例如:
Python
import tensorflow as tf interpreter = tf.lite.Interpreter(model_path="your_model.tflite") interpreter.allocate_tensors() # Print input shape and type inputs = interpreter.get_input_details() print('{} input(s):'.format(len(inputs))) for i in range(0, len(inputs)): print('{} {}'.format(inputs[i]['shape'], inputs[i]['dtype'])) # Print output shape and type outputs = interpreter.get_output_details() print('\n{} output(s):'.format(len(outputs))) for i in range(0, len(outputs)): print('{} {}'.format(outputs[i]['shape'], outputs[i]['dtype']))
示例輸出:
1 input(s): [ 1 224 224 3] <class 'numpy.float32'> 1 output(s): [1 1000] <class 'numpy.float32'>
運行解釋器
確定模型輸入和輸出的格式後,獲取輸入數據並對數據執行任何必要的轉換,以獲得模型正確形狀的輸入。例如,如果您的模型處理圖像,並且您的模型具有[1, 224, 224, 3]
浮點值的輸入尺寸,則您可能必須將圖像的顏色值縮放到浮點範圍,如下例所示:
迅速
let image: CGImage = // Your input image
guard let context = CGContext(
data: nil,
width: image.width, height: image.height,
bitsPerComponent: 8, bytesPerRow: image.width * 4,
space: CGColorSpaceCreateDeviceRGB(),
bitmapInfo: CGImageAlphaInfo.noneSkipFirst.rawValue
) else {
return false
}
context.draw(image, in: CGRect(x: 0, y: 0, width: image.width, height: image.height))
guard let imageData = context.data else { return false }
var inputData = Data()
for row in 0 ..< 224 {
for col in 0 ..< 224 {
let offset = 4 * (row * context.width + col)
// (Ignore offset 0, the unused alpha channel)
let red = imageData.load(fromByteOffset: offset+1, as: UInt8.self)
let green = imageData.load(fromByteOffset: offset+2, as: UInt8.self)
let blue = imageData.load(fromByteOffset: offset+3, as: UInt8.self)
// Normalize channel values to [0.0, 1.0]. This requirement varies
// by model. For example, some models might require values to be
// normalized to the range [-1.0, 1.0] instead, and others might
// require fixed-point values or the original bytes.
var normalizedRed = Float32(red) / 255.0
var normalizedGreen = Float32(green) / 255.0
var normalizedBlue = Float32(blue) / 255.0
// Append normalized values to Data object in RGB order.
let elementSize = MemoryLayout.size(ofValue: normalizedRed)
var bytes = [UInt8](repeating: 0, count: elementSize)
memcpy(&bytes, &normalizedRed, elementSize)
inputData.append(&bytes, count: elementSize)
memcpy(&bytes, &normalizedGreen, elementSize)
inputData.append(&bytes, count: elementSize)
memcpy(&ammp;bytes, &normalizedBlue, elementSize)
inputData.append(&bytes, count: elementSize)
}
}
然後,將輸入的NSData
複製到解釋器並運行它:
迅速
try interpreter.allocateTensors()
try interpreter.copy(inputData, toInputAt: 0)
try interpreter.invoke()
您可以通過調用解釋器的output(at:)
方法來獲取模型的輸出。您如何使用輸出取決於您使用的模型。
例如,如果您正在執行分類,作為下一步,您可能會將結果的索引映射到它們所代表的標籤:
迅速
let output = try interpreter.output(at: 0)
let probabilities =
UnsafeMutableBufferPointer<Float32>.allocate(capacity: 1000)
output.data.copyBytes(to: probabilities)
guard let labelPath = Bundle.main.path(forResource: "retrained_labels", ofType: "txt") else { return }
let fileContents = try? String(contentsOfFile: labelPath)
guard let labels = fileContents?.components(separatedBy: "\n") else { return }
for i in labels.indices {
print("\(labels[i]): \(probabilities[i])")
}
附錄:模型安全
無論您如何使 TensorFlow Lite 模型可用於 Firebase ML,Firebase ML 都會以標準序列化 protobuf 格式將它們存儲在本地存儲中。
理論上,這意味著任何人都可以復制您的模型。然而,在實踐中,大多數模型都是特定於應用程序的,並且被優化混淆了,其風險類似於競爭對手反彙編和重用代碼的風險。不過,在您的應用程序中使用自定義模型之前,您應該意識到這種風險。