Sử dụng mô hình TensorFlow Lite để suy luận bằng Bộ công cụ học máy trên Android

Bạn có thể dùng Bộ công cụ học máy để thực hiện suy luận trên thiết bị bằng mô hình TensorFlow Lite.

API này yêu cầu SDK Android cấp 16 (Jelly Bean) trở lên.

Trước khi bắt đầu

  1. Nếu bạn chưa thực hiện, hãy thêm Firebase vào dự án Android.
  2. Thêm các phần phụ thuộc cho thư viện ML Kit Android vào tệp Gradle (ở cấp ứng dụng) trong mô-đun của bạn (thường là app/build.gradle):
    apply plugin: 'com.android.application'
    apply plugin: 'com.google.gms.google-services'
    
    dependencies {
      // ...
    
      implementation 'com.google.firebase:firebase-ml-model-interpreter:22.0.3'
    }
  3. Chuyển đổi mô hình TensorFlow mà bạn muốn sử dụng sang định dạng TensorFlow Lite. Xem TOCO: Trình chuyển đổi tối ưu hoá TensorFlow Lite.

Lưu trữ hoặc đóng gói mô hình của bạn

Trước khi có thể sử dụng mô hình TensorFlow Lite để suy luận trong ứng dụng, bạn phải cung cấp mô hình cho ML Kit. Bộ công cụ học máy có thể sử dụng các mô hình TensorFlow Lite được lưu trữ từ xa bằng Firebase, đi kèm với tệp nhị phân của ứng dụng hoặc cả hai.

Bằng cách lưu trữ một mô hình trên Firebase, bạn có thể cập nhật mô hình mà không cần phát hành phiên bản ứng dụng mới và bạn có thể sử dụng Remote ConfigA/B Testing để phân phát linh hoạt các mô hình khác nhau cho các nhóm người dùng khác nhau.

Nếu chỉ cung cấp mô hình bằng cách lưu trữ mô hình đó với Firebase và không đi kèm mô hình đó với ứng dụng, bạn có thể giảm kích thước tải xuống ban đầu của ứng dụng. Tuy nhiên, hãy lưu ý rằng nếu mô hình không đi kèm với ứng dụng, thì mọi chức năng liên quan đến mô hình sẽ không hoạt động cho đến khi ứng dụng tải mô hình xuống lần đầu tiên.

Bằng cách kết hợp mô hình với ứng dụng, bạn có thể đảm bảo các tính năng ML của ứng dụng vẫn hoạt động khi mô hình do Firebase lưu trữ không có sẵn.

Lưu trữ các mô hình trên Firebase

Cách lưu trữ mô hình TensorFlow Lite trên Firebase:

  1. Trong mục ML Kit của bảng điều khiển Firebase, hãy nhấp vào thẻ Tuỳ chỉnh.
  2. Nhấp vào Thêm mô hình tuỳ chỉnh (hoặc Thêm mô hình khác).
  3. Chỉ định một tên sẽ được dùng để xác định mô hình của bạn trong dự án Firebase, sau đó tải tệp mô hình TensorFlow Lite lên (thường có đuôi là .tflite hoặc .lite).
  4. Trong tệp kê khai của ứng dụng, hãy khai báo rằng bạn cần có quyền INTERNET:
    <uses-permission android:name="android.permission.INTERNET" />

Sau khi thêm một mô hình tuỳ chỉnh vào dự án Firebase, bạn có thể tham chiếu mô hình đó trong các ứng dụng bằng tên mà bạn đã chỉ định. Bất cứ lúc nào, bạn cũng có thể tải một mô hình TensorFlow Lite mới lên và ứng dụng của bạn sẽ tải mô hình mới xuống rồi bắt đầu sử dụng mô hình đó vào lần khởi động lại ứng dụng tiếp theo. Bạn có thể xác định các điều kiện về thiết bị mà ứng dụng cần đáp ứng để cố gắng cập nhật mô hình (xem bên dưới).

