ใช้โมเดล TensorFlow Lite เพื่อการอนุมานด้วย ML Kit บน Android

คุณใช้ ML Kit เพื่อทำการอนุมานบนอุปกรณ์ด้วยโมเดล TensorFlow Lite ได้

API นี้ต้องใช้ Android SDK ระดับ 16 (Jelly Bean) ขึ้นไป

ก่อนเริ่มต้น

  1. เพิ่ม Firebase ลงในโปรเจ็กต์ Android หากยังไม่ได้เพิ่ม
  2. เพิ่มทรัพยากร Dependency สำหรับคลัง Android ของ ML Kit ลงในไฟล์ Gradle ของโมดูล (ระดับแอป) (โดยปกติคือ app/build.gradle)
    apply plugin: 'com.android.application'
    apply plugin: 'com.google.gms.google-services'
    
    dependencies {
      // ...
    
      implementation 'com.google.firebase:firebase-ml-model-interpreter:22.0.3'
    }
  3. แปลงโมเดล TensorFlow ที่ต้องการใช้เป็นรูปแบบ TensorFlow Lite ดู TOCO: ตัวแปลงการเพิ่มประสิทธิภาพ TensorFlow Lite

โฮสต์หรือรวมโมเดลของคุณ

ก่อนที่จะใช้โมเดล TensorFlow Lite เพื่อการอนุมานในแอปได้ คุณต้องทำให้โมเดลพร้อมใช้งานกับ ML Kit ML Kit สามารถใช้โมเดล TensorFlow Lite ที่โฮสต์จากระยะไกลโดยใช้ Firebase, รวมไว้กับไบนารีของแอป หรือทั้ง 2 อย่าง

การโฮสต์โมเดลใน Firebase ช่วยให้คุณอัปเดตโมเดลได้โดยไม่ต้องเผยแพร่แอปเวอร์ชันใหม่ และใช้ Remote Config และ A/B Testing เพื่อแสดงโมเดลที่แตกต่างกันแบบไดนามิกต่อผู้ใช้กลุ่มต่างๆ ได้

หากเลือกที่จะระบุเฉพาะโมเดลโดยโฮสต์ด้วย Firebase และไม่รวมไว้กับแอป คุณจะลดขนาดการดาวน์โหลดเริ่มต้นของแอปได้ โปรดทราบว่าหากไม่ได้รวมโมเดลไว้กับแอป ฟังก์ชันที่เกี่ยวข้องกับโมเดลจะใช้ไม่ได้จนกว่าแอปจะดาวน์โหลดโมเดลเป็นครั้งแรก

การรวมโมเดลไว้กับแอปจะช่วยให้มั่นใจได้ว่าฟีเจอร์ ML ของแอปจะยังคงทำงานได้เมื่อโมเดลที่โฮสต์ใน Firebase ไม่พร้อมใช้งาน

โฮสต์โมเดลใน Firebase

วิธีโฮสต์โมเดล TensorFlow Lite ใน Firebase

  1. ในส่วน ML Kit ของFirebase คอนโซล ให้คลิก แท็บกำหนดเอง
  2. คลิกเพิ่มโมเดลที่กำหนดเอง (หรือเพิ่มโมเดลอื่น)
  3. ระบุชื่อที่จะใช้เพื่อระบุโมเดลในโปรเจ็กต์ Firebase จากนั้นอัปโหลดไฟล์โมเดล TensorFlow Lite (โดยปกติจะลงท้ายด้วย .tflite หรือ .lite)
  4. ในไฟล์ Manifest ของแอป ให้ประกาศว่าต้องมีสิทธิ์ INTERNET
    <uses-permission android:name="android.permission.INTERNET" />

หลังจากเพิ่มโมเดลที่กำหนดเองลงในโปรเจ็กต์ Firebase แล้ว คุณจะอ้างอิงโมเดลในแอปโดยใช้ชื่อที่ระบุได้ คุณอัปโหลด โมเดล TensorFlow Lite ใหม่ได้ทุกเมื่อ และแอปจะดาวน์โหลดโมเดลใหม่และ เริ่มใช้เมื่อแอปรีสตาร์ทครั้งถัดไป คุณกำหนดเงื่อนไขของอุปกรณ์ ที่แอปต้องใช้เพื่อพยายามอัปเดตโมเดลได้ (ดูด้านล่าง)

รวมโมเดลกับแอป

