Usa un modelo de TensorFlow Lite para las inferencias con el Kit de AA en Android

Puedes usar el Kit de AA para realizar inferencias en el dispositivo con un modelo de TensorFlow Lite.

Esta API requiere un SDK de Android nivel 16 (Jelly Bean) o posterior.

Antes de comenzar

  1. Si aún no lo has hecho, agrega Firebase a tu proyecto de Android.
  2. En tu archivo build.gradle de nivel de proyecto, asegúrate de incluir el repositorio Maven de Google en las secciones buildscript y allprojects.
  3. Agrega las dependencias para las bibliotecas de Android del Kit de AA al archivo Gradle (generalmente app/build.gradle) de tu módulo (nivel de app):
    apply plugin: 'com.android.application'
    apply plugin: 'com.google.gms.google-services'
    
    dependencies {
      // ...
    
      implementation 'com.google.firebase:firebase-ml-model-interpreter:22.0.3'
    }
    
  4. Convierte el modelo de TensorFlow que deseas usar al formato de TensorFlow Lite. Consulta TOCO: Convertidor de optimización de TensorFlow Lite.

Aloja o empaqueta tu modelo

Si quieres usar un modelo de TensorFlow Lite para las inferencias en tu app, primero debes hacer que esté disponible para el Kit de AA. El Kit de AA puede usar modelos de TensorFlow Lite alojados de forma remota con Firebase, empaquetados con el objeto binario de la app o ambas opciones.

Si alojas un modelo en Firebase, puedes actualizarlo sin lanzar una nueva versión de la app, y puedes usar Remote Config y A/B Testing para entregar de forma dinámica diferentes modelos a distintos conjuntos de usuarios.

Si eliges proporcionar el modelo únicamente mediante el alojamiento con Firebase y no empaquetarlo con tu app, puedes reducir el tamaño de descarga inicial de la app. Sin embargo, ten en cuenta que, si el modelo no se empaqueta con la aplicación, las funcionalidades relacionadas con el modelo no estarán disponibles hasta que la app lo descargue por primera vez.

Si empaquetas el modelo con la app, puedes asegurarte de que las funciones de AA de tu app estén activas cuando el modelo alojado en Firebase no esté disponible.

Cómo alojar modelos en Firebase

Para alojar tu modelo de TensorFlow Lite en Firebase, sigue estos pasos:

  1. En la sección Kit de AA de Firebase console, haz clic en la pestaña Personalizado.
  2. Haz clic en Agregar modelo personalizado (o Agregar otro modelo).
  3. Ingresa el nombre que se usará para identificar el modelo en tu proyecto de Firebase. Luego, sube el archivo del modelo de TensorFlow Lite (generalmente, con la extensión .tflite o .lite).
  4. En el manifiesto de tu app, declara que se requiera el permiso de INTERNET:
    <uses-permission android:name="android.permission.INTERNET" />
    

Después de agregar un modelo personalizado al proyecto de Firebase, podrás usar el nombre que especificaste para hacer referencia al modelo en tus apps. Puedes subir un nuevo modelo de TensorFlow Lite en cualquier momento. Tu app descargará el nuevo modelo y comenzará a usarlo cuando se reinicie. Podrás definir las condiciones del dispositivo que tu app requiera para intentar actualizar el modelo (ver a continuación).

Cómo empaquetar modelos con una app

Para empaquetar tu modelo de TensorFlow Lite con tu app, copia el archivo del modelo (por lo general tiene la extensión .tflite o .lite) a la carpeta assets/ de tu app. Es posible que primero debas crear la carpeta. Para ello, haz clic con el botón derecho en la carpeta app/ y, luego, en Nuevo > Carpeta > Carpeta de elementos.

Luego, agrega lo siguiente al archivo build.gradle de tu app para asegurarte de que Gradle no comprima los modelos cuando se compile la app:

android {

    // ...

    aaptOptions {
        noCompress "tflite"  // Your model's file extension: "tflite", "lite", etc.
    }
}

El archivo del modelo se incluirá en el paquete de la app y estará disponible para el Kit de AA como elemento sin procesar.

Carga el modelo

Para usar el modelo de TensorFlow Lite en tu app, primero configura el Kit de AA con las ubicaciones donde se encuentra disponible el modelo: de manera remota mediante Firebase, en el almacenamiento local, o en ambos. Si especificas un modelo local y uno remoto, puedes usar el modelo remoto si está disponible y, en caso de no estarlo, puedes usar el modelo almacenado localmente.

