Catch up on everthing we announced at this year's Firebase Summit. Learn more

在 Apple 平台上使用自定義 TensorFlow Lite 模型

如果應用程序使用自定義的TensorFlow精簡版機型,您可以用火力地堡ML部署您的機型。通過使用 Firebase 部署模型,您可以減少應用的初始下載大小並更新應用的機器學習模型,而無需發布應用的新版本。而且,通過遠程配置和 A/B 測試,您可以為不同的用戶組動態地提供不同的模型。

先決條件

  • MLModelDownloader庫僅適用於斯威夫特。
  • TensorFlow Lite 只能在使用 iOS 9 及更新版本的設備上運行。

TensorFlow Lite 模型

TensorFlow Lite 模型是經過優化可在移動設備上運行的 ML 模型。要獲取 TensorFlow Lite 模型:

在你開始之前

如果您尚未添加火力地堡到您的應用程序,通過遵循的步驟做這樣的入門指南

使用 Swift Package Manager 安裝和管理 Firebase 依賴項。

  1. 在Xcode中,您的應用項目打開,導航到File>斯威夫特包>添加包的依賴
  2. 出現提示時,添加 Firebase Apple 平台 SDK 存儲庫:
  3.   https://github.com/firebase/firebase-ios-sdk
      
  4. 選擇 Firebase ML 庫。
  5. 完成後,Xcode 將在後台自動開始解析和下載您的依賴項。

接下來,執行一些應用程序內設置:

  1. 在您的應用中,導入 Firebase:

    迅速

    import Firebase
    import TensorFlowLite
    

1. 部署您的模型

使用 Firebase 控制台或 Firebase Admin Python 和 Node.js SDK 部署您的自定義 TensorFlow 模型。請參閱部署和管理定制機型

將自定義模型添加到 Firebase 項目後,您可以使用指定的名稱在應用中引用該模型。在任何時候,你可以部署一個新的TensorFlow精簡版模型,並通過調用下載的新模式到用戶的設備getModel()見下文)。

2. 將模型下載到設備並初始化一個 TensorFlow Lite 解釋器

要在您的應用中使用 TensorFlow Lite 模型,請首先使用 Firebase ML SDK 將模型的最新版本下載到設備。

要啟動模型下載,調用模型下載的getModel()方法,指定你分配模式,當你上傳它,你是否要隨時下載最新的模型,以及在何種條件下要允許下載的名稱。

您可以從三種下載行為中進行選擇:

下載類型描述
localModel從設備獲取本地模型。如果沒有可用的本地模式,這種行為就像latestModel 。如果您對檢查模型更新不感興趣,請使用此下載類型。例如,您使用遠程配置來檢索模型名稱,並且始終以新名稱(推薦)上傳模型。
localModelUpdateInBackground從設備獲取本地模型並在後台開始更新模型。如果沒有可用的本地模式,這種行為就像latestModel
latestModel獲取最新型號。如果本地模型是最新版本,則返回本地模型。否則,請下載最新型號。在下載最新版本之前,此行為將被阻止(不推薦)。僅在您明確需要最新版本的情況下才使用此行為。

您應該禁用與模型相關的功能——例如,灰顯或隱藏部分 UI——直到您確認模型已被下載。

迅速

let conditions = ModelDownloadConditions(allowsCellularAccess: false)
ModelDownloader.modelDownloader()
    .getModel(name: "your_model",
              downloadType: .localModelUpdateInBackground,
              conditions: conditions) { result in
        switch (result) {
        case .success(let customModel):
            do {
                // Download complete. Depending on your app, you could enable the ML
                // feature, or switch from the local model to the remote model, etc.

                // The CustomModel object contains the local path of the model file,
                // which you can use to instantiate a TensorFlow Lite interpreter.
                let interpreter = try Interpreter(modelPath: customModel.path)
            } catch {
                // Error. Bad model file?
            }
        case .failure(let error):
            // Download was unsuccessful. Don't enable ML features.
            print(error)
        }
}

許多應用程序在其初始化代碼中啟動下載任務,但您可以在需要使用模型之前隨時執行此操作。

3. 對輸入數據進行推理

獲取模型的輸入和輸出形狀

TensorFlow Lite 模型解釋器將一個或多個多維數組作為輸入並作為輸出生成。這些陣列包含任一byteintlong ,或float值。在將數據傳遞給模型或使用其結果之前,您必須知道模型使用的數組的數量和維度(“形狀”)。

如果您自己構建了模型,或者模型的輸入和輸出格式已記錄在案,則您可能已經擁有這些信息。如果您不知道模型輸入和輸出的形狀和數據類型,您可以使用 TensorFlow Lite 解釋器來檢查您的模型。例如:

