ใช้โมเดล TensorFlow Lite ที่กำหนดเองบน Android

หากแอปของคุณใช้โมเดล TensorFlow Lite ที่กำหนดเอง คุณสามารถใช้ Firebase ML เพื่อปรับใช้โมเดลของคุณได้ การปรับใช้โมเดลกับ Firebase ช่วยให้คุณลดขนาดการดาวน์โหลดเริ่มต้นของแอปและอัปเดตโมเดล ML ของแอปได้โดยไม่ต้องออกแอปเวอร์ชันใหม่ และด้วย Remote Config และ A/B Testing คุณสามารถให้บริการโมเดลต่างๆ แบบไดนามิกแก่กลุ่มผู้ใช้ที่แตกต่างกันได้

รุ่น TensorFlow Lite

รุ่น TensorFlow Lite เป็นรุ่น ML ที่ได้รับการปรับแต่งให้ทำงานบนอุปกรณ์เคลื่อนที่ ในการรับโมเดล TensorFlow Lite:

ก่อนที่คุณจะเริ่มต้น

  1. หากคุณยังไม่ได้ ดำเนินการ ให้เพิ่ม Firebase ในโครงการ Android ของคุณ
  2. ใน ไฟล์ Gradle ของโมดูล (ระดับแอป) (โดยปกติคือ <project>/<app-module>/build.gradle.kts หรือ <project>/<app-module>/build.gradle ) เพิ่มการพึ่งพาสำหรับ Firebase ML ไลบรารีตัวดาวน์โหลดโมเดล Android ขอแนะนำให้ใช้ Firebase Android BoM เพื่อควบคุมการกำหนดเวอร์ชันของไลบรารี

    นอกจากนี้ ในการตั้งค่าตัวดาวน์โหลดโมเดล Firebase ML คุณต้องเพิ่ม TensorFlow Lite SDK ลงในแอปของคุณ

    Kotlin+KTX

    dependencies {
        // Import the BoM for the Firebase platform
        implementation(platform("com.google.firebase:firebase-bom:32.3.1"))
    
        // Add the dependency for the Firebase ML model downloader library
        // When using the BoM, you don't specify versions in Firebase library dependencies
        implementation("com.google.firebase:firebase-ml-modeldownloader-ktx")
    // Also add the dependency for the TensorFlow Lite library and specify its version implementation("org.tensorflow:tensorflow-lite:2.3.0")
    }

    เมื่อใช้ Firebase Android BoM แอปของคุณจะใช้ไลบรารี Firebase Android เวอร์ชันที่เข้ากันได้เสมอ

    (ทางเลือก) เพิ่มการอ้างอิงไลบรารี Firebase โดยไม่ ใช้ BoM

    หากคุณเลือกที่จะไม่ใช้ Firebase BoM คุณต้องระบุแต่ละเวอร์ชันของไลบรารี Firebase ในบรรทัดอ้างอิง

    โปรดทราบว่าหากคุณใช้ไลบรารี Firebase หลาย ไลบรารีในแอป เราขอแนะนำอย่างยิ่งให้ใช้ BoM เพื่อจัดการเวอร์ชันของไลบรารี ซึ่งทำให้แน่ใจว่าเวอร์ชันทั้งหมดเข้ากันได้

    dependencies {
        // Add the dependency for the Firebase ML model downloader library
        // When NOT using the BoM, you must specify versions in Firebase library dependencies
        implementation("com.google.firebase:firebase-ml-modeldownloader-ktx:24.1.3")
    // Also add the dependency for the TensorFlow Lite library and specify its version implementation("org.tensorflow:tensorflow-lite:2.3.0")
    }

    Java

    dependencies {
        // Import the BoM for the Firebase platform
        implementation(platform("com.google.firebase:firebase-bom:32.3.1"))
    
        // Add the dependency for the Firebase ML model downloader library
        // When using the BoM, you don't specify versions in Firebase library dependencies
        implementation("com.google.firebase:firebase-ml-modeldownloader")
    // Also add the dependency for the TensorFlow Lite library and specify its version implementation("org.tensorflow:tensorflow-lite:2.3.0")
    }

    เมื่อใช้ Firebase Android BoM แอปของคุณจะใช้ไลบรารี Firebase Android เวอร์ชันที่เข้ากันได้เสมอ

    (ทางเลือก) เพิ่มการอ้างอิงไลบรารี Firebase โดยไม่ ใช้ BoM

    หากคุณเลือกที่จะไม่ใช้ Firebase BoM คุณต้องระบุแต่ละเวอร์ชันของไลบรารี Firebase ในบรรทัดอ้างอิง

    โปรดทราบว่าหากคุณใช้ไลบรารี Firebase หลาย ไลบรารีในแอป เราขอแนะนำอย่างยิ่งให้ใช้ BoM เพื่อจัดการเวอร์ชันของไลบรารี ซึ่งทำให้แน่ใจว่าเวอร์ชันทั้งหมดเข้ากันได้

    dependencies {
        // Add the dependency for the Firebase ML model downloader library
        // When NOT using the BoM, you must specify versions in Firebase library dependencies
        implementation("com.google.firebase:firebase-ml-modeldownloader:24.1.3")
    // Also add the dependency for the TensorFlow Lite library and specify its version implementation("org.tensorflow:tensorflow-lite:2.3.0")
    }
  3. ในไฟล์ Manifest ของแอป ให้ประกาศว่าต้องได้รับอนุญาตจากอินเทอร์เน็ต:
    <uses-permission android:name="android.permission.INTERNET" />