Configura un modelo alojado en Firebase

Si alojaste tu modelo con Firebase, crea un objeto FirebaseCustomRemoteModel y especifica el nombre que asignaste al modelo cuando lo publicaste:

Java

FirebaseCustomRemoteModel remoteModel =
        new FirebaseCustomRemoteModel.Builder("your_model").build();

Kotlin+KTX

val remoteModel = FirebaseCustomRemoteModel.Builder("your_model").build()

Luego, inicia la tarea de descarga del modelo y especifica las condiciones en las que deseas permitir la descarga. Si el modelo no está en el dispositivo o si hay una versión más reciente de este, la tarea descargará el modelo de Firebase de forma asíncrona:

Java

FirebaseModelDownloadConditions conditions = new FirebaseModelDownloadConditions.Builder()
        .requireWifi()
        .build();
FirebaseModelManager.getInstance().download(remoteModel, conditions)
        .addOnCompleteListener(new OnCompleteListener<Void>() {
            @Override
            public void onComplete(@NonNull Task<Void> task) {
                // Success.
            }
        });

Kotlin+KTX

val conditions = FirebaseModelDownloadConditions.Builder()
    .requireWifi()
    .build()
FirebaseModelManager.getInstance().download(remoteModel, conditions)
    .addOnCompleteListener {
        // Success.
    }

Muchas apps comienzan la tarea de descarga en su código de inicialización, pero puedes hacerlo en cualquier momento antes de usar el modelo.

Configura un modelo local

Si empaquetaste el modelo con tu app, crea un objeto FirebaseCustomLocalModel y especifica el nombre de archivo del modelo TensorFlow Lite:

Java

FirebaseCustomLocalModel localModel = new FirebaseCustomLocalModel.Builder()
        .setAssetFilePath("your_model.tflite")
        .build();

Kotlin+KTX

val localModel = FirebaseCustomLocalModel.Builder()
    .setAssetFilePath("your_model.tflite")
    .build()

Crea un intérprete a partir de tu modelo

Después de configurar las fuentes de tu modelo, crea un objeto FirebaseModelInterpreter a partir de una de ellas.

Si solo tienes un modelo empaquetado a nivel local, crea un intérprete a partir del objeto FirebaseCustomLocalModel:

Java

FirebaseModelInterpreter interpreter;
try {
    FirebaseModelInterpreterOptions options =
            new FirebaseModelInterpreterOptions.Builder(localModel).build();
    interpreter = FirebaseModelInterpreter.getInstance(options);
} catch (FirebaseMLException e) {
    // ...
}

Kotlin+KTX

val options = FirebaseModelInterpreterOptions.Builder(localModel).build()
val interpreter = FirebaseModelInterpreter.getInstance(options)

Si tienes un modelo alojado de forma remota, comprueba si se descargó antes de ejecutarlo. Puedes verificar el estado de la tarea de descarga del modelo con el método isModelDownloaded() del administrador del modelo.

Aunque solo tienes que confirmar esto antes de ejecutar el intérprete, si tienes un modelo alojado de forma remota y uno empaquetado localmente, tendría sentido realizar esta verificación cuando se crea una instancia de intérprete de modelo: crea un intérprete desde el modelo remoto si se descargó o, en su defecto, desde el modelo local.

Java

FirebaseModelManager.getInstance().isModelDownloaded(remoteModel)
        .addOnSuccessListener(new OnSuccessListener<Boolean>() {
            @Override
            public void onSuccess(Boolean isDownloaded) {
                FirebaseModelInterpreterOptions options;
                if (isDownloaded) {
                    options = new FirebaseModelInterpreterOptions.Builder(remoteModel).build();
                } else {
                    options = new FirebaseModelInterpreterOptions.Builder(localModel).build();
                }
                FirebaseModelInterpreter interpreter = FirebaseModelInterpreter.getInstance(options);
                // ...
            }
        });

Kotlin+KTX

FirebaseModelManager.getInstance().isModelDownloaded(remoteModel)
    .addOnSuccessListener { isDownloaded ->
    val options =
        if (isDownloaded) {
            FirebaseModelInterpreterOptions.Builder(remoteModel).build()
        } else {
            FirebaseModelInterpreterOptions.Builder(localModel).build()
        }
    val interpreter = FirebaseModelInterpreter.getInstance(options)
}

