Sử dụng mô hình TensorFlow Lite để suy luận bằng Bộ công cụ học máy trên Android

Bạn có thể sử dụng Bộ công cụ học máy để tiến hành suy luận trên thiết bị bằng một Mô hình TensorFlow Lite.

API này yêu cầu SDK Android cấp 16 (Jelly Bean) trở lên.

Trước khi bắt đầu

  1. Nếu bạn chưa làm như vậy, thêm Firebase vào dự án Android của bạn.
  2. Thêm các phần phụ thuộc của thư viện Android cho Bộ công cụ học máy vào mô-đun của bạn Tệp Gradle (cấp ứng dụng) (thường là app/build.gradle):
    apply plugin: 'com.android.application'
    apply plugin: 'com.google.gms.google-services'
    
    dependencies {
      // ...
    
      implementation 'com.google.firebase:firebase-ml-model-interpreter:22.0.3'
    }
    
  3. Chuyển đổi mô hình TensorFlow mà bạn muốn sử dụng sang định dạng TensorFlow Lite. Xem TOCO: Trình chuyển đổi tối ưu hoá TensorFlow Lite.

Lưu trữ hoặc nhóm mô hình của bạn

Trước khi có thể sử dụng mô hình TensorFlow Lite để suy luận trong ứng dụng, bạn phải cung cấp mô hình đó cho Bộ công cụ học máy. Bộ công cụ học máy có thể sử dụng TensorFlow Lite các mô hình được lưu trữ từ xa bằng Firebase, đi kèm với tệp nhị phân của ứng dụng hoặc cả hai.

Bằng cách lưu trữ mô hình trên Firebase, bạn có thể cập nhật mô hình đó mà không cần phát hành phiên bản ứng dụng mới và bạn có thể sử dụng Cấu hình từ xa và Thử nghiệm A/B để phân phát linh động các mô hình khác nhau cho các nhóm người dùng khác nhau.

Nếu bạn chỉ chọn cung cấp mô hình bằng cách lưu trữ mô hình đó bằng Firebase, chứ không phải hãy kết hợp ứng dụng đó với ứng dụng, bạn có thể giảm kích thước tải xuống ban đầu của ứng dụng. Mặc dù vậy, hãy lưu ý rằng nếu mô hình không được đóng gói với ứng dụng của bạn, bất kỳ sẽ không có sẵn chức năng liên quan đến mô hình cho đến khi ứng dụng của bạn tải xuống mô hình lần đầu tiên.

Bằng cách kết hợp mô hình với ứng dụng, bạn có thể đảm bảo các tính năng học máy của ứng dụng vẫn hoạt động khi không có mô hình lưu trữ trên Firebase.

Mô hình lưu trữ trên Firebase

Cách lưu trữ mô hình TensorFlow Lite trên Firebase:

  1. Trong mục Bộ công cụ học máy của bảng điều khiển của Firebase, hãy nhấp vào thẻ Tuỳ chỉnh.
  2. Nhấp vào Thêm mô hình tuỳ chỉnh (hoặc Thêm mô hình khác).
  3. Chỉ định tên sẽ được dùng để xác định mô hình trong Firebase dự án, sau đó tải tệp mô hình TensorFlow Lite lên (thường kết thúc bằng .tflite hoặc .lite).
  4. Trong tệp kê khai của ứng dụng, hãy khai báo rằng cần phải có quyền INTERNET:
    <uses-permission android:name="android.permission.INTERNET" />
    

Sau khi thêm mô hình tùy chỉnh vào dự án Firebase, bạn có thể tham khảo trong các ứng dụng của mình bằng tên mà bạn đã chỉ định. Bất cứ lúc nào, bạn cũng có thể tải lên một mô hình TensorFlow Lite mới và ứng dụng của bạn sẽ tải mô hình mới xuống cũng như bắt đầu sử dụng vào lần tiếp theo ứng dụng khởi động lại. Bạn có thể xác định thiết bị các điều kiện cần thiết để ứng dụng của bạn cố gắng cập nhật mô hình (xem bên dưới).

