Save the date - Google I/O returns May 18-20. Register to get the most out of the digital experience: Build your schedule, reserve space, participate in Q&As, earn Google Developer profile badges, and more. Register now
本頁面由 Cloud Translation API 翻譯而成。
Switch to English

在iOS上使用自定義TensorFlow Lite模型

如果您的應用程序使用自定義TensorFlow Lite模型,則可以使用Firebase ML部署模型。通過使用Firebase部署模型,您可以減少應用程序的初始下載大小並更新應用程序的ML模型,而無需發布新版本的應用程序。而且,通過遠程配置和A / B測試,您可以為不同的用戶組動態提供不同的模型。

先決條件

  • MLModelDownloader庫僅可用於Swift。
  • TensorFlow Lite僅在使用iOS 9及更高版本的設備上運行。

TensorFlow Lite模型

TensorFlow Lite模型是經過優化可在移動設備上運行的ML模型。要獲取TensorFlow Lite模型:

在你開始之前

  1. 如果您尚未將Firebase添加到您的應用程序,請按照入門指南中的步驟進行操作。
  2. 在您的Podfile中包括Firebase:

    迅速

    第0993章
    安裝或更新項目的Pod之後,請確保使用其.xcworkspace打開Xcode項目。
  3. 在您的應用中,導入Firebase:

    迅速

    import Firebase
    import TensorFlowLite
    

1.部署模型

使用Firebase控制台或Firebase Admin Python和Node.js SDK部署自定義TensorFlow模型。請參閱部署和管理定制模型

將自定義模型添加到Firebase項目後,可以使用您指定的名稱在應用程序中引用該模型。您隨時可以部署新的TensorFlow Lite模型,並通過調用getModel()將新模型下載到用戶的設備上(請參見下文)。

2.將模型下載到設備並初始化TensorFlow Lite解釋器

要在您的應用中使用TensorFlow Lite模型,請首先使用Firebase ML SDK將模型的最新版本下載到設備上。

要開始下載模型,請調用模型下載器的getModel()方法,指定在上載模型時為模型分配的名稱,是否要始終下載最新模型以及允許下載的條件。

您可以從三種下載行為中進行選擇:

下載類型描述
localModel從設備獲取本地模型。如果沒有可用的本地模型,則其行為類似於latestModel 。如果您對檢查模型更新不感興趣,請使用此下載類型。例如,您正在使用“遠程配置”來檢索模型名稱,並且始終以新名稱上載模型(推薦)。
localModelUpdateInBackground從設備獲取本地模型,然後開始在後台更新模型。如果沒有可用的本地模型,則其行為類似於latestModel
latestModel獲取最新型號。如果本地模型是最新版本,則返回本地模型。否則,請下載最新型號。在下載最新版本之前,此行為將一直阻止(不推薦)。僅在明確需要最新版本的情況下,才使用此行為。

您應禁用與模型相關的功能,例如灰色或隱藏UI的一部分,直到您確認模型已下載。

迅速

let conditions = ModelDownloadConditions(allowsCellularAccess: false)
ModelDownloader.modelDownloader()
    .getModel(name: "your_model",
              downloadType: .localModelUpdateInBackground,
              conditions: conditions) { result in
        switch (result) {
        case .success(let customModel):
            do {
                // Download complete. Depending on your app, you could enable the ML
                // feature, or switch from the local model to the remote model, etc.

                // The CustomModel object contains the local path of the model file,
                // which you can use to instantiate a TensorFlow Lite interpreter.
                let interpreter = try Interpreter(modelPath: customModel.path)
            } catch {
                // Error. Bad model file?
            }
        case .failure(let error):
            // Download was unsuccessful. Don't enable ML features.
            print(error)
        }
}

許多應用程序都使用其初始化代碼啟動下載任務,但是您可以在需要使用模型之前隨時進行下載。

3.對輸入數據進行推斷

獲取模型的輸入和輸出形狀

TensorFlow Lite模型解釋器將輸入或輸出一個或多個多維數組作為輸出。這些數組包含byteintlongfloat值。在將數據傳遞到模型或使用其結果之前,必須知道模型使用的數組的數量和尺寸(“形狀”)。

如果您自己構建模型,或者記錄了模型的輸入和輸出格式,則可能已經有了此信息。如果您不知道模型輸入和輸出的形狀和數據類型,則可以使用TensorFlow Lite解釋器檢查模型。例如:

Python

import tensorflow as tf

