Utilizzare un modello TensorFlow Lite per l'inferenza con ML Kit su Android

Puoi utilizzare ML Kit per eseguire l'inferenza on-device con una Modello TensorFlow Lite.

Questa API richiede l'SDK per Android di livello 16 (Jelly Bean) o versioni successive.

Prima di iniziare

  1. Se non l'hai già fatto, aggiungi Firebase al tuo progetto Android.
  2. Aggiungi al modulo le dipendenze per le librerie Android di ML Kit file Gradle (a livello di app) (di solito app/build.gradle):
    apply plugin: 'com.android.application'
    apply plugin: 'com.google.gms.google-services'
    
    dependencies {
      // ...
    
      implementation 'com.google.firebase:firebase-ml-model-interpreter:22.0.3'
    }
  3. Converti il modello TensorFlow che vuoi utilizzare nel formato TensorFlow Lite. Consulta TOCO: TensorFlow Lite Optimizing Converter.

Ospittare o raggruppare il modello

Prima di poter utilizzare un modello TensorFlow Lite per l'inferenza nella tua app, devi renderlo disponibile per ML Kit. ML Kit può utilizzare i modelli TensorFlow Lite ospitati in remoto utilizzando Firebase, inclusi nel file binario dell'app o entrambi.

Se ospiti un modello su Firebase, puoi aggiornarlo senza rilasciare un nuova versione dell'app e puoi usare Remote Config e A/B Testing per di pubblicare dinamicamente modelli diversi per insiemi di utenti diversi.

Se scegli di fornire il modello solo ospitandolo con Firebase e non puoi ridurne le dimensioni di download iniziali. Tieni presente, tuttavia, che se il modello non è integrato nella tua app, le funzionalità correlate al modello non saranno disponibili finché l'app non scarica l'app per la prima volta.

Se combini il modello con l'app, puoi assicurarti che le funzionalità di ML dell'app continuino a funzionare anche quando il modello ospitato su Firebase non è disponibile.

Ospita modelli su Firebase

Per ospitare il tuo modello TensorFlow Lite su Firebase:

  1. Nella sezione ML Kit della console Firebase, fai clic su la scheda Personalizzata.
  2. Fai clic su Aggiungi modello personalizzato (o Aggiungi un altro modello).
  3. Specifica un nome che verrà utilizzato per identificare il modello in Firebase progetto, quindi carica il file del modello TensorFlow Lite (che di solito termina con .tflite o .lite).
  4. Nel file manifest dell'app, dichiara che è necessaria l'autorizzazione INTERNET:
    <uses-permission android:name="android.permission.INTERNET" />

Dopo aver aggiunto un modello personalizzato al tuo progetto Firebase, puoi fare riferimento alla tuo modello nelle tue app utilizzando il nome specificato. Puoi caricare un nuovo modello TensorFlow Lite in qualsiasi momento e la tua app lo scaricherà e inizierà a utilizzarlo al successivo riavvio. Puoi definire il dispositivo le condizioni necessarie affinché l'app tenti di aggiornare il modello (vedi di seguito).

Raggruppa i modelli con un'app

Per raggruppare il modello TensorFlow Lite con la tua app, copia il file del modello (di solito che termina con .tflite o .lite) alla cartella assets/ dell'app. (Potrebbe servirti per creare prima la cartella facendo clic con il tasto destro del mouse sulla cartella app/, quindi facendo clic Nuovo > Cartella > nella cartella Asset.

Poi, aggiungi quanto segue al file build.gradle dell'app per assicurarti che Gradle non comprimono i modelli durante la creazione dell'app:

android {

    // ...

    aaptOptions {
        noCompress "tflite"  // Your model's file extension: "tflite", "lite", etc.
    }
}

Il file del modello sarà incluso nel pacchetto dell'app e sarà disponibile per ML Kit come asset non elaborato.

Carica il modello

Per utilizzare il modello TensorFlow Lite nella tua app, configura prima ML Kit con le posizioni in cui è disponibile il modello: da remoto utilizzando Firebase, nello spazio di archiviazione locale o in entrambe le posizioni. Se specifichi sia un modello locale che uno remoto, puoi utilizzare il modello remoto se disponibile e utilizzare come fallback il modello archiviato localmente se il modello remoto non è disponibile.

