Sử dụng mô hình TensorFlow Lite tùy chỉnh trên Android

Nếu ứng dụng của bạn sử dụng các mô hình TensorFlow Lite tùy chỉnh, thì bạn có thể sử dụng Firebase ML để triển khai các mô hình của mình. Bằng cách triển khai các mô hình với Firebase, bạn có thể giảm kích thước tải xuống ban đầu của ứng dụng và cập nhật các mô hình ML của ứng dụng mà không cần phát hành phiên bản mới của ứng dụng. Ngoài ra, với Cấu hình từ xa và Thử nghiệm A/B, bạn có thể tự động phân phát các mô hình khác nhau cho các nhóm người dùng khác nhau.

Các mẫu TensorFlow Lite

Các mô hình TensorFlow Lite là các mô hình ML được tối ưu hóa để chạy trên thiết bị di động. Để có được mô hình TensorFlow Lite:

Trước khi bắt đầu

  1. Nếu bạn chưa có, hãy thêm Firebase vào dự án Android của bạn .
  2. Trong tệp Gradle mô-đun (cấp ứng dụng) của bạn (thường là <project>/<app-module>/build.gradle.kts hoặc <project>/<app-module>/build.gradle ), hãy thêm phần phụ thuộc cho Firebase ML thư viện Android của trình tải xuống mô hình. Chúng tôi khuyên bạn nên sử dụng Firebase Android BoM để kiểm soát việc lập phiên bản thư viện.

    Ngoài ra, trong quá trình thiết lập trình tải xuống mô hình Firebase ML, bạn cần thêm SDK TensorFlow Lite vào ứng dụng của mình.

    Kotlin+KTX

    dependencies {
        // Import the BoM for the Firebase platform
        implementation(platform("com.google.firebase:firebase-bom:32.3.1"))
    
        // Add the dependency for the Firebase ML model downloader library
        // When using the BoM, you don't specify versions in Firebase library dependencies
        implementation("com.google.firebase:firebase-ml-modeldownloader-ktx")
    // Also add the dependency for the TensorFlow Lite library and specify its version implementation("org.tensorflow:tensorflow-lite:2.3.0")
    }

    Bằng cách sử dụng Firebase Android BoM , ứng dụng của bạn sẽ luôn sử dụng các phiên bản tương thích của thư viện Android Firebase.

    (Thay thế) Thêm phụ thuộc thư viện Firebase mà không cần sử dụng BoM

    Nếu chọn không sử dụng Firebase BoM, bạn phải chỉ định từng phiên bản thư viện Firebase trong dòng phụ thuộc của nó.

    Lưu ý rằng nếu bạn sử dụng nhiều thư viện Firebase trong ứng dụng của mình, chúng tôi thực sự khuyên bạn nên sử dụng BoM để quản lý các phiên bản thư viện, điều này đảm bảo rằng tất cả các phiên bản đều tương thích.

    dependencies {
        // Add the dependency for the Firebase ML model downloader library
        // When NOT using the BoM, you must specify versions in Firebase library dependencies
        implementation("com.google.firebase:firebase-ml-modeldownloader-ktx:24.1.3")
    // Also add the dependency for the TensorFlow Lite library and specify its version implementation("org.tensorflow:tensorflow-lite:2.3.0")
    }

    Java

    dependencies {
        // Import the BoM for the Firebase platform
        implementation(platform("com.google.firebase:firebase-bom:32.3.1"))
    
        // Add the dependency for the Firebase ML model downloader library
        // When using the BoM, you don't specify versions in Firebase library dependencies
        implementation("com.google.firebase:firebase-ml-modeldownloader")
    // Also add the dependency for the TensorFlow Lite library and specify its version implementation("org.tensorflow:tensorflow-lite:2.3.0")
    }

    Bằng cách sử dụng Firebase Android BoM , ứng dụng của bạn sẽ luôn sử dụng các phiên bản tương thích của thư viện Android Firebase.

    (Thay thế) Thêm phụ thuộc thư viện Firebase mà không cần sử dụng BoM

    Nếu chọn không sử dụng Firebase BoM, bạn phải chỉ định từng phiên bản thư viện Firebase trong dòng phụ thuộc của nó.

    Lưu ý rằng nếu bạn sử dụng nhiều thư viện Firebase trong ứng dụng của mình, chúng tôi thực sự khuyên bạn nên sử dụng BoM để quản lý các phiên bản thư viện, điều này đảm bảo rằng tất cả các phiên bản đều tương thích.

    dependencies {
        // Add the dependency for the Firebase ML model downloader library
        // When NOT using the BoM, you must specify versions in Firebase library dependencies
        implementation("com.google.firebase:firebase-ml-modeldownloader:24.1.3")
    // Also add the dependency for the TensorFlow Lite library and specify its version implementation("org.tensorflow:tensorflow-lite:2.3.0")
    }
  3. Trong tệp kê khai ứng dụng của bạn, hãy khai báo rằng cần có quyền INTERNET:
    <uses-permission android:name="android.permission.INTERNET" />

