ใช้โมเดล TensorFlow Lite สำหรับการอนุมานด้วย ML Kit บน Android

คุณสามารถใช้ ML Kit เพื่อทำการอนุมานในอุปกรณ์ด้วยโมเดล TensorFlow Lite

API นี้ต้องใช้ Android SDK ระดับ 16 (Jelly Bean) หรือใหม่กว่า

ก่อนที่คุณจะเริ่ม

  1. หากคุณยังไม่ได้ เพิ่ม Firebase ในโครงการ Android ของคุณ
  2. เพิ่มการพึ่งพาสำหรับไลบรารี ML Kit Android ให้กับไฟล์ Gradle ของโมดูล (ระดับแอป) (โดยปกติคือ app/build.gradle ):
    apply plugin: 'com.android.application'
    apply plugin: 'com.google.gms.google-services'
    
    dependencies {
      // ...
    
      implementation 'com.google.firebase:firebase-ml-model-interpreter:22.0.3'
    }
    
  3. แปลงโมเดล TensorFlow ที่คุณต้องการใช้เป็นรูปแบบ TensorFlow Lite ดู TOCO: ตัวแปลงการเพิ่มประสิทธิภาพ TensorFlow Lite

โฮสต์หรือรวมโมเดลของคุณ

ก่อนที่คุณจะสามารถใช้โมเดล TensorFlow Lite สำหรับการอนุมานในแอปได้ คุณต้องทำให้โมเดลพร้อมใช้งานสำหรับ ML Kit ML Kit สามารถใช้โมเดล TensorFlow Lite ที่โฮสต์จากระยะไกลโดยใช้ Firebase ที่มาพร้อมกับไบนารีของแอป หรือทั้งสองอย่าง

ด้วยการโฮสต์โมเดลบน Firebase คุณสามารถอัปเดตโมเดลได้โดยไม่ต้องเปิดตัวแอปเวอร์ชันใหม่ และคุณสามารถใช้การกำหนดค่าระยะไกลและการทดสอบ A/B เพื่อให้บริการโมเดลต่างๆ แบบไดนามิกแก่ผู้ใช้กลุ่มต่างๆ

หากคุณเลือกที่จะระบุโมเดลโดยการโฮสต์โมเดลนั้นกับ Firebase เท่านั้น และไม่รวมโมเดลเข้ากับแอปของคุณ คุณสามารถลดขนาดการดาวน์โหลดเริ่มต้นของแอปได้ อย่างไรก็ตาม โปรดทราบว่าหากโมเดลไม่ได้รวมอยู่กับแอปของคุณ ฟังก์ชันการทำงานใดๆ ที่เกี่ยวข้องกับโมเดลจะไม่สามารถใช้งานได้จนกว่าแอปของคุณจะดาวน์โหลดโมเดลเป็นครั้งแรก

เมื่อรวมโมเดลเข้ากับแอป คุณจะมั่นใจได้ว่าฟีเจอร์ ML ของแอปยังคงใช้งานได้เมื่อโมเดลที่โฮสต์โดย Firebase ไม่พร้อมใช้งาน

โฮสต์โมเดลบน Firebase

หากต้องการโฮสต์โมเดล TensorFlow Lite ของคุณบน Firebase:

  1. ในส่วน ML Kit ของ คอนโซล Firebase ให้คลิกแท็บ กำหนดเอง
  2. คลิก เพิ่มโมเดลที่กำหนดเอง (หรือ เพิ่มโมเดลอื่น )
  3. ระบุชื่อที่จะใช้ระบุโมเดลของคุณในโปรเจ็กต์ Firebase จากนั้นอัปโหลดไฟล์โมเดล TensorFlow Lite (โดยปกติจะลงท้ายด้วย .tflite หรือ .lite )
  4. ในรายการแอปของคุณ ให้ประกาศว่าต้องได้รับอนุญาตจากอินเทอร์เน็ต:
    <uses-permission android:name="android.permission.INTERNET" />
    

หลังจากที่คุณเพิ่มโมเดลที่กำหนดเองลงในโปรเจ็กต์ Firebase แล้ว คุณจะอ้างอิงโมเดลในแอปได้โดยใช้ชื่อที่คุณระบุ คุณสามารถอัปโหลดโมเดล TensorFlow Lite ใหม่ได้ตลอดเวลา และแอปของคุณจะดาวน์โหลดโมเดลใหม่และเริ่มใช้งานเมื่อแอปรีสตาร์ทครั้งถัดไป คุณสามารถกำหนดเงื่อนไขอุปกรณ์ที่จำเป็นสำหรับแอปของคุณเพื่อพยายามอัปเดตโมเดลได้ (ดูด้านล่าง)

