Utilizzare un modello TensorFlow Lite per l'inferenza con ML Kit su Android

Puoi utilizzare ML Kit per eseguire l'inferenza on-device con un modello TensorFlow Lite.

Questa API richiede Android SDK 16 (Jelly Bean) o versioni successive.

Prima di iniziare

  1. Se non lo hai già fatto, aggiungi Firebase al tuo progetto Android.
  2. Aggiungi le dipendenze per le librerie Android di ML Kit al file Gradle del tuo modulo (a livello di app) (di solito app/build.gradle):
    apply plugin: 'com.android.application'
    apply plugin: 'com.google.gms.google-services'
    
    dependencies {
      // ...
    
      implementation 'com.google.firebase:firebase-ml-model-interpreter:22.0.3'
    }
  3. Converti il modello TensorFlow che vuoi utilizzare in formato TensorFlow Lite. Consulta TOCO: TensorFlow Lite Optimizing Converter.

Ospittare o raggruppare il modello

Prima di poter utilizzare un modello TensorFlow Lite per l'inferenza nella tua app, devi renderlo disponibile per ML Kit. ML Kit può utilizzare i modelli TensorFlow Lite ospitati in remoto utilizzando Firebase, inclusi nel file binario dell'app o entrambi.

Se ospiti un modello su Firebase, puoi aggiornarlo senza rilasciare una nuova versione dell'app e puoi utilizzare Remote Config e A/B Testing per pubblicare dinamicamente modelli diversi per gruppi diversi di utenti.

Se scegli di fornire solo il modello ospitandolo su Firebase e di non includerlo nella tua app, puoi ridurre le dimensioni del download iniziale dell'app. Tuttavia, tieni presente che se il modello non è incluso nella tua app, qualsiasi funzionalità correlata al modello non sarà disponibile finché l'app non lo scarica per la prima volta.

Se combini il modello con l'app, puoi assicurarti che le funzionalità di ML dell'app continuino a funzionare anche quando il modello ospitato su Firebase non è disponibile.

Ospitare i modelli su Firebase

Per ospitare il modello TensorFlow Lite su Firebase:

  1. Nella sezione ML Kit della console Firebase, fai clic sulla scheda Personalizzata.
  2. Fai clic su Aggiungi modello personalizzato (o Aggiungi un altro modello).
  3. Specifica un nome che verrà utilizzato per identificare il modello nel progetto Firebase, quindi carica il file del modello TensorFlow Lite (di solito termina con .tflite o .lite).
  4. Nel file manifest dell'app, dichiara che l'autorizzazione INTERNET è obbligatoria:
    <uses-permission android:name="android.permission.INTERNET" />

Dopo aver aggiunto un modello personalizzato al tuo progetto Firebase, puoi fare riferimento al modello nelle tue app utilizzando il nome specificato. Puoi caricare un nuovo modello TensorFlow Lite in qualsiasi momento e l'app lo scaricherà e inizierà a utilizzarlo al successivo riavvio. Puoi definire le condizioni del dispositivo necessarie per consentire all'app di tentare di aggiornare il modello (vedi di seguito).

Raggruppare i modelli con un'app

Per raggruppare il modello TensorFlow Lite con l'app, copia il file del modello (in genere termina con .tflite o .lite) nella cartella assets/ dell'app. Potresti dover prima creare la cartella facendo clic con il tasto destro del mouse sulla cartella app/, quindi su Nuovo > Cartella > Cartella asset.

Poi, aggiungi quanto segue al file build.gradle dell'app per assicurarti che Gradle non comprima i modelli durante la compilazione dell'app:

android {

    // ...

    aaptOptions {
        noCompress "tflite"  // Your model's file extension: "tflite", "lite", etc.
    }
}

Il file del modello verrà incluso nel pacchetto dell'app e sarà disponibile per ML Kit come asset non elaborato.

Carica il modello

Per utilizzare il modello TensorFlow Lite nella tua app, configura prima ML Kit con le posizioni in cui è disponibile il modello: da remoto utilizzando Firebase, nello spazio di archiviazione locale o in entrambe le posizioni. Se specifichi sia un modello locale che uno remoto, puoi utilizzare il modello remoto se è disponibile e utilizzare come fallback il modello archiviato localmente se il modello remoto non è disponibile.

