Google is committed to advancing racial equity for Black communities. See how.
本頁面由 Cloud Translation API 翻譯而成。
Switch to English

使用TensorFlow Lite模型在iOS上通過ML Kit進行推理

{%setvar ml_link%} 最新版本 {%endsetvar%}

您可以使用ML Kit通過TensorFlow Lite模型執行設備上的推斷。

ML Kit只能在運行iOS 9及更高版本的設備上使用TensorFlow Lite模型。

在你開始之前

  1. 如果您尚未將Firebase添加到您的應用程序,請按照入門指南中的步驟進行操作。
  2. 在您的Podfile中包括ML Kit庫:
    pod 'Firebase/MLModelInterpreter', '6.25.0'
    
    安裝或更新項目的Pod之後,請確保使用其.xcworkspace打開Xcode項目。
  3. 在您的應用中,導入Firebase:

    迅速

    import Firebase

    目標C

    @import Firebase;
  4. 將您要使用的TensorFlow模型轉換為TensorFlow Lite格式。請參閱TOCO:TensorFlow Lite優化轉換器

託管或捆綁您的模型

在將TensorFlow Lite模型用於應用程序推理之前,您必須使該模型可用於ML Kit。 ML Kit可以使用通過Firebase遠程託管的TensorFlow Lite模型,與應用程序二進製文件捆綁在一起的版本,或同時使用兩者。

通過在Firebase上託管模型,您可以在不發布新應用版本的情況下更新模型,並且可以使用“遠程配置”和“ A / B測試”為不同的用戶組動態提供不同的模型。

如果您選擇僅通過將模型託管在Firebase中來提供模型,而不將其與應用程序捆綁在一起,則可以減小應用程序的初始下載大小。但是請記住,如果模型未與您的應用程序捆綁在一起,則在您的應用程序首次下載模型之前,與模型相關的任何功能將不可用。

通過將模型與應用程序捆綁在一起,可以確保當Firebase託管的模型不可用時,應用程序的ML功能仍然可以使用。

Firebase上的主機模型

要將TensorFlow Lite模型託管在Firebase上:

  1. Firebase控制台ML Kit部分中,單擊“ 自定義”選項卡。
  2. 點擊添加自定義模型 (或添加其他模型 )。
  3. 指定將用於在Firebase項目中標識模型的名稱,然後上傳TensorFlow Lite模型文件(通常以.tflite.lite )。

將自定義模型添加到Firebase項目後,您可以使用指定的名稱在應用程序中引用該模型。您隨時可以上傳新的TensorFlow Lite模型,您的應用將下載新模型並在下次重新啟動時開始使用它。您可以定義應用程序嘗試更新模型所需的設備條件(請參見下文)。

將模型與應用捆綁

要將TensorFlow Lite模型與您的應用程序捆綁在一起,請將模型文件(通常以.tflite.lite )添加到Xcode項目中,請注意選擇複製捆綁包資源 。模型文件將包含在應用程序捆綁包中,並且可用於ML Kit。

加載模型

要在您的應用程序中使用TensorFlow Lite模型,請先使用模型可用的位置配置ML Kit:使用Firebase遠程,在本地存儲中或在這兩者中。如果同時指定了本地模型和遠程模型,則可以使用遠程模型(如果可用),如果遠程模型不可用,則可以使用本地模型。

配置Firebase託管的模型

如果您使用Firebase託管模型,請創建CustomRemoteModel對象,並在發布模型時指定分配給模型的名稱:

迅速

 let remoteModel = CustomRemoteModel(
  name: "your_remote_model"  // The name you assigned in the Firebase console.
)
 

目標C

 // Initialize using the name you assigned in the Firebase console.
FIRCustomRemoteModel *remoteModel =
    [[FIRCustomRemoteModel alloc] initWithName:@"your_remote_model"];
 

然後,啟動模型下載任務,指定要允許下載的條件。如果模型不在設備上,或者有較新版本的模型可用,則該任務將從Firebase異步下載模型:

迅速

 let downloadConditions = ModelDownloadConditions(
  allowsCellularAccess: true,
  allowsBackgroundDownloading: true
)

let downloadProgress = ModelManager.modelManager().download(
  remoteModel,
  conditions: downloadConditions
)
 

目標C

 FIRModelDownloadConditions *downloadConditions =
    [[FIRModelDownloadConditions alloc] initWithAllowsCellularAccess:YES
                                         allowsBackgroundDownloading:YES];

NSProgress *downloadProgress =
    [[FIRModelManager modelManager] downloadRemoteModel:remoteModel
                                             conditions:downloadConditions];
 

許多應用程序都使用其初始化代碼啟動下載任務,但是您可以在需要使用模型之前隨時進行下載。

