Google 致力于为黑人社区推动种族平等。查看具体举措

在 iOS 上使用自定义 TensorFlow Lite 模型

如果应用程序使用自定义的TensorFlow精简版机型,您可以用火力地堡ML部署您的机型。通过使用 Firebase 部署模型,您可以减少应用的初始下载大小并更新应用的机器学习模型,而无需发布应用的新版本。而且,通过远程配置和 A/B 测试,您可以为不同的用户组动态地提供不同的模型。

先决条件

  • MLModelDownloader库仅适用于斯威夫特。
  • TensorFlow Lite 只能在使用 iOS 9 及更新版本的设备上运行。

TensorFlow Lite 模型

TensorFlow Lite 模型是经过优化以在移动设备上运行的 ML 模型。要获取 TensorFlow Lite 模型:

在你开始之前

  1. 如果您尚未添加火力地堡到您的应用程序,通过遵循的步骤做这样的入门指南
  2. 在 Podfile 中包含 Firebase:

    迅速

    pod 'Firebase/MLModelDownloader'
    pod 'TensorFlowLiteSwift'
    
    您安装或更新项目的吊舱后,务必使用它来打开你的Xcode项目.xcworkspace
  3. 在您的应用中,导入 Firebase:

    迅速

    import Firebase
    import TensorFlowLite
    

1. 部署您的模型

使用 Firebase 控制台或 Firebase Admin Python 和 Node.js SDK 部署您的自定义 TensorFlow 模型。请参阅部署和管理定制机型

将自定义模型添加到 Firebase 项目后,您可以使用指定的名称在应用中引用该模型。在任何时候,你可以部署一个新的TensorFlow精简版模型,并通过调用下载的新模式到用户的设备getModel()见下文)。

2. 将模型下载到设备并初始化一个 TensorFlow Lite 解释器

要在您的应用中使用 TensorFlow Lite 模型,请首先使用 Firebase ML SDK 将模型的最新版本下载到设备。

要启动模型下载,调用模型下载的getModel()方法,指定你分配模式,当你上传它,你是否要随时下载最新的模型,以及在何种条件下要允许下载的名称。

您可以从三种下载行为中进行选择:

下载类型描述
localModel从设备获取本地模型。如果没有可用的本地模式,这种行为就像latestModel 。如果您对检查模型更新不感兴趣,请使用此下载类型。例如,您使用远程配置来检索模型名称,并且您总是以新名称(推荐)上传模型。
localModelUpdateInBackground从设备获取本地模型并在后台开始更新模型。如果没有可用的本地模式,这种行为就像latestModel
latestModel获取最新型号。如果本地模型是最新版本,则返回本地模型。否则,请下载最新型号。此行为将阻止,直到下载最新版本(不推荐)。仅在您明确需要最新版本的情况下才使用此行为。

您应该禁用与模型相关的功能——例如,灰显或隐藏部分 UI——直到您确认模型已被下载。

迅速

let conditions = ModelDownloadConditions(allowsCellularAccess: false)
ModelDownloader.modelDownloader()
    .getModel(name: "your_model",
              downloadType: .localModelUpdateInBackground,
              conditions: conditions) { result in
        switch (result) {
        case .success(let customModel):
            do {
                // Download complete. Depending on your app, you could enable the ML
                // feature, or switch from the local model to the remote model, etc.

                // The CustomModel object contains the local path of the model file,
                // which you can use to instantiate a TensorFlow Lite interpreter.
                let interpreter = try Interpreter(modelPath: customModel.path)
            } catch {
                // Error. Bad model file?
            }
        case .failure(let error):
            // Download was unsuccessful. Don't enable ML features.
            print(error)
        }
}

许多应用程序在其初始化代码中启动下载任务,但您可以在需要使用模型之前随时执行此操作。

3. 对输入数据进行推理

获取模型的输入和输出形状

TensorFlow Lite 模型解释器将一个或多个多维数组作为输入并作为输出生成。这些阵列包含任一byteintlong ,或float值。在将数据传递给模型或使用其结果之前,您必须知道模型使用的数组的数量和维度(“形状”)。

如果您自己构建模型,或者模型的输入和输出格式已记录,则您可能已经拥有此信息。如果您不知道模型输入和输出的形状和数据类型,您可以使用 TensorFlow Lite 解释器来检查您的模型。例如:

Python

import tensorflow as tf

