prasannareddyp commited on
Commit
350923f
·
verified ·
1 Parent(s): b4f6d16

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -89
app.py CHANGED
@@ -1,89 +1,89 @@
1
- # app.py — minimal Gradio app that loads your GAN and outputs a 10x10 image grid per click
2
- import os
3
- import numpy as np
4
- import torch
5
- from torchvision.utils import make_grid
6
- from PIL import Image
7
- import gradio as gr
8
- from scipy.stats import truncnorm
9
-
10
- # Your repo modules
11
- import archs
12
- import cfg
13
-
14
- # ---- tweak these if your filenames differ ----
15
- ARCH = "arch_cifar10" # e.g. "arch_cifar10"
16
- DATASET = "celeba" # used to build ckpt path below
17
- CKPT_PATH = os.path.join("model", f"{DATASET}.pth")
18
- GENOTYPE_PATH = os.path.join("nas", "latest_G.npy")
19
- N_SAMPLES = 50 # 5x10
20
- N_COLS = 10
21
-
22
- def _to_pil_grid(t):
23
- # t: (3,H,W) in [0,1] (we'll pass normalize=True to make_grid)
24
- arr = (t.clamp(0,1) * 255).byte().permute(1,2,0).cpu().numpy()
25
- return Image.fromarray(arr)
26
-
27
- # Cache the model so it loads only once
28
- _GEN = None
29
- _DEVICE = "cpu"
30
- _LATENT_DIM = 120 # fallback if not present in args
31
-
32
- def load_model_once():
33
- global _GEN, _DEVICE, _LATENT_DIM
34
- if _GEN is not None:
35
- return _GEN, _DEVICE, _LATENT_DIM
36
-
37
- args = cfg.parse_args()
38
- # space may or may not have GPU; handle both
39
- has_cuda = torch.cuda.is_available()
40
- args.gpu_ids = [0] if has_cuda else []
41
-
42
- _DEVICE = "cuda" if has_cuda else "cpu"
43
-
44
- # 1) Load genotype
45
- genotype_G = np.load(GENOTYPE_PATH, allow_pickle=True)
46
-
47
- # 2) Build generator
48
- G_base = eval(f"archs.{ARCH}.Generator")(args, genotype_G)
49
- if has_cuda:
50
- CKPT_PATH = os.path.join("model", f"{DATASET}.pth")
51
- G = torch.nn.DataParallel(G_base, device_ids=args.gpu_ids).cuda(args.gpu_ids[0])
52
- else:
53
- CKPT_PATH = os.path.join("model", f"{DATASET}_cpu.pth")
54
- G = G_base
55
-
56
- # 3) Load checkpoint (expects {'gen_state_dict': ...})
57
- ckpt = torch.load(CKPT_PATH, map_location=_DEVICE)
58
- state = ckpt["gen_state_dict"] if "gen_state_dict" in ckpt else ckpt
59
- G.load_state_dict(state)
60
- G.eval()
61
-
62
- # 4) Latent dim from args, fallback to common default
63
- _LATENT_DIM = int(getattr(args, "latent_dim", 120))
64
-
65
- _GEN = G
66
- return _GEN, _DEVICE, _LATENT_DIM
67
-
68
- def generate_grid():
69
- G, device, latent_dim = load_model_once()
70
-
71
- with torch.no_grad():
72
- z_np = truncnorm.rvs(-1, 1, loc=0, scale=1, size=(N_SAMPLES, latent_dim)).astype(np.float32)
73
- z = torch.from_numpy(z_np).to(device)
74
- imgs = G(z) # (B,3,H,W) in [-1,1] or [0,1]
75
-
76
- # make a grid; normalize handles [-1,1]
77
- grid = make_grid(imgs, nrow=N_COLS, normalize=True, scale_each=True)
78
- return _to_pil_grid(grid)
79
-
80
- # ---------------- Gradio UI ----------------
81
- with gr.Blocks(title="MMD-PMish-NAS-GAN Image Generation") as demo:
82
- gr.Markdown("## MMD-PMish-NAS-GAN Image Generation\nClick **Generate** to sample a fresh 10×10 grid from your model.")
83
- generate_btn = gr.Button("Generate", variant="primary")
84
- out = gr.Image(label="Output Grid", type="pil")
85
-
86
- generate_btn.click(fn=generate_grid, inputs=None, outputs=out, queue=True)
87
-
88
- if __name__ == "__main__":
89
- demo.queue().launch()
 
