Usa un modelo de TensorFlow Lite para realizar inferencias con el Kit de AA en Android

Puedes usar el Kit de AA para realizar inferencias en el dispositivo con un modelo de TensorFlow Lite.

Esta API requiere un SDK de Android nivel 16 (Jelly Bean) o posterior.

Consulta la muestra de inicio rápido del Kit de AA en GitHub para ver un ejemplo de esta API en uso, o bien prueba el codelab.

Antes de comenzar

  1. Si aún no agregaste Firebase a tu app, sigue los pasos en la guía de introducción para hacerlo.
  2. Incluye las dependencias para el Kit de AA en el archivo build.gradle del nivel de tu app:
    dependencies {
      // ...
    
      implementation 'com.google.firebase:firebase-ml-model-interpreter:16.2.2'
    }
    
  3. Convierte el modelo de TensorFlow que deseas usar al formato de TensorFlow Lite (tflite). Consulta TOCO: Convertidor de optimización de TensorFlow Lite.

Aloja o empaqueta tu modelo

Si quieres usar un modelo de TensorFlow Lite para las inferencias en tu app, primero debes hacer que esté disponible para el Kit de AA. El Kit puede usar modelos de TensorFlow Lite alojados de forma remota con Firebase, almacenados de forma local en el dispositivo, o ambas opciones.

Podrás asegurarte de que se usará la versión más reciente del modelo (cuando esté disponible) si lo alojas en Firebase y lo almacenas de forma local. Sin embargo, las funciones del AA de tu app seguirán funcionando cuando no esté disponible el modelo alojado en Firebase.

Seguridad del modelo

Sin importar lo que hagas para que estén disponibles tus modelos de TensorFlow para AA, el Kit de AA los almacena en forma local en el formato estándar de protobuf serializado.

En teoría, eso significa que cualquier persona puede copiar tu modelo. Sin embargo, en la práctica, la mayoría de los modelos son tan específicos para la app, y ofuscados además por las optimizaciones, que el riesgo es comparable a que alguien de la competencia desensamble y vuelva a usar tu código. Con todo, debes estar al tanto de ese riesgo antes de usar un modelo personalizado en tu app.

En la API de Android nivel 21 (Lollipop) o posterior, el modelo se descarga en un directorio excluido de las copias de seguridad automáticas.

En una API de Android nivel 20 o anterior, el modelo se descarga en un directorio llamado com.google.firebase.ml.custom.models en el almacenamiento interno privado de la app. Si habilitas la copia de seguridad a través de BackupAgent, tienes la opción de excluir este directorio.

Cómo alojar modelos en Firebase

Para alojar tu modelo de TensorFlow Lite en Firebase, sigue estos pasos:

  1. Haz clic en la pestaña Personalizado en la sección Kit de AA de Firebase console.
  2. Haz clic en Agregar modelo personalizado (o Agregar otro modelo).
  3. Ingresa el nombre que se usará para identificar el modelo en tu proyecto de Firebase. Luego, sube el archivo .tflite.
  4. En el manifiesto de tu app, declara que se requiera el permiso de INTERNET:
    <uses-permission android:name="android.permission.INTERNET" />
    

Después de agregar un modelo personalizado al proyecto de Firebase, podrás usar el nombre que especificaste para hacer referencia al modelo en tus apps. Puedes subir un nuevo archivo .tflite para un modelo en cualquier momento. Tu app descargará el nuevo modelo y comenzará a usarlo cuando se reinicie. Puedes definir las condiciones del dispositivo requeridas por tu app para intentar actualizar el modelo (ver a continuación).

Cómo hacer que los modelos estén disponibles de forma local

Para que tu modelo de TensorFlow Lite esté disponible de forma local, puedes empaquetarlo con tu app o descargarlo desde tu propio servidor en el tiempo de ejecución.

Para empaquetar el modelo de TensorFlow Lite con tu app, copia el archivo .tflite a la carpeta assets/ de tu app. Es posible que primero debas crear la carpeta. Para ello, haz clic con el botón derecho en la carpeta app/ y, luego, en Nuevo > Carpeta > Carpeta de elementos.

Luego, agrega lo siguiente al archivo de tu proyecto build.gradle:

android {

    // ...

    aaptOptions {
        noCompress "tflite"
    }
}

Se incluirá el archivo .tflite en el paquete de la app y estará disponible para el Kit de AA como elemento sin procesar.

Por otro lado, si alojas el modelo en tu propio servidor, podrás descargarlo al almacenamiento local de tu app en el momento adecuado. De esta forma, el modelo estará disponible para el Kit de AA como un archivo local.

Carga el modelo

Si quieres usar un modelo de TensorFlow Lite para las inferencias, primero especifica las ubicaciones del archivo .tflite.

Si alojaste tu modelo con Firebase, crea un objeto FirebaseCloudModelSource y especifica el nombre que le asignaste cuando lo subiste, las condiciones bajo las que el Kit de AA debería descargarlo en un comienzo y cuándo habrá actualizaciones disponibles.

FirebaseModelDownloadConditions.Builder conditionsBuilder =
        new FirebaseModelDownloadConditions.Builder().requireWifi();
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.N) {
    // Enable advanced conditions on Android Nougat and newer.
    conditionsBuilder = conditionsBuilder
            .requireCharging()
            .requireDeviceIdle();
}
FirebaseModelDownloadConditions conditions = conditionsBuilder.build();