Gộp các mô hình bằng một ứng dụng

Để nhóm mô hình TensorFlow Lite với ứng dụng của bạn, hãy sao chép tệp mô hình (thường là kết thúc bằng .tflite hoặc .lite) vào thư mục assets/ của ứng dụng. (Bạn có thể cần để tạo thư mục trước tiên, hãy nhấp chuột phải vào thư mục app/ rồi nhấp vào Mới > Thư mục > Thư mục thành phần.)

Sau đó, hãy thêm đoạn mã sau vào tệp build.gradle của ứng dụng để đảm bảo Gradle không nén các mô hình khi tạo ứng dụng:

android {

    // ...

    aaptOptions {
        noCompress "tflite"  // Your model's file extension: "tflite", "lite", etc.
    }
}

Tệp mô hình sẽ được đưa vào gói ứng dụng và được cung cấp cho Bộ công cụ học máy dưới dạng nội dung thô.

Tải mô hình

Để sử dụng mô hình TensorFlow Lite trong ứng dụng của bạn, trước tiên hãy định cấu hình Bộ công cụ học máy bằng các vị trí nơi mô hình của bạn có thể sử dụng: từ xa bằng Firebase, bộ nhớ cục bộ hoặc cả hai. Nếu bạn chỉ định cả mô hình cục bộ và từ xa, bạn có thể sử dụng mô hình điều khiển từ xa (nếu có) và quay lại dùng mô hình được lưu trữ cục bộ nếu không có mô hình từ xa.

Định cấu hình mô hình lưu trữ trên Firebase

Nếu bạn lưu trữ mô hình bằng Firebase, hãy tạo một FirebaseCustomRemoteModel , chỉ định tên mà bạn đã chỉ định cho mô hình khi tải lên:

Java

FirebaseCustomRemoteModel remoteModel =
        new FirebaseCustomRemoteModel.Builder("your_model").build();

Kotlin+KTX

val remoteModel = FirebaseCustomRemoteModel.Builder("your_model").build()

Sau đó, hãy bắt đầu tác vụ tải mô hình xuống, xác định các điều kiện mà bạn muốn cho phép tải xuống. Nếu kiểu máy này không có trên thiết bị hoặc nếu là kiểu máy mới hơn phiên bản của mô hình sẵn có, tác vụ sẽ tải xuống không đồng bộ mô hình từ Firebase:

Java

FirebaseModelDownloadConditions conditions = new FirebaseModelDownloadConditions.Builder()
        .requireWifi()
        .build();
FirebaseModelManager.getInstance().download(remoteModel, conditions)
        .addOnCompleteListener(new OnCompleteListener<Void>() {
            @Override
            public void onComplete(@NonNull Task<Void> task) {
                // Success.
            }
        });

Kotlin+KTX

val conditions = FirebaseModelDownloadConditions.Builder()
    .requireWifi()
    .build()
FirebaseModelManager.getInstance().download(remoteModel, conditions)
    .addOnCompleteListener {
        // Success.
    }

Nhiều ứng dụng bắt đầu tác vụ tải xuống trong mã khởi chạy của ứng dụng, nhưng bạn có thể làm được vì vậy, tại bất kỳ thời điểm nào trước khi cần sử dụng mô hình này.

Định cấu hình mô hình cục bộ

Nếu bạn đã đóng gói mô hình với ứng dụng của mình, hãy tạo một FirebaseCustomLocalModel xác định tên tệp của mô hình TensorFlow Lite:

Java

FirebaseCustomLocalModel localModel = new FirebaseCustomLocalModel.Builder()
        .setAssetFilePath("your_model.tflite")
        .build();

Kotlin+KTX

val localModel = FirebaseCustomLocalModel.Builder()
    .setAssetFilePath("your_model.tflite")
    .build()

Tạo phiên dịch từ mô hình của bạn

Sau khi bạn định cấu hình các nguồn mô hình, hãy tạo một FirebaseModelInterpreter khỏi một trong số chúng.