interpreter = tf.lite.Interpreter(model_path="your_model.tflite")
interpreter.allocate_tensors()

# Print input shape and type
inputs = interpreter.get_input_details()
print('{} input(s):'.format(len(inputs)))
for i in range(0, len(inputs)):
    print('{} {}'.format(inputs[i]['shape'], inputs[i]['dtype']))

# Print output shape and type
outputs = interpreter.get_output_details()
print('\n{} output(s):'.format(len(outputs)))
for i in range(0, len(outputs)):
    print('{} {}'.format(outputs[i]['shape'], outputs[i]['dtype']))

輸出示例:

1 input(s):
[  1 224 224   3] <class 'numpy.float32'>

1 output(s):
[1 1000] <class 'numpy.float32'>

運行口譯員

確定模型輸入和輸出的格式之後,獲取輸入數據並對數據進行任何必要的轉換,以獲取適合模型的正確形狀的輸入。

例如,如果模型處理圖像,並且模型的輸入尺寸為[1, 224, 224, 3]浮點值,則可能需要將圖像的顏色值縮放到浮點範圍,如以下示例所示:

迅速

let image: CGImage = // Your input image
guard let context = CGContext(
  data: nil,
  width: image.width, height: image.height,
  bitsPerComponent: 8, bytesPerRow: image.width * 4,
  space: CGColorSpaceCreateDeviceRGB(),
  bitmapInfo: CGImageAlphaInfo.noneSkipFirst.rawValue
) else {
  return false
}

context.draw(image, in: CGRect(x: 0, y: 0, width: image.width, height: image.height))
guard let imageData = context.data else { return false }

var inputData = Data()
for row in 0 ..&lt; 224 {
  for col in 0 ..&lt; 224 {
    let offset = 4 * (row * context.width + col)
    // (Ignore offset 0, the unused alpha channel)
    let red = imageData.load(fromByteOffset: offset+1, as: UInt8.self)
    let green = imageData.load(fromByteOffset: offset+2, as: UInt8.self)
    let blue = imageData.load(fromByteOffset: offset+3, as: UInt8.self)

    // Normalize channel values to [0.0, 1.0]. This requirement varies
    // by model. For example, some models might require values to be
    // normalized to the range [-1.0, 1.0] instead, and others might
    // require fixed-point values or the original bytes.
    var normalizedRed = Float32(red) / 255.0
    var normalizedGreen = Float32(green) / 255.0
    var normalizedBlue = Float32(blue) / 255.0

    // Append normalized values to Data object in RGB order.
    let elementSize = MemoryLayout.size(ofValue: normalizedRed)
    var bytes = [UInt8](repeating: 0, count: elementSize)
    memcpy(&amp;bytes, &amp;normalizedRed, elementSize)
    inputData.append(&amp;bytes, count: elementSize)
    memcpy(&amp;bytes, &amp;normalizedGreen, elementSize)
    inputData.append(&amp;bytes, count: elementSize)
    memcpy(&ammp;bytes, &amp;normalizedBlue, elementSize)
    inputData.append(&amp;bytes, count: elementSize)
  }
}

然後,將輸入的NSData複製到解釋器並運行它:

迅速

try interpreter.allocateTensors()
try interpreter.copy(inputData, toInputAt: 0)
try interpreter.invoke()

您可以通過調用解釋器的output(at:)方法來獲取模型的輸出。您如何使用輸出取決於您使用的模型。

例如,如果要執行分類,那麼下一步,您可以將結果的索引映射到它們代表的標籤上:

迅速

let output = try interpreter.output(at: 0)
let probabilities =
        UnsafeMutableBufferPointer<Float32>.allocate(capacity: 1000)
output.data.copyBytes(to: probabilities)

guard let labelPath = Bundle.main.path(forResource: "retrained_labels", ofType: "txt") else { return }
let fileContents = try? String(contentsOfFile: labelPath)
guard let labels = fileContents?.components(separatedBy: "\n") else { return }

for i in labels.indices {
    print("\(labels[i]): \(probabilities[i])")
}

附錄:模型安全

無論您如何使TensorFlow Lite模型可用於Firebase ML,Firebase ML都將它們以標準的序列化protobuf格式存儲在本地存儲中。

從理論上講,這意味著任何人都可以復制您的模型。但是,實際上,大多數模型都是特定於應用程序的,並且由於優化而模糊不清,以至於風險與競爭者拆卸和重用您的代碼相似。但是,在應用程序中使用自定義模型之前,您應該意識到這種風險。