ใช้โมเดล TensorFlow Lite เพื่อการอนุมานด้วย ML Kit บน Android

คุณใช้ ML Kit เพื่อทำการอนุมานในอุปกรณ์ด้วยโมเดล TensorFlow Lite ได้

API นี้ต้องใช้ Android SDK ระดับ 16 (Jelly Bean) ขึ้นไป

ก่อนเริ่มต้น

  1. เพิ่ม Firebase ลงในโปรเจ็กต์ Android หากยังไม่ได้ทำ
  2. เพิ่มทรัพยากร Dependency สำหรับไลบรารี ML Kit Android ลงในไฟล์ Gradle ของโมดูล (ระดับแอป) (โดยปกติจะเป็น app/build.gradle):
    apply plugin: 'com.android.application'
    apply plugin: 'com.google.gms.google-services'
    
    dependencies {
      // ...
    
      implementation 'com.google.firebase:firebase-ml-model-interpreter:22.0.3'
    }
    
  3. แปลงโมเดล TensorFlow ที่ต้องการใช้เป็นรูปแบบ TensorFlow Lite โปรดดู TOCO: TensorFlow Lite เพิ่มประสิทธิภาพตัวแปลง

โฮสต์หรือรวมกลุ่มโมเดลของคุณ

คุณต้องทำให้โมเดลพร้อมใช้งานใน ML Kit ก่อนจึงจะใช้โมเดล TensorFlow Lite เพื่อการอนุมานในแอปได้ ML Kit สามารถใช้โมเดล TensorFlow Lite ที่โฮสต์จากระยะไกลโดยใช้ Firebase ที่มาพร้อมกับไบนารีของแอป หรือทั้ง 2 อย่างก็ได้

การโฮสต์โมเดลใน Firebase ทำให้คุณอัปเดตโมเดลได้โดยไม่ต้องเผยแพร่แอปเวอร์ชันใหม่ รวมถึงใช้การกำหนดค่าระยะไกลและการทดสอบ A/B เพื่อแสดงโมเดลต่างๆ แก่ผู้ใช้กลุ่มต่างๆ แบบไดนามิกได้

หากเลือกระบุโมเดลโดยการโฮสต์โมเดลด้วย Firebase เท่านั้น โดยไม่รวมเข้ากับแอป ให้ลดขนาดการดาวน์โหลดเริ่มต้นของแอปได้ แต่โปรดทราบว่าหากโมเดลไม่ได้รวมมากับแอป ฟังก์ชันที่เกี่ยวข้องกับโมเดลจะใช้งานไม่ได้จนกว่าแอปจะดาวน์โหลดโมเดลเป็นครั้งแรก

การรวมโมเดลกับแอปจะทำให้ฟีเจอร์ ML ของแอปยังคงทำงานได้เมื่อโมเดลที่โฮสต์โดย Firebase ไม่พร้อมใช้งาน

โมเดลโฮสต์บน Firebase

วิธีโฮสต์โมเดล TensorFlow Lite บน Firebase

  1. ในส่วน ML Kit ของคอนโซล Firebase ให้คลิกแท็บกำหนดเอง
  2. คลิกเพิ่มรูปแบบที่กำหนดเอง (หรือเพิ่มโมเดลอื่น)
  3. ระบุชื่อที่จะใช้ระบุโมเดลในโปรเจ็กต์ Firebase แล้วอัปโหลดไฟล์โมเดล TensorFlow Lite (โดยปกติจะลงท้ายด้วย .tflite หรือ .lite)
  4. ให้ประกาศว่าต้องใช้สิทธิ์อินเทอร์เน็ตในไฟล์ Manifest ของแอป โดยทำดังนี้
    <uses-permission android:name="android.permission.INTERNET" />
    

หลังจากเพิ่มโมเดลที่กำหนดเองลงในโปรเจ็กต์ Firebase แล้ว คุณจะอ้างอิงโมเดลในแอปโดยใช้ชื่อที่ระบุได้ คุณอัปโหลดโมเดล TensorFlow Lite ใหม่ได้ทุกเมื่อ และแอปจะดาวน์โหลดโมเดลใหม่และเริ่มใช้งานเมื่อแอปรีสตาร์ทครั้งถัดไป คุณกําหนดเงื่อนไขอุปกรณ์ที่จําเป็นเพื่อให้แอปพยายามอัปเดตโมเดลได้ (ดูด้านล่าง)

รวมโมเดลเข้ากับแอป

