surokpro2 commited on
Commit
837d48c
·
verified ·
1 Parent(s): 7b003bc

Upload 31 files

Browse files
Files changed (31) hide show
  1. README.md +7 -6
  2. SAE/__init__.py +1 -0
  3. SAE/config.json +23 -0
  4. SAE/dataset_iterator.py +53 -0
  5. SAE/sae.py +215 -0
  6. SAE/sae_utils.py +48 -0
  7. SDLens/__init__.py +1 -0
  8. SDLens/hooked_scheduler.py +40 -0
  9. SDLens/hooked_sd_pipeline.py +321 -0
  10. app.py +554 -0
  11. checkpoints/unet.down_blocks.2.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/config.json +1 -0
  12. checkpoints/unet.down_blocks.2.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/mean.pt +3 -0
  13. checkpoints/unet.down_blocks.2.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/state_dict.pth +3 -0
  14. checkpoints/unet.down_blocks.2.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/std.pt +3 -0
  15. checkpoints/unet.mid_block.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/config.json +1 -0
  16. checkpoints/unet.mid_block.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/mean.pt +3 -0
  17. checkpoints/unet.mid_block.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/state_dict.pth +3 -0
  18. checkpoints/unet.mid_block.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/std.pt +3 -0
  19. checkpoints/unet.up_blocks.0.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/config.json +1 -0
  20. checkpoints/unet.up_blocks.0.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/mean.pt +3 -0
  21. checkpoints/unet.up_blocks.0.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/state_dict.pth +3 -0
  22. checkpoints/unet.up_blocks.0.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/std.pt +3 -0
  23. checkpoints/unet.up_blocks.0.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/config.json +1 -0
  24. checkpoints/unet.up_blocks.0.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/mean.pt +3 -0
  25. checkpoints/unet.up_blocks.0.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/state_dict.pth +3 -0
  26. checkpoints/unet.up_blocks.0.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/std.pt +3 -0
  27. requirements.txt +10 -0
  28. scripts/collect_latents_dataset.py +96 -0
  29. scripts/train_sae.py +308 -0
  30. utils/__init__.py +1 -0
  31. utils/hooks.py +120 -0
README.md CHANGED
@@ -1,12 +1,13 @@
1
  ---
2
- title: Sdxl Sae Multistep
3
- emoji: 😻
4
- colorFrom: gray
5
- colorTo: gray
6
  sdk: gradio
7
- sdk_version: 5.32.0
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Sdxlsae
3
+ emoji: 🔥
4
+ colorFrom: red
5
+ colorTo: red
6
  sdk: gradio
7
+ sdk_version: 5.23.2
8
  app_file: app.py
9
  pinned: false
10
+ license: mit
11
  ---
