ML Kit を使用して推論に TensorFlow Lite モデルを使用する(Android)

ML Kit を使用すると、TensorFlow Lite モデルを使用してデバイス上で推論を実行できます。

この API を使用するには、Android SDK レベル 16(Jelly Bean)以上が必要です。

始める前に

  1. まだ Firebase を Android プロジェクトに追加していない場合は追加します。
  2. ML Kit Android ライブラリの依存関係をモジュール(アプリレベル)の Gradle ファイル(通常は app/build.gradle)に追加します:
    apply plugin: 'com.android.application'
    apply plugin: 'com.google.gms.google-services'
    
    dependencies {
      // ...
    
      implementation 'com.google.firebase:firebase-ml-model-interpreter:22.0.3'
    }
    
  3. 使用する TensorFlow モデルを TensorFlow Lite 形式に変換します。TOCO: TensorFlow Lite 最適化コンバータをご覧ください。

モデルをホストまたはバンドルする

アプリの推論に TensorFlow Lite モデルを使用するには、事前に ML Kit でモデルを利用可能にしておく必要があります。ML Kit は、Firebase を使用してリモートでホストされている TensorFlow Lite モデル、アプリのバイナリにバンドルされている TensorFlow Lite モデルのいずれか、またはその両方を使用できます。

Firebase でモデルをホストすることで、新しいアプリ バージョンをリリースすることなくモデルを更新できます。また、Remote ConfigA/B Testing を使用して、さまざまなモデルをさまざまなユーザーセットに動的に提供できます。

モデルをアプリにバンドルしないで、Firebase でホストすることによってのみモデルを提供することで、アプリの初期ダウンロード サイズを小さくできます。ただし、モデルがアプリにバンドルされていない場合、モデルに関連する機能は、アプリでモデルを初めてダウンロードするまで使用できません。

モデルをアプリにバンドルすると、Firebase でホストされているモデルを取得できないときにもアプリの ML 機能を引き続き使用できます。

モデルを Firebase でホストする

TensorFlow Lite モデルを Firebase でホストするには:

  1. Firebase コンソールの [ML Kit] セクションで [カスタム] タブをクリックします。
  2. [カスタムモデルを追加](または [別のモデルを追加])をクリックします。
  3. Firebase プロジェクトでモデルを識別するための名前を指定し、TensorFlow Lite モデルファイル(拡張子は通常 .tflite または .lite)をアップロードします。
  4. アプリのマニフェストで、INTERNET 権限が必要であることを宣言します:
    <uses-permission android:name="android.permission.INTERNET" />
    

Firebase プロジェクトにカスタムモデルを追加した後は、指定した名前を使用してアプリ内でモデルを参照できます。新しい TensorFlow Lite モデルはいつでもアップロードできます。新しいモデルは、次回アプリが起動したときにダウンロードされて使用されます。アプリがモデルを更新するために必要なデバイスの条件を定義できます(以下を参照)。

モデルをアプリにバンドルする

TensorFlow Lite モデルをアプリにバンドルするには、モデルファイル(拡張子は通常 .tflite または .lite)をアプリの assets/ フォルダにコピーします(先にこのフォルダを作成する必要がある場合があります。作成するには app/ フォルダを右クリックし、[新規] > [フォルダ] > Assets フォルダをクリックします)。

次に、アプリの build.gradle ファイルに次の行を追加します。これにより、アプリのビルド時にモデルが圧縮されなくなります。

android {

    // ...

    aaptOptions {
        noCompress "tflite"  // Your model's file extension: "tflite", "lite", etc.
    }
}

モデルファイルはアプリ パッケージに含められ、ML Kit から生のアセットとして使用できます。

モデルを読み込む

TensorFlow Lite モデルをアプリで使用するには、まずモデルが利用可能なロケーション(Firebase を使用したリモート、ローカル ストレージ、またはその両方)で ML Kit を構成します。モデルとしてローカルとリモートの両方を指定すると、使用可能な場合はリモートモデルが使用され、リモートモデルが使用できない場合はローカルに保存されているモデルにフォールバックします。

Firebase ホストモデルを構成する