หากต้องการรวมโมเดล TensorFlow Lite กับแอป ให้คัดลอกไฟล์โมเดล (โดยปกติจะลงท้ายด้วย .tflite หรือ .lite) ไปยังโฟลเดอร์ assets/ ของแอป (คุณอาจต้องสร้างโฟลเดอร์ก่อนโดยคลิกขวาที่โฟลเดอร์ app/ แล้วคลิกใหม่ > โฟลเดอร์ > โฟลเดอร์ชิ้นงาน)

จากนั้นเพิ่มโค้ดต่อไปนี้ลงในไฟล์ build.gradle ของแอปเพื่อให้แน่ใจว่า Gradle จะไม่บีบอัดโมเดลเมื่อสร้างแอป

android {

    // ...

    aaptOptions {
        noCompress "tflite"  // Your model's file extension: "tflite", "lite", etc.
    }
}

ไฟล์โมเดลจะรวมอยู่ในแพ็กเกจแอปและพร้อมใช้งานใน ML Kit เป็นเนื้อหาดิบ

โหลดโมเดล

หากต้องการใช้โมเดล TensorFlow Lite ในแอป ให้กำหนดค่า ML Kit ด้วยตำแหน่งที่โมเดลพร้อมใช้งานก่อน ซึ่งอาจเป็นแบบระยะไกลโดยใช้ Firebase, ในพื้นที่เก็บข้อมูลในเครื่อง หรือทั้ง 2 อย่าง หากระบุทั้งโมเดลในเครื่องและโมเดลระยะไกล คุณจะใช้โมเดลระยะไกลได้หากมี และถอยกลับไปใช้ โมเดลที่เก็บไว้ในเครื่องหากโมเดลระยะไกลไม่พร้อมใช้งาน

กำหนดค่าโมเดลที่โฮสต์ใน Firebase

หากโฮสต์โมเดลด้วย Firebase ให้สร้างFirebaseCustomRemoteModel ออบเจ็กต์ โดยระบุชื่อที่คุณกำหนดให้กับโมเดลเมื่ออัปโหลด

Java

FirebaseCustomRemoteModel remoteModel =
        new FirebaseCustomRemoteModel.Builder("your_model").build();

Kotlin

val remoteModel = FirebaseCustomRemoteModel.Builder("your_model").build()

จากนั้นเริ่มงานดาวน์โหลดโมเดล โดยระบุเงื่อนไขที่คุณต้องการอนุญาตให้ดาวน์โหลด หากโมเดลไม่ได้อยู่ในอุปกรณ์ หรือหากมีโมเดลเวอร์ชันใหม่กว่า งานจะดาวน์โหลดโมเดลจาก Firebase แบบไม่พร้อมกัน

Java

FirebaseModelDownloadConditions conditions = new FirebaseModelDownloadConditions.Builder()
        .requireWifi()
        .build();
FirebaseModelManager.getInstance().download(remoteModel, conditions)
        .addOnCompleteListener(new OnCompleteListener<Void>() {
            @Override
            public void onComplete(@NonNull Task<Void> task) {
                // Success.
            }
        });

Kotlin

val conditions = FirebaseModelDownloadConditions.Builder()
    .requireWifi()
    .build()
FirebaseModelManager.getInstance().download(remoteModel, conditions)
    .addOnCompleteListener {
        // Success.
    }

แอปจำนวนมากจะเริ่มงานดาวน์โหลดในโค้ดการเริ่มต้น แต่คุณสามารถทำได้ทุกเมื่อก่อนที่จะต้องใช้โมเดล

กำหนดค่าโมเดลในเครื่อง

หากคุณรวมโมเดลไว้กับแอป ให้สร้างออบเจ็กต์ FirebaseCustomLocalModel โดยระบุชื่อไฟล์ของโมเดล TensorFlow Lite ดังนี้

Java

FirebaseCustomLocalModel localModel = new FirebaseCustomLocalModel.Builder()
        .setAssetFilePath("your_model.tflite")
        .build();

Kotlin

val localModel = FirebaseCustomLocalModel.Builder()
    .setAssetFilePath("your_model.tflite")
    .build()

สร้างโปรแกรมตีความจากโมเดล

หลังจากกำหนดค่าแหล่งที่มาของโมเดลแล้ว ให้สร้างFirebaseModelInterpreter ออบเจ็กต์จากแหล่งที่มาใดแหล่งที่มาหนึ่ง

