Używanie niestandardowego modelu TensorFlow Lite na Androidzie

Jeśli Twoja aplikacja używa niestandardowych funkcji TensorFlow Lite, możesz używać Firebase ML do wdrażania modeli. Według wdrażając modele za pomocą Firebase, możesz zmniejszyć początkowy rozmiar pobieranych danych i aktualizowanie jej modeli ML bez publikowania nowej wersji do aplikacji. Dzięki Remote Config i A/B Testing możesz dynamicznie udostępniają różne modele różnym grupom użytkowników.

Modele TensorFlow Lite

Modele TensorFlow Lite to modele ML zoptymalizowane pod kątem uruchamiania na urządzeniach mobilnych urządzenia. Aby pobrać model TensorFlow Lite:

Zanim zaczniesz

  1. Jeśli jeszcze nie masz tego za sobą, dodaj Firebase do swojego projektu na Androida.
  2. w pliku Gradle (na poziomie aplikacji) modułu, (zwykle <project>/<app-module>/build.gradle.kts lub <project>/<app-module>/build.gradle), dodaj zależność z biblioteką pobierania modeli Firebase ML na Androida. Zalecamy użycie metody Firebase Android BoM aby kontrolować obsługę wersji biblioteki.

    W ramach konfigurowania programu do pobierania modeli Firebase ML musisz też dodać atrybut pakietu SDK TensorFlow Lite do aplikacji.

    dependencies {
        // Import the BoM for the Firebase platform
        implementation(platform("com.google.firebase:firebase-bom:33.5.1"))
    
        // Add the dependency for the Firebase ML model downloader library
        // When using the BoM, you don't specify versions in Firebase library dependencies
        implementation("com.google.firebase:firebase-ml-modeldownloader")
    // Also add the dependency for the TensorFlow Lite library and specify its version implementation("org.tensorflow:tensorflow-lite:2.3.0")
    }

    Korzystając z narzędzia Firebase Android BoM, Twoja aplikacja zawsze używa zgodnych wersji bibliotek Firebase na Androida.

    (Wersja alternatywna) Dodaj zależności biblioteki Firebase bez użycia komponentu BoM

    Jeśli nie chcesz używać biblioteki Firebase BoM, musisz określić każdą wersję biblioteki Firebase w wierszu zależności.

    Pamiętaj, że jeśli używasz wielu bibliotek Firebase w swojej aplikacji, zalecamy korzystanie z BoM do zarządzania wersjami biblioteki. Dzięki temu wszystkie wersje są zgodne.

    dependencies {
        // Add the dependency for the Firebase ML model downloader library
        // When NOT using the BoM, you must specify versions in Firebase library dependencies
        implementation("com.google.firebase:firebase-ml-modeldownloader:25.0.1")
    // Also add the dependency for the TensorFlow Lite library and specify its version implementation("org.tensorflow:tensorflow-lite:2.3.0")
    }
    Szukasz modułu biblioteki korzystającego z usługi Kotlin? Zaczyna się za Październik 2023 r. (Firebase BoM 32.5.0) zarówno programiści Kotlin, jak i Java zależą od modułu biblioteki głównej (więcej informacji znajdziesz w Najczęstsze pytania na temat tej inicjatywy).
  3. W pliku manifestu aplikacji zadeklaruj, że wymagane są uprawnienia INTERNET:
    <uses-permission android:name="android.permission.INTERNET" />

1. Wdrażanie modelu

Wdróż niestandardowe modele TensorFlow za pomocą konsoli Firebase lub pakiety SDK Python dla administratorów Firebase i Node.js. Zobacz Wdrażanie modeli niestandardowych i zarządzanie nimi

Po dodaniu do projektu Firebase modelu niestandardowego możesz się odwoływać do w swoich aplikacjach o podanej przez Ciebie nazwie. W każdej chwili możesz wdrożyć utworzyć nowy model TensorFlow Lite i pobrać go na pliki użytkowników urządzenia według Dzwonię pod getModel() (patrz poniżej).

2. Pobierz model na urządzenie i zainicjuj interpreter TensorFlow Lite

Aby użyć modelu TensorFlow Lite w aplikacji, najpierw użyj pakietu SDK Firebase ML aby pobrać na urządzenie najnowszą wersję modelu. Następnie utwórz instancję Interpreter TensorFlow Lite z modelem.

Aby rozpocząć pobieranie modelu, wywołaj metodę getModel() narzędzia do pobierania modelu, określając nazwę przypisaną do modelu podczas jego przesyłania. aby zawsze pobierać najnowsze modele oraz warunki, którzy chcą zezwolić na pobieranie.

Możesz wybrać jeden z 3 sposobów pobierania:

Typ pobierania Opis
LOCAL_MODEL Pobierz model lokalny z urządzenia. Jeśli nie jest dostępny model lokalny, działa jak LATEST_MODEL. Użyj tej Jeśli Cię to nie interesuje, typ pliku do pobrania sprawdzając dostępność aktualizacji modelu. Przykład: używasz Zdalnej konfiguracji do pobierania nazwy modeli i zawsze przesłane modele pod nowymi nazwami (zalecane).
LOCAL_MODEL_UPDATE_IN_BACKGROUND Pobierz model lokalny z urządzenia zacznij aktualizować model w tle. Jeśli nie jest dostępny model lokalny, działa jak LATEST_MODEL.
NAJNOWSZY_MODEL Pobierz najnowszy model. Jeśli model lokalny to najnowsza wersja, zwraca błąd lokalny model atrybucji. Jeśli nie, pobierz najnowszą wersję model atrybucji. To działanie będzie blokowane do czasu pobrana najnowsza wersja (nie ). Używaj tego sposobu tylko w w przypadku, gdy musisz podać najnowsze wersji.

Należy wyłączyć funkcje związane z modelem – na przykład wyszarzone lub ukryj część interfejsu użytkownika, dopóki nie potwierdzisz, że model został pobrany.

Kotlin+KTX

val conditions = CustomModelDownloadConditions.Builder()
        .requireWifi()  // Also possible: .requireCharging() and .requireDeviceIdle()
        .build()
FirebaseModelDownloader.getInstance()
        .getModel("your_model", DownloadType.LOCAL_MODEL_UPDATE_IN_BACKGROUND,
            conditions)
        .addOnSuccessListener { model: CustomModel? ->
            // Download complete. Depending on your app, you could enable the ML
            // feature, or switch from the local model to the remote model, etc.

            // The CustomModel object contains the local path of the model file,
            // which you can use to instantiate a TensorFlow Lite interpreter.
            val modelFile = model?.file
            if (modelFile != null) {
                interpreter = Interpreter(modelFile)
            }
        }

Java

CustomModelDownloadConditions conditions = new CustomModelDownloadConditions.Builder()
    .requireWifi()  // Also possible: .requireCharging() and .requireDeviceIdle()
    .build();
FirebaseModelDownloader.getInstance()
    .getModel("your_model", DownloadType.LOCAL_MODEL_UPDATE_IN_BACKGROUND, conditions)
    .addOnSuccessListener(new OnSuccessListener<CustomModel>() {
      @Override
      public void onSuccess(CustomModel model) {
        // Download complete. Depending on your app, you could enable the ML
        // feature, or switch from the local model to the remote model, etc.

        // The CustomModel object contains the local path of the model file,
        // which you can use to instantiate a TensorFlow Lite interpreter.
        File modelFile = model.getFile();
        if (modelFile != null) {
            interpreter = new Interpreter(modelFile);
        }
      }
    });

Wiele aplikacji rozpoczyna zadanie pobierania w kodzie inicjowania, ale możesz to zrobić. więc w dowolnym momencie przed użyciem modelu.

3. Przeprowadź wnioskowanie na danych wejściowych

Pobierz kształty danych wejściowych i wyjściowych modelu

Interpreter modelu TensorFlow Lite pobiera jako dane wejściowe i tworzy jako dane wyjściowe co najmniej jednej tablicy wielowymiarowej. Te tablice zawierają: byte, int, long lub float . Aby przekazywać dane do modelu lub używać jego wyników, musisz wiedzieć, liczby i wymiary („kształt”) tablic używanych przez model.

Jeśli model został utworzony przez Ciebie lub jeśli format wejściowy i wyjściowy modelu to możesz mieć już te informacje. Jeśli nie znasz kształtu i typu danych danych wejściowych i wyjściowych modelu, możesz użyć funkcji Interpreter TensorFlow Lite do sprawdzenia modelu. Przykład:

Python

import tensorflow as tf

interpreter = tf.lite.Interpreter(model_path="your_model.tflite")
interpreter.allocate_tensors()

# Print input shape and type
inputs = interpreter.get_input_details()
print('{} input(s):'.format(len(inputs)))
for i in range(0, len(inputs)):
    print('{} {}'.format(inputs[i]['shape'], inputs[i]['dtype']))

# Print output shape and type
outputs = interpreter.get_output_details()
print('\n{} output(s):'.format(len(outputs)))
for i in range(0, len(outputs)):
    print('{} {}'.format(outputs[i]['shape'], outputs[i]['dtype']))

Przykładowe dane wyjściowe:

1 input(s):
[  1 224 224   3] <class 'numpy.float32'>

1 output(s):
[1 1000] <class 'numpy.float32'>

Uruchamianie tłumaczenia rozmowy

Po określeniu formatu danych wejściowych i wyjściowych modelu pobierz plik i wykonywać przekształcenia danych, które są niezbędne, aby model miał kształt właściwy.

