Spaces:
Build error
Build error
Update e4e/models/psp.py
Browse files- e4e/models/psp.py +3 -1
e4e/models/psp.py
CHANGED
|
@@ -40,9 +40,11 @@ class pSp(nn.Module):
|
|
| 40 |
def load_weights(self):
|
| 41 |
if self.opts.checkpoint_path is not None:
|
| 42 |
print('Loading e4e over the pSp framework from checkpoint: {}'.format(self.opts.checkpoint_path))
|
| 43 |
-
ckpt = torch.load(self.opts.checkpoint_path, map_location='
|
| 44 |
self.encoder.load_state_dict(get_keys(ckpt, 'encoder'), strict=True)
|
|
|
|
| 45 |
self.decoder.load_state_dict(get_keys(ckpt, 'decoder'), strict=True)
|
|
|
|
| 46 |
self.__load_latent_avg(ckpt)
|
| 47 |
else:
|
| 48 |
print('Loading encoders weights from irse50!')
|
|
|
|
| 40 |
def load_weights(self):
|
| 41 |
if self.opts.checkpoint_path is not None:
|
| 42 |
print('Loading e4e over the pSp framework from checkpoint: {}'.format(self.opts.checkpoint_path))
|
| 43 |
+
ckpt = torch.load(self.opts.checkpoint_path, map_location='cuda:0' if torch.cuda.is_available() else "cpu")
|
| 44 |
self.encoder.load_state_dict(get_keys(ckpt, 'encoder'), strict=True)
|
| 45 |
+
self.encoder.to(self.device)
|
| 46 |
self.decoder.load_state_dict(get_keys(ckpt, 'decoder'), strict=True)
|
| 47 |
+
self.decoder.to(self.device)
|
| 48 |
self.__load_latent_avg(ckpt)
|
| 49 |
else:
|
| 50 |
print('Loading encoders weights from irse50!')
|