Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| from dataclasses import dataclass, field | |
| import os | |
| class SAETrainingConfig: | |
| d_model: int | |
| n_dirs: int | |
| k: int | |
| block_name: str | |
| bs: int | |
| save_path_base: str | |
| auxk: int = 256 | |
| lr: float = 1e-4 | |
| eps: float = 6.25e-10 | |
| dead_toks_threshold: int = 10_000_000 | |
| auxk_coef: float = 1/32 | |
| def sae_name(self): | |
| return f'{self.block_name}_k{self.k}_hidden{self.n_dirs}_auxk{self.auxk}_bs{self.bs}_lr{self.lr}' | |
| def save_path(self): | |
| return os.path.join(save_path_base, f'{self.block_name}_k{self.k}_hidden{self.n_dirs}_auxk{self.auxk}_bs{self.bs}_lr{self.lr}') | |
| class Config: | |
| saes: list[SAETrainingConfig] | |
| paths_to_latents: list[str] | |
| log_interval: int | |
| save_interval: int | |
| bs: int | |
| block_name: str | |
| wandb_project: str = 'sdxl_sae_train' | |
| wandb_name: str = 'multiple_sae' | |
| def __init__(self, cfg_json): | |
| self.saes = [SAETrainingConfig(**sae_cfg, block_name=cfg_json['block_name'], bs=cfg_json['bs'], save_path_base=cfg_json['save_path_base']) | |
| for sae_cfg in cfg_json['sae_configs']] | |
| self.save_path_base = cfg_json['save_path_base'] | |
| self.paths_to_latents = cfg_json['paths_to_latents'] | |
| self.log_interval = cfg_json['log_interval'] | |
| self.save_interval = cfg_json['save_interval'] | |
| self.bs = cfg_json['bs'] | |
| self.block_name = cfg_json['block_name'] |