Nếu bạn chỉ có mô hình được gói cục bộ, chỉ cần tạo trình thông dịch từ Đối tượng FirebaseCustomLocalModel:

Java

FirebaseModelInterpreter interpreter;
try {
    FirebaseModelInterpreterOptions options =
            new FirebaseModelInterpreterOptions.Builder(localModel).build();
    interpreter = FirebaseModelInterpreter.getInstance(options);
} catch (FirebaseMLException e) {
    // ...
}

Kotlin+KTX

val options = FirebaseModelInterpreterOptions.Builder(localModel).build()
val interpreter = FirebaseModelInterpreter.getInstance(options)

Nếu có mô hình được lưu trữ từ xa, bạn sẽ phải kiểm tra xem mô hình đó đã được tải xuống trước khi chạy nó. Bạn có thể kiểm tra trạng thái tải mô hình xuống bằng cách sử dụng phương thức isModelDownloaded() của trình quản lý mô hình.

Mặc dù bạn chỉ phải xác nhận điều này trước khi chạy phiên dịch, nếu bạn có cả mô hình được lưu trữ từ xa và mô hình được gói cục bộ, thực hiện kiểm tra này khi tạo thực thể cho trình thông dịch mô hình: tạo một từ mô hình từ xa nếu mô hình đã được tải xuống và từ mô hình mô hình khác.

Java

FirebaseModelManager.getInstance().isModelDownloaded(remoteModel)
        .addOnSuccessListener(new OnSuccessListener<Boolean>() {
            @Override
            public void onSuccess(Boolean isDownloaded) {
                FirebaseModelInterpreterOptions options;
                if (isDownloaded) {
                    options = new FirebaseModelInterpreterOptions.Builder(remoteModel).build();
                } else {
                    options = new FirebaseModelInterpreterOptions.Builder(localModel).build();
                }
                FirebaseModelInterpreter interpreter = FirebaseModelInterpreter.getInstance(options);
                // ...
            }
        });

Kotlin+KTX

FirebaseModelManager.getInstance().isModelDownloaded(remoteModel)
    .addOnSuccessListener { isDownloaded -> 
    val options =
        if (isDownloaded) {
            FirebaseModelInterpreterOptions.Builder(remoteModel).build()
        } else {
            FirebaseModelInterpreterOptions.Builder(localModel).build()
        }
    val interpreter = FirebaseModelInterpreter.getInstance(options)
}

Nếu chỉ có một mô hình được lưu trữ từ xa, bạn nên tắt tính năng liên quan đến mô hình đó chức năng (ví dụ: chuyển sang màu xám hoặc ẩn một phần giao diện người dùng) cho đến khi bạn xác nhận mô hình đã được tải xuống. Bạn có thể thực hiện việc này bằng cách đính kèm một trình nghe đối với phương thức download() của trình quản lý mô hình:

Java

FirebaseModelManager.getInstance().download(remoteModel, conditions)
        .addOnSuccessListener(new OnSuccessListener<Void>() {
            @Override
            public void onSuccess(Void v) {
              // Download complete. Depending on your app, you could enable
              // the ML feature, or switch from the local model to the remote
              // model, etc.
            }
        });

Kotlin+KTX

FirebaseModelManager.getInstance().download(remoteModel, conditions)
    .addOnCompleteListener {
        // Download complete. Depending on your app, you could enable the ML
        // feature, or switch from the local model to the remote model, etc.
    }

Chỉ định dữ liệu đầu vào và đầu ra của mô hình

Tiếp theo, hãy định cấu hình định dạng đầu vào và đầu ra của trình phiên dịch mô hình.

Mô hình TensorFlow Lite lấy dữ liệu đầu vào và tạo ra một hoặc nhiều đầu ra mảng đa chiều. Các mảng này chứa byte, Giá trị int, long hoặc float. Bạn phải định cấu hình Bộ công cụ học máy bằng số lượng và kích thước ("hình dạng") của các mảng mà bạn mô hình sử dụng.

