Używanie niestandardowego modelu TensorFlow Lite na Androidzie

Jeśli Twoja aplikacja korzysta z niestandardowych modeli TensorFlow Lite, możesz użyć Firebase ML do wdrożenia modeli. Wdrażając modele za pomocą Firebase, możesz zmniejszyć początkowy rozmiar pobierania aplikacji i aktualizować modele ML aplikacji bez publikowania jej nowej wersji. Dzięki Remote ConfigA/B Testing możesz dynamicznie udostępniać różne modele różnym grupom użytkowników.

Modele TensorFlow Lite

Modele TensorFlow Lite to modele ML zoptymalizowane pod kątem uruchamiania na urządzeniach mobilnych. Aby uzyskać model TensorFlow Lite:

Zanim zaczniesz

  1. Jeśli jeszcze tego nie zrobiono, dodaj Firebase do projektu na Androida.
  2. pliku Gradle modułu (na poziomie aplikacji) (zwykle <project>/<app-module>/build.gradle.kts lub <project>/<app-module>/build.gradle) dodaj zależność z biblioteką Firebase ML do pobierania modeli na Androida. Zalecamy używanie Firebase Android BoM do kontrolowania wersji biblioteki.

    W ramach konfigurowania Firebase ML narzędzia do pobierania modeli musisz też dodać do aplikacji pakiet SDK TensorFlow Lite.

    dependencies {
        // Import the BoM for the Firebase platform
        implementation(platform("com.google.firebase:firebase-bom:34.0.0"))
    
        // Add the dependency for the Firebase ML model downloader library
        // When using the BoM, you don't specify versions in Firebase library dependencies
        implementation("com.google.firebase:firebase-ml-modeldownloader")
    // Also add the dependency for the TensorFlow Lite library and specify its version implementation("org.tensorflow:tensorflow-lite:2.3.0")
    }

    Gdy korzystamy z Firebase Android BoM, aplikacja zawsze używa zgodnych wersji bibliotek Firebase na Androida.

    Jeśli nie chcesz używać Firebase BoM, musisz określić każdą wersję biblioteki Firebase w wierszu zależności.

    Pamiętaj, że jeśli w aplikacji używasz wielu bibliotek Firebase, zdecydowanie zalecamy korzystanie z BoM do zarządzania wersjami bibliotek, co zapewnia zgodność wszystkich wersji.

    dependencies {
        // Add the dependency for the Firebase ML model downloader library
        // When NOT using the BoM, you must specify versions in Firebase library dependencies
        implementation("com.google.firebase:firebase-ml-modeldownloader:26.0.0")
    // Also add the dependency for the TensorFlow Lite library and specify its version implementation("org.tensorflow:tensorflow-lite:2.3.0")
    }
  3. W pliku manifestu aplikacji zadeklaruj, że wymagane jest uprawnienie INTERNET:
    <uses-permission android:name="android.permission.INTERNET" />

1. Wdrażanie modelu

Wdrażaj niestandardowe modele TensorFlow za pomocą Firebase konsoli lub pakietów Firebase Admin SDK w językach Python i Node.js. Zobacz Wdrażanie modeli niestandardowych i zarządzanie nimi.

Po dodaniu modelu niestandardowego do projektu Firebase możesz odwoływać się do niego w swoich aplikacjach, podając określoną nazwę. W dowolnym momencie możesz wdrożyć nowy model TensorFlow Lite i pobrać go na urządzenia użytkowników, wywołując funkcję getModel() (patrz poniżej).

2. Pobierz model na urządzenie i zainicjuj interpreter TensorFlow Lite.

Aby użyć modelu TensorFlow Lite w aplikacji, najpierw pobierz najnowszą wersję modelu na urządzenie za pomocą Firebase MLpakietu SDK. Następnie utwórz instancję interpretera TensorFlow Lite z modelem.

Aby rozpocząć pobieranie modelu, wywołaj metodę getModel() narzędzia do pobierania modeli, podając nazwę przypisaną do modelu podczas przesyłania, informację, czy chcesz zawsze pobierać najnowszy model, oraz warunki, w których chcesz zezwolić na pobieranie.

Możesz wybrać jeden z 3 sposobów pobierania:

