prasannareddyp commited on
Commit
e091f33
·
verified ·
1 Parent(s): 610f867

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -68,11 +68,11 @@ def generate_grid():
68
  args = cfg.parse_args()
69
 
70
  with torch.no_grad():
71
- # z = torch.randn(N_SAMPLES, latent_dim, device=device)
72
- z = torch.cuda.FloatTensor(truncnorm.rvs(-1, 1, loc=0, scale=1, size=(N_SAMPLES, args.latent_dim)))
73
  imgs = G(z) # (B,3,H,W) in [-1,1] or [0,1]
74
 
75
- # make a 10x10 grid; normalize handles [-1,1]
76
  grid = make_grid(imgs, nrow=N_COLS, normalize=True, scale_each=True)
77
  return _to_pil_grid(grid)
78
 
 
68
  args = cfg.parse_args()
69
 
70
  with torch.no_grad():
71
+ z_np = truncnorm.rvs(-1, 1, loc=0, scale=1, size=(N_SAMPLES, latent_dim)).astype(np.float32)
72
+ z = torch.from_numpy(z_np).to(device)
73
  imgs = G(z) # (B,3,H,W) in [-1,1] or [0,1]
74
 
75
+ # make a grid; normalize handles [-1,1]
76
  grid = make_grid(imgs, nrow=N_COLS, normalize=True, scale_each=True)
77
  return _to_pil_grid(grid)
78