Configurare un modello ospitato su Firebase

Se hai ospitato il modello con Firebase, crea un oggetto FirebaseCustomRemoteModel specificando il nome che hai assegnato al modello quando lo hai caricato:

Java

FirebaseCustomRemoteModel remoteModel =
        new FirebaseCustomRemoteModel.Builder("your_model").build();

Kotlin

val remoteModel = FirebaseCustomRemoteModel.Builder("your_model").build()

Quindi, avvia l'attività di download del modello, specificando le condizioni in cui vuoi consentire il download. Se il modello non è sul dispositivo o se è disponibile una versione più recente, l'attività lo scarica in modo asincrono da Firebase:

Java

FirebaseModelDownloadConditions conditions = new FirebaseModelDownloadConditions.Builder()
        .requireWifi()
        .build();
FirebaseModelManager.getInstance().download(remoteModel, conditions)
        .addOnCompleteListener(new OnCompleteListener<Void>() {
            @Override
            public void onComplete(@NonNull Task<Void> task) {
                // Success.
            }
        });

Kotlin

val conditions = FirebaseModelDownloadConditions.Builder()
    .requireWifi()
    .build()
FirebaseModelManager.getInstance().download(remoteModel, conditions)
    .addOnCompleteListener {
        // Success.
    }

Molte app avviano l'attività di download nel codice di inizializzazione, ma puoi farlo in qualsiasi momento prima di dover utilizzare il modello.

Configurare un modello locale

Se hai aggregato il modello con la tua app, crea un oggetto FirebaseCustomLocalModel specificando il nome file del modello TensorFlow Lite:

Java

FirebaseCustomLocalModel localModel = new FirebaseCustomLocalModel.Builder()
        .setAssetFilePath("your_model.tflite")
        .build();

Kotlin

val localModel = FirebaseCustomLocalModel.Builder()
    .setAssetFilePath("your_model.tflite")
    .build()

Creare un interprete dal modello

Dopo aver configurato le origini del modello, crea un oggetto FirebaseModelInterpreter da una di queste.

Se hai solo un modello in bundle locale, crea un interprete dall'oggetto FirebaseCustomLocalModel:

Java

FirebaseModelInterpreter interpreter;
try {
    FirebaseModelInterpreterOptions options =
            new FirebaseModelInterpreterOptions.Builder(localModel).build();
    interpreter = FirebaseModelInterpreter.getInstance(options);
} catch (FirebaseMLException e) {
    // ...
}

Kotlin

val options = FirebaseModelInterpreterOptions.Builder(localModel).build()
val interpreter = FirebaseModelInterpreter.getInstance(options)

Se hai un modello ospitato in remoto, dovrai verificare che sia stato scaricato prima di eseguirlo. Puoi controllare lo stato del compito di download del modello utilizzando il metodo isModelDownloaded() del gestore dei modelli.

Anche se devi confermarlo solo prima di eseguire l'interprete, se hai sia un modello ospitato in remoto sia un modello in bundle locale, potrebbe essere sensato eseguire questo controllo durante l'inizializzazione dell'interprete del modello: crea un interprete dal modello remoto se è stato scaricato e dal modello locale in caso contrario.

Java

FirebaseModelManager.getInstance().isModelDownloaded(remoteModel)
        .addOnSuccessListener(new OnSuccessListener<Boolean>() {
            @Override
            public void onSuccess(Boolean isDownloaded) {
                FirebaseModelInterpreterOptions options;
                if (isDownloaded) {
                    options = new FirebaseModelInterpreterOptions.Builder(remoteModel).build();
                } else {
                    options = new FirebaseModelInterpreterOptions.Builder(localModel).build();
                }
                FirebaseModelInterpreter interpreter = FirebaseModelInterpreter.getInstance(options);
                // ...
            }
        });

Kotlin

FirebaseModelManager.getInstance().isModelDownloaded(remoteModel)
    .addOnSuccessListener { isDownloaded -> 
    val options =
        if (isDownloaded) {
            FirebaseModelInterpreterOptions.Builder(remoteModel).build()
        } else {
            FirebaseModelInterpreterOptions.Builder(localModel).build()
        }
    val interpreter = FirebaseModelInterpreter.getInstance(options)
}

