Usar um modelo personalizado do TensorFlow Lite no Android

Se o app usa modelos personalizados do TensorFlow Lite, é possível usar o Firebase ML para implantar os modelos. Ao implantar modelos com o Firebase, é possível reduzir o tamanho inicial do download do app e atualizar os modelos de ML sem lançar uma nova versão do aplicativo. Além disso, com a Configuração remota e o Teste A/B, você pode exibir dinamicamente diferentes modelos para diferentes conjuntos de usuários.

Modelos do TensorFlow Lite

Os modelos do TensorFlow Lite são modelos de ML otimizados para execução em dispositivos móveis. Para receber um modelo do TensorFlow Lite:

Antes de começar

  1. Adicione o Firebase ao seu projeto para Android, caso ainda não tenha feito isso.
  2. No arquivo Gradle do módulo (nível do app) (geralmente <project>/<app-module>/build.gradle.kts ou <project>/<app-module>/build.gradle), adicione a dependência da biblioteca Android de download de modelos do Firebase ML. Para gerenciar o controle de versões das bibliotecas, recomendamos usar a BoM do Firebase para Android.

    Além disso, como parte da configuração do download de modelos do Firebase ML, é necessário adicionar o SDK do TensorFlow Lite ao app.

    Kotlin+KTX

    dependencies {
        // Import the BoM for the Firebase platform
        implementation(platform("com.google.firebase:firebase-bom:32.3.1"))
    
        // Add the dependency for the Firebase ML model downloader library
        // When using the BoM, you don't specify versions in Firebase library dependencies
        implementation("com.google.firebase:firebase-ml-modeldownloader-ktx")
    // Also add the dependency for the TensorFlow Lite library and specify its version implementation("org.tensorflow:tensorflow-lite:2.3.0")
    }

    Com a BoM do Firebase para Android, seu app sempre vai usar versões compatíveis das bibliotecas Android do Firebase.

    (Alternativa) Adicionar dependências das bibliotecas do Firebase sem usar a BoM

    Se você preferir não usar a BoM do Firebase, especifique cada versão das bibliotecas do Firebase na linha de dependência correspondente.

    Se você usa várias bibliotecas do Firebase no seu app, recomendamos utilizar a BoM para gerenciar as versões delas, porque isso ajuda a garantir a compatibilidade de todas as bibliotecas.

    dependencies {
        // Add the dependency for the Firebase ML model downloader library
        // When NOT using the BoM, you must specify versions in Firebase library dependencies
        implementation("com.google.firebase:firebase-ml-modeldownloader-ktx:24.1.3")
    // Also add the dependency for the TensorFlow Lite library and specify its version implementation("org.tensorflow:tensorflow-lite:2.3.0")
    }

    Java

    dependencies {
        // Import the BoM for the Firebase platform
        implementation(platform("com.google.firebase:firebase-bom:32.3.1"))
    
        // Add the dependency for the Firebase ML model downloader library
        // When using the BoM, you don't specify versions in Firebase library dependencies
        implementation("com.google.firebase:firebase-ml-modeldownloader")
    // Also add the dependency for the TensorFlow Lite library and specify its version implementation("org.tensorflow:tensorflow-lite:2.3.0")
    }

    Com a BoM do Firebase para Android, seu app sempre vai usar versões compatíveis das bibliotecas do Firebase para Android.

    (Alternativa) Adicionar dependências das bibliotecas do Firebase sem usar a BoM

    Se você preferir não usar a BoM do Firebase, especifique cada versão das bibliotecas do Firebase na linha de dependência correspondente.

    Se você usa várias bibliotecas do Firebase no seu app, recomendamos utilizar a BoM para gerenciar as versões delas, porque isso ajuda a garantir a compatibilidade de todas as bibliotecas.

    dependencies {
        // Add the dependency for the Firebase ML model downloader library
        // When NOT using the BoM, you must specify versions in Firebase library dependencies
        implementation("com.google.firebase:firebase-ml-modeldownloader:24.1.3")
    // Also add the dependency for the TensorFlow Lite library and specify its version implementation("org.tensorflow:tensorflow-lite:2.3.0")
    }
  3. No manifesto do app, informe que a permissão INTERNET é necessária:
    <uses-permission android:name="android.permission.INTERNET" />

1. Implantar seu modelo

Implante seus modelos personalizados do TensorFlow usando o Console do Firebase ou os SDKs Admin para Python e Node.js do Firebase. Consulte Implantar e gerenciar modelos personalizados.

Depois de adicionar um modelo personalizado ao seu projeto do Firebase, você pode referenciá-lo nos seus apps usando o nome especificado. A qualquer momento é possível implantar um novo modelo do TensorFlow Lite e fazer o download do novo modelo nos dispositivos dos usuários chamando getModel() (veja abaixo).

