Używanie niestandardowego modelu TensorFlow Lite na Androidzie

Jeśli Twoja aplikacja korzysta z niestandardowych TensorFlow Lite modeli, możesz użyć Firebase ML do wdrożenia swoich modeli. Dzięki wdrażaniu modeli za pomocą Firebase możesz zmniejszyć początkowy rozmiar pobierania aplikacji i aktualizować modele ML aplikacji bez publikowania nowej wersji aplikacji. Za pomocą Remote Config i A/B Testing możesz dynamicznie udostępniać różne modele różnym grupom użytkowników.

Modele TensorFlow Lite

Modele TensorFlow Lite to modele ML zoptymalizowane pod kątem działania na urządzeniach mobilnych. Aby uzyskać model TensorFlow Lite:

Zanim zaczniesz

  1. Dodaj Firebase do projektu aplikacji na Androida, jeśli nie korzystasz w nim jeszcze z tej usługi.
  2. W pliku Gradle na poziomie modułu (aplikacji) (zwykle <project>/<app-module>/build.gradle.kts lub <project>/<app-module>/build.gradle), dodaj zależność od biblioteki pobierania modeli Firebase ML na Androida. Zalecamy używanie Firebase Android BoM do kontrolowania obsługi wersji bibliotek.

    W ramach konfigurowania pobierania modeli Firebase ML musisz też dodać do aplikacji pakiet SDK TensorFlow Lite.

    dependencies {
        // Import the BoM for the Firebase platform
        implementation(platform("com.google.firebase:firebase-bom:34.13.0"))
    
        // Add the dependency for the Firebase ML model downloader library
        // When using the BoM, you don't specify versions in Firebase library dependencies
        implementation("com.google.firebase:firebase-ml-modeldownloader")
    // Also add the dependency for the TensorFlow Lite library and specify its version implementation("org.tensorflow:tensorflow-lite:2.3.0")
    }

    Gdy korzystamy z Firebase Android BoM, aplikacja zawsze używa zgodnych wersji bibliotek Firebase na Androida.

    (Alternatywnie)  Dodaj zależności biblioteki Firebase bez użycia BoM

    Jeśli nie chcesz używać Firebase BoM, musisz określić wersję każdej biblioteki Firebase w wierszu zależności.

    Pamiętaj, że jeśli w aplikacji używasz kilku bibliotek Firebase, zdecydowanie zalecamy używanie BoM do zarządzania wersjami bibliotek, co zapewnia zgodność wszystkich wersji.

    dependencies {
        // Add the dependency for the Firebase ML model downloader library
        // When NOT using the BoM, you must specify versions in Firebase library dependencies
        implementation("com.google.firebase:firebase-ml-modeldownloader:26.0.2")
    // Also add the dependency for the TensorFlow Lite library and specify its version implementation("org.tensorflow:tensorflow-lite:2.3.0")
    }
  3. W manifeście aplikacji zadeklaruj, że wymagane jest uprawnienie INTERNET:
    <uses-permission android:name="android.permission.INTERNET" />

1. Wdróż model

Wdróż niestandardowe modele TensorFlow za pomocą Firebase konsoli lub pakietów Firebase Admin SDK na Pythona i Node.js. Zobacz Wdrażanie niestandardowych modeli i zarządzanie nimi.

Po dodaniu niestandardowego modelu do projektu w Firebase możesz odwoływać się do niego w swoich aplikacjach, podając nazwę, którą określisz. W każdej chwili możesz wdrożyć nowy model TensorFlow Lite i pobrać go na urządzenia użytkowników, wywołując metodę getModel() (patrz poniżej).

2. Pobierz model na urządzenie i zainicjuj interpretera TensorFlow Lite

Aby używać modelu TensorFlow Lite w aplikacji, najpierw pobierz najnowszą wersję modelu na urządzenie za pomocą pakietu Firebase ML SDK. Następnie utwórz instancję interpretera TensorFlow Lite z modelem.

Aby rozpocząć pobieranie modelu, wywołaj metodę getModel() pobierania modelu, podając nazwę przypisaną do modelu podczas przesyłania, informację, czy chcesz zawsze pobierać najnowszy model, oraz warunki, w których chcesz zezwolić na pobieranie.

Możesz wybrać jeden z 3 sposobów pobierania:

