Google is committed to advancing racial equity for Black communities. See how.
Эта страница была переведа с помощью Cloud Translation API.
Switch to English

Используйте пользовательскую модель TensorFlow Lite на Android

Если ваше приложение использует пользовательские модели TensorFlow Lite , вы можете использовать Firebase ML для развертывания своих моделей. Развертывая модели с Firebase, вы можете уменьшить начальный размер загрузки вашего приложения и обновить модели машинного обучения вашего приложения без выпуска новой версии вашего приложения. А с помощью Remote Config и A / B Testing вы можете динамически обслуживать разные модели для разных групп пользователей.

Модели TensorFlow Lite

Модели TensorFlow Lite - это модели машинного обучения, оптимизированные для работы на мобильных устройствах. Чтобы получить модель TensorFlow Lite:

Прежде чем вы начнете

  1. Если вы еще этого не сделали, добавьте Firebase в свой проект Android .
  2. build.gradle , что в build.gradle файле build.gradle уровне проекта build.gradle репозиторий Google Maven как в buildscript и в разделы allprojects .
  3. Добавьте библиотеки Firebase ML и TensorFlow Lite Android в свой модуль (на уровне приложения) файл Gradle (обычно app/build.gradle ):
    apply plugin: 'com.android.application'
    apply plugin: 'com.google.gms.google-services'
    
    dependencies {
      // ...
    
      implementation 'com.google.firebase:firebase-ml-model-interpreter:22.0.4'
      implementation 'org.tensorflow:tensorflow-lite:2.0.0'
    }
    
  4. В манифесте вашего приложения объявите, что требуется разрешение INTERNET:
    <uses-permission android:name="android.permission.INTERNET" />

1. Разверните свою модель

Разверните свои собственные модели TensorFlow с помощью консоли Firebase или SDK Firebase Admin Python и Node.js. См. Развертывание пользовательских моделей и управление ими .

После добавления пользовательской модели в проект Firebase вы можете ссылаться на модель в своих приложениях, используя указанное вами имя. В любое время вы можете загрузить новую модель TensorFlow Lite, и ваше приложение загрузит новую модель и начнет использовать ее при следующем перезапуске приложения. Вы можете определить условия устройства, необходимые для вашего приложения, чтобы попытаться обновить модель (см. Ниже).

2. Загрузите модель в устройство

Чтобы использовать модель TensorFlow Lite в своем приложении, сначала используйте Firebase ML SDK, чтобы загрузить последнюю версию модели на устройство.

Чтобы начать загрузку модели, вызовите метод download() диспетчера моделей, указав имя, присвоенное модели при ее загрузке, и условия, при которых вы хотите разрешить загрузку. Если модели нет на устройстве или доступна более новая версия модели, задача асинхронно загрузит модель из Firebase.

Вам следует отключить функции, связанные с моделью, например, затемнение или скрытие части пользовательского интерфейса, пока вы не подтвердите, что модель была загружена.

Ява

FirebaseCustomRemoteModel remoteModel =
      new FirebaseCustomRemoteModel.Builder("your_model").build();
FirebaseModelDownloadConditions conditions = new FirebaseModelDownloadConditions.Builder()
        .requireWifi()
        .build();
FirebaseModelManager.getInstance().download(remoteModel, conditions)
        .addOnSuccessListener(new OnSuccessListener<Void>() {
            @Override
            public void onSuccess(Void v) {
              // Download complete. Depending on your app, you could enable
              // the ML feature, or switch from the local model to the remote
              // model, etc.
            }
        });

Котлин + KTX

val remoteModel = FirebaseCustomRemoteModel.Builder("your_model").build()
val conditions = FirebaseModelDownloadConditions.Builder()
    .requireWifi()
    .build()
FirebaseModelManager.getInstance().download(remoteModel, conditions)
    .addOnCompleteListener {
        // Download complete. Depending on your app, you could enable the ML
        // feature, or switch from the local model to the remote model, etc.
    }

Многие приложения запускают задачу загрузки в своем коде инициализации, но вы можете сделать это в любой момент, прежде чем вам понадобится использовать модель.

3. Инициализировать интерпретатор TensorFlow Lite.

После загрузки модели в устройство вы можете получить местоположение файла модели, getLatestModelFile() метод getLatestModelFile() диспетчера getLatestModelFile() . Используйте это значение для создания экземпляра интерпретатора TensorFlow Lite:

Ява

FirebaseCustomRemoteModel remoteModel = new FirebaseCustomRemoteModel.Builder("your_model").build();
FirebaseModelManager.getInstance().getLatestModelFile(remoteModel)
        .addOnCompleteListener(new OnCompleteListener<File>() {
            @Override
            public void onComplete(@NonNull Task<File> task) {
                File modelFile = task.getResult();
                if (modelFile != null) {
                    interpreter = new Interpreter(modelFile);
                }
            }
        });

