Utilizzare un modello TensorFlow Lite per l'inferenza con ML Kit su Android

Puoi utilizzare ML Kit per eseguire l'inferenza on-device con un modello TensorFlow Lite.

Questa API richiede il livello SDK Android 16 (Jelly Bean) o versioni successive.

Prima di iniziare

  1. Se non l'hai ancora fatto, aggiungi Firebase al tuo progetto Android.
  2. Aggiungi le dipendenze per le librerie Android ML Kit al file Gradle del modulo (a livello di app, di solito app/build.gradle):
    apply plugin: 'com.android.application'
    apply plugin: 'com.google.gms.google-services'
    
    dependencies {
      // ...
    
      implementation 'com.google.firebase:firebase-ml-model-interpreter:22.0.3'
    }
  3. Converti il modello TensorFlow che vuoi utilizzare nel formato TensorFlow Lite. Consulta TOCO: TensorFlow Lite Optimizing Converter.

Ospitare o raggruppare il modello

Prima di poter utilizzare un modello TensorFlow Lite per l'inferenza nella tua app, devi rendere il modello disponibile per ML Kit. ML Kit può utilizzare modelli TensorFlow Lite ospitati in remoto utilizzando Firebase, inclusi nel binario dell'app o entrambi.

Se ospiti un modello su Firebase, puoi aggiornarlo senza rilasciare una nuova versione dell'app e puoi utilizzare Remote Config e A/B Testing per mostrare dinamicamente modelli diversi a diversi gruppi di utenti.

Se scegli di fornire solo il modello ospitandolo con Firebase e non di raggrupparlo con la tua app, puoi ridurre le dimensioni del download iniziale dell'app. Tieni presente, tuttavia, che se il modello non è raggruppato con la tua app, qualsiasi funzionalità correlata al modello non sarà disponibile finché l'app non scarica il modello per la prima volta.

Se raggruppi il modello con l'app, puoi assicurarti che le funzionalità di ML dell'app continuino a funzionare anche quando il modello ospitato su Firebase non è disponibile.

Ospitare modelli su Firebase

Per ospitare il modello TensorFlow Lite su Firebase:

  1. Nella sezione ML Kit della console Firebase, fai clic sulla scheda Personalizzato.
  2. Fai clic su Aggiungi modello personalizzato (o Aggiungi un altro modello).
  3. Specifica un nome che verrà utilizzato per identificare il modello nel tuo progetto Firebase, poi carica il file del modello TensorFlow Lite (di solito termina con .tflite o .lite).
  4. Nel manifest della tua app, dichiara che è necessaria l'autorizzazione INTERNET:
    <uses-permission android:name="android.permission.INTERNET" />

Dopo aver aggiunto un modello personalizzato al tuo progetto Firebase, puoi fare riferimento al modello nelle tue app usando il nome specificato. In qualsiasi momento, puoi caricare un nuovo modello TensorFlow Lite e la tua app scaricherà il nuovo modello e inizierà a utilizzarlo al successivo riavvio. Puoi definire le condizioni del dispositivo necessarie per l'aggiornamento del modello da parte dell'app (vedi di seguito).

Raggruppare i modelli con un'app

Per raggruppare il modello TensorFlow Lite con l'app, copia il file del modello (di solito termina con .tflite o .lite) nella cartella assets/ dell'app. Potresti dover creare prima la cartella facendo clic con il tasto destro del mouse sulla cartella app/, quindi facendo clic su Nuovo > Cartella > Cartella asset.

Poi, aggiungi quanto segue al file build.gradle dell'app per assicurarti che Gradle non comprima i modelli durante la creazione dell'app:

android {

    // ...

    aaptOptions {
        noCompress "tflite"  // Your model's file extension: "tflite", "lite", etc.
    }
}

Il file del modello verrà incluso nel pacchetto dell'app e sarà disponibile per ML Kit come asset non elaborato.

Carica il modello

Per utilizzare il modello TensorFlow Lite nella tua app, configura prima ML Kit con le posizioni in cui è disponibile il modello: in remoto utilizzando Firebase, nell'archivio locale o in entrambi. Se specifichi sia un modello locale che uno remoto, puoi utilizzare il modello remoto se è disponibile e ripiegare sul modello archiviato localmente se il modello remoto non è disponibile.

