MVP Factory
ai startup development

Predictive Prefetching in Android with TensorFlow Lite: Training Navigation Models on User Session Data and the On-Device Inference Pipeline That Cut Our P95 Screen Load Time by 40%

KW
Krystian Wiewiór · · 5 min read

Why predict navigation at all?

The same logic behind ML-based molecular screening (where teams like 10x Science predict which molecules matter out of millions of candidates) applies to mobile UX. You have a combinatorial space of possible next screens, and a model that narrows it down saves real resources. In our case, the resource is the user’s time.

Most Android apps treat navigation reactively: user taps, system inflates Fragment, network call fires, data renders. Every millisecond in that chain is felt. Predictive prefetching flips the sequence by starting work before the tap.

The model: sequential navigation prediction

We treat each user session as a sequence of screen IDs and train a model to predict the next screen given the last N screens.

Training data pipeline

StepDetail
CollectionAnonymized screen_id sequences from Firebase Analytics, bucketed by session
Vocabulary47 unique screens mapped to integer tokens
Sequence lengthSliding window of 5 (last 5 screens predict 6th)
Dataset size~2.1M sequences from 90 days of production logs
Split80/10/10 train/val/test

The model architecture is deliberately simple: a two-layer LSTM with a 32-unit hidden size feeding a softmax output over the 47-screen vocabulary. I’ve shipped enough production ML to know that the winning move is almost always the simplest model that clears the accuracy bar, not the cleverest one.

model = tf.keras.Sequential([
    tf.keras.layers.Embedding(vocab_size, 16, input_length=seq_len),
    tf.keras.layers.LSTM(32, return_sequences=True),
    tf.keras.layers.LSTM(32),
    tf.keras.layers.Dense(vocab_size, activation='softmax')
])

Top-1 accuracy landed at 68%; top-3 hit 89%. For prefetching, top-3 is the metric that matters. We speculatively load the three most likely next screens.

TFLite conversion with dynamic quantization

converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]  # dynamic range quantization
tflite_model = converter.convert()  # 94 KB output
MetricFull KerasTFLite (quantized)
Model size410 KB94 KB
Inference latency (Pixel 6)12 ms3.1 ms
Top-3 accuracy89.2%88.7%

Half a percentage point of accuracy for a 4x size reduction and 4x speed improvement. A 94 KB model running inference in ~3 ms is practically invisible to the runtime budget.

On-device inference: Lifecycle-aware coroutine

What most teams get wrong: they run inference on every screen transition without respecting the Android lifecycle. That leads to wasted work during config changes and leaked coroutines. We bind inference to the NavController destination change listener inside a lifecycleScope.

class PrefetchNavigationObserver(
    private val lifecycle: LifecycleOwner,
    private val predictor: ScreenPredictor,
    private val prefetcher: FragmentPrefetcher
) : NavController.OnDestinationChangedListener {

    override fun onDestinationChanged(
        controller: NavController, dest: NavDestination, args: Bundle?
    ) {
        lifecycle.lifecycleScope.launch(Dispatchers.Default) {
            val predictions = predictor.topK(screenHistory, k = 3)
            predictions.forEach { screenId ->
                prefetcher.prefetch(screenId) // inflate + cache data
            }
        }
    }
}

FragmentPrefetcher inflates the Fragment view hierarchy into an off-screen cache and fires the associated ViewModel data load. When the user actually navigates, the cached view and pre-loaded data are swapped in.

Benchmarks: the production impact

We ran an A/B test over four weeks with 22K daily active users per cohort.

MetricControl (no prefetch)Prefetch cohortDelta
P50 screen load280 ms210 ms-25%
P95 screen load820 ms490 ms-40%
Memory overhead+2.8 MB avg
Battery (24h drain)100% baseline+0.3%Negligible
Network (daily)100% baseline+4.2%Acceptable

The P95 improvement is where this pays off. Tail latency is what users remember. Shaving 330 ms off the worst-case path changed our app store review sentiment measurably.

Solving the cold-start bootstrap problem

A fresh install has zero navigation history. Our fallback strategy layers three sources:

  1. Population prior — a static frequency table baked into the APK at build time, derived from aggregate navigation patterns across all users.
  2. Session accumulation — after three screen transitions, the model begins issuing live predictions.
  3. Model update — the TFLite file ships via Firebase ML Model Management, updated monthly without an app release.

The population prior alone covers 72% of top-3 predictions correctly, so even first-session users see some benefit.

What I’d take away from this

Start with the simplest sequence model that clears top-3 accuracy above 85%. A two-layer LSTM with 32 hidden units and dynamic quantization gives you a sub-100 KB artifact with ~3 ms inference. Don’t over-architect this.

Bind inference to the Android lifecycle. Use lifecycleScope and Dispatchers.Default so prediction work is automatically cancelled on configuration changes and never blocks the main thread.

Solve cold-start on day one. Ship a population-prior frequency table in your APK and switch to live predictions after a minimum session history threshold. Without this, your first-install experience, the moment that matters most, gets no benefit.


TAGS: android, kotlin, architecture, mobile, jetpackcompose


Share: Twitter LinkedIn