2. Fazer o download do modelo no dispositivo e inicializar um intérprete do TensorFlow Lite

Para usar o modelo do TensorFlow Lite no app, primeiro use o SDK do Firebase ML para fazer o download da versão mais recente do modelo no dispositivo. Em seguida, instancie um intérprete do TensorFlow Lite com o modelo.

Para iniciar o download do modelo, chame o método getModel() da ferramenta de download de modelos, especificando o nome atribuído ao modelo durante o upload, se você quer sempre fazer o download do modelo mais recente e as condições em que quer permitir o download.

Você pode escolher entre três comportamentos de download:

Tipo de download Descrição
LOCAL_MODEL Consiga o modelo local do dispositivo. Se não houver um modelo local disponível, o comportamento será como LATEST_MODEL. Use esse tipo de download se você não tiver interesse em verificar as atualizações do modelo. Por exemplo, você está usando o Configuração remota para recuperar nomes de modelos e sempre faz upload de modelos usando novos nomes (recomendado).
LOCAL_MODEL_UPDATE_IN_BACKGROUND Consiga o modelo local do dispositivo e comece a atualizá-lo em segundo plano. Se não houver um modelo local disponível, o comportamento será como LATEST_MODEL.
LATEST_MODEL Receba o modelo mais recente. Se o modelo local for a versão mais recente, retornará o modelo local. Caso contrário, faça o download do modelo mais recente. Esse comportamento será bloqueado até o download da versão mais recente (não recomendado). Use esse comportamento somente quando precisar da versão mais recente.

Desative o recurso relacionado ao modelo, por exemplo, usar o recurso esmaecido ou ocultar parte da IU, até confirmar que o download do modelo foi concluído.

Kotlin+KTX

val conditions = CustomModelDownloadConditions.Builder()
        .requireWifi()  // Also possible: .requireCharging() and .requireDeviceIdle()
        .build()
FirebaseModelDownloader.getInstance()
        .getModel("your_model", DownloadType.LOCAL_MODEL_UPDATE_IN_BACKGROUND,
            conditions)
        .addOnSuccessListener { model: CustomModel? ->
            // Download complete. Depending on your app, you could enable the ML
            // feature, or switch from the local model to the remote model, etc.

            // The CustomModel object contains the local path of the model file,
            // which you can use to instantiate a TensorFlow Lite interpreter.
            val modelFile = model?.file
            if (modelFile != null) {
                interpreter = Interpreter(modelFile)
            }
        }

Java

CustomModelDownloadConditions conditions = new CustomModelDownloadConditions.Builder()
    .requireWifi()  // Also possible: .requireCharging() and .requireDeviceIdle()
    .build();
FirebaseModelDownloader.getInstance()
    .getModel("your_model", DownloadType.LOCAL_MODEL_UPDATE_IN_BACKGROUND, conditions)
    .addOnSuccessListener(new OnSuccessListener<CustomModel>() {
      @Override
      public void onSuccess(CustomModel model) {
        // Download complete. Depending on your app, you could enable the ML
        // feature, or switch from the local model to the remote model, etc.

        // The CustomModel object contains the local path of the model file,
        // which you can use to instantiate a TensorFlow Lite interpreter.
        File modelFile = model.getFile();
        if (modelFile != null) {
            interpreter = new Interpreter(modelFile);
        }
      }
    });

Muitos apps iniciam a tarefa de download no código de inicialização, mas você pode fazer isso a qualquer momento antes de precisar usar o modelo.

3. Realizar inferência em dados de entrada

Gerar formas de entrada e saída do modelo

O intérprete de modelos do TensorFlow Lite utiliza como entrada e produz como saída uma ou mais matrizes multidimensionais. Essas matrizes contêm valores byte, int, long ou float. Antes de transmitir dados para um modelo ou usar o resultado dele, você precisa saber o número e as dimensões ("forma") das matrizes usadas pelo modelo.

Se você mesmo criou o modelo ou se o formato de entrada e saída do modelo está documentado, talvez já tenha essas informações. Se você não sabe qual é a forma e o tipo de dados da entrada e da saída do seu modelo, pode usar o intérprete do TensorFlow Lite para inspecionar seu modelo. Exemplo:

Python

import tensorflow as tf

interpreter = tf.lite.Interpreter(model_path="your_model.tflite")
interpreter.allocate_tensors()

# Print input shape and type
inputs = interpreter.get_input_details()
print('{} input(s):'.format(len(inputs)))
for i in range(0, len(inputs)):
    print('{} {}'.format(inputs[i]['shape'], inputs[i]['dtype']))

# Print output shape and type
outputs = interpreter.get_output_details()
print('\n{} output(s):'.format(len(outputs)))
for i in range(0, len(outputs)):
    print('{} {}'.format(outputs[i]['shape'], outputs[i]['dtype']))

