KV Cache Compression

Working on a KV cache compression library, cacheshrink, would love feedback

Hey everyone,

I’ve been working on a library called cacheshrink that compresses KV caches in HuggingFace transformer models. The core idea is converting attention layers to use Multi-Head Latent Attention (MLA), instead of caching full-dimensional keys and values, the model caches smaller latent representations and reconstructs K/V on the fly.


Standard: h → W_k → K (cached at full size)

MLA: h → W_down → c (small, cached) → W_u → K (reconstructed on the fly)

The decompression matrices are constrained to have orthonormal columns (Stiefel manifold) and trained with Riemannian optimization via geoopt, which keeps things numerically stable.

How it works in practice


import torch

from cacheshrink import convert_to_mla

model, tokenizer = convert_to_mla(

“Qwen/Qwen2.5-7B”,

compression_ratio=2.0,

compression_method=“auto”, # detects GQA and picks cross-layer compression

device=“cuda”,

dtype=torch.bfloat16,

use_calibration=True,

)

# works like the original model after this

outputs = model.generate(

tokenizer(“Hello”, return_tensors=“pt”).to(“cuda”).input_ids,

max_new_tokens=50,)

It auto-detects whether the model uses MHA or GQA and picks the right strategy, per-layer compression for MHA models (GPT-2, LLaMA 2) or cross-layer shared basis (xKV) for GQA models (Qwen, Mistral, LLaMA 3). I’ve tested it on 25+ models so far including Mixtral 8x7B.

Converted models are drop-in compatible with model.generate() and the rest of the HuggingFace API. You can save/load them with safetensors.

Things I’m still figuring out

  • How well this stacks with KV cache quantization (FP8/INT8 on top of dimensional compression)

  • vLLM or llama.cpp integration for actual serving workloads

  • Whether custom Triton kernels for fused decompress+attention would eliminate the reconstruction overhead

If anyone’s dealing with KV cache memory bottlenecks for long-context or high-throughput serving, I’d really appreciate hearing what compression ratios and quality trade-offs would actually be useful in practice.

Install: pip install cacheshrink

Github: GitHub - Kranium2002/cacheshrink: Python library that converts HuggingFace transformer models to use Multi-Head Latent Attention (MLA) with Riemannian optimization on Stiefel manifolds.

Happy to answer any questions about the approach or results.

1 Like

Hi, updated the link

1 Like

The Stiefel manifold approach is mathematically elegant, but without a fused kernel to handle the up-projection in SRAM, aren’t you just trading VRAM capacity for HBM bandwidth latency?

1 Like