รวมโมเดลเข้ากับแอป

หากต้องการรวมโมเดล TensorFlow Lite เข้ากับแอปของคุณ ให้คัดลอกไฟล์โมเดล (โดยปกติจะลงท้ายด้วย .tflite หรือ .lite ) ไปยังโฟลเดอร์ assets/ ของแอป (คุณอาจต้องสร้างโฟลเดอร์ก่อนโดยคลิกขวาที่ app/ โฟลเดอร์ จากนั้นคลิก ใหม่ > โฟลเดอร์ > โฟลเดอร์สินทรัพย์ )

จากนั้น เพิ่มสิ่งต่อไปนี้ลงในไฟล์ build.gradle ของแอปของคุณเพื่อให้แน่ใจว่า Gradle จะไม่บีบอัดโมเดลเมื่อสร้างแอป:

android {

    // ...

    aaptOptions {
        noCompress "tflite"  // Your model's file extension: "tflite", "lite", etc.
    }
}

ไฟล์โมเดลจะรวมอยู่ในแพ็คเกจแอปและพร้อมใช้งานสำหรับ ML Kit ในรูปแบบสินทรัพย์ดิบ

โหลดโมเดล

หากต้องการใช้โมเดล TensorFlow Lite ในแอปของคุณ ขั้นแรกให้กำหนดค่า ML Kit ด้วยตำแหน่งที่โมเดลของคุณพร้อมใช้งาน: การใช้ Firebase จากระยะไกล ในพื้นที่จัดเก็บในเครื่อง หรือทั้งสองอย่าง หากคุณระบุทั้งโมเดลภายในและโมเดลระยะไกล คุณสามารถใช้โมเดลระยะไกลได้หากมี และถอยกลับไปยังโมเดลที่จัดเก็บไว้ในเครื่องหากไม่มีโมเดลระยะไกล

กำหนดค่าโมเดลที่โฮสต์โดย Firebase

หากคุณโฮสต์โมเดลของคุณด้วย Firebase ให้สร้างออบเจ็กต์ FirebaseCustomRemoteModel โดยระบุชื่อที่คุณกำหนดให้กับโมเดลเมื่อคุณอัปโหลด:

Java

FirebaseCustomRemoteModel remoteModel =
        new FirebaseCustomRemoteModel.Builder("your_model").build();

Kotlin+KTX

val remoteModel = FirebaseCustomRemoteModel.Builder("your_model").build()

จากนั้น เริ่มต้นงานการดาวน์โหลดโมเดล โดยระบุเงื่อนไขที่คุณต้องการอนุญาตให้ดาวน์โหลด หากไม่มีโมเดลดังกล่าวอยู่ในอุปกรณ์ หรือมีเวอร์ชันที่ใหม่กว่า งานจะดาวน์โหลดโมเดลจาก Firebase แบบอะซิงโครนัส:

Java

FirebaseModelDownloadConditions conditions = new FirebaseModelDownloadConditions.Builder()
        .requireWifi()
        .build();
FirebaseModelManager.getInstance().download(remoteModel, conditions)
        .addOnCompleteListener(new OnCompleteListener<Void>() {
            @Override
            public void onComplete(@NonNull Task<Void> task) {
                // Success.
            }
        });

Kotlin+KTX

val conditions = FirebaseModelDownloadConditions.Builder()
    .requireWifi()
    .build()
FirebaseModelManager.getInstance().download(remoteModel, conditions)
    .addOnCompleteListener {
        // Success.
    }

แอพจำนวนมากเริ่มงานดาวน์โหลดด้วยโค้ดเริ่มต้น แต่คุณสามารถทำได้ทุกเมื่อก่อนจำเป็นต้องใช้โมเดล

กำหนดค่าโมเดลท้องถิ่น

หากคุณรวมโมเดลเข้ากับแอปของคุณ ให้สร้างออบเจ็กต์ FirebaseCustomLocalModel โดยระบุชื่อไฟล์ของโมเดล TensorFlow Lite:

Java

FirebaseCustomLocalModel localModel = new FirebaseCustomLocalModel.Builder()
        .setAssetFilePath("your_model.tflite")
        .build();

Kotlin+KTX

val localModel = FirebaseCustomLocalModel.Builder()
    .setAssetFilePath("your_model.tflite")
    .build()

สร้างล่ามจากแบบจำลองของคุณ

หลังจากที่คุณกำหนดค่าแหล่งที่มาของโมเดลแล้ว ให้สร้างออบเจ็กต์ FirebaseModelInterpreter จากหนึ่งในนั้น

