Join us in person and online for Firebase Summit on October 18, 2022. Learn how Firebase can help you accelerate app development, release your app with confidence, and scale with ease. Register now

在 Apple 平台上使用自定義 TensorFlow Lite 模型

透過集合功能整理內容 你可以依據偏好儲存及分類內容。

如果您的應用使用自定義TensorFlow Lite模型,您可以使用 Firebase ML 來部署您的模型。通過使用 Firebase 部署模型,您可以減少應用的初始下載大小並更新應用的 ML 模型,而無需發布應用的新版本。而且,通過遠程配置和 A/B 測試,您可以為不同的用戶組動態地提供不同的模型。

先決條件

  • MLModelDownloader庫僅適用於 Swift。
  • TensorFlow Lite 只能在使用 iOS 9 及更新版本的設備上運行。

TensorFlow Lite 模型

TensorFlow Lite 模型是經過優化以在移動設備上運行的 ML 模型。要獲取 TensorFlow Lite 模型:

在你開始之前

要將 TensorFlowLite 與 Firebase 一起使用,您必須使用 CocoaPods,因為 TensorFlowLite 目前不支持使用 Swift 包管理器進行安裝。有關如何安裝MLModelDownloader的說明,請參閱CocoaPods 安裝指南

安裝後,導入 Firebase 和 TensorFlowLite 以使用它們。

迅速

import FirebaseMLModelDownloader
import TensorFlowLite

1. 部署你的模型

使用 Firebase 控制台或 Firebase Admin Python 和 Node.js SDK 部署您的自定義 TensorFlow 模型。請參閱部署和管理自定義模型

將自定義模型添加到 Firebase 項目後,您可以使用您指定的名稱在應用中引用該模型。您可以隨時部署新的 TensorFlow Lite 模型,並通過調用getModel()將新模型下載到用戶的設備上(見下文)。

2. 將模型下載到設備並初始化 TensorFlow Lite 解釋器

要在您的應用中使用您的 TensorFlow Lite 模型,請首先使用 Firebase ML SDK 將最新版本的模型下載到設備中。

要開始模型下載,請調用模型下載器的getModel()方法,指定上傳模型時指定的名稱、是否要始終下載最新模型以及允許下載的條件。

您可以從三種下載行為中進行選擇:

下載類型描述
localModel從設備獲取本地模型。如果沒有可用的本地模型,則其行為類似於latestModel 。如果您對檢查模型更新不感興趣,請使用此下載類型。例如,您正在使用遠程配置來檢索模型名稱,並且您總是以新名稱上傳模型(推薦)。
localModelUpdateInBackground從設備獲取本地模型並開始在後台更新模型。如果沒有可用的本地模型,則其行為類似於latestModel
latestModel獲取最新型號。如果本地模型是最新版本,則返回本地模型。否則,請下載最新型號。在下載最新版本之前,此行為將被阻止(不推薦)。僅在您明確需要最新版本的情況下使用此行為。

您應該禁用與模型相關的功能(例如,灰顯或隱藏部分 UI),直到您確認模型已下載。

迅速

let conditions = ModelDownloadConditions(allowsCellularAccess: false)
ModelDownloader.modelDownloader()
    .getModel(name: "your_model",
              downloadType: .localModelUpdateInBackground,
              conditions: conditions) { result in
        switch (result) {
        case .success(let customModel):
            do {
                // Download complete. Depending on your app, you could enable the ML
                // feature, or switch from the local model to the remote model, etc.

                // The CustomModel object contains the local path of the model file,
                // which you can use to instantiate a TensorFlow Lite interpreter.
                let interpreter = try Interpreter(modelPath: customModel.path)
            } catch {
                // Error. Bad model file?
            }
        case .failure(let error):
            // Download was unsuccessful. Don't enable ML features.
            print(error)
        }
}

許多應用程序在其初始化代碼中啟動下載任務,但您可以在需要使用模型之前的任何時候這樣做。

3. 對輸入數據進行推理

獲取模型的輸入和輸出形狀

TensorFlow Lite 模型解釋器將一個或多個多維數組作為輸入並生成輸出。這些數組包含byteintlongfloat值。在將數據傳遞給模型或使用其結果之前,您必須知道模型使用的數組的數量和維度(“形狀”)。

如果您自己構建了模型,或者模型的輸入和輸出格式已記錄在案,您可能已經擁有此信息。如果您不知道模型輸入和輸出的形狀和數據類型,您可以使用 TensorFlow Lite 解釋器來檢查您的模型。例如:

Python