Nếu bạn không biết hình dạng và kiểu dữ liệu của đầu vào và đầu ra của mô hình, bạn có thể sử dụng trình thông dịch TensorFlow Lite Python để kiểm tra mô hình của mình. Cho ví dụ:

import tensorflow as tf

interpreter = tf.lite.Interpreter(model_path="my_model.tflite")
interpreter.allocate_tensors()

# Print input shape and type
print(interpreter.get_input_details()[0]['shape'])  # Example: [1 224 224 3]
print(interpreter.get_input_details()[0]['dtype'])  # Example: <class 'numpy.float32'>

# Print output shape and type
print(interpreter.get_output_details()[0]['shape'])  # Example: [1 1000]
print(interpreter.get_output_details()[0]['dtype'])  # Example: <class 'numpy.float32'>

Sau khi xác định định dạng cho đầu vào và đầu ra của mô hình, bạn có thể định cấu hình trình thông dịch mô hình của ứng dụng bằng cách tạo một Đối tượng FirebaseModelInputOutputOptions.

Ví dụ: mô hình phân loại hình ảnh dấu phẩy động có thể lấy dữ liệu đầu vào là NMảng x224x224x3 của float, đại diện cho một loạt N hình ảnh ba kênh (RGB) 224x224 và xuất ra danh sách 1000 giá trị float, mỗi giá trị thể hiện xác suất mà hình ảnh là một thành phần của một trong 1.000 danh mục mà mô hình dự đoán.

Đối với mô hình như vậy, bạn sẽ định cấu hình đầu vào và đầu ra của trình phiên dịch mô hình như minh hoạ dưới đây:

Java

FirebaseModelInputOutputOptions inputOutputOptions =
        new FirebaseModelInputOutputOptions.Builder()
                .setInputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 224, 224, 3})
                .setOutputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 5})
                .build();

Kotlin+KTX

val inputOutputOptions = FirebaseModelInputOutputOptions.Builder()
        .setInputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1, 224, 224, 3))
        .setOutputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1, 5))
        .build()

Tiến hành suy luận về dữ liệu đầu vào

Cuối cùng, để tiến hành suy luận bằng mô hình này, hãy lấy dữ liệu đầu vào và thực hiện bất kỳ biến đổi nào trên dữ liệu cần thiết để có được một mảng đầu vào của hình dạng phù hợp cho mô hình của bạn.

Ví dụ: nếu bạn có mô hình phân loại hình ảnh với hình dạng đầu vào là [1 224 224 3] giá trị dấu phẩy động, bạn có thể tạo một mảng đầu vào từ một Đối tượng Bitmap như trong ví dụ sau:

Java

Bitmap bitmap = getYourInputImage();
bitmap = Bitmap.createScaledBitmap(bitmap, 224, 224, true);

int batchNum = 0;
float[][][][] input = new float[1][224][224][3];
for (int x = 0; x < 224; x++) {
    for (int y = 0; y < 224; y++) {
        int pixel = bitmap.getPixel(x, y);
        // Normalize channel values to [-1.0, 1.0]. This requirement varies by
        // model. For example, some models might require values to be normalized
        // to the range [0.0, 1.0] instead.
        input[batchNum][x][y][0] = (Color.red(pixel) - 127) / 128.0f;
        input[batchNum][x][y][1] = (Color.green(pixel) - 127) / 128.0f;
        input[batchNum][x][y][2] = (Color.blue(pixel) - 127) / 128.0f;
    }
}

Kotlin+KTX

val bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true)

val batchNum = 0
val input = Array(1) { Array(224) { Array(224) { FloatArray(3) } } }
for (x in 0..223) {
    for (y in 0..223) {
        val pixel = bitmap.getPixel(x, y)
        // Normalize channel values to [-1.0, 1.0]. This requirement varies by
        // model. For example, some models might require values to be normalized
        // to the range [0.0, 1.0] instead.
        input[batchNum][x][y][0] = (Color.red(pixel) - 127) / 255.0f
        input[batchNum][x][y][1] = (Color.green(pixel) - 127) / 255.0f
        input[batchNum][x][y][2] = (Color.blue(pixel) - 127) / 255.0f
    }
}

