Check out what’s new from Firebase at Google I/O 2022. Learn more

ใช้โมเดล TensorFlow Lite สำหรับการอนุมานด้วย ML Kit บน Android

คุณสามารถใช้ ML Kit เพื่อทำการอนุมานในอุปกรณ์ด้วยรุ่น TensorFlow Lite

API นี้ต้องใช้ Android SDK ระดับ 16 (Jelly Bean) หรือใหม่กว่า

ก่อนจะเริ่ม

  1. หากคุณยังไม่ได้ เพิ่ม Firebase ในโครงการ Android ของคุณ
  2. เพิ่มการพึ่งพาสำหรับไลบรารี ML Kit Android ไปยังโมดูลของคุณ (ระดับแอป) ไฟล์ Gradle (โดยปกติคือ app/build.gradle ):
    apply plugin: 'com.android.application'
    apply plugin: 'com.google.gms.google-services'
    
    dependencies {
      // ...
    
      implementation 'com.google.firebase:firebase-ml-model-interpreter:22.0.3'
    }
    
  3. แปลงโมเดล TensorFlow ที่คุณต้องการใช้เป็นรูปแบบ TensorFlow Lite ดู TOCO: TensorFlow Lite Optimizing Converter

โฮสต์หรือรวมโมเดลของคุณ

ก่อนที่คุณจะสามารถใช้โมเดล TensorFlow Lite สำหรับการอนุมานในแอปของคุณได้ คุณต้องทำให้โมเดลพร้อมใช้งานใน ML Kit ML Kit สามารถใช้โมเดล TensorFlow Lite ที่โฮสต์จากระยะไกลโดยใช้ Firebase ที่มาพร้อมกับไบนารีของแอป หรือทั้งสองอย่าง

ด้วยการโฮสต์โมเดลบน Firebase คุณสามารถอัปเดตโมเดลโดยไม่ต้องเปิดตัวแอปเวอร์ชันใหม่ และคุณสามารถใช้การกำหนดค่าระยะไกลและการทดสอบ A/B เพื่อแสดงโมเดลต่างๆ แบบไดนามิกให้กับผู้ใช้กลุ่มต่างๆ

หากคุณเลือกที่จะระบุโมเดลโดยโฮสต์โมเดลกับ Firebase เท่านั้น และไม่ได้รวมเข้ากับแอปของคุณ คุณสามารถลดขนาดการดาวน์โหลดเริ่มต้นของแอปได้ อย่างไรก็ตาม พึงระลึกไว้เสมอว่า หากโมเดลนั้นไม่ได้รวมอยู่ในแอพของคุณ ฟังก์ชันที่เกี่ยวข้องกับโมเดลจะไม่สามารถใช้ได้จนกว่าแอพของคุณจะดาวน์โหลดโมเดลนั้นเป็นครั้งแรก

เมื่อรวมโมเดลเข้ากับแอป คุณจะมั่นใจได้ว่าฟีเจอร์ ML ของแอปจะยังทำงานเมื่อโมเดลที่โฮสต์กับ Firebase ไม่พร้อมใช้งาน

โฮสต์โมเดลบน Firebase

ในการโฮสต์โมเดล TensorFlow Lite ของคุณบน Firebase:

  1. ในส่วน ML Kit ของ คอนโซล Firebase ให้คลิกแท็บ กำหนดเอง
  2. คลิก เพิ่มโมเดลที่กำหนดเอง (หรือ เพิ่มโมเดลอื่น )
  3. ระบุชื่อที่จะใช้เพื่อระบุโมเดลของคุณในโปรเจ็กต์ Firebase จากนั้นอัปโหลดไฟล์โมเดล TensorFlow Lite (มักจะลงท้ายด้วย . .tflite หรือ . .lite )
  4. ในไฟล์ Manifest ของแอป ให้ประกาศว่าต้องได้รับอนุญาตจากอินเทอร์เน็ต:
    <uses-permission android:name="android.permission.INTERNET" />
    

หลังจากเพิ่มโมเดลที่กำหนดเองลงในโปรเจ็กต์ Firebase แล้ว คุณจะอ้างอิงโมเดลในแอปได้โดยใช้ชื่อที่คุณระบุ คุณสามารถอัปโหลดโมเดล TensorFlow Lite ใหม่ได้ทุกเมื่อ และแอปของคุณจะดาวน์โหลดโมเดลใหม่และเริ่มใช้งานเมื่อแอปรีสตาร์ทครั้งถัดไป คุณสามารถกำหนดเงื่อนไขอุปกรณ์ที่จำเป็นสำหรับแอปของคุณเพื่อพยายามอัปเดตรุ่น (ดูด้านล่าง)