หากต้องการรวมโมเดล TensorFlow Lite กับแอป ให้คัดลอกไฟล์โมเดล (โดยปกติแล้วจะลงท้ายด้วย .tflite หรือ .lite) ไปยังโฟลเดอร์ assets/ ของแอป (คุณอาจต้องสร้างโฟลเดอร์ก่อนโดยคลิกขวาที่โฟลเดอร์ app/ จากนั้นคลิกใหม่ > โฟลเดอร์ > โฟลเดอร์เนื้อหา)

จากนั้นเพิ่มโค้ดต่อไปนี้ลงในไฟล์ build.gradle ของแอปเพื่อไม่ให้ Gradle ไม่บีบอัดโมเดลเมื่อสร้างแอป

android {

    // ...

    aaptOptions {
        noCompress "tflite"  // Your model's file extension: "tflite", "lite", etc.
    }
}

ไฟล์โมเดลจะรวมอยู่ในแพ็กเกจแอปและพร้อมให้ ML Kit เป็นเนื้อหาดิบ

โหลดโมเดล

หากต้องการใช้โมเดล TensorFlow Lite ในแอป ก่อนอื่นให้กำหนดค่า ML Kit ด้วยตำแหน่งที่โมเดลของคุณพร้อมใช้งาน ไม่ว่าจะเป็นจากระยะไกลโดยใช้ Firebase ในพื้นที่เก็บข้อมูลในเครื่อง หรือทั้ง 2 อย่าง หากคุณระบุทั้งโมเดลในเครื่องและระยะไกล คุณจะใช้โมเดลระยะไกลได้หากพร้อมใช้งาน และกลับไปใช้โมเดลที่เก็บไว้ในเครื่องหากโมเดลระยะไกลไม่พร้อมใช้งาน

กำหนดค่าโมเดลที่โฮสต์ด้วย Firebase

หากคุณฝากโมเดลไว้กับ Firebase ให้สร้างออบเจ็กต์ FirebaseCustomRemoteModel โดยระบุชื่อที่กำหนดให้กับโมเดลเมื่ออัปโหลด ดังนี้

Java

FirebaseCustomRemoteModel remoteModel =
        new FirebaseCustomRemoteModel.Builder("your_model").build();

Kotlin+KTX

val remoteModel = FirebaseCustomRemoteModel.Builder("your_model").build()

จากนั้นเริ่มงานดาวน์โหลดโมเดล โดยระบุเงื่อนไขที่คุณต้องการอนุญาตให้ดาวน์โหลด หากโมเดลไม่ได้อยู่ในอุปกรณ์ หรือหากมีโมเดลเวอร์ชันใหม่กว่า งานจะดาวน์โหลดโมเดลแบบไม่พร้อมกันจาก Firebase ดังนี้

Java

FirebaseModelDownloadConditions conditions = new FirebaseModelDownloadConditions.Builder()
        .requireWifi()
        .build();
FirebaseModelManager.getInstance().download(remoteModel, conditions)
        .addOnCompleteListener(new OnCompleteListener<Void>() {
            @Override
            public void onComplete(@NonNull Task<Void> task) {
                // Success.
            }
        });

Kotlin+KTX

val conditions = FirebaseModelDownloadConditions.Builder()
    .requireWifi()
    .build()
FirebaseModelManager.getInstance().download(remoteModel, conditions)
    .addOnCompleteListener {
        // Success.
    }

แอปจำนวนมากจะเริ่มงานดาวน์โหลดในโค้ดการเริ่มต้น แต่คุณทำได้ทุกเมื่อก่อนที่จะต้องใช้โมเดลดังกล่าว

กำหนดค่าโมเดลในเครื่อง

หากคุณจัดกลุ่มโมเดลกับแอป ให้สร้างออบเจ็กต์ FirebaseCustomLocalModel โดยระบุชื่อไฟล์ของโมเดล TensorFlow Lite ดังนี้

Java

FirebaseCustomLocalModel localModel = new FirebaseCustomLocalModel.Builder()
        .setAssetFilePath("your_model.tflite")
        .build();

Kotlin+KTX

val localModel = FirebaseCustomLocalModel.Builder()
    .setAssetFilePath("your_model.tflite")
    .build()

สร้างล่ามจากโมเดล

หลังจากที่กำหนดค่าแหล่งที่มาของโมเดลแล้ว ให้สร้างออบเจ็กต์ FirebaseModelInterpreter จากหนึ่งในแหล่งที่มาดังกล่าว