หากมีเฉพาะโมเดลที่รวมไว้ในเครื่อง ให้สร้างอินเทอร์พรีเตอร์จากออบเจ็กต์ FirebaseCustomLocalModel ดังนี้

Java

FirebaseModelInterpreter interpreter;
try {
    FirebaseModelInterpreterOptions options =
            new FirebaseModelInterpreterOptions.Builder(localModel).build();
    interpreter = FirebaseModelInterpreter.getInstance(options);
} catch (FirebaseMLException e) {
    // ...
}

Kotlin

val options = FirebaseModelInterpreterOptions.Builder(localModel).build()
val interpreter = FirebaseModelInterpreter.getInstance(options)

หากมีโมเดลที่โฮสต์จากระยะไกล คุณจะต้องตรวจสอบว่าได้ ดาวน์โหลดโมเดลแล้วก่อนที่จะเรียกใช้ คุณตรวจสอบสถานะของงานดาวน์โหลดโมเดลได้โดยใช้เมธอด isModelDownloaded() ของตัวจัดการโมเดล

แม้ว่าคุณจะต้องยืนยันเรื่องนี้ก่อนเรียกใช้ Interpreter แต่หากมีทั้งโมเดลที่โฮสต์จากระยะไกลและโมเดลที่รวมไว้ในเครื่อง คุณอาจต้องตรวจสอบนี้เมื่อสร้างอินสแตนซ์ของโมเดล Interpreter นั่นคือ สร้าง Interpreter จากโมเดลระยะไกลหากดาวน์โหลดแล้ว และจากโมเดลในเครื่องหากยังไม่ได้ดาวน์โหลด

Java

FirebaseModelManager.getInstance().isModelDownloaded(remoteModel)
        .addOnSuccessListener(new OnSuccessListener<Boolean>() {
            @Override
            public void onSuccess(Boolean isDownloaded) {
                FirebaseModelInterpreterOptions options;
                if (isDownloaded) {
                    options = new FirebaseModelInterpreterOptions.Builder(remoteModel).build();
                } else {
                    options = new FirebaseModelInterpreterOptions.Builder(localModel).build();
                }
                FirebaseModelInterpreter interpreter = FirebaseModelInterpreter.getInstance(options);
                // ...
            }
        });

Kotlin

FirebaseModelManager.getInstance().isModelDownloaded(remoteModel)
    .addOnSuccessListener { isDownloaded -> 
    val options =
        if (isDownloaded) {
            FirebaseModelInterpreterOptions.Builder(remoteModel).build()
        } else {
            FirebaseModelInterpreterOptions.Builder(localModel).build()
        }
    val interpreter = FirebaseModelInterpreter.getInstance(options)
}

หากมีเฉพาะโมเดลที่โฮสต์จากระยะไกล คุณควรปิดใช้ฟังก์ชันที่เกี่ยวข้องกับโมเดล เช่น ทำให้ส่วนหนึ่งของ UI เป็นสีเทาหรือซ่อนไว้ จนกว่าคุณจะยืนยันว่าดาวน์โหลดโมเดลแล้ว คุณทำได้โดยแนบ Listener ไปยังเมธอด download() ของ Model Manager ดังนี้

Java

FirebaseModelManager.getInstance().download(remoteModel, conditions)
        .addOnSuccessListener(new OnSuccessListener<Void>() {
            @Override
            public void onSuccess(Void v) {
              // Download complete. Depending on your app, you could enable
              // the ML feature, or switch from the local model to the remote
              // model, etc.
            }
        });

Kotlin

FirebaseModelManager.getInstance().download(remoteModel, conditions)
    .addOnCompleteListener {
        // Download complete. Depending on your app, you could enable the ML
        // feature, or switch from the local model to the remote model, etc.
    }

ระบุอินพุตและเอาต์พุตของโมเดล

จากนั้นกำหนดค่ารูปแบบอินพุตและเอาต์พุตของตัวตีความโมเดล

โมเดล TensorFlow Lite รับอาร์เรย์หลายมิติอย่างน้อย 1 รายการเป็นอินพุตและสร้างเป็นเอาต์พุต อาร์เรย์เหล่านี้มีค่า byte, int, long หรือ float คุณต้อง กำหนดค่า ML Kit ด้วยจำนวนและขนาด ("รูปร่าง") ของอาร์เรย์ที่โมเดล ใช้

หากไม่ทราบรูปร่างและประเภทข้อมูลของอินพุตและเอาต์พุตของโมเดล คุณสามารถใช้ตัวแปล Python ของ TensorFlow Lite เพื่อตรวจสอบโมเดลได้ เช่น