Sau đó, hãy tạo một đối tượng FirebaseModelInputs bằng dữ liệu đầu vào rồi truyền dữ liệu đó cũng như thông số đầu vào và đầu ra của mô hình đến phương thức run của trình phiên dịch mô hình:

Java

FirebaseModelInputs inputs = new FirebaseModelInputs.Builder()
        .add(input)  // add() as many input arrays as your model requires
        .build();
firebaseInterpreter.run(inputs, inputOutputOptions)
        .addOnSuccessListener(
                new OnSuccessListener<FirebaseModelOutputs>() {
                    @Override
                    public void onSuccess(FirebaseModelOutputs result) {
                        // ...
                    }
                })
        .addOnFailureListener(
                new OnFailureListener() {
                    @Override
                    public void onFailure(@NonNull Exception e) {
                        // Task failed with an exception
                        // ...
                    }
                });

Kotlin+KTX

val inputs = FirebaseModelInputs.Builder()
        .add(input) // add() as many input arrays as your model requires
        .build()
firebaseInterpreter.run(inputs, inputOutputOptions)
        .addOnSuccessListener { result ->
            // ...
        }
        .addOnFailureListener { e ->
            // Task failed with an exception
            // ...
        }

Nếu lệnh gọi thành công, bạn có thể nhận kết quả bằng cách gọi phương thức getOutput() của đối tượng được truyền đến trình nghe thành công. Ví dụ:

Java

float[][] output = result.getOutput(0);
float[] probabilities = output[0];

Kotlin+KTX

val output = result.getOutput<Array<FloatArray>>(0)
val probabilities = output[0]

Cách bạn sử dụng dữ liệu đầu ra phụ thuộc vào mô hình bạn đang sử dụng.

Ví dụ: nếu bạn đang thực hiện phân loại, trong bước tiếp theo, bạn có thể ánh xạ chỉ mục của kết quả với nhãn mà chúng đại diện:

Java

BufferedReader reader = new BufferedReader(
        new InputStreamReader(getAssets().open("retrained_labels.txt")));
for (int i = 0; i < probabilities.length; i++) {
    String label = reader.readLine();
    Log.i("MLKit", String.format("%s: %1.4f", label, probabilities[i]));
}

Kotlin+KTX

val reader = BufferedReader(
        InputStreamReader(assets.open("retrained_labels.txt")))
for (i in probabilities.indices) {
    val label = reader.readLine()
    Log.i("MLKit", String.format("%s: %1.4f", label, probabilities[i]))
}

Phụ lục: Bảo mật mô hình

Bất kể bạn áp dụng mô hình TensorFlow Lite bằng cách nào ML Kit, ML Kit lưu trữ chúng ở định dạng protobuf được chuyển đổi tuần tự tiêu chuẩn trong lưu trữ cục bộ.

Về mặt lý thuyết, điều này có nghĩa là bất kỳ ai cũng có thể sao chép mô hình của bạn. Tuy nhiên, trong thực tế, hầu hết các mô hình đều dành riêng cho ứng dụng và bị làm rối mã nguồn rủi ro tương tự như các biện pháp tối ưu hoá của đối thủ cạnh tranh bị loại bỏ và việc sử dụng lại mã. Tuy nhiên, bạn nên lưu ý rủi ro này trước khi sử dụng một mô hình tuỳ chỉnh trong ứng dụng của bạn.

Trên API Android cấp 21 (Lollipop) trở lên, mô hình sẽ được tải xuống thư mục là bị loại trừ khỏi tính năng sao lưu tự động.

Trên API Android cấp 20 trở xuống, mô hình sẽ được tải xuống một thư mục có tên là com.google.firebase.ml.custom.models ở chế độ riêng tư trong ứng dụng bộ nhớ trong. Nếu bạn đã bật tính năng sao lưu tệp bằng BackupAgent, bạn có thể chọn loại trừ thư mục này.