Configura un modello ospitato da Firebase

Se hai ospitato il tuo modello con Firebase, crea un FirebaseCustomRemoteModel specificando il nome assegnato al modello al momento del caricamento:

Java

FirebaseCustomRemoteModel remoteModel =
        new FirebaseCustomRemoteModel.Builder("your_model").build();

Kotlin+KTX

val remoteModel = FirebaseCustomRemoteModel.Builder("your_model").build()

Poi, avvia l'attività di download del modello, specificando le condizioni in cui vuoi consentire il download. Se il modello non è presente sul dispositivo o se una versione più recente del modello, l'attività scaricherà in modo asincrono modello di Firebase:

Java

FirebaseModelDownloadConditions conditions = new FirebaseModelDownloadConditions.Builder()
        .requireWifi()
        .build();
FirebaseModelManager.getInstance().download(remoteModel, conditions)
        .addOnCompleteListener(new OnCompleteListener<Void>() {
            @Override
            public void onComplete(@NonNull Task<Void> task) {
                // Success.
            }
        });

Kotlin+KTX

val conditions = FirebaseModelDownloadConditions.Builder()
    .requireWifi()
    .build()
FirebaseModelManager.getInstance().download(remoteModel, conditions)
    .addOnCompleteListener {
        // Success.
    }

Molte app avviano l'attività di download nel codice di inizializzazione, ma puoi farlo quindi in qualsiasi momento prima di dover usare il modello.

Configurare un modello locale

Se hai bundle il modello con la tua app, crea un FirebaseCustomLocalModel specificando il nome file del modello TensorFlow Lite:

Java

FirebaseCustomLocalModel localModel = new FirebaseCustomLocalModel.Builder()
        .setAssetFilePath("your_model.tflite")
        .build();

Kotlin+KTX

val localModel = FirebaseCustomLocalModel.Builder()
    .setAssetFilePath("your_model.tflite")
    .build()

Crea un interprete dal tuo modello

Dopo aver configurato le origini del modello, crea un oggetto FirebaseModelInterpreter da una di queste.

Se hai solo un modello in bundle locale, crea un interprete dalla tua Oggetto FirebaseCustomLocalModel:

Java

FirebaseModelInterpreter interpreter;
try {
    FirebaseModelInterpreterOptions options =
            new FirebaseModelInterpreterOptions.Builder(localModel).build();
    interpreter = FirebaseModelInterpreter.getInstance(options);
} catch (FirebaseMLException e) {
    // ...
}

Kotlin+KTX

val options = FirebaseModelInterpreterOptions.Builder(localModel).build()
val interpreter = FirebaseModelInterpreter.getInstance(options)

Se il tuo modello è ospitato in remoto, dovrai verificare che sia stato scaricato prima di eseguirlo. Puoi controllare lo stato del download del modello utilizzando il metodo isModelDownloaded() del gestore del modello.

Anche se devi confermarlo solo prima di eseguire l'interprete, se hai sia un modello ospitato in remoto sia un modello in bundle locale, potrebbe essere utile eseguire questo controllo durante l'inizializzazione dell'interprete del modello: crea un interprete dal modello remoto se è stato scaricato e dal modello locale in caso contrario.

Java

FirebaseModelManager.getInstance().isModelDownloaded(remoteModel)
        .addOnSuccessListener(new OnSuccessListener<Boolean>() {
            @Override
            public void onSuccess(Boolean isDownloaded) {
                FirebaseModelInterpreterOptions options;
                if (isDownloaded) {
                    options = new FirebaseModelInterpreterOptions.Builder(remoteModel).build();
                } else {
                    options = new FirebaseModelInterpreterOptions.Builder(localModel).build();
                }
                FirebaseModelInterpreter interpreter = FirebaseModelInterpreter.getInstance(options);
                // ...
            }
        });

Kotlin+KTX

FirebaseModelManager.getInstance().isModelDownloaded(remoteModel)
    .addOnSuccessListener { isDownloaded -> 
    val options =
        if (isDownloaded) {
            FirebaseModelInterpreterOptions.Builder(remoteModel).build()
        } else {
            FirebaseModelInterpreterOptions.Builder(localModel).build()
        }
    val interpreter = FirebaseModelInterpreter.getInstance(options)
}