Se hai solo un modello ospitato in remoto, devi disattivare le funzionalità correlate al modello, ad esempio disattivare o nascondere parte dell'interfaccia utente, finché non confermi che il modello è stato scaricato. Puoi farlo collegando un ascoltatore al metodo download() del gestore del modello:

Java

FirebaseModelManager.getInstance().download(remoteModel, conditions)
        .addOnSuccessListener(new OnSuccessListener<Void>() {
            @Override
            public void onSuccess(Void v) {
              // Download complete. Depending on your app, you could enable
              // the ML feature, or switch from the local model to the remote
              // model, etc.
            }
        });

Kotlin

FirebaseModelManager.getInstance().download(remoteModel, conditions)
    .addOnCompleteListener {
        // Download complete. Depending on your app, you could enable the ML
        // feature, or switch from the local model to the remote model, etc.
    }

Specifica l'input e l'output del modello

Poi, configura i formati di input e output dell'interprete del modello.

Un modello TensorFlow Lite riceve come input e produce come output uno o più array multidimensionali. Questi array contengono valori byte, int, long o float. Devi configurare ML Kit con il numero e le dimensioni ("forma") degli array utilizzati dal tuo modello.

Se non conosci la forma e il tipo di dati di input e output del tuo modello, puoi utilizzare l'interprete Python di TensorFlow Lite per ispezionarlo. Ad esempio:

import tensorflow as tf

interpreter = tf.lite.Interpreter(model_path="my_model.tflite")
interpreter.allocate_tensors()

# Print input shape and type
print(interpreter.get_input_details()[0]['shape'])  # Example: [1 224 224 3]
print(interpreter.get_input_details()[0]['dtype'])  # Example: <class 'numpy.float32'>

# Print output shape and type
print(interpreter.get_output_details()[0]['shape'])  # Example: [1 1000]
print(interpreter.get_output_details()[0]['dtype'])  # Example: <class 'numpy.float32'>

Dopo aver determinato il formato di input e output del modello, puoi configurare l'interprete del modello dell'app creando un oggetto FirebaseModelInputOutputOptions.

Ad esempio, un modello di classificazione delle immagini a virgola mobile potrebbe prendere come input un array di valori float di Nx224x224x3, che rappresenta un batch di N immagini a tre canali (RGB) di 224 x 224 e produrre come output un elenco di 1000 valori float, ciascuno dei quali rappresenta la probabilità che l'immagine appartenga a una delle 1000 categorie previste dal modello.

Per un modello di questo tipo, devi configurare l'input e l'output dell'interprete del modello come mostrato di seguito:

Java

FirebaseModelInputOutputOptions inputOutputOptions =
        new FirebaseModelInputOutputOptions.Builder()
                .setInputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 224, 224, 3})
                .setOutputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 5})
                .build();

Kotlin

val inputOutputOptions = FirebaseModelInputOutputOptions.Builder()
        .setInputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1, 224, 224, 3))
        .setOutputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1, 5))
        .build()

Eseguire l'inferenza sui dati di input

Infine, per eseguire l'inferenza utilizzando il modello, recupera i dati di input ed esegui tutte le trasformazioni necessarie per ottenere un array di input della forma corretta per il modello.

Ad esempio, se hai un modello di classificazione delle immagini con una forma di input di valori di tipo floating point [1 224 224 3], puoi generare un array di input da un oggetto Bitmap come mostrato nell'esempio seguente:

Java

Bitmap bitmap = getYourInputImage();
bitmap = Bitmap.createScaledBitmap(bitmap, 224, 224, true);

int batchNum = 0;
float[][][][] input = new float[1][224][224][3];
for (int x = 0; x < 224; x++) {
    for (int y = 0; y < 224; y++) {
        int pixel = bitmap.getPixel(x, y);
        // Normalize channel values to [-1.0, 1.0]. This requirement varies by
        // model. For example, some models might require values to be normalized
        // to the range [0.0, 1.0] instead.
        input[batchNum][x][y][0] = (Color.red(pixel) - 127) / 128.0f;
        input[batchNum][x][y][1] = (Color.green(pixel) - 127) / 128.0f;
        input[batchNum][x][y][2] = (Color.blue(pixel) - 127) / 128.0f;
    }
}