1
+ # app.py — minimal Gradio app that loads your GAN and outputs a 10x10 image grid per click
2
+ import os
3
+ import numpy as np
4
+ import torch
5
+ from torchvision.utils import make_grid
6
+ from PIL import Image
7
+ import gradio as gr
8
+ from scipy.stats import truncnorm
9
+
10
+ # Your repo modules
11
+ import archs
12
+ import cfg
13
+
14
+ # ---- tweak these if your filenames differ ----
15
+ ARCH = "arch_cifar10" # e.g. "arch_cifar10"
16
+ DATASET = "celeba" # used to build ckpt path below
17
+ CKPT_PATH = os.path.join("model", f"{DATASET}.pth")
18
+ GENOTYPE_PATH = os.path.join("nas", "latest_G.npy")
19
+ N_SAMPLES = 50 # 5x10
20
+ N_COLS = 10
21
+
22
+ def _to_pil_grid(t):
23
+ # t: (3,H,W) in [0,1] (we'll pass normalize=True to make_grid)
24
+ arr = (t.clamp(0,1) * 255).byte().permute(1,2,0).cpu().numpy()
25
+ return Image.fromarray(arr)
26
+
27
+ # Cache the model so it loads only once
28
+ _GEN = None
29
+ _DEVICE = "cpu"
30
+ _LATENT_DIM = 120 # fallback if not present in args
31
+
32
+ def load_model_once():
33
+ global _GEN, _DEVICE, _LATENT_DIM
34
+ if _GEN is not None:
35
+ return _GEN, _DEVICE, _LATENT_DIM
36
+
37
+ args = cfg.parse_args()
38
+ # space may or may not have GPU; handle both
39
+ has_cuda = torch.cuda.is_available()
40
+ args.gpu_ids = [0] if has_cuda else []
41
+
42
+ _DEVICE = "cuda" if has_cuda else "cpu"
43
+
44
+ # 1) Load genotype
45
+ genotype_G = np.load(GENOTYPE_PATH, allow_pickle=True)
46
+
47
+ # 2) Build generator
48
+ G_base = eval(f"archs.{ARCH}.Generator")(args, genotype_G)
49
+ if has_cuda:
50
+ CKPT_PATH = os.path.join("model", f"{DATASET}.pth")
51
+ G = torch.nn.DataParallel(G_base, device_ids=args.gpu_ids).cuda(args.gpu_ids[0])
52
+ else:
53
+ CKPT_PATH = os.path.join("model", f"{DATASET}_cpu.pth")
54
+ G = G_base
55
+
56
+ # 3) Load checkpoint (expects {'gen_state_dict': ...})
57
+ ckpt = torch.load(CKPT_PATH, map_location=_DEVICE)
58
+ state = ckpt["gen_state_dict"] if "gen_state_dict" in ckpt else ckpt
59
+ G.load_state_dict(state)
60
+ G.eval()
61
+
62
+ # 4) Latent dim from args, fallback to common default
63
+ _LATENT_DIM = int(getattr(args, "latent_dim", 120))
64
+
65
+ _GEN = G
66
+ return _GEN, _DEVICE, _LATENT_DIM
67
+
68
+ def generate_grid():
69
+ G, device, latent_dim = load_model_once()
70
+
71
+ with torch.no_grad():
72
+ z_np = truncnorm.rvs(-1, 1, loc=0, scale=1, size=(N_SAMPLES, latent_dim)).astype(np.float32)
73
+ z = torch.from_numpy(z_np).to(device)
74
+ imgs = G(z) # (B,3,H,W) in [-1,1] or [0,1]
75
+
76
+ # make a grid; normalize handles [-1,1]
77
+ grid = make_grid(imgs, nrow=N_COLS, normalize=True, scale_each=True)
78
+ return _to_pil_grid(grid)
79
+
80
+ # ---------------- Gradio UI ----------------
81
+ with gr.Blocks(title="MMD-PMish-NAS-GAN Image Generation") as demo:
82
+ gr.Markdown("## MMD-PMish-NAS-GAN Image Generation\nClick **Generate** to sample a fresh grid of images using the model.")
83
+ generate_btn = gr.Button("Generate", variant="primary")
84
+ out = gr.Image(label="Output Grid", type="pil")
85
+
86
+ generate_btn.click(fn=generate_grid, inputs=None, outputs=out, queue=True)
87
+
88
+ if __name__ == "__main__":
89
+ demo.queue().launch()