Spaces:
Runtime error
Runtime error
Remove token
Browse files
app.py
CHANGED
|
@@ -34,8 +34,6 @@ Related Apps:
|
|
| 34 |
'''
|
| 35 |
ARTICLE = '<center><img src="https://visitor-badge.glitch.me/badge?page_id=hysts.tadne-interpolation" alt="visitor badge"/></center>'
|
| 36 |
|
| 37 |
-
TOKEN = os.environ['TOKEN']
|
| 38 |
-
|
| 39 |
|
| 40 |
def parse_args() -> argparse.Namespace:
|
| 41 |
parser = argparse.ArgumentParser()
|
|
@@ -54,8 +52,7 @@ def parse_args() -> argparse.Namespace:
|
|
| 54 |
def load_model(device: torch.device) -> nn.Module:
|
| 55 |
model = Generator(512, 1024, 4, channel_multiplier=2)
|
| 56 |
path = hf_hub_download('hysts/TADNE',
|
| 57 |
-
'models/aydao-anime-danbooru2019s-512-5268480.pt'
|
| 58 |
-
use_auth_token=TOKEN)
|
| 59 |
checkpoint = torch.load(path)
|
| 60 |
model.load_state_dict(checkpoint['g_ema'])
|
| 61 |
model.eval()
|
|
@@ -84,10 +81,10 @@ def generate_image(model: nn.Module, z: torch.Tensor, truncation_psi: float,
|
|
| 84 |
|
| 85 |
|
| 86 |
@torch.inference_mode()
|
| 87 |
-
def generate_interpolated_images(
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
seed0 = int(np.clip(seed0, 0, np.iinfo(np.uint32).max))
|
| 92 |
seed1 = int(np.clip(seed1, 0, np.iinfo(np.uint32).max))
|
| 93 |
|
|
|
|
| 34 |
'''
|
| 35 |
ARTICLE = '<center><img src="https://visitor-badge.glitch.me/badge?page_id=hysts.tadne-interpolation" alt="visitor badge"/></center>'
|
| 36 |
|
|
|
|
|
|
|
| 37 |
|
| 38 |
def parse_args() -> argparse.Namespace:
|
| 39 |
parser = argparse.ArgumentParser()
|
|
|
|
| 52 |
def load_model(device: torch.device) -> nn.Module:
|
| 53 |
model = Generator(512, 1024, 4, channel_multiplier=2)
|
| 54 |
path = hf_hub_download('hysts/TADNE',
|
| 55 |
+
'models/aydao-anime-danbooru2019s-512-5268480.pt')
|
|
|
|
| 56 |
checkpoint = torch.load(path)
|
| 57 |
model.load_state_dict(checkpoint['g_ema'])
|
| 58 |
model.eval()
|
|
|
|
| 81 |
|
| 82 |
|
| 83 |
@torch.inference_mode()
|
| 84 |
+
def generate_interpolated_images(seed0: int, seed1: int, num_intermediate: int,
|
| 85 |
+
psi0: float, psi1: float,
|
| 86 |
+
randomize_noise: bool, model: nn.Module,
|
| 87 |
+
device: torch.device) -> list[np.ndarray]:
|
| 88 |
seed0 = int(np.clip(seed0, 0, np.iinfo(np.uint32).max))
|
| 89 |
seed1 = int(np.clip(seed1, 0, np.iinfo(np.uint32).max))
|
| 90 |
|