หากคุณมีเฉพาะโมเดลที่รวมอยู่ในเครื่อง ให้สร้างล่ามจากออบเจ็กต์ FirebaseCustomLocalModel ของคุณ:

Java

FirebaseModelInterpreter interpreter;
try {
    FirebaseModelInterpreterOptions options =
            new FirebaseModelInterpreterOptions.Builder(localModel).build();
    interpreter = FirebaseModelInterpreter.getInstance(options);
} catch (FirebaseMLException e) {
    // ...
}

Kotlin+KTX

val options = FirebaseModelInterpreterOptions.Builder(localModel).build()
val interpreter = FirebaseModelInterpreter.getInstance(options)

หากคุณมีโมเดลที่โฮสต์ระยะไกล คุณจะต้องตรวจสอบว่ามีการดาวน์โหลดก่อนที่จะรัน คุณสามารถตรวจสอบสถานะของงานการดาวน์โหลดโมเดลได้โดยใช้เมธอด isModelDownloaded() ของตัวจัดการโมเดล

แม้ว่าคุณจะต้องยืนยันสิ่งนี้ก่อนเรียกใช้ล่ามเท่านั้น หากคุณมีทั้งโมเดลที่โฮสต์จากระยะไกลและโมเดลที่รวมอยู่ในเครื่อง มันอาจจะสมเหตุสมผลที่จะดำเนินการตรวจสอบนี้เมื่อสร้างอินสแตนซ์ของล่ามโมเดล: สร้างล่ามจากโมเดลระยะไกลหาก มันถูกดาวน์โหลดแล้ว และจากรุ่นท้องถิ่นเป็นอย่างอื่น

Java

FirebaseModelManager.getInstance().isModelDownloaded(remoteModel)
        .addOnSuccessListener(new OnSuccessListener<Boolean>() {
            @Override
            public void onSuccess(Boolean isDownloaded) {
                FirebaseModelInterpreterOptions options;
                if (isDownloaded) {
                    options = new FirebaseModelInterpreterOptions.Builder(remoteModel).build();
                } else {
                    options = new FirebaseModelInterpreterOptions.Builder(localModel).build();
                }
                FirebaseModelInterpreter interpreter = FirebaseModelInterpreter.getInstance(options);
                // ...
            }
        });

Kotlin+KTX

FirebaseModelManager.getInstance().isModelDownloaded(remoteModel)
    .addOnSuccessListener { isDownloaded -> 
    val options =
        if (isDownloaded) {
            FirebaseModelInterpreterOptions.Builder(remoteModel).build()
        } else {
            FirebaseModelInterpreterOptions.Builder(localModel).build()
        }
    val interpreter = FirebaseModelInterpreter.getInstance(options)
}

หากคุณมีเฉพาะโมเดลที่โฮสต์จากระยะไกล คุณควรปิดใช้งานฟังก์ชันที่เกี่ยวข้องกับโมเดล เช่น ทำให้เป็นสีเทาหรือซ่อนบางส่วนของ UI จนกว่าคุณจะยืนยันว่าดาวน์โหลดโมเดลแล้ว คุณสามารถทำได้โดยการแนบ Listener เข้ากับ download() ของตัวจัดการโมเดล:

Java

FirebaseModelManager.getInstance().download(remoteModel, conditions)
        .addOnSuccessListener(new OnSuccessListener<Void>() {
            @Override
            public void onSuccess(Void v) {
              // Download complete. Depending on your app, you could enable
              // the ML feature, or switch from the local model to the remote
              // model, etc.
            }
        });

Kotlin+KTX

FirebaseModelManager.getInstance().download(remoteModel, conditions)
    .addOnCompleteListener {
        // Download complete. Depending on your app, you could enable the ML
        // feature, or switch from the local model to the remote model, etc.
    }

ระบุอินพุตและเอาต์พุตของโมเดล

ถัดไป กำหนดค่ารูปแบบอินพุตและเอาต์พุตของตัวแปลโมเดล

โมเดล TensorFlow Lite ใช้เป็นอินพุตและสร้างเป็นเอาต์พุตอาร์เรย์หลายมิติตั้งแต่หนึ่งอาร์เรย์ขึ้นไป อาร์เรย์เหล่านี้มีค่า byte , int , long หรือ float คุณต้องกำหนดค่า ML Kit ด้วยจำนวนและขนาด ("รูปร่าง") ของอาร์เรย์ที่โมเดลของคุณใช้