Gói các mô hình với một ứng dụng

Để gói mô hình TensorFlow Lite với ứng dụng, hãy sao chép tệp mô hình (thường kết thúc bằng .tflite hoặc .lite) vào thư mục assets/ của ứng dụng. (Trước tiên, bạn có thể cần tạo thư mục bằng cách nhấp chuột phải vào thư mục app/, sau đó nhấp vào New > Folder > Assets Folder (Mới > Thư mục > Thư mục thành phần)).

Sau đó, hãy thêm nội dung sau vào tệp build.gradle của ứng dụng để đảm bảo Gradle không nén các mô hình khi tạo ứng dụng:

android {

    // ...

    aaptOptions {
        noCompress "tflite"  // Your model's file extension: "tflite", "lite", etc.
    }
}

Tệp mô hình sẽ được đưa vào gói ứng dụng và có sẵn cho ML Kit dưới dạng một tài sản thô.

Tải mô hình

Để sử dụng mô hình TensorFlow Lite trong ứng dụng, trước tiên, hãy định cấu hình ML Kit bằng các vị trí mà mô hình của bạn có sẵn: từ xa bằng Firebase, trong bộ nhớ cục bộ hoặc cả hai. Nếu chỉ định cả mô hình cục bộ và mô hình từ xa, bạn có thể dùng mô hình từ xa nếu có và quay lại mô hình được lưu trữ cục bộ nếu mô hình từ xa không có.

Định cấu hình một mô hình được lưu trữ trên Firebase

Nếu bạn lưu trữ mô hình bằng Firebase, hãy tạo một đối tượng FirebaseCustomRemoteModel, chỉ định tên mà bạn đã chỉ định cho mô hình khi tải mô hình lên:

Java

FirebaseCustomRemoteModel remoteModel =
        new FirebaseCustomRemoteModel.Builder("your_model").build();

Kotlin

val remoteModel = FirebaseCustomRemoteModel.Builder("your_model").build()

Sau đó, hãy bắt đầu tác vụ tải mô hình xuống, chỉ định các điều kiện mà bạn muốn cho phép tải xuống. Nếu mô hình không có trên thiết bị hoặc nếu có phiên bản mới hơn của mô hình, thì tác vụ sẽ tải mô hình xuống không đồng bộ từ Firebase:

Java

FirebaseModelDownloadConditions conditions = new FirebaseModelDownloadConditions.Builder()
        .requireWifi()
        .build();
FirebaseModelManager.getInstance().download(remoteModel, conditions)
        .addOnCompleteListener(new OnCompleteListener<Void>() {
            @Override
            public void onComplete(@NonNull Task<Void> task) {
                // Success.
            }
        });

Kotlin

val conditions = FirebaseModelDownloadConditions.Builder()
    .requireWifi()
    .build()
FirebaseModelManager.getInstance().download(remoteModel, conditions)
    .addOnCompleteListener {
        // Success.
    }

Nhiều ứng dụng bắt đầu tác vụ tải xuống trong mã khởi tạo, nhưng bạn có thể thực hiện việc này bất cứ lúc nào trước khi cần sử dụng mô hình.

Định cấu hình mô hình cục bộ

Nếu bạn đã gói mô hình này cùng với ứng dụng, hãy tạo một đối tượng FirebaseCustomLocalModel, chỉ định tên tệp của mô hình TensorFlow Lite:

Java

FirebaseCustomLocalModel localModel = new FirebaseCustomLocalModel.Builder()
        .setAssetFilePath("your_model.tflite")
        .build();

Kotlin

val localModel = FirebaseCustomLocalModel.Builder()
    .setAssetFilePath("your_model.tflite")
    .build()

Tạo một trình thông dịch từ mô hình của bạn

Sau khi định cấu hình các nguồn mô hình, hãy tạo một đối tượng FirebaseModelInterpreter từ một trong các nguồn đó.

Nếu bạn chỉ có một mô hình được gói cục bộ, hãy tạo một trình thông dịch từ đối tượng FirebaseCustomLocalModel:

Java

FirebaseModelInterpreter interpreter;
try {
    FirebaseModelInterpreterOptions options =
            new FirebaseModelInterpreterOptions.Builder(localModel).build();
    interpreter = FirebaseModelInterpreter.getInstance(options);
} catch (FirebaseMLException e) {
    // ...
}

Kotlin

val options = FirebaseModelInterpreterOptions.Builder(localModel).build()
val interpreter = FirebaseModelInterpreter.getInstance(options)

Nếu có một mô hình được lưu trữ từ xa, bạn sẽ phải kiểm tra để đảm bảo mô hình đó đã được tải xuống trước khi chạy. Bạn có thể kiểm tra trạng thái của tác vụ tải mô hình xuống bằng phương thức isModelDownloaded() của trình quản lý mô hình.

Mặc dù bạn chỉ phải xác nhận điều này trước khi chạy trình thông dịch, nhưng nếu có cả mô hình được lưu trữ từ xa và mô hình được gói cục bộ, thì bạn nên thực hiện quy trình kiểm tra này khi tạo thực thể trình thông dịch mô hình: tạo trình thông dịch từ mô hình từ xa nếu mô hình đó đã được tải xuống và từ mô hình cục bộ nếu không.

Java

FirebaseModelManager.getInstance().isModelDownloaded(remoteModel)
        .addOnSuccessListener(new OnSuccessListener<Boolean>() {
            @Override
            public void onSuccess(Boolean isDownloaded) {
                FirebaseModelInterpreterOptions options;
                if (isDownloaded) {
                    options = new FirebaseModelInterpreterOptions.Builder(remoteModel).build();
                } else {
                    options = new FirebaseModelInterpreterOptions.Builder(localModel).build();
                }
                FirebaseModelInterpreter interpreter = FirebaseModelInterpreter.getInstance(options);
                // ...
            }
        });

Kotlin

FirebaseModelManager.getInstance().isModelDownloaded(remoteModel)
    .addOnSuccessListener { isDownloaded -> 
    val options =
        if (isDownloaded) {
            FirebaseModelInterpreterOptions.Builder(remoteModel).build()
        } else {
            FirebaseModelInterpreterOptions.Builder(localModel).build()
        }
    val interpreter = FirebaseModelInterpreter.getInstance(options)
}

Nếu chỉ có một mô hình được lưu trữ từ xa, bạn nên tắt chức năng liên quan đến mô hình (ví dụ: làm mờ hoặc ẩn một phần giao diện người dùng) cho đến khi xác nhận rằng mô hình đã được tải xuống. Bạn có thể làm như vậy bằng cách đính kèm một trình nghe vào phương thức download() của trình quản lý mô hình:

Java

FirebaseModelManager.getInstance().download(remoteModel, conditions)
        .addOnSuccessListener(new OnSuccessListener<Void>() {
            @Override
            public void onSuccess(Void v) {
              // Download complete. Depending on your app, you could enable
              // the ML feature, or switch from the local model to the remote
              // model, etc.
            }
        });

Kotlin

FirebaseModelManager.getInstance().download(remoteModel, conditions)
    .addOnCompleteListener {
        // Download complete. Depending on your app, you could enable the ML
        // feature, or switch from the local model to the remote model, etc.
    }

Chỉ định đầu vào và đầu ra của mô hình

Tiếp theo, hãy định cấu hình định dạng đầu vào và đầu ra của trình thông dịch mô hình.

Mô hình TensorFlow Lite nhận dữ liệu đầu vào và tạo ra dữ liệu đầu ra là một hoặc nhiều mảng đa chiều. Các mảng này chứa các giá trị byte, int, long hoặc float. Bạn phải định cấu hình ML Kit bằng số lượng và kích thước ("hình dạng") của các mảng mà mô hình của bạn sử dụng.