Firebase でモデルをホストする場合は、FirebaseCustomRemoteModel オブジェクトを作成します。その際に、モデルをアップロードしたときに割り当てた名前を指定します。

Java

FirebaseCustomRemoteModel remoteModel =
        new FirebaseCustomRemoteModel.Builder("your_model").build();

Kotlin+KTX

val remoteModel = FirebaseCustomRemoteModel.Builder("your_model").build()

次に、ダウンロードを許可する条件を指定してモデルのダウンロード タスクを開始します。モデルがデバイスにない場合、または新しいバージョンのモデルが使用可能な場合、このタスクは Firebase から非同期でモデルをダウンロードします。

Java

FirebaseModelDownloadConditions conditions = new FirebaseModelDownloadConditions.Builder()
        .requireWifi()
        .build();
FirebaseModelManager.getInstance().download(remoteModel, conditions)
        .addOnCompleteListener(new OnCompleteListener<Void>() {
            @Override
            public void onComplete(@NonNull Task<Void> task) {
                // Success.
            }
        });

Kotlin+KTX

val conditions = FirebaseModelDownloadConditions.Builder()
    .requireWifi()
    .build()
FirebaseModelManager.getInstance().download(remoteModel, conditions)
    .addOnCompleteListener {
        // Success.
    }

多くのアプリは、初期化コードでモデルのダウンロード タスクを開始しますが、モデルを使用する前に開始することもできます。

ローカルモデルを構成する

アプリにモデルをバンドルする場合は、TensorFlow Lite モデルのファイル名を指定して FirebaseCustomLocalModel オブジェクトを作成します。

Java

FirebaseCustomLocalModel localModel = new FirebaseCustomLocalModel.Builder()
        .setAssetFilePath("your_model.tflite")
        .build();

Kotlin+KTX

val localModel = FirebaseCustomLocalModel.Builder()
    .setAssetFilePath("your_model.tflite")
    .build()

モデルからインタープリタを作成する

モデルソースを構成したら、そのソースのいずれか 1 つから FirebaseModelInterpreter オブジェクトを作成します。

ローカル バンドルモデルのみがある場合は FirebaseCustomLocalModel オブジェクトからインタープリタを作成するだけで済みます。

Java

FirebaseModelInterpreter interpreter;
try {
    FirebaseModelInterpreterOptions options =
            new FirebaseModelInterpreterOptions.Builder(localModel).build();
    interpreter = FirebaseModelInterpreter.getInstance(options);
} catch (FirebaseMLException e) {
    // ...
}

Kotlin+KTX

val options = FirebaseModelInterpreterOptions.Builder(localModel).build()
val interpreter = FirebaseModelInterpreter.getInstance(options)

リモートでホストされるモデルがある場合は、そのモデルを実行する前にダウンロード済みであることを確認する必要があります。モデルのダウンロード タスクのステータスは、モデル マネージャーの isModelDownloaded() メソッドを使用して確認できます。

ダウンロードのステータスはインタープリタを実行する前に確認するだけで済みますが、リモートでホストされるモデルとローカル バンドルモデルの両方がある場合は、モデル インタープリタをインスタンス化する、つまりインタープリタを作成する(リモートモデルをダウンロード済みの場合はリモートモデルから、ダウンロードされていない場合はローカルモデルから作成する)ときに確認しても問題ありません。

Java

FirebaseModelManager.getInstance().isModelDownloaded(remoteModel)
        .addOnSuccessListener(new OnSuccessListener<Boolean>() {
            @Override
            public void onSuccess(Boolean isDownloaded) {
                FirebaseModelInterpreterOptions options;
                if (isDownloaded) {
                    options = new FirebaseModelInterpreterOptions.Builder(remoteModel).build();
                } else {
                    options = new FirebaseModelInterpreterOptions.Builder(localModel).build();
                }
                FirebaseModelInterpreter interpreter = FirebaseModelInterpreter.getInstance(options);
                // ...
            }
        });

Kotlin+KTX

FirebaseModelManager.getInstance().isModelDownloaded(remoteModel)
    .addOnSuccessListener { isDownloaded -> 
    val options =
        if (isDownloaded) {
            FirebaseModelInterpreterOptions.Builder(remoteModel).build()
        } else {
            FirebaseModelInterpreterOptions.Builder(localModel).build()
        }
    val interpreter = FirebaseModelInterpreter.getInstance(options)
}

