Usa un modello TensorFlow Lite personalizzato su Android

Se la vostra applicazione utilizza personalizzati tensorflow Lite modelli, è possibile utilizzare Firebase ML per distribuire i modelli. Distribuendo i modelli con Firebase, puoi ridurre le dimensioni di download iniziali della tua app e aggiornare i modelli ML della tua app senza rilasciare una nuova versione della tua app. Inoltre, con Remote Config e A/B Testing, puoi servire dinamicamente diversi modelli a diversi gruppi di utenti.

Modelli TensorFlow Lite

I modelli TensorFlow Lite sono modelli ML ottimizzati per l'esecuzione su dispositivi mobili. Per ottenere un modello TensorFlow Lite:

Prima di iniziare

  1. Se non l'hai già, aggiungi Firebase al progetto Android .
  2. Utilizzando la Firebase Android BoM , dichiarare la dipendenza per il modello ML Firebase downloader biblioteca Android nel modulo (a livello di app) File Gradle (di solito app/build.gradle ).

    Inoltre, come parte della configurazione del downloader del modello Firebase ML, devi aggiungere l'SDK TensorFlow Lite alla tua app.

    Giava

    dependencies {
        // Import the BoM for the Firebase platform
        implementation platform('com.google.firebase:firebase-bom:28.4.1')
    
        // Declare the dependency for the Firebase ML model downloader library
        // When using the BoM, you don't specify versions in Firebase library dependencies
        implementation 'com.google.firebase:firebase-ml-modeldownloader'
    // Also declare the dependency for the TensorFlow Lite library and specify its version implementation 'org.tensorflow:tensorflow-lite:2.3.0'
    }

    Usando il Firebase Android BoM , la vostra applicazione sarà sempre utilizzare versioni compatibili delle librerie Firebase Android.

    (Alternativa) Dichiarare Firebase dipendenze delle librerie senza utilizzare la distinta

    Se scegli di non utilizzare Firebase BoM, devi specificare ogni versione della libreria Firebase nella relativa riga di dipendenza.

    Si noti che se si utilizzano più librerie Firebase nella vostra app, ti consigliamo di utilizzare la distinta di gestire versioni della libreria, che assicura che tutte le versioni sono compatibili.

    dependencies {
        // Declare the dependency for the Firebase ML model downloader library
        // When NOT using the BoM, you must specify versions in Firebase library dependencies
        implementation 'com.google.firebase:firebase-ml-modeldownloader:24.0.0'
    // Also declare the dependency for the TensorFlow Lite library and specify its version implementation 'org.tensorflow:tensorflow-lite:2.3.0'
    }

    Kotlin+KTX

    dependencies {
        // Import the BoM for the Firebase platform
        implementation platform('com.google.firebase:firebase-bom:28.4.1')
    
        // Declare the dependency for the Firebase ML model downloader library
        // When using the BoM, you don't specify versions in Firebase library dependencies
        implementation 'com.google.firebase:firebase-ml-modeldownloader-ktx'
    // Also declare the dependency for the TensorFlow Lite library and specify its version implementation 'org.tensorflow:tensorflow-lite:2.3.0'
    }

    Usando il Firebase Android BoM , la vostra applicazione sarà sempre utilizzare versioni compatibili delle librerie Firebase Android.

    (Alternativa) Dichiarare Firebase dipendenze delle librerie senza utilizzare la distinta

    Se scegli di non utilizzare Firebase BoM, devi specificare ogni versione della libreria Firebase nella relativa riga di dipendenza.

    Si noti che se si utilizzano più librerie Firebase nella vostra app, ti consigliamo di utilizzare la distinta di gestire versioni della libreria, che assicura che tutte le versioni sono compatibili.

    dependencies {
        // Declare the dependency for the Firebase ML model downloader library
        // When NOT using the BoM, you must specify versions in Firebase library dependencies
        implementation 'com.google.firebase:firebase-ml-modeldownloader-ktx:24.0.0'
    // Also declare the dependency for the TensorFlow Lite library and specify its version implementation 'org.tensorflow:tensorflow-lite:2.3.0'
    }
  3. Nel manifesto, dichiarare che è necessario il permesso Internet del app:
    <uses-permission android:name="android.permission.INTERNET" />

1. Distribuisci il tuo modello

Distribuisci i tuoi modelli TensorFlow personalizzati utilizzando la console Firebase o gli SDK Firebase Admin Python e Node.js. Vedere implementare e gestire modelli personalizzati .

Dopo aver aggiunto un modello personalizzato al tuo progetto Firebase, puoi fare riferimento al modello nelle tue app utilizzando il nome specificato. In qualsiasi momento, è possibile distribuire un nuovo modello tensorflow Lite e scaricare il nuovo modello sui dispositivi degli utenti chiamando getModel() (vedi sotto).

2. Scaricare il modello sul dispositivo e inizializzare un interprete TensorFlow Lite