// Build a FirebaseCloudModelSource object by specifying the name you assigned the model
// when you uploaded it in the Firebase console.
FirebaseCloudModelSource cloudSource = new FirebaseCloudModelSource.Builder("my_cloud_model")
        .enableModelUpdates(true)
        .setInitialDownloadConditions(conditions)
        .setUpdatesDownloadConditions(conditions)
        .build();
FirebaseModelManager.getInstance().registerCloudModelSource(cloudSource);

Si empaquetaste el modelo con tu app o lo descargaste desde tu propio host en el tiempo de ejecución, crea un objeto FirebaseLocalModelSource, indica el nombre de archivo del modelo .tflite y si este archivo es un elemento sin procesar (si está empaquetado) o si está en el almacenamiento local (si se descargó en el tiempo de ejecución).

FirebaseLocalModelSource localSource = new FirebaseLocalModelSource.Builder("my_local_model")
        .setAssetFilePath("mymodel.tflite")  // Or setFilePath if you downloaded from your host
        .build();
FirebaseModelManager.getInstance().registerLocalModelSource(localSource);

A continuación, crea un objeto FirebaseModelOptions con el nombre de tu fuente de Cloud, el de la fuente local, o ambos, y úsalo para obtener una instancia de FirebaseModelInterpreter:

FirebaseModelOptions options = new FirebaseModelOptions.Builder()
        .setCloudModelName("my_cloud_model")
        .setLocalModelName("my_local_model")
        .build();
FirebaseModelInterpreter firebaseInterpreter =
        FirebaseModelInterpreter.getInstance(options);

Si especificas una fuente de modelo de Cloud y una fuente de modelo local, el intérprete de modelo usará la de Cloud cuando esté disponible; si no, usará la local.

Especifica la entrada y el resultado del modelo

A continuación, debes especificar el formato de entrada y resultado del modelo. Para ello, crea un objeto FirebaseModelInputOutputOptions.

Un modelo de TensorFlow Lite toma como entrada y produce como salida uno o más arreglos multidimensionales. Estos arreglos contienen valores byte, int, long o float. Debes configurar el Kit de AA con la cantidad y las dimensiones ("shape") de las matrices que usa tu modelo.

Por ejemplo, un modelo de clasificación de imágenes cuantificado podría tomar como entrada una matriz de bytes de Nx224x224x3, que representa un lote de imágenes de N224x224 con color verdadero (24 bits), y produce una salida de 1,000 valores byte, que representan la probabilidad de que la imagen pertenezca a cada una de las 1,000 categorías que predice el modelo.

FirebaseModelInputOutputOptions inputOutputOptions =
    new FirebaseModelInputOutputOptions.Builder()
        .setInputFormat(0, FirebaseModelDataType.BYTE, new int[]{1, 224, 224, 3})
        .setOutputFormat(0, FirebaseModelDataType.BYTE, new int[]{1, 1000})
        .build();

O bien, un modelo de clasificación de imágenes de coma flotante podría tomar como entrada una matriz de floats de Nx224x224x3, que representa un lote de imágenes RGB de N224x224:

FirebaseModelInputOutputOptions inputOutputOptions =
    new FirebaseModelInputOutputOptions.Builder()
        .setInputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 224, 224, 3})
        .setOutputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 1000})
        .build();

Realiza inferencias sobre los datos de entrada

Por último, para realizar inferencias con el modelo, obtén tus datos de entrada (por ejemplo, una imagen capturada con la cámara del dispositivo):

// Quantized model:
byte[][][][] input = new byte[1][224][224][3];
input = getYourInputData();

Si estás usando un modelo de coma flotante con datos de imagen como entrada, es posible que debas convertir esos datos al formato de coma flotante:

// Floating-point model:
byte[][][][] prenormalizedInput = new byte[1][224][224][3];
prenormalizedInput = getYourInputData();

float[][][][] input = new float[1][224][224][3];
for (int b = 0; b < 1; b++) {
    for (int x = 0; x < 224; x++) {
        for (int y = 0; y < 224; y++) {
            for (int ch = 0; ch < 3; ch++) {
                // Normalize channel values to [-1.0, 1.0]
                input[b][x][y][ch] =
                        ((float) prenormalizedInput[b][x][y][ch] / 255.0f - 0.5f) * 2;
            }
        }
    }
}

Luego, crea un objeto FirebaseModelInputs con tus datos de entrada y pásalo junto con la especificación de la entrada y el resultado del modelo al método run del intérprete del modelo:

FirebaseModelInputs inputs = new FirebaseModelInputs.Builder()
    .add(input)  // add() as many input arrays as your model requires
    .build();
Task<FirebaseModelOutputs> result =
    firebaseInterpreter.run(inputs, inputOutputOptions)
        .addOnSuccessListener(
          new OnSuccessListener<FirebaseModelOutputs>() {
            @Override
            public void onSuccess(FirebaseModelOutputs result) {
              // ...
            }
          })
        .addOnFailureListener(
          new OnFailureListener() {
            @Override
            public void onFailure(@NonNull Exception e) {
              // Task failed with an exception
              // ...
            }
          });

Para obtener el resultado, llama al método getOutput() del objeto que se pasa al agente de escucha que detectó el resultado correcto. Por ejemplo:

// Quantized model:
byte[][] output = result.<byte[][]>getOutput(0);
byte[] probabilities = output[0];

// Floating-point model:
float[][] output = result.<float[][]>getOutput(0);
float[] probabilities = output[0];

La manera de usar el resultado depende del modelo que uses. Por ejemplo, si realizas una clasificación, el paso siguiente podría ser asignar los índices del resultado a las etiquetas que representan.

Enviar comentarios sobre…

¿Necesitas ayuda? Visita nuestra página de asistencia.