หากคุณไม่ทราบรูปร่างและประเภทข้อมูลของอินพุตและเอาต์พุตของโมเดล คุณสามารถใช้ล่าม TensorFlow Lite Python เพื่อตรวจสอบโมเดลของคุณได้ ตัวอย่างเช่น:

import tensorflow as tf

interpreter = tf.lite.Interpreter(model_path="my_model.tflite")
interpreter.allocate_tensors()

# Print input shape and type
print(interpreter.get_input_details()[0]['shape'])  # Example: [1 224 224 3]
print(interpreter.get_input_details()[0]['dtype'])  # Example: <class 'numpy.float32'>

# Print output shape and type
print(interpreter.get_output_details()[0]['shape'])  # Example: [1 1000]
print(interpreter.get_output_details()[0]['dtype'])  # Example: <class 'numpy.float32'>

หลังจากที่คุณกำหนดรูปแบบของอินพุตและเอาต์พุตของโมเดลแล้ว คุณสามารถกำหนดค่าตัวแปลโมเดลของแอปได้โดยการสร้างออบเจ็กต์ FirebaseModelInputOutputOptions

ตัวอย่างเช่น โมเดลการจำแนกภาพจุดลอยตัวอาจใช้อาร์เรย์ N x224x224x3 ของค่าจำนวน float แทนชุดของภาพสามช่องสัญญาณ (RGB) N 224x224 N และสร้างรายการค่า float 1,000 ค่าเป็นเอาต์พุต โดยแต่ละค่าแสดงถึง ความน่าจะเป็นที่รูปภาพจะเป็นหนึ่งใน 1,000 หมวดหมู่ที่แบบจำลองคาดการณ์

สำหรับโมเดลดังกล่าว คุณจะต้องกำหนดค่าอินพุตและเอาท์พุตของตัวแปลโมเดลดังที่แสดงด้านล่าง:

Java

FirebaseModelInputOutputOptions inputOutputOptions =
        new FirebaseModelInputOutputOptions.Builder()
                .setInputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 224, 224, 3})
                .setOutputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 5})
                .build();

Kotlin+KTX

val inputOutputOptions = FirebaseModelInputOutputOptions.Builder()
        .setInputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1, 224, 224, 3))
        .setOutputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1, 5))
        .build()

ดำเนินการอนุมานข้อมูลอินพุต

สุดท้าย หากต้องการอนุมานโดยใช้โมเดล ให้รับข้อมูลอินพุตของคุณและดำเนินการแปลงข้อมูลที่จำเป็นเพื่อให้ได้อาร์เรย์อินพุตที่มีรูปร่างที่เหมาะสมสำหรับโมเดลของคุณ

ตัวอย่างเช่น หากคุณมีโมเดลการจัดหมวดหมู่รูปภาพที่มีรูปร่างอินพุตเป็นค่าจุดลอยตัว [1 224 224 3] คุณสามารถสร้างอาร์เรย์อินพุตจากออบ Bitmap ได้ดังที่แสดงในตัวอย่างต่อไปนี้:

Java

Bitmap bitmap = getYourInputImage();
bitmap = Bitmap.createScaledBitmap(bitmap, 224, 224, true);

int batchNum = 0;
float[][][][] input = new float[1][224][224][3];
for (int x = 0; x < 224; x++) {
    for (int y = 0; y < 224; y++) {
        int pixel = bitmap.getPixel(x, y);
        // Normalize channel values to [-1.0, 1.0]. This requirement varies by
        // model. For example, some models might require values to be normalized
        // to the range [0.0, 1.0] instead.
        input[batchNum][x][y][0] = (Color.red(pixel) - 127) / 128.0f;
        input[batchNum][x][y][1] = (Color.green(pixel) - 127) / 128.0f;
        input[batchNum][x][y][2] = (Color.blue(pixel) - 127) / 128.0f;
    }
}

Kotlin+KTX

val bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true)

val batchNum = 0
val input = Array(1) { Array(224) { Array(224) { FloatArray(3) } } }
for (x in 0..223) {
    for (y in 0..223) {
        val pixel = bitmap.getPixel(x, y)
        // Normalize channel values to [-1.0, 1.0]. This requirement varies by
        // model. For example, some models might require values to be normalized
        // to the range [0.0, 1.0] instead.
        input[batchNum][x][y][0] = (Color.red(pixel) - 127) / 255.0f
        input[batchNum][x][y][1] = (Color.green(pixel) - 127) / 255.0f
        input[batchNum][x][y][2] = (Color.blue(pixel) - 127) / 255.0f
    }
}

