Working on a KV cache compression library, cacheshrink, would love feedback
Hey everyone,
I’ve been working on a library called cacheshrink that compresses KV caches in HuggingFace transformer models. The core idea is converting attention layers to use Multi-Head Latent Attention (MLA), instead of caching full-dimensional keys and values, the model caches smaller latent representations and reconstructs K/V on the fly.
Standard: h → W_k → K (cached at full size)
MLA: h → W_down → c (small, cached) → W_u → K (reconstructed on the fly)
The decompression matrices are constrained to have orthonormal columns (Stiefel manifold) and trained with Riemannian optimization via geoopt, which keeps things numerically stable.
How it works in practice
import torch
from cacheshrink import convert_to_mla
model, tokenizer = convert_to_mla(
“Qwen/Qwen2.5-7B”,
compression_ratio=2.0,
compression_method=“auto”, # detects GQA and picks cross-layer compression
device=“cuda”,
dtype=torch.bfloat16,
use_calibration=True,
)
# works like the original model after this
outputs = model.generate(
tokenizer(“Hello”, return_tensors=“pt”).to(“cuda”).input_ids,
max_new_tokens=50,)
It auto-detects whether the model uses MHA or GQA and picks the right strategy, per-layer compression for MHA models (GPT-2, LLaMA 2) or cross-layer shared basis (xKV) for GQA models (Qwen, Mistral, LLaMA 3). I’ve tested it on 25+ models so far including Mixtral 8x7B.
Converted models are drop-in compatible with model.generate() and the rest of the HuggingFace API. You can save/load them with safetensors.
Things I’m still figuring out
-
How well this stacks with KV cache quantization (FP8/INT8 on top of dimensional compression)
-
vLLM or llama.cpp integration for actual serving workloads
-
Whether custom Triton kernels for fused decompress+attention would eliminate the reconstruction overhead
If anyone’s dealing with KV cache memory bottlenecks for long-context or high-throughput serving, I’d really appreciate hearing what compression ratios and quality trade-offs would actually be useful in practice.
Install: pip install cacheshrink
Happy to answer any questions about the approach or results.