หากคุณมีเฉพาะโมเดลที่รวมภายในเครื่อง ก็เพียงแค่สร้างล่ามจากออบเจ็กต์ FirebaseCustomLocalModel โดยทำดังนี้

Java

FirebaseModelInterpreter interpreter;
try {
    FirebaseModelInterpreterOptions options =
            new FirebaseModelInterpreterOptions.Builder(localModel).build();
    interpreter = FirebaseModelInterpreter.getInstance(options);
} catch (FirebaseMLException e) {
    // ...
}

Kotlin+KTX

val options = FirebaseModelInterpreterOptions.Builder(localModel).build()
val interpreter = FirebaseModelInterpreter.getInstance(options)

หากมีโมเดลที่โฮสต์จากระยะไกล คุณจะต้องตรวจสอบว่ามีการดาวน์โหลดโมเดลแล้วก่อนที่จะเรียกใช้ คุณตรวจสอบสถานะของงานดาวน์โหลดโมเดลได้โดยใช้เมธอด isModelDownloaded() ของตัวจัดการโมเดล

แม้ว่าคุณจะต้องยืนยันข้อมูลนี้ก่อนเรียกใช้ล่ามเท่านั้น หากคุณมีทั้งโมเดลที่โฮสต์จากระยะไกลและโมเดลแบบกลุ่มในเครื่อง ก็อาจทำให้ดำเนินการตรวจสอบนี้เมื่อเริ่มต้นล่ามโมเดล โดยให้สร้างล่ามจากโมเดลระยะไกลหากดาวน์โหลดแล้ว และจากโมเดลในเครื่องหากไม่เป็นเช่นนั้น

Java

FirebaseModelManager.getInstance().isModelDownloaded(remoteModel)
        .addOnSuccessListener(new OnSuccessListener<Boolean>() {
            @Override
            public void onSuccess(Boolean isDownloaded) {
                FirebaseModelInterpreterOptions options;
                if (isDownloaded) {
                    options = new FirebaseModelInterpreterOptions.Builder(remoteModel).build();
                } else {
                    options = new FirebaseModelInterpreterOptions.Builder(localModel).build();
                }
                FirebaseModelInterpreter interpreter = FirebaseModelInterpreter.getInstance(options);
                // ...
            }
        });

Kotlin+KTX

FirebaseModelManager.getInstance().isModelDownloaded(remoteModel)
    .addOnSuccessListener { isDownloaded -> 
    val options =
        if (isDownloaded) {
            FirebaseModelInterpreterOptions.Builder(remoteModel).build()
        } else {
            FirebaseModelInterpreterOptions.Builder(localModel).build()
        }
    val interpreter = FirebaseModelInterpreter.getInstance(options)
}

หากคุณมีเฉพาะโมเดลที่โฮสต์จากระยะไกล คุณควรปิดใช้ฟังก์ชันที่เกี่ยวข้องกับโมเดล เช่น เป็นสีเทาหรือซ่อนบางส่วนของ UI จนกว่าจะยืนยันว่าดาวน์โหลดโมเดลแล้ว ซึ่งทำได้โดยการแนบ Listener ลงในเมธอด download() ของตัวจัดการโมเดล ดังนี้

Java

FirebaseModelManager.getInstance().download(remoteModel, conditions)
        .addOnSuccessListener(new OnSuccessListener<Void>() {
            @Override
            public void onSuccess(Void v) {
              // Download complete. Depending on your app, you could enable
              // the ML feature, or switch from the local model to the remote
              // model, etc.
            }
        });

Kotlin+KTX

FirebaseModelManager.getInstance().download(remoteModel, conditions)
    .addOnCompleteListener {
        // Download complete. Depending on your app, you could enable the ML
        // feature, or switch from the local model to the remote model, etc.
    }

ระบุอินพุตและเอาต์พุตของโมเดล

ถัดไป ให้กำหนดค่ารูปแบบอินพุตและเอาต์พุตของตัวแปลมโมเดล

โมเดล TensorFlow Lite รับเป็นอินพุตและสร้างเอาต์พุตเป็นอาร์เรย์หลายมิติอย่างน้อย 1 รายการ อาร์เรย์เหล่านี้มีค่า byte, int, long หรือ float คุณต้องกำหนดค่า ML Kit ด้วยจำนวนและขนาด ("รูปร่าง") ของอาร์เรย์ที่โมเดลของคุณใช้