Python

import tensorflow as tf

interpreter = tf.lite.Interpreter(model_path="your_model.tflite")
interpreter.allocate_tensors()

# Print input shape and type
inputs = interpreter.get_input_details()
print('{} input(s):'.format(len(inputs)))
for i in range(0, len(inputs)):
    print('{} {}'.format(inputs[i]['shape'], inputs[i]['dtype']))

# Print output shape and type
outputs = interpreter.get_output_details()
print('\n{} output(s):'.format(len(outputs)))
for i in range(0, len(outputs)):
    print('{} {}'.format(outputs[i]['shape'], outputs[i]['dtype']))

示例輸出:

1 input(s):
[  1 224 224   3] <class 'numpy.float32'>

1 output(s):
[1 1000] <class 'numpy.float32'>

運行解釋器

在確定模型輸入和輸出的格式後,獲取輸入數據並對數據執行任何必要的轉換,以獲取模型的正確形狀的輸入。

例如,如果模型處理圖像,和模型具有的輸入維數[1, 224, 224, 3]浮點值,則可能必須對圖像的顏色值擴展到浮點範圍如下面的示例:

迅速

let image: CGImage = // Your input image
guard let context = CGContext(
  data: nil,
  width: image.width, height: image.height,
  bitsPerComponent: 8, bytesPerRow: image.width * 4,
  space: CGColorSpaceCreateDeviceRGB(),
  bitmapInfo: CGImageAlphaInfo.noneSkipFirst.rawValue
) else {
  return false
}

context.draw(image, in: CGRect(x: 0, y: 0, width: image.width, height: image.height))
guard let imageData = context.data else { return false }

var inputData = Data()
for row in 0 ..&lt; 224 {
  for col in 0 ..&lt; 224 {
    let offset = 4 * (row * context.width + col)
    // (Ignore offset 0, the unused alpha channel)
    let red = imageData.load(fromByteOffset: offset+1, as: UInt8.self)
    let green = imageData.load(fromByteOffset: offset+2, as: UInt8.self)
    let blue = imageData.load(fromByteOffset: offset+3, as: UInt8.self)

    // Normalize channel values to [0.0, 1.0]. This requirement varies
    // by model. For example, some models might require values to be
    // normalized to the range [-1.0, 1.0] instead, and others might
    // require fixed-point values or the original bytes.
    var normalizedRed = Float32(red) / 255.0
    var normalizedGreen = Float32(green) / 255.0
    var normalizedBlue = Float32(blue) / 255.0

    // Append normalized values to Data object in RGB order.
    let elementSize = MemoryLayout.size(ofValue: normalizedRed)
    var bytes = [UInt8](repeating: 0, count: elementSize)
    memcpy(&amp;bytes, &amp;normalizedRed, elementSize)
    inputData.append(&amp;bytes, count: elementSize)
    memcpy(&amp;bytes, &amp;normalizedGreen, elementSize)
    inputData.append(&amp;bytes, count: elementSize)
    memcpy(&ammp;bytes, &amp;normalizedBlue, elementSize)
    inputData.append(&amp;bytes, count: elementSize)
  }
}

然後,您輸入複製NSData來解釋並運行它:

迅速

try interpreter.allocateTensors()
try interpreter.copy(inputData, toInputAt: 0)
try interpreter.invoke()

你可以通過調用解釋者獲得該模型的輸出中output(at:)方法。您如何使用輸出取決於您使用的模型。

例如,如果您正在執行分類,作為下一步,您可以將結果的索引映射到它們代表的標籤:

迅速

let output = try interpreter.output(at: 0)
let probabilities =
        UnsafeMutableBufferPointer<Float32>.allocate(capacity: 1000)
output.data.copyBytes(to: probabilities)

guard let labelPath = Bundle.main.path(forResource: "retrained_labels", ofType: "txt") else { return }
let fileContents = try? String(contentsOfFile: labelPath)
guard let labels = fileContents?.components(separatedBy: "\n") else { return }

for i in labels.indices {
    print("\(labels[i]): \(probabilities[i])")
}

附錄:模型安全

無論您如何將 TensorFlow Lite 模型提供給 Firebase ML,Firebase ML 都會將它們以標準序列化 protobuf 格式存儲在本地存儲中。

理論上,這意味著任何人都可以復制您的模型。然而,在實踐中,大多數模型都是特定於應用程序的,並且被優化混淆,以至於其風險與競爭對手反彙編和重用您的代碼的風險相似。不過,在您的應用程序中使用自定義模型之前,您應該意識到這種風險。