มัดโมเดลด้วยแอพ

ในการรวมโมเดล TensorFlow Lite กับแอปของคุณ ให้คัดลอกไฟล์โมเดล (โดยปกติจะลงท้ายด้วย . .tflite หรือ .lite ) ไปยัง assets/ โฟลเดอร์ของแอป (คุณอาจต้องสร้างโฟลเดอร์ก่อนโดยคลิกขวาที่ app/ โฟลเดอร์ จากนั้นคลิก ใหม่ > โฟลเดอร์ > โฟลเดอร์เนื้อหา )

จากนั้น เพิ่มข้อมูลต่อไปนี้ในไฟล์ build.gradle ของแอปเพื่อให้แน่ใจว่า Gradle จะไม่บีบอัดโมเดลเมื่อสร้างแอป

android {

    // ...

    aaptOptions {
        noCompress "tflite"  // Your model's file extension: "tflite", "lite", etc.
    }
}

ไฟล์โมเดลจะรวมอยู่ในแพ็คเกจแอปและพร้อมให้ ML Kit เป็นสินทรัพย์ดิบ

โหลดโมเดล

หากต้องการใช้โมเดล TensorFlow Lite ในแอปของคุณ ก่อนอื่นให้กำหนดค่า ML Kit ด้วยตำแหน่งที่โมเดลของคุณพร้อมใช้งาน: จากระยะไกลโดยใช้ Firebase ในที่จัดเก็บในเครื่อง หรือทั้งสองอย่าง หากคุณระบุทั้งรุ่นในพื้นที่และรุ่นระยะไกล คุณสามารถใช้รุ่นระยะไกลได้หากมี และถอยกลับไปใช้รุ่นที่จัดเก็บในเครื่องหากรุ่นระยะไกลไม่พร้อมใช้งาน

กำหนดค่าโมเดลที่โฮสต์โดย Firebase

หากคุณโฮสต์โมเดลของคุณด้วย Firebase ให้สร้างวัตถุ FirebaseCustomRemoteModel โดยระบุชื่อที่คุณกำหนดให้กับโมเดลเมื่อคุณอัปโหลด:

Java

FirebaseCustomRemoteModel remoteModel =
        new FirebaseCustomRemoteModel.Builder("your_model").build();

Kotlin+KTX

val remoteModel = FirebaseCustomRemoteModel.Builder("your_model").build()

จากนั้น เริ่มงานดาวน์โหลดแบบจำลอง โดยระบุเงื่อนไขที่คุณต้องการอนุญาตให้ดาวน์โหลด หากโมเดลนั้นไม่มีอยู่ในอุปกรณ์ หรือมีเวอร์ชันที่ใหม่กว่าอยู่ งานจะดาวน์โหลดโมเดลแบบอะซิงโครนัสจาก Firebase:

Java

FirebaseModelDownloadConditions conditions = new FirebaseModelDownloadConditions.Builder()
        .requireWifi()
        .build();
FirebaseModelManager.getInstance().download(remoteModel, conditions)
        .addOnCompleteListener(new OnCompleteListener<Void>() {
            @Override
            public void onComplete(@NonNull Task<Void> task) {
                // Success.
            }
        });

Kotlin+KTX

val conditions = FirebaseModelDownloadConditions.Builder()
    .requireWifi()
    .build()
FirebaseModelManager.getInstance().download(remoteModel, conditions)
    .addOnCompleteListener {
        // Success.
    }

แอพจำนวนมากเริ่มงานดาวน์โหลดในรหัสเริ่มต้น แต่คุณสามารถทำได้ทุกเมื่อก่อนจะต้องใช้โมเดล

กำหนดค่าโมเดลท้องถิ่น

หากคุณรวมโมเดลเข้ากับแอปของคุณ ให้สร้างอ็อบเจ็กต์ FirebaseCustomLocalModel โดยระบุชื่อไฟล์ของโมเดล TensorFlow Lite:

Java

FirebaseCustomLocalModel localModel = new FirebaseCustomLocalModel.Builder()
        .setAssetFilePath("your_model.tflite")
        .build();

Kotlin+KTX

val localModel = FirebaseCustomLocalModel.Builder()
    .setAssetFilePath("your_model.tflite")
    .build()

สร้างล่ามจากโมเดลของคุณ

หลังจากที่คุณกำหนดค่าแหล่งที่มาของโมเดลแล้ว ให้สร้างออบเจ็กต์ FirebaseModelInterpreter จากหนึ่งในนั้น