Exemplo de saída:

1 input(s):
[  1 224 224   3] <class 'numpy.float32'>

1 output(s):
[1 1000] <class 'numpy.float32'>

Executar o intérprete

Depois de determinar o formato da entrada e da saída do modelo, colete os dados de entrada e execute todas as transformações necessárias nos dados para ter uma entrada da forma certa para o modelo.

Por exemplo, se você tiver um modelo de classificação de imagem com uma forma de entrada de valores de ponto flutuante [1 224 224 3], será possível gerar um ByteBuffer de entrada usando um objeto Bitmap, conforme exibido no exemplo a seguir:

Kotlin+KTX

val bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true)
val input = ByteBuffer.allocateDirect(224*224*3*4).order(ByteOrder.nativeOrder())
for (y in 0 until 224) {
    for (x in 0 until 224) {
        val px = bitmap.getPixel(x, y)

        // Get channel values from the pixel value.
        val r = Color.red(px)
        val g = Color.green(px)
        val b = Color.blue(px)

        // Normalize channel values to [-1.0, 1.0]. This requirement depends on the model.
        // For example, some models might require values to be normalized to the range
        // [0.0, 1.0] instead.
        val rf = (r - 127) / 255f
        val gf = (g - 127) / 255f
        val bf = (b - 127) / 255f

        input.putFloat(rf)
        input.putFloat(gf)
        input.putFloat(bf)
    }
}

Java

Bitmap bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true);
ByteBuffer input = ByteBuffer.allocateDirect(224 * 224 * 3 * 4).order(ByteOrder.nativeOrder());
for (int y = 0; y < 224; y++) {
    for (int x = 0; x < 224; x++) {
        int px = bitmap.getPixel(x, y);

        // Get channel values from the pixel value.
        int r = Color.red(px);
        int g = Color.green(px);
        int b = Color.blue(px);

        // Normalize channel values to [-1.0, 1.0]. This requirement depends
        // on the model. For example, some models might require values to be
        // normalized to the range [0.0, 1.0] instead.
        float rf = (r - 127) / 255.0f;
        float gf = (g - 127) / 255.0f;
        float bf = (b - 127) / 255.0f;

        input.putFloat(rf);
        input.putFloat(gf);
        input.putFloat(bf);
    }
}

Em seguida, aloque um ByteBuffer grande o suficiente para conter a saída do modelo e transmitir os buffers de entrada e de saída para o método run() do intérprete do TensorFlow Lite. Por exemplo, para uma forma de saída de valores de ponto flutuante [1 1000]:

Kotlin+KTX

val bufferSize = 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE
val modelOutput = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder())
interpreter?.run(input, modelOutput)

Java

int bufferSize = 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE;
ByteBuffer modelOutput = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder());
interpreter.run(input, modelOutput);

Como você usa a saída depende do modelo que está usando.

Por exemplo, se você estiver realizando uma classificação, como próxima etapa, será possível mapear os índices do resultado para os rótulos representados:

Kotlin+KTX

modelOutput.rewind()
val probabilities = modelOutput.asFloatBuffer()
try {
    val reader = BufferedReader(
            InputStreamReader(assets.open("custom_labels.txt")))
    for (i in probabilities.capacity()) {
        val label: String = reader.readLine()
        val probability = probabilities.get(i)
        println("$label: $probability")
    }
} catch (e: IOException) {
    // File not found?
}

Java

modelOutput.rewind();
FloatBuffer probabilities = modelOutput.asFloatBuffer();
try {
    BufferedReader reader = new BufferedReader(
            new InputStreamReader(getAssets().open("custom_labels.txt")));
    for (int i = 0; i < probabilities.capacity(); i++) {
        String label = reader.readLine();
        float probability = probabilities.get(i);
        Log.i(TAG, String.format("%s: %1.4f", label, probability));
    }
} catch (IOException e) {
    // File not found?
}

Apêndice: segurança do modelo

Independentemente de como você disponibiliza seus modelos do TensorFlow Lite para o Firebase ML, o Firebase ML os armazena localmente no formato padrão protobuf serializado.

Teoricamente, isso significa que qualquer pessoa pode copiar seu modelo. No entanto, na prática, a maioria dos modelos é tão específica de cada aplicativo e ofuscada por otimizações que o risco é comparável ao de concorrentes desmontando e reutilizando seu código. Apesar disso, você deve estar ciente desse risco antes de usar um modelo personalizado no seu app.

Na API Android nível 21 (Lollipop) e versões mais recentes, o download do modelo é feito em um diretório que não é incluído no backup automático.

Na API Android nível 20 ou versões mais antigas, o download do modelo é feito em um diretório chamado com.google.firebase.ml.custom.models em um armazenamento interno particular do app. Se você ativou o backup de arquivos usando BackupAgent, pode optar por excluir esse diretório.