Используйте пользовательскую модель TensorFlow Lite на Android

Если ваше приложение использует пользовательские модели TensorFlow Lite , вы можете использовать Firebase ML для развертывания своих моделей. Развертывая модели с помощью Firebase, вы можете уменьшить первоначальный размер загружаемого приложения и обновить модели машинного обучения своего приложения, не выпуская новую версию своего приложения. А с помощью удаленной настройки и A/B-тестирования вы можете динамически обслуживать разные модели для разных групп пользователей.

Модели TensorFlow Lite

Модели TensorFlow Lite — это модели машинного обучения, оптимизированные для работы на мобильных устройствах. Чтобы получить модель TensorFlow Lite:

Прежде чем вы начнете

  1. Если вы еще этого не сделали, добавьте Firebase в свой проект Android .
  2. Используя Firebase Android BoM , объявите зависимость для библиотеки Android-загрузчика модели Firebase ML в файле Gradle вашего модуля (на уровне приложения) (обычно app/build.gradle ).

    Кроме того, в рамках настройки загрузчика моделей Firebase ML вам необходимо добавить TensorFlow Lite SDK в свое приложение.

    Java

    dependencies {
        // Import the BoM for the Firebase platform
        implementation platform('com.google.firebase:firebase-bom:30.2.0')
    
        // Declare the dependency for the Firebase ML model downloader library
        // When using the BoM, you don't specify versions in Firebase library dependencies
        implementation 'com.google.firebase:firebase-ml-modeldownloader'
    // Also declare the dependency for the TensorFlow Lite library and specify its version implementation 'org.tensorflow:tensorflow-lite:2.3.0'
    }

    Используя Firebase Android BoM , ваше приложение всегда будет использовать совместимые версии библиотек Firebase Android.

    (Альтернатива) Объявите зависимости библиотеки Firebase без использования BoM

    Если вы решите не использовать Firebase BoM, вы должны указать каждую версию библиотеки Firebase в строке зависимостей.

    Обратите внимание: если вы используете несколько библиотек Firebase в своем приложении, мы настоятельно рекомендуем использовать BoM для управления версиями библиотек, что гарантирует совместимость всех версий.

    dependencies {
        // Declare the dependency for the Firebase ML model downloader library
        // When NOT using the BoM, you must specify versions in Firebase library dependencies
        implementation 'com.google.firebase:firebase-ml-modeldownloader:24.0.3'
    // Also declare the dependency for the TensorFlow Lite library and specify its version implementation 'org.tensorflow:tensorflow-lite:2.3.0'
    }

    Kotlin+KTX

    dependencies {
        // Import the BoM for the Firebase platform
        implementation platform('com.google.firebase:firebase-bom:30.2.0')
    
        // Declare the dependency for the Firebase ML model downloader library
        // When using the BoM, you don't specify versions in Firebase library dependencies
        implementation 'com.google.firebase:firebase-ml-modeldownloader-ktx'
    // Also declare the dependency for the TensorFlow Lite library and specify its version implementation 'org.tensorflow:tensorflow-lite:2.3.0'
    }

    Используя Firebase Android BoM , ваше приложение всегда будет использовать совместимые версии библиотек Firebase Android.

    (Альтернатива) Объявите зависимости библиотеки Firebase без использования BoM

    Если вы решите не использовать Firebase BoM, вы должны указать каждую версию библиотеки Firebase в строке зависимостей.

    Обратите внимание: если вы используете несколько библиотек Firebase в своем приложении, мы настоятельно рекомендуем использовать BoM для управления версиями библиотек, что гарантирует совместимость всех версий.

    dependencies {
        // Declare the dependency for the Firebase ML model downloader library
        // When NOT using the BoM, you must specify versions in Firebase library dependencies
        implementation 'com.google.firebase:firebase-ml-modeldownloader-ktx:24.0.3'
    // Also declare the dependency for the TensorFlow Lite library and specify its version implementation 'org.tensorflow:tensorflow-lite:2.3.0'
    }
  3. В манифесте вашего приложения объявите, что требуется разрешение INTERNET:
    <uses-permission android:name="android.permission.INTERNET" />

1. Разверните свою модель

Разверните свои пользовательские модели TensorFlow с помощью консоли Firebase или пакетов Firebase Admin Python и Node.js SDK. См. Развертывание пользовательских моделей и управление ими .

После добавления пользовательской модели в проект Firebase вы можете ссылаться на модель в своих приложениях, используя указанное вами имя. В любое время вы можете развернуть новую модель TensorFlow Lite и загрузить новую модель на устройства пользователей, вызвав getModel() (см. ниже).

2. Загрузите модель на устройство и инициализируйте интерпретатор TensorFlow Lite.