หากคุณมีเฉพาะโมเดลที่รวมอยู่ในเครื่อง ให้สร้างล่ามจากอ็อบเจ็กต์ FirebaseCustomLocalModel ของคุณ:

Java

FirebaseModelInterpreter interpreter;
try {
    FirebaseModelInterpreterOptions options =
            new FirebaseModelInterpreterOptions.Builder(localModel).build();
    interpreter = FirebaseModelInterpreter.getInstance(options);
} catch (FirebaseMLException e) {
    // ...
}

Kotlin+KTX

val options = FirebaseModelInterpreterOptions.Builder(localModel).build()
val interpreter = FirebaseModelInterpreter.getInstance(options)

หากคุณมีโมเดลที่โฮสต์จากระยะไกล คุณจะต้องตรวจสอบว่าได้ดาวน์โหลดโมเดลดังกล่าวแล้วก่อนที่จะเรียกใช้ คุณสามารถตรวจสอบสถานะของงานดาวน์โหลดโมเดลได้โดยใช้ isModelDownloaded() ของตัวจัดการโมเดล

แม้ว่าคุณจะต้องยืนยันสิ่งนี้ก่อนเรียกใช้ล่ามเท่านั้น หากคุณมีทั้งโมเดลที่โฮสต์จากระยะไกลและโมเดลที่รวมอยู่ในเครื่อง คุณควรดำเนินการตรวจสอบนี้เมื่อสร้างอินสแตนซ์ของล่ามโมเดล: สร้างล่ามจากโมเดลระยะไกลหาก มันถูกดาวน์โหลดและจากรุ่นในพื้นที่เป็นอย่างอื่น

Java

FirebaseModelManager.getInstance().isModelDownloaded(remoteModel)
        .addOnSuccessListener(new OnSuccessListener<Boolean>() {
            @Override
            public void onSuccess(Boolean isDownloaded) {
                FirebaseModelInterpreterOptions options;
                if (isDownloaded) {
                    options = new FirebaseModelInterpreterOptions.Builder(remoteModel).build();
                } else {
                    options = new FirebaseModelInterpreterOptions.Builder(localModel).build();
                }
                FirebaseModelInterpreter interpreter = FirebaseModelInterpreter.getInstance(options);
                // ...
            }
        });

Kotlin+KTX

FirebaseModelManager.getInstance().isModelDownloaded(remoteModel)
    .addOnSuccessListener { isDownloaded -> 
    val options =
        if (isDownloaded) {
            FirebaseModelInterpreterOptions.Builder(remoteModel).build()
        } else {
            FirebaseModelInterpreterOptions.Builder(localModel).build()
        }
    val interpreter = FirebaseModelInterpreter.getInstance(options)
}

หากคุณมีเฉพาะโมเดลที่โฮสต์จากระยะไกล คุณควรปิดใช้งานฟังก์ชันที่เกี่ยวข้องกับโมเดล เช่น สีเทาหรือซ่อนส่วนหนึ่งของ UI ของคุณ จนกว่าคุณจะยืนยันว่าได้ดาวน์โหลดโมเดลแล้ว คุณสามารถทำได้โดยแนบ Listener เข้ากับเมธอด download() ของ model manager:

Java

FirebaseModelManager.getInstance().download(remoteModel, conditions)
        .addOnSuccessListener(new OnSuccessListener<Void>() {
            @Override
            public void onSuccess(Void v) {
              // Download complete. Depending on your app, you could enable
              // the ML feature, or switch from the local model to the remote
              // model, etc.
            }
        });

Kotlin+KTX

FirebaseModelManager.getInstance().download(remoteModel, conditions)
    .addOnCompleteListener {
        // Download complete. Depending on your app, you could enable the ML
        // feature, or switch from the local model to the remote model, etc.
    }

ระบุอินพุตและเอาต์พุตของโมเดล

ถัดไป กำหนดค่ารูปแบบอินพุตและเอาต์พุตของล่ามแบบจำลอง

โมเดล TensorFlow Lite ใช้เป็นอินพุตและสร้างอาร์เรย์หลายมิติเป็นเอาต์พุต อาร์เรย์เหล่านี้มีค่า byte , int , long หรือ float คุณต้องกำหนดค่า ML Kit ด้วยจำนวนและมิติข้อมูล ("รูปร่าง") ของอาร์เรย์ที่โมเดลของคุณใช้

