import gradio as gr import spaces import torch import torch.nn as nn from torch.nn import functional as F import numpy as np import math import os import pickle import requests import textwrap import subprocess import shutil import time from dataclasses import dataclass from typing import Optional # --- 1. Automated Environment and Data Setup --- def setup_environment(): """ Checks for and sets up the necessary data and code. - Clones nanoGPT if not present. - Copies the shakespeare_char dataset directory. - Runs the data preparation script to create meta.pkl and binary files. This function makes the script self-contained. """ nano_gpt_repo_path = 'nanoGPT' data_dir_path = 'shakespeare_char' meta_path = os.path.join(data_dir_path, 'meta.pkl') if os.path.exists(meta_path): print("Dataset and metadata found. Skipping setup.") return print("Required data not found. Starting one-time setup...") if not os.path.exists(nano_gpt_repo_path): print(f"Cloning nanoGPT repository...") try: subprocess.run( ['git', 'clone', 'https://github.com/karpathy/nanoGPT.git'], check=True, capture_output=True, text=True ) print("Cloned successfully.") except subprocess.CalledProcessError as e: print(f"Error cloning repository: {e.stderr}") raise else: print("nanoGPT repository already exists.") source_data_dir = os.path.join(nano_gpt_repo_path, 'data', 'shakespeare_char') if not os.path.exists(data_dir_path): print(f"Copying '{source_data_dir}' to '{data_dir_path}'...") shutil.copytree(source_data_dir, data_dir_path) print("Copied successfully.") else: print(f"'{data_dir_path}' directory already exists.") prepare_script_path = os.path.join(data_dir_path, 'prepare.py') if not os.path.exists(meta_path): print(f"Running data preparation script: '{prepare_script_path}'...") try: subprocess.run( ['python', 'prepare.py'], check=True, cwd=data_dir_path, capture_output=True, text=True ) print("Data preparation script finished successfully.") except subprocess.CalledProcessError as e: print(f"Error running prepare.py: {e.stderr}") raise print("Setup complete.") setup_environment() # --- 2. Global Setup & Helper Functions --- data_dir = './shakespeare_char/' meta_path = os.path.join(data_dir, 'meta.pkl') with open(meta_path, 'rb') as f: meta = pickle.load(f) vocab_size = meta['vocab_size'] itos = meta['itos'] stoi = meta['stoi'] context_length = 256 device = 'cuda' if torch.cuda.is_available() else 'cpu' def decode(indices_tensor: torch.Tensor): if indices_tensor.dim() > 1: indices_tensor = indices_tensor.squeeze(0) indices = indices_tensor.cpu().numpy() return ''.join([itos.get(i, '?') for i in indices]) def wrap_text(long_text, width=80): paragraphs = long_text.splitlines() wrapped = [textwrap.fill(p, width=width) if p else '' for p in paragraphs] return "\n".join(wrapped) # --- 3. Model Architecture (Identical to Notebook) --- @dataclass class GPTConfig: block_size: int = 1024 vocab_size: int = 50304 n_layer: int = 12 n_head: int = 12 n_embd: int = 768 cond_dim: int = 64 dropout: float = 0.0 bias: bool = False class MLP(nn.Module): def __init__(self, config): super().__init__() self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) self.gelu = nn.GELU() self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) self.dropout = nn.Dropout(config.dropout) def forward(self, x): x = self.c_fc(x) x = self.gelu(x) x = self.c_proj(x) x = self.dropout(x) return x class SelfAttention(nn.Module): def __init__(self, config): super().__init__() assert config.n_embd % config.n_head == 0 self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) self.attn_dropout = nn.Dropout(config.dropout) self.resid_dropout = nn.Dropout(config.dropout) self.n_head = config.n_head self.n_embd = config.n_embd self.dropout = config.dropout self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') def forward(self, x): B, T, C = x.size() q, k, v = self.c_attn(x).split(self.n_embd, dim=2) k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) if self.flash: y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=False) else: att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) att = F.softmax(att, dim=-1) att = self.attn_dropout(att) y = att @ v y = y.transpose(1, 2).contiguous().view(B, T, C) y = self.resid_dropout(self.c_proj(y)) return y def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: return x * (1 + scale) + shift def bias_add_scale(x: torch.Tensor, bias: Optional[torch.Tensor], scale: torch.Tensor, residual: Optional[torch.Tensor]) -> torch.Tensor: if bias is not None: out = scale * (x + bias) else: out = scale * x if residual is not None: out = residual + out return out class DDiTBlock(nn.Module): def __init__(self, config): super().__init__() self.ln_1 = nn.LayerNorm(config.n_embd, bias=config.bias) self.attn = SelfAttention(config) self.ln_2 = nn.LayerNorm(config.n_embd, bias=config.bias) self.mlp = MLP(config) self.adaLN_modulation = nn.Linear(config.cond_dim, 6 * config.n_embd) self.adaLN_modulation.weight.data.zero_() self.adaLN_modulation.bias.data.zero_() def forward(self, x, c): shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c)[:, None].chunk(6, dim=2) x_skip = x x = modulate(self.ln_1(x), shift_msa, scale_msa) x = self.attn(x) x = bias_add_scale(self.attn(self.ln_1(x)), None, gate_msa, x_skip) x = bias_add_scale(self.mlp(modulate(self.ln_2(x), shift_mlp, scale_mlp)), None, gate_mlp, x) return x class DDitFinalLayer(nn.Module): def __init__(self, config): super().__init__() self.norm_final = nn.LayerNorm(config.n_embd, bias=config.bias) self.linear = nn.Linear(config.n_embd, config.vocab_size) self.linear.weight.data.zero_() self.linear.bias.data.zero_() self.adaLN_modulation = nn.Linear(config.cond_dim, 2 * config.n_embd) self.adaLN_modulation.weight.data.zero_() self.adaLN_modulation.bias.data.zero_() def forward(self, x, c): shift, scale = self.adaLN_modulation(c)[:, None].chunk(2, dim=2) x = modulate(self.norm_final(x), shift, scale) x = self.linear(x) return x class TimestepEmbedder(nn.Module): def __init__(self, hidden_size, frequency_embedding_size=256): super().__init__() self.mlp = nn.Sequential( nn.Linear(frequency_embedding_size, hidden_size, bias=True), nn.SiLU(), nn.Linear(hidden_size, hidden_size, bias=True), ) self.frequency_embedding_size = frequency_embedding_size @staticmethod def timestep_embedding(t, dim, max_period=10000): half = dim // 2 freqs = torch.exp( -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half ).to(device=t.device) args = t[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) return embedding def forward(self, t): t_freq = self.timestep_embedding(t, self.frequency_embedding_size) t_emb = self.mlp(t_freq) return t_emb class GPT(nn.Module): def __init__(self, config): super().__init__() assert config.vocab_size is not None assert config.block_size is not None self.config = config self.sigma_map = TimestepEmbedder(config.cond_dim) self.transformer = nn.ModuleDict(dict( wte = nn.Embedding(config.vocab_size, config.n_embd), wpe = nn.Embedding(config.block_size, config.n_embd), drop = nn.Dropout(config.dropout), h = nn.ModuleList([DDiTBlock(config) for _ in range(config.n_layer)]), ln_f = nn.LayerNorm(config.n_embd, bias=config.bias), )) self.lm_head = DDitFinalLayer(config) self.apply(self._init_weights) for pn, p in self.named_parameters(): if pn.endswith('c_proj.weight'): torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer)) def _init_weights(self, module): if isinstance(module, nn.Linear): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) def forward(self, idx, sigma): sigma = sigma.reshape(-1) b, t = idx.size() c = F.silu(self.sigma_map(sigma)) assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" pos = torch.arange(0, t, dtype=torch.long, device=device) tok_emb = self.transformer.wte(idx) pos_emb = self.transformer.wpe(pos) x = self.transformer.drop(tok_emb + pos_emb) for block in self.transformer.h: x = block(x, c) x = self.transformer.ln_f(x) x = self.lm_head(x, c) x = torch.scatter(x, -1, idx[..., None], torch.zeros_like(x[..., :1])) return x class GeometricNoise: def __init__(self, sigma_min=1e-4, sigma_max=20): self.sigmas = 1.0 * torch.tensor([sigma_min, sigma_max]).to(device) def rate_noise(self, t): return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t * (self.sigmas[1].log() - self.sigmas[0].log()) def total_noise(self, t): return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t def __call__(self, t): return self.total_noise(t), self.rate_noise(t) # --- 4. Inference & Sampling Logic (Identical to Notebook) --- def transition(x_t: torch.Tensor, delta_sigma: torch.Tensor) -> torch.Tensor: base_prob = (1 - torch.exp(-delta_sigma[..., None])) / vocab_size trans = torch.ones(*x_t.shape, vocab_size, device=x_t.device) * base_prob trans = trans.scatter(-1, x_t[..., None], torch.zeros_like(trans)) diag_fill = 1 - trans.sum(dim=-1, keepdim=True) trans = trans.scatter(-1, x_t[..., None], diag_fill) return trans def staggered_score(score, delta_sigma): exp_factor = torch.exp(-delta_sigma)[..., None] correction = ((exp_factor - 1) / (vocab_size * exp_factor)) * score.sum(dim=-1, keepdim=True) return correction + score / exp_factor def sample_categorical(probs: torch.Tensor) -> torch.Tensor: eps = 1e-10 gumbel_noise = -torch.log(-torch.log(torch.rand_like(probs) + eps) + eps) return torch.argmax(torch.log(probs + eps) + gumbel_noise, dim=-1) # --- 5. Model Initialization and Loading --- print("Initializing and loading the pretrained model...") model_args = dict(n_layer=6, n_head=6, n_embd=384, cond_dim=64, bias=False, vocab_size=vocab_size, block_size=context_length, dropout=0.2) config = GPTConfig(**model_args) model = GPT(config) model.load_state_dict( torch.hub.load_state_dict_from_url( 'https://raw.githubusercontent.com/ash80/diffusion-gpt/master/pretrained_model/model_epoch_25.pth', map_location=device ) ) model.to(device) model.eval() noise = GeometricNoise(sigma_min=1e-4, sigma_max=20) print("Model loaded successfully.") # --- 6. Gradio Interface Logic --- @spaces.GPU def generate_text(steps): """ Fast generation phase. Runs the diffusion process and stores all intermediate frames in a list, then returns the final text and the list. """ steps = int(steps) eps = 1e-5 # List to store each frame of the diffusion process diffusion_frames = [] # Start with a random sample x = torch.randint(0, vocab_size, (1, context_length), device=device) initial_text = f"--- Initial Random Noise ---\n\n{wrap_text(decode(x[0]))}" diffusion_frames.append(initial_text) timesteps = torch.linspace(1, eps, steps + 1, device=device) step_size = (1 - eps) / steps with torch.no_grad(): for i in range(steps): t = timesteps[i] * torch.ones(x.shape[0], 1, device=device) curr_sigma_bar = noise(t)[0] next_sigma_bar = noise(t - step_size)[0] delta_sigma = curr_sigma_bar - next_sigma_bar log_score = model(x, curr_sigma_bar) score = torch.exp(log_score) stag_score = staggered_score(score, delta_sigma) probs = stag_score * transition(x, delta_sigma) x = sample_categorical(probs) # Store the frame progress_text = f"--- Denoising Step {i + 1}/{steps} ---\n\n{wrap_text(decode(x[0]))}" diffusion_frames.append(progress_text) # Final denoising step t = timesteps[steps] * torch.ones(x.shape[0], 1, device=device) curr_sigma_bar = noise(t)[0] delta_sigma = curr_sigma_bar log_score = model(x, curr_sigma_bar) score = torch.exp(log_score) stag_score = staggered_score(score, delta_sigma) probs = stag_score * transition(x, delta_sigma) x = sample_categorical(probs) final_text = f"--- Final Denoised Text (Step {steps}) ---\n\n{wrap_text(decode(x[0]))}" diffusion_frames.append(final_text) # Return the final text and the complete list of frames return final_text, diffusion_frames def replay_diffusion(frames, replay_speed): """ Slow replay phase. Iterates through the stored frames and yields them with a delay to create an animation effect. """ delay = 0.5 / replay_speed # Calculate delay based on speed multiplier for frame in frames: yield frame time.sleep(delay) # Define the Gradio UI css = '''.gradio-container > .fillable {max-width: 720px !important} h3{margin-top: 1em} p{margin-top: 0} textarea{font-family: monospace;background-color: black} ''' with gr.Blocks(theme=gr.themes.Citrus(), css=css) as demo: gr.Markdown( """ # The Annotated Discrete Diffusion Models ### Tiny 7.23M parameters Shakespeare character diffusion model by [Ashwani Kumar](https://x.com/ash_at_tt/status/1977376958859092250) [GitHub](https://github.com/ash80/diffusion-gpt), [Colab](https://colab.research.google.com/github/ash80/diffusion-gpt/blob/master/The_Annotated_Discrete_Diffusion_Models.ipynb) """ ) generate_button = gr.Button("Generate", variant="primary") output_textbox = gr.Textbox( label="Generated Text", lines=15, interactive=False, show_copy_button=True, placeholder="Generation will appear here..." ) with gr.Row(): steps_slider = gr.Slider( minimum=64, maximum=512, value=128, step=1, label="Denoising Steps", info="Number of steps in the generation process." ) speed_slider = gr.Slider( minimum=1, maximum=20, value=10, step=1, label="Replay Speed", info="Controls the speed of the animation after generation.", visible=False ) diffusion_frames_state = gr.State([]) generate_event = generate_button.click( fn=generate_text, inputs=[steps_slider], outputs=[output_textbox, diffusion_frames_state] ).then( fn=replay_diffusion, inputs=[diffusion_frames_state, speed_slider], outputs=[output_textbox] ) if __name__ == "__main__": demo.launch()