Se disponi solo di un modello ospitato in remoto, devi disattivare le relative funzionalità, ad esempio rendere non selezionabile o nascondere parte dell'interfaccia utente, fino a quando confermi che il modello è stato scaricato. Puoi farlo collegando un listener al metodo download() del gestore del modello:

Java

FirebaseModelManager.getInstance().download(remoteModel, conditions)
        .addOnSuccessListener(new OnSuccessListener<Void>() {
            @Override
            public void onSuccess(Void v) {
              // Download complete. Depending on your app, you could enable
              // the ML feature, or switch from the local model to the remote
              // model, etc.
            }
        });

Kotlin+KTX

FirebaseModelManager.getInstance().download(remoteModel, conditions)
    .addOnCompleteListener {
        // Download complete. Depending on your app, you could enable the ML
        // feature, or switch from the local model to the remote model, etc.
    }

Specifica l'input e l'output del modello

A questo punto, configura i formati di input e di output dell'interprete del modello.

Un modello TensorFlow Lite prende come input e genera come output uno o più matrici multidimensionali. Questi array contengono valori byte, int, long o float. Devi configurare ML Kit con il numero e le dimensioni ("forma") degli array che utilizzati dal modello.

Se non conosci la forma e il tipo di dati di input e output del tuo modello, puoi utilizzare l'interprete Python di TensorFlow Lite per ispezionarlo. Ad esempio:

import tensorflow as tf

interpreter = tf.lite.Interpreter(model_path="my_model.tflite")
interpreter.allocate_tensors()

# Print input shape and type
print(interpreter.get_input_details()[0]['shape'])  # Example: [1 224 224 3]
print(interpreter.get_input_details()[0]['dtype'])  # Example: <class 'numpy.float32'>

# Print output shape and type
print(interpreter.get_output_details()[0]['shape'])  # Example: [1 1000]
print(interpreter.get_output_details()[0]['dtype'])  # Example: <class 'numpy.float32'>

Dopo aver determinato il formato di input e output del modello, puoi configurare l'interprete del modello della tua app creando un FirebaseModelInputOutputOptions.

Ad esempio, un modello di classificazione di immagini in virgola mobile potrebbe utilizzare come input NArray x224 x 224 x 3 di valori float, che rappresenta un gruppo di N Immagini a tre canali (RGB) 224 x 224 e genera come output un elenco di 1000 valori float, ognuno dei quali rappresenta la probabilità di un'immagine di appartenenza una delle 1000 categorie previste dal modello.

Per un modello di questo tipo, devi configurare l'input e l'output dell'interprete del modello come mostrato di seguito:

Java

FirebaseModelInputOutputOptions inputOutputOptions =
        new FirebaseModelInputOutputOptions.Builder()
                .setInputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 224, 224, 3})
                .setOutputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 5})
                .build();

Kotlin+KTX

val inputOutputOptions = FirebaseModelInputOutputOptions.Builder()
        .setInputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1, 224, 224, 3))
        .setOutputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1, 5))
        .build()

Esegui l'inferenza sui dati di input

Infine, per eseguire l'inferenza utilizzando il modello, ottieni i dati di input ed esegui le trasformazioni dei dati necessarie per ottenere un array di input la forma più adatta al tuo modello.

Ad esempio, se disponi di un modello di classificazione delle immagini con la forma di input [1 224 224 3] valori in virgola mobile, si potrebbe generare un array di input da un Oggetto Bitmap come mostrato nell'esempio seguente:

Java

Bitmap bitmap = getYourInputImage();
bitmap = Bitmap.createScaledBitmap(bitmap, 224, 224, true);

int batchNum = 0;
float[][][][] input = new float[1][224][224][3];
for (int x = 0; x < 224; x++) {
    for (int y = 0; y < 224; y++) {
        int pixel = bitmap.getPixel(x, y);
        // Normalize channel values to [-1.0, 1.0]. This requirement varies by
        // model. For example, some models might require values to be normalized
        // to the range [0.0, 1.0] instead.
        input[batchNum][x][y][0] = (Color.red(pixel) - 127) / 128.0f;
        input[batchNum][x][y][1] = (Color.green(pixel) - 127) / 128.0f;
        input[batchNum][x][y][2] = (Color.blue(pixel) - 127) / 128.0f;
    }
}