12
 
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
SAE/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .sae import SparseAutoencoder
SAE/config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "sae_configs": [
3
+ {
4
+ "d_model": 1280,
5
+ "n_dirs": 5120,
6
+ "k": 20
7
+ },
8
+ {
9
+ "d_model": 1280,
10
+ "n_dirs": 640,
11
+ "k": 20
12
+ }
13
+ ],
14
+ "bs": 4096,
15
+ "log_interval": 500,
16
+ "save_interval": 5000,
17
+
18
+ "paths_to_latents": [
19
+ "PASS YOUR PATHS HERE. Example /home/username/latents/<timestamp>. It should contain tar archives with latents."
20
+ ],
21
+ "save_path_base": "<Your SAE save path>",
22
+ "block_name": "unet.down_blocks.2.attentions.1"
23
+ }
SAE/dataset_iterator.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import webdataset as wds
2
+ import os
3
+ import torch
4
+
5
+ class ActivationsDataloader:
6
+ def __init__(self, paths_to_datasets, block_name, batch_size, output_or_diff='diff', num_in_buffer=50):
7
+ assert output_or_diff in ['diff', 'output'], "Provide 'output' or 'diff'"
8
+
9
+ self.dataset = wds.WebDataset(
10
+ [os.path.join(path_to_dataset, f"{block_name}.tar")
11
+ for path_to_dataset in paths_to_datasets]
12
+ ).decode("torch")
13
+ self.iter = iter(self.dataset)
14
+ self.buffer = None
15
+ self.pointer = 0
16
+ self.num_in_buffer = num_in_buffer
17
+ self.output_or_diff = output_or_diff
18
+ self.batch_size = batch_size
19
+ self.one_size = None
20
+
21
+ def renew_buffer(self, to_retrieve):
22
+ to_merge = []
23
+ if self.buffer is not None and self.buffer.shape[0] > self.pointer:
24
+ to_merge = [self.buffer[self.pointer:].clone()]
25
+ del self.buffer
26
+ for _ in range(to_retrieve):
27
+ sample = next(self.iter)
28
+ latents = sample['output.pth'] if self.output_or_diff == 'output' else sample['diff.pth']
29
+ latents = latents.permute((0, 1, 3, 4, 2))
30
+ latents = latents.reshape((-1, latents.shape[-1]))
31
+ to_merge.append(latents.to('cuda'))
32
+ self.one_size = latents.shape[0]
33
+ self.buffer = torch.cat(to_merge, dim=0)
34
+ shuffled_indices = torch.randperm(self.buffer.shape[0])
35
+ self.buffer = self.buffer[shuffled_indices]
36
+ self.pointer = 0
37
+
38
+ def iterate(self):
39
+ while True:
40
+ if self.buffer == None or self.buffer.shape[0] - self.pointer < self.num_in_buffer * self.one_size * 4 // 5:
41
+ try:
42
+ to_retrieve = self.num_in_buffer if self.buffer is None else self.num_in_buffer // 5
43
+ self.renew_buffer(to_retrieve)
44
+ except StopIteration:
45
+ break
46
+
47
+ batch = self.buffer[self.pointer: self.pointer + self.batch_size]
48
+ self.pointer += self.batch_size
49
+
50
+ assert batch.shape[0] == self.batch_size
51
+ yield batch
52
+
53
+
SAE/sae.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Adapted from
3
+ https://github.com/openai/sparse_autoencoder/blob/main/sparse_autoencoder/model.py
4
+ '''
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import os
9
+ import json
10
+ import spaces
11
+ import logging
12
+
13
+ class SparseAutoencoder(nn.Module):
14
+ """
15
+ Top-K Autoencoder with sparse kernels. Implements:
16
+
17
+ latents = relu(topk(encoder(x - pre_bias) + latent_bias))
18
+ recons = decoder(latents) + pre_bias
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ n_dirs_local: int,
24
+ d_model: int,
25
+ k: int,
26
+ auxk: int | None,
27
+ dead_steps_threshold: int,
28
+ ):
29
+ super().__init__()
30
+ self.n_dirs_local = n_dirs_local
31
+ self.d_model = d_model
32
+ self.k = k
33
+ self.auxk = auxk
34
+ self.dead_steps_threshold = dead_steps_threshold
35
+
36
+ self.encoder = nn.Linear(d_model, n_dirs_local, bias=False)
37
+ self.decoder = nn.Linear(n_dirs_local, d_model, bias=False)
38
+
39
+ self.pre_bias = nn.Parameter(torch.zeros(d_model))
40
+ self.latent_bias = nn.Parameter(torch.zeros(n_dirs_local))
41
+
42
+ self.stats_last_nonzero: torch.Tensor
43
+ self.register_buffer("stats_last_nonzero", torch.zeros(n_dirs_local, dtype=torch.long))
44
+
45
+ ## initialization
46
+
47
+ # "tied" init
48
+ self.decoder.weight.data = self.encoder.weight.data.T.clone()
49
+
50
+ # store decoder in column major layout for kernel
51
+ self.decoder.weight.data = self.decoder.weight.data.T.contiguous().T
52
+
53
+ unit_norm_decoder_(self)
54
+
55
+ def auxk_mask_fn(self, x):
56
+ dead_mask = self.stats_last_nonzero > dead_steps_threshold
57
+ x.data *= dead_mask # inplace to save memory
58
+ return x
59
+
60
+ def save_to_disk(self, path: str):
61
+ PATH_TO_CFG = 'config.json'
62
+ PATH_TO_WEIGHTS = 'state_dict.pth'
63
+
64
+ cfg = {
65
+ "n_dirs_local": self.n_dirs_local,
66
+ "d_model": self.d_model,
67
+ "k": self.k,
68
+ "auxk": self.auxk,
69
+ "dead_steps_threshold": self.dead_steps_threshold,
70
+ }
71
+
72
+ os.makedirs(path, exist_ok=True)
73
+
74
+ with open(os.path.join(path, PATH_TO_CFG), 'w') as f:
75
+ json.dump(cfg, f)
76
+
77
+
78
+ torch.save({
79
+ "state_dict": self.state_dict(),
80
+ }, os.path.join(path, PATH_TO_WEIGHTS))
81
+
82
+
83
+ @classmethod
84
+ def load_from_disk(cls, path: str):
85
+ PATH_TO_CFG = 'config.json'
86
+ PATH_TO_WEIGHTS = 'state_dict.pth'
87
+
88
+ with open(os.path.join(path, PATH_TO_CFG), 'r') as f:
89
+ cfg = json.load(f)
90
+
91
+ ae = cls(
92
+ n_dirs_local=cfg["n_dirs_local"],
93
+ d_model=cfg["d_model"],
94
+ k=cfg["k"],
95
+ auxk=cfg["auxk"],
96
+ dead_steps_threshold=cfg["dead_steps_threshold"],
97
+ )
98
+
99
+ state_dict = torch.load(os.path.join(path, PATH_TO_WEIGHTS))["state_dict"]
100
+ ae.load_state_dict(state_dict)
101
+
102
+ return ae
103
+
104
+ @property
105
+ def n_dirs(self):
106
+ return self.n_dirs_local
107
+
108
+ def encode(self, x):
109
+ x = x.to('cuda') - self.pre_bias
110
+ latents_pre_act = self.encoder(x) + self.latent_bias
111
+
112
+ vals, inds = torch.topk(
113
+ latents_pre_act,
114
+ k=self.k,
115
+ dim=-1
116
+ )
117
+
118
+ latents = torch.zeros_like(latents_pre_act)
119
+ latents.scatter_(-1, inds, torch.relu(vals))
120
+
121
+ return latents
122
+
123
+ def forward(self, x):
124
+ x = x - self.pre_bias
125
+ latents_pre_act = self.encoder(x) + self.latent_bias
126
+ vals, inds = torch.topk(
127
+ latents_pre_act,
128
+ k=self.k,
129
+ dim=-1
130
+ )
131
+
132
+ ## set num nonzero stat ##
133
+ tmp = torch.zeros_like(self.stats_last_nonzero)
134
+ tmp.scatter_add_(
135
+ 0,
136
+ inds.reshape(-1),
137
+ (vals > 1e-3).to(tmp.dtype).reshape(-1),
138
+ )
139
+ self.stats_last_nonzero *= 1 - tmp.clamp(max=1)
140
+ self.stats_last_nonzero += 1
141
+ ## end stats ##
142
+
143
+ ## auxk
144
+ if self.auxk is not None: # for auxk
145
+ # IMPORTANT: has to go after stats update!
146
+ # WARN: auxk_mask_fn can mutate latents_pre_act!
147
+ auxk_vals, auxk_inds = torch.topk(
148
+ self.auxk_mask_fn(latents_pre_act),
149
+ k=self.auxk,
150
+ dim=-1
151
+ )
152
+ else:
153
+ auxk_inds = None
154
+ auxk_vals = None
155
+
156
+ ## end auxk
157
+
158
+ vals = torch.relu(vals)
159
+ if auxk_vals is not None:
160
+ auxk_vals = torch.relu(auxk_vals)
161
+
162
+
163
+ rows, cols = latents_pre_act.size()
164
+ row_indices = torch.arange(rows).unsqueeze(1).expand(-1, self.k).reshape(-1)
165
+ vals = vals.reshape(-1)
166
+ inds = inds.reshape(-1)
167
+
168
+ indices = torch.stack([row_indices.to(inds.device), inds])
169
+
170
+ sparse_tensor = torch.sparse_coo_tensor(indices, vals, torch.Size([rows, cols]))
171
+
172
+ recons = torch.sparse.mm(sparse_tensor, self.decoder.weight.T) + self.pre_bias
173
+
174
+
175
+ return recons, {
176
+ "inds": inds,
177
+ "vals": vals,
178
+ "auxk_inds": auxk_inds,
179
+ "auxk_vals": auxk_vals,
180
+ }
181
+
182
+ def decode_sparse(self, inds, vals):
183
+ rows, cols = inds.shape[0], self.n_dirs
184
+
185
+ row_indices = torch.arange(rows).unsqueeze(1).expand(-1, inds.shape[1]).reshape(-1)
186
+ vals = vals.reshape(-1)
187
+ inds = inds.reshape(-1)
188
+
189
+ indices = torch.stack([row_indices.to(inds.device), inds])
190
+
191
+ sparse_tensor = torch.sparse_coo_tensor(indices, vals, torch.Size([rows, cols]))
192
+
193
+ recons = torch.sparse.mm(sparse_tensor, self.decoder.weight.T) + self.pre_bias
194
+ return recons
195
+
196
+ @property
197
+ def device(self):
198
+ return next(self.parameters()).device
199
+
200
+
201
+ def unit_norm_decoder_(autoencoder: SparseAutoencoder) -> None:
202
+ """
203
+ Unit normalize the decoder weights of an autoencoder.
204
+ """
205
+ autoencoder.decoder.weight.data /= autoencoder.decoder.weight.data.norm(dim=0)
206
+
207
+
208
+ def unit_norm_decoder_grad_adjustment_(autoencoder) -> None:
209
+ """project out gradient information parallel to the dictionary vectors - assumes that the decoder is already unit normed"""
210
+
211
+ assert autoencoder.decoder.weight.grad is not None
212
+
213
+ autoencoder.decoder.weight.grad +=\
214
+ torch.einsum("bn,bn->n", autoencoder.decoder.weight.data, autoencoder.decoder.weight.grad) *\
215
+ autoencoder.decoder.weight.data * -1
SAE/sae_utils.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from dataclasses import dataclass, field
3
+ import os
4
+
5
+ @dataclass
6
+ class SAETrainingConfig:
7
+ d_model: int
8
+ n_dirs: int
9
+ k: int
10
+ block_name: str
11
+ bs: int
12
+ save_path_base: str
13
+ auxk: int = 256
14
+ lr: float = 1e-4
15
+ eps: float = 6.25e-10
16
+ dead_toks_threshold: int = 10_000_000
17
+ auxk_coef: float = 1/32
18
+
19
+ @property
20
+ def sae_name(self):
21
+ return f'{self.block_name}_k{self.k}_hidden{self.n_dirs}_auxk{self.auxk}_bs{self.bs}_lr{self.lr}'
22
+
23
+ @property
24
+ def save_path(self):
25
+ return os.path.join(save_path_base, f'{self.block_name}_k{self.k}_hidden{self.n_dirs}_auxk{self.auxk}_bs{self.bs}_lr{self.lr}')
26
+
27
+
28
+ @dataclass
29
+ class Config:
30
+ saes: list[SAETrainingConfig]
31
+ paths_to_latents: list[str]
32
+ log_interval: int
33
+ save_interval: int
34
+ bs: int
35
+ block_name: str
36
+ wandb_project: str = 'sdxl_sae_train'
37
+ wandb_name: str = 'multiple_sae'
38
+
39
+ def __init__(self, cfg_json):
40
+ self.saes = [SAETrainingConfig(**sae_cfg, block_name=cfg_json['block_name'], bs=cfg_json['bs'], save_path_base=cfg_json['save_path_base'])
41
+ for sae_cfg in cfg_json['sae_configs']]
42
+
43
+ self.save_path_base = cfg_json['save_path_base']
44
+ self.paths_to_latents = cfg_json['paths_to_latents']
45
+ self.log_interval = cfg_json['log_interval']
46
+ self.save_interval = cfg_json['save_interval']
47
+ self.bs = cfg_json['bs']
48
+ self.block_name = cfg_json['block_name']
SDLens/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .hooked_sd_pipeline import HookedIFPipeline, HookedStableDiffusionXLPipeline
SDLens/hooked_scheduler.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import DDPMScheduler
2
+ import torch
3
+
4
+ class HookedNoiseScheduler:
5
+ scheduler: DDPMScheduler
6
+ pre_hooks: list
7
+ post_hooks: list
8
+
9
+ def __init__(self, scheduler):
10
+ object.__setattr__(self, 'scheduler', scheduler)
11
+ object.__setattr__(self, 'pre_hooks', [])
12
+ object.__setattr__(self, 'post_hooks', [])
13
+
14
+ def step(
15
+ self,
16
+ model_output, timestep, sample, generator, return_dict
17
+ ):
18
+ assert return_dict == False, "return_dict == True is not implemented"
19
+ for hook in self.pre_hooks:
20
+ hook_output = hook(model_output, timestep, sample, generator)
21
+ if hook_output is not None:
22
+ model_output, timestep, sample, generator = hook_output
23
+
24
+ (pred_prev_sample, ) = self.scheduler.step(model_output, timestep, sample, generator, return_dict)
25
+
26
+ for hook in self.post_hooks:
27
+ hook_output = hook(pred_prev_sample)
28
+ if hook_output is not None:
29
+ pred_prev_sample = hook_output
30
+
31
+ return (pred_prev_sample, )
32
+
33
+ def __getattr__(self, name):
34
+ return getattr(self.scheduler, name)
35
+
36
+ def __setattr__(self, name, value):
37
+ if name in {'scheduler', 'pre_hooks', 'post_hooks'}:
38
+ object.__setattr__(self, name, value)
39
+ else:
40
+ setattr(self.scheduler, name, value)
SDLens/hooked_sd_pipeline.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import einops
2
+ from diffusers import StableDiffusionXLPipeline, IFPipeline
3
+ from typing import List, Dict, Callable, Union
4
+ import torch
5
+ from .hooked_scheduler import HookedNoiseScheduler
6
+ import spaces
7
+
8
+ def retrieve(io):
9
+ if isinstance(io, tuple):
10
+ if len(io) == 1:
11
+ return io[0]
12
+ else:
13
+ raise ValueError("A tuple should have length of 1")
14
+ elif isinstance(io, torch.Tensor):
15
+ return io
16
+ else:
17
+ raise ValueError("Input/Output must be a tensor, or 1-element tuple")
18
+
19
+
20
+ class HookedDiffusionAbstractPipeline:
21
+ parent_cls = None
22
+ pipe = None
23
+
24
+ def __init__(self, pipe: parent_cls, use_hooked_scheduler: bool = False):
25
+ if use_hooked_scheduler:
26
+ pipe.scheduler = HookedNoiseScheduler(pipe.scheduler)
27
+ self.__dict__['pipe'] = pipe
28
+ self.use_hooked_scheduler = use_hooked_scheduler
29
+
30
+ @classmethod
31
+ def from_pretrained(cls, *args, **kwargs):
32
+ return cls(cls.parent_cls.from_pretrained(*args, **kwargs))
33
+
34
+ def run_with_hooks(self,
35
+ *args,
36
+ position_hook_dict: Dict[str, Union[Callable, List[Callable]]],
37
+ **kwargs
38
+ ):
39
+ '''
40
+ Run the pipeline with hooks at specified positions.
41
+ Returns the final output.
42
+
43
+ Args:
44
+ *args: Arguments to pass to the pipeline.
45
+ position_hook_dict: A dictionary mapping positions to hooks.
46
+ The keys are positions in the pipeline where the hooks should be registered.
47
+ The values are either a single hook or a list of hooks to be registered at the specified position.
48
+ Each hook should be a callable that takes three arguments: (module, input, output).
49
+ **kwargs: Keyword arguments to pass to the pipeline.
50
+ '''
51
+ hooks = []
52
+ for position, hook in position_hook_dict.items():
53
+ if isinstance(hook, list):
54
+ for h in hook:
55
+ hooks.append(self._register_general_hook(position, h))
56
+ else:
57
+ hooks.append(self._register_general_hook(position, hook))
58
+
59
+ hooks = [hook for hook in hooks if hook is not None]
60
+
61
+ try:
62
+ output = self.pipe(*args, **kwargs)
63
+ finally:
64
+ for hook in hooks:
65
+ hook.remove()
66
+ if self.use_hooked_scheduler:
67
+ self.pipe.scheduler.pre_hooks = []
68
+ self.pipe.scheduler.post_hooks = []
69
+
70
+ return output
71
+
72
+
73
+ def run_with_cache(self,
74
+ *args,
75
+ positions_to_cache: List[str],
76
+ save_input: bool = False,
77
+ save_output: bool = True,
78
+ **kwargs
79
+ ):
80
+ '''
81
+ Run the pipeline with caching at specified positions.
82
+
83
+ This method allows you to cache the intermediate inputs and/or outputs of the pipeline
84
+ at certain positions. The final output of the pipeline and a dictionary of cached values
85
+ are returned.
86
+
87
+ Args:
88
+ *args: Arguments to pass to the pipeline.
89
+ positions_to_cache (List[str]): A list of positions in the pipeline where intermediate
90
+ inputs/outputs should be cached.
91
+ save_input (bool, optional): If True, caches the input at each specified position.
92
+ Defaults to False.
93
+ save_output (bool, optional): If True, caches the output at each specified position.
94
+ Defaults to True.
95
+ **kwargs: Keyword arguments to pass to the pipeline.
96
+
97
+ Returns:
98
+ final_output: The final output of the pipeline after execution.
99
+ cache_dict (Dict[str, Dict[str, Any]]): A dictionary where keys are the specified positions
100
+ and values are dictionaries containing the cached 'input' and/or 'output' at each position,
101
+ depending on the flags `save_input` and `save_output`.
102
+ '''
103
+ cache_input, cache_output = dict() if save_input else None, dict() if save_output else None
104
+ hooks = [
105
+ self._register_cache_hook(position, cache_input, cache_output) for position in positions_to_cache
106
+ ]
107
+ hooks = [hook for hook in hooks if hook is not None]
108
+ output = self.pipe(*args, **kwargs)
109
+ for hook in hooks:
110
+ hook.remove()
111
+ if self.use_hooked_scheduler:
112
+ self.pipe.scheduler.pre_hooks = []
113
+ self.pipe.scheduler.post_hooks = []
114
+
115
+ cache_dict = {}
116
+ if save_input:
117
+ for position, block in cache_input.items():
118
+ cache_input[position] = torch.stack(block, dim=1)
119
+ cache_dict['input'] = cache_input
120
+
121
+ if save_output:
122
+ for position, block in cache_output.items():
123
+ cache_output[position] = torch.stack(block, dim=1)
124
+ cache_dict['output'] = cache_output
125
+ return output, cache_dict
126
+
127
+
128
+ def run_with_hooks_and_cache(self,
129
+ *args,
130
+ position_hook_dict: Dict[str, Union[Callable, List[Callable]]],
131
+ positions_to_cache: List[str] = [],
132
+ save_input: bool = False,
133
+ save_output: bool = True,
134
+ **kwargs
135
+ ):
136
+ '''
137
+ Run the pipeline with hooks and caching at specified positions.
138
+
139
+ This method allows you to register hooks at certain positions in the pipeline and
140
+ cache intermediate inputs and/or outputs at specified positions. Hooks can be used
141
+ for inspecting or modifying the pipeline's execution, and caching stores intermediate
142
+ values for later inspection or use.
143
+
144
+ Args:
145
+ *args: Arguments to pass to the pipeline.
146
+ position_hook_dict Dict[str, Union[Callable, List[Callable]]]:
147
+ A dictionary where the keys are the positions in the pipeline, and the values
148
+ are hooks (either a single hook or a list of hooks) to be registered at those positions.
149
+ Each hook should be a callable that accepts three arguments: (module, input, output).
150
+ positions_to_cache (List[str], optional): A list of positions in the pipeline where
151
+ intermediate inputs/outputs should be cached. Defaults to an empty list.
152
+ save_input (bool, optional): If True, caches the input at each specified position.
153
+ Defaults to False.
154
+ save_output (bool, optional): If True, caches the output at each specified position.
155
+ Defaults to True.
156
+ **kwargs: Additional keyword arguments to pass to the pipeline.
157
+
158
+ Returns:
159
+ final_output: The final output of the pipeline after execution.
160
+ cache_dict (Dict[str, Dict[str, Any]]): A dictionary where keys are the specified positions
161
+ and values are dictionaries containing the cached 'input' and/or 'output' at each position,
162
+ depending on the flags `save_input` and `save_output`.
163
+ '''
164
+ cache_input, cache_output = dict() if save_input else None, dict() if save_output else None
165
+ hooks = [
166
+ self._register_cache_hook(position, cache_input, cache_output) for position in positions_to_cache
167
+ ]
168
+
169
+ for position, hook in position_hook_dict.items():
170
+ if isinstance(hook, list):
171
+ for h in hook:
172
+ hooks.append(self._register_general_hook(position, h))
173
+ else:
174
+ hooks.append(self._register_general_hook(position, hook))
175
+
176
+ hooks = [hook for hook in hooks if hook is not None]
177
+ output = self.pipe(*args, **kwargs)
178
+ for hook in hooks:
179
+ hook.remove()
180
+ if self.use_hooked_scheduler:
181
+ self.pipe.scheduler.pre_hooks = []
182
+ self.pipe.scheduler.post_hooks = []
183
+
184
+ cache_dict = {}
185
+ if save_input:
186
+ for position, block in cache_input.items():
187
+ cache_input[position] = torch.stack(block, dim=1)
188
+ cache_dict['input'] = cache_input
189
+
190
+ if save_output:
191
+ for position, block in cache_output.items():
192
+ cache_output[position] = torch.stack(block, dim=1)
193
+ cache_dict['output'] = cache_output
194
+
195
+ return output, cache_dict
196
+
197
+
198
+ def _locate_block(self, position: str):
199
+ '''
200
+ Locate the block at the specified position in the pipeline.
201
+ '''
202
+ block = self.pipe
203
+ for step in position.split('.'):
204
+ if step.isdigit():
205
+ step = int(step)
206
+ block = block[step]
207
+ else:
208
+ block = getattr(block, step)
209
+ return block
210
+
211
+
212
+ def _register_cache_hook(self, position: str, cache_input: Dict, cache_output: Dict):
213
+
214
+ if position.endswith('$self_attention') or position.endswith('$cross_attention'):
215
+ return self._register_cache_attention_hook(position, cache_output)
216
+
217
+ if position == 'noise':
218
+ def hook(model_output, timestep, sample, generator):
219
+ if position not in cache_output:
220
+ cache_output[position] = []
221
+ cache_output[position].append(sample)
222
+
223
+ if self.use_hooked_scheduler:
224
+ self.pipe.scheduler.post_hooks.append(hook)
225
+ else:
226
+ raise ValueError('Cannot cache noise without using hooked scheduler')
227
+ return
228
+
229
+ block = self._locate_block(position)
230
+
231
+ def hook(module, input, kwargs, output):
232
+ if cache_input is not None:
233
+ if position not in cache_input:
234
+ cache_input[position] = []
235
+ cache_input[position].append(retrieve(input))
236
+
237
+ if cache_output is not None:
238
+ if position not in cache_output:
239
+ cache_output[position] = []
240
+ cache_output[position].append(retrieve(output))
241
+
242
+ return block.register_forward_hook(hook, with_kwargs=True)
243
+
244
+ def _register_cache_attention_hook(self, position, cache):
245
+ attn_block = self._locate_block(position.split('$')[0])
246
+ if position.endswith('$self_attention'):
247
+ attn_block = attn_block.attn1
248
+ elif position.endswith('$cross_attention'):
249
+ attn_block = attn_block.attn2
250
+ else:
251
+ raise ValueError('Wrong attention type')
252
+
253
+ def hook(module, args, kwargs, output):
254
+ hidden_states = args[0]
255
+ encoder_hidden_states = kwargs['encoder_hidden_states']
256
+ attention_mask = kwargs['attention_mask']
257
+ batch_size, sequence_length, _ = hidden_states.shape
258
+ attention_mask = attn_block.prepare_attention_mask(attention_mask, sequence_length, batch_size)
259
+ query = attn_block.to_q(hidden_states)
260
+
261
+
262
+ if encoder_hidden_states is None:
263
+ encoder_hidden_states = hidden_states
264
+ elif attn_block.norm_cross is not None:
265
+ encoder_hidden_states = attn_block.norm_cross(encoder_hidden_states)
266
+
267
+ key = attn_block.to_k(encoder_hidden_states)
268
+ value = attn_block.to_v(encoder_hidden_states)
269
+
270
+ query = attn_block.head_to_batch_dim(query)
271
+ key = attn_block.head_to_batch_dim(key)
272
+ value = attn_block.head_to_batch_dim(value)
273
+
274
+ attention_probs = attn_block.get_attention_scores(query, key, attention_mask)
275
+ attention_probs = attention_probs.view(
276
+ batch_size,
277
+ attention_probs.shape[0] // batch_size,
278
+ attention_probs.shape[1],
279
+ attention_probs.shape[2]
280
+ )
281
+ if position not in cache:
282
+ cache[position] = []
283
+ cache[position].append(attention_probs)
284
+
285
+ return attn_block.register_forward_hook(hook, with_kwargs=True)
286
+
287
+ def _register_general_hook(self, position, hook):
288
+ if position == 'scheduler_pre':
289
+ if not self.use_hooked_scheduler:
290
+ raise ValueError('Cannot register hooks on scheduler without using hooked scheduler')
291
+ self.pipe.scheduler.pre_hooks.append(hook)
292
+ return
293
+ elif position == 'scheduler_post':
294
+ if not self.use_hooked_scheduler:
295
+ raise ValueError('Cannot register hooks on scheduler without using hooked scheduler')
296
+ self.pipe.scheduler.post_hooks.append(hook)
297
+ return
298
+
299
+ block = self._locate_block(position)
300
+ return block.register_forward_hook(hook)
301
+
302
+ def to(self, *args, **kwargs):
303
+ self.pipe = self.pipe.to(*args, **kwargs)
304
+ return self
305
+
306
+ def __getattr__(self, name):
307
+ return getattr(self.pipe, name)
308
+
309
+ def __setattr__(self, name, value):
310
+ return setattr(self.pipe, name, value)
311
+
312
+ def __call__(self, *args, **kwargs):
313
+ return self.pipe(*args, **kwargs)
314
+
315
+
316
+ class HookedStableDiffusionXLPipeline(HookedDiffusionAbstractPipeline):
317
+ parent_cls = StableDiffusionXLPipeline
318
+
319
+
320
+ class HookedIFPipeline(HookedDiffusionAbstractPipeline):
321
+ parent_cls = IFPipeline
app.py ADDED
@@ -0,0 +1,554 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import torch
4
+ from PIL import Image
5
+ from SDLens import HookedStableDiffusionXLPipeline
6
+ from SAE import SparseAutoencoder
7
+ from utils import TimedHook, add_feature_on_area_base, replace_with_feature_base, add_feature_on_area_turbo, replace_with_feature_turbo
8
+ import numpy as np
9
+ import matplotlib.pyplot as plt
10
+ from matplotlib.colors import ListedColormap
11
+ import threading
12
+ import spaces
13
+ code_to_block = {
14
+ "down.2.1": "unet.down_blocks.2.attentions.1",
15
+ "mid.0": "unet.mid_block.attentions.0",
16
+ "up.0.1": "unet.up_blocks.0.attentions.1",
17
+ "up.0.0": "unet.up_blocks.0.attentions.0"
18
+ }
19
+ lock = threading.Lock()
20
+
21
+ base_guidance_scale_default = 8.0
22
+ turbo_guidance_scale_default = 0.0
23
+
24
+
25
+ def process_cache(cache, saes_dict, timestep=None):
26
+
27
+ top_features_dict = {}
28
+ sparse_maps_dict = {}
29
+
30
+ for code in code_to_block.keys():
31
+ block = code_to_block[code]
32
+ sae = saes_dict[code]
33
+
34
+ diff = cache["output"][block] - cache["input"][block]
35
+ if diff.shape[0] == 2: # guidance is on and we need to select the second output
36
+ diff = diff[1].unsqueeze(0)
37
+
38
+ # If a specific timestep is provided, select that timestep from the cached activations
39
+ if timestep is not None and timestep < diff.shape[1]:
40
+ diff = diff[:, timestep:timestep+1]
41
+
42
+ diff = diff.permute(0, 1, 3, 4, 2).squeeze(0).squeeze(0)
43
+ with torch.no_grad():
44
+ sparse_maps = sae.encode(diff)
45
+ averages = torch.mean(sparse_maps, dim=(0, 1))
46
+
47
+ top_features = torch.topk(averages, 10).indices
48
+
49
+ top_features_dict[code] = top_features.cpu().tolist()
50
+ sparse_maps_dict[code] = sparse_maps.cpu().numpy()
51
+
52
+ return top_features_dict, sparse_maps_dict
53
+
54
+
55
+ def plot_image_heatmap(cache, block_select, radio):
56
+ code = block_select.split()[0]
57
+ feature = int(radio)
58
+ block = code_to_block[code]
59
+
60
+ heatmap = cache["heatmaps"][code][:, :, feature]
61
+ heatmap = np.kron(heatmap, np.ones((32, 32)))
62
+ image = cache["image"].convert("RGBA")
63
+
64
+ jet = plt.cm.jet
65
+ cmap = jet(np.arange(jet.N))
66
+ cmap[:1, -1] = 0
67
+ cmap[1:, -1] = 0.6
68
+ cmap = ListedColormap(cmap)
69
+ heatmap = (heatmap - np.min(heatmap)) / (np.max(heatmap) - np.min(heatmap))
70
+ heatmap_rgba = cmap(heatmap)
71
+ heatmap_image = Image.fromarray((heatmap_rgba * 255).astype(np.uint8))
72
+ heatmap_with_transparency = Image.alpha_composite(image, heatmap_image)
73
+
74
+ return heatmap_with_transparency
75
+
76
+
77
+ def create_prompt_part(pipe, saes_dict, demo):
78
+ @spaces.GPU
79
+ def image_gen(prompt, timestep=None, num_steps=None, guidance_scale=None):
80
+ lock.acquire()
81
+ try:
82
+ # Default values
83
+ is_base_model = pipe.pipe.name_or_path == "stabilityai/stable-diffusion-xl-base-1.0"
84
+ default_n_steps = 25 if is_base_model else 1
85
+ default_guidance = base_guidance_scale_default if is_base_model else turbo_guidance_scale_default
86
+
87
+ # Use provided values if available, otherwise use defaults
88
+ n_steps = default_n_steps if num_steps is None else int(num_steps)
89
+ guidance = default_guidance if guidance_scale is None else float(guidance_scale)
90
+
91
+ # Convert timestep to integer if it's not None
92
+ timestep_int = None if timestep is None else int(timestep)
93
+
94
+ images, cache = pipe.run_with_cache(
95
+ prompt,
96
+ positions_to_cache=list(code_to_block.values()),
97
+ num_inference_steps=n_steps,
98
+ generator=torch.Generator(device="cpu").manual_seed(42),
99
+ guidance_scale=guidance,
100
+ save_input=True,
101
+ save_output=True
102
+ )
103
+ finally:
104
+ lock.release()
105
+
106
+ top_features_dict, top_sparse_maps_dict = process_cache(cache, saes_dict, timestep_int)
107
+ return images.images[0], {
108
+ "image": images.images[0],
109
+ "heatmaps": top_sparse_maps_dict,
110
+ "features": top_features_dict
111
+ }
112
+
113
+ def update_radio(cache, block_select):
114
+ code = block_select.split()[0]
115
+ return gr.update(choices=cache["features"][code])
116
+
117
+ def update_img(cache, block_select, radio):
118
+ new_img = plot_image_heatmap(cache, block_select, radio)
119
+ return new_img
120
+
121
+ def update_visibility():
122
+ is_base_model = pipe.pipe.name_or_path == "stabilityai/stable-diffusion-xl-base-1.0"
123
+ return gr.update(visible=is_base_model), gr.update(visible=is_base_model)
124
+
125
+ with gr.Tab("Explore", elem_classes="tabs") as explore_tab:
126
+ cache = gr.State(value={
127
+ "image": None,
128
+ "heatmaps": None,
129
+ "features": []
130
+ })
131
+ with gr.Row():
132
+ with gr.Column(scale=7):
133
+ with gr.Row(equal_height=True):
134
+ prompt_field = gr.Textbox(lines=1, label="Enter prompt here", value="A cinematic shot of a professor sloth wearing a tuxedo at a BBQ party and eathing a dish with peas.")
135
+ button = gr.Button("Generate", elem_classes="generate_button1")
136
+
137
+ with gr.Row():
138
+ image = gr.Image(width=512, height=512, image_mode="RGB", label="Generated image")
139
+
140
+ with gr.Column(scale=4):
141
+ block_select = gr.Dropdown(
142
+ choices=["up.0.1 (style)", "down.2.1 (composition)", "up.0.0 (details)", "mid.0"],
143
+ value="down.2.1 (composition)",
144
+ label="Select block",
145
+ elem_id="block_select",
146
+ interactive=True
147
+ )
148
+
149
+ # Add SDXL base specific controls
150
+ is_base_model = pipe.pipe.name_or_path == "stabilityai/stable-diffusion-xl-base-1.0"
151
+
152
+ with gr.Group() as sdxl_base_controls:
153
+ steps_slider = gr.Slider(
154
+ minimum=1,
155
+ maximum=50,
156
+ value=25 if is_base_model else 1,
157
+ step=1,
158
+ label="Number of steps",
159
+ elem_id="steps_slider",
160
+ interactive=True,
161
+ visible=is_base_model
162
+ )
163
+
164
+ guidance_slider = gr.Slider(
165
+ minimum=0.0,
166
+ maximum=15.0,
167
+ value=base_guidance_scale_default if is_base_model else turbo_guidance_scale_default,
168
+ step=0.1,
169
+ label="Guidance scale",
170
+ elem_id="guidance_slider",
171
+ interactive=True,
172
+ visible=is_base_model
173
+ )
174
+
175
+ # Add timestep selector
176
+ n_steps = 25 if is_base_model else 1
177
+ timestep_selector = gr.Slider(
178
+ minimum=0,
179
+ maximum=n_steps-1,
180
+ value=None,
181
+ step=1,
182
+ label="Timestep (leave empty for average across all steps)",
183
+ elem_id="timestep_selector",
184
+ interactive=True,
185
+ visible=is_base_model
186
+ )
187
+ recompute_button = gr.Button("Recompute", elem_id="recompute_button",
188
+ visible=is_base_model)
189
+
190
+ # Update max timestep when steps change
191
+ steps_slider.change(lambda s: gr.update(maximum=s-1), [steps_slider], [timestep_selector])
192
+
193
+ radio = gr.Radio(choices=[], label="Select a feature", interactive=True)
194
+
195
+ button.click(image_gen, [prompt_field, timestep_selector, steps_slider, guidance_slider], outputs=[image, cache])
196
+ cache.change(update_radio, [cache, block_select], outputs=[radio])
197
+ block_select.select(update_radio, [cache, block_select], outputs=[radio])
198
+ radio.select(update_img, [cache, block_select, radio], outputs=[image])
199
+ recompute_button.click(image_gen, [prompt_field, timestep_selector, steps_slider, guidance_slider], outputs=[image, cache])
200
+ demo.load(image_gen, [prompt_field, timestep_selector, steps_slider, guidance_slider], outputs=[image, cache])
201
+
202
+ return explore_tab
203
+
204
+ def downsample_mask(image, factor):
205
+ downsampled = image.reshape(
206
+ (image.shape[0] // factor, factor,
207
+ image.shape[1] // factor, factor)
208
+ )
209
+ downsampled = downsampled.mean(axis=(1, 3))
210
+ return downsampled
211
+
212
+ def create_intervene_part(pipe: HookedStableDiffusionXLPipeline, saes_dict, means_dict, demo):
213
+ @spaces.GPU
214
+ def image_gen(prompt, num_steps, guidance_scale=None):
215
+ lock.acquire()
216
+ is_base_model = pipe.pipe.name_or_path == "stabilityai/stable-diffusion-xl-base-1.0"
217
+ default_guidance = base_guidance_scale_default if is_base_model else turbo_guidance_scale_default
218
+ guidance = default_guidance if guidance_scale is None else float(guidance_scale)
219
+ try:
220
+ images = pipe.run_with_hooks(
221
+ prompt,
222
+ position_hook_dict={},
223
+ num_inference_steps=int(num_steps),
224
+ generator=torch.Generator(device="cpu").manual_seed(42),
225
+ guidance_scale=guidance,
226
+ )
227
+ finally:
228
+ lock.release()
229
+ if images.images[0].size == (1024, 1024):
230
+ return images.images[0].resize((512, 512)), images.images[0].resize((512, 512))
231
+ else:
232
+ return images.images[0], images.images[0]
233
+
234
+ @spaces.GPU
235
+ def image_mod(prompt, block_str, brush_index, strength, num_steps, input_image, guidance_scale=None, start_index=None, end_index=None):
236
+ block = block_str.split(" ")[0]
237
+ is_base_model = pipe.pipe.name_or_path == "stabilityai/stable-diffusion-xl-base-1.0"
238
+ mask = (input_image["layers"][0] > 0)[:, :, -1].astype(float)
239
+ if is_base_model:
240
+ mask = downsample_mask(mask, 16)
241
+ else:
242
+ mask = downsample_mask(mask, 32)
243
+ mask = torch.tensor(mask, dtype=torch.float32, device="cuda")
244
+
245
+ if mask.sum() == 0:
246
+ gr.Info("No mask selected, please draw on the input image")
247
+
248
+ if is_base_model:
249
+ # Set default values for start_index and end_index if not provided
250
+ if start_index is None:
251
+ start_index = 0
252
+ if end_index is None:
253
+ end_index = int(num_steps)
254
+
255
+ # Ensure start_index and end_index are within valid ranges
256
+ start_index = max(0, min(int(start_index), int(num_steps)))
257
+ end_index = max(0, min(int(end_index), int(num_steps)))
258
+
259
+ # Ensure start_index is less than end_index
260
+ if start_index >= end_index:
261
+ start_index = max(0, end_index - 1)
262
+ def myhook(module, input, output):
263
+ return add_feature_on_area_base(
264
+ saes_dict[block],
265
+ brush_index,
266
+ mask * means_dict[block][brush_index] * strength,
267
+ module,
268
+ input,
269
+ output)
270
+ hook = TimedHook(myhook, int(num_steps), np.arange(start_index, end_index))
271
+ else:
272
+ def hook(module, input, output):
273
+ return add_feature_on_area_turbo(
274
+ saes_dict[block],
275
+ brush_index,
276
+ mask * means_dict[block][brush_index] * strength,
277
+ module,
278
+ input,
279
+ output)
280
+ lock.acquire()
281
+ is_base_model = pipe.pipe.name_or_path == "stabilityai/stable-diffusion-xl-base-1.0"
282
+ default_guidance = base_guidance_scale_default if is_base_model else turbo_guidance_scale_default
283
+ guidance = default_guidance if guidance_scale is None else float(guidance_scale)
284
+
285
+ try:
286
+ image = pipe.run_with_hooks(
287
+ prompt,
288
+ position_hook_dict={code_to_block[block]: hook},
289
+ num_inference_steps=int(num_steps),
290
+ generator=torch.Generator(device="cpu").manual_seed(42),
291
+ guidance_scale=guidance
292
+ ).images[0]
293
+ finally:
294
+ lock.release()
295
+ return image
296
+
297
+ @spaces.GPU
298
+ def feature_icon(block_str, brush_index, guidance_scale=None):
299
+ block = block_str.split(" ")[0]
300
+ if block in ["mid.0", "up.0.0"]:
301
+ gr.Info("Note that Feature Icon works best with down.2.1 and up.0.1 blocks but feel free to explore", duration=3)
302
+
303
+ def hook(module, input, output):
304
+ if is_base_model:
305
+ return replace_with_feature_base(
306
+ saes_dict[block],
307
+ brush_index,
308
+ means_dict[block][brush_index] * saes_dict[block].k,
309
+ module,
310
+ input,
311
+ output
312
+ )
313
+ else:
314
+ return replace_with_feature_turbo(
315
+ saes_dict[block],
316
+ brush_index,
317
+ means_dict[block][brush_index] * saes_dict[block].k,
318
+ module,
319
+ input,
320
+ output)
321
+ lock.acquire()
322
+ is_base_model = pipe.pipe.name_or_path == "stabilityai/stable-diffusion-xl-base-1.0"
323
+ n_steps = 25 if is_base_model else 1
324
+ default_guidance = base_guidance_scale_default if is_base_model else turbo_guidance_scale_default
325
+ guidance = default_guidance if guidance_scale is None else float(guidance_scale)
326
+
327
+ try:
328
+ image = pipe.run_with_hooks(
329
+ "",
330
+ position_hook_dict={code_to_block[block]: hook},
331
+ num_inference_steps=n_steps,
332
+ generator=torch.Generator(device="cpu").manual_seed(42),
333
+ guidance_scale=guidance,
334
+ ).images[0]
335
+ finally:
336
+ lock.release()
337
+ return image
338
+
339
+ with gr.Tab("Paint!", elem_classes="tabs") as intervene_tab:
340
+ image_state = gr.State(value=None)
341
+ with gr.Row():
342
+ with gr.Column(scale=3):
343
+ # Generation column
344
+ with gr.Row():
345
+ # prompt and num_steps
346
+ is_base_model = pipe.pipe.name_or_path == "stabilityai/stable-diffusion-xl-base-1.0"
347
+ n_steps = 25 if is_base_model else 1
348
+ prompt_field = gr.Textbox(lines=1, label="Enter prompt here", value="A dog plays with a ball, closeup", elem_id="prompt_input")
349
+
350
+ with gr.Row():
351
+ num_steps = gr.Number(value=n_steps, label="Number of steps", minimum=1, maximum=50, elem_id="num_steps", precision=0)
352
+ guidance_slider = gr.Slider(
353
+ minimum=0.0,
354
+ maximum=15.0,
355
+ value=base_guidance_scale_default if is_base_model else turbo_guidance_scale_default,
356
+ step=0.1,
357
+ label="Guidance scale",
358
+ elem_id="paint_guidance_slider",
359
+ interactive=True,
360
+ visible=is_base_model
361
+ )
362
+
363
+ with gr.Row():
364
+ # Generate button
365
+ button_generate = gr.Button("Generate", elem_id="generate_button")
366
+ with gr.Column(scale=3):
367
+ # Intervention column
368
+ with gr.Row():
369
+ # dropdowns and number inputs
370
+ with gr.Column(scale=7):
371
+ with gr.Row():
372
+ block_select = gr.Dropdown(
373
+ choices=["up.0.1 (style)", "down.2.1 (composition)", "up.0.0 (details)", "mid.0"],
374
+ value="down.2.1 (composition)",
375
+ label="Select block",
376
+ elem_id="block_select"
377
+ )
378
+ brush_index = gr.Number(value=4998, label="Brush index", minimum=0, maximum=5119, elem_id="brush_index", precision=0)
379
+ with gr.Row():
380
+ button_icon = gr.Button('Feature Icon', elem_id="feature_icon_button")
381
+ with gr.Row():
382
+ gr.Markdown("**TimedHook Range** (which steps to apply the feature)", visible=is_base_model)
383
+ with gr.Row():
384
+ start_index = gr.Number(value=5 if is_base_model else 0, label="Start index", minimum=0, maximum=n_steps, elem_id="start_index", precision=0, visible=is_base_model)
385
+ end_index = gr.Number(value=20 if is_base_model else 1, label="End index", minimum=0, maximum=n_steps, elem_id="end_index", precision=0, visible=is_base_model)
386
+ with gr.Column(scale=3):
387
+ with gr.Row():
388
+ strength = gr.Number(value=10, label="Strength", minimum=-40, maximum=40, elem_id="strength", precision=2)
389
+ with gr.Row():
390
+ button = gr.Button('Apply', elem_id="apply_button")
391
+
392
+ with gr.Row():
393
+ with gr.Column():
394
+ # Input image
395
+ i_image = gr.Sketchpad(
396
+ height=600,
397
+ layers=False, transforms=None, placeholder="Generate and paint!",
398
+ container=False,
399
+ brush=gr.Brush(default_size=40, color_mode="fixed", colors=['black']),
400
+ canvas_size=(512, 512),
401
+ label="Input Image")
402
+ clear_button = gr.Button("Clear")
403
+ clear_button.click(lambda x: x, [image_state], [i_image])
404
+ # Output image
405
+ o_image = gr.Image(width=512, height=512, label="Output Image")
406
+
407
+ # Set up the click events
408
+ button_generate.click(image_gen, inputs=[prompt_field, num_steps, guidance_slider], outputs=[image_state, o_image])
409
+ image_state.change(lambda x: x, [image_state], [i_image])
410
+
411
+ if is_base_model:
412
+ # Update max values for start_index and end_index when num_steps changes
413
+ def update_index_maxes(steps):
414
+ return gr.update(maximum=steps), gr.update(maximum=steps)
415
+
416
+ num_steps.change(update_index_maxes, [num_steps], [start_index, end_index])
417
+
418
+ button.click(image_mod,
419
+ inputs=[prompt_field, block_select, brush_index, strength, num_steps, i_image, guidance_slider, start_index, end_index],
420
+ outputs=o_image)
421
+ button_icon.click(feature_icon, inputs=[block_select, brush_index, guidance_slider], outputs=o_image)
422
+ demo.load(image_gen, [prompt_field, num_steps, guidance_slider], outputs=[image_state, o_image])
423
+
424
+
425
+ return intervene_tab
426
+
427
+ def create_top_images_part(demo):
428
+ def update_top_images(block_select, brush_index):
429
+ block = block_select.split(" ")[0]
430
+ url = f"https://huggingface.co/datasets/surokpro2/sdxl_sae_images/resolve/main/{block}/{brush_index}.jpg"
431
+ return url
432
+
433
+ with gr.Tab("Top Images", elem_classes="tabs") as top_images_tab:
434
+ with gr.Row():
435
+ block_select = gr.Dropdown(
436
+ choices=["up.0.1 (style)", "down.2.1 (composition)", "up.0.0 (details)", "mid.0"],
437
+ value="down.2.1 (composition)",
438
+ label="Select blk"
439
+ )
440
+ brush_index = gr.Number(value=0, label="Brush index", minimum=0, maximum=5119, precision=0)
441
+ with gr.Row():
442
+ image = gr.Image(width=600, height=600, label="Top Images")
443
+
444
+ block_select.select(update_top_images, [block_select, brush_index], outputs=[image])
445
+ brush_index.change(update_top_images, [block_select, brush_index], outputs=[image])
446
+ demo.load(update_top_images, [block_select, brush_index], outputs=[image])
447
+ return top_images_tab
448
+
449
+
450
+ def create_intro_part():
451
+ with gr.Tab("Instructions", elem_classes="tabs") as intro_tab:
452
+ gr.Markdown(
453
+ '''# One-Step is Enough: Sparse Autoencoders for Text-to-Image Diffusion Models
454
+ ## Stable Diffustion XL multistep version
455
+
456
+ ## Note
457
+ If you encounter GPU time limit errors, don't worry, the app still works and you can use it freely.
458
+
459
+ ## Demo Overview
460
+ This demo showcases the use of Sparse Autoencoders (SAEs) to understand the features learned by the Stable Diffusion XL (Turbo) model.
461
+
462
+ ## How to Use
463
+ ### Explore
464
+ * Enter a prompt in the text box and click on the "Generate" button to generate an image.
465
+ * You can observe the active features in different blocks plot on top of the generated image.
466
+ ### Top Images
467
+ * For each feature, you can view the top images that activate the feature the most.
468
+ ### Paint!
469
+ * Generate an image using the prompt.
470
+ * Paint on the generated image to apply interventions.
471
+ * Use the "Feature Icon" button to understand how the selected brush functions.
472
+
473
+ ### Remarks
474
+ * Not all brushes mix well with all images. Experiment with different brushes and strengths.
475
+ * Feature Icon works best with `down.2.1 (composition)` and `up.0.1 (style)` blocks.
476
+ * This demo is provided for research purposes only. We do not take responsibility for the content generated by the demo.
477
+
478
+ ### Interesting features to try
479
+ To get started, try the following features:
480
+ - down.2.1 (composition): 2301 (evil) 3747 (image frame) 4998 (cartoon)
481
+ - up.0.1 (style): 4977 (tiger stripes) 90 (fur) 2615 (twilight blur)
482
+ '''
483
+ )
484
+
485
+ return intro_tab
486
+
487
+
488
+ def create_demo(pipe, saes_dict, means_dict):
489
+ custom_css = """
490
+ .tabs button {
491
+ font-size: 20px !important; /* Adjust font size for tab text */
492
+ padding: 10px !important; /* Adjust padding to make the tabs bigger */
493
+ font-weight: bold !important; /* Adjust font weight to make the text bold */
494
+ }
495
+ .generate_button1 {
496
+ max-width: 160px !important;
497
+ margin-top: 20px !important;
498
+ margin-bottom: 20px !important;
499
+ }
500
+ """
501
+
502
+ with gr.Blocks(css=custom_css) as demo:
503
+ with create_intro_part():
504
+ pass
505
+ with create_prompt_part(pipe, saes_dict, demo):
506
+ pass
507
+ with create_top_images_part(demo):
508
+ pass
509
+ with create_intervene_part(pipe, saes_dict, means_dict, demo):
510
+ pass
511
+
512
+ return demo
513
+
514
+ if __name__ == "__main__":
515
+ import os
516
+ import gradio as gr
517
+ import torch
518
+ from SDLens import HookedStableDiffusionXLPipeline
519
+ from SAE import SparseAutoencoder
520
+
521
+ dtype=torch.float32
522
+ pipe = HookedStableDiffusionXLPipeline.from_pretrained(
523
+ 'stabilityai/stable-diffusion-xl-base-1.0',
524
+ torch_dtype=dtype,
525
+ variant=("fp16" if dtype==torch.float16 else None)
526
+ )
527
+ pipe.set_progress_bar_config(disable=True)
528
+ pipe.to('cuda')
529
+
530
+ path_to_checkpoints = './checkpoints/'
531
+
532
+ code_to_block = {
533
+ "down.2.1": "unet.down_blocks.2.attentions.1",
534
+ "mid.0": "unet.mid_block.attentions.0",
535
+ "up.0.1": "unet.up_blocks.0.attentions.1",
536
+ "up.0.0": "unet.up_blocks.0.attentions.0"
537
+ }
538
+
539
+ saes_dict = {}
540
+ means_dict = {}
541
+
542
+ for code, block in code_to_block.items():
543
+ sae = SparseAutoencoder.load_from_disk(
544
+ os.path.join(path_to_checkpoints, f"{block}_k10_hidden5120_auxk256_bs4096_lr0.0001", "final"),
545
+ )
546
+ means = torch.load(
547
+ os.path.join(path_to_checkpoints, f"{block}_k10_hidden5120_auxk256_bs4096_lr0.0001", "final", "mean.pt"),
548
+ weights_only=True
549
+ )
550
+ saes_dict[code] = sae.to('cuda', dtype=dtype)
551
+ means_dict[code] = means.to('cuda', dtype=dtype)
552
+
553
+ demo = create_demo(pipe, saes_dict, means_dict)
554
+ demo.launch()
checkpoints/unet.down_blocks.2.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"n_dirs_local": 5120, "d_model": 1280, "k": 10, "auxk": 256, "dead_steps_threshold": 2441}
checkpoints/unet.down_blocks.2.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/mean.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:387f2b6f8c4e4a6f1227921f28f00dfa4beb2bd4e422b7eb592cd8627af0e58f
3
+ size 21581
checkpoints/unet.down_blocks.2.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/state_dict.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:39e3c6d17aa572a53368ca8ba8f82757947a3caf14fe654e84b175d0dc0a4650
3
+ size 52497831
checkpoints/unet.down_blocks.2.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/std.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c6ca694c9504a7a8aa827004d3fdec5c1cb8fcf3904acc3562d1861fc6e65c19
3
+ size 21576
checkpoints/unet.mid_block.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"n_dirs_local": 5120, "d_model": 1280, "k": 10, "auxk": 256, "dead_steps_threshold": 2441}
checkpoints/unet.mid_block.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/mean.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:80790481d0e56ac3fa36599703cee7a05cfb4cc078db57c8f9180e860c330e1d
3
+ size 21581
checkpoints/unet.mid_block.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/state_dict.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:49d38d9178c2a2780e04a5482a2feb9548c6e9a636ed1bf85291acf42e0ffa34
3
+ size 52497831
checkpoints/unet.mid_block.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/std.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bb6bfc7ce5e596f8aa048ab262ca56841868c222bf07eb2ed35b6e4f7094fea6
3
+ size 21576
checkpoints/unet.up_blocks.0.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"n_dirs_local": 5120, "d_model": 1280, "k": 10, "auxk": 256, "dead_steps_threshold": 2441}
checkpoints/unet.up_blocks.0.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/mean.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:de036d0fb9ee663f7bdf60e4a5d89d038516dae637531676b53ff75d05eab46b
3
+ size 21581
checkpoints/unet.up_blocks.0.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/state_dict.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:14c45efd9cce0258f014c49babdcd0e9ce8b266fe31eed72db1a45b990a1a0f8
3
+ size 52497831
checkpoints/unet.up_blocks.0.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/std.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cb9c04499ccae041987cc262894e254c2f04288857a8a0470cfb1b86a8ecfa09
3
+ size 21576
checkpoints/unet.up_blocks.0.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"n_dirs_local": 5120, "d_model": 1280, "k": 10, "auxk": 256, "dead_steps_threshold": 2441}
checkpoints/unet.up_blocks.0.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/mean.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:96dbf6fffe9d62c3b3352f8e4fe48c54dfd69906cf8ad6828d5ce93db9a5f0dc
3
+ size 21581
checkpoints/unet.up_blocks.0.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/state_dict.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f8eed82f4bcb2f010ae9075f10a1ece801ee3dec46dba7fadccc35f6c0a7836b
3
+ size 52497831
checkpoints/unet.up_blocks.0.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/std.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fe5c5be0c4c2d2b57e7888319053cb64929559f947c8ce445ddd6a397302afab
3
+ size 21576
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ diffusers==0.29.2
2
+ --extra-index-url https://download.pytorch.org/whl/cu113
3
+ torch
4
+ numpy
5
+ matplotlib
6
+ pillow
7
+ wandb
8
+ einops
9
+ transformers
10
+ accelerate
scripts/collect_latents_dataset.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import io
4
+ import tarfile
5
+ import torch
6
+ import webdataset as wds
7
+ import numpy as np
8
+
9
+ from tqdm import tqdm
10
+ sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
11
+ from SDLens.hooked_sd_pipeline import HookedStableDiffusionXLPipeline
12
+
13
+ import datetime
14
+ from datasets import load_dataset
15
+ from torch.utils.data import DataLoader
16
+ import diffusers
17
+ import fire
18
+
19
+ def main(save_path, start_at=0, finish_at=30000, dataset_batch_size=50):
20
+ blocks_to_save = [
21
+ 'unet.down_blocks.2.attentions.1',
22
+ 'unet.mid_block.attentions.0',
23
+ 'unet.up_blocks.0.attentions.0',
24
+ 'unet.up_blocks.0.attentions.1',
25
+ ]
26
+
27
+ # Initialization
28
+ dataset = load_dataset("guangyil/laion-coco-aesthetic", split="train", columns=["caption"], streaming=True).shuffle(seed=42)
29
+ pipe = HookedStableDiffusionXLPipeline.from_pretrained('stabilityai/sdxl-turbo')
30
+ pipe.to('cuda')
31
+ pipe.set_progress_bar_config(disable=True)
32
+ dataloader = DataLoader(dataset, batch_size=dataset_batch_size)
33
+
34
+ ct = datetime.datetime.now()
35
+ save_path = os.path.join(save_path, str(ct))
36
+ # Collecting dataset
37
+ os.makedirs(save_path, exist_ok=True)
38
+
39
+ writers = {
40
+ block: wds.TarWriter(f'{save_path}/{block}.tar') for block in blocks_to_save
41
+ }
42
+
43
+ writers.update({'images': wds.TarWriter(f'{save_path}/images.tar')})
44
+
45
+ def to_kwargs(kwargs_to_save):
46
+ kwargs = kwargs_to_save.copy()
47
+ seed = kwargs['seed']
48
+ del kwargs['seed']
49
+ kwargs['generator'] = torch.Generator(device="cpu").manual_seed(num_document)
50
+ return kwargs
51
+
52
+ dataloader_iter = iter(dataloader)
53
+ for num_document, batch in tqdm(enumerate(dataloader)):
54
+ if num_document < start_at:
55
+ continue
56
+
57
+ if num_document >= finish_at:
58
+ break
59
+
60
+ kwargs_to_save = {
61
+ 'prompt': batch['caption'],
62
+ 'positions_to_cache': blocks_to_save,
63
+ 'save_input': True,
64
+ 'save_output': True,
65
+ 'num_inference_steps': 1,
66
+ 'guidance_scale': 0.0,
67
+ 'seed': num_document,
68
+ 'output_type': 'pil'
69
+ }
70
+
71
+ kwargs = to_kwargs(kwargs_to_save)
72
+
73
+ output, cache = pipe.run_with_cache(
74
+ **kwargs
75
+ )
76
+
77
+ blocks = cache['input'].keys()
78
+ for block in blocks:
79
+ sample = {
80
+ "__key__": f"sample_{num_document}",
81
+ "output.pth": cache['output'][block],
82
+ "diff.pth": cache['output'][block] - cache['input'][block],
83
+ "gen_args.json": kwargs_to_save
84
+ }
85
+
86
+ writers[block].write(sample)
87
+ writers['images'].write({
88
+ "__key__": f"sample_{num_document}",
89
+ "images.npy": np.stack(output.images)
90
+ })
91
+
92
+ for block, writer in writers.items():
93
+ writer.close()
94
+
95
+ if __name__ == '__main__':
96
+ fire.Fire(main)
scripts/train_sae.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Adapted from
3
+ https://github.com/openai/sparse_autoencoder/blob/main/sparse_autoencoder/train.py
4
+ '''
5
+
6
+
7
+ import os
8
+ import sys
9
+ sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
10
+ from typing import Callable, Iterable, Iterator
11
+
12
+ import torch
13
+ import torch.distributed as dist
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ from torch.distributed import ReduceOp
17
+ from SAE.dataset_iterator import ActivationsDataloader
18
+ from SAE.sae import SparseAutoencoder, unit_norm_decoder_, unit_norm_decoder_grad_adjustment_
19
+ from SAE.sae_utils import SAETrainingConfig, Config
20
+
21
+ from types import SimpleNamespace
22
+ from typing import Optional, List
23
+ import json
24
+
25
+ import tqdm
26
+
27
+ def weighted_average(points: torch.Tensor, weights: torch.Tensor):
28
+ weights = weights / weights.sum()
29
+ return (points * weights.view(-1, 1)).sum(dim=0)
30
+
31
+
32
+ @torch.no_grad()
33
+ def geometric_median_objective(
34
+ median: torch.Tensor, points: torch.Tensor, weights: torch.Tensor
35
+ ) -> torch.Tensor:
36
+
37
+ norms = torch.linalg.norm(points - median.view(1, -1), dim=1) # type: ignore
38
+
39
+ return (norms * weights).sum()
40
+
41
+
42
+ def compute_geometric_median(
43
+ points: torch.Tensor,
44
+ weights: Optional[torch.Tensor] = None,
45
+ eps: float = 1e-6,
46
+ maxiter: int = 100,
47
+ ftol: float = 1e-20,
48
+ do_log: bool = False,
49
+ ):
50
+ """
51
+ :param points: ``torch.Tensor`` of shape ``(n, d)``
52
+ :param weights: Optional ``torch.Tensor`` of shape :math:``(n,)``.
53
+ :param eps: Smallest allowed value of denominator, to avoid divide by zero.
54
+ Equivalently, this is a smoothing parameter. Default 1e-6.
55
+ :param maxiter: Maximum number of Weiszfeld iterations. Default 100
56
+ :param ftol: If objective value does not improve by at least this `ftol` fraction, terminate the algorithm. Default 1e-20.
57
+ :param do_log: If true will return a log of function values encountered through the course of the algorithm
58
+ :return: SimpleNamespace object with fields
59
+ - `median`: estimate of the geometric median, which is a ``torch.Tensor`` object of shape :math:``(d,)``
60
+ - `termination`: string explaining how the algorithm terminated.
61
+ - `logs`: function values encountered through the course of the algorithm in a list (None if do_log is false).
62
+ """
63
+ with torch.no_grad():
64
+
65
+ if weights is None:
66
+ weights = torch.ones((points.shape[0],), device=points.device)
67
+ # initialize median estimate at mean
68
+ new_weights = weights
69
+ median = weighted_average(points, weights)
70
+ objective_value = geometric_median_objective(median, points, weights)
71
+ if do_log:
72
+ logs = [objective_value]
73
+ else:
74
+ logs = None
75
+
76
+ # Weiszfeld iterations
77
+ early_termination = False
78
+ pbar = tqdm.tqdm(range(maxiter))
79
+ for _ in pbar:
80
+ prev_obj_value = objective_value
81
+
82
+ norms = torch.linalg.norm(points - median.view(1, -1), dim=1) # type: ignore
83
+ new_weights = weights / torch.clamp(norms, min=eps)
84
+ median = weighted_average(points, new_weights)
85
+ objective_value = geometric_median_objective(median, points, weights)
86
+
87
+ if logs is not None:
88
+ logs.append(objective_value)
89
+ if abs(prev_obj_value - objective_value) <= ftol * objective_value:
90
+ early_termination = True
91
+ break
92
+
93
+ pbar.set_description(f"Objective value: {objective_value:.4f}")
94
+
95
+ median = weighted_average(points, new_weights) # allow autodiff to track it
96
+ return SimpleNamespace(
97
+ median=median,
98
+ new_weights=new_weights,
99
+ termination=(
100
+ "function value converged within tolerance"
101
+ if early_termination
102
+ else "maximum iterations reached"
103
+ ),
104
+ logs=logs,
105
+ )
106
+
107
+ def maybe_transpose(x):
108
+ return x.T if not x.is_contiguous() and x.T.is_contiguous() else x
109
+
110
+ import wandb
111
+
112
+ RANK = 0
113
+
114
+ class Logger:
115
+ def __init__(self, sae_name, **kws):
116
+ self.vals = {}
117
+ self.enabled = (RANK == 0) and not kws.pop("dummy", False)
118
+ self.sae_name = sae_name
119
+
120
+ def logkv(self, k, v):
121
+ if self.enabled:
122
+ self.vals[f'{self.sae_name}/{k}'] = v.detach() if isinstance(v, torch.Tensor) else v
123
+ return v
124
+
125
+ def dumpkvs(self, step):
126
+ if self.enabled:
127
+ wandb.log(self.vals, step=step)
128
+ self.vals = {}
129
+
130
+
131
+ class FeaturesStats:
132
+ def __init__(self, dim, logger):
133
+ self.dim = dim
134
+ self.logger = logger
135
+ self.reinit()
136
+
137
+ def reinit(self):
138
+ self.n_activated = torch.zeros(self.dim, dtype=torch.long, device="cuda")
139
+ self.n = 0
140
+
141
+ def update(self, inds):
142
+ self.n += inds.shape[0]
143
+ inds = inds.flatten().detach()
144
+ self.n_activated.scatter_add_(0, inds, torch.ones_like(inds))
145
+
146
+ def log(self):
147
+ self.logger.logkv('activated', (self.n_activated / self.n + 1e-9).log10().cpu().numpy())
148
+
149
+ def training_loop_(
150
+ aes,
151
+ train_acts_iter,
152
+ loss_fn,
153
+ log_interval,
154
+ save_interval,
155
+ loggers,
156
+ sae_cfgs,
157
+ ):
158
+ sae_packs = []
159
+ for ae, cfg, logger in zip(aes, sae_cfgs, loggers):
160
+ pbar = tqdm.tqdm(unit=" steps", desc="Training Loss: ")
161
+ fstats = FeaturesStats(ae.n_dirs, logger)
162
+ opt = torch.optim.Adam(ae.parameters(), lr=cfg.lr, eps=cfg.eps, fused=True)
163
+ sae_packs.append((ae, cfg, logger, pbar, fstats, opt))
164
+
165
+ for i, flat_acts_train_batch in enumerate(train_acts_iter):
166
+ flat_acts_train_batch = flat_acts_train_batch.cuda()
167
+
168
+ for ae, cfg, logger, pbar, fstats, opt in sae_packs:
169
+ recons, info = ae(flat_acts_train_batch)
170
+ loss = loss_fn(ae, cfg, flat_acts_train_batch, recons, info, logger)
171
+
172
+ fstats.update(info['inds'])
173
+
174
+ bs = flat_acts_train_batch.shape[0]
175
+ logger.logkv('not-activated 1e4', (ae.stats_last_nonzero > 1e4 / bs).mean(dtype=float).item())
176
+ logger.logkv('not-activated 1e6', (ae.stats_last_nonzero > 1e6 / bs).mean(dtype=float).item())
177
+ logger.logkv('not-activated 1e7', (ae.stats_last_nonzero > 1e7 / bs).mean(dtype=float).item())
178
+
179
+ logger.logkv('explained variance', explained_variance(recons, flat_acts_train_batch))
180
+ logger.logkv('l2_div', (torch.linalg.norm(recons, dim=1) / torch.linalg.norm(flat_acts_train_batch, dim=1)).mean())
181
+
182
+ if (i + 1) % log_interval == 0:
183
+ fstats.log()
184
+ fstats.reinit()
185
+
186
+ if (i + 1) % save_interval == 0:
187
+ ae.save_to_disk(f"{cfg.save_path}/{i + 1}")
188
+
189
+ loss.backward()
190
+
191
+ unit_norm_decoder_(ae)
192
+ unit_norm_decoder_grad_adjustment_(ae)
193
+
194
+ opt.step()
195
+ opt.zero_grad()
196
+ logger.dumpkvs(i)
197
+
198
+ pbar.set_description(f"Training Loss {loss.item():.4f}")
199
+ pbar.update(1)
200
+
201
+
202
+ for ae, cfg, logger, pbar, fstats, opt in sae_packs:
203
+ pbar.close()
204
+ ae.save_to_disk(f"{cfg.save_path}/final")
205
+
206
+
207
+ def init_from_data_(ae, stats_acts_sample):
208
+ ae.pre_bias.data = (
209
+ compute_geometric_median(stats_acts_sample[:32768].float().cpu()).median.cuda().float()
210
+ )
211
+
212
+
213
+ def mse(recons, x):
214
+ # return ((recons - x) ** 2).sum(dim=-1).mean()
215
+ return ((recons - x) ** 2).mean()
216
+
217
+ def normalized_mse(recon: torch.Tensor, xs: torch.Tensor) -> torch.Tensor:
218
+ # only used for auxk
219
+ xs_mu = xs.mean(dim=0)
220
+
221
+ loss = mse(recon, xs) / mse(
222
+ xs_mu[None, :].broadcast_to(xs.shape), xs
223
+ )
224
+
225
+ return loss
226
+
227
+ def explained_variance(recons, x):
228
+ # Compute the variance of the difference
229
+ diff = x - recons
230
+ diff_var = torch.var(diff, dim=0, unbiased=False)
231
+
232
+ # Compute the variance of the original tensor
233
+ x_var = torch.var(x, dim=0, unbiased=False)
234
+
235
+ # Avoid division by zero
236
+ explained_var = 1 - diff_var / (x_var + 1e-8)
237
+
238
+ return explained_var.mean()
239
+
240
+
241
+ def main():
242
+ cfg = Config(json.load(open('SAE/config.json')))
243
+
244
+ dataloader = ActivationsDataloader(cfg.paths_to_latents, cfg.block_name, cfg.bs)
245
+
246
+ acts_iter = dataloader.iterate()
247
+ stats_acts_sample = torch.cat([
248
+ next(acts_iter).cpu() for _ in range(10)
249
+ ], dim=0)
250
+
251
+ aes = [
252
+ SparseAutoencoder(
253
+ n_dirs_local=sae.n_dirs,
254
+ d_model=sae.d_model,
255
+ k=sae.k,
256
+ auxk=sae.auxk,
257
+ dead_steps_threshold=sae.dead_toks_threshold // cfg.bs,
258
+ ).cuda()
259
+ for sae in cfg.saes
260
+ ]
261
+
262
+ for ae in aes:
263
+ init_from_data_(ae, stats_acts_sample)
264
+
265
+ mse_scale = (
266
+ 1 / ((stats_acts_sample.float().mean(dim=0) - stats_acts_sample.float()) ** 2).mean()
267
+ )
268
+ mse_scale = mse_scale.item()
269
+ del stats_acts_sample
270
+
271
+ wandb.init(
272
+ project=cfg.wandb_project,
273
+ name=cfg.wandb_name,
274
+ )
275
+
276
+ loggers = [Logger(
277
+ sae_name=cfg_sae.sae_name,
278
+ dummy=False,
279
+ ) for cfg_sae in cfg.saes]
280
+
281
+ training_loop_(
282
+ aes,
283
+ acts_iter,
284
+ lambda ae, cfg_sae, flat_acts_train_batch, recons, info, logger: (
285
+ # MSE
286
+ logger.logkv("train_recons", mse_scale * mse(recons, flat_acts_train_batch))
287
+ # AuxK
288
+ + logger.logkv(
289
+ "train_maxk_recons",
290
+ cfg_sae.auxk_coef
291
+ * normalized_mse(
292
+ ae.decode_sparse(
293
+ info["auxk_inds"],
294
+ info["auxk_vals"],
295
+ ),
296
+ flat_acts_train_batch - recons.detach() + ae.pre_bias.detach(),
297
+ ).nan_to_num(0),
298
+ )
299
+ ),
300
+ sae_cfgs = cfg.saes,
301
+ loggers=loggers,
302
+ log_interval=cfg.log_interval,
303
+ save_interval=cfg.save_interval,
304
+ )
305
+
306
+
307
+ if __name__ == "__main__":
308
+ main()
utils/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .hooks import *
utils/hooks.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ class TimedHook:
4
+ def __init__(self, hook_fn, total_steps, apply_at_steps=None):
5
+ self.hook_fn = hook_fn
6
+ self.total_steps = total_steps
7
+ self.apply_at_steps = apply_at_steps
8
+ self.current_step = 0
9
+
10
+ def identity(self, module, input, output):
11
+ return output
12
+
13
+ def __call__(self, module, input, output):
14
+ if self.apply_at_steps is not None:
15
+ if self.current_step in self.apply_at_steps:
16
+ self.__increment()
17
+ return self.hook_fn(module, input, output)
18
+ else:
19
+ self.__increment()
20
+ return self.identity(module, input, output)
21
+
22
+ return self.identity(module, input, output)
23
+
24
+ def __increment(self):
25
+ if self.current_step < self.total_steps:
26
+ self.current_step += 1
27
+ else:
28
+ self.current_step = 0
29
+
30
+ @torch.no_grad()
31
+ def add_feature(sae, feature_idx, value, module, input, output):
32
+ diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
33
+ activated = sae.encode(diff)
34
+ mask = torch.zeros_like(activated, device=diff.device)
35
+ mask[..., feature_idx] = value
36
+ to_add = mask @ sae.decoder.weight.T
37
+ return (output[0] + to_add.permute(0, 3, 1, 2).to(output[0].device),)
38
+
39
+ @torch.no_grad()
40
+ def add_feature_on_area_base(sae, feature_idx, activation_map, module, input, output):
41
+ return add_feature_on_area_base_both(sae, feature_idx, activation_map, module, input, output)
42
+
43
+ @torch.no_grad()
44
+ def add_feature_on_area_base_both(sae, feature_idx, activation_map, module, input, output):
45
+ # add the feature to cond and subtract from uncond
46
+ # this assumes diff.shape[0] == 2
47
+ diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
48
+ activated = sae.encode(diff)
49
+ mask = torch.zeros_like(activated, device=diff.device)
50
+ if len(activation_map) == 2:
51
+ activation_map = activation_map.unsqueeze(0)
52
+ mask[..., feature_idx] = mask[..., feature_idx] = activation_map.to(mask.device)
53
+ to_add = mask @ sae.decoder.weight.T
54
+ to_add = to_add.chunk(2)
55
+ output[0][0] -= to_add[0].permute(0, 3, 1, 2).to(output[0].device)[0]
56
+ output[0][1] += to_add[1].permute(0, 3, 1, 2).to(output[0].device)[0]
57
+ return output
58
+
59
+
60
+ @torch.no_grad()
61
+ def add_feature_on_area_base_cond(sae, feature_idx, activation_map, module, input, output):
62
+ # add the feature to cond
63
+ # this assumes diff.shape[0] == 2
64
+ diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
65
+ diff_uncond, diff_cond = diff.chunk(2)
66
+ activated = sae.encode(diff_cond)
67
+ mask = torch.zeros_like(activated, device=diff_cond.device)
68
+ if len(activation_map) == 2:
69
+ activation_map = activation_map.unsqueeze(0)
70
+ mask[..., feature_idx] = mask[..., feature_idx] = activation_map.to(mask.device)
71
+ to_add = mask @ sae.decoder.weight.T
72
+ output[0][1] += to_add.permute(0, 3, 1, 2).to(output[0].device)[0]
73
+ return output
74
+
75
+
76
+ @torch.no_grad()
77
+ def replace_with_feature_base(sae, feature_idx, value, module, input, output):
78
+ # this assumes diff.shape[0] == 2
79
+ diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
80
+ diff_uncond, diff_cond = diff.chunk(2)
81
+ activated = sae.encode(diff_cond)
82
+ mask = torch.zeros_like(activated, device=diff_cond.device)
83
+ mask[..., feature_idx] = value
84
+ to_add = mask @ sae.decoder.weight.T
85
+ input[0][1] += to_add.permute(0, 3, 1, 2).to(output[0].device)[0]
86
+ return input
87
+
88
+
89
+ @torch.no_grad()
90
+ def add_feature_on_area_turbo(sae, feature_idx, activation_map, module, input, output):
91
+ diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
92
+ activated = sae.encode(diff)
93
+ mask = torch.zeros_like(activated, device=diff.device)
94
+ if len(activation_map) == 2:
95
+ activation_map = activation_map.unsqueeze(0)
96
+ mask[..., feature_idx] = mask[..., feature_idx] = activation_map.to(mask.device)
97
+ to_add = mask @ sae.decoder.weight.T
98
+ return (output[0] + to_add.permute(0, 3, 1, 2).to(output[0].device),)
99
+
100
+ @torch.no_grad()
101
+ def replace_with_feature_turbo(sae, feature_idx, value, module, input, output):
102
+ diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
103
+ activated = sae.encode(diff)
104
+ mask = torch.zeros_like(activated, device=diff.device)
105
+ mask[..., feature_idx] = value
106
+ to_add = mask @ sae.decoder.weight.T
107
+ return (input[0] + to_add.permute(0, 3, 1, 2).to(output[0].device),)
108
+
109
+
110
+ @torch.no_grad()
111
+ def reconstruct_sae_hook(sae, module, input, output):
112
+ diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
113
+ activated = sae.encode(diff)
114
+ reconstructed = sae.decoder(activated) + sae.pre_bias
115
+ return (input[0] + reconstructed.permute(0, 3, 1, 2).to(output[0].device),)
116
+
117
+
118
+ @torch.no_grad()
119
+ def ablate_block(module, input, output):
120
+ return input