配置本地模型

如果您將模型與應用程序捆綁在一起,請創建一個CustomLocalModel對象,並指定TensorFlow Lite模型的文件名:

迅速

 guard let modelPath = Bundle.main.path(
  forResource: "your_model",
  ofType: "tflite",
  inDirectory: "your_model_directory"
) else { /* Handle error. */ }
let localModel = CustomLocalModel(modelPath: modelPath)
 

目標C

 NSString *modelPath = [NSBundle.mainBundle pathForResource:@"your_model"
                                                    ofType:@"tflite"
                                               inDirectory:@"your_model_directory"];
FIRCustomLocalModel *localModel =
    [[FIRCustomLocalModel alloc] initWithModelPath:modelPath];
 

根據模型創建解釋器

配置模型源之後,請從其中一個創建ModelInterpreter對象。

如果您只有本地捆綁的模型,只需將CustomLocalModelmodelInterpreter(localModel:)傳遞給modelInterpreter(localModel:)

迅速

 let interpreter = ModelInterpreter.modelInterpreter(localModel: localModel)
 

目標C

 FIRModelInterpreter *interpreter =
    [FIRModelInterpreter modelInterpreterForLocalModel:localModel];
 

如果您有一個遠程託管的模型,則必須在運行它之前檢查它是否已下載。您可以使用模型管理器的isModelDownloaded(remoteModel:)方法檢查模型下載任務的狀態。

儘管您只需要在運行解釋器之前進行確認,但是如果您同時具有遠程託管的模型和本地捆綁的模型,則在實例化ModelInterpreter時執行此檢查可能是有意義的:如果是遠程模型,則從遠程模型創建解釋器已下載,否則從本地模型下載。

迅速

 var interpreter: ModelInterpreter
if ModelManager.modelManager().isModelDownloaded(remoteModel) {
  interpreter = ModelInterpreter.modelInterpreter(remoteModel: remoteModel)
} else {
  interpreter = ModelInterpreter.modelInterpreter(localModel: localModel)
}
 

目標C

 FIRModelInterpreter *interpreter;
if ([[FIRModelManager modelManager] isModelDownloaded:remoteModel]) {
  interpreter = [FIRModelInterpreter modelInterpreterForRemoteModel:remoteModel];
} else {
  interpreter = [FIRModelInterpreter modelInterpreterForLocalModel:localModel];
}
 

如果只有遠程託管的模型,則應禁用與模型相關的功能(例如,變灰或隱藏UI的一部分),直到確認已下載模型為止。

您可以通過將觀察者附加到默認的通知中心來獲取模型下載狀態。確保在觀察者塊中使用對self的弱引用,因為下載可能會花費一些時間,並且在下載完成時可以釋放原始對象。例如:

迅速

NotificationCenter.default.addObserver(
    forName: .firebaseMLModelDownloadDidSucceed,
    object: nil,
    queue: nil
) { [weak self] notification in
    guard let strongSelf = self,
        let userInfo = notification.userInfo,
        let model = userInfo[ModelDownloadUserInfoKey.remoteModel.rawValue]
            as? RemoteModel,
        model.name == "your_remote_model"
        else { return }
    // The model was downloaded and is available on the device
}

NotificationCenter.default.addObserver(
    forName: .firebaseMLModelDownloadDidFail,
    object: nil,
    queue: nil
) { [weak self] notification in
    guard let strongSelf = self,
        let userInfo = notification.userInfo,
        let model = userInfo[ModelDownloadUserInfoKey.remoteModel.rawValue]
            as? RemoteModel
        else { return }
    let error = userInfo[ModelDownloadUserInfoKey.error.rawValue]
    // ...
}

目標C

__weak typeof(self) weakSelf = self;

[NSNotificationCenter.defaultCenter
    addObserverForName:FIRModelDownloadDidSucceedNotification
                object:nil
                 queue:nil
            usingBlock:^(NSNotification *_Nonnull note) {
              if (weakSelf == nil | note.userInfo == nil) {
                return;
              }
              __strong typeof(self) strongSelf = weakSelf;

              FIRRemoteModel *model = note.userInfo[FIRModelDownloadUserInfoKeyRemoteModel];
              if ([model.name isEqualToString:@"your_remote_model"]) {
                // The model was downloaded and is available on the device
              }
            }];

[NSNotificationCenter.defaultCenter
    addObserverForName:FIRModelDownloadDidFailNotification
                object:nil
                 queue:nil
            usingBlock:^(NSNotification *_Nonnull note) {
              if (weakSelf == nil | note.userInfo == nil) {
                return;
              }
              __strong typeof(self) strongSelf = weakSelf;

              NSError *error = note.userInfo[FIRModelDownloadUserInfoKeyError];
            }];