Чтобы использовать модель TensorFlow Lite в своем приложении, сначала используйте SDK Firebase ML, чтобы загрузить последнюю версию модели на устройство. Затем создайте экземпляр интерпретатора TensorFlow Lite с моделью.

Чтобы начать загрузку модели, вызовите метод getModel() загрузчика модели, указав имя, которое вы присвоили модели при ее загрузке, хотите ли вы всегда загружать последнюю модель и условия, при которых вы хотите разрешить загрузку.

Вы можете выбрать один из трех способов загрузки:

Тип загрузки Описание
LOCAL_MODEL Получите локальную модель с устройства. Если нет доступной локальной модели, это ведет себя как LATEST_MODEL . Используйте этот тип загрузки, если вы не заинтересованы в проверке обновлений модели. Например, вы используете Remote Config для получения имен моделей и всегда загружаете модели под новыми именами (рекомендуется).
LOCAL_MODEL_UPDATE_IN_BACKGROUND Получите локальную модель с устройства и запустите обновление модели в фоновом режиме. Если нет доступной локальной модели, это ведет себя как LATEST_MODEL .
ПОСЛЕДНЯЯ МОДЕЛЬ Получите последнюю модель. Если локальная модель является последней версией, возвращает локальную модель. В противном случае загрузите последнюю модель. Это поведение будет заблокировано до тех пор, пока не будет загружена последняя версия (не рекомендуется). Используйте это поведение только в тех случаях, когда вам явно нужна последняя версия.

Вы должны отключить функции, связанные с моделью, например сделать серым или скрыть часть вашего пользовательского интерфейса, пока вы не подтвердите, что модель была загружена.

Java

CustomModelDownloadConditions conditions = new CustomModelDownloadConditions.Builder()
    .requireWifi()  // Also possible: .requireCharging() and .requireDeviceIdle()
    .build();
FirebaseModelDownloader.getInstance()
    .getModel("your_model", DownloadType.LOCAL_MODEL_UPDATE_IN_BACKGROUND, conditions)
    .addOnSuccessListener(new OnSuccessListener<CustomModel>() {
      @Override
      public void onSuccess(CustomModel model) {
        // Download complete. Depending on your app, you could enable the ML
        // feature, or switch from the local model to the remote model, etc.

        // The CustomModel object contains the local path of the model file,
        // which you can use to instantiate a TensorFlow Lite interpreter.
        File modelFile = model.getFile();
        if (modelFile != null) {
            interpreter = new Interpreter(modelFile);
        }
      }
    });

Kotlin+KTX

val conditions = CustomModelDownloadConditions.Builder()
        .requireWifi()  // Also possible: .requireCharging() and .requireDeviceIdle()
        .build()
FirebaseModelDownloader.getInstance()
        .getModel("your_model", DownloadType.LOCAL_MODEL_UPDATE_IN_BACKGROUND,
            conditions)
        .addOnSuccessListener { model: CustomModel? ->
            // Download complete. Depending on your app, you could enable the ML
            // feature, or switch from the local model to the remote model, etc.

            // The CustomModel object contains the local path of the model file,
            // which you can use to instantiate a TensorFlow Lite interpreter.
            val modelFile = model?.file
            if (modelFile != null) {
                interpreter = Interpreter(modelFile)
            }
        }

Многие приложения запускают задачу загрузки в своем коде инициализации, но вы можете сделать это в любой момент, прежде чем вам понадобится использовать модель.

3. Выполните вывод по входным данным

Получите входные и выходные формы вашей модели

Интерпретатор модели TensorFlow Lite принимает на вход и создает на выходе один или несколько многомерных массивов. Эти массивы содержат значения byte , int , long или float . Прежде чем вы сможете передавать данные в модель или использовать ее результат, вы должны знать количество и размеры ("форму") массивов, которые использует ваша модель.

Если вы построили модель самостоятельно или если формат ввода и вывода модели задокументирован, у вас уже может быть эта информация. Если вы не знаете форму и тип данных ввода и вывода вашей модели, вы можете использовать интерпретатор TensorFlow Lite для проверки вашей модели. Например:

Питон

import tensorflow as tf

interpreter = tf.lite.Interpreter(model_path="your_model.tflite")
interpreter.allocate_tensors()

# Print input shape and type
inputs = interpreter.get_input_details()
print('{} input(s):'.format(len(inputs)))
for i in range(0, len(inputs)):
    print('{} {}'.format(inputs[i]['shape'], inputs[i]['dtype']))

# Print output shape and type
outputs = interpreter.get_output_details()
print('\n{} output(s):'.format(len(outputs)))
for i in range(0, len(outputs)):
    print('{} {}'.format(outputs[i]['shape'], outputs[i]['dtype']))