Jeśli na przykład masz model klasyfikacji obrazów o wejściowym kształcie [1 224 224 3] wartości zmiennoprzecinkowych. Możesz wygenerować wartość wejściową ByteBuffer. z obiektu Bitmap, jak pokazano w tym przykładzie:

Kotlin+KTX

val bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true)
val input = ByteBuffer.allocateDirect(224*224*3*4).order(ByteOrder.nativeOrder())
for (y in 0 until 224) {
    for (x in 0 until 224) {
        val px = bitmap.getPixel(x, y)

        // Get channel values from the pixel value.
        val r = Color.red(px)
        val g = Color.green(px)
        val b = Color.blue(px)

        // Normalize channel values to [-1.0, 1.0]. This requirement depends on the model.
        // For example, some models might require values to be normalized to the range
        // [0.0, 1.0] instead.
        val rf = (r - 127) / 255f
        val gf = (g - 127) / 255f
        val bf = (b - 127) / 255f

        input.putFloat(rf)
        input.putFloat(gf)
        input.putFloat(bf)
    }
}

Java

Bitmap bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true);
ByteBuffer input = ByteBuffer.allocateDirect(224 * 224 * 3 * 4).order(ByteOrder.nativeOrder());
for (int y = 0; y < 224; y++) {
    for (int x = 0; x < 224; x++) {
        int px = bitmap.getPixel(x, y);

        // Get channel values from the pixel value.
        int r = Color.red(px);
        int g = Color.green(px);
        int b = Color.blue(px);

        // Normalize channel values to [-1.0, 1.0]. This requirement depends
        // on the model. For example, some models might require values to be
        // normalized to the range [0.0, 1.0] instead.
        float rf = (r - 127) / 255.0f;
        float gf = (g - 127) / 255.0f;
        float bf = (b - 127) / 255.0f;

        input.putFloat(rf);
        input.putFloat(gf);
        input.putFloat(bf);
    }
}

Następnie przydziel plik ByteBuffer na tyle duży, aby zawierał dane wyjściowe modelu, przekazać bufor wejściowy i bufor wyjściowy do interpretera TensorFlow Lite Metoda run(). Na przykład dla kształtu wyjściowego liczby zmiennoprzecinkowej [1 1000] wartości:

Kotlin+KTX

val bufferSize = 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE
val modelOutput = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder())
interpreter?.run(input, modelOutput)

Java

int bufferSize = 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE;
ByteBuffer modelOutput = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder());
interpreter.run(input, modelOutput);

Sposób wykorzystania danych wyjściowych zależy od używanego modelu.

Jeśli na przykład przeprowadzasz klasyfikację, kolejnym krokiem może być zmapuj indeksy wyników na etykiety, które reprezentują:

Kotlin+KTX

modelOutput.rewind()
val probabilities = modelOutput.asFloatBuffer()
try {
    val reader = BufferedReader(
            InputStreamReader(assets.open("custom_labels.txt")))
    for (i in probabilities.capacity()) {
        val label: String = reader.readLine()
        val probability = probabilities.get(i)
        println("$label: $probability")
    }
} catch (e: IOException) {
    // File not found?
}

Java

modelOutput.rewind();
FloatBuffer probabilities = modelOutput.asFloatBuffer();
try {
    BufferedReader reader = new BufferedReader(
            new InputStreamReader(getAssets().open("custom_labels.txt")));
    for (int i = 0; i < probabilities.capacity(); i++) {
        String label = reader.readLine();
        float probability = probabilities.get(i);
        Log.i(TAG, String.format("%s: %1.4f", label, probability));
    }
} catch (IOException e) {
    // File not found?
}

Dodatek: zabezpieczenia modelu

Niezależnie od tego, jak udostępnisz swoje modele TensorFlow Lite Firebase ML, Firebase ML przechowuje je w standardowym zserializowanym formacie protokołu w formacie pamięci lokalnej.

Teoretycznie oznacza to, że każdy może skopiować Twój model. Pamiętaj jednak: W praktyce większość modeli jest specyficzna dla danej aplikacji i pod kątem podobnych optymalizacji, jakie stwarzają konkurencji, demontaż ponownego wykorzystania kodu. Musisz jednak wiedzieć o tym ryzyku, zanim zaczniesz niestandardowy model w swojej aplikacji.

W przypadku interfejsu API Androida na poziomie 21 (Lollipop) lub nowszym model jest pobierany do katalogu, który jest wykluczono z automatycznej kopii zapasowej.

W przypadku interfejsu Android API na poziomie 20 lub starszym model jest pobierany do katalogu z nazwą com.google.firebase.ml.custom.models w sekcji prywatnej w aplikacji pamięci wewnętrznej. Jeśli masz włączoną kopię zapasową plików za pomocą usługi BackupAgent, możesz wykluczyć ten katalog.