Kotlin+KTX

val bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true)

val batchNum = 0
val input = Array(1) { Array(224) { Array(224) { FloatArray(3) } } }
for (x in 0..223) {
    for (y in 0..223) {
        val pixel = bitmap.getPixel(x, y)
        // Normalize channel values to [-1.0, 1.0]. This requirement varies by
        // model. For example, some models might require values to be normalized
        // to the range [0.0, 1.0] instead.
        input[batchNum][x][y][0] = (Color.red(pixel) - 127) / 255.0f
        input[batchNum][x][y][1] = (Color.green(pixel) - 127) / 255.0f
        input[batchNum][x][y][2] = (Color.blue(pixel) - 127) / 255.0f
    }
}

Quindi, crea un oggetto FirebaseModelInputs con i dati di input e passalo, insieme alla specifica di input e output del modello, al metodo run dell'interprete del modello:

Java

FirebaseModelInputs inputs = new FirebaseModelInputs.Builder()
        .add(input)  // add() as many input arrays as your model requires
        .build();
firebaseInterpreter.run(inputs, inputOutputOptions)
        .addOnSuccessListener(
                new OnSuccessListener<FirebaseModelOutputs>() {
                    @Override
                    public void onSuccess(FirebaseModelOutputs result) {
                        // ...
                    }
                })
        .addOnFailureListener(
                new OnFailureListener() {
                    @Override
                    public void onFailure(@NonNull Exception e) {
                        // Task failed with an exception
                        // ...
                    }
                });

Kotlin+KTX

val inputs = FirebaseModelInputs.Builder()
        .add(input) // add() as many input arrays as your model requires
        .build()
firebaseInterpreter.run(inputs, inputOutputOptions)
        .addOnSuccessListener { result ->
            // ...
        }
        .addOnFailureListener { e ->
            // Task failed with an exception
            // ...
        }

Se la chiamata ha esito positivo, puoi ottenere l'output chiamando il metodo getOutput() dell'oggetto passato al listener di successo. Ad esempio:

Java

float[][] output = result.getOutput(0);
float[] probabilities = output[0];

Kotlin+KTX

val output = result.getOutput<Array<FloatArray>>(0)
val probabilities = output[0]

La modalità di utilizzo dell'output dipende dal modello utilizzato.

Ad esempio, se esegui la classificazione, come passaggio successivo, mappa gli indici del risultato alle etichette che rappresentano:

Java

BufferedReader reader = new BufferedReader(
        new InputStreamReader(getAssets().open("retrained_labels.txt")));
for (int i = 0; i < probabilities.length; i++) {
    String label = reader.readLine();
    Log.i("MLKit", String.format("%s: %1.4f", label, probabilities[i]));
}

Kotlin+KTX

val reader = BufferedReader(
        InputStreamReader(assets.open("retrained_labels.txt")))
for (i in probabilities.indices) {
    val label = reader.readLine()
    Log.i("MLKit", String.format("%s: %1.4f", label, probabilities[i]))
}

Appendice: Sicurezza del modello

Indipendentemente da come rendi disponibili i tuoi modelli TensorFlow Lite ML Kit, ML Kit li archivia nel formato protobuf serializzato standard in formato archiviazione locale.

In teoria, questo significa che chiunque può copiare il modello. Tuttavia, Nella pratica, la maggior parte dei modelli è così specifica per l'applicazione e offuscata secondo cui il rischio è simile a quello dello smontaggio della concorrenza a riutilizzare il codice. Tuttavia, è necessario essere consapevoli di questo rischio prima di utilizzare un modello personalizzato nella tua app.

Sul livello API Android 21 (Lollipop) e versioni successive, il modello viene scaricato in un che è esclusi dal backup automatico.

Sul livello API Android 20 e versioni precedenti, il modello viene scaricato in una directory denominato com.google.firebase.ml.custom.models in privato dell'app memoria interna. Se hai attivato il backup dei file utilizzando BackupAgent, puoi scegliere di escludere questa directory.