File size: 10,094 Bytes
c0c6434 0ecf289 c0c6434 0ecf289 d445365 6406e40 d445365 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
---
language:
- en
license: mit
pipeline_tag: feature-extraction
library_name: dictionary-learning
---
# Model Card for MAIRA-2-SAE
This is a collection of sparse autoencoders (SAEs) trained on the residual stream of layer 15 of [MAIRA-2](https://huggingface.co/microsoft/maira-2), and described in the preprint ['Insights into a radiology-specialised multimodal large language model with sparse autoencoders'](https://arxiv.org/abs/2507.12950), presented at the [Actionable Interpretability Workshop @ ICML 2025](https://actionable-interpretability.github.io/).
In the preprint, we primarily study an SAE with expansion factor 4. Here we also release SAEs with expansion factors 2 and 8 to enable additional analyses. For expansion factors 2 and 4, we also provide LLM-generated interpretations of each feature and their corresponding interpretability scores.
## Model Details
A sparse autoencoder is a model which provides for two functions:
- Encoding some input (in this case, model activations) into a "latent space" (in this case, one which is higher dimensional than its input)
- Decoding from the "latent space" back into the input space
SAEs encode such that only a small number of latent dimensions (we call these features) are active for any input.
Specifically these are Matryoshka BatchTopK SAEs, which are described in [Learning Multi-Level Features with Matryoshka Sparse Autoencoders](https://arxiv.org/abs/2503.17547). Importantly, the decoder is linear, hence the SAE serves to reconstruct model activations as a linear combination of (putatively) interpretable feature directions.
### Model Description
<!-- Provide a longer summary of what this model is. -->
- **Developed by:** Microsoft Research Health Futures
- **Model type:** Autoencoder
- **License:** MIT
## Uses
<!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
These SAEs are shared for research purposes only. Their intended use is interpretability analysis of MAIRA-2. Given MAIRA-2 and a data example (e.g. from MIMIC-CXR), one can retrieve the activation strength of all SAE features. This can be used to ascribe interpretations to SAE features, or to use such feature interpretations to analyse the workings of MAIRA-2.
### Direct Use
<!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
Use of these SAEs requires access to MAIRA-2 - see the [MAIRA-2 model card](https://huggingface.co/microsoft/maira-2) for details.
Assuming one has extracted the residual stream from layer 15 of MAIRA-2, and processed the activations as described in [the preprint](https://arxiv.org/abs/2507.12950), the SAE can be used to encode this representation into a higher-dimensional space more suitable for interpretation.
We provide a usage example below.
Analyses specifically of the SAEs are also possible, for example by inspecting the learned dictionary elements (the decoder layer). In this case, the provided feature interpretations may be useful, however we stress that only a subset of features have meaningful interpretations.
### Out-of-Scope Use
<!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
These SAEs were trained on MAIRA-2 activations collected from the MIMIC-CXR findings generation subset of the original MAIRA-2 training dataset. Hence, they may not perform well (in the sense of reconstruction) on other datasets or tasks either within MAIRA-2's training distribution (e.g. PadChest, [PadChest-GR](https://ai.nejm.org/doi/full/10.1056/AIdbp2401120)), or datasets MAIRA-2 was not trained on. Any non-research use of these SAEs is out of scope.
## Bias, Risks, and Limitations
<!-- This section is meant to convey both technical and sociotechnical limitations. -->
As above, the SAEs were trained and interpreted using the MIMIC-CXR subset of the MAIRA-2 training data. MIMIC-CXR represents a cohort of patients from a single hospital in the USA. Inferences made about MAIRA-2 using these SAEs will necessarily be limited to concepts which could plausibly be discovered using MIMIC-CXR.
## How to Get Started with the Model
### Setup
Install [dictionary_learning](https://github.com/saprmarks/dictionary_learning):
`pip install dictionary-learning`
or
`uv add dictionary-learning`.
We used `dictionary_learning` as a submodule at commit `07975f7`, which is version `0.1.0`.
#### Download weights from the hub
Option 1: Download a single SAE with specified expansion factor
```python
from huggingface_hub import hf_hub_download
expansion_factor = 2
model_name = f"layer15_res_matryoshka_k256_ef{expansion_factor}.pt"
# Each expansion factor has its own subfolder
ef_subfolder = f"ef{expansion_factor}"
# Specify your own local download directory here if you want
local_dir = "./"
local_path = hf_hub_download(repo_id="microsoft/maira-2-sae", subfolder=ef_subfolder, filename=model_name, local_dir=local_dir)
```
Option 2: Download all SAEs
```python
from huggingface_hub import snapshot_download
# Specify your own local download directory here if you want
local_dir = "./"
snapshot_download(repo_id="microsoft/maira-2-sae", local_dir=local_dir)
```
### Use SAE to get activations
```python
import torch
from dictionary_learning.trainers.matryoshka_batch_top_k import MatryoshkaBatchTopKSAE
# local_path is the path to the dictionary weights (.pt file), however you downloaded them
ae = MatryoshkaBatchTopKSAE.from_pretrained(local_path)
# get NN activations using your preferred method: hooks, transformer_lens, nnsight, etc. ...
# for now we'll just use random activations
activation_dim = 4096
activations = torch.randn(64, activation_dim)
features = ae.encode(activations) # get features from activations
reconstructed_activations = ae.decode(features)
# you can also just get the reconstruction ...
reconstructed_activations = ae(activations)
# ... or get the features and reconstruction at the same time
reconstructed_activations, features = ae(activations, output_features=True)
```
## Training Details
### Training Data
We collected activations from the residual stream of layer 15 of MAIRA-2 using the MIMIC-CXR subset of the [MAIRA-2 training/validation set](https://arxiv.org/abs/2406.04449). As detailed in [our preprint](https://arxiv.org/abs/2507.12950), we collected activations from all tokens in the sequence excluding image tokens and boilerplate/templated subsequences. This resulted in 34.7M tokens for training, and 1.7M for validation (respecting the splits used to train MAIRA-2). Following [Gao et al.](https://arxiv.org/abs/2406.04093), we scaled all tokens with a normalization factor of 22.34, representing the mean l2 norm of the training samples.
<!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
### Training Procedure
<!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
We trained the SAEs using the open-source [dictionary learning](https://github.com/saprmarks/dictionary_learning) library, using the `MatryoshkaBatchTopKTrainer`.
#### Training Hyperparameters
- Matryoshka group fractions: [1/2, 1/4, 1/8, 1/16, 1/16]
- k (mean l0 per batch): 256
- Batch size: 8192
- Epochs: 1
- Expansion factors: 2, 4, 8 (multiple models)
Further hyperparameters are listed in [the preprint](https://arxiv.org/abs/2507.12950).
## Automated Interpretation
For SAEs with expansion factor 2 and 4, we also provide automatically-generated interpretations of each feature, again as described in [our preprint](https://arxiv.org/abs/2507.12950). These are the files `autointerp_layer15_res_matryoshka_k256_ef{2,4}.csv`.
These interpretations were generated by showing GPT-4o data samples selected based on the activation strength for that feature. Note that we did not show GPT-4o the images, so these interpretations are necessarily limited. We did not run full automated interpretation on expansion factor 8 due to the large number of features (32,768).
We scored the quality of the interpretations using the detection scoring approach from [Automatically Interpreting Millions of Features in Large Language Models](https://arxiv.org/abs/2410.13928), wherein the interpretation is provided to a LLM judge (again, GPT-4o) to predict whether a new sample will activate the feature. We provide binary classification metrics (accuracy, precision, recall, and F1) for each feature for both the 'train' samples (samples used to generate the interpretation) and validation (held-out samples) as a measure of interpretability. We also provide statistics on how often each feature was observed to activate in a random subset of the training set (n), to facilitate further analyses.
## Citation
<!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
**BibTeX:**
```
@article{maira2sae,
title={Insights into a radiology-specialised multimodal large language model with sparse autoencoders},
author={Kenza Bouzid and Shruthi Bannur and Felix Meissen and Daniel Coelho de Castro and Anton Schwaighofer and Javier Alvarez-Valle and Stephanie L. Hyland},
journal={Actionable Interpretability Workshop @ ICML 2025},
year={2025},
url={https://arxiv.org/abs/2507.12950}
}
```
**APA:**
> Bouzid, K., Bannur, S., Meissen, F., Coelho de Castro, D., Schwaighofer, A., Alvarez-Valle, J., & Hyland, S. L. (2025). Insights into a radiology-specialised multimodal large language model with sparse autoencoders. *Actionable Interpretability Workshop @ ICML 2025*. [arXiv](https://arxiv.org/abs/2507.12950).
## Model Card Contact
- Stephanie Hyland ([`[email protected]`](mailto:[email protected]))
- Kenza Bouzid ([`[email protected]`](mailto:[email protected])) |