Typ pobierania Opis
LOCAL_MODEL Pobierz model lokalny z urządzenia. Jeśli nie ma dostępnego modelu lokalnego, ta opcja działa jak LATEST_MODEL. Użyj tego typu pobierania, jeśli nie chcesz sprawdzać aktualizacji modelu. Na przykład używasz Zdalnej konfiguracji do pobierania nazw modeli i zawsze przesyłasz modele pod nowymi nazwami (zalecane).
LOCAL_MODEL_UPDATE_IN_BACKGROUND Pobierz model lokalny z urządzenia i rozpocznij aktualizowanie modelu w tle. Jeśli nie ma dostępnego modelu lokalnego, ta opcja działa jak LATEST_MODEL.
LATEST_MODEL Pobierz najnowszy model. Jeśli model lokalny jest najnowszą wersją, zwraca model lokalny. W przeciwnym razie pobierz najnowszy model. To działanie będzie blokować, dopóki nie zostanie pobrana najnowsza wersja (niezalecane). Używaj tego działania tylko w przypadkach, gdy wyraźnie potrzebujesz najnowszej wersji.

Do czasu potwierdzenia pobrania modelu należy wyłączyć funkcje związane z modelem, np. wyszarzyć lub ukryć część interfejsu.

Kotlin

val conditions = CustomModelDownloadConditions.Builder()
        .requireWifi()  // Also possible: .requireCharging() and .requireDeviceIdle()
        .build()
FirebaseModelDownloader.getInstance()
        .getModel("your_model", DownloadType.LOCAL_MODEL_UPDATE_IN_BACKGROUND,
            conditions)
        .addOnSuccessListener { model: CustomModel? ->
            // Download complete. Depending on your app, you could enable the ML
            // feature, or switch from the local model to the remote model, etc.

            // The CustomModel object contains the local path of the model file,
            // which you can use to instantiate a TensorFlow Lite interpreter.
            val modelFile = model?.file
            if (modelFile != null) {
                interpreter = Interpreter(modelFile)
            }
        }

Java

CustomModelDownloadConditions conditions = new CustomModelDownloadConditions.Builder()
    .requireWifi()  // Also possible: .requireCharging() and .requireDeviceIdle()
    .build();
FirebaseModelDownloader.getInstance()
    .getModel("your_model", DownloadType.LOCAL_MODEL_UPDATE_IN_BACKGROUND, conditions)
    .addOnSuccessListener(new OnSuccessListener<CustomModel>() {
      @Override
      public void onSuccess(CustomModel model) {
        // Download complete. Depending on your app, you could enable the ML
        // feature, or switch from the local model to the remote model, etc.

        // The CustomModel object contains the local path of the model file,
        // which you can use to instantiate a TensorFlow Lite interpreter.
        File modelFile = model.getFile();
        if (modelFile != null) {
            interpreter = new Interpreter(modelFile);
        }
      }
    });

Wiele aplikacji rozpoczyna zadanie pobierania w kodzie inicjującym, ale możesz to zrobić w dowolnym momencie przed użyciem modelu.

3. Przeprowadź wnioskowanie na podstawie danych wejściowych

Pobierz kształty wejściowe i wyjściowe modelu

Interpreter modelu TensorFlow Lite przyjmuje jako dane wejściowe i generuje jako dane wyjściowe co najmniej 1 tablicę wielowymiarową. Te tablice zawierają wartości byte, int, long, lub float. Zanim przekażesz dane do modelu lub użyjesz jego wyniku, musisz znać liczbę i wymiary ("kształt") tablic używanych przez model.

Jeśli model został utworzony samodzielnie lub jeśli format wejściowy i wyjściowy modelu jest udokumentowany, możesz już mieć te informacje. Jeśli nie znasz kształtu i typu danych wejściowych i wyjściowych modelu, możesz użyć interpretera TensorFlow Lite, aby sprawdzić model. Przykład:

Python

import tensorflow as tf

interpreter = tf.lite.Interpreter(model_path="your_model.tflite")
interpreter.allocate_tensors()

# Print input shape and type
inputs = interpreter.get_input_details()
print('{} input(s):'.format(len(inputs)))
for i in range(0, len(inputs)):
    print('{} {}'.format(inputs[i]['shape'], inputs[i]['dtype']))

# Print output shape and type
outputs = interpreter.get_output_details()
print('\n{} output(s):'.format(len(outputs)))
for i in range(0, len(outputs)):
    print('{} {}'.format(outputs[i]['shape'], outputs[i]['dtype']))

Przykładowe dane wyjściowe:

1 input(s):
[  1 224 224   3] <class 'numpy.float32'>

1 output(s):
[1 1000] <class 'numpy.float32'>

Uruchom interpretera

Po określeniu formatu wejściowego i wyjściowego modelu pobierz dane wejściowe i przeprowadź na nich wszelkie transformacje niezbędne do uzyskania danych wejściowych o odpowiednim kształcie dla modelu.