1. Triển khai mô hình của bạn

Triển khai các mô hình TensorFlow tùy chỉnh của bạn bằng cách sử dụng bảng điều khiển Firebase hoặc SDK Python và Node.js dành cho quản trị viên Firebase. Xem Triển khai và quản lý các mô hình tùy chỉnh .

Sau khi thêm mô hình tùy chỉnh vào dự án Firebase, bạn có thể tham chiếu mô hình đó trong ứng dụng của mình bằng tên bạn đã chỉ định. Bất cứ lúc nào, bạn có thể triển khai mô hình TensorFlow Lite mới và tải mô hình mới xuống thiết bị của người dùng bằng cách gọi getModel() (xem bên dưới).

2. Tải mô hình xuống thiết bị và khởi tạo trình thông dịch TensorFlow Lite

Để sử dụng mô hình TensorFlow Lite trong ứng dụng của bạn, trước tiên hãy sử dụng SDK Firebase ML để tải phiên bản mới nhất của mô hình xuống thiết bị. Sau đó, khởi tạo trình thông dịch TensorFlow Lite với mô hình.

Để bắt đầu tải xuống mô hình, hãy gọi phương thức getModel() của trình tải xuống mô hình, chỉ định tên bạn đã gán cho mô hình khi tải lên, bạn có muốn luôn tải xuống mô hình mới nhất hay không và các điều kiện mà bạn muốn cho phép tải xuống.

Bạn có thể chọn từ ba hành vi tải xuống:

loại tải xuống Sự miêu tả
LOCAL_MODEL Lấy mô hình cục bộ từ thiết bị. Nếu không có sẵn mô hình cục bộ, mô hình này sẽ hoạt động như LATEST_MODEL . Sử dụng loại tải xuống này nếu bạn không quan tâm đến việc kiểm tra các bản cập nhật kiểu máy. Ví dụ: bạn đang sử dụng Cấu hình từ xa để truy xuất tên kiểu máy và bạn luôn tải các kiểu máy lên dưới tên mới (được khuyến nghị).
LOCAL_MODEL_UPDATE_IN_BACKGROUND Lấy mô hình cục bộ từ thiết bị và bắt đầu cập nhật mô hình trong nền. Nếu không có sẵn mô hình cục bộ, mô hình này sẽ hoạt động như LATEST_MODEL .
MẪU MỚI NHẤT Lấy mẫu mới nhất. Nếu mô hình cục bộ là phiên bản mới nhất, hãy trả về mô hình cục bộ. Nếu không, hãy tải xuống mô hình mới nhất. Hành vi này sẽ chặn cho đến khi phiên bản mới nhất được tải xuống (không được khuyến nghị). Chỉ sử dụng hành vi này trong trường hợp bạn rõ ràng cần phiên bản mới nhất.

Bạn nên tắt chức năng liên quan đến mô hình—ví dụ: tô xám hoặc ẩn một phần giao diện người dùng của bạn—cho đến khi bạn xác nhận mô hình đã được tải xuống.

Kotlin+KTX

val conditions = CustomModelDownloadConditions.Builder()
        .requireWifi()  // Also possible: .requireCharging() and .requireDeviceIdle()
        .build()