Kotlin

val bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true)

val batchNum = 0
val input = Array(1) { Array(224) { Array(224) { FloatArray(3) } } }
for (x in 0..223) {
    for (y in 0..223) {
        val pixel = bitmap.getPixel(x, y)
        // Normalize channel values to [-1.0, 1.0]. This requirement varies by
        // model. For example, some models might require values to be normalized
        // to the range [0.0, 1.0] instead.
        input[batchNum][x][y][0] = (Color.red(pixel) - 127) / 255.0f
        input[batchNum][x][y][1] = (Color.green(pixel) - 127) / 255.0f
        input[batchNum][x][y][2] = (Color.blue(pixel) - 127) / 255.0f
    }
}

Quindi, crea un oggetto FirebaseModelInputs con i dati di input e passalo, insieme alle specifiche di input e output del modello, al metodo run dell'interprete del modello:

Java

FirebaseModelInputs inputs = new FirebaseModelInputs.Builder()
        .add(input)  // add() as many input arrays as your model requires
        .build();
firebaseInterpreter.run(inputs, inputOutputOptions)
        .addOnSuccessListener(
                new OnSuccessListener<FirebaseModelOutputs>() {
                    @Override
                    public void onSuccess(FirebaseModelOutputs result) {
                        // ...
                    }
                })
        .addOnFailureListener(
                new OnFailureListener() {
                    @Override
                    public void onFailure(@NonNull Exception e) {
                        // Task failed with an exception
                        // ...
                    }
                });

Kotlin

val inputs = FirebaseModelInputs.Builder()
        .add(input) // add() as many input arrays as your model requires
        .build()
firebaseInterpreter.run(inputs, inputOutputOptions)
        .addOnSuccessListener { result ->
            // ...
        }
        .addOnFailureListener { e ->
            // Task failed with an exception
            // ...
        }

Se la chiamata va a buon fine, puoi ottenere l'output chiamando il metodo getOutput() dell'oggetto passato all'ascoltatore di eventi di successo. Ad esempio:

Java

float[][] output = result.getOutput(0);
float[] probabilities = output[0];

Kotlin

val output = result.getOutput<Array<FloatArray>>(0)
val probabilities = output[0]

La modalità di utilizzo dell'output dipende dal modello in uso.

Ad esempio, se stai eseguendo la classificazione, come passaggio successivo potresti mappare gli indici del risultato alle etichette che rappresentano:

Java

BufferedReader reader = new BufferedReader(
        new InputStreamReader(getAssets().open("retrained_labels.txt")));
for (int i = 0; i < probabilities.length; i++) {
    String label = reader.readLine();
    Log.i("MLKit", String.format("%s: %1.4f", label, probabilities[i]));
}

Kotlin

val reader = BufferedReader(
        InputStreamReader(assets.open("retrained_labels.txt")))
for (i in probabilities.indices) {
    val label = reader.readLine()
    Log.i("MLKit", String.format("%s: %1.4f", label, probabilities[i]))
}

Appendice: sicurezza del modello

Indipendentemente da come rendi disponibili i tuoi modelli TensorFlow Lite per ML Kit, ML Kit li memorizza nel formato protobuf serializzato standard nello spazio di archiviazione locale.

In teoria, ciò significa che chiunque può copiare il tuo modello. Tuttavia, in pratica, la maggior parte dei modelli è così specifica per l'applicazione e offuscata dalle ottimizzazioni che il rischio è simile a quello dei concorrenti che smontano e riutilizzano il codice. Tuttavia, devi essere consapevole di questo rischio prima di utilizzare un modello personalizzato nella tua app.

Su Android a partire dal livello API 21 (Lollipop), il modello viene scaricato in una directory esclusa dal backup automatico.

Sul livello API Android 20 e versioni precedenti, il modello viene scaricato in una directory denominata com.google.firebase.ml.custom.models nello spazio di archiviazione interno privato dell'app. Se hai attivato il backup dei file utilizzando BackupAgent, potresti scegliere di escludere questa directory.