Котлин + KTX

val remoteModel = FirebaseCustomRemoteModel.Builder("your_model").build()
FirebaseModelManager.getInstance().getLatestModelFile(remoteModel)
    .addOnCompleteListener { task ->
        val modelFile = task.result
        if (modelFile != null) {
            interpreter = Interpreter(modelFile)
        }
    }

4. Выполните вывод по входным данным.

Получите входные и выходные формы вашей модели

Интерпретатор модели TensorFlow Lite принимает на вход и выдает на выходе один или несколько многомерных массивов. Эти массивы содержат значения типа byte , int , long или float . Прежде чем вы сможете передавать данные в модель или использовать их результат, вы должны знать количество и размеры («форму») массивов, используемых вашей моделью.

Если вы построили модель самостоятельно или задокументирован формат ввода и вывода модели, возможно, у вас уже есть эта информация. Если вы не знаете форму и тип данных входных и выходных данных вашей модели, вы можете использовать интерпретатор TensorFlow Lite для проверки вашей модели. Например:

питон

import tensorflow as tf

interpreter = tf.lite.Interpreter(model_path="your_model.tflite")
interpreter.allocate_tensors()

# Print input shape and type
inputs = interpreter.get_input_details()
print('{} input(s):'.format(len(inputs)))
for i in range(0, len(inputs)):
    print('{} {}'.format(inputs[i]['shape'], inputs[i]['dtype']))

# Print output shape and type
outputs = interpreter.get_output_details()
print('\n{} output(s):'.format(len(outputs)))
for i in range(0, len(outputs)):
    print('{} {}'.format(outputs[i]['shape'], outputs[i]['dtype']))

Пример вывода:

1 input(s):
[  1 224 224   3] <class 'numpy.float32'>

1 output(s):
[1 1000] <class 'numpy.float32'>

Запускаем интерпретатор

После того, как вы определили формат ввода и вывода вашей модели, получите ваши входные данные и выполните любые преобразования данных, которые необходимы для получения входных данных правильной формы для вашей модели.

Например, если у вас есть модель классификации изображений с входной формой [1 224 224 3] значений с плавающей запятой, вы можете сгенерировать входной ByteBuffer из объекта Bitmap как показано в следующем примере:

Ява

Bitmap bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true);
ByteBuffer input = ByteBuffer.allocateDirect(224 * 224 * 3 * 4).order(ByteOrder.nativeOrder());
for (int y = 0; y < 224; y++) {
    for (int x = 0; x < 224; x++) {
        int px = bitmap.getPixel(x, y);

        // Get channel values from the pixel value.
        int r = Color.red(px);
        int g = Color.green(px);
        int b = Color.blue(px);

        // Normalize channel values to [-1.0, 1.0]. This requirement depends
        // on the model. For example, some models might require values to be
        // normalized to the range [0.0, 1.0] instead.
        float rf = (r - 127) / 255.0f;
        float gf = (g - 127) / 255.0f;
        float bf = (b - 127) / 255.0f;

        input.putFloat(rf);
        input.putFloat(gf);
        input.putFloat(bf);
    }
}

Котлин + KTX

val bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true)
val input = ByteBuffer.allocateDirect(224*224*3*4).order(ByteOrder.nativeOrder())
for (y in 0 until 224) {
    for (x in 0 until 224) {
        val px = bitmap.getPixel(x, y)

        // Get channel values from the pixel value.
        val r = Color.red(px)
        val g = Color.green(px)
        val b = Color.blue(px)

        // Normalize channel values to [-1.0, 1.0]. This requirement depends on the model.
        // For example, some models might require values to be normalized to the range
        // [0.0, 1.0] instead.
        val rf = (r - 127) / 255f
        val gf = (g - 127) / 255f
        val bf = (b - 127) / 255f

        input.putFloat(rf)
        input.putFloat(gf)
        input.putFloat(bf)
    }
}

Затем выделите ByteBuffer достаточно большой, чтобы содержать выходные данные модели, и передайте входной буфер и выходной буфер методу run() интерпретатора TensorFlow Lite. Например, для формы вывода [1 1000] значений с плавающей запятой:

Ява

int bufferSize = 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE;
ByteBuffer modelOutput = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder());
interpreter.run(input, modelOutput);

Котлин + KTX

val bufferSize = 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE
val modelOutput = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder())
interpreter?.run(input, modelOutput)

Как вы используете вывод, зависит от модели, которую вы используете.