Per utilizzare il modello TensorFlow Lite nella tua app, utilizza prima l'SDK Firebase ML per scaricare l'ultima versione del modello sul dispositivo. Quindi, crea un'istanza di un interprete TensorFlow Lite con il modello.

Per avviare il modello di download, chiamare il modello del downloader getModel() il metodo, specificando il nome assegnato al modello quando hai caricata, se si desidera scaricare sempre l'ultimo modello, e le condizioni in cui si desidera consentire il download.

Puoi scegliere tra tre comportamenti di download:

Tipo di download Descrizione
MODELLO_LOCALE Ottieni il modello locale dal dispositivo. Se non c'è un modello locale disponibile, questo si comporta come LATEST_MODEL . Usa questo tipo di download se non sei interessato a controllare gli aggiornamenti del modello. Ad esempio, stai utilizzando Remote Config per recuperare i nomi dei modelli e carichi sempre i modelli con nuovi nomi (consigliato).
LOCAL_MODEL_UPDATE_IN_BACKGROUND Ottieni il modello locale dal dispositivo e inizia ad aggiornare il modello in background. Se non c'è un modello locale disponibile, questo si comporta come LATEST_MODEL .
ULTIMO_MODELLO Ottieni l'ultimo modello. Se il modello locale è la versione più recente, restituisce il modello locale. Altrimenti, scarica l'ultimo modello. Questo comportamento si bloccherà fino al download dell'ultima versione (non consigliato). Utilizzare questo comportamento solo nei casi in cui è necessaria esplicitamente la versione più recente.

Dovresti disabilitare la funzionalità relativa al modello, ad esempio disattivare o nascondere parte dell'interfaccia utente, fino a quando non confermi che il modello è stato scaricato.

Giava

CustomModelDownloadConditions conditions = new CustomModelDownloadConditions.Builder()
    .requireWifi()  // Also possible: .requireCharging() and .requireDeviceIdle()
    .build();
FirebaseModelDownloader.getInstance()
    .getModel("your_model", DownloadType.LOCAL_MODEL_UPDATE_IN_BACKGROUND, conditions)
    .addOnSuccessListener(new OnSuccessListener<CustomModel>() {
      @Override
      public void onSuccess(CustomModel model) {
        // Download complete. Depending on your app, you could enable the ML
        // feature, or switch from the local model to the remote model, etc.

        // The CustomModel object contains the local path of the model file,
        // which you can use to instantiate a TensorFlow Lite interpreter.
        File modelFile = model.getFile();
        if (modelFile != null) {
            interpreter = new Interpreter(modelFile);
        }
      }
    });

Kotlin+KTX

val conditions = CustomModelDownloadConditions.Builder()
        .requireWifi()  // Also possible: .requireCharging() and .requireDeviceIdle()
        .build()
FirebaseModelDownloader.getInstance()
        .getModel("your_model", DownloadType.LOCAL_MODEL_UPDATE_IN_BACKGROUND,
            conditions)
        .addOnSuccessListener { model: CustomModel? ->
            // Download complete. Depending on your app, you could enable the ML
            // feature, or switch from the local model to the remote model, etc.

            // The CustomModel object contains the local path of the model file,
            // which you can use to instantiate a TensorFlow Lite interpreter.
            val modelFile = model?.file
            if (modelFile != null) {
                interpreter = Interpreter(modelFile)
            }
        }

Molte app avviano l'attività di download nel codice di inizializzazione, ma puoi farlo in qualsiasi momento prima di dover utilizzare il modello.

3. Eseguire l'inferenza sui dati di input

Ottieni le forme di input e output del tuo modello

L'interprete del modello TensorFlow Lite prende come input e produce come output uno o più array multidimensionali. Questi array contengono sia byte , int , long , o float valori. Prima di poter passare i dati a un modello o utilizzarne il risultato, è necessario conoscere il numero e le dimensioni ("forma") degli array utilizzati dal modello.

Se hai creato il modello da solo o se il formato di input e output del modello è documentato, potresti già disporre di queste informazioni. Se non si conosce la forma e il tipo di dati dell'input e dell'output del modello, è possibile utilizzare l'interprete TensorFlow Lite per ispezionare il modello. Per esempio:

Pitone

import tensorflow as tf

interpreter = tf.lite.Interpreter(model_path="your_model.tflite")
interpreter.allocate_tensors()

# Print input shape and type
inputs = interpreter.get_input_details()
print('{} input(s):'.format(len(inputs)))
for i in range(0, len(inputs)):
    print('{} {}'.format(inputs[i]['shape'], inputs[i]['dtype']))

# Print output shape and type
outputs = interpreter.get_output_details()
print('\n{} output(s):'.format(len(outputs)))
for i in range(0, len(outputs)):
    print('{} {}'.format(outputs[i]['shape'], outputs[i]['dtype']))

Esempio di output:

1 input(s):
[  1 224 224   3] <class 'numpy.float32'>

1 output(s):
[1 1000] <class 'numpy.float32'>

Esegui l'interprete

Dopo aver determinato il formato dell'input e dell'output del modello, ottenere i dati di input ed eseguire le trasformazioni sui dati necessarie per ottenere un input della forma corretta per il modello.

Ad esempio, se si dispone di un modello di classificazione un'immagine con una forma di ingresso [1 224 224 3] valori in virgola mobile, è possibile generare un ingresso ByteBuffer da un Bitmap oggetto come mostrato nel seguente esempio:

Giava

Bitmap bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true);
ByteBuffer input = ByteBuffer.allocateDirect(224 * 224 * 3 * 4).order(ByteOrder.nativeOrder());
for (int y = 0; y < 224; y++) {
    for (int x = 0; x < 224; x++) {
        int px = bitmap.getPixel(x, y);

        // Get channel values from the pixel value.
        int r = Color.red(px);
        int g = Color.green(px);
        int b = Color.blue(px);

        // Normalize channel values to [-1.0, 1.0]. This requirement depends
        // on the model. For example, some models might require values to be
        // normalized to the range [0.0, 1.0] instead.
        float rf = (r - 127) / 255.0f;
        float gf = (g - 127) / 255.0f;
        float bf = (b - 127) / 255.0f;

        input.putFloat(rf);
        input.putFloat(gf);
        input.putFloat(bf);
    }
}

Kotlin+KTX

val bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true)
val input = ByteBuffer.allocateDirect(224*224*3*4).order(ByteOrder.nativeOrder())
for (y in 0 until 224) {
    for (x in 0 until 224) {
        val px = bitmap.getPixel(x, y)

        // Get channel values from the pixel value.
        val r = Color.red(px)
        val g = Color.green(px)
        val b = Color.blue(px)

        // Normalize channel values to [-1.0, 1.0]. This requirement depends on the model.
        // For example, some models might require values to be normalized to the range
        // [0.0, 1.0] instead.
        val rf = (r - 127) / 255f
        val gf = (g - 127) / 255f
        val bf = (b - 127) / 255f

        input.putFloat(rf)
        input.putFloat(gf)
        input.putFloat(bf)
    }
}

Poi, allocare una ByteBuffer abbastanza grande da contenere l'output del modello e passare il buffer di input e output buffer per l'interprete tensorflow Lite run() metodo. Ad esempio, per una forma di uscita [1 1000] valori in virgola mobile:

Giava

int bufferSize = 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE;
ByteBuffer modelOutput = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder());
interpreter.run(input, modelOutput);

Kotlin+KTX

val bufferSize = 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE
val modelOutput = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder())
interpreter?.run(input, modelOutput)

Il modo in cui utilizzi l'output dipende dal modello che stai utilizzando.

Ad esempio, se stai eseguendo la classificazione, come passaggio successivo, potresti mappare gli indici del risultato alle etichette che rappresentano:

Giava

modelOutput.rewind();
FloatBuffer probabilities = modelOutput.asFloatBuffer();
try {
    BufferedReader reader = new BufferedReader(
            new InputStreamReader(getAssets().open("custom_labels.txt")));
    for (int i = 0; i < probabilities.capacity(); i++) {
        String label = reader.readLine();
        float probability = probabilities.get(i);
        Log.i(TAG, String.format("%s: %1.4f", label, probability));
    }
} catch (IOException e) {
    // File not found?
}

Kotlin+KTX

modelOutput.rewind()
val probabilities = modelOutput.asFloatBuffer()
try {
    val reader = BufferedReader(
            InputStreamReader(assets.open("custom_labels.txt")))
    for (i in probabilities.capacity()) {
        val label: String = reader.readLine()
        val probability = probabilities.get(i)
        println("$label: $probability")
    }
} catch (e: IOException) {
    // File not found?
}

Appendice: sicurezza del modello

Indipendentemente da come rendi disponibili i tuoi modelli TensorFlow Lite a Firebase ML, Firebase ML li archivia nel formato protobuf serializzato standard nell'archiviazione locale.

In teoria, questo significa che chiunque può copiare il tuo modello. Tuttavia, in pratica, la maggior parte dei modelli è così specifica per l'applicazione e offuscata dalle ottimizzazioni che il rischio è simile a quello dei concorrenti che smontano e riutilizzano il codice. Tuttavia, dovresti essere consapevole di questo rischio prima di utilizzare un modello personalizzato nella tua app.

Su Android livello di API 21 (Lollipop) e più recenti, il modello viene scaricato in una directory che viene escluso dal backup automatico .

Su Android livello di API 20 anni e più, il modello viene scaricato in una directory denominata com.google.firebase.ml.custom.models nella memoria interna app-privato. Se è stato abilitato il backup di file utilizzando BackupAgent , si potrebbe scegliere di escludere questa directory.