หากคุณไม่ทราบรูปร่างและประเภทข้อมูลของอินพุตและเอาต์พุตของโมเดล คุณสามารถใช้ตัวแปล TensorFlow Lite Python เพื่อตรวจสอบโมเดลของคุณได้ ตัวอย่างเช่น:

import tensorflow as tf

interpreter = tf.lite.Interpreter(model_path="my_model.tflite")
interpreter.allocate_tensors()

# Print input shape and type
print(interpreter.get_input_details()[0]['shape'])  # Example: [1 224 224 3]
print(interpreter.get_input_details()[0]['dtype'])  # Example: <class 'numpy.float32'>

# Print output shape and type
print(interpreter.get_output_details()[0]['shape'])  # Example: [1 1000]
print(interpreter.get_output_details()[0]['dtype'])  # Example: <class 'numpy.float32'>

หลังจากที่คุณได้กำหนดรูปแบบของอินพุตและเอาต์พุตของโมเดลของคุณแล้ว คุณสามารถกำหนดค่าตัวแปลโมเดลของแอปได้โดยการสร้างอ็อบเจ็กต์ FirebaseModelInputOutputOptions

ตัวอย่างเช่น แบบจำลองการจัดประเภทรูปภาพแบบทศนิยมอาจใช้อินพุตอาร์เรย์ N float ของค่าทศนิยม แทนชุดของรูปภาพ N 224x224 สามช่องสัญญาณ (RGB) และสร้างรายการค่า float ลต 1000 รายการเป็นเอาต์พุต โดยแต่ละรายการแทน ความน่าจะเป็นที่รูปภาพจะเป็นหนึ่งใน 1,000 หมวดหมู่ที่แบบจำลองคาดการณ์

สำหรับโมเดลดังกล่าว คุณจะต้องกำหนดค่าอินพุตและเอาต์พุตของตัวแปลโมเดลดังที่แสดงด้านล่าง:

Java

FirebaseModelInputOutputOptions inputOutputOptions =
        new FirebaseModelInputOutputOptions.Builder()
                .setInputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 224, 224, 3})
                .setOutputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 5})
                .build();

Kotlin+KTX

val inputOutputOptions = FirebaseModelInputOutputOptions.Builder()
        .setInputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1, 224, 224, 3))
        .setOutputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1, 5))
        .build()

ทำการอนุมานข้อมูลอินพุต

สุดท้าย ในการอนุมานโดยใช้โมเดล ให้รับข้อมูลที่คุณป้อนและทำการแปลงข้อมูลที่จำเป็นเพื่อให้ได้อาร์เรย์อินพุตที่มีรูปทรงที่เหมาะสมสำหรับโมเดลของคุณ

ตัวอย่างเช่น หากคุณมีโมเดลการจำแนกรูปภาพที่มีรูปร่างอินพุตเป็นค่าจุดลอยตัว [1 224 224 3] คุณสามารถสร้างอาร์เรย์อินพุตจากออบเจ็กต์ Bitmap ดังที่แสดงในตัวอย่างต่อไปนี้:

Java

Bitmap bitmap = getYourInputImage();
bitmap = Bitmap.createScaledBitmap(bitmap, 224, 224, true);

int batchNum = 0;
float[][][][] input = new float[1][224][224][3];
for (int x = 0; x < 224; x++) {
    for (int y = 0; y < 224; y++) {
        int pixel = bitmap.getPixel(x, y);
        // Normalize channel values to [-1.0, 1.0]. This requirement varies by
        // model. For example, some models might require values to be normalized
        // to the range [0.0, 1.0] instead.
        input[batchNum][x][y][0] = (Color.red(pixel) - 127) / 128.0f;
        input[batchNum][x][y][1] = (Color.green(pixel) - 127) / 128.0f;
        input[batchNum][x][y][2] = (Color.blue(pixel) - 127) / 128.0f;
    }
}

Kotlin+KTX

val bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true)

val batchNum = 0
val input = Array(1) { Array(224) { Array(224) { FloatArray(3) } } }
for (x in 0..223) {
    for (y in 0..223) {
        val pixel = bitmap.getPixel(x, y)
        // Normalize channel values to [-1.0, 1.0]. This requirement varies by
        // model. For example, some models might require values to be normalized
        // to the range [0.0, 1.0] instead.
        input[batchNum][x][y][0] = (Color.red(pixel) - 127) / 255.0f
        input[batchNum][x][y][1] = (Color.green(pixel) - 127) / 255.0f
        input[batchNum][x][y][2] = (Color.blue(pixel) - 127) / 255.0f
    }
}

