license: cc-by-nc-4.0
tags:
- audio
- music
- music-information-retrieval
- music-representation-learning
- foundation-model
- continual-pretraining
- cross-cultural-MIR
metrics:
- roc_auc
- f1
pipeline_tag: audio-classification
CultureMERT: Continual Pre-Training for Cross-Cultural Music Representation Learning
CultureMERT-95M is a multi-culturally adapted 95M-parameter music foundation model based on MERT-v1-95M. It is developed through a two-stage continual pre-training strategy on 650 hours of culturally diverse audio spanning Greek, Turkish, and Indian musical traditions. The model significantly improves representation quality for "non-Western" music, achieving an average ROC-AUC improvement of 4.43% across culturally diverse music tagging tasks, surpassing prior state-of-the-art, while maintaining strong performance on Western-centric benchmarks such as MagnaTagATune and FMA-medium.
π§ Model Details
- Architecture: 12-layer Transformer encoder (768-dim) with a 7-layer 1D CNN frontend
- Input: Raw mono audio at 24kHz
- Pretraining Objective: Multi-task masked prediction of discrete EnCodec acoustic tokens and continuous CQT frame reconstruction (75Hz)
π Training Data
| Dataset | Music Tradition | Hours Used |
|---|---|---|
| Lyra | Greek traditional/folk | 50h |
| Turkish-makam | Turkish/Ottoman classical | 200h |
| Hindustani | North Indian classical | 200h |
| Carnatic | South Indian classical | 200h |
π The datasets used were obtained under research-use agreements and are not redistributed.
π Evaluation
We evaluate CultureMERT-95M on both Western and non-Western music auto-tagging tasks to assess its cross-cultural generalization. The evaluation uses standard multi-label classification metrics:
- ROC-AUC (Receiver Operating Characteristic - Area Under Curve)
- Average Precision (AP)
- Micro-F1 and Macro-F1
All results are averaged over five random seeds.
Evaluation Datasets
- Non-Western traditions:
- Turkish Makam (Ottoman classical)
- Hindustani (North Indian classical)
- Carnatic (South Indian classical)
- Lyra (Greek traditional/folk music)
- Western benchmarks:
- MagnaTagATune (MTAT)
- FMA-medium
ROC-AUC and Average Precision (AP)
| Model | Turkish-makam | Hindustani | Carnatic | Lyra | FMA | MTAT | Avg. |
|---|---|---|---|---|---|---|---|
| MERT-v1-95M | 83.2 / 53.3 | 82.4 / 52.9 | 74.9 / 39.7 | 85.7 / 56.5 | 90.7 / 48.1 | 89.6 / 35.9 | 66.1 |
| CultureMERT-95M | 89.6 / 60.6 | 88.2 / 63.5 | 79.2 / 43.1 | 86.9 / 56.7 | 90.7 / 48.1 | 89.4 / 35.9 | 69.3 |
Micro-F1 and Macro-F1
| Model | Turkish-makam | Hindustani | Carnatic | Lyra | FMA | MTAT | Avg. |
|---|---|---|---|---|---|---|---|
| MERT-v1-95M | 73.0 / 38.9 | 71.1 / 33.2 | 80.1 / 30.0 | 72.4 / 42.6 | 57.0 / 36.9 | 35.7 / 21.2 | 49.3 |
| CultureMERT-95M | 77.4 / 45.8 | 77.8 / 50.4 | 82.7 / 32.5 | 73.1 / 43.1 | 58.3 / 36.6 | 35.6 / 22.9 | 52.9 |
CultureMERT-95M outperforms the original MERT-v1-95M by an average of 4.43% in ROC-AUC across the non-Western traditions, alongside consistent average improvements of 5.4% in AP, 3.6% in Micro-F1, and 6.8% in Macro-F1, while exhibiting minimal forgetting on Western benchmarks.
Model Usage
# from transformers import Wav2Vec2Processor
from transformers import Wav2Vec2FeatureExtractor
from transformers import AutoModel
import torch
from torch import nn
import torchaudio.transforms as T
from datasets import load_dataset
# loading our model weights
model = AutoModel.from_pretrained("ntua-slp/CultureMERT-95M", trust_remote_code=True)
# loading the corresponding preprocessor config
processor = Wav2Vec2FeatureExtractor.from_pretrained("ntua-slp/CultureMERT-95M",trust_remote_code=True)
# load demo audio and set processor
dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
dataset = dataset.sort("id")
sampling_rate = dataset.features["audio"].sampling_rate
resample_rate = processor.sampling_rate
# make sure the sample_rate aligned
if resample_rate != sampling_rate:
print(f'setting rate from {sampling_rate} to {resample_rate}')
resampler = T.Resample(sampling_rate, resample_rate)
else:
resampler = None
# audio file is decoded on the fly
if resampler is None:
input_audio = dataset[0]["audio"]["array"]
else:
input_audio = resampler(torch.from_numpy(dataset[0]["audio"]["array"]))
inputs = processor(input_audio, sampling_rate=resample_rate, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs, output_hidden_states=True)
# take a look at the output shape, there are 13 layers of representation
# each layer performs differently in different downstream tasks, you should choose empirically
all_layer_hidden_states = torch.stack(outputs.hidden_states).squeeze()
print(all_layer_hidden_states.shape) # [13 layer, Time steps, 768 feature_dim]
# for utterance level classification tasks, you can simply reduce the representation in time
time_reduced_hidden_states = all_layer_hidden_states.mean(-2)
print(time_reduced_hidden_states.shape) # [13, 768]
# you can even use a learnable weighted average representation
aggregator = nn.Conv1d(in_channels=13, out_channels=1, kernel_size=1)
weighted_avg_hidden_states = aggregator(time_reduced_hidden_states.unsqueeze(0)).squeeze()
print(weighted_avg_hidden_states.shape) # [768]
Ethical Considerations
Careful consideration is advised before deploying this model in real-world contexts, as it may still reflect cultural and dataset biases. This models should not be used for commercial or generative applications without explicit attention to cultural representation, appropriate licensing, and the consent of the relevant communities or dataset curators.
Citation
...