Jeśli na przykład masz model klasyfikacji obrazów o kształcie wejściowym [1 224 224 3] wartości reprezentacji zmiennoprzecinkowej, możesz wygenerować dane wejściowe ByteBuffer z obiektu Bitmap, jak pokazano w tym przykładzie:

Kotlin

val bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true)
val input = ByteBuffer.allocateDirect(224*224*3*4).order(ByteOrder.nativeOrder())
for (y in 0 until 224) {
    for (x in 0 until 224) {
        val px = bitmap.getPixel(x, y)

        // Get channel values from the pixel value.
        val r = Color.red(px)
        val g = Color.green(px)
        val b = Color.blue(px)

        // Normalize channel values to [-1.0, 1.0]. This requirement depends on the model.
        // For example, some models might require values to be normalized to the range
        // [0.0, 1.0] instead.
        val rf = (r - 127) / 255f
        val gf = (g - 127) / 255f
        val bf = (b - 127) / 255f

        input.putFloat(rf)
        input.putFloat(gf)
        input.putFloat(bf)
    }
}

Java

Bitmap bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true);
ByteBuffer input = ByteBuffer.allocateDirect(224 * 224 * 3 * 4).order(ByteOrder.nativeOrder());
for (int y = 0; y < 224; y++) {
    for (int x = 0; x < 224; x++) {
        int px = bitmap.getPixel(x, y);

        // Get channel values from the pixel value.
        int r = Color.red(px);
        int g = Color.green(px);
        int b = Color.blue(px);

        // Normalize channel values to [-1.0, 1.0]. This requirement depends
        // on the model. For example, some models might require values to be
        // normalized to the range [0.0, 1.0] instead.
        float rf = (r - 127) / 255.0f;
        float gf = (g - 127) / 255.0f;
        float bf = (b - 127) / 255.0f;

        input.putFloat(rf);
        input.putFloat(gf);
        input.putFloat(bf);
    }
}

Następnie przydziel ByteBuffer wystarczająco duży, aby pomieścić dane wyjściowe modelu, i przekaż bufor wejściowy oraz bufor wyjściowy do metody run() interpretera TensorFlow Lite. Na przykład w przypadku kształtu wyjściowego [1 1000] wartości zmiennoprzecinkowych:

Kotlin

val bufferSize = 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE
val modelOutput = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder())
interpreter?.run(input, modelOutput)

Java

int bufferSize = 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE;
ByteBuffer modelOutput = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder());
interpreter.run(input, modelOutput);

Sposób użycia danych wyjściowych zależy od używanego modelu.

Jeśli na przykład przeprowadzasz klasyfikację, w następnym kroku możesz przypisać indeksy wyniku do etykiet, które reprezentują:

Kotlin

modelOutput.rewind()
val probabilities = modelOutput.asFloatBuffer()
try {
    val reader = BufferedReader(
            InputStreamReader(assets.open("custom_labels.txt")))
    for (i in probabilities.capacity()) {
        val label: String = reader.readLine()
        val probability = probabilities.get(i)
        println("$label: $probability")
    }
} catch (e: IOException) {
    // File not found?
}

Java

modelOutput.rewind();
FloatBuffer probabilities = modelOutput.asFloatBuffer();
try {
    BufferedReader reader = new BufferedReader(
            new InputStreamReader(getAssets().open("custom_labels.txt")));
    for (int i = 0; i < probabilities.capacity(); i++) {
        String label = reader.readLine();
        float probability = probabilities.get(i);
        Log.i(TAG, String.format("%s: %1.4f", label, probability));
    }
} catch (IOException e) {
    // File not found?
}

Dodatek: bezpieczeństwo modelu

Niezależnie od tego, jak udostępniasz modele TensorFlow Lite w Firebase ML, Firebase ML przechowuje je w standardowym serializowanym formacie protobuf w pamięci lokalnej.

Teoretycznie oznacza to, że każdy może skopiować Twój model. W praktyce jednak większość modeli jest tak specyficzna dla aplikacji i zaciemniona przez optymalizacje, że ryzyko jest podobne do ryzyka, że konkurenci rozmontują i ponownie wykorzystają Twój kod. Mimo to przed użyciem niestandardowego modelu w aplikacji musisz zdawać sobie sprawę z tego ryzyka.

W przypadku Androida w wersji API 21 (Lollipop) i nowszych model jest pobierany do katalogu, który jest wykluczony z automatycznego tworzenia kopii zapasowych.

W przypadku Androida w wersji API 20 i starszych model jest pobierany do katalogu o nazwie com.google.firebase.ml.custom.models w pamięci wewnętrznej aplikacji. Jeśli włączysz tworzenie kopii zapasowych plików za pomocą BackupAgent, możesz wykluczyć ten katalog.