จากนั้น สร้างอ็อบเจ็กต์ FirebaseModelInputs ด้วยข้อมูลอินพุตของคุณ และส่งผ่านและข้อมูลจำเพาะอินพุตและเอาต์พุตของโมเดลไปยังวิธีการ run ของ model interpreter :

Java

FirebaseModelInputs inputs = new FirebaseModelInputs.Builder()
        .add(input)  // add() as many input arrays as your model requires
        .build();
firebaseInterpreter.run(inputs, inputOutputOptions)
        .addOnSuccessListener(
                new OnSuccessListener<FirebaseModelOutputs>() {
                    @Override
                    public void onSuccess(FirebaseModelOutputs result) {
                        // ...
                    }
                })
        .addOnFailureListener(
                new OnFailureListener() {
                    @Override
                    public void onFailure(@NonNull Exception e) {
                        // Task failed with an exception
                        // ...
                    }
                });

Kotlin+KTX

val inputs = FirebaseModelInputs.Builder()
        .add(input) // add() as many input arrays as your model requires
        .build()
firebaseInterpreter.run(inputs, inputOutputOptions)
        .addOnSuccessListener { result ->
            // ...
        }
        .addOnFailureListener { e ->
            // Task failed with an exception
            // ...
        }

หากการโทรสำเร็จ คุณสามารถรับผลลัพธ์ได้โดยการเรียกเมธอด getOutput() ของอ็อบเจ็กต์ที่ส่งผ่านไปยังผู้ฟังที่ประสบความสำเร็จ ตัวอย่างเช่น:

Java

float[][] output = result.getOutput(0);
float[] probabilities = output[0];

Kotlin+KTX

val output = result.getOutput<Array<FloatArray>>(0)
val probabilities = output[0]

วิธีที่คุณใช้ผลลัพธ์ขึ้นอยู่กับรุ่นที่คุณใช้

ตัวอย่างเช่น หากคุณกำลังจัดประเภท คุณอาจแมปดัชนีของผลลัพธ์กับป้ายกำกับในขั้นตอนต่อไป

Java

BufferedReader reader = new BufferedReader(
        new InputStreamReader(getAssets().open("retrained_labels.txt")));
for (int i = 0; i < probabilities.length; i++) {
    String label = reader.readLine();
    Log.i("MLKit", String.format("%s: %1.4f", label, probabilities[i]));
}

Kotlin+KTX

val reader = BufferedReader(
        InputStreamReader(assets.open("retrained_labels.txt")))
for (i in probabilities.indices) {
    val label = reader.readLine()
    Log.i("MLKit", String.format("%s: %1.4f", label, probabilities[i]))
}

ภาคผนวก: ความปลอดภัยของแบบจำลอง

ไม่ว่าคุณจะทำให้รุ่น TensorFlow Lite ของคุณพร้อมใช้งานใน ML Kit อย่างไร ML Kit จะจัดเก็บโมเดลเหล่านี้ไว้ในรูปแบบโปรโตบัฟซีเรียลไลซ์มาตรฐานในที่จัดเก็บในเครื่อง

ตามทฤษฎีแล้ว นี่หมายความว่าใครๆ ก็สามารถคัดลอกแบบจำลองของคุณได้ อย่างไรก็ตาม ในทางปฏิบัติ โมเดลส่วนใหญ่มีความเฉพาะเจาะจงกับแอปพลิเคชันและทำให้สับสนโดยการปรับให้เหมาะสม ซึ่งความเสี่ยงนั้นใกล้เคียงกับของคู่แข่งในการถอดแยกชิ้นส่วนและนำโค้ดของคุณมาใช้ซ้ำ อย่างไรก็ตาม คุณควรตระหนักถึงความเสี่ยงนี้ก่อนที่คุณจะใช้โมเดลแบบกำหนดเองในแอปของคุณ

ใน Android API ระดับ 21 (Lollipop) และใหม่กว่า โมเดลจะถูกดาวน์โหลดไปยังไดเร็กทอรีที่ ไม่รวมอยู่ในการสำรองข้อมูลอัตโนมัติ

ใน Android API ระดับ 20 ขึ้นไป โมเดลจะถูกดาวน์โหลดไปยังไดเร็กทอรีชื่อ com.google.firebase.ml.custom.models ในที่จัดเก็บข้อมูลภายในแบบส่วนตัวของแอป หากคุณเปิดใช้งานการสำรองไฟล์โดยใช้ BackupAgent คุณอาจเลือกที่จะยกเว้นไดเร็กทอรีนี้