指定模型的輸入和輸出

接下來,配置模型解釋器的輸入和輸出格式。

TensorFlow Lite模型將輸入作為輸入,並產生一個或多個多維數組作為輸出。這些數組包含byteintlongfloat值。您必須使用模型使用的陣列的數量和尺寸(“形狀”)來配置ML Kit。

如果您不知道模型輸入和輸出的形狀和數據類型,則可以使用TensorFlow Lite Python解釋器檢查模型。例如:

import tensorflow as tf

interpreter = tf.lite.Interpreter(model_path="my_model.tflite")
interpreter.allocate_tensors()

# Print input shape and type
print(interpreter.get_input_details()[0]['shape'])  # Example: [1 224 224 3]
print(interpreter.get_input_details()[0]['dtype'])  # Example: <class 'numpy.float32'>

# Print output shape and type
print(interpreter.get_output_details()[0]['shape'])  # Example: [1 1000]
print(interpreter.get_output_details()[0]['dtype'])  # Example: <class 'numpy.float32'>

確定模型輸入和輸出的格式之後,通過創建ModelInputOutputOptions對象來配置應用程序的模型解釋器。

例如,浮點圖像分類模型可能需要作為輸入N的x224x224x3陣列Float值,代表一批N 224x224三通道(RGB)圖像,並且產生作為輸出的千列表Float值,每個值代表圖像是模型預測的1000個類別之一的概率。

對於這樣的模型,您將配置模型解釋器的輸入和輸出,如下所示:

迅速

let ioOptions = ModelInputOutputOptions()
do {
    try ioOptions.setInputFormat(index: 0, type: .float32, dimensions: [1, 224, 224, 3])
    try ioOptions.setOutputFormat(index: 0, type: .float32, dimensions: [1, 1000])
} catch let error as NSError {
    print("Failed to set input or output format with error: \(error.localizedDescription)")
}

目標C

FIRModelInputOutputOptions *ioOptions = [[FIRModelInputOutputOptions alloc] init];
NSError *error;
[ioOptions setInputFormatForIndex:0
                             type:FIRModelElementTypeFloat32
                       dimensions:@[@1, @224, @224, @3]
                            error:&error];
if (error != nil) { return; }
[ioOptions setOutputFormatForIndex:0
                              type:FIRModelElementTypeFloat32
                        dimensions:@[@1, @1000]
                             error:&error];
if (error != nil) { return; }

對輸入數據進行推斷

最後,要使用模型進行推理,請獲取輸入數據,對模型可能需要的數據進行任何轉換,並構建一個包含數據的Data對象。

例如,如果模型處理圖像,並且模型的輸入尺寸為[BATCH_SIZE, 224, 224, 3]浮點值,則可能必須將圖像的顏色值縮放到浮點範圍,如以下示例所示:

迅速

let image: CGImage = // Your input image
guard let context = CGContext(
  data: nil,
  width: image.width, height: image.height,
  bitsPerComponent: 8, bytesPerRow: image.width * 4,
  space: CGColorSpaceCreateDeviceRGB(),
  bitmapInfo: CGImageAlphaInfo.noneSkipFirst.rawValue
) else {
  return false
}

context.draw(image, in: CGRect(x: 0, y: 0, width: image.width, height: image.height))
guard let imageData = context.data else { return false }

let inputs = ModelInputs()
var inputData = Data()
do {
  for row in 0 ..< 224 {
    for col in 0 ..< 224 {
      let offset = 4 * (col * context.width + row)
      // (Ignore offset 0, the unused alpha channel)
      let red = imageData.load(fromByteOffset: offset+1, as: UInt8.self)
      let green = imageData.load(fromByteOffset: offset+2, as: UInt8.self)
      let blue = imageData.load(fromByteOffset: offset+3, as: UInt8.self)

      // Normalize channel values to [0.0, 1.0]. This requirement varies
      // by model. For example, some models might require values to be
      // normalized to the range [-1.0, 1.0] instead, and others might
      // require fixed-point values or the original bytes.
      var normalizedRed = Float32(red) / 255.0
      var normalizedGreen = Float32(green) / 255.0
      var normalizedBlue = Float32(blue) / 255.0

      // Append normalized values to Data object in RGB order.
      let elementSize = MemoryLayout.size(ofValue: normalizedRed)
      var bytes = [UInt8](repeating: 0, count: elementSize)
      memcpy(&bytes, &normalizedRed, elementSize)
      inputData.append(&bytes, count: elementSize)
      memcpy(&bytes, &normalizedGreen, elementSize)
      inputData.append(&bytes, count: elementSize)
      memcpy(&ammp;bytes, &normalizedBlue, elementSize)
      inputData.append(&bytes, count: elementSize)
    }
  }
  try inputs.addInput(inputData)
} catch let error {
  print("Failed to add input: \(error)")
}