Configurare un modello ospitato da Firebase

Se hai ospitato il modello con Firebase, crea un oggetto FirebaseCustomRemoteModel specificando il nome che hai assegnato al modello quando lo hai caricato:

Java

FirebaseCustomRemoteModel remoteModel =
        new FirebaseCustomRemoteModel.Builder("your_model").build();

Kotlin

val remoteModel = FirebaseCustomRemoteModel.Builder("your_model").build()

Poi avvia l'attività di download del modello, specificando le condizioni in base alle quali vuoi consentire il download. Se il modello non è sul dispositivo o se è disponibile una versione più recente, l'attività scaricherà in modo asincrono il modello da Firebase:

Java

FirebaseModelDownloadConditions conditions = new FirebaseModelDownloadConditions.Builder()
        .requireWifi()
        .build();
FirebaseModelManager.getInstance().download(remoteModel, conditions)
        .addOnCompleteListener(new OnCompleteListener<Void>() {
            @Override
            public void onComplete(@NonNull Task<Void> task) {
                // Success.
            }
        });

Kotlin

val conditions = FirebaseModelDownloadConditions.Builder()
    .requireWifi()
    .build()
FirebaseModelManager.getInstance().download(remoteModel, conditions)
    .addOnCompleteListener {
        // Success.
    }

Molte app avviano l'attività di download nel codice di inizializzazione, ma puoi farlo in qualsiasi momento prima di dover utilizzare il modello.

Configura un modello locale

Se hai incluso il modello nel bundle dell'app, crea un oggetto FirebaseCustomLocalModel specificando il nome file del modello TensorFlow Lite:

Java

FirebaseCustomLocalModel localModel = new FirebaseCustomLocalModel.Builder()
        .setAssetFilePath("your_model.tflite")
        .build();

Kotlin

val localModel = FirebaseCustomLocalModel.Builder()
    .setAssetFilePath("your_model.tflite")
    .build()

Crea un interprete dal tuo modello

Dopo aver configurato le origini del modello, crea un FirebaseModelInterpreter oggetto da una di queste.

Se hai solo un modello in bundle locale, crea un interprete dall'oggetto FirebaseCustomLocalModel:

Java

FirebaseModelInterpreter interpreter;
try {
    FirebaseModelInterpreterOptions options =
            new FirebaseModelInterpreterOptions.Builder(localModel).build();
    interpreter = FirebaseModelInterpreter.getInstance(options);
} catch (FirebaseMLException e) {
    // ...
}

Kotlin

val options = FirebaseModelInterpreterOptions.Builder(localModel).build()
val interpreter = FirebaseModelInterpreter.getInstance(options)

Se hai un modello ospitato in remoto, devi verificare che sia stato scaricato prima di eseguirlo. Puoi controllare lo stato del download del modello utilizzando il metodo isModelDownloaded() di Model Manager.

Anche se devi confermare questa operazione solo prima di eseguire l'interprete, se hai sia un modello ospitato in remoto sia un modello incluso localmente, potrebbe essere utile eseguire questo controllo durante l'istanza dell'interprete del modello: crea un interprete dal modello remoto se è stato scaricato e dal modello locale in caso contrario.

Java

FirebaseModelManager.getInstance().isModelDownloaded(remoteModel)
        .addOnSuccessListener(new OnSuccessListener<Boolean>() {
            @Override
            public void onSuccess(Boolean isDownloaded) {
                FirebaseModelInterpreterOptions options;
                if (isDownloaded) {
                    options = new FirebaseModelInterpreterOptions.Builder(remoteModel).build();
                } else {
                    options = new FirebaseModelInterpreterOptions.Builder(localModel).build();
                }
                FirebaseModelInterpreter interpreter = FirebaseModelInterpreter.getInstance(options);
                // ...
            }
        });

Kotlin

FirebaseModelManager.getInstance().isModelDownloaded(remoteModel)
    .addOnSuccessListener { isDownloaded -> 
    val options =
        if (isDownloaded) {
            FirebaseModelInterpreterOptions.Builder(remoteModel).build()
        } else {
            FirebaseModelInterpreterOptions.Builder(localModel).build()
        }
    val interpreter = FirebaseModelInterpreter.getInstance(options)
}