Typ pobierania Opis
LOCAL_MODEL Pobierz model lokalny z urządzenia. Jeśli nie ma dostępnego modelu lokalnego, ta funkcja działa jak LATEST_MODEL. Użyj tego typu pobierania, jeśli nie chcesz sprawdzać aktualizacji modelu. Na przykład używasz Zdalnej konfiguracji do pobierania nazw modeli i zawsze przesyłasz modele pod nowymi nazwami (zalecane).
LOCAL_MODEL_UPDATE_IN_BACKGROUND Pobierz model lokalny z urządzenia i zacznij aktualizować go w tle. Jeśli nie ma dostępnego modelu lokalnego, ta funkcja działa jak LATEST_MODEL.
LATEST_MODEL Pobierz najnowszy model. Jeśli model lokalny jest najnowszą wersją, zwraca model lokalny. W przeciwnym razie pobierz najnowszy model. To działanie będzie blokować pobieranie, dopóki nie zostanie pobrana najnowsza wersja (niezalecane). Używaj tego działania tylko w przypadkach, gdy wyraźnie potrzebujesz najnowszej wersji.

Do czasu potwierdzenia pobrania modelu należy wyłączyć funkcje z nim związane, np. wyszarzyć lub ukryć część interfejsu.

KotlinJava
val conditions = CustomModelDownloadConditions.Builder()
        .requireWifi()  // Also possible: .requireCharging() and .requireDeviceIdle()
        .build()
FirebaseModelDownloader.getInstance()
        .getModel("your_model", DownloadType.LOCAL_MODEL_UPDATE_IN_BACKGROUND,
            conditions)
        .addOnSuccessListener { model: CustomModel? ->
            // Download complete. Depending on your app, you could enable the ML
            // feature, or switch from the local model to the remote model, etc.

            // The CustomModel object contains the local path of the model file,
            // which you can use to instantiate a TensorFlow Lite interpreter.
            val modelFile = model?.file
            if (modelFile != null) {
                interpreter = Interpreter(modelFile)
            }
        }
CustomModelDownloadConditions conditions = new CustomModelDownloadConditions.Builder()
    .requireWifi()  // Also possible: .requireCharging() and .requireDeviceIdle()
    .build();
FirebaseModelDownloader.getInstance()
    .getModel("your_model", DownloadType.LOCAL_MODEL_UPDATE_IN_BACKGROUND, conditions)
    .addOnSuccessListener(new OnSuccessListener<CustomModel>() {
      @Override
      public void onSuccess(CustomModel model) {
        // Download complete. Depending on your app, you could enable the ML
        // feature, or switch from the local model to the remote model, etc.

        // The CustomModel object contains the local path of the model file,
        // which you can use to instantiate a TensorFlow Lite interpreter.
        File modelFile = model.getFile();
        if (modelFile != null) {
            interpreter = new Interpreter(modelFile);
        }
      }
    });

Wiele aplikacji rozpoczyna pobieranie w kodzie inicjującym, ale możesz to zrobić w dowolnym momencie przed użyciem modelu.

3. Przeprowadzanie wnioskowania na podstawie danych wejściowych

Pobieranie kształtów wejściowych i wyjściowych modelu

Interpreter modelu TensorFlow Lite przyjmuje jako dane wejściowe i generuje jako dane wyjściowe co najmniej 1 wielowymiarową tablicę. Tablice te zawierają wartości byte, int, long lub float. Zanim przekażesz dane do modelu lub użyjesz jego wyniku, musisz znać liczbę i wymiary („kształt”) tablic używanych przez model.

Jeśli model został utworzony samodzielnie lub jeśli format danych wejściowych i wyjściowych modelu jest udokumentowany, możesz już mieć te informacje. Jeśli nie znasz kształtu i typu danych wejścia i wyjścia modelu, możesz użyć interpretera TensorFlow Lite, aby sprawdzić model. Przykład:

Python
import tensorflow as tf

interpreter = tf.lite.Interpreter(model_path="your_model.tflite")
interpreter.allocate_tensors()

# Print input shape and type
inputs = interpreter.get_input_details()
print('{} input(s):'.format(len(inputs)))
for i in range(0, len(inputs)):
    print('{} {}'.format(inputs[i]['shape'], inputs[i]['dtype']))

# Print output shape and type
outputs = interpreter.get_output_details()
print('\n{} output(s):'.format(len(outputs)))
for i in range(0, len(outputs)):
    print('{} {}'.format(outputs[i]['shape'], outputs[i]['dtype']))

Przykładowe dane wyjściowe:

1 input(s):
[  1 224 224   3] <class 'numpy.float32'>

1 output(s):
[1 1000] <class 'numpy.float32'>

Uruchamianie interpretera

Po określeniu formatu danych wejściowych i wyjściowych modelu pobierz dane wejściowe i przeprowadź na nich wszelkie niezbędne przekształcenia, aby uzyskać dane wejściowe o odpowiednim kształcie dla modelu.

