GokseninYuksel commited on
Commit
fefd7ae
·
verified ·
1 Parent(s): 7473c99

Upload model

Browse files
Files changed (9) hide show
  1. audio_extractor.py +197 -0
  2. config.json +52 -0
  3. configuration_wavjepa.py +83 -0
  4. model.py +182 -0
  5. model.safetensors +3 -0
  6. modeling_wavjepa.py +33 -0
  7. pos_embed.py +267 -0
  8. types.py +51 -0
  9. utils.py +33 -0
audio_extractor.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import prod
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+ from einops.layers.torch import Rearrange
7
+ from einops import rearrange
8
+
9
+ from typing import List, Optional
10
+
11
+ from abc import ABC, abstractmethod
12
+
13
+ class Extractor(ABC):
14
+ """Abstract base class for encoders."""
15
+
16
+ # Just declare that implementers should have this attribute
17
+ embedding_dim: int
18
+
19
+ @abstractmethod
20
+ def forward(self, x : torch.Tensor) -> torch.Tensor:
21
+ """Forward pass through the encoder."""
22
+ pass
23
+
24
+ @abstractmethod
25
+ def total_patches(self, time: int) -> int:
26
+ """Returns the total patches given the time dimension of the input."""
27
+ pass
28
+
29
+
30
+ class ConvFeatureExtractor(Extractor, nn.Module):
31
+ """
32
+ Convolutional feature encoder for EEG data.
33
+
34
+ Computes successive 1D convolutions (with activations) over the time
35
+ dimension of the audio signal. This encoder also uses different kernels for each time signal.
36
+ Therefore, in_channels argument is necessary!
37
+
38
+ Inspiration from https://github.com/facebookresearch/fairseq/blob/main/fairseq/models/wav2vec/wav2vec2.py
39
+ and https://github.com/SPOClab-ca/BENDR/blob/main/dn3_ext.py
40
+
41
+ Args:
42
+ conv_layers_spec: list of tuples (dim, k, stride) where:
43
+ * dim: number of output channels of the layer (unrelated to EEG channels);
44
+ * k: temporal length of the layer's kernel;
45
+ * stride: temporal stride of the layer's kernel.
46
+
47
+ in_channels: int
48
+ Number of audio channels.
49
+ dropout: float
50
+ mode: str
51
+ Normalisation mode. Either``default`` or ``layer_norm``.
52
+ conv_bias: bool
53
+ depthwise: bool
54
+ Perform depthwise convolutions rather than the full convolution.
55
+ """
56
+
57
+ def __init__(
58
+ self,
59
+ *args,
60
+ conv_layers_spec: list[tuple[int, int, int]],
61
+ in_channels : int = 2,
62
+ dropout: float = 0.0,
63
+ mode: str = "default",
64
+ conv_bias: bool = False,
65
+ depthwise : bool = False,
66
+ **kwargs,
67
+ ):
68
+ assert mode in {"default", "layer_norm"}
69
+ super().__init__() # type: ignore
70
+
71
+ def block(
72
+ n_in : int,
73
+ n_out : int,
74
+ k : int,
75
+ stride : int,
76
+ is_layer_norm : bool =False,
77
+ is_group_norm : bool =False,
78
+ conv_bias : bool =False,
79
+ depthwise : bool = True,
80
+ ):
81
+
82
+ def make_conv():
83
+ if depthwise:
84
+ assert n_out % n_in == 0, f"For depthwise signals we can not have non-multipler of {n_out} and {n_in}"
85
+ conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias, groups = n_in)
86
+ else:
87
+ conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
88
+
89
+ nn.init.kaiming_normal_(conv.weight)
90
+ return conv
91
+
92
+ assert not (
93
+ is_layer_norm and is_group_norm
94
+ ), "layer norm and group norm are exclusive"
95
+
96
+ if is_layer_norm:
97
+ return nn.Sequential(
98
+ make_conv(),
99
+ nn.Dropout(p=dropout),
100
+ nn.Sequential(
101
+ Rearrange("... channels time -> ... time channels"),
102
+ nn.LayerNorm(n_out, elementwise_affine=True), # Fixed: use n_out instead of dim
103
+ Rearrange("... time channels -> ... channels time"),
104
+ ),
105
+ nn.GELU(),
106
+ )
107
+ elif is_group_norm:
108
+ return nn.Sequential(
109
+ make_conv(),
110
+ nn.Dropout(p=dropout),
111
+ nn.GroupNorm(n_out, n_out, affine=True), # Fixed: use n_out instead of dim
112
+ nn.GELU(),
113
+ )
114
+ else:
115
+ return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU())
116
+
117
+ self.in_channels = in_channels
118
+ self.depthwise = depthwise
119
+ in_d = in_channels
120
+ conv_layers = []
121
+ for i, cl in enumerate(conv_layers_spec):
122
+ assert len(cl) == 3, "invalid conv definition: " + str(cl)
123
+ (dim, k, stride) = cl
124
+ conv_layers.append( # type: ignore
125
+ block(
126
+ in_d,
127
+ dim,
128
+ k,
129
+ stride,
130
+ is_layer_norm=mode == "layer_norm",
131
+ is_group_norm=mode == "default" and i == 0,
132
+ conv_bias=conv_bias,
133
+ depthwise=self.depthwise
134
+ )
135
+ )
136
+ in_d = dim
137
+ self.conv_layers_spec = conv_layers_spec
138
+ self.cnn : nn.Module = nn.Sequential(*conv_layers) # type: ignore
139
+ self.embedding_dim = conv_layers_spec[-1][0]
140
+
141
+ def forward(self, x : torch.Tensor) -> torch.Tensor:
142
+ """
143
+ Args:
144
+ x: (batch_size, n_chans, n_times)
145
+ Batched EEG signal.
146
+
147
+ Returns:
148
+ local_features: (batch_size, emb_dim, n_times_out)
149
+ Local features extracted from the audio signal.
150
+ ``emb_dim`` corresponds to the ``dim`` of the last element of
151
+ ``conv_layers_spec``.
152
+ """
153
+ x = self.cnn(x)
154
+ x = rearrange(x, "batch_size n_channels n_time -> batch_size n_time n_channels")
155
+ return x
156
+
157
+ def total_patches(self, time: int, device : str = "cuda") -> int:
158
+ """Calculate the number of output time steps for a given input length."""
159
+ x = torch.zeros((1, self.in_channels, time), device = next(self.cnn[0].parameters()).device)
160
+ x = self.cnn(x)
161
+ x : torch.Tensor = rearrange(x, "batch_size n_channels n_time -> batch_size n_time n_channels")
162
+ return x.shape[1] # Return time dimension size
163
+
164
+ @property
165
+ def receptive_fields(self) -> List[int]:
166
+ rf = 1
167
+ receptive_fields = [rf]
168
+ for _, width, stride in reversed(self.conv_layers_spec):
169
+ rf = (rf - 1) * stride + width # assumes no padding and no dilation
170
+ receptive_fields.append(rf)
171
+ return list(reversed(receptive_fields))
172
+
173
+ def description(self, sfreq : Optional[int] = None, dummy_time : Optional[int] = None) -> str:
174
+ dims, _, strides = zip(*self.conv_layers_spec)
175
+ receptive_fields = self.receptive_fields
176
+ rf = receptive_fields[0]
177
+ desc = f"Receptive field: {rf} samples"
178
+ if sfreq is not None:
179
+ desc += f", {rf / sfreq:.2f} seconds"
180
+
181
+ ds_factor = prod(strides)
182
+ desc += f" | Downsampled by {ds_factor}"
183
+ if sfreq is not None:
184
+ desc += f", new sfreq: {sfreq / ds_factor:.2f} Hz"
185
+ desc += f" | Overlap of {rf - ds_factor} samples"
186
+ if dummy_time is not None:
187
+ n_times_out = self.total_patches(dummy_time)
188
+ desc += f" | {n_times_out} encoded samples/trial"
189
+
190
+ n_features = [
191
+ f"{dim}*{rf}" for dim, rf in zip([self.in_channels] + list(dims), receptive_fields)
192
+ ]
193
+ desc += f" | #features/sample at each layer (n_channels*n_times): [{', '.join(n_features)}] = {[eval(x) for x in n_features]}"
194
+ return desc
195
+
196
+
197
+
config.json ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "WavJEPAModel"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_wavjepa.WavJEPAConfig",
7
+ "AutoModel": "modeling_wavjepa.WavJEPAModel"
8
+ },
9
+ "decoder_cfg": {
10
+ "enable_nested_tensor": false,
11
+ "mask_check": true,
12
+ "num_layers": 12
13
+ },
14
+ "decoder_layers_cfg": {
15
+ "activation": "gelu",
16
+ "batch_first": true,
17
+ "bias": true,
18
+ "d_model": 384,
19
+ "dim_feedforward": 1536,
20
+ "dropout": 0.0,
21
+ "layer_norm_eps": 1e-06,
22
+ "nhead": 12,
23
+ "norm_first": true
24
+ },
25
+ "dtype": "float32",
26
+ "encoder_cfg": {
27
+ "enable_nested_tensor": false,
28
+ "mask_check": true,
29
+ "num_layers": 12
30
+ },
31
+ "encoder_layers_cfg": {
32
+ "activation": "gelu",
33
+ "batch_first": true,
34
+ "bias": true,
35
+ "d_model": 768,
36
+ "dim_feedforward": 3072,
37
+ "dropout": 0.0,
38
+ "layer_norm_eps": 1e-06,
39
+ "nhead": 12,
40
+ "norm_first": true
41
+ },
42
+ "extractor_config": {
43
+ "conv_bias": false,
44
+ "conv_layers_spec": "[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)]",
45
+ "depthwise": false,
46
+ "dropout": 0.0,
47
+ "in_channels": 1,
48
+ "mode": "default"
49
+ },
50
+ "model_type": "wavjepa-base",
51
+ "transformers_version": "4.57.1"
52
+ }
configuration_wavjepa.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ from torch import nn
3
+ from .types import TransformerLayerCFG, TransformerEncoderCFG
4
+
5
+
6
+
7
+ class WavJEPAConfig(PretrainedConfig):
8
+ model_type = "wavjepa-base"
9
+ model_size = "base"
10
+ in_channels: int = 1
11
+
12
+ def __init__(
13
+ self,
14
+ extractor_layers_spec: str = "[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)]",
15
+ extractor_dropout : float = 0.0,
16
+ extractor_mode : str = "default",
17
+ extractor_conv_bias : bool = False,
18
+ extractor_depthwise: bool = False,
19
+ encoder_d_model: int = 768,
20
+ encoder_nhead : int = 12,
21
+ encoder_batch_first = True,
22
+ encoder_norm_first = True,
23
+ encoder_bias = True,
24
+ encoder_mlp_ratio = 4.0,
25
+ encoder_dropout = 0.0,
26
+ encoder_num_layers: int = 12,
27
+ encoder_enable_nested_tensor = False,
28
+ encoder_mask_check = True,
29
+ encoder_activation = 'gelu',
30
+ decoder_d_model: int = 384,
31
+ decoder_nhead : int = 12,
32
+ decoder_batch_first = True,
33
+ decoder_norm_first = True,
34
+ decoder_bias = True,
35
+ decoder_mlp_ratio = 4.0,
36
+ decoder_dropout = 0.0,
37
+ decoder_num_layers: int = 12,
38
+ decoder_enable_nested_tensor = False,
39
+ decoder_mask_check = True,
40
+ decoder_activation = 'gelu',
41
+ **kwargs
42
+ ):
43
+ self.encoder_cfg = TransformerEncoderCFG.create(
44
+ num_layers = encoder_num_layers,
45
+ enable_nested_tensor = encoder_enable_nested_tensor,
46
+ mask_check = encoder_mask_check,
47
+ )
48
+ self.decoder_cfg = TransformerEncoderCFG.create(
49
+ num_layers = decoder_num_layers,
50
+ enable_nested_tensor = decoder_enable_nested_tensor,
51
+ mask_check = decoder_mask_check,
52
+ )
53
+ self.encoder_layers_cfg = TransformerLayerCFG.create(
54
+ d_model = encoder_d_model,
55
+ nhead = encoder_nhead,
56
+ batch_first = encoder_batch_first,
57
+ norm_first = encoder_norm_first,
58
+ bias = encoder_bias,
59
+ mlp_ratio = encoder_mlp_ratio,
60
+ dropout = encoder_dropout,
61
+ activation = encoder_activation,
62
+ layer_norm_eps = 1e-6
63
+ )
64
+ self.decoder_layers_cfg = TransformerLayerCFG.create(
65
+ d_model = decoder_d_model,
66
+ nhead = decoder_nhead,
67
+ batch_first = decoder_batch_first,
68
+ norm_first = decoder_norm_first,
69
+ bias = decoder_bias,
70
+ mlp_ratio = decoder_mlp_ratio,
71
+ dropout = decoder_dropout,
72
+ activation = decoder_activation,
73
+ layer_norm_eps = 1e-6
74
+ )
75
+ self.extractor_config = dict(
76
+ conv_layers_spec = extractor_layers_spec,
77
+ in_channels = self.in_channels,
78
+ dropout = extractor_dropout,
79
+ mode = extractor_mode,
80
+ conv_bias = extractor_conv_bias,
81
+ depthwise = extractor_depthwise)
82
+
83
+ super().__init__(**kwargs)
model.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import numpy as np
3
+
4
+ from typing import Any, Optional
5
+
6
+ import torch
7
+ from torch import nn
8
+
9
+
10
+ from .pos_embed import get_1d_sincos_pos_embed_from_grid, get_2d_sincos_pos_embed, get_binaural_pos_embed
11
+ from .audio_extractor import Extractor
12
+ from .types import TransformerLayerCFG, TransformerEncoderCFG
13
+ from .utils import normalize, calculate_padding_mask, get_timestamps
14
+
15
+ class WavJEPA(nn.Module):
16
+ """
17
+ Joint-Embedding Predictive Architecture (JEPA).
18
+
19
+ This implementation is inspired by:
20
+ * I-JEPA http://arxiv.org/abs/2301.08243
21
+ * Data2vec 2.0 http://arxiv.org/abs/2212.07525
22
+ """
23
+
24
+ teacher_encoder: nn.Module
25
+ sample_rate : int = 16000
26
+ process_audio_seconds : float = 2.01
27
+ in_channels : int = 1
28
+
29
+
30
+ def __init__(
31
+ self,
32
+ feature_extractor: Extractor,
33
+ transformer_encoder_layers_cfg : TransformerLayerCFG,
34
+ transformer_encoder_cfg : TransformerEncoderCFG,
35
+ transformer_decoder_layers_cfg : TransformerLayerCFG,
36
+ transformer_decoder_cfg : TransformerEncoderCFG,
37
+ size : str = "base",
38
+ **kwargs : dict[str, Any],
39
+ ):
40
+ super().__init__(**kwargs)
41
+
42
+ self.is_spectrogram = False
43
+ self.target_length = int(self.sample_rate * self.process_audio_seconds)
44
+ self.extract_audio = feature_extractor
45
+ self.total_patches = 200
46
+ self.feature_norms : nn.Module = nn.LayerNorm(self.extract_audio.embedding_dim)
47
+
48
+ self.n_encoder_heads = transformer_encoder_layers_cfg["nhead"]
49
+ self.encoder_embedding_dim = transformer_encoder_layers_cfg["d_model"]
50
+ self.n_decoder_heads = transformer_decoder_layers_cfg["nhead"]
51
+ self.decoder_embedding_dim = transformer_decoder_layers_cfg["d_model"]
52
+
53
+ encoder_layer = nn.TransformerEncoderLayer(**transformer_encoder_layers_cfg)
54
+ self.encoder = nn.TransformerEncoder(encoder_layer, norm = nn.LayerNorm(self.encoder_embedding_dim), **transformer_encoder_cfg)
55
+ self.post_extraction_mapper : Optional[nn.Module] = nn.Linear(feature_extractor.embedding_dim, self.encoder_embedding_dim) if feature_extractor.embedding_dim != self.encoder_embedding_dim else None
56
+ decoder_layer = nn.TransformerEncoderLayer(**transformer_decoder_layers_cfg)
57
+ self.decoder = nn.TransformerEncoder(decoder_layer, norm = nn.LayerNorm(self.decoder_embedding_dim), **transformer_decoder_cfg)
58
+ self.decoder_to_encoder_mapper = nn.Linear(self.decoder_embedding_dim, self.encoder_embedding_dim, bias=True)
59
+ self.encoder_to_decoder_mapper = nn.Linear(self.encoder_embedding_dim, self.decoder_embedding_dim)
60
+
61
+ # For the autocast add batch dimensions.
62
+ self.mask_token = nn.Parameter(
63
+ torch.zeros(1, 1, self.decoder_embedding_dim, requires_grad=True)
64
+ )
65
+ torch.nn.init.normal_(self.mask_token, std=0.02)
66
+ self.pos_encoding_encoder = self._get_pos_embed_params(self.encoder_embedding_dim)
67
+ self.pos_encoding_decoder = self._get_pos_embed_params(self.decoder_embedding_dim)
68
+ self.output_steps = self.extract_audio.total_patches(self.target_length) // self.in_channels
69
+
70
+ self._init_teacher()
71
+
72
+
73
+ def _get_pos_embed_params(self, embedding_dim):
74
+ """Calculates the pos embedding embedding parameters and returns them."""
75
+ # Update positional embedding
76
+ pos_embed = nn.Parameter(
77
+ torch.zeros(
78
+ 1,
79
+ self.total_patches,
80
+ embedding_dim,
81
+ ),
82
+ requires_grad=False,
83
+ )
84
+ positions = np.arange(self.total_patches, dtype=np.float64)
85
+ if self.is_spectrogram:
86
+ # If it is a spectrogram, we use 2d sincos embeddings.
87
+ pos_embed_data = get_2d_sincos_pos_embed(
88
+ embedding_dim, self.extract_audio.grid_size, cls_token_num=0
89
+ )
90
+ #TODO! Remove this total patches later.
91
+ elif not self.is_spectrogram and self.in_channels == 2 and (self.total_patches == 400):
92
+ # We use 1D sincos embeddings with channel number indicated on the last 384 dimensions.
93
+ pos_embed_data = get_binaural_pos_embed(embedding_dim, time_steps=self.total_patches // self.in_channels
94
+ )
95
+ elif not self.is_spectrogram and self.in_channels == 2 and (self.total_patches == 200):
96
+ #Use 1D pos_embeddings if channel-mixing feature extractor
97
+ pos_embed_data = get_1d_sincos_pos_embed_from_grid(
98
+ embedding_dim,
99
+ positions,
100
+ )
101
+ elif not self.is_spectrogram and self.in_channels == 1 and (self.total_patches == 200):
102
+ # IF it is plain audio, we used 1d sincos embeddings
103
+ pos_embed_data = get_1d_sincos_pos_embed_from_grid(
104
+ embedding_dim,
105
+ positions,
106
+ )
107
+ else:
108
+ raise Exception(f"Not implemented for more in_channels, {self.in_channels}, {self.total_patches}")
109
+ pos_embed.data.copy_(torch.from_numpy(pos_embed_data).float().unsqueeze(0))
110
+ return pos_embed
111
+
112
+ def _init_teacher(self):
113
+ self.teacher_encoder = copy.deepcopy(self.encoder)
114
+ self.teacher_encoder.requires_grad_(False)
115
+
116
+
117
+
118
+ @torch.inference_mode()
119
+ def _get_segment_representation(self, audio : torch.Tensor, padding_mask : torch.tensor):
120
+ # Get the audio representatin of waveform x.
121
+ self.eval()
122
+ local_features = self.extract_audio(audio)
123
+ local_features = self.feature_norms(local_features)
124
+ if self.post_extraction_mapper:
125
+ local_features = self.post_extraction_mapper(local_features)
126
+ local_features = local_features + self.pos_encoding_encoder
127
+ # Encoder and decoder forward
128
+ contextual_features = self.encoder(local_features, src_key_padding_mask = padding_mask)
129
+ return contextual_features
130
+
131
+ @torch.inference_mode()
132
+ def get_audio_representation(self, audio : torch.Tensor):
133
+ self.eval()
134
+ B = audio.shape[0]
135
+ input_audio_len = audio.shape[-1]
136
+ # Assert audio is of correct shape
137
+ if audio.ndim != 3:
138
+ raise ValueError(
139
+ "audio input tensor must be 2D with shape (n_sounds, n_channels, num_samples)"
140
+ )
141
+ cur_frames = audio.shape[-1]
142
+ pad_frames = self.target_length - (cur_frames % self.target_length)
143
+
144
+ if pad_frames > 0:
145
+ # Padding with constant 0s
146
+ pad_arg = (
147
+ 0,
148
+ pad_frames,
149
+ ) # (channel, channel, height, height, width, width)
150
+ audio = torch.nn.functional.pad(audio, pad_arg, mode="constant")
151
+ embeddings = []
152
+ padding_mask, cut_off = calculate_padding_mask(pad_frames = pad_frames,
153
+ total_frames = audio.shape[-1],
154
+ sr = self.sample_rate,
155
+ output_steps = self.total_patches,
156
+ process_seconds = self.target_length // self.sample_rate,
157
+ device = audio.device,
158
+ B = B)
159
+ mask_idx = 0
160
+ masked_mean = torch.zeros(audio.shape, dtype = torch.bool)
161
+ masked_mean[..., cur_frames:] = True
162
+ mt = torch.masked.masked_tensor(audio, masked_mean)
163
+ # Now get the embeddings o the model.
164
+ for i in range(audio.shape[-1] // self.target_length):
165
+ mt = audio[..., i * self.target_length : (i + 1) * self.target_length]
166
+ mask = padding_mask[...,mask_idx : mask_idx + self.output_steps]
167
+ with torch.no_grad():
168
+ # We do not include padding tokens in the mean and std calculation.
169
+ embedding = self._get_segment_representation(
170
+ normalize(mt),
171
+ mask
172
+ )
173
+ mask_idx = mask_idx + self.output_steps
174
+ embeddings.append(embedding)
175
+
176
+ x = torch.hstack(embeddings)
177
+ x = x[:, :cut_off, :]
178
+ ts = get_timestamps(self.sample_rate, B, input_audio_len, x)
179
+ return x, ts
180
+
181
+
182
+
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:988546976c453353c14ebb0798f275ea5eb95270d75eed506ea455c5b49f6be0
3
+ size 785248328
modeling_wavjepa.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel
2
+
3
+ from .model import WavJEPA
4
+ from .configuration_wavjepa import WavJEPAConfig
5
+ from .audio_extractor import ConvFeatureExtractor
6
+ import torch
7
+ from typing import Union
8
+
9
+ class WavJEPAModel(PreTrainedModel):
10
+ config_class = WavJEPAConfig
11
+
12
+ def __init__(self, config):
13
+ super().__init__(config)
14
+
15
+ self.model = WavJEPA(
16
+ feature_extractor = ConvFeatureExtractor(
17
+ conv_layers_spec = eval(config.extractor_config['conv_layers_spec']),
18
+ in_channels = config.extractor_config['in_channels'],
19
+ dropout = config.extractor_config['dropout'],
20
+ mode = config.extractor_config['mode'],
21
+ conv_bias = config.extractor_config['conv_bias'],
22
+ depthwise = config.extractor_config['depthwise'],
23
+ ),
24
+ transformer_encoder_layers_cfg = config.encoder_layers_cfg,
25
+ transformer_encoder_cfg = config.encoder_cfg,
26
+ transformer_decoder_layers_cfg = config.decoder_layers_cfg,
27
+ transformer_decoder_cfg = config.decoder_cfg,
28
+ size = config.model_size,
29
+ )
30
+
31
+ def forward(self, tensor) -> Union[torch.Tensor, torch.Tensor]:
32
+ return self.model.get_audio_representation(tensor)
33
+
pos_embed.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # Position embedding utils
8
+ # --------------------------------------------------------
9
+
10
+
11
+ # https://github.com/facebookresearch/AudioMAE/blob/main/util/pos_embed.py
12
+ import numpy as np
13
+ import torch
14
+
15
+
16
+ # --------------------------------------------------------
17
+ # 2D sine-cosine position embedding
18
+ # References:
19
+ # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
20
+ # MoCo v3: https://github.com/facebookresearch/moco-v3
21
+ # --------------------------------------------------------
22
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token_num):
23
+ """
24
+ grid_size: int of the grid height and width
25
+ return:
26
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
27
+ """
28
+ if grid_size is int:
29
+ gH = grid_size
30
+ gW = grid_size
31
+ else:
32
+ gH = grid_size[0]
33
+ gW = grid_size[1]
34
+ grid_h = np.arange(gH, dtype=np.float64)
35
+ grid_w = np.arange(gW, dtype=np.float64)
36
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
37
+ grid = np.stack(grid, axis=0)
38
+
39
+ grid = grid.reshape([2, 1, gH, gW])
40
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
41
+ for _ in range(cls_token_num):
42
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
43
+ return pos_embed
44
+
45
+
46
+ def get_2d_sincos_pos_embed_flexible(embed_dim, grid_size, cls_token=False):
47
+ """
48
+ grid_size: int of the grid height and width
49
+ return:
50
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
51
+ """
52
+ grid_h = np.arange(grid_size[0], dtype=np.float64)
53
+ grid_w = np.arange(grid_size[1], dtype=np.float64)
54
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
55
+ grid = np.stack(grid, axis=0)
56
+
57
+ grid = grid.reshape([2, 1, grid_size[0], grid_size[1]])
58
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
59
+ if cls_token:
60
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
61
+ return pos_embed
62
+
63
+
64
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
65
+ assert embed_dim % 2 == 0
66
+
67
+ # use half of dimensions to encode grid_h
68
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
69
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
70
+
71
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
72
+ return emb
73
+
74
+
75
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
76
+ """
77
+ embed_dim: output dimension for each position
78
+ pos: a list of positions to be encoded: size (M,)
79
+ out: (M, D)
80
+ """
81
+ assert embed_dim % 2 == 0
82
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
83
+ omega /= embed_dim / 2.0
84
+ omega = 1.0 / 10000**omega # (D/2,)
85
+
86
+ pos = pos.reshape(-1) # (M,)
87
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
88
+
89
+ emb_sin = np.sin(out) # (M, D/2)
90
+ emb_cos = np.cos(out) # (M, D/2)
91
+
92
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
93
+ return emb
94
+
95
+
96
+ def get_1d_sincos_pos_embed(embed_dim, length):
97
+ """
98
+ Create 1D sinusoidal positional embeddings.
99
+
100
+ Args:
101
+ embed_dim: embedding dimension
102
+ length: sequence length
103
+
104
+ Returns:
105
+ pos_embed: [length, embed_dim]
106
+ """
107
+ assert embed_dim % 2 == 0
108
+
109
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
110
+ omega /= embed_dim / 2.0
111
+ omega = 1.0 / 10000**omega # (D/2,)
112
+
113
+ pos = np.arange(length, dtype=np.float64) # (length,)
114
+ out = np.einsum("m,d->md", pos, omega) # (length, D/2)
115
+
116
+ emb_sin = np.sin(out) # (length, D/2)
117
+ emb_cos = np.cos(out) # (length, D/2)
118
+
119
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (length, D)
120
+ return emb
121
+
122
+ def get_binaural_pos_embed(embed_dim, time_steps=100):
123
+ """
124
+ Create positional embeddings for binaural audio.
125
+ Same time encoding, different channel encoding.
126
+
127
+ Args:
128
+ embed_dim: embedding dimension
129
+ time_steps: number of time steps per channel
130
+
131
+ Returns:
132
+ pos_embed: [2*time_steps, embed_dim] - for concatenated L+R channels
133
+ """
134
+ assert embed_dim % 2 == 0
135
+
136
+ # Time dimension encoding (same for both channels)
137
+ time_embed = get_1d_sincos_pos_embed(embed_dim // 2, time_steps)
138
+
139
+ # Channel dimension encoding (different for L and R)
140
+ channel_embed_left = np.zeros((time_steps, embed_dim // 2)) # Left channel = 0
141
+ channel_embed_right = get_1d_sincos_pos_embed(embed_dim // 2, 1) # Right channel = different
142
+ channel_embed_right = np.tile(channel_embed_right, (time_steps, 1))
143
+
144
+ # Combine time and channel embeddings
145
+ left_pos_embed = np.concatenate([time_embed, channel_embed_left], axis=1)
146
+ right_pos_embed = np.concatenate([time_embed, channel_embed_right], axis=1)
147
+
148
+ # Concatenate left and right channel embeddings
149
+ binaural_pos_embed = np.concatenate([left_pos_embed, right_pos_embed], axis=0)
150
+
151
+ return binaural_pos_embed
152
+
153
+ # --------------------------------------------------------
154
+ # Interpolate position embeddings for high-resolution
155
+ # References:
156
+ # DeiT: https://github.com/facebookresearch/deit
157
+ # --------------------------------------------------------
158
+ def interpolate_pos_embed(model, checkpoint_model):
159
+ if "pos_embed" in checkpoint_model:
160
+ pos_embed_checkpoint = checkpoint_model["pos_embed"]
161
+ embedding_size = pos_embed_checkpoint.shape[-1]
162
+ num_patches = model.patch_embed.num_patches
163
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
164
+ # height (== width) for the checkpoint position embedding
165
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
166
+ # height (== width) for the new position embedding
167
+ new_size = int(num_patches**0.5)
168
+ # class_token and dist_token are kept unchanged
169
+ if orig_size != new_size:
170
+ print(
171
+ "Position interpolate from %dx%d to %dx%d"
172
+ % (orig_size, orig_size, new_size, new_size)
173
+ )
174
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
175
+ # only the position tokens are interpolated
176
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
177
+ pos_tokens = pos_tokens.reshape(
178
+ -1, orig_size, orig_size, embedding_size
179
+ ).permute(0, 3, 1, 2)
180
+ pos_tokens = torch.nn.functional.interpolate(
181
+ pos_tokens,
182
+ size=(new_size, new_size),
183
+ mode="bicubic",
184
+ align_corners=False,
185
+ )
186
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
187
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
188
+ checkpoint_model["pos_embed"] = new_pos_embed
189
+
190
+
191
+ def interpolate_pos_embed_img2audio(model, checkpoint_model, orig_size, new_size):
192
+ if "pos_embed" in checkpoint_model:
193
+ pos_embed_checkpoint = checkpoint_model["pos_embed"]
194
+ embedding_size = pos_embed_checkpoint.shape[-1]
195
+ num_patches = model.patch_embed.num_patches
196
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
197
+ # height (== width) for the checkpoint position embedding
198
+ # orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
199
+ # height (== width) for the new position embedding
200
+ # new_size = int(num_patches ** 0.5)
201
+ # class_token and dist_token are kept unchanged
202
+ if orig_size != new_size:
203
+ print(
204
+ "Position interpolate from %dx%d to %dx%d"
205
+ % (orig_size[0], orig_size[1], new_size[0], new_size[1])
206
+ )
207
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
208
+ # only the position tokens are interpolated
209
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
210
+ pos_tokens = pos_tokens.reshape(
211
+ -1, orig_size[0], orig_size[1], embedding_size
212
+ ).permute(0, 3, 1, 2)
213
+ pos_tokens = torch.nn.functional.interpolate(
214
+ pos_tokens,
215
+ size=(new_size[0], new_size[1]),
216
+ mode="bicubic",
217
+ align_corners=False,
218
+ )
219
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
220
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
221
+ checkpoint_model["pos_embed"] = new_pos_embed
222
+
223
+
224
+ def interpolate_pos_embed_audio(model, checkpoint_model, orig_size, new_size):
225
+ if "pos_embed" in checkpoint_model:
226
+ pos_embed_checkpoint = checkpoint_model["pos_embed"]
227
+ embedding_size = pos_embed_checkpoint.shape[-1]
228
+ if orig_size != new_size:
229
+ print(
230
+ "Position interpolate from %dx%d to %dx%d"
231
+ % (orig_size[0], orig_size[1], new_size[0], new_size[1])
232
+ )
233
+ # extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
234
+ # only the position tokens are interpolated
235
+ cls_token = pos_embed_checkpoint[:, 0, :].unsqueeze(1)
236
+ pos_tokens = pos_embed_checkpoint[:, 1:, :] # remove
237
+ pos_tokens = pos_tokens.reshape(
238
+ -1, orig_size[0], orig_size[1], embedding_size
239
+ ) # .permute(0, 3, 1, 2)
240
+ # pos_tokens = torch.nn.functional.interpolate(
241
+ # pos_tokens, size=(new_size[0], new_size[1]), mode='bicubic', align_corners=False)
242
+
243
+ # pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
244
+ pos_tokens = pos_tokens[:, :, : new_size[1], :] # assume only time diff
245
+ pos_tokens = pos_tokens.flatten(1, 2)
246
+ new_pos_embed = torch.cat((cls_token, pos_tokens), dim=1)
247
+ checkpoint_model["pos_embed"] = new_pos_embed
248
+
249
+
250
+ def interpolate_patch_embed_audio(
251
+ model,
252
+ checkpoint_model,
253
+ orig_channel,
254
+ new_channel=1,
255
+ kernel_size=(16, 16),
256
+ stride=(16, 16),
257
+ padding=(0, 0),
258
+ ):
259
+ if orig_channel != new_channel:
260
+ if "patch_embed.proj.weight" in checkpoint_model:
261
+ # aggregate 3 channels in rgb ckpt to 1 channel for audio
262
+ new_proj_weight = torch.nn.Parameter(
263
+ torch.sum(checkpoint_model["patch_embed.proj.weight"], dim=1).unsqueeze(
264
+ 1
265
+ )
266
+ )
267
+ checkpoint_model["patch_embed.proj.weight"] = new_proj_weight
types.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TypedDict
2
+ from torch import nn
3
+
4
+
5
+ class TransformerLayerCFG(TypedDict):
6
+ d_model : int
7
+ nhead : int
8
+ batch_first : bool
9
+ norm_first : bool
10
+ bias : bool
11
+ dim_feedforward : int
12
+ dropout : float
13
+ activation : nn.Module
14
+ layer_norm_eps : float
15
+
16
+ @classmethod
17
+ def create(cls,
18
+ d_model : int = 768,
19
+ nhead : int = 12,
20
+ batch_first : bool = True,
21
+ norm_first : bool = False,
22
+ bias : bool = True,
23
+ mlp_ratio : float = 4.0,
24
+ dropout : float = 0.0,
25
+ activation : nn.Module = nn.GELU(),
26
+ layer_norm_eps : float = 1e-6) -> 'TransformerLayerCFG':
27
+ return TransformerLayerCFG(d_model = d_model,
28
+ nhead = nhead,
29
+ batch_first = batch_first,
30
+ norm_first = norm_first,
31
+ bias = bias,
32
+ dim_feedforward = int(d_model * mlp_ratio),
33
+ dropout = dropout,
34
+ activation = activation,
35
+ layer_norm_eps = layer_norm_eps)
36
+
37
+
38
+ # Norm needs to be defined by the user!
39
+ class TransformerEncoderCFG(TypedDict):
40
+ num_layers : int
41
+ enable_nested_tensor: bool
42
+ mask_check: bool
43
+
44
+ @classmethod
45
+ def create(cls,
46
+ num_layers : int = 12,
47
+ enable_nested_tensor: bool = False,
48
+ mask_check: bool = True) -> 'TransformerEncoderCFG':
49
+ return TransformerEncoderCFG(num_layers=num_layers,
50
+ enable_nested_tensor = enable_nested_tensor,
51
+ mask_check = mask_check)
utils.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def normalize(audio):
4
+ mean = audio.mean(dim=(-2, -1), keepdim=True)
5
+ std = audio.std(dim=(-2, -1), keepdim=True)
6
+ audio = (audio - mean) / (std + 1e-5) # Add epsilon for stability
7
+ return audio
8
+
9
+ def calculate_padding_mask(pad_frames, total_frames, sr, output_steps, process_seconds, device, B):
10
+ # How many 2 seconds chunks does this audio have?
11
+ # Find it and then multiply by the output_steps.
12
+ total_frames = int((total_frames / sr) / process_seconds)
13
+ total_output_steps = output_steps * total_frames
14
+ mask = torch.zeros((B, total_output_steps), dtype = torch.bool, device = device)
15
+
16
+ # Check the number of padding tokens that we have in the audio.
17
+ output_sr = int(output_steps / process_seconds)
18
+ pad_seconds = pad_frames / sr
19
+ pad_steps = int(pad_seconds * output_sr)
20
+ # Create the mask
21
+
22
+ mask[..., total_output_steps - pad_steps:] = True
23
+ return mask, total_output_steps - pad_steps
24
+
25
+
26
+ def get_timestamps(sample_rate, B, input_audio_len, x):
27
+ audio_len = input_audio_len
28
+ sec = audio_len / sample_rate
29
+ x_len = x.shape[1]
30
+ step = sec / x_len * 1000 # sec -> ms
31
+ ts = torch.tensor([step * i for i in range(x_len)]).unsqueeze(0)
32
+ ts = ts.repeat(B, 1)
33
+ return ts