リモートでホストされるモデルのみがある場合は、モデルがダウンロード済みであることを確認するまで、モデルに関連する機能を無効にする必要があります(UI の一部をグレー表示または非表示にするなど)。確認はモデル マネージャーの download() メソッドにリスナーを接続して行います。

Java

FirebaseModelManager.getInstance().download(remoteModel, conditions)
        .addOnSuccessListener(new OnSuccessListener<Void>() {
            @Override
            public void onSuccess(Void v) {
              // Download complete. Depending on your app, you could enable
              // the ML feature, or switch from the local model to the remote
              // model, etc.
            }
        });

Kotlin+KTX

FirebaseModelManager.getInstance().download(remoteModel, conditions)
    .addOnCompleteListener {
        // Download complete. Depending on your app, you could enable the ML
        // feature, or switch from the local model to the remote model, etc.
    }

モデルの入力と出力を指定する

次に、モデル インタープリタの入出力形式を構成します。

TensorFlow Lite モデルは、1 つ以上の多次元配列を入力として受け取り、出力として生成します。これらの配列には、byteintlongfloat 値のいずれかが含まれます。モデルで使用する配列の数と次元(「シェイプ」)で ML キットを構成する必要があります。

モデルの入出力のシェイプとデータ型がわからない場合は、TensorFlow Lite Python インタープリタを使用してモデルを検査できます。次に例を示します。

import tensorflow as tf

interpreter = tf.lite.Interpreter(model_path="my_model.tflite")
interpreter.allocate_tensors()

# Print input shape and type
print(interpreter.get_input_details()[0]['shape'])  # Example: [1 224 224 3]
print(interpreter.get_input_details()[0]['dtype'])  # Example: <class 'numpy.float32'>

# Print output shape and type
print(interpreter.get_output_details()[0]['shape'])  # Example: [1 1000]
print(interpreter.get_output_details()[0]['dtype'])  # Example: <class 'numpy.float32'>

モデルの入出力の形式がわかったら、FirebaseModelInputOutputOptions オブジェクトを作成してアプリのモデル インタープリタを構成できます。

たとえば、浮動小数点画像分類モデルは、N 個の 224 x 224 x 3 チャネル(RGB)画像のまとまりを表す N x 224 x 224 x 3 の float 値の配列を入力として受け取り、1,000 個の float 値のリストを出力として生成します。このリストの値はそれぞれ、対象の画像が、モデルによって予測される 1,000 個のカテゴリのいずれか 1 つのメンバーである確率を表します。

このようなモデルの場合は、モデル インタープリタの入力と出力を次のように構成します。

Java

FirebaseModelInputOutputOptions inputOutputOptions =
        new FirebaseModelInputOutputOptions.Builder()
                .setInputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 224, 224, 3})
                .setOutputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 5})
                .build();

Kotlin+KTX

val inputOutputOptions = FirebaseModelInputOutputOptions.Builder()
        .setInputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1, 224, 224, 3))
        .setOutputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1, 5))
        .build()

入力データの推論を行う

最後に、モデルを使用して推論を行うため、入力データを取得して必要な変換を実行し、モデルに適したシェイプの入力配列を取得します。

たとえば、使用する画像分類モデルの入力シェイプが [1 224 224 3] 個の浮動小数点値である場合は、次の例に示すように、Bitmap オブジェクトから入力配列を生成できます。

Java

Bitmap bitmap = getYourInputImage();
bitmap = Bitmap.createScaledBitmap(bitmap, 224, 224, true);

int batchNum = 0;
float[][][][] input = new float[1][224][224][3];
for (int x = 0; x < 224; x++) {
    for (int y = 0; y < 224; y++) {
        int pixel = bitmap.getPixel(x, y);
        // Normalize channel values to [-1.0, 1.0]. This requirement varies by
        // model. For example, some models might require values to be normalized
        // to the range [0.0, 1.0] instead.
        input[batchNum][x][y][0] = (Color.red(pixel) - 127) / 128.0f;
        input[batchNum][x][y][1] = (Color.green(pixel) - 127) / 128.0f;
        input[batchNum][x][y][2] = (Color.blue(pixel) - 127) / 128.0f;
    }
}

