# app.py — minimal Gradio app that loads your GAN and outputs a 10x10 image grid per click import os import numpy as np import torch from torchvision.utils import make_grid from PIL import Image import gradio as gr from scipy.stats import truncnorm # Your repo modules import archs import cfg # ---- tweak these if your filenames differ ---- ARCH = "arch_cifar10" # e.g. "arch_cifar10" DATASET = "celeba" # used to build ckpt path below CKPT_PATH = os.path.join("model", f"{DATASET}.pth") GENOTYPE_PATH = os.path.join("nas", "latest_G.npy") N_SAMPLES = 50 # 5x10 N_COLS = 10 def _to_pil_grid(t): # t: (3,H,W) in [0,1] (we'll pass normalize=True to make_grid) arr = (t.clamp(0,1) * 255).byte().permute(1,2,0).cpu().numpy() return Image.fromarray(arr) # Cache the model so it loads only once _GEN = None _DEVICE = "cpu" _LATENT_DIM = 120 # fallback if not present in args def load_model_once(): global _GEN, _DEVICE, _LATENT_DIM if _GEN is not None: return _GEN, _DEVICE, _LATENT_DIM args = cfg.parse_args() # space may or may not have GPU; handle both has_cuda = torch.cuda.is_available() args.gpu_ids = [0] if has_cuda else [] _DEVICE = "cuda" if has_cuda else "cpu" # 1) Load genotype genotype_G = np.load(GENOTYPE_PATH, allow_pickle=True) # 2) Build generator G_base = eval(f"archs.{ARCH}.Generator")(args, genotype_G) if has_cuda: CKPT_PATH = os.path.join("model", f"{DATASET}.pth") G = torch.nn.DataParallel(G_base, device_ids=args.gpu_ids).cuda(args.gpu_ids[0]) else: CKPT_PATH = os.path.join("model", f"{DATASET}_cpu.pth") G = G_base # 3) Load checkpoint (expects {'gen_state_dict': ...}) ckpt = torch.load(CKPT_PATH, map_location=_DEVICE) state = ckpt["gen_state_dict"] if "gen_state_dict" in ckpt else ckpt G.load_state_dict(state) G.eval() # 4) Latent dim from args, fallback to common default _LATENT_DIM = int(getattr(args, "latent_dim", 120)) _GEN = G return _GEN, _DEVICE, _LATENT_DIM def generate_grid(): G, device, latent_dim = load_model_once() with torch.no_grad(): z_np = truncnorm.rvs(-1, 1, loc=0, scale=1, size=(N_SAMPLES, latent_dim)).astype(np.float32) z = torch.from_numpy(z_np).to(device) imgs = G(z) # (B,3,H,W) in [-1,1] or [0,1] # make a grid; normalize handles [-1,1] grid = make_grid(imgs, nrow=N_COLS, normalize=True, scale_each=True) return _to_pil_grid(grid) # ---------------- Gradio UI ---------------- with gr.Blocks(title="MMD-PMish-NAS-GAN Image Generation") as demo: gr.Markdown("## MMD-PMish-NAS-GAN Image Generation\nClick **Generate** to sample a fresh grid of images using the model.") generate_btn = gr.Button("Generate", variant="primary") out = gr.Image(label="Output Grid", type="pil") generate_btn.click(fn=generate_grid, inputs=None, outputs=out, queue=True) if __name__ == "__main__": demo.queue().launch()