FirebaseModelDownloader.getInstance()
        .getModel("your_model", DownloadType.LOCAL_MODEL_UPDATE_IN_BACKGROUND,
            conditions)
        .addOnSuccessListener { model: CustomModel? ->
            // Download complete. Depending on your app, you could enable the ML
            // feature, or switch from the local model to the remote model, etc.

            // The CustomModel object contains the local path of the model file,
            // which you can use to instantiate a TensorFlow Lite interpreter.
            val modelFile = model?.file
            if (modelFile != null) {
                interpreter = Interpreter(modelFile)
            }
        }

Java

CustomModelDownloadConditions conditions = new CustomModelDownloadConditions.Builder()
    .requireWifi()  // Also possible: .requireCharging() and .requireDeviceIdle()
    .build();
FirebaseModelDownloader.getInstance()
    .getModel("your_model", DownloadType.LOCAL_MODEL_UPDATE_IN_BACKGROUND, conditions)
    .addOnSuccessListener(new OnSuccessListener<CustomModel>() {
      @Override
      public void onSuccess(CustomModel model) {
        // Download complete. Depending on your app, you could enable the ML
        // feature, or switch from the local model to the remote model, etc.

        // The CustomModel object contains the local path of the model file,
        // which you can use to instantiate a TensorFlow Lite interpreter.
        File modelFile = model.getFile();
        if (modelFile != null) {
            interpreter = new Interpreter(modelFile);
        }
      }
    });

Nhiều ứng dụng bắt đầu tác vụ tải xuống trong mã khởi tạo của chúng nhưng bạn có thể thực hiện việc này bất kỳ lúc nào trước khi cần sử dụng mô hình.

3. Thực hiện suy luận trên dữ liệu đầu vào

Nhận hình dạng đầu vào và đầu ra của mô hình của bạn

Trình thông dịch mô hình TensorFlow Lite lấy đầu vào làm đầu vào và tạo đầu ra một hoặc nhiều mảng đa chiều. Các mảng này chứa các giá trị byte , int , long hoặc float . Trước khi bạn có thể truyền dữ liệu cho một mô hình hoặc sử dụng kết quả của nó, bạn phải biết số lượng và kích thước ("hình dạng") của các mảng mà mô hình của bạn sử dụng.

Nếu bạn tự xây dựng mô hình hoặc nếu định dạng đầu vào và đầu ra của mô hình được ghi lại, bạn có thể đã có thông tin này. Nếu không biết hình dạng và loại dữ liệu đầu vào và đầu ra của mô hình, bạn có thể sử dụng trình thông dịch TensorFlow Lite để kiểm tra mô hình của mình. Ví dụ:

con trăn

import tensorflow as tf

interpreter = tf.lite.Interpreter(model_path="your_model.tflite")
interpreter.allocate_tensors()

# Print input shape and type
inputs = interpreter.get_input_details()
print('{} input(s):'.format(len(inputs)))
for i in range(0, len(inputs)):
    print('{} {}'.format(inputs[i]['shape'], inputs[i]['dtype']))

# Print output shape and type
outputs = interpreter.get_output_details()
print('\n{} output(s):'.format(len(outputs)))
for i in range(0, len(outputs)):
    print('{} {}'.format(outputs[i]['shape'], outputs[i]['dtype']))

Đầu ra ví dụ:

1 input(s):
[  1 224 224   3] <class 'numpy.float32'>

1 output(s):
[1 1000] <class 'numpy.float32'>

Chạy trình thông dịch

Sau khi bạn đã xác định định dạng đầu vào và đầu ra của mô hình, hãy lấy dữ liệu đầu vào và thực hiện bất kỳ phép biến đổi nào trên dữ liệu cần thiết để có được đầu vào có hình dạng phù hợp cho mô hình của bạn.

Ví dụ: nếu bạn có mô hình phân loại hình ảnh với hình dạng đầu vào là [1 224 224 3] giá trị dấu phẩy động, thì bạn có thể tạo ByteBuffer đầu vào từ đối tượng Bitmap như minh họa trong ví dụ sau:

Kotlin+KTX

val bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true)
val input = ByteBuffer.allocateDirect(224*224*3*4).order(ByteOrder.nativeOrder())
for (y in 0 until 224) {
    for (x in 0 until 224) {
        val px = bitmap.getPixel(x, y)

        // Get channel values from the pixel value.
        val r = Color.red(px)
        val g = Color.green(px)
        val b = Color.blue(px)

        // Normalize channel values to [-1.0, 1.0]. This requirement depends on the model.
        // For example, some models might require values to be normalized to the range
        // [0.0, 1.0] instead.
        val rf = (r - 127) / 255f
        val gf = (g - 127) / 255f
        val bf = (b - 127) / 255f

        input.putFloat(rf)
        input.putFloat(gf)
        input.putFloat(bf)
    }
}