Пример вывода:

1 input(s):
[  1 224 224   3] <class 'numpy.float32'>

1 output(s):
[1 1000] <class 'numpy.float32'>

Запустите интерпретатор

После того, как вы определили формат входных и выходных данных вашей модели, получите входные данные и выполните любые преобразования данных, необходимые для получения входных данных правильной формы для вашей модели.

Например, если у вас есть модель классификации изображений с входной формой [1 224 224 3] значений с плавающей запятой, вы можете сгенерировать входной ByteBuffer из объекта Bitmap , как показано в следующем примере:

Java

Bitmap bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true);
ByteBuffer input = ByteBuffer.allocateDirect(224 * 224 * 3 * 4).order(ByteOrder.nativeOrder());
for (int y = 0; y < 224; y++) {
    for (int x = 0; x < 224; x++) {
        int px = bitmap.getPixel(x, y);

        // Get channel values from the pixel value.
        int r = Color.red(px);
        int g = Color.green(px);
        int b = Color.blue(px);

        // Normalize channel values to [-1.0, 1.0]. This requirement depends
        // on the model. For example, some models might require values to be
        // normalized to the range [0.0, 1.0] instead.
        float rf = (r - 127) / 255.0f;
        float gf = (g - 127) / 255.0f;
        float bf = (b - 127) / 255.0f;

        input.putFloat(rf);
        input.putFloat(gf);
        input.putFloat(bf);
    }
}

Kotlin+KTX

val bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true)
val input = ByteBuffer.allocateDirect(224*224*3*4).order(ByteOrder.nativeOrder())
for (y in 0 until 224) {
    for (x in 0 until 224) {
        val px = bitmap.getPixel(x, y)

        // Get channel values from the pixel value.
        val r = Color.red(px)
        val g = Color.green(px)
        val b = Color.blue(px)

        // Normalize channel values to [-1.0, 1.0]. This requirement depends on the model.
        // For example, some models might require values to be normalized to the range
        // [0.0, 1.0] instead.
        val rf = (r - 127) / 255f
        val gf = (g - 127) / 255f
        val bf = (b - 127) / 255f

        input.putFloat(rf)
        input.putFloat(gf)
        input.putFloat(bf)
    }
}

Затем выделите ByteBuffer , достаточно большой, чтобы содержать выходные данные модели, и передайте входной буфер и выходной буфер методу run() интерпретатора TensorFlow Lite. Например, для формы вывода [1 1000] значений с плавающей запятой:

Java

int bufferSize = 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE;
ByteBuffer modelOutput = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder());
interpreter.run(input, modelOutput);

Kotlin+KTX

val bufferSize = 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE
val modelOutput = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder())
interpreter?.run(input, modelOutput)

То, как вы используете выходные данные, зависит от используемой модели.

Например, если вы выполняете классификацию, в качестве следующего шага вы можете сопоставить индексы результата с метками, которые они представляют:

Java

modelOutput.rewind();
FloatBuffer probabilities = modelOutput.asFloatBuffer();
try {
    BufferedReader reader = new BufferedReader(
            new InputStreamReader(getAssets().open("custom_labels.txt")));
    for (int i = 0; i < probabilities.capacity(); i++) {
        String label = reader.readLine();
        float probability = probabilities.get(i);
        Log.i(TAG, String.format("%s: %1.4f", label, probability));
    }
} catch (IOException e) {
    // File not found?
}

Kotlin+KTX

modelOutput.rewind()
val probabilities = modelOutput.asFloatBuffer()
try {
    val reader = BufferedReader(
            InputStreamReader(assets.open("custom_labels.txt")))
    for (i in probabilities.capacity()) {
        val label: String = reader.readLine()
        val probability = probabilities.get(i)
        println("$label: $probability")
    }
} catch (e: IOException) {
    // File not found?
}

Приложение: Безопасность модели

Независимо от того, как вы делаете свои модели TensorFlow Lite доступными для Firebase ML, Firebase ML сохраняет их в стандартном сериализованном формате protobuf в локальном хранилище.

Теоретически это означает, что любой может скопировать вашу модель. Однако на практике большинство моделей настолько зависят от приложения и запутаны оптимизацией, что риск аналогичен риску дизассемблирования и повторного использования вашего кода конкурентами. Тем не менее, вы должны знать об этом риске, прежде чем использовать пользовательскую модель в своем приложении.

В Android API уровня 21 (Lollipop) и новее модель загружается в каталог, который исключен из автоматического резервного копирования .

В Android API уровня 20 и старше модель загружается в каталог с именем com.google.firebase.ml.custom.models во внутренней памяти приложения. Если вы включили резервное копирование файлов с помощью BackupAgent , вы можете исключить этот каталог.