Jeśli na przykład masz model klasyfikacji obrazów o kształcie danych wejściowych [1 224 224 3] wartości zmiennoprzecinkowe, możesz wygenerować dane wejściowe ByteBuffer z obiektu Bitmap, jak pokazano w tym przykładzie:

KotlinJava
val bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true)
val input = ByteBuffer.allocateDirect(224*224*3*4).order(ByteOrder.nativeOrder())
for (y in 0 until 224) {
    for (x in 0 until 224) {
        val px = bitmap.getPixel(x, y)

        // Get channel values from the pixel value.
        val r = Color.red(px)
        val g = Color.green(px)
        val b = Color.blue(px)

        // Normalize channel values to [-1.0, 1.0]. This requirement depends on the model.
        // For example, some models might require values to be normalized to the range
        // [0.0, 1.0] instead.
        val rf = (r - 127) / 255f
        val gf = (g - 127) / 255f
        val bf = (b - 127) / 255f

        input.putFloat(rf)
        input.putFloat(gf)
        input.putFloat(bf)
    }
}
Bitmap bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true);
ByteBuffer input = ByteBuffer.allocateDirect(224 * 224 * 3 * 4).order(ByteOrder.nativeOrder());
for (int y = 0; y < 224; y++) {
    for (int x = 0; x < 224; x++) {
        int px = bitmap.getPixel(x, y);

        // Get channel values from the pixel value.
        int r = Color.red(px);
        int g = Color.green(px);
        int b = Color.blue(px);

        // Normalize channel values to [-1.0, 1.0]. This requirement depends
        // on the model. For example, some models might require values to be
        // normalized to the range [0.0, 1.0] instead.
        float rf = (r - 127) / 255.0f;
        float gf = (g - 127) / 255.0f;
        float bf = (b - 127) / 255.0f;

        input.putFloat(rf);
        input.putFloat(gf);
        input.putFloat(bf);
    }
}

Następnie przydziel ByteBuffer wystarczająco duży, aby pomieścić dane wyjściowe modelu, i przekaż bufor wejściowy i bufor wyjściowy do metody run() interpretera TensorFlow Lite. Na przykład w przypadku kształtu wyjściowego [1 1000] wartości zmiennoprzecinkowe:

KotlinJava
val bufferSize = 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE
val modelOutput = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder())
interpreter?.run(input, modelOutput)
int bufferSize = 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE;
ByteBuffer modelOutput = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder());
interpreter.run(input, modelOutput);

Sposób wykorzystania danych wyjściowych zależy od używanego modelu.

Jeśli na przykład przeprowadzasz klasyfikację, w następnym kroku możesz przypisać indeksy wyniku do etykiet, które reprezentują:

KotlinJava
modelOutput.rewind()
val probabilities = modelOutput.asFloatBuffer()
try {
    val reader = BufferedReader(
            InputStreamReader(assets.open("custom_labels.txt")))
    for (i in probabilities.capacity()) {
        val label: String = reader.readLine()
        val probability = probabilities.get(i)
        println("$label: $probability")
    }
} catch (e: IOException) {
    // File not found?
}
modelOutput.rewind();
FloatBuffer probabilities = modelOutput.asFloatBuffer();
try {
    BufferedReader reader = new BufferedReader(
            new InputStreamReader(getAssets().open("custom_labels.txt")));
    for (int i = 0; i < probabilities.capacity(); i++) {
        String label = reader.readLine();
        float probability = probabilities.get(i);
        Log.i(TAG, String.format("%s: %1.4f", label, probability));
    }
} catch (IOException e) {
    // File not found?
}

Dodatek: bezpieczeństwo modelu

Niezależnie od tego, w jaki sposób udostępniasz modele TensorFlow Lite, Firebase MLFirebase ML przechowuje je w standardowym serializowanym formacie protobuf w pamięci lokalnej.

Teoretycznie oznacza to, że każdy może skopiować Twój model. W praktyce jednak większość modeli jest tak ściśle powiązana z aplikacją i zaciemniona przez optymalizacje, że ryzyko jest podobne do ryzyka związanego z rozłożeniem i ponownym wykorzystaniem kodu przez konkurencję. Zanim jednak użyjesz w aplikacji modelu niestandardowego, musisz mieć świadomość tego ryzyka.

Na Androidzie w wersji API 21 (Lollipop) i nowszych model jest pobierany do katalogu, który jest wykluczony z automatycznej kopii zapasowej.

Na Androidzie w wersji API 20 i starszej model jest pobierany do katalogu o nazwie com.google.firebase.ml.custom.models w pamięci wewnętrznej aplikacji. Jeśli włączysz tworzenie kopii zapasowych plików za pomocą BackupAgent, możesz wykluczyć ten katalog.