Si solo tienes un modelo alojado de forma remota, debes inhabilitar la funcionalidad relacionada con el modelo, por ejemplo, oculta o inhabilita parte de tu IU, hasta que confirmes que el modelo se descargó. Puedes hacerlo si adjuntas un objeto de escucha al método download() del administrador de modelos:

Java

FirebaseModelManager.getInstance().download(remoteModel, conditions)
        .addOnSuccessListener(new OnSuccessListener<Void>() {
            @Override
            public void onSuccess(Void v) {
              // Download complete. Depending on your app, you could enable
              // the ML feature, or switch from the local model to the remote
              // model, etc.
            }
        });

Kotlin+KTX

FirebaseModelManager.getInstance().download(remoteModel, conditions)
    .addOnCompleteListener {
        // Download complete. Depending on your app, you could enable the ML
        // feature, or switch from the local model to the remote model, etc.
    }

Especifica la entrada y salida del modelo

A continuación, configura los formatos de entrada y salida del intérprete de modelo.

Un modelo de TensorFlow Lite toma como entrada y produce como salida uno o más arreglos multidimensionales. Estos arreglos contienen valores byte, int, long o float. Debes configurar el Kit de AA con la cantidad y las dimensiones ("shape") de los arreglos que usa tu modelo.

Si no conoces la forma y el tipo de datos de la entrada y la salida del modelo, puedes usar el intérprete Python de TensorFlow Lite para inspeccionar tu modelo. Por ejemplo:

import tensorflow as tf

interpreter = tf.lite.Interpreter(model_path="my_model.tflite")
interpreter.allocate_tensors()

# Print input shape and type
print(interpreter.get_input_details()[0]['shape'])  # Example: [1 224 224 3]
print(interpreter.get_input_details()[0]['dtype'])  # Example: <class 'numpy.float32'>

# Print output shape and type
print(interpreter.get_output_details()[0]['shape'])  # Example: [1 1000]
print(interpreter.get_output_details()[0]['dtype'])  # Example: <class 'numpy.float32'>

Después de determinar el formato de entrada y salida del modelo, puedes crear un objeto FirebaseModelInputOutputOptions para configurar el intérprete de modelo de tu app.

Por ejemplo, un modelo de clasificación de imágenes de punto flotante puede tomar como entrada un arreglo Nx 224 x 224 x 3 de valores float que representa un lote de N imágenes de 224 x 224 de tres canales (RGB), y producir como salida una lista de 1,000 valores float, cada uno de los cuales representa la probabilidad de que la imagen pertenezca a una de las 1,000 categorías que predice el modelo.

Para ese tipo de modelo, debes configurar la entrada y la salida del intérprete de modelo como se muestra a continuación:

Java

FirebaseModelInputOutputOptions inputOutputOptions =
        new FirebaseModelInputOutputOptions.Builder()
                .setInputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 224, 224, 3})
                .setOutputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 5})
                .build();

Kotlin+KTX

val inputOutputOptions = FirebaseModelInputOutputOptions.Builder()
        .setInputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1, 224, 224, 3))
        .setOutputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1, 5))
        .build()

Realiza inferencias sobre los datos de entrada

Por último, para realizar inferencias con el modelo, obtén tus datos de entrada y realiza todas las transformaciones en los datos que sean necesarias a fin de obtener un arreglo de entrada con la forma correcta para tu modelo.

Por ejemplo, si tienes un modelo de clasificación de imágenes con una forma de entrada de valores de punto flotante [1 224 224 3], puedes generar un arreglo de entrada a partir de un objeto Bitmap como se muestra en el siguiente ejemplo:

Java

Bitmap bitmap = getYourInputImage();
bitmap = Bitmap.createScaledBitmap(bitmap, 224, 224, true);

int batchNum = 0;
float[][][][] input = new float[1][224][224][3];
for (int x = 0; x < 224; x++) {
    for (int y = 0; y < 224; y++) {
        int pixel = bitmap.getPixel(x, y);
        // Normalize channel values to [-1.0, 1.0]. This requirement varies by
        // model. For example, some models might require values to be normalized
        // to the range [0.0, 1.0] instead.
        input[batchNum][x][y][0] = (Color.red(pixel) - 127) / 128.0f;
        input[batchNum][x][y][1] = (Color.green(pixel) - 127) / 128.0f;
        input[batchNum][x][y][2] = (Color.blue(pixel) - 127) / 128.0f;
    }
}