จากนั้น สร้างออบเจ็กต์ FirebaseModelInputs ด้วยข้อมูลอินพุตของคุณ และส่งผ่านข้อมูลดังกล่าวและข้อกำหนดอินพุตและเอาต์พุตของโมเดลไปยังวิธี run ของ ตัวแปลโมเดล :

Java

FirebaseModelInputs inputs = new FirebaseModelInputs.Builder()
        .add(input)  // add() as many input arrays as your model requires
        .build();
firebaseInterpreter.run(inputs, inputOutputOptions)
        .addOnSuccessListener(
                new OnSuccessListener<FirebaseModelOutputs>() {
                    @Override
                    public void onSuccess(FirebaseModelOutputs result) {
                        // ...
                    }
                })
        .addOnFailureListener(
                new OnFailureListener() {
                    @Override
                    public void onFailure(@NonNull Exception e) {
                        // Task failed with an exception
                        // ...
                    }
                });

Kotlin+KTX

val inputs = FirebaseModelInputs.Builder()
        .add(input) // add() as many input arrays as your model requires
        .build()
firebaseInterpreter.run(inputs, inputOutputOptions)
        .addOnSuccessListener { result ->
            // ...
        }
        .addOnFailureListener { e ->
            // Task failed with an exception
            // ...
        }

หากการโทรสำเร็จ คุณสามารถรับเอาต์พุตได้โดยการเรียกเมธอด getOutput() ของอ็อบเจ็กต์ที่ส่งผ่านไปยัง Listener ที่สำเร็จ ตัวอย่างเช่น:

Java

float[][] output = result.getOutput(0);
float[] probabilities = output[0];

Kotlin+KTX

val output = result.getOutput<Array<FloatArray>>(0)
val probabilities = output[0]

วิธีที่คุณใช้เอาต์พุตจะขึ้นอยู่กับรุ่นที่คุณใช้

ตัวอย่างเช่น หากคุณกำลังดำเนินการจัดหมวดหมู่ ในขั้นตอนถัดไป คุณอาจจับคู่ดัชนีของผลลัพธ์กับป้ายกำกับที่เป็นตัวแทน:

Java

BufferedReader reader = new BufferedReader(
        new InputStreamReader(getAssets().open("retrained_labels.txt")));
for (int i = 0; i < probabilities.length; i++) {
    String label = reader.readLine();
    Log.i("MLKit", String.format("%s: %1.4f", label, probabilities[i]));
}

Kotlin+KTX

val reader = BufferedReader(
        InputStreamReader(assets.open("retrained_labels.txt")))
for (i in probabilities.indices) {
    val label = reader.readLine()
    Log.i("MLKit", String.format("%s: %1.4f", label, probabilities[i]))
}

ภาคผนวก: ความปลอดภัยของโมเดล

ไม่ว่าคุณจะทำให้โมเดล TensorFlow Lite พร้อมใช้งานใน ML Kit ได้อย่างไร ML Kit จะจัดเก็บโมเดลเหล่านั้นในรูปแบบ protobuf ที่เป็นอนุกรมมาตรฐานในพื้นที่จัดเก็บในตัวเครื่อง

ตามทฤษฎีแล้ว นี่หมายความว่าใครๆ ก็สามารถคัดลอกโมเดลของคุณได้ อย่างไรก็ตาม ในทางปฏิบัติ โมเดลส่วนใหญ่จะมีความเฉพาะเจาะจงกับแอปพลิเคชันมากและถูกทำให้สับสนด้วยการปรับให้เหมาะสมให้เหมาะสม ซึ่งมีความเสี่ยงใกล้เคียงกับความเสี่ยงที่คู่แข่งจะแยกส่วนและนำโค้ดของคุณกลับมาใช้ใหม่ อย่างไรก็ตาม คุณควรตระหนักถึงความเสี่ยงนี้ก่อนที่จะใช้โมเดลที่กำหนดเองในแอปของคุณ

บน Android API ระดับ 21 (Lollipop) และใหม่กว่า โมเดลจะถูกดาวน์โหลดไปยังไดเร็กทอรีที่ ไม่รวมอยู่ในการสำรองข้อมูลอัตโนมัติ

บน Android API ระดับ 20 และเก่ากว่า โมเดลจะถูกดาวน์โหลดไปยังไดเร็กทอรีชื่อ com.google.firebase.ml.custom.models ในที่จัดเก็บข้อมูลภายในส่วนตัวของแอป หากคุณเปิดใช้งานการสำรองข้อมูลไฟล์โดยใช้ BackupAgent คุณอาจเลือกที่จะยกเว้นไดเร็กทอรีนี้