หากไม่ทราบรูปร่างและประเภทข้อมูลของอินพุตและเอาต์พุตของโมเดล คุณสามารถใช้อินเทอร์พรีเตอร์ของ TensorFlow Lite Python เพื่อตรวจสอบโมเดลได้ ตัวอย่างเช่น

import tensorflow as tf

interpreter = tf.lite.Interpreter(model_path="my_model.tflite")
interpreter.allocate_tensors()

# Print input shape and type
print(interpreter.get_input_details()[0]['shape'])  # Example: [1 224 224 3]
print(interpreter.get_input_details()[0]['dtype'])  # Example: <class 'numpy.float32'>

# Print output shape and type
print(interpreter.get_output_details()[0]['shape'])  # Example: [1 1000]
print(interpreter.get_output_details()[0]['dtype'])  # Example: <class 'numpy.float32'>

หลังจากที่ระบุรูปแบบอินพุตและเอาต์พุตของโมเดลแล้ว คุณจะกำหนดค่าล่ามโมเดลของแอปได้โดยการสร้างออบเจ็กต์ FirebaseModelInputOutputOptions

ตัวอย่างเช่น โมเดลการจัดประเภทรูปภาพจุดลอยตัวอาจใช้เป็นอินพุตอาร์เรย์ Nx224x224x3 ของค่า float โดยแสดงกลุ่มรูปภาพ 3 ช่อง (RGB) N ขนาด 224x224 และสร้างเอาต์พุตเป็นรายการค่า float 1,000 ค่า โดยแต่ละค่าแสดงถึงความน่าจะเป็นที่รูปภาพจะเป็นสมาชิกของหนึ่งในหมวดหมู่ 1,000 หมวดหมู่ที่โมเดลคาดการณ์

สำหรับโมเดลดังกล่าว คุณจะต้องกำหนดค่าอินพุตและเอาต์พุตของตัวแปลมโมเดลดังที่แสดงด้านล่าง

Java

FirebaseModelInputOutputOptions inputOutputOptions =
        new FirebaseModelInputOutputOptions.Builder()
                .setInputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 224, 224, 3})
                .setOutputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 5})
                .build();

Kotlin+KTX

val inputOutputOptions = FirebaseModelInputOutputOptions.Builder()
        .setInputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1, 224, 224, 3))
        .setOutputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1, 5))
        .build()

ทำการอนุมานเกี่ยวกับข้อมูลอินพุต

สุดท้าย ในการอนุมานโดยใช้โมเดล ให้รับข้อมูลอินพุตและแปลงข้อมูลที่จำเป็นต่อการหาอาร์เรย์อินพุตที่มีรูปร่างเหมาะสมกับโมเดลของคุณ

เช่น หากคุณมีโมเดลการจัดประเภทรูปภาพที่มีรูปร่างอินพุตเป็น [1 224 224 3] ค่าจุดลอยตัว คุณอาจสร้างอาร์เรย์อินพุตจากออบเจ็กต์ Bitmap ดังที่แสดงในตัวอย่างต่อไปนี้

Java

Bitmap bitmap = getYourInputImage();
bitmap = Bitmap.createScaledBitmap(bitmap, 224, 224, true);

int batchNum = 0;
float[][][][] input = new float[1][224][224][3];
for (int x = 0; x < 224; x++) {
    for (int y = 0; y < 224; y++) {
        int pixel = bitmap.getPixel(x, y);
        // Normalize channel values to [-1.0, 1.0]. This requirement varies by
        // model. For example, some models might require values to be normalized
        // to the range [0.0, 1.0] instead.
        input[batchNum][x][y][0] = (Color.red(pixel) - 127) / 128.0f;
        input[batchNum][x][y][1] = (Color.green(pixel) - 127) / 128.0f;
        input[batchNum][x][y][2] = (Color.blue(pixel) - 127) / 128.0f;
    }
}

Kotlin+KTX

val bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true)

val batchNum = 0
val input = Array(1) { Array(224) { Array(224) { FloatArray(3) } } }
for (x in 0..223) {
    for (y in 0..223) {
        val pixel = bitmap.getPixel(x, y)
        // Normalize channel values to [-1.0, 1.0]. This requirement varies by
        // model. For example, some models might require values to be normalized
        // to the range [0.0, 1.0] instead.
        input[batchNum][x][y][0] = (Color.red(pixel) - 127) / 255.0f
        input[batchNum][x][y][1] = (Color.green(pixel) - 127) / 255.0f
        input[batchNum][x][y][2] = (Color.blue(pixel) - 127) / 255.0f
    }
}