Kotlin+KTX

val bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true)

val batchNum = 0
val input = Array(1) { Array(224) { Array(224) { FloatArray(3) } } }
for (x in 0..223) {
    for (y in 0..223) {
        val pixel = bitmap.getPixel(x, y)
        // Normalize channel values to [-1.0, 1.0]. This requirement varies by
        // model. For example, some models might require values to be normalized
        // to the range [0.0, 1.0] instead.
        input[batchNum][x][y][0] = (Color.red(pixel) - 127) / 255.0f
        input[batchNum][x][y][1] = (Color.green(pixel) - 127) / 255.0f
        input[batchNum][x][y][2] = (Color.blue(pixel) - 127) / 255.0f
    }
}

Luego, crea un objeto FirebaseModelInputs con los datos de entrada y pásalo junto con la especificación de entrada y salida del modelo al método run del intérprete de modelo:

Java

FirebaseModelInputs inputs = new FirebaseModelInputs.Builder()
        .add(input)  // add() as many input arrays as your model requires
        .build();
firebaseInterpreter.run(inputs, inputOutputOptions)
        .addOnSuccessListener(
                new OnSuccessListener<FirebaseModelOutputs>() {
                    @Override
                    public void onSuccess(FirebaseModelOutputs result) {
                        // ...
                    }
                })
        .addOnFailureListener(
                new OnFailureListener() {
                    @Override
                    public void onFailure(@NonNull Exception e) {
                        // Task failed with an exception
                        // ...
                    }
                });

Kotlin+KTX

val inputs = FirebaseModelInputs.Builder()
        .add(input) // add() as many input arrays as your model requires
        .build()
firebaseInterpreter.run(inputs, inputOutputOptions)
        .addOnSuccessListener { result ->
            // ...
        }
        .addOnFailureListener { e ->
            // Task failed with an exception
            // ...
        }

Si la llamada se ejecuta correctamente, puedes llamar al método getOutput() del objeto transferido al objeto de escucha que detectó el resultado correcto para obtener la salida. Por ejemplo:

Java

float[][] output = result.getOutput(0);
float[] probabilities = output[0];

Kotlin+KTX

val output = result.getOutput<Array<FloatArray>>(0)
val probabilities = output[0]

La manera de usar el resultado depende del modelo que uses.

Por ejemplo, si realizas una clasificación, el paso siguiente puede ser asignar los índices del resultado a las etiquetas que representan:

Java

BufferedReader reader = new BufferedReader(
        new InputStreamReader(getAssets().open("retrained_labels.txt")));
for (int i = 0; i < probabilities.length; i++) {
    String label = reader.readLine();
    Log.i("MLKit", String.format("%s: %1.4f", label, probabilities[i]));
}

Kotlin+KTX

val reader = BufferedReader(
        InputStreamReader(assets.open("retrained_labels.txt")))
for (i in probabilities.indices) {
    val label = reader.readLine()
    Log.i("MLKit", String.format("%s: %1.4f", label, probabilities[i]))
}

Apéndice: Seguridad del modelo

Sin importar cómo hagas que estén disponibles tus modelos de TensorFlow Lite para el Kit de AA, este los almacena en el formato estándar de protobuf serializado en el almacenamiento local.

En teoría, eso significa que cualquier persona puede copiar tu modelo. Sin embargo, en la práctica, la mayoría de los modelos son tan específicos para la aplicación, y ofuscados por las optimizaciones, que el riesgo es comparable a que alguien de la competencia desensamble y vuelva a usar tu código. No obstante, debes estar al tanto de ese riesgo antes de usar un modelo personalizado en tu app.

En la API de Android nivel 21 (Lollipop) o posterior, el modelo se descarga en un directorio excluido de las copias de seguridad automáticas.

En una API de Android nivel 20 o anterior, el modelo se descarga en un directorio llamado com.google.firebase.ml.custom.models en el almacenamiento interno privado de la app. Si habilitas la copia de seguridad con BackupAgent, tienes la opción de excluir este directorio.