import tensorflow as tf

interpreter = tf.lite.Interpreter(model_path="my_model.tflite")
interpreter.allocate_tensors()

# Print input shape and type
print(interpreter.get_input_details()[0]['shape'])  # Example: [1 224 224 3]
print(interpreter.get_input_details()[0]['dtype'])  # Example: <class 'numpy.float32'>

# Print output shape and type
print(interpreter.get_output_details()[0]['shape'])  # Example: [1 1000]
print(interpreter.get_output_details()[0]['dtype'])  # Example: <class 'numpy.float32'>

หลังจากกำหนดรูปแบบอินพุตและเอาต์พุตของโมเดลแล้ว คุณจะ กำหนดค่าตัวแปลโมเดลของแอปได้โดยการสร้างออบเจ็กต์ FirebaseModelInputOutputOptions

ตัวอย่างเช่น โมเดลการจัดประเภทรูปภาพแบบทศนิยมอาจรับอาร์เรย์ Nx224x224x3 ของค่า float เป็นอินพุต ซึ่งแสดงถึงกลุ่มรูปภาพขนาด 224x224 แบบ 3 แชแนล (RGB) และสร้างเอาต์พุตเป็นรายการค่า float จำนวน 1,000 ค่า โดยแต่ละค่าแสดงถึงความน่าจะเป็นที่รูปภาพจะเป็นสมาชิกของหมวดหมู่ใดหมวดหมู่หนึ่งใน 1,000 หมวดหมู่ที่โมเดลคาดการณ์N

สำหรับโมเดลดังกล่าว คุณจะต้องกำหนดค่าอินพุตและเอาต์พุตของตัวตีความโมเดล ตามที่แสดงด้านล่าง

Java

FirebaseModelInputOutputOptions inputOutputOptions =
        new FirebaseModelInputOutputOptions.Builder()
                .setInputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 224, 224, 3})
                .setOutputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 5})
                .build();

Kotlin

val inputOutputOptions = FirebaseModelInputOutputOptions.Builder()
        .setInputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1, 224, 224, 3))
        .setOutputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1, 5))
        .build()

ทำการอนุมานในข้อมูลอินพุต

สุดท้ายนี้ หากต้องการทำการอนุมานโดยใช้โมเดล ให้รับข้อมูลอินพุตและทำการ แปลงข้อมูลที่จำเป็นเพื่อให้ได้อาร์เรย์อินพุตที่มี รูปร่างที่เหมาะสมสำหรับโมเดล

เช่น หากคุณมีโมเดลการแยกประเภทรูปภาพที่มีรูปร่างอินพุตเป็นค่าจุดลอยตัว [1 224 224 3] คุณจะสร้างอาร์เรย์อินพุตจากออบเจ็กต์ Bitmap ได้ดังตัวอย่างต่อไปนี้

Java

Bitmap bitmap = getYourInputImage();
bitmap = Bitmap.createScaledBitmap(bitmap, 224, 224, true);

int batchNum = 0;
float[][][][] input = new float[1][224][224][3];
for (int x = 0; x < 224; x++) {
    for (int y = 0; y < 224; y++) {
        int pixel = bitmap.getPixel(x, y);
        // Normalize channel values to [-1.0, 1.0]. This requirement varies by
        // model. For example, some models might require values to be normalized
        // to the range [0.0, 1.0] instead.
        input[batchNum][x][y][0] = (Color.red(pixel) - 127) / 128.0f;
        input[batchNum][x][y][1] = (Color.green(pixel) - 127) / 128.0f;
        input[batchNum][x][y][2] = (Color.blue(pixel) - 127) / 128.0f;
    }
}

Kotlin

val bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true)

val batchNum = 0
val input = Array(1) { Array(224) { Array(224) { FloatArray(3) } } }
for (x in 0..223) {
    for (y in 0..223) {
        val pixel = bitmap.getPixel(x, y)
        // Normalize channel values to [-1.0, 1.0]. This requirement varies by
        // model. For example, some models might require values to be normalized
        // to the range [0.0, 1.0] instead.
        input[batchNum][x][y][0] = (Color.red(pixel) - 127) / 255.0f
        input[batchNum][x][y][1] = (Color.green(pixel) - 127) / 255.0f
        input[batchNum][x][y][2] = (Color.blue(pixel) - 127) / 255.0f
    }
}

