Reauthoring & Converting models for edge inference: MambaV2 on LiteRT
The journey of taking an architecture from huggingface and getting it running on an on-device inference framework
Initially, this post was going to be mainly theory about my past experience in getting a novel architecture (say, an LLM like Llama) exported for on-device inference with frameworks like Google’s LiteRT, Microsoft’s ONNX, or Apple’s coremltools. But just then IBM’s Granite 4 Nano models got published, so I decided to dust off the old cobwebs and make it a practical demo :-)
The novel thing about the Granite 4 Nano (hyrbid) models is the inclusion of MambaV2. For an intro, I would read the paper (or just ask ChatGPT in Study & learn mode!). The high-level description is that instead of using an O(N*d) KV-cache for storing past computations, Mamba stores a couple of convolution & SSM states (analogous to RNN states) that are O(d). They are more efficient and use lesser memory, but aren’t as good quality-wise as Attention layers - therefore most models including Granite4 Nano have some attention layers mixed with Mamba.
Now that the model has been introduced, lets see what we need to do to get this exported for on-device inference (for edge devices like mobile phones) on LiteRT via ai-edge-torch. All the code can be seen here, which I will try to contribute back to ai-edge-torch pending some approvals from my employer.
Why ai-edge-torch & LiteRT? Because as of writing this (November 2025), this framework was the most ergonomic, with plenty of examples & tooling. This ensured that I could just focus on the part interesting to me (implementing Mamba), rather than spend a ton of time understanding a new framework. coremltools doesn’t have a model zoo-type thing focused on LLMs, and optimum-executorch had too many layers of abstractions to wade through (just my opinion).
1) Re-authoring the model
Edge frameworks love static execution plans and predefined shapes to be efficient on diverse edge hardware like Adreno GPUs, Qualcomm DSPs, and Apple/Google NPUs. This differs from source code (usually from huggingface) used for training on GPUs or TPUs. As a result, most models undergo a “re-authoring” process pre-deployment, specifically to optimize the original PyTorch (or Jax, or TensorFlow) code for on-device inference. This re-authoring process generally involves two main categories of changes:
1.a.) Representing ops in an edge-friendly way
A prime example of this is ANE Transformers. This effort by engineers from Apple involved representing attention with convolutions and chunking large tensors within the computation graph, in a bid to make the transformer more conducive to the ANE.
In the Mamba case, one key “feature” to implement was the new Mamba layer for token mixing alongside the attention implementation from ai-edge-torch. A simplification I made here was to go from using PyTorch’s roll (which wouldn’t be supported by many frameworks) to a combination of slice+concat for the convolution state update:
Before:
conv_state = conv_state.roll(shifts=-1, dims=-1)
conv_state[:, :, -1] = hidden_states_B_C[:, 0, :].to(conv_state.device)After:
new_conv_input = hidden_states_B_C[:, 0, :] # [batch_size, conv_dim]
# Slice out old states and concatenate with new input
old_conv_states = mamba_cache.conv_state[:, :, 1:]
new_conv_state = torch.cat([old_conv_states,new_conv_input.unsqueeze(-1)], dim=-1)1.b.) Specially handling novel concepts to simplify the converter’s life
For LLMs, the KV cache is a great example of how to use memory efficiently. Rather than simply concatenating PyTorch tensors as in naive PyTorch code, on-device implementations provide annotations for “in-place update” patterns. This allows the inference engine to update a pre-allocated memory chunk in place, which is much more efficient. A good examples of this approach is the CoreMLTools state tensors.
ai-edge-torch already has a kv_cache layer, I just had to implement a simple Mamba equivalent (key snippets):
@dataclasses.dataclass
class MambaCacheEntry:
“”“A single cache entry that includes conv & SSM states in MambaV2 impl.”“”
conv_state: torch.Tensor
ssm_state: torch.Tensor
@classmethod
def from_model_config(
cls,
config: model_config.MambaConfig,
dtype: torch.dtype = torch.float32,
device: torch.device | None = None,
batch_size: int = 1,
) -> “MambaCacheEntry”:
“”“Build an instance of the class based on a MambaConfig.”“”
conv_state_shape = (
batch_size,
(
config.expand * config.hidden_size
+ 2 * config.n_groups * config.d_state
),
config.d_conv,
)
ssm_state_shape = (
batch_size,
config.n_heads,
config.d_head,
config.d_state,
)
conv_state = torch.zeros(conv_state_shape, dtype=dtype, device=device)
ssm_state = torch.zeros(ssm_state_shape, dtype=dtype, device=device)
obj = cls(conv_state=conv_state, ssm_state=ssm_state)
return obj
def _update_impl(
cache: MambaCacheEntry,
new_conv_state: torch.Tensor,
new_ssm_state: torch.Tensor,
) -> MambaCacheEntry:
“”“Update the cache buffer for conv_state and ssm_state.”“”
# NB: Here assume that input_pos == range(input_pos[0], len(input_pos))
conv_state_indices = _get_slice_indices(3)
ssm_state_indices = _get_slice_indices(4)
conv_state = dus_utils.dynamic_update_slice(
cache.conv_state, new_conv_state, conv_state_indices
)
ssm_state = dus_utils.dynamic_update_slice(
cache.ssm_state, new_ssm_state, ssm_state_indices
)
updated_cache = MambaCacheEntry(conv_state, ssm_state)
return updated_cacheNote that the updates for Mamba caches are simpler than KV cache since we just have to update the whole state in-place instead of a particular slice (per-token) of the whole cache. More on the dus_utils.dynamic_update_slice soon.
2) Testing the re-authored model
Once re-authoring is done, it is important to ensure that the underlying math remains unchanged. The techniques utilized in this process are not efficient during training, and as such, they are not used during that phase.
To load the trained weights, the original weight checkpoint keys are remapped to new patterns that correspond to the re-authored code. Heres a snippet from my Mamba code (I had to add support for the mamba layers):
HYBRID_TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
embedding=”model.embed_tokens”,
ff_up_proj=”model.layers.{}.shared_mlp.input_linear”,
ff_gate_proj=”model.layers.{}.shared_mlp.gate_linear”,
ff_down_proj=”model.layers.{}.shared_mlp.output_linear”,
attn_query_proj=”model.layers.{}.self_attn.q_proj”,
attn_key_proj=”model.layers.{}.self_attn.k_proj”,
attn_value_proj=”model.layers.{}.self_attn.v_proj”,
attn_output_proj=”model.layers.{}.self_attn.o_proj”,
pre_attn_norm=”model.layers.{}.input_layernorm”,
post_attn_norm=”model.layers.{}.post_attention_layernorm”,
mamba_A_log=”model.layers.{}.mamba.A_log”,
mamba_dt_bias=”model.layers.{}.mamba.dt_bias”,
mamba_D=”model.layers.{}.mamba.D”,
mamba_conv1d_weight=”model.layers.{}.mamba.conv1d.weight”,
mamba_conv1d_bias=”model.layers.{}.mamba.conv1d.bias”,
mamba_in_proj_weight=”model.layers.{}.mamba.in_proj.weight”,
mamba_norm_weight=”model.layers.{}.mamba.norm.weight”,
mamba_out_proj_weight=”model.layers.{}.mamba.out_proj.weight”,
final_norm=”model.norm”,
)Once the verification passed (using a combination of test prompts & output verification), I was now free to convert and get a model artifact.
3) Conversion
So, a natural question pops up: can’t we just “auto-detect” many of these patterns and implement them using converters like Google’s LiteRT, Microsoft’s ONNX, or Apple’s coremltools which effectively act as frontends to their respective compilers?
The key challenge here is the sheer size of LLMs (both in terms of number of parameters and operations in the computational graph) - it becomes prohibitive to “auto-detect” complex patterns across the entire graph. For instance, auto-detecting the usage of a KV cache across all layers of a Transformer would demand a very complex and quite fragile compiler pass. On top of this, squeezing every last drop of performance from these models is absolutely critical for developers. It’s not just about reducing latency, but also about keeping memory and power usage in check.
Re-authoring transforms users into power users who understand how to leverage the export toolchain, considering both the model and hardware. This isn’t particularly problematic, as foundation models are converging - and major players like Apple, Google, or Meta have the resources to develop end-to-end toolchains. Converters are still useful of course, handling mundane tasks and providing robust APIs that allow users to direct the compiler as needed.
When it comes to LLMs, two main features within a great toolchain are especially useful for “lowering” these models:
3.a.) Lowering to custom ops
For LLMs, it’s crucial for model authors to provide “hints” on how to map constructs in their re-authored model code to the lower-level IR understood by backends. This typically involves wrapping part of the source model in a PyTorch custom op and using an interface in the converter or compiler to “lower” that custom op into individual operations or kernels in the lower-level IR.
In case of in-place updates for Mamba states, my implementation already used the more general dus_utils.dynamic_update_slice op to denote the full-state update. As a result, the lowering via TFLite/LiteRT’s converter took care of mapping this to the correct stablehlo op:
# Use torch.library.custom_op to define a new custom operator.
# TODO: Update impl for multiple non-trivial start_indices
@torch.library.custom_op(”ai_edge_torch::dynamic_update_slice”, mutates_args=())
def dynamic_update_slice(
in_tensor: torch.Tensor,
update: torch.Tensor,
start_indices: Sequence[torch.Tensor],
) -> torch.Tensor:
...
@lowerings.lower(torch.ops.ai_edge_torch.dynamic_update_slice)
def _dynamic_update_slice_lower(
lctx,
in_tensor: ir.Value,
update: ir.Value,
start_indices: Sequence[ir.Value],
):
return stablehlo.dynamic_update_slice(in_tensor, update, start_indices)3.b.) Packing multiple functions & shape specializations into the same Program
Each function here can be a different model (e.g., image/audio encoders and an LLM in a multimodal setup) or a different aspect of the same model inference (e.g., prefill and decode paths). Each function is compiled by a lower-level toolchain, aiming to share weights where possible. For an example, see the functions in the Gemma 3n exported artifact.
To maintain performance across varying prompt sizes, converters offer options to generate different variations of functions with different enumerations of well-defined shapes. These shapes are usually based on common use-case prompt sizes. To see an instance of such APIs, look at Flexible Input Shapes from Apple’s coremltools.
In this case, I was able to re-use most of the converter code from ai-edge-torch that adds signatures for different prefill lengths & the decode (1) into the same Program; I just had to implement some code specific to the Mamba mask and cache for IO:
prefill_mamba_cache = None
decode_mamba_cache = None
prefill_mamba_masks = None
decode_mamba_mask = None
if pytorch_model.has_mamba_blocks():
prefill_mamba_cache = mamba_utils.MambaCache.from_model_config(config)
decode_mamba_cache = mamba_utils.MambaCache.from_model_config(
config, batch_size=export_config.decode_batch_size
)
prefill_mamba_masks = _build_mamba_mask(prefill_seq_lens)
decode_mamba_mask = _build_mamba_mask(1)
def _maybe_add_mamba_to_kwargs(sample_kwargs: Dict, is_prefill: bool = False):
cache = prefill_mamba_cache if is_prefill else decode_mamba_cache
if cache is not None:
sample_kwargs[’mamba_cache’] = cacheNote that the default conversion in ai-edge-torch uses dynamic 8-bit quantization. I kept it as-is since it didn’t seem to cause much of an (anecdotal) degradation in the model output!
3) Running the model!
Finally, I was ready to run the model on my M1 Pro Macbook. By default, LiteRT uses XNNPack delegate with CPU and I was fine with that.
The dense (non-Mamba) version:
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.
Running prefill
Running decode
New York City is a global hub for culture, art, and entertainment. It’s home to iconic landmarks such as the Statue of Liberty, the Empire State Building, and the Broadway theater. The city also offers a rich history, diverse food scene, and a vibrant nightlife.
Metrics:
Time to First Token (TTFT): 56.35 ms
Average Time Between Tokens (TBT): 28.62 msThe hybrid version with MambaV2 (naively 10% faster TBT, but more terse):
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.
Running prefill
Running decode
(b) The best place to visit in New York is the Statue of Liberty.
Metrics:
Time to First Token (TTFT): 36.81 ms
Average Time Between Tokens (TBT): 25.72 msFor an understanding of what happens lower down the stack for last-mile compilation & running, I recommend the executorch documentation which outlines graph capture, compiler passes, and hardware-specific delegation.
All the code for this experiment is here.
4) Parting thoughts…
There are some things I would have done as follow-ups if I was seriously shiping this to production, such as… Run the model with DSP/NPU/etc to check performance across accelerators, Write tests to check quality of converted model vs the re-authored Pytorch code, try out different compression schemes (and maybe even compressing the KV cache), etc.
Most edge inference frameworks make the implementation/documentation of Python-based test harness for their model artifacts an after-thought, but I don’t think thats right. Model authors (especially people from the research/optimization world) aren’t experts in C++ or building native code. My code to run the tflite artifacts was taken from this deepseek notebook.
I really like how ai-edge-torch provides great building blocks with just the right amount of abstraction to “assemble” different types of transformers with good huggingface integration. To add a novel concept (Mamba), I didn’t need to dig too much into multiple layers of abstraction. The actual lower-level infra (LiteRT) is ofcourse awesome (and I have worked on it in a past life :-) ), but these kinds of exporting frameworks will be valuable for power-users.