interpreter = tf.lite.Interpreter(model_path="your_model.tflite")
interpreter.allocate_tensors()

# Print input shape and type
inputs = interpreter.get_input_details()
print('{} input(s):'.format(len(inputs)))
for i in range(0, len(inputs)):
    print('{} {}'.format(inputs[i]['shape'], inputs[i]['dtype']))

# Print output shape and type
outputs = interpreter.get_output_details()
print('\n{} output(s):'.format(len(outputs)))
for i in range(0, len(outputs)):
    print('{} {}'.format(outputs[i]['shape'], outputs[i]['dtype']))

示例输出:

1 input(s):
[  1 224 224   3] <class 'numpy.float32'>

1 output(s):
[1 1000] <class 'numpy.float32'>

运行解释器

在确定模型输入和输出的格式后,获取输入数据并对数据执行任何必要的转换,以获取模型的正确形状的输入。

例如,如果模型处理图像,和模型具有的输入维数[1, 224, 224, 3]浮点值,则可能必须对图像的颜色值扩展到浮点范围如下面的示例:

迅速

let image: CGImage = // Your input image
guard let context = CGContext(
  data: nil,
  width: image.width, height: image.height,
  bitsPerComponent: 8, bytesPerRow: image.width * 4,
  space: CGColorSpaceCreateDeviceRGB(),
  bitmapInfo: CGImageAlphaInfo.noneSkipFirst.rawValue
) else {
  return false
}

context.draw(image, in: CGRect(x: 0, y: 0, width: image.width, height: image.height))
guard let imageData = context.data else { return false }

var inputData = Data()
for row in 0 ..&lt; 224 {
  for col in 0 ..&lt; 224 {
    let offset = 4 * (row * context.width + col)
    // (Ignore offset 0, the unused alpha channel)
    let red = imageData.load(fromByteOffset: offset+1, as: UInt8.self)
    let green = imageData.load(fromByteOffset: offset+2, as: UInt8.self)
    let blue = imageData.load(fromByteOffset: offset+3, as: UInt8.self)

    // Normalize channel values to [0.0, 1.0]. This requirement varies
    // by model. For example, some models might require values to be
    // normalized to the range [-1.0, 1.0] instead, and others might
    // require fixed-point values or the original bytes.
    var normalizedRed = Float32(red) / 255.0
    var normalizedGreen = Float32(green) / 255.0
    var normalizedBlue = Float32(blue) / 255.0

    // Append normalized values to Data object in RGB order.
    let elementSize = MemoryLayout.size(ofValue: normalizedRed)
    var bytes = [UInt8](repeating: 0, count: elementSize)
    memcpy(&amp;bytes, &amp;normalizedRed, elementSize)
    inputData.append(&amp;bytes, count: elementSize)
    memcpy(&amp;bytes, &amp;normalizedGreen, elementSize)
    inputData.append(&amp;bytes, count: elementSize)
    memcpy(&ammp;bytes, &amp;normalizedBlue, elementSize)
    inputData.append(&amp;bytes, count: elementSize)
  }
}

然后,您输入复制NSData来解释并运行它:

迅速

try interpreter.allocateTensors()
try interpreter.copy(inputData, toInputAt: 0)
try interpreter.invoke()

你可以通过调用解释者获得该模型的输出中output(at:)方法。您如何使用输出取决于您使用的模型。

例如,如果您正在执行分类,作为下一步,您可以将结果的索引映射到它们代表的标签:

迅速

let output = try interpreter.output(at: 0)
let probabilities =
        UnsafeMutableBufferPointer<Float32>.allocate(capacity: 1000)
output.data.copyBytes(to: probabilities)

guard let labelPath = Bundle.main.path(forResource: "retrained_labels", ofType: "txt") else { return }
let fileContents = try? String(contentsOfFile: labelPath)
guard let labels = fileContents?.components(separatedBy: "\n") else { return }

for i in labels.indices {
    print("\(labels[i]): \(probabilities[i])")
}

附录:模型安全

无论您如何将 TensorFlow Lite 模型提供给 Firebase ML,Firebase ML 都会将它们以标准序列化 protobuf 格式存储在本地存储中。

理论上,这意味着任何人都可以复制您的模型。然而,在实践中,大多数模型都是特定于应用程序的,并且被优化混淆,以至于其风险与竞争对手反汇编和重用您的代码的风险相似。尽管如此,在您的应用程序中使用自定义模型之前,您应该意识到这种风险。