目標C

CGImageRef image = // Your input image
long imageWidth = CGImageGetWidth(image);
long imageHeight = CGImageGetHeight(image);
CGContextRef context = CGBitmapContextCreate(nil,
                                             imageWidth, imageHeight,
                                             8,
                                             imageWidth * 4,
                                             CGColorSpaceCreateDeviceRGB(),
                                             kCGImageAlphaNoneSkipFirst);
CGContextDrawImage(context, CGRectMake(0, 0, imageWidth, imageHeight), image);
UInt8 *imageData = CGBitmapContextGetData(context);

FIRModelInputs *inputs = [[FIRModelInputs alloc] init];
NSMutableData *inputData = [[NSMutableData alloc] initWithCapacity:0];

for (int row = 0; row < 224; row++) {
  for (int col = 0; col < 224; col++) {
    long offset = 4 * (col * imageWidth + row);
    // Normalize channel values to [0.0, 1.0]. This requirement varies
    // by model. For example, some models might require values to be
    // normalized to the range [-1.0, 1.0] instead, and others might
    // require fixed-point values or the original bytes.
    // (Ignore offset 0, the unused alpha channel)
    Float32 red = imageData[offset+1] / 255.0f;
    Float32 green = imageData[offset+2] / 255.0f;
    Float32 blue = imageData[offset+3] / 255.0f;

    [inputData appendBytes:&red length:sizeof(red)];
    [inputData appendBytes:&green length:sizeof(green)];
    [inputData appendBytes:&blue length:sizeof(blue)];
  }
}

[inputs addInput:inputData error:&error];
if (error != nil) { return nil; }

準備好模型輸入後(並確認模型可用),將輸入和輸入/輸出選項傳遞給模型解釋器run(inputs:options:completion:)方法。

迅速

interpreter.run(inputs: inputs, options: ioOptions) { outputs, error in
    guard error == nil, let outputs = outputs else { return }
    // Process outputs
    // ...
}

目標C

[interpreter runWithInputs:inputs
                   options:ioOptions
                completion:^(FIRModelOutputs * _Nullable outputs,
                             NSError * _Nullable error) {
  if (error != nil || outputs == nil) {
    return;
  }
  // Process outputs
  // ...
}];

您可以通過調用返回對象的output(index:)方法獲取輸出。例如:

迅速

// Get first and only output of inference with a batch size of 1
let output = try? outputs.output(index: 0) as? [[NSNumber]]
let probabilities = output??[0]

目標C

// Get first and only output of inference with a batch size of 1
NSError *outputError;
NSArray *probabilites = [outputs outputAtIndex:0 error:&outputError][0];

您如何使用輸出取決於您使用的模型。

例如,如果要執行分類,則下一步,您可以將結果的索引映射到它們表示的標籤。假設您有一個文本文件,其中包含每個模型類別的標籤字符串;您可以通過執行以下操作將標籤字符串映射到輸出概率:

迅速

guard let labelPath = Bundle.main.path(forResource: "retrained_labels", ofType: "txt") else { return }
let fileContents = try? String(contentsOfFile: labelPath)
guard let labels = fileContents?.components(separatedBy: "\n") else { return }

for i in 0 ..< labels.count {
  if let probability = probabilities?[i] {
    print("\(labels[i]): \(probability)")
  }
}

目標C

NSError *labelReadError = nil;
NSString *labelPath = [NSBundle.mainBundle pathForResource:@"retrained_labels"
                                                    ofType:@"txt"];
NSString *fileContents = [NSString stringWithContentsOfFile:labelPath
                                                   encoding:NSUTF8StringEncoding
                                                      error:&labelReadError];
if (labelReadError != nil || fileContents == NULL) { return; }
NSArray<NSString *> *labels = [fileContents componentsSeparatedByString:@"\n"];
for (int i = 0; i < labels.count; i++) {
    NSString *label = labels[i];
    NSNumber *probability = probabilites[i];
    NSLog(@"%@: %f", label, probability.floatValue);
}

附錄:模型安全

無論您如何將TensorFlow Lite模型提供給ML Kit,ML Kit均以標準序列化protobuf格式將它們存儲在本地存儲中。

從理論上講,這意味著任何人都可以復制您的模型。但是,實際上,大多數模型都是特定於應用程序的,並且由於優化而模糊不清,以致於風險與競爭對手拆卸和重用代碼的風險相似。但是,在應用程序中使用自定義模型之前,您應該意識到這種風險。