Java

Bitmap bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true);
ByteBuffer input = ByteBuffer.allocateDirect(224 * 224 * 3 * 4).order(ByteOrder.nativeOrder());
for (int y = 0; y < 224; y++) {
    for (int x = 0; x < 224; x++) {
        int px = bitmap.getPixel(x, y);

        // Get channel values from the pixel value.
        int r = Color.red(px);
        int g = Color.green(px);
        int b = Color.blue(px);

        // Normalize channel values to [-1.0, 1.0]. This requirement depends
        // on the model. For example, some models might require values to be
        // normalized to the range [0.0, 1.0] instead.
        float rf = (r - 127) / 255.0f;
        float gf = (g - 127) / 255.0f;
        float bf = (b - 127) / 255.0f;

        input.putFloat(rf);
        input.putFloat(gf);
        input.putFloat(bf);
    }
}

Sau đó, phân bổ một ByteBuffer đủ lớn để chứa đầu ra của mô hình và chuyển bộ đệm đầu vào và bộ đệm đầu ra cho phương thức run() của trình thông dịch TensorFlow Lite. Ví dụ: đối với hình dạng đầu ra của [1 1000] giá trị dấu phẩy động:

Kotlin+KTX

val bufferSize = 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE
val modelOutput = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder())
interpreter?.run(input, modelOutput)

Java

int bufferSize = 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE;
ByteBuffer modelOutput = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder());
interpreter.run(input, modelOutput);

Cách bạn sử dụng đầu ra tùy thuộc vào kiểu máy bạn đang sử dụng.

Ví dụ: nếu bạn đang thực hiện phân loại, trong bước tiếp theo, bạn có thể ánh xạ các chỉ mục của kết quả tới các nhãn mà chúng đại diện:

Kotlin+KTX

modelOutput.rewind()
val probabilities = modelOutput.asFloatBuffer()
try {
    val reader = BufferedReader(
            InputStreamReader(assets.open("custom_labels.txt")))
    for (i in probabilities.capacity()) {
        val label: String = reader.readLine()
        val probability = probabilities.get(i)
        println("$label: $probability")
    }
} catch (e: IOException) {
    // File not found?
}

Java

modelOutput.rewind();
FloatBuffer probabilities = modelOutput.asFloatBuffer();
try {
    BufferedReader reader = new BufferedReader(
            new InputStreamReader(getAssets().open("custom_labels.txt")));
    for (int i = 0; i < probabilities.capacity(); i++) {
        String label = reader.readLine();
        float probability = probabilities.get(i);
        Log.i(TAG, String.format("%s: %1.4f", label, probability));
    }
} catch (IOException e) {
    // File not found?
}

Phụ lục: Bảo mẫu

Bất kể cách bạn cung cấp các mô hình TensorFlow Lite của mình cho Firebase ML, Firebase ML sẽ lưu trữ chúng ở định dạng protobuf được đánh số thứ tự tiêu chuẩn trong bộ nhớ cục bộ.

Về lý thuyết, điều này có nghĩa là bất kỳ ai cũng có thể sao chép mô hình của bạn. Tuy nhiên, trên thực tế, hầu hết các mô hình đều dành riêng cho ứng dụng và không được tối ưu hóa nên rủi ro tương tự như rủi ro đối thủ cạnh tranh tháo rời và sử dụng lại mã của bạn. Tuy nhiên, bạn nên biết rủi ro này trước khi sử dụng mô hình tùy chỉnh trong ứng dụng của mình.

Trên API Android cấp 21 (Lollipop) trở lên, mô hình được tải xuống một thư mục bị loại trừ khỏi sao lưu tự động .

Trên API Android cấp 20 trở lên, mô hình được tải xuống thư mục có tên com.google.firebase.ml.custom.models trong bộ nhớ trong riêng tư của ứng dụng. Nếu bạn đã bật sao lưu tệp bằng BackupAgent , bạn có thể chọn loại trừ thư mục này.