จากนั้นสร้างออบเจ็กต์ FirebaseModelInputs พร้อมข้อมูลอินพุตของคุณ แล้วส่งออบเจ็กต์ดังกล่าวและข้อมูลจำเพาะของอินพุตและเอาต์พุตของโมเดลไปยังเมธอด โมเดลอินเทอร์พรีเตอร์run ของโมเดล

Java

FirebaseModelInputs inputs = new FirebaseModelInputs.Builder()
        .add(input)  // add() as many input arrays as your model requires
        .build();
firebaseInterpreter.run(inputs, inputOutputOptions)
        .addOnSuccessListener(
                new OnSuccessListener<FirebaseModelOutputs>() {
                    @Override
                    public void onSuccess(FirebaseModelOutputs result) {
                        // ...
                    }
                })
        .addOnFailureListener(
                new OnFailureListener() {
                    @Override
                    public void onFailure(@NonNull Exception e) {
                        // Task failed with an exception
                        // ...
                    }
                });

Kotlin+KTX

val inputs = FirebaseModelInputs.Builder()
        .add(input) // add() as many input arrays as your model requires
        .build()
firebaseInterpreter.run(inputs, inputOutputOptions)
        .addOnSuccessListener { result ->
            // ...
        }
        .addOnFailureListener { e ->
            // Task failed with an exception
            // ...
        }

หากการเรียกใช้สำเร็จ คุณจะได้รับเอาต์พุตโดยการเรียกใช้เมธอด getOutput() ของออบเจ็กต์ที่ส่งไปยัง Listener ที่สำเร็จ ตัวอย่างเช่น

Java

float[][] output = result.getOutput(0);
float[] probabilities = output[0];

Kotlin+KTX

val output = result.getOutput<Array<FloatArray>>(0)
val probabilities = output[0]

วิธีที่คุณใช้เอาต์พุตจะขึ้นอยู่กับโมเดลที่คุณใช้

ตัวอย่างเช่น หากใช้การแยกประเภท ในขั้นตอนถัดไป คุณอาจจับคู่ดัชนีของผลลัพธ์กับป้ายกำกับที่แสดงดังนี้

Java

BufferedReader reader = new BufferedReader(
        new InputStreamReader(getAssets().open("retrained_labels.txt")));
for (int i = 0; i < probabilities.length; i++) {
    String label = reader.readLine();
    Log.i("MLKit", String.format("%s: %1.4f", label, probabilities[i]));
}

Kotlin+KTX

val reader = BufferedReader(
        InputStreamReader(assets.open("retrained_labels.txt")))
for (i in probabilities.indices) {
    val label = reader.readLine()
    Log.i("MLKit", String.format("%s: %1.4f", label, probabilities[i]))
}

ภาคผนวก: ความปลอดภัยของโมเดล

ML Kit จะจัดเก็บโมเดล TensorFlow Lite แบบอนุกรมมาตรฐานในพื้นที่เก็บข้อมูลในเครื่องไว้ ไม่ว่าคุณจะทำให้โมเดล TensorFlow Lite ใช้งานได้กับ ML Kit อย่างไร

ในทางทฤษฎี หมายความว่าทุกคนสามารถคัดลอกโมเดลของคุณได้ อย่างไรก็ตาม ในทางปฏิบัติ โมเดลส่วนใหญ่จะมีความเฉพาะเจาะจงกับแอปพลิเคชันโดยเฉพาะและทำให้ยากต่อการอ่าน (Obfuscate) ด้วยการเพิ่มประสิทธิภาพ ซึ่งมีความเสี่ยงคล้ายกับของคู่แข่งที่แยกชิ้นส่วนและนำโค้ดกลับมาใช้ใหม่ อย่างไรก็ตาม คุณควรตระหนักถึงความเสี่ยงนี้ก่อนที่จะใช้โมเดลที่กำหนดเองในแอป

ใน Android API ระดับ 21 (Lollipop) ขึ้นไป ระบบจะดาวน์โหลดโมเดลไปยังไดเรกทอรีที่ ยกเว้นจากการสำรองข้อมูลอัตโนมัติ

ใน Android API ระดับ 20 และเก่ากว่า ระบบจะดาวน์โหลดโมเดลไปยังไดเรกทอรีชื่อ com.google.firebase.ml.custom.models ในที่จัดเก็บข้อมูลภายในแบบส่วนตัวของแอป หากเปิดใช้การสำรองไฟล์โดยใช้ BackupAgent คุณอาจเลือกยกเว้นไดเรกทอรีนี้ได้