Se hai solo un modello ospitato in remoto, devi disattivare la funzionalità correlata al modello, ad esempio disattivare o nascondere parte della tua UI, finché non confermi che il modello è stato scaricato. Puoi farlo collegando un listener al metodo download() del gestore dei modelli:

Java

FirebaseModelManager.getInstance().download(remoteModel, conditions)
        .addOnSuccessListener(new OnSuccessListener<Void>() {
            @Override
            public void onSuccess(Void v) {
              // Download complete. Depending on your app, you could enable
              // the ML feature, or switch from the local model to the remote
              // model, etc.
            }
        });

Kotlin

FirebaseModelManager.getInstance().download(remoteModel, conditions)
    .addOnCompleteListener {
        // Download complete. Depending on your app, you could enable the ML
        // feature, or switch from the local model to the remote model, etc.
    }

Specifica l'input e l'output del modello

Successivamente, configura i formati di input e output dell'interprete del modello.

Un modello TensorFlow Lite accetta come input e produce come output uno o più array multidimensionali. Questi array contengono valori byte, int, long o float. Devi configurare ML Kit con il numero e le dimensioni ("forma") degli array utilizzati dal modello.

Se non conosci la forma e il tipo di dati di input e output del tuo modello, puoi utilizzare l'interprete Python di TensorFlow Lite per esaminare il modello. Ad esempio:

import tensorflow as tf

interpreter = tf.lite.Interpreter(model_path="my_model.tflite")
interpreter.allocate_tensors()

# Print input shape and type
print(interpreter.get_input_details()[0]['shape'])  # Example: [1 224 224 3]
print(interpreter.get_input_details()[0]['dtype'])  # Example: <class 'numpy.float32'>

# Print output shape and type
print(interpreter.get_output_details()[0]['shape'])  # Example: [1 1000]
print(interpreter.get_output_details()[0]['dtype'])  # Example: <class 'numpy.float32'>

Dopo aver determinato il formato dell'input e dell'output del modello, puoi configurare l'interprete del modello della tua app creando un oggetto FirebaseModelInputOutputOptions.

Ad esempio, un modello di classificazione delle immagini in virgola mobile potrebbe prendere come input un array Nx224x224x3 di valori float, che rappresenta un batch di N immagini a tre canali (RGB) 224x224 e produrre come output un elenco di 1000 valori float, ognuno dei quali rappresenta la probabilità che l'immagine appartenga a una delle 1000 categorie previste dal modello.

Per un modello di questo tipo, configureresti l'input e l'output dell'interprete del modello come mostrato di seguito:

Java

FirebaseModelInputOutputOptions inputOutputOptions =
        new FirebaseModelInputOutputOptions.Builder()
                .setInputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 224, 224, 3})
                .setOutputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 5})
                .build();

Kotlin

val inputOutputOptions = FirebaseModelInputOutputOptions.Builder()
        .setInputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1, 224, 224, 3))
        .setOutputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1, 5))
        .build()

Esegue l'inferenza sui dati di input

Infine, per eseguire l'inferenza utilizzando il modello, recupera i dati di input ed esegui le trasformazioni necessarie per ottenere un array di input della forma corretta per il modello.

Ad esempio, se hai un modello di classificazione delle immagini con una forma di input di [1 224 224 3] valori in virgola mobile, puoi generare un array di input da un oggetto Bitmap come mostrato nell'esempio seguente:

Java

Bitmap bitmap = getYourInputImage();
bitmap = Bitmap.createScaledBitmap(bitmap, 224, 224, true);

int batchNum = 0;
float[][][][] input = new float[1][224][224][3];
for (int x = 0; x < 224; x++) {
    for (int y = 0; y < 224; y++) {
        int pixel = bitmap.getPixel(x, y);
        // Normalize channel values to [-1.0, 1.0]. This requirement varies by
        // model. For example, some models might require values to be normalized
        // to the range [0.0, 1.0] instead.
        input[batchNum][x][y][0] = (Color.red(pixel) - 127) / 128.0f;
        input[batchNum][x][y][1] = (Color.green(pixel) - 127) / 128.0f;
        input[batchNum][x][y][2] = (Color.blue(pixel) - 127) / 128.0f;
    }
}