1. ปรับใช้โมเดลของคุณ

ปรับใช้โมเดล TensorFlow ที่กำหนดเองโดยใช้คอนโซล Firebase หรือ Firebase Admin Python และ Node.js SDK ดู ปรับใช้และจัดการโมเดลแบบกำหนดเอง

หลังจากเพิ่มโมเดลที่กำหนดเองในโครงการ Firebase แล้ว คุณจะอ้างอิงโมเดลในแอปได้โดยใช้ชื่อที่คุณระบุ คุณสามารถปรับใช้โมเดล TensorFlow Lite ใหม่ได้ทุกเมื่อ และดาวน์โหลดโมเดลใหม่ลงในอุปกรณ์ของผู้ใช้โดยเรียก getModel() (ดูด้านล่าง)

2. ดาวน์โหลดโมเดลลงในอุปกรณ์และเริ่มต้นล่าม TensorFlow Lite

หากต้องการใช้โมเดล TensorFlow Lite ในแอป ก่อนอื่นให้ใช้ Firebase ML SDK เพื่อดาวน์โหลดโมเดลเวอร์ชันล่าสุดลงในอุปกรณ์ จากนั้นสร้างอินสแตนซ์ของล่าม TensorFlow Lite ด้วยโมเดล

ในการเริ่มการดาวน์โหลดโมเดล ให้เรียกเมธอด getModel() ของตัวดาวน์โหลดโมเดล โดยระบุชื่อที่คุณกำหนดให้กับโมเดลเมื่ออัปโหลดโมเดล คุณต้องการดาวน์โหลดโมเดลล่าสุดเสมอหรือไม่ และเงื่อนไขที่คุณต้องการอนุญาตให้ดาวน์โหลด

คุณสามารถเลือกพฤติกรรมการดาวน์โหลดได้สามแบบ:

ประเภทการดาวน์โหลด คำอธิบาย
LOCAL_MODEL รับโมเดลท้องถิ่นจากอุปกรณ์ หากไม่มีโมเดลท้องถิ่น การทำงานจะเป็นแบบ LATEST_MODEL ใช้ประเภทการดาวน์โหลดนี้หากคุณไม่สนใจที่จะตรวจสอบการอัปเดตโมเดล ตัวอย่างเช่น คุณกำลังใช้ Remote Config เพื่อดึงชื่อโมเดล และคุณมักจะอัปโหลดโมเดลโดยใช้ชื่อใหม่ (แนะนำ)
LOCAL_MODEL_UPDATE_IN_BACKGROUND รับโมเดลในเครื่องจากอุปกรณ์และเริ่มอัปเดตโมเดลในเบื้องหลัง หากไม่มีโมเดลท้องถิ่น การทำงานจะเป็นแบบ LATEST_MODEL
LATEST_MODEL รับรุ่นใหม่ล่าสุด หากโมเดลโลคัลเป็นเวอร์ชันล่าสุด ให้ส่งคืนโมเดลโลคัล มิฉะนั้นให้ดาวน์โหลดโมเดลล่าสุด ลักษณะการทำงานนี้จะบล็อกจนกว่าจะดาวน์โหลดเวอร์ชันล่าสุด (ไม่แนะนำ) ใช้ลักษณะการทำงานนี้เฉพาะในกรณีที่คุณต้องการเวอร์ชันล่าสุดอย่างชัดเจน

คุณควรปิดใช้งานฟังก์ชันที่เกี่ยวข้องกับโมเดล เช่น สีเทาหรือซ่อนบางส่วนของ UI ของคุณ จนกว่าคุณจะยืนยันว่าดาวน์โหลดโมเดลแล้ว

Kotlin+KTX

val conditions = CustomModelDownloadConditions.Builder()
        .requireWifi()  // Also possible: .requireCharging() and .requireDeviceIdle()
        .build()