Например, если вы выполняете классификацию, в качестве следующего шага вы можете сопоставить индексы результата с метками, которые они представляют:

Ява

modelOutput.rewind();
FloatBuffer probabilities = modelOutput.asFloatBuffer();
try {
    BufferedReader reader = new BufferedReader(
            new InputStreamReader(getAssets().open("custom_labels.txt")));
    for (int i = 0; i < probabilities.capacity(); i++) {
        String label = reader.readLine();
        float probability = probabilities.get(i);
        Log.i(TAG, String.format("%s: %1.4f", label, probability));
    }
} catch (IOException e) {
    // File not found?
}

Котлин + KTX

modelOutput.rewind()
val probabilities = modelOutput.asFloatBuffer()
try {
    val reader = BufferedReader(
            InputStreamReader(assets.open("custom_labels.txt")))
    for (i in probabilities.capacity()) {
        val label: String = reader.readLine()
        val probability = probabilities.get(i)
        println("$label: $probability")
    }
} catch (e: IOException) {
    // File not found?
}

Приложение: возврат к модели с локальной связью

Когда вы размещаете свою модель в Firebase, любые связанные с ней функции не будут доступны, пока ваше приложение не загрузит модель в первый раз. Для некоторых приложений это может быть нормально, но если ваша модель поддерживает основные функции, вы можете связать версию своей модели с вашим приложением и использовать наилучшую доступную версию. Таким образом, вы можете убедиться, что функции машинного обучения вашего приложения работают, когда модель, размещенная в Firebase, недоступна.

Чтобы связать модель TensorFlow Lite с приложением:

  1. Скопируйте файл модели (обычно с .tflite или .lite ) в папку assets/ вашего приложения. (Возможно, вам потребуется сначала создать папку, щелкнув app/ папку правой кнопкой мыши, а затем выбрав « Создать»> «Папка»> «Папка с активами» .)

  2. Добавьте следующее в файл build.gradle вашего приложения, чтобы Gradle не сжимал модели при создании приложения:

    android {
    
        // ...
    
        aaptOptions {
            noCompress "tflite", "lite"
        }
    }
    

Затем используйте модель с локальным пакетом, когда размещенная модель недоступна:

Ява

FirebaseCustomRemoteModel remoteModel =
        new FirebaseCustomRemoteModel.Builder("your_model").build();
FirebaseModelManager.getInstance().getLatestModelFile(remoteModel)
        .addOnCompleteListener(new OnCompleteListener<File>() {
            @Override
            public void onComplete(@NonNull Task<File> task) {
                File modelFile = task.getResult();
                if (modelFile != null) {
                    interpreter = new Interpreter(modelFile);
                } else {
                    try {
                        InputStream inputStream = getAssets().open("your_fallback_model.tflite");
                        byte[] model = new byte[inputStream.available()];
                        inputStream.read(model);
                        ByteBuffer buffer = ByteBuffer.allocateDirect(model.length)
                                .order(ByteOrder.nativeOrder());
                        buffer.put(model);
                        interpreter = new Interpreter(buffer);
                    } catch (IOException e) {
                        // File not found?
                    }
                }
            }
        });

Котлин + KTX

val remoteModel = FirebaseCustomRemoteModel.Builder("your_model").build()
FirebaseModelManager.getInstance().getLatestModelFile(remoteModel)
    .addOnCompleteListener { task ->
        val modelFile = task.result
        if (modelFile != null) {
            interpreter = Interpreter(modelFile)
        } else {
            val model = assets.open("your_fallback_model.tflite").readBytes()
            val buffer = ByteBuffer.allocateDirect(model.size).order(ByteOrder.nativeOrder())
            buffer.put(model)
            interpreter = Interpreter(buffer)
        }
    }

Приложение: Модель безопасности

Независимо от того, как вы делаете свои модели TensorFlow Lite доступными для Firebase ML, Firebase ML сохраняет их в стандартном сериализованном формате protobuf в локальном хранилище.

Теоретически это означает, что любой может скопировать вашу модель. Однако на практике большинство моделей настолько привязаны к конкретному приложению и обфусцированы оптимизацией, что риск такой же, как у конкурентов, разобравших и повторно использующих ваш код. Тем не менее, вы должны знать об этом риске, прежде чем использовать пользовательскую модель в своем приложении.

На уровне 21 API Android (Lollipop) и новее модель загружается в каталог, который исключен из автоматического резервного копирования .

На Android API уровня 20 и старше модель загружается в каталог с именем com.google.firebase.ml.custom.models во внутренней памяти приложения. Если вы включили резервное копирование файлов с помощью BackupAgent , вы можете исключить этот каталог.