Kotlin

val bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true)

val batchNum = 0
val input = Array(1) { Array(224) { Array(224) { FloatArray(3) } } }
for (x in 0..223) {
    for (y in 0..223) {
        val pixel = bitmap.getPixel(x, y)
        // Normalize channel values to [-1.0, 1.0]. This requirement varies by
        // model. For example, some models might require values to be normalized
        // to the range [0.0, 1.0] instead.
        input[batchNum][x][y][0] = (Color.red(pixel) - 127) / 255.0f
        input[batchNum][x][y][1] = (Color.green(pixel) - 127) / 255.0f
        input[batchNum][x][y][2] = (Color.blue(pixel) - 127) / 255.0f
    }
}

Quindi, crea un oggetto FirebaseModelInputs con i tuoi dati di input e trasmettilo insieme alla specifica di input e output del modello al metodo run dell'interprete del modello:

Java

FirebaseModelInputs inputs = new FirebaseModelInputs.Builder()
        .add(input)  // add() as many input arrays as your model requires
        .build();
firebaseInterpreter.run(inputs, inputOutputOptions)
        .addOnSuccessListener(
                new OnSuccessListener<FirebaseModelOutputs>() {
                    @Override
                    public void onSuccess(FirebaseModelOutputs result) {
                        // ...
                    }
                })
        .addOnFailureListener(
                new OnFailureListener() {
                    @Override
                    public void onFailure(@NonNull Exception e) {
                        // Task failed with an exception
                        // ...
                    }
                });

Kotlin

val inputs = FirebaseModelInputs.Builder()
        .add(input) // add() as many input arrays as your model requires
        .build()
firebaseInterpreter.run(inputs, inputOutputOptions)
        .addOnSuccessListener { result ->
            // ...
        }
        .addOnFailureListener { e ->
            // Task failed with an exception
            // ...
        }

Se la chiamata ha esito positivo, puoi ottenere l'output chiamando il metodo getOutput() dell'oggetto passato al listener di successo. Ad esempio:

Java

float[][] output = result.getOutput(0);
float[] probabilities = output[0];

Kotlin

val output = result.getOutput<Array<FloatArray>>(0)
val probabilities = output[0]

Il modo in cui utilizzi l'output dipende dal modello che stai utilizzando.

Ad esempio, se esegui la classificazione, come passaggio successivo potresti mappare gli indici del risultato alle etichette che rappresentano:

Java

BufferedReader reader = new BufferedReader(
        new InputStreamReader(getAssets().open("retrained_labels.txt")));
for (int i = 0; i < probabilities.length; i++) {
    String label = reader.readLine();
    Log.i("MLKit", String.format("%s: %1.4f", label, probabilities[i]));
}

Kotlin

val reader = BufferedReader(
        InputStreamReader(assets.open("retrained_labels.txt")))
for (i in probabilities.indices) {
    val label = reader.readLine()
    Log.i("MLKit", String.format("%s: %1.4f", label, probabilities[i]))
}

Appendice: Sicurezza del modello

Indipendentemente da come rendi disponibili i tuoi modelli TensorFlow Lite per ML Kit, ML Kit li memorizza nel formato protobuf serializzato standard nello spazio di archiviazione locale.

In teoria, questo significa che chiunque può copiare il tuo modello. Tuttavia, in pratica, la maggior parte dei modelli è così specifica per l'applicazione e offuscata dalle ottimizzazioni che il rischio è simile a quello dei concorrenti che disassemblano e riutilizzano il tuo codice. Tuttavia, devi essere consapevole di questo rischio prima di utilizzare un modello personalizzato nella tua app.

A partire dal livello API 21 (Lollipop) di Android, il modello viene scaricato in una directory esclusa dal backup automatico.

Nei livelli API Android 20 e precedenti, il modello viene scaricato in una directory denominata com.google.firebase.ml.custom.models nella memoria interna privata dell'app. Se hai attivato il backup dei file utilizzando BackupAgent, potresti scegliere di escludere questa directory.