Google is committed to advancing racial equity for Black communities. See how.

Use a custom TensorFlow Lite model on Android

If your app uses custom TensorFlow Lite models, you can use Firebase ML to deploy your models. By deploying models with Firebase, you can reduce the initial download size of your app and update your app's ML models without releasing a new version of your app. And, with Remote Config and A/B Testing, you can dynamically serve different models to different sets of users.

TensorFlow Lite models

TensorFlow Lite models are ML models that are optimized to run on mobile devices. To get a TensorFlow Lite model:

Before you begin

  1. If you haven't already, add Firebase to your Android project.
  2. Using the Firebase Android BoM, declare the dependency for the Firebase ML Custom Models Android library in your module (app-level) Gradle file (usually app/build.gradle).

    Also, as part of setting up Firebase ML Custom Models, you need to add the TensorFlow Lite SDK to your app.

    dependencies {
        // Import the BoM for the Firebase platform
        implementation platform('com.google.firebase:firebase-bom:26.1.0')
    
        // Declare the dependency for the Firebase ML Custom Models library
        // When using the BoM, you don't specify versions in Firebase library dependencies
        implementation 'com.google.firebase:firebase-ml-model-interpreter'
    // Also declare the dependency for the TensorFlow Lite library and specify its version implementation 'org.tensorflow:tensorflow-lite:2.3.0'
    }

    By using the Firebase Android BoM, your app will always use compatible versions of the Firebase Android libraries.

    (Alternative) Declare Firebase library dependencies without using the BoM

    If you choose not to use the Firebase BoM, you must specify each Firebase library version in its dependency line.

    Note that if you use multiple Firebase libraries in your app, we highly recommend using the BoM to manage library versions, which ensures that all versions are compatible.

    dependencies {
        // Declare the dependency for the Firebase ML Custom Models library
        // When NOT using the BoM, you must specify versions in Firebase library dependencies
        implementation 'com.google.firebase:firebase-ml-model-interpreter:22.0.4'
    // Also declare the dependency for the TensorFlow Lite library and specify its version implementation 'org.tensorflow:tensorflow-lite:2.3.0'
    }
  3. In your app's manifest, declare that INTERNET permission is required:
    <uses-permission android:name="android.permission.INTERNET" />

1. Deploy your model

Deploy your custom TensorFlow models using either the Firebase console or the Firebase Admin Python and Node.js SDKs. See Deploy and manage custom models.

After you add a custom model to your Firebase project, you can reference the model in your apps using the name you specified. At any time, you can upload a new TensorFlow Lite model, and your app will download the new model and start using it when the app next restarts. You can define the device conditions required for your app to attempt to update the model (see below).

2. Download the model to the device

To use your TensorFlow Lite model in your app, first use the Firebase ML SDK to download the latest version of the model to the device.

To start the model download, call the model manager's download() method, specifying the name you assigned the model when you uploaded it and the conditions under which you want to allow downloading. If the model isn't on the device, or if a newer version of the model is available, the task will asynchronously download the model from Firebase.

You should disable model-related functionality—for example, grey-out or hide part of your UI—until you confirm the model has been downloaded.

Java

FirebaseCustomRemoteModel remoteModel =
      new FirebaseCustomRemoteModel.Builder("your_model").build();
FirebaseModelDownloadConditions conditions = new FirebaseModelDownloadConditions.Builder()
        .requireWifi()
        .build();
FirebaseModelManager.getInstance().download(remoteModel, conditions)
        .addOnSuccessListener(new OnSuccessListener<Void>() {
            @Override
            public void onSuccess(Void v) {
              // Download complete. Depending on your app, you could enable
              // the ML feature, or switch from the local model to the remote
              // model, etc.
            }
        });

Kotlin+KTX

val remoteModel = FirebaseCustomRemoteModel.Builder("your_model").build()
val conditions = FirebaseModelDownloadConditions.Builder()
    .requireWifi()
    .build()
FirebaseModelManager.getInstance().download(remoteModel, conditions)
    .addOnCompleteListener {
        // Download complete. Depending on your app, you could enable the ML
        // feature, or switch from the local model to the remote model, etc.
    }

Many apps start the download task in their initialization code, but you can do so at any point before you need to use the model.

3. Initialize a TensorFlow Lite interpreter

After you download the model to the device, you can get the model file location by calling the model manager's getLatestModelFile() method. Use this value to instantiate a TensorFlow Lite interpreter:

Java

FirebaseCustomRemoteModel remoteModel = new FirebaseCustomRemoteModel.Builder("your_model").build();
FirebaseModelManager.getInstance().getLatestModelFile(remoteModel)
        .addOnCompleteListener(new OnCompleteListener<File>() {
            @Override
            public void onComplete(@NonNull Task<File> task) {
                File modelFile = task.getResult();
                if (modelFile != null) {
                    interpreter = new Interpreter(modelFile);
                }
            }
        });

Kotlin+KTX

val remoteModel = FirebaseCustomRemoteModel.Builder("your_model").build()
FirebaseModelManager.getInstance().getLatestModelFile(remoteModel)
    .addOnCompleteListener { task ->
        val modelFile = task.result
        if (modelFile != null) {
            interpreter = Interpreter(modelFile)
        }
    }

4. Perform inference on input data

Get your model's input and output shapes

The TensorFlow Lite model interpreter takes as input and produces as output one or more multidimensional arrays. These arrays contain either byte, int, long, or float values. Before you can pass data to a model or use its result, you must know the number and dimensions ("shape") of the arrays your model uses.

If you built the model yourself, or if the model's input and output format is documented, you might already have this information. If you don't know the shape and data type of your model's input and output, you can use the TensorFlow Lite interpreter to inspect your model. For example:

Python

import tensorflow as tf

interpreter = tf.lite.Interpreter(model_path="your_model.tflite")
interpreter.allocate_tensors()

# Print input shape and type
inputs = interpreter.get_input_details()
print('{} input(s):'.format(len(inputs)))
for i in range(0, len(inputs)):
    print('{} {}'.format(inputs[i]['shape'], inputs[i]['dtype']))

# Print output shape and type
outputs = interpreter.get_output_details()
print('\n{} output(s):'.format(len(outputs)))
for i in range(0, len(outputs)):
    print('{} {}'.format(outputs[i]['shape'], outputs[i]['dtype']))

Example output:

1 input(s):
[  1 224 224   3] <class 'numpy.float32'>

1 output(s):
[1 1000] <class 'numpy.float32'>

Run the interpreter

After you have determined the format of your model's input and output, get your input data and perform any transformations on the data that are necessary to get an input of the right shape for your model.

For example, if you have an image classification model with an input shape of [1 224 224 3] floating-point values, you could generate an input ByteBuffer from a Bitmap object as shown in the following example:

Java

Bitmap bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true);
ByteBuffer input = ByteBuffer.allocateDirect(224 * 224 * 3 * 4).order(ByteOrder.nativeOrder());
for (int y = 0; y < 224; y++) {
    for (int x = 0; x < 224; x++) {
        int px = bitmap.getPixel(x, y);

        // Get channel values from the pixel value.
        int r = Color.red(px);
        int g = Color.green(px);
        int b = Color.blue(px);

        // Normalize channel values to [-1.0, 1.0]. This requirement depends
        // on the model. For example, some models might require values to be
        // normalized to the range [0.0, 1.0] instead.
        float rf = (r - 127) / 255.0f;
        float gf = (g - 127) / 255.0f;
        float bf = (b - 127) / 255.0f;

        input.putFloat(rf);
        input.putFloat(gf);
        input.putFloat(bf);
    }
}

Kotlin+KTX

val bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true)
val input = ByteBuffer.allocateDirect(224*224*3*4).order(ByteOrder.nativeOrder())
for (y in 0 until 224) {
    for (x in 0 until 224) {
        val px = bitmap.getPixel(x, y)

        // Get channel values from the pixel value.
        val r = Color.red(px)
        val g = Color.green(px)
        val b = Color.blue(px)

        // Normalize channel values to [-1.0, 1.0]. This requirement depends on the model.
        // For example, some models might require values to be normalized to the range
        // [0.0, 1.0] instead.
        val rf = (r - 127) / 255f
        val gf = (g - 127) / 255f
        val bf = (b - 127) / 255f

        input.putFloat(rf)
        input.putFloat(gf)
        input.putFloat(bf)
    }
}

Then, allocate a ByteBuffer large enough to contain the model's output and pass the input buffer and output buffer to the TensorFlow Lite interpreter's run() method. For example, for an output shape of [1 1000] floating-point values:

Java

int bufferSize = 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE;
ByteBuffer modelOutput = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder());
interpreter.run(input, modelOutput);