Kotlin+KTX

val bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true)

val batchNum = 0
val input = Array(1) { Array(224) { Array(224) { FloatArray(3) } } }
for (x in 0..223) {
    for (y in 0..223) {
        val pixel = bitmap.getPixel(x, y)
        // Normalize channel values to [-1.0, 1.0]. This requirement varies by
        // model. For example, some models might require values to be normalized
        // to the range [0.0, 1.0] instead.
        input[batchNum][x][y][0] = (Color.red(pixel) - 127) / 255.0f
        input[batchNum][x][y][1] = (Color.green(pixel) - 127) / 255.0f
        input[batchNum][x][y][2] = (Color.blue(pixel) - 127) / 255.0f
    }
}

次に、入力データを使用して FirebaseModelInputs オブジェクトを作成し、そのオブジェクトとモデルの入出力指定をモデル インタープリタrun メソッドに渡します。

Java

FirebaseModelInputs inputs = new FirebaseModelInputs.Builder()
        .add(input)  // add() as many input arrays as your model requires
        .build();
firebaseInterpreter.run(inputs, inputOutputOptions)
        .addOnSuccessListener(
                new OnSuccessListener<FirebaseModelOutputs>() {
                    @Override
                    public void onSuccess(FirebaseModelOutputs result) {
                        // ...
                    }
                })
        .addOnFailureListener(
                new OnFailureListener() {
                    @Override
                    public void onFailure(@NonNull Exception e) {
                        // Task failed with an exception
                        // ...
                    }
                });

Kotlin+KTX

val inputs = FirebaseModelInputs.Builder()
        .add(input) // add() as many input arrays as your model requires
        .build()
firebaseInterpreter.run(inputs, inputOutputOptions)
        .addOnSuccessListener { result ->
            // ...
        }
        .addOnFailureListener { e ->
            // Task failed with an exception
            // ...
        }

呼び出しが成功した場合は、成功リスナーに渡されたオブジェクトの getOutput() メソッドを呼び出すことで出力を取得できます。例:

Java

float[][] output = result.getOutput(0);
float[] probabilities = output[0];

Kotlin+KTX

val output = result.getOutput<Array<FloatArray>>(0)
val probabilities = output[0]

出力をどのように使用するかは、使用しているモデルによって異なります。

たとえば、分類を行う場合は、次のステップとして、結果のインデックスをそれぞれが表すラベルにマッピングできます。

Java

BufferedReader reader = new BufferedReader(
        new InputStreamReader(getAssets().open("retrained_labels.txt")));
for (int i = 0; i < probabilities.length; i++) {
    String label = reader.readLine();
    Log.i("MLKit", String.format("%s: %1.4f", label, probabilities[i]));
}

Kotlin+KTX

val reader = BufferedReader(
        InputStreamReader(assets.open("retrained_labels.txt")))
for (i in probabilities.indices) {
    val label = reader.readLine()
    Log.i("MLKit", String.format("%s: %1.4f", label, probabilities[i]))
}

付録: モデルのセキュリティ

TensorFlow Lite モデルをどの方法で ML Kit に提供するかにかかわらず、これらのモデルは標準のシリアル化された protobuf 形式でローカル ストレージに保存されます。

理論上、これは誰でもモデルをコピーできることを意味します。ただし、実際には、ほとんどのモデルはアプリケーションに固有であり、最適化により難読化されています。このため、リスクは、競合他社がコードを逆アセンブルして再利用する場合と同程度です。そうであっても、アプリでカスタムモデルを使用する前に、このリスクを認識しておく必要があります。

Android API レベル 21(Lollipop)以降では、モデルは自動バックアップから除外されるディレクトリにダウンロードされます。

Android API レベル 20 以前では、モデルはアプリ専用の内部ストレージ内の com.google.firebase.ml.custom.models というディレクトリにダウンロードされます。BackupAgent を使用したファイルのバックアップを有効にした場合は、このディレクトリを除外できます。