import tensorflow as tf

interpreter = tf.lite.Interpreter(model_path="your_model.tflite")
interpreter.allocate_tensors()

# Print input shape and type
inputs = interpreter.get_input_details()
print('{} input(s):'.format(len(inputs)))
for i in range(0, len(inputs)):
    print('{} {}'.format(inputs[i]['shape'], inputs[i]['dtype']))

# Print output shape and type
outputs = interpreter.get_output_details()
print('\n{} output(s):'.format(len(outputs)))
for i in range(0, len(outputs)):
    print('{} {}'.format(outputs[i]['shape'], outputs[i]['dtype']))

示例輸出:

1 input(s):
[  1 224 224   3] <class 'numpy.float32'>

1 output(s):
[1 1000] <class 'numpy.float32'>

運行解釋器

確定模型輸入和輸出的格式後,獲取輸入數據並對數據執行任何必要的轉換,以獲得模型正確形狀的輸入。

例如,如果您的模型處理圖像,並且您的模型具有[1, 224, 224, 3]浮點值的輸入尺寸,則您可能必須將圖像的顏色值縮放到浮點範圍,如下例所示:

迅速

let image: CGImage = // Your input image
guard let context = CGContext(
  data: nil,
  width: image.width, height: image.height,
  bitsPerComponent: 8, bytesPerRow: image.width * 4,
  space: CGColorSpaceCreateDeviceRGB(),
  bitmapInfo: CGImageAlphaInfo.noneSkipFirst.rawValue
) else {
  return false
}

context.draw(image, in: CGRect(x: 0, y: 0, width: image.width, height: image.height))
guard let imageData = context.data else { return false }

var inputData = Data()
for row in 0 ..&lt; 224 {
  for col in 0 ..&lt; 224 {
    let offset = 4 * (row * context.width + col)
    // (Ignore offset 0, the unused alpha channel)
    let red = imageData.load(fromByteOffset: offset+1, as: UInt8.self)
    let green = imageData.load(fromByteOffset: offset+2, as: UInt8.self)
    let blue = imageData.load(fromByteOffset: offset+3, as: UInt8.self)

    // Normalize channel values to [0.0, 1.0]. This requirement varies
    // by model. For example, some models might require values to be
    // normalized to the range [-1.0, 1.0] instead, and others might
    // require fixed-point values or the original bytes.
    var normalizedRed = Float32(red) / 255.0
    var normalizedGreen = Float32(green) / 255.0
    var normalizedBlue = Float32(blue) / 255.0

    // Append normalized values to Data object in RGB order.
    let elementSize = MemoryLayout.size(ofValue: normalizedRed)
    var bytes = [UInt8](repeating: 0, count: elementSize)
    memcpy(&amp;bytes, &amp;normalizedRed, elementSize)
    inputData.append(&amp;bytes, count: elementSize)
    memcpy(&amp;bytes, &amp;normalizedGreen, elementSize)
    inputData.append(&amp;bytes, count: elementSize)
    memcpy(&ammp;bytes, &amp;normalizedBlue, elementSize)
    inputData.append(&amp;bytes, count: elementSize)
  }
}

然後,將輸入的NSData複製到解釋器並運行它:

迅速

try interpreter.allocateTensors()
try interpreter.copy(inputData, toInputAt: 0)
try interpreter.invoke()

您可以通過調用解釋器的output(at:)方法來獲取模型的輸出。您如何使用輸出取決於您使用的模型。

例如,如果您正在執行分類,作為下一步,您可能會將結果的索引映射到它們所代表的標籤:

迅速

let output = try interpreter.output(at: 0)
let probabilities =
        UnsafeMutableBufferPointer<Float32>.allocate(capacity: 1000)
output.data.copyBytes(to: probabilities)

guard let labelPath = Bundle.main.path(forResource: "retrained_labels", ofType: "txt") else { return }
let fileContents = try? String(contentsOfFile: labelPath)
guard let labels = fileContents?.components(separatedBy: "\n") else { return }

for i in labels.indices {
    print("\(labels[i]): \(probabilities[i])")
}

附錄:模型安全

無論您如何使 TensorFlow Lite 模型可用於 Firebase ML,Firebase ML 都會以標準序列化 protobuf 格式將它們存儲在本地存儲中。

理論上,這意味著任何人都可以復制您的模型。然而,在實踐中,大多數模型都是特定於應用程序的,並且被優化混淆了,其風險類似於競爭對手反彙編和重用代碼的風險。不過,在您的應用程序中使用自定義模型之前,您應該意識到這種風險。