Kotlin+KTX

val bufferSize = 1000 * java.lang.Float.SIZE / java.lang.Byte.SIZE
val modelOutput = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder())
interpreter?.run(input, modelOutput)

How you use the output depends on the model you are using.

For example, if you are performing classification, as a next step, you might map the indexes of the result to the labels they represent:

Java

modelOutput.rewind();
FloatBuffer probabilities = modelOutput.asFloatBuffer();
try {
    BufferedReader reader = new BufferedReader(
            new InputStreamReader(getAssets().open("custom_labels.txt")));
    for (int i = 0; i < probabilities.capacity(); i++) {
        String label = reader.readLine();
        float probability = probabilities.get(i);
        Log.i(TAG, String.format("%s: %1.4f", label, probability));
    }
} catch (IOException e) {
    // File not found?
}

Kotlin+KTX

modelOutput.rewind()
val probabilities = modelOutput.asFloatBuffer()
try {
    val reader = BufferedReader(
            InputStreamReader(assets.open("custom_labels.txt")))
    for (i in probabilities.capacity()) {
        val label: String = reader.readLine()
        val probability = probabilities.get(i)
        println("$label: $probability")
    }
} catch (e: IOException) {
    // File not found?
}

Appendix: Fall back to a locally-bundled model

When you host your model with Firebase, any model-related functionality will not be available until your app downloads the model for the first time. For some apps, this might be fine, but if your model enables core functionality, you might want to bundle a version of your model with your app and use the best-available version. By doing so, you can ensure your app's ML features work when the Firebase-hosted model isn't available.

To bundle your TensorFlow Lite model with your app:

  1. Copy the model file (usually ending in .tflite or .lite) to your app's assets/ folder. (You might need to create the folder first by right-clicking the app/ folder, then clicking New > Folder > Assets Folder.)

  2. Add the following to your app's build.gradle file to ensure Gradle doesn't compress the models when building the app:

    android {
    
        // ...
    
        aaptOptions {
            noCompress "tflite", "lite"
        }
    }
    

Then, use the locally-bundled model when the hosted model isn't available:

Java

FirebaseCustomRemoteModel remoteModel =
        new FirebaseCustomRemoteModel.Builder("your_model").build();
FirebaseModelManager.getInstance().getLatestModelFile(remoteModel)
        .addOnCompleteListener(new OnCompleteListener<File>() {
            @Override
            public void onComplete(@NonNull Task<File> task) {
                File modelFile = task.getResult();
                if (modelFile != null) {
                    interpreter = new Interpreter(modelFile);
                } else {
                    try {
                        InputStream inputStream = getAssets().open("your_fallback_model.tflite");
                        byte[] model = new byte[inputStream.available()];
                        inputStream.read(model);
                        ByteBuffer buffer = ByteBuffer.allocateDirect(model.length)
                                .order(ByteOrder.nativeOrder());
                        buffer.put(model);
                        interpreter = new Interpreter(buffer);
                    } catch (IOException e) {
                        // File not found?
                    }
                }
            }
        });

Kotlin+KTX

val remoteModel = FirebaseCustomRemoteModel.Builder("your_model").build()
FirebaseModelManager.getInstance().getLatestModelFile(remoteModel)
    .addOnCompleteListener { task ->
        val modelFile = task.result
        if (modelFile != null) {
            interpreter = Interpreter(modelFile)
        } else {
            val model = assets.open("your_fallback_model.tflite").readBytes()
            val buffer = ByteBuffer.allocateDirect(model.size).order(ByteOrder.nativeOrder())
            buffer.put(model)
            interpreter = Interpreter(buffer)
        }
    }

Appendix: Model security

Regardless of how you make your TensorFlow Lite models available to Firebase ML, Firebase ML stores them in the standard serialized protobuf format in local storage.

In theory, this means that anybody can copy your model. However, in practice, most models are so application-specific and obfuscated by optimizations that the risk is similar to that of competitors disassembling and reusing your code. Nevertheless, you should be aware of this risk before you use a custom model in your app.

On Android API level 21 (Lollipop) and newer, the model is downloaded to a directory that is excluded from automatic backup.

On Android API level 20 and older, the model is downloaded to a directory named com.google.firebase.ml.custom.models in app-private internal storage. If you enabled file backup using BackupAgent, you might choose to exclude this directory.