Nếu không biết hình dạng và kiểu dữ liệu của đầu vào và đầu ra của mô hình, bạn có thể dùng trình thông dịch Python của TensorFlow Lite để kiểm tra mô hình. Ví dụ:

import tensorflow as tf

interpreter = tf.lite.Interpreter(model_path="my_model.tflite")
interpreter.allocate_tensors()

# Print input shape and type
print(interpreter.get_input_details()[0]['shape'])  # Example: [1 224 224 3]
print(interpreter.get_input_details()[0]['dtype'])  # Example: <class 'numpy.float32'>

# Print output shape and type
print(interpreter.get_output_details()[0]['shape'])  # Example: [1 1000]
print(interpreter.get_output_details()[0]['dtype'])  # Example: <class 'numpy.float32'>

Sau khi xác định định dạng đầu vào và đầu ra của mô hình, bạn có thể định cấu hình trình thông dịch mô hình của ứng dụng bằng cách tạo một đối tượng FirebaseModelInputOutputOptions.

Ví dụ: một mô hình phân loại hình ảnh dấu phẩy động có thể lấy làm đầu vào một mảng Nx224x224x3 gồm các giá trị float, biểu thị một lô gồm N hình ảnh 224x224 có 3 kênh (RGB) và tạo ra một danh sách gồm 1000 giá trị float làm đầu ra, mỗi giá trị biểu thị xác suất hình ảnh là thành viên của một trong 1000 danh mục mà mô hình dự đoán.

Đối với mô hình như vậy, bạn sẽ định cấu hình đầu vào và đầu ra của trình thông dịch mô hình như minh hoạ dưới đây:

Java

FirebaseModelInputOutputOptions inputOutputOptions =
        new FirebaseModelInputOutputOptions.Builder()
                .setInputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 224, 224, 3})
                .setOutputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 5})
                .build();

Kotlin

val inputOutputOptions = FirebaseModelInputOutputOptions.Builder()
        .setInputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1, 224, 224, 3))
        .setOutputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1, 5))
        .build()

Thực hiện suy luận trên dữ liệu đầu vào

Cuối cùng, để thực hiện suy luận bằng mô hình, hãy lấy dữ liệu đầu vào và thực hiện mọi phép biến đổi cần thiết trên dữ liệu để có được một mảng đầu vào có hình dạng phù hợp cho mô hình của bạn.

Ví dụ: nếu bạn có một mô hình phân loại hình ảnh với hình dạng đầu vào là [1 224 224 3] giá trị dấu phẩy động, bạn có thể tạo một mảng đầu vào từ đối tượng Bitmap như trong ví dụ sau:

Java

Bitmap bitmap = getYourInputImage();
bitmap = Bitmap.createScaledBitmap(bitmap, 224, 224, true);

int batchNum = 0;
float[][][][] input = new float[1][224][224][3];
for (int x = 0; x < 224; x++) {
    for (int y = 0; y < 224; y++) {
        int pixel = bitmap.getPixel(x, y);
        // Normalize channel values to [-1.0, 1.0]. This requirement varies by
        // model. For example, some models might require values to be normalized
        // to the range [0.0, 1.0] instead.
        input[batchNum][x][y][0] = (Color.red(pixel) - 127) / 128.0f;
        input[batchNum][x][y][1] = (Color.green(pixel) - 127) / 128.0f;
        input[batchNum][x][y][2] = (Color.blue(pixel) - 127) / 128.0f;
    }
}

Kotlin

val bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true)

val batchNum = 0
val input = Array(1) { Array(224) { Array(224) { FloatArray(3) } } }
for (x in 0..223) {
    for (y in 0..223) {
        val pixel = bitmap.getPixel(x, y)
        // Normalize channel values to [-1.0, 1.0]. This requirement varies by
        // model. For example, some models might require values to be normalized
        // to the range [0.0, 1.0] instead.
        input[batchNum][x][y][0] = (Color.red(pixel) - 127) / 255.0f
        input[batchNum][x][y][1] = (Color.green(pixel) - 127) / 255.0f
        input[batchNum][x][y][2] = (Color.blue(pixel) - 127) / 255.0f
    }
}