FirebaseModelDownloader.getInstance()
        .getModel("your_model", DownloadType.LOCAL_MODEL_UPDATE_IN_BACKGROUND,
            conditions)
        .addOnSuccessListener { model: CustomModel? ->
            // Download complete. Depending on your app, you could enable the ML
            // feature, or switch from the local model to the remote model, etc.

            // The CustomModel object contains the local path of the model file,
            // which you can use to instantiate a TensorFlow Lite interpreter.
            val modelFile = model?.file
            if (modelFile != null) {
                interpreter = Interpreter(modelFile)
            }
        }

Java

CustomModelDownloadConditions conditions = new CustomModelDownloadConditions.Builder()
    .requireWifi()  // Also possible: .requireCharging() and .requireDeviceIdle()
    .build();
FirebaseModelDownloader.getInstance()
    .getModel("your_model", DownloadType.LOCAL_MODEL_UPDATE_IN_BACKGROUND, conditions)
    .addOnSuccessListener(new OnSuccessListener<CustomModel>() {
      @Override
      public void onSuccess(CustomModel model) {
        // Download complete. Depending on your app, you could enable the ML
        // feature, or switch from the local model to the remote model, etc.

        // The CustomModel object contains the local path of the model file,
        // which you can use to instantiate a TensorFlow Lite interpreter.
        File modelFile = model.getFile();
        if (modelFile != null) {
            interpreter = new Interpreter(modelFile);
        }
      }
    });

แอพจำนวนมากเริ่มงานดาวน์โหลดด้วยรหัสการเริ่มต้น แต่คุณสามารถทำได้ทุกเมื่อก่อนที่คุณจะต้องใช้โมเดล

3. ทำการอนุมานข้อมูลเข้า

รับรูปร่างอินพุตและเอาต์พุตของโมเดลของคุณ

ตัวแปลโมเดล TensorFlow Lite ใช้เป็นอินพุตและสร้างอาร์เรย์หลายมิติตั้งแต่หนึ่งอาร์เรย์ขึ้นไปเป็นเอาต์พุต อาร์เรย์เหล่านี้มีค่า byte , int , long หรือ float ก่อนที่คุณจะสามารถส่งข้อมูลไปยังโมเดลหรือใช้ผลลัพธ์ คุณต้องทราบจำนวนและขนาด ("รูปร่าง") ของอาร์เรย์ที่โมเดลของคุณใช้

หากคุณสร้างโมเดลด้วยตัวเอง หรือหากมีการบันทึกรูปแบบอินพุตและเอาต์พุตของโมเดล คุณอาจมีข้อมูลนี้อยู่แล้ว หากคุณไม่ทราบรูปร่างและประเภทข้อมูลของอินพุตและเอาต์พุตของโมเดล คุณสามารถใช้ตัวแปล TensorFlow Lite เพื่อตรวจสอบโมเดลของคุณได้ ตัวอย่างเช่น:

หลาม

import tensorflow as tf

interpreter = tf.lite.Interpreter(model_path="your_model.tflite")
interpreter.allocate_tensors()

# Print input shape and type
inputs = interpreter.get_input_details()
print('{} input(s):'.format(len(inputs)))
for i in range(0, len(inputs)):
    print('{} {}'.format(inputs[i]['shape'], inputs[i]['dtype']))

# Print output shape and type
outputs = interpreter.get_output_details()
print('\n{} output(s):'.format(len(outputs)))
for i in range(0, len(outputs)):
    print('{} {}'.format(outputs[i]['shape'], outputs[i]['dtype']))

ตัวอย่างเอาต์พุต:

1 input(s):
[  1 224 224   3] <class 'numpy.float32'>

1 output(s):
[1 1000] <class 'numpy.float32'>

เรียกใช้ล่าม

หลังจากที่คุณได้กำหนดรูปแบบของอินพุตและเอาต์พุตของโมเดลของคุณแล้ว ให้รับข้อมูลอินพุตของคุณและทำการแปลงใดๆ กับข้อมูลที่จำเป็นเพื่อให้ได้อินพุตของรูปร่างที่เหมาะสมสำหรับโมเดลของคุณ

ตัวอย่างเช่น หากคุณมีโมเดลการจัดประเภทรูปภาพที่มีรูปร่างอินพุตเป็นค่าทศนิยม [1 224 224 3] คุณสามารถสร้าง ByteBuffer อินพุตจากวัตถุ Bitmap ดังที่แสดงในตัวอย่างต่อไปนี้:

Kotlin+KTX

val bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true)
val input = ByteBuffer.allocateDirect(224*224*3*4).order(ByteOrder.nativeOrder())
for (y in 0 until 224) {
    for (x in 0 until 224) {
        val px = bitmap.getPixel(x, y)

        // Get channel values from the pixel value.
        val r = Color.red(px)
        val g = Color.green(px)
        val b = Color.blue(px)

        // Normalize channel values to [-1.0, 1.0]. This requirement depends on the model.
        // For example, some models might require values to be normalized to the range
        // [0.0, 1.0] instead.
        val rf = (r - 127) / 255f
        val gf = (g - 127) / 255f
        val bf = (b - 127) / 255f

        input.putFloat(rf)
        input.putFloat(gf)
        input.putFloat(bf)
    }
}