จากนั้นสร้างออบเจ็กต์ FirebaseModelInputs ด้วย ข้อมูลอินพุต แล้วส่งออบเจ็กต์ดังกล่าวพร้อมกับข้อกำหนดอินพุตและเอาต์พุตของโมเดลไปยังเมธอด run ของตัวตีความโมเดล

Java

FirebaseModelInputs inputs = new FirebaseModelInputs.Builder()
        .add(input)  // add() as many input arrays as your model requires
        .build();
firebaseInterpreter.run(inputs, inputOutputOptions)
        .addOnSuccessListener(
                new OnSuccessListener<FirebaseModelOutputs>() {
                    @Override
                    public void onSuccess(FirebaseModelOutputs result) {
                        // ...
                    }
                })
        .addOnFailureListener(
                new OnFailureListener() {
                    @Override
                    public void onFailure(@NonNull Exception e) {
                        // Task failed with an exception
                        // ...
                    }
                });

Kotlin

val inputs = FirebaseModelInputs.Builder()
        .add(input) // add() as many input arrays as your model requires
        .build()
firebaseInterpreter.run(inputs, inputOutputOptions)
        .addOnSuccessListener { result ->
            // ...
        }
        .addOnFailureListener { e ->
            // Task failed with an exception
            // ...
        }

หากการเรียกสำเร็จ คุณจะรับเอาต์พุตได้โดยการเรียกใช้getOutput()เมธอด ของออบเจ็กต์ที่ส่งไปยัง Listener ที่สำเร็จ เช่น

Java

float[][] output = result.getOutput(0);
float[] probabilities = output[0];

Kotlin

val output = result.getOutput<Array<FloatArray>>(0)
val probabilities = output[0]

วิธีใช้เอาต์พุตจะขึ้นอยู่กับโมเดลที่คุณใช้

ตัวอย่างเช่น หากคุณทำการแยกประเภท ขั้นตอนถัดไปอาจเป็นการ แมปดัชนีของผลลัพธ์กับป้ายกำกับที่แสดง

Java

BufferedReader reader = new BufferedReader(
        new InputStreamReader(getAssets().open("retrained_labels.txt")));
for (int i = 0; i < probabilities.length; i++) {
    String label = reader.readLine();
    Log.i("MLKit", String.format("%s: %1.4f", label, probabilities[i]));
}

Kotlin

val reader = BufferedReader(
        InputStreamReader(assets.open("retrained_labels.txt")))
for (i in probabilities.indices) {
    val label = reader.readLine()
    Log.i("MLKit", String.format("%s: %1.4f", label, probabilities[i]))
}

ภาคผนวก: ความปลอดภัยของโมเดล

ไม่ว่าคุณจะทำให้โมเดล TensorFlow Lite พร้อมใช้งานใน ML Kit อย่างไรก็ตาม ML Kit จะจัดเก็บโมเดลในรูปแบบ protobuf ที่ซีเรียลไลซ์มาตรฐานในพื้นที่เก็บข้อมูลในเครื่อง

ในทางทฤษฎี หมายความว่าทุกคนสามารถคัดลอกโมเดลของคุณได้ อย่างไรก็ตาม ในทางปฏิบัติ โมเดลส่วนใหญ่มีความเฉพาะเจาะจงกับแอปพลิเคชันและมีการปกปิดโดยการเพิ่มประสิทธิภาพ ซึ่งทำให้ความเสี่ยงคล้ายกับที่คู่แข่งถอดแยกชิ้นส่วนและนำโค้ดของคุณกลับมาใช้ใหม่ อย่างไรก็ตาม คุณควรทราบถึงความเสี่ยงนี้ก่อนใช้ โมเดลที่กำหนดเองในแอป

ใน Android API ระดับ 21 (Lollipop) ขึ้นไป ระบบจะดาวน์โหลดโมเดลไปยังไดเรกทอรีที่ ยกเว้นจากการสำรองข้อมูลอัตโนมัติ

ใน Android API ระดับ 20 และเก่ากว่า ระบบจะดาวน์โหลดโมเดลไปยังไดเรกทอรี ชื่อ com.google.firebase.ml.custom.models ในที่เก็บข้อมูลภายในแบบส่วนตัวของแอป หากเปิดใช้การสำรองข้อมูลไฟล์โดยใช้ BackupAgent คุณอาจเลือกที่จะยกเว้นไดเรกทอรีนี้