Sau đó, hãy tạo một đối tượng FirebaseModelInputs bằng dữ liệu đầu vào của bạn, rồi truyền đối tượng đó cùng với quy cách đầu vào và đầu ra của mô hình đến phương thức run của trình thông dịch mô hình:

Java

FirebaseModelInputs inputs = new FirebaseModelInputs.Builder()
        .add(input)  // add() as many input arrays as your model requires
        .build();
firebaseInterpreter.run(inputs, inputOutputOptions)
        .addOnSuccessListener(
                new OnSuccessListener<FirebaseModelOutputs>() {
                    @Override
                    public void onSuccess(FirebaseModelOutputs result) {
                        // ...
                    }
                })
        .addOnFailureListener(
                new OnFailureListener() {
                    @Override
                    public void onFailure(@NonNull Exception e) {
                        // Task failed with an exception
                        // ...
                    }
                });

Kotlin

val inputs = FirebaseModelInputs.Builder()
        .add(input) // add() as many input arrays as your model requires
        .build()
firebaseInterpreter.run(inputs, inputOutputOptions)
        .addOnSuccessListener { result ->
            // ...
        }
        .addOnFailureListener { e ->
            // Task failed with an exception
            // ...
        }

Nếu lệnh gọi thành công, bạn có thể nhận được đầu ra bằng cách gọi phương thức getOutput() của đối tượng được truyền đến trình nghe thành công. Ví dụ:

Java

float[][] output = result.getOutput(0);
float[] probabilities = output[0];

Kotlin

val output = result.getOutput<Array<FloatArray>>(0)
val probabilities = output[0]

Cách bạn sử dụng đầu ra phụ thuộc vào mô hình bạn đang dùng.

Ví dụ: nếu đang thực hiện phân loại, thì bước tiếp theo có thể là bạn sẽ ánh xạ các chỉ mục của kết quả với nhãn mà chúng đại diện:

Java

BufferedReader reader = new BufferedReader(
        new InputStreamReader(getAssets().open("retrained_labels.txt")));
for (int i = 0; i < probabilities.length; i++) {
    String label = reader.readLine();
    Log.i("MLKit", String.format("%s: %1.4f", label, probabilities[i]));
}

Kotlin

val reader = BufferedReader(
        InputStreamReader(assets.open("retrained_labels.txt")))
for (i in probabilities.indices) {
    val label = reader.readLine()
    Log.i("MLKit", String.format("%s: %1.4f", label, probabilities[i]))
}

Phụ lục: Bảo mật mô hình

Bất kể bạn cung cấp các mô hình TensorFlow Lite cho Bộ công cụ học máy bằng cách nào, Bộ công cụ học máy đều lưu trữ các mô hình đó ở định dạng protobuf được chuyển đổi tuần tự tiêu chuẩn trong bộ nhớ cục bộ.

Về lý thuyết, điều này có nghĩa là bất kỳ ai cũng có thể sao chép mô hình của bạn. Tuy nhiên, trên thực tế, hầu hết các mô hình đều quá dành riêng cho ứng dụng và bị làm rối mã nguồn bằng các quy trình tối ưu hoá đến mức rủi ro tương tự như việc đối thủ cạnh tranh tháo rời và sử dụng lại mã của bạn. Tuy nhiên, bạn cần lưu ý đến rủi ro này trước khi sử dụng một mô hình tuỳ chỉnh trong ứng dụng của mình.

Trên Android API cấp 21 (Lollipop) trở lên, mô hình được tải xuống một thư mục được loại trừ khỏi tính năng sao lưu tự động.

Trên Android API cấp 20 trở xuống, mô hình này được tải xuống một thư mục có tên là com.google.firebase.ml.custom.models trong bộ nhớ trong riêng của ứng dụng. Nếu đã bật tính năng sao lưu tệp bằng BackupAgent, bạn có thể chọn loại trừ thư mục này.