Java

Bitmap bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true);
ByteBuffer input = ByteBuffer.allocateDirect(224 * 224 * 3 * 4).order(ByteOrder.nativeOrder());
for (int y = 0; y < 224; y++) {
    for (int x = 0; x < 224; x++) {
        int px = bitmap.getPixel(x, y);

        // Get channel values from the pixel value.
        int r = Color.red(px);
        int g = Color.green(px);
        int b = Color.blue(px);

        // Normalize channel values to [-1.0, 1.0]. This requirement depends
        // on the model. For example, some models might require values to be
        // normalized to the range [0.0, 1.0] instead.
        float rf = (r - 127) / 255.0f;
        float gf = (g - 127) / 255.0f;
        float bf = (b - 127) / 255.0f;

        input.putFloat(rf);
        input.putFloat(gf);
        input.putFloat(bf);
    }
}

จากนั้น จัดสรร ByteBuffer ให้ใหญ่พอที่จะเก็บเอาต์พุตของโมเดล และส่งผ่านอินพุตบัฟเฟอร์และเอาต์พุตบัฟเฟอร์ไปยังเมธอด run() ของล่าม TensorFlow Lite ตัวอย่างเช่น สำหรับรูปร่างเอาต์พุตของค่าทศนิยม [1 1000] :

Kotlin+KTX

val bufferSize = 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE
val modelOutput = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder())
interpreter?.run(input, modelOutput)

Java

int bufferSize = 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE;
ByteBuffer modelOutput = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder());
interpreter.run(input, modelOutput);

วิธีที่คุณใช้เอาต์พุตขึ้นอยู่กับรุ่นที่คุณใช้

ตัวอย่างเช่น หากคุณกำลังจัดประเภท ในขั้นตอนต่อไป คุณอาจแมปดัชนีของผลลัพธ์กับป้ายกำกับที่แสดง:

Kotlin+KTX

modelOutput.rewind()
val probabilities = modelOutput.asFloatBuffer()
try {
    val reader = BufferedReader(
            InputStreamReader(assets.open("custom_labels.txt")))
    for (i in probabilities.capacity()) {
        val label: String = reader.readLine()
        val probability = probabilities.get(i)
        println("$label: $probability")
    }
} catch (e: IOException) {
    // File not found?
}

Java

modelOutput.rewind();
FloatBuffer probabilities = modelOutput.asFloatBuffer();
try {
    BufferedReader reader = new BufferedReader(
            new InputStreamReader(getAssets().open("custom_labels.txt")));
    for (int i = 0; i < probabilities.capacity(); i++) {
        String label = reader.readLine();
        float probability = probabilities.get(i);
        Log.i(TAG, String.format("%s: %1.4f", label, probability));
    }
} catch (IOException e) {
    // File not found?
}

ภาคผนวก: ความปลอดภัยของโมเดล

ไม่ว่าคุณจะทำให้โมเดล TensorFlow Lite พร้อมใช้งานกับ Firebase ML ด้วยวิธีใด Firebase ML จะจัดเก็บโมเดลเหล่านี้ในรูปแบบ protobuf ที่ทำให้เป็นอนุกรมมาตรฐานในที่จัดเก็บในตัวเครื่อง

ในทางทฤษฎีหมายความว่าใครก็ตามสามารถคัดลอกโมเดลของคุณได้ อย่างไรก็ตาม ในทางปฏิบัติ โมเดลส่วนใหญ่มีความเฉพาะเจาะจงกับแอปพลิเคชันและถูกทำให้ยุ่งเหยิงด้วยการปรับให้เหมาะสม ซึ่งมีความเสี่ยงใกล้เคียงกับคู่แข่งในการถอดประกอบและนำโค้ดของคุณกลับมาใช้ใหม่ อย่างไรก็ตาม คุณควรตระหนักถึงความเสี่ยงนี้ก่อนที่จะใช้โมเดลที่กำหนดเองในแอปของคุณ

ใน Android API ระดับ 21 (Lollipop) และใหม่กว่า โมเดลจะถูกดาวน์โหลดไปยังไดเร็กทอรีที่ ไม่รวมอยู่ในการสำรองข้อมูลอัตโนมัติ

ใน Android API ระดับ 20 และเก่ากว่า โมเดลจะถูกดาวน์โหลดไปยังไดเร็กทอรีชื่อ com.google.firebase.ml.custom.models ในที่จัดเก็บข้อมูลภายในแอปส่วนตัว หากคุณเปิดใช้งานการสำรองไฟล์โดยใช้ BackupAgent คุณอาจเลือกที่จะไม่รวมไดเร็กทอรีนี้