Spaces:
Build error
Build error
adymaharana
commited on
Commit
·
908bed5
1
Parent(s):
1cac669
fp16 version
Browse files- app.py +113 -18
- dalle/models/__init__.py +40 -36
- dalle/models/__pycache__/__init__.cpython-38.pyc +0 -0
- dalle/models/stage2/__pycache__/layers.cpython-38.pyc +0 -0
- dalle/models/stage2/layers.py +7 -2
- gradio_demo_pororo.png +3 -0
- requirements.txt +1 -0
app.py
CHANGED
|
@@ -6,9 +6,14 @@ from dalle.models import StoryDalle
|
|
| 6 |
import argparse
|
| 7 |
from PIL import Image
|
| 8 |
from torchvision.utils import save_image
|
|
|
|
| 9 |
import tensorflow_hub as hub
|
| 10 |
import gdown
|
|
|
|
|
|
|
| 11 |
|
|
|
|
|
|
|
| 12 |
|
| 13 |
source_frame_paths = {
|
| 14 |
'Pororo': '/playpen-ssd/adyasha/projects/StoryGAN/pororo_png/Pororo_ENGLISH1_2/Pororo_ENGLISH1_2_ep6/12.png',
|
|
@@ -23,6 +28,51 @@ source_frame_paths = {
|
|
| 23 |
}
|
| 24 |
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
def inverse_normalize(tensor, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)):
|
| 27 |
mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device)
|
| 28 |
std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device)
|
|
@@ -66,9 +116,10 @@ def save_story_results(images, video_len=4, n_candidates=1, mask=None):
|
|
| 66 |
|
| 67 |
|
| 68 |
def main(args):
|
|
|
|
| 69 |
#device = 'cuda:0'
|
| 70 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 71 |
-
#device = torch.device('cpu')
|
| 72 |
|
| 73 |
model_url = 'https://drive.google.com/u/1/uc?id=1KAXVtE8lEE2Yc83VY7w6ycOOMkdWbmJo&export=sharing'
|
| 74 |
|
|
@@ -81,6 +132,9 @@ def main(args):
|
|
| 81 |
#assert os.path.exists("./ckpt/25.pth")
|
| 82 |
gdown.download(png_url, quiet=True, use_cookies=False, output="demo_pororo_good.png")
|
| 83 |
|
|
|
|
|
|
|
|
|
|
| 84 |
if args.debug:
|
| 85 |
model = None
|
| 86 |
embed = None
|
|
@@ -88,13 +142,20 @@ def main(args):
|
|
| 88 |
model, config = StoryDalle.from_pretrained(args)
|
| 89 |
model.tokenizer.add_tokens(['pororo', 'loopy', 'eddy', 'harry', 'poby', 'tongtong', 'crong', 'rody', 'petty'])
|
| 90 |
model.eval()
|
| 91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
embed = hub.load("https://tfhub.dev/google/universal-sentence-encoder-large/5")
|
| 93 |
|
| 94 |
-
if model.config.story.condition:
|
| 95 |
-
for i in range(len(model.cross_attention_layers)):
|
| 96 |
-
model.cross_attention_layers[i].to(device)
|
| 97 |
-
print("Cross-attention layers are in cuda:", next(model.cross_attention_layers[0].parameters()).is_cuda)
|
| 98 |
|
| 99 |
valid_transform = transforms.Compose(
|
| 100 |
[transforms.Resize(config.dataset.image_resolution),
|
|
@@ -103,6 +164,8 @@ def main(args):
|
|
| 103 |
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]
|
| 104 |
)
|
| 105 |
|
|
|
|
|
|
|
| 106 |
#torch.save(model, './ckpt/checkpoint.pt')
|
| 107 |
#sys.exit()
|
| 108 |
|
|
@@ -110,32 +173,62 @@ def main(args):
|
|
| 110 |
supercondition=False):
|
| 111 |
|
| 112 |
if not args.debug:
|
| 113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
mask = [1 if caption != '' else 0 for caption in captions]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
print(captions, mask, source, n_candidates)
|
|
|
|
|
|
|
|
|
|
| 116 |
for i, caption in enumerate(captions):
|
| 117 |
if caption == "":
|
| 118 |
-
captions[i] = "Pororo is reading a book."
|
|
|
|
| 119 |
tokens = [model.tokenizer.encode(caption) for caption in captions]
|
| 120 |
texts = torch.stack([torch.LongTensor(token.ids) for token in tokens]).unsqueeze(0)
|
| 121 |
sent_embeds = torch.tensor(embed(captions).numpy())
|
| 122 |
-
# sent_embeds = torch.tensor(description_vecs[source_frame_paths[source].
|
| 123 |
-
# replace('/playpen-ssd/adyasha/projects/StoryGAN/pororo_png/', '')[:-4]][0]).unsqueeze(0).repeat(4, 1)
|
| 124 |
-
|
| 125 |
src_image = valid_transform(Image.open('./demo/%s.png' % source).convert('RGB'))
|
| 126 |
|
| 127 |
stories = []
|
| 128 |
with torch.no_grad():
|
| 129 |
for i in range(texts.shape[0]):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
pixels = model.sampling_batch(texts[i].to(device), src_image.unsqueeze(0).to(device),
|
| 131 |
-
|
| 132 |
-
|
| 133 |
stories.append(pixels)
|
| 134 |
-
|
| 135 |
img = save_story_results(stories, video_len=4, n_candidates=n_candidates, mask=mask)
|
| 136 |
-
save_image(img,
|
|
|
|
|
|
|
|
|
|
| 137 |
|
| 138 |
-
return
|
| 139 |
|
| 140 |
with gr.Blocks(css='#output {width:750px; height:750px; float:left;}') as demo:
|
| 141 |
gr.Markdown('''
|
|
@@ -170,7 +263,7 @@ def main(args):
|
|
| 170 |
Here are some examples of generated visual stories for the above-mentioned settings.
|
| 171 |
|
| 172 |
<p align="center">
|
| 173 |
-
<img src="file/
|
| 174 |
</p>
|
| 175 |
|
| 176 |
Due to the small training dataset size for story visualization, the model has poor generalization to some unseen settings. The model struggles to generate coherent images in the following scenarios.
|
|
@@ -236,10 +329,11 @@ def main(args):
|
|
| 236 |
\[4\] Sharma, Piyush, et al. "Conceptual captions: A cleaned, hypernymed, image alt-text dataset for automatic image captioning." Proceedings of the 56th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers). 2018.
|
| 237 |
''')
|
| 238 |
|
| 239 |
-
demo.launch(share=
|
| 240 |
|
| 241 |
|
| 242 |
if __name__ == "__main__":
|
|
|
|
| 243 |
args_list = ['--model_name_or_path', './ckpt/25.pth',
|
| 244 |
'--prefix_model_name_or_path', './1.3B/',
|
| 245 |
'--dataset_name', 'pororo',
|
|
@@ -351,6 +445,7 @@ if __name__ == "__main__":
|
|
| 351 |
)
|
| 352 |
|
| 353 |
parser.add_argument("--debug", action="store_true", help="Whether to debug the demo.")
|
|
|
|
| 354 |
|
| 355 |
args = parser.parse_args(args_list)
|
| 356 |
|
|
|
|
| 6 |
import argparse
|
| 7 |
from PIL import Image
|
| 8 |
from torchvision.utils import save_image
|
| 9 |
+
import tensorflow as tf
|
| 10 |
import tensorflow_hub as hub
|
| 11 |
import gdown
|
| 12 |
+
from allennlp.predictors.predictor import Predictor
|
| 13 |
+
import random
|
| 14 |
|
| 15 |
+
torch.set_grad_enabled(False)
|
| 16 |
+
tf.config.set_visible_devices([], 'GPU') # setting Tensorflow's GPU visibility to None to constraing embedding model to CPU
|
| 17 |
|
| 18 |
source_frame_paths = {
|
| 19 |
'Pororo': '/playpen-ssd/adyasha/projects/StoryGAN/pororo_png/Pororo_ENGLISH1_2/Pororo_ENGLISH1_2_ep6/12.png',
|
|
|
|
| 28 |
}
|
| 29 |
|
| 30 |
|
| 31 |
+
def get_span_words(span, document):
|
| 32 |
+
return ' '.join(document[span[0]:span[1]+1])
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def print_clusters(prediction):
|
| 36 |
+
document, clusters = prediction['document'], prediction['clusters']
|
| 37 |
+
for cluster in clusters:
|
| 38 |
+
print(get_span_words(cluster[0], document) + ': ', end='')
|
| 39 |
+
print(f"[{'; '.join([get_span_words(span, document) for span in cluster])}]")
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def resolve_coref(captions, captions_mask, coref_predictor):
|
| 43 |
+
sent_counts = []
|
| 44 |
+
doc = ''
|
| 45 |
+
for cap, mask in zip(captions, captions_mask):
|
| 46 |
+
if mask == 0:
|
| 47 |
+
sent_counts.append(0)
|
| 48 |
+
else:
|
| 49 |
+
print(cap)
|
| 50 |
+
count = len([c.strip() for c in cap.split('.') if c.strip()])
|
| 51 |
+
sent_counts.append(count)
|
| 52 |
+
doc += cap + ' '
|
| 53 |
+
|
| 54 |
+
# print(doc)
|
| 55 |
+
|
| 56 |
+
doc = doc.strip()
|
| 57 |
+
resolved_doc = coref_predictor.coref_resolved(doc)
|
| 58 |
+
# print(resolved_doc)
|
| 59 |
+
# print(sent_counts)
|
| 60 |
+
|
| 61 |
+
sents = resolved_doc.split('. ')
|
| 62 |
+
resolved_captions = []
|
| 63 |
+
for i, (count, mask) in enumerate(zip(sent_counts, captions_mask)):
|
| 64 |
+
if mask == 0:
|
| 65 |
+
resolved_captions.append('')
|
| 66 |
+
else:
|
| 67 |
+
new_cap = '. '.join(sents[sum(sent_counts[:i]):sum(sent_counts[:i]) + count])
|
| 68 |
+
new_cap = new_cap.strip()
|
| 69 |
+
if new_cap[-1] not in ['!', '?', '.']:
|
| 70 |
+
new_cap += '.'
|
| 71 |
+
resolved_captions.append(new_cap)
|
| 72 |
+
|
| 73 |
+
return resolved_captions
|
| 74 |
+
|
| 75 |
+
|
| 76 |
def inverse_normalize(tensor, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)):
|
| 77 |
mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device)
|
| 78 |
std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device)
|
|
|
|
| 116 |
|
| 117 |
|
| 118 |
def main(args):
|
| 119 |
+
|
| 120 |
#device = 'cuda:0'
|
| 121 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 122 |
+
# device = torch.device('cpu')
|
| 123 |
|
| 124 |
model_url = 'https://drive.google.com/u/1/uc?id=1KAXVtE8lEE2Yc83VY7w6ycOOMkdWbmJo&export=sharing'
|
| 125 |
|
|
|
|
| 132 |
#assert os.path.exists("./ckpt/25.pth")
|
| 133 |
gdown.download(png_url, quiet=True, use_cookies=False, output="demo_pororo_good.png")
|
| 134 |
|
| 135 |
+
coref_model_url = 'https://storage.googleapis.com/allennlp-public-models/coref-spanbert-large-2020.02.27.tar.gz'
|
| 136 |
+
coref_predictor = Predictor.from_path(coref_model_url)
|
| 137 |
+
|
| 138 |
if args.debug:
|
| 139 |
model = None
|
| 140 |
embed = None
|
|
|
|
| 142 |
model, config = StoryDalle.from_pretrained(args)
|
| 143 |
model.tokenizer.add_tokens(['pororo', 'loopy', 'eddy', 'harry', 'poby', 'tongtong', 'crong', 'rody', 'petty'])
|
| 144 |
model.eval()
|
| 145 |
+
# split_model into CPU and GPU
|
| 146 |
+
if args.split_memory:
|
| 147 |
+
model.stage2.to(device=device)
|
| 148 |
+
model.story_linear.to(device=device)
|
| 149 |
+
model.story_block.to(device=device)
|
| 150 |
+
else:
|
| 151 |
+
model.to(device=device)
|
| 152 |
+
if model.config.story.condition:
|
| 153 |
+
for i in range(len(model.cross_attention_layers)):
|
| 154 |
+
model.cross_attention_layers[i].to(device)
|
| 155 |
+
print("Cross-attention layers are in cuda:", next(model.cross_attention_layers[0].parameters()).is_cuda)
|
| 156 |
+
|
| 157 |
embed = hub.load("https://tfhub.dev/google/universal-sentence-encoder-large/5")
|
| 158 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
|
| 160 |
valid_transform = transforms.Compose(
|
| 161 |
[transforms.Resize(config.dataset.image_resolution),
|
|
|
|
| 164 |
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]
|
| 165 |
)
|
| 166 |
|
| 167 |
+
print("Model is in ", model.device)
|
| 168 |
+
|
| 169 |
#torch.save(model, './ckpt/checkpoint.pt')
|
| 170 |
#sys.exit()
|
| 171 |
|
|
|
|
| 173 |
supercondition=False):
|
| 174 |
|
| 175 |
if not args.debug:
|
| 176 |
+
|
| 177 |
+
suffix = random.randint(0, 1000)
|
| 178 |
+
img_file_path = "./demo/images/gradio_demo_pororo_%s.png" % suffix
|
| 179 |
+
txt_file_path = "./demo/texts/gradio_demo_pororo_%s.txt" % suffix
|
| 180 |
+
|
| 181 |
+
captions = [caption_1.strip(), caption_2.strip(), caption_3.strip(), caption_4.strip()]
|
| 182 |
+
for i in range(len(captions)):
|
| 183 |
+
if captions[i][-1] not in ['!', '?', '.']:
|
| 184 |
+
captions[i] = captions[i] + '.'
|
| 185 |
mask = [1 if caption != '' else 0 for caption in captions]
|
| 186 |
+
|
| 187 |
+
with open(txt_file_path, 'w') as f:
|
| 188 |
+
f.write('\n'.join(captions))
|
| 189 |
+
|
| 190 |
print(captions, mask, source, n_candidates)
|
| 191 |
+
captions = resolve_coref(captions, mask, coref_predictor)
|
| 192 |
+
print(captions)
|
| 193 |
+
|
| 194 |
for i, caption in enumerate(captions):
|
| 195 |
if caption == "":
|
| 196 |
+
captions[i] = "Pororo is reading a book." # filler for shorter captions
|
| 197 |
+
|
| 198 |
tokens = [model.tokenizer.encode(caption) for caption in captions]
|
| 199 |
texts = torch.stack([torch.LongTensor(token.ids) for token in tokens]).unsqueeze(0)
|
| 200 |
sent_embeds = torch.tensor(embed(captions).numpy())
|
|
|
|
|
|
|
|
|
|
| 201 |
src_image = valid_transform(Image.open('./demo/%s.png' % source).convert('RGB'))
|
| 202 |
|
| 203 |
stories = []
|
| 204 |
with torch.no_grad():
|
| 205 |
for i in range(texts.shape[0]):
|
| 206 |
+
candidates = []
|
| 207 |
+
# for _ in range(n_candidates):
|
| 208 |
+
# if args.split_memory: # if splitting model into CPU/GPU, send src_image from CPU memory
|
| 209 |
+
# pixels = model.sampling_batch(texts[i].to(device), src_image.unsqueeze(0),
|
| 210 |
+
# sent_embeds.unsqueeze(0).to(device), top_k=top_k, top_p=top_p,
|
| 211 |
+
# prompt=None, n_candidates=1, device=device).cpu()
|
| 212 |
+
# else:
|
| 213 |
+
# pixels = model.sampling_batch(texts[i].to(device), src_image.unsqueeze(0).to(device),
|
| 214 |
+
# sent_embeds.unsqueeze(0).to(device), top_k=top_k, top_p=top_p,
|
| 215 |
+
# prompt=None, n_candidates=1).cpu()
|
| 216 |
+
# print(pixels.shape)
|
| 217 |
+
# candidates.append(pixels.squeeze())
|
| 218 |
+
# stories.append(torch.stack(candidates))
|
| 219 |
+
#with torch.cuda.amp.autocast():
|
| 220 |
+
|
| 221 |
pixels = model.sampling_batch(texts[i].to(device), src_image.unsqueeze(0).to(device),
|
| 222 |
+
sent_embeds.unsqueeze(0).to(device), top_k=top_k, top_p=top_p,
|
| 223 |
+
prompt=None, n_candidates=n_candidates).cpu()
|
| 224 |
stories.append(pixels)
|
|
|
|
| 225 |
img = save_story_results(stories, video_len=4, n_candidates=n_candidates, mask=mask)
|
| 226 |
+
save_image(img, img_file_path, normalize=True)
|
| 227 |
+
|
| 228 |
+
else:
|
| 229 |
+
img_file_path = "gradio_demo_pororo.png"
|
| 230 |
|
| 231 |
+
return img_file_path
|
| 232 |
|
| 233 |
with gr.Blocks(css='#output {width:750px; height:750px; float:left;}') as demo:
|
| 234 |
gr.Markdown('''
|
|
|
|
| 263 |
Here are some examples of generated visual stories for the above-mentioned settings.
|
| 264 |
|
| 265 |
<p align="center">
|
| 266 |
+
<img src="file/demo_pororo_good_v1.png" width="1000">
|
| 267 |
</p>
|
| 268 |
|
| 269 |
Due to the small training dataset size for story visualization, the model has poor generalization to some unseen settings. The model struggles to generate coherent images in the following scenarios.
|
|
|
|
| 329 |
\[4\] Sharma, Piyush, et al. "Conceptual captions: A cleaned, hypernymed, image alt-text dataset for automatic image captioning." Proceedings of the 56th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers). 2018.
|
| 330 |
''')
|
| 331 |
|
| 332 |
+
demo.launch(share=False)
|
| 333 |
|
| 334 |
|
| 335 |
if __name__ == "__main__":
|
| 336 |
+
|
| 337 |
args_list = ['--model_name_or_path', './ckpt/25.pth',
|
| 338 |
'--prefix_model_name_or_path', './1.3B/',
|
| 339 |
'--dataset_name', 'pororo',
|
|
|
|
| 445 |
)
|
| 446 |
|
| 447 |
parser.add_argument("--debug", action="store_true", help="Whether to debug the demo.")
|
| 448 |
+
parser.add_argument("--split_memory", action="store_true", help="Whether to split the model into GPU & CPU in the demo.")
|
| 449 |
|
| 450 |
args = parser.parse_args(args_list)
|
| 451 |
|
dalle/models/__init__.py
CHANGED
|
@@ -1094,7 +1094,7 @@ class PromptConditionalDalle(Dalle):
|
|
| 1094 |
prompt = self.get_prompt(bsz=5, eval=True)
|
| 1095 |
|
| 1096 |
images = []
|
| 1097 |
-
for t in texts:
|
| 1098 |
pixels = self.sampling(t, prompt, top_k=64, num_candidates=5, labels=codes[i]).cpu().numpy()
|
| 1099 |
pixels = np.transpose(pixels, (0, 2, 3, 1))
|
| 1100 |
images.append(pixels)
|
|
@@ -1211,7 +1211,6 @@ class StoryDalle(Dalle):
|
|
| 1211 |
lowercase=True,
|
| 1212 |
dropout=None)
|
| 1213 |
|
| 1214 |
-
|
| 1215 |
return model, config_update
|
| 1216 |
|
| 1217 |
|
|
@@ -1224,6 +1223,7 @@ class StoryDalle(Dalle):
|
|
| 1224 |
resid_pdrop=hparams.resid_pdrop,
|
| 1225 |
attn_pdrop=hparams.attn_pdrop) for i in cross_attention_layers]
|
| 1226 |
|
|
|
|
| 1227 |
def get_prompt_p5(self, bsz=None, eval=False):
|
| 1228 |
input_tokens = self.input_tokens.unsqueeze(0).expand(bsz, -1).to(self.device)
|
| 1229 |
temp_control = self.wte(input_tokens)
|
|
@@ -1232,6 +1232,7 @@ class StoryDalle(Dalle):
|
|
| 1232 |
past_key_values = self.dropout(past_key_values)
|
| 1233 |
return past_key_values
|
| 1234 |
|
|
|
|
| 1235 |
def forward(self,
|
| 1236 |
images: torch.FloatTensor,
|
| 1237 |
src_images: Optional[torch.FloatTensor],
|
|
@@ -1287,6 +1288,7 @@ class StoryDalle(Dalle):
|
|
| 1287 |
# print(logits_img.shape, logits_txt.shape, codes.shape, texts.shape)
|
| 1288 |
return logits_img, logits_txt, codes
|
| 1289 |
|
|
|
|
| 1290 |
@torch.no_grad()
|
| 1291 |
def sampling(self,
|
| 1292 |
tokens: torch.LongTensor,
|
|
@@ -1327,6 +1329,7 @@ class StoryDalle(Dalle):
|
|
| 1327 |
|
| 1328 |
#with autocast(enabled=False):
|
| 1329 |
src_codes = self.stage1.get_codes(source).detach()
|
|
|
|
| 1330 |
src_codes = torch.repeat_interleave(src_codes, self.config.story.story_len, dim=0)
|
| 1331 |
print(tokens.shape, src_codes.shape, prompt.shape)
|
| 1332 |
if self.config.story.condition:
|
|
@@ -1355,6 +1358,7 @@ class StoryDalle(Dalle):
|
|
| 1355 |
pixels = torch.clamp(self.stage1.decode_code(codes) * 0.5 + 0.5, 0, 1) # [B, 256, 256]
|
| 1356 |
return pixels
|
| 1357 |
|
|
|
|
| 1358 |
@torch.no_grad()
|
| 1359 |
def sampling_batch(self,
|
| 1360 |
tokens: torch.LongTensor,
|
|
@@ -1363,10 +1367,8 @@ class StoryDalle(Dalle):
|
|
| 1363 |
top_k: int = 256,
|
| 1364 |
top_p: Optional[float] = None,
|
| 1365 |
softmax_temperature: float = 1.0,
|
| 1366 |
-
num_candidates: int = 96,
|
| 1367 |
device: str = 'cuda:0',
|
| 1368 |
use_fp16: bool = True,
|
| 1369 |
-
labels=None,
|
| 1370 |
prompt=None, n_candidates=1) -> torch.FloatTensor:
|
| 1371 |
|
| 1372 |
self.stage1.eval()
|
|
@@ -1396,37 +1398,40 @@ class StoryDalle(Dalle):
|
|
| 1396 |
|
| 1397 |
#with autocast(enabled=False):
|
| 1398 |
src_codes = self.stage1.get_codes(source).detach()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1399 |
|
| 1400 |
-
# repeat inputs to adjust to n_candidates and story length
|
| 1401 |
-
src_codes = torch.repeat_interleave(src_codes, self.config.story.story_len * n_candidates, dim=0)
|
| 1402 |
-
prompt = prompt.repeat(n_candidates, 1, 1)
|
| 1403 |
-
pos_enc_prompt = pos_enc_prompt.repeat(n_candidates, 1)
|
| 1404 |
-
tokens = tokens.repeat(n_candidates, 1)
|
| 1405 |
-
print(tokens.shape, src_codes.shape, prompt.shape, pos_enc_prompt.shape)
|
| 1406 |
-
if self.config.story.condition:
|
| 1407 |
-
codes = sampling_conditional(self.stage2,
|
| 1408 |
-
self.cross_attention_idxs,
|
| 1409 |
-
self.cross_attention_layers,
|
| 1410 |
-
tokens,
|
| 1411 |
-
src_codes,
|
| 1412 |
-
top_k=top_k,
|
| 1413 |
-
top_p=top_p,
|
| 1414 |
-
softmax_temperature=softmax_temperature,
|
| 1415 |
-
use_fp16=use_fp16,
|
| 1416 |
-
prompt=prompt,
|
| 1417 |
-
pos_prompt=pos_enc_prompt)
|
| 1418 |
-
else:
|
| 1419 |
-
codes = sampling(self.stage2,
|
| 1420 |
-
tokens,
|
| 1421 |
-
top_k=top_k,
|
| 1422 |
-
top_p=top_p,
|
| 1423 |
-
softmax_temperature=softmax_temperature,
|
| 1424 |
-
use_fp16=use_fp16,
|
| 1425 |
-
prompt=prompt,
|
| 1426 |
-
pos_prompt=pos_enc_prompt)
|
| 1427 |
-
|
| 1428 |
-
codes = codes.view(self.config.story.story_len * n_candidates, 16, 16) # [B, 16, 16]
|
| 1429 |
-
print(codes.shape)
|
| 1430 |
pixels = torch.clamp(self.stage1.decode_code(codes) * 0.5 + 0.5, 0, 1) # [B, 3, 256, 256]
|
| 1431 |
print(pixels.shape)
|
| 1432 |
return pixels.view(n_candidates, self.config.story.story_len, pixels.shape[-3], pixels.shape[-2], pixels.shape[-1])
|
|
@@ -1444,11 +1449,10 @@ class StoryDalle(Dalle):
|
|
| 1444 |
pred = pred.view(bs, 16, 16) # [B, 16, 16]
|
| 1445 |
pixels = torch.clamp(self.stage1.decode_code(pred) * 0.5 + 0.5, 0, 1).cpu().numpy() # [B, 256, 256]
|
| 1446 |
pixels = np.transpose(pixels, (0, 2, 3, 1))
|
| 1447 |
-
|
| 1448 |
prompt = self.get_prompt(bsz=5, eval=True)
|
| 1449 |
|
| 1450 |
images = []
|
| 1451 |
-
for t in texts:
|
| 1452 |
pixels = self.sampling(t, prompt, top_k=64, num_candidates=5, labels=codes[i]).cpu().numpy()
|
| 1453 |
pixels = np.transpose(pixels, (0, 2, 3, 1))
|
| 1454 |
images.append(pixels)
|
|
|
|
| 1094 |
prompt = self.get_prompt(bsz=5, eval=True)
|
| 1095 |
|
| 1096 |
images = []
|
| 1097 |
+
for i, t in enumerate(texts):
|
| 1098 |
pixels = self.sampling(t, prompt, top_k=64, num_candidates=5, labels=codes[i]).cpu().numpy()
|
| 1099 |
pixels = np.transpose(pixels, (0, 2, 3, 1))
|
| 1100 |
images.append(pixels)
|
|
|
|
| 1211 |
lowercase=True,
|
| 1212 |
dropout=None)
|
| 1213 |
|
|
|
|
| 1214 |
return model, config_update
|
| 1215 |
|
| 1216 |
|
|
|
|
| 1223 |
resid_pdrop=hparams.resid_pdrop,
|
| 1224 |
attn_pdrop=hparams.attn_pdrop) for i in cross_attention_layers]
|
| 1225 |
|
| 1226 |
+
|
| 1227 |
def get_prompt_p5(self, bsz=None, eval=False):
|
| 1228 |
input_tokens = self.input_tokens.unsqueeze(0).expand(bsz, -1).to(self.device)
|
| 1229 |
temp_control = self.wte(input_tokens)
|
|
|
|
| 1232 |
past_key_values = self.dropout(past_key_values)
|
| 1233 |
return past_key_values
|
| 1234 |
|
| 1235 |
+
|
| 1236 |
def forward(self,
|
| 1237 |
images: torch.FloatTensor,
|
| 1238 |
src_images: Optional[torch.FloatTensor],
|
|
|
|
| 1288 |
# print(logits_img.shape, logits_txt.shape, codes.shape, texts.shape)
|
| 1289 |
return logits_img, logits_txt, codes
|
| 1290 |
|
| 1291 |
+
|
| 1292 |
@torch.no_grad()
|
| 1293 |
def sampling(self,
|
| 1294 |
tokens: torch.LongTensor,
|
|
|
|
| 1329 |
|
| 1330 |
#with autocast(enabled=False):
|
| 1331 |
src_codes = self.stage1.get_codes(source).detach()
|
| 1332 |
+
|
| 1333 |
src_codes = torch.repeat_interleave(src_codes, self.config.story.story_len, dim=0)
|
| 1334 |
print(tokens.shape, src_codes.shape, prompt.shape)
|
| 1335 |
if self.config.story.condition:
|
|
|
|
| 1358 |
pixels = torch.clamp(self.stage1.decode_code(codes) * 0.5 + 0.5, 0, 1) # [B, 256, 256]
|
| 1359 |
return pixels
|
| 1360 |
|
| 1361 |
+
|
| 1362 |
@torch.no_grad()
|
| 1363 |
def sampling_batch(self,
|
| 1364 |
tokens: torch.LongTensor,
|
|
|
|
| 1367 |
top_k: int = 256,
|
| 1368 |
top_p: Optional[float] = None,
|
| 1369 |
softmax_temperature: float = 1.0,
|
|
|
|
| 1370 |
device: str = 'cuda:0',
|
| 1371 |
use_fp16: bool = True,
|
|
|
|
| 1372 |
prompt=None, n_candidates=1) -> torch.FloatTensor:
|
| 1373 |
|
| 1374 |
self.stage1.eval()
|
|
|
|
| 1398 |
|
| 1399 |
#with autocast(enabled=False):
|
| 1400 |
src_codes = self.stage1.get_codes(source).detach()
|
| 1401 |
+
# src_codes = src_codes.to(device=device) #ensure that src_codes is moved to GPU in case VQGAN was kept in CPU
|
| 1402 |
+
|
| 1403 |
+
with torch.autocast(device_type='cuda', dtype=torch.float16):
|
| 1404 |
+
# repeat inputs to adjust to n_candidates and story length
|
| 1405 |
+
src_codes = torch.repeat_interleave(src_codes, self.config.story.story_len * n_candidates, dim=0)
|
| 1406 |
+
prompt = prompt.repeat(n_candidates, 1, 1)
|
| 1407 |
+
pos_enc_prompt = pos_enc_prompt.repeat(n_candidates, 1)
|
| 1408 |
+
tokens = tokens.repeat(n_candidates, 1)
|
| 1409 |
+
print(tokens.shape, src_codes.shape, prompt.shape, pos_enc_prompt.shape)
|
| 1410 |
+
if self.config.story.condition:
|
| 1411 |
+
codes = sampling_conditional(self.stage2,
|
| 1412 |
+
self.cross_attention_idxs,
|
| 1413 |
+
self.cross_attention_layers,
|
| 1414 |
+
tokens,
|
| 1415 |
+
src_codes,
|
| 1416 |
+
top_k=top_k,
|
| 1417 |
+
top_p=top_p,
|
| 1418 |
+
softmax_temperature=softmax_temperature,
|
| 1419 |
+
use_fp16=use_fp16,
|
| 1420 |
+
prompt=prompt,
|
| 1421 |
+
pos_prompt=pos_enc_prompt)
|
| 1422 |
+
else:
|
| 1423 |
+
codes = sampling(self.stage2,
|
| 1424 |
+
tokens,
|
| 1425 |
+
top_k=top_k,
|
| 1426 |
+
top_p=top_p,
|
| 1427 |
+
softmax_temperature=softmax_temperature,
|
| 1428 |
+
use_fp16=use_fp16,
|
| 1429 |
+
prompt=prompt,
|
| 1430 |
+
pos_prompt=pos_enc_prompt)
|
| 1431 |
+
|
| 1432 |
+
codes = codes.view(self.config.story.story_len * n_candidates, 16, 16) # [B, 16, 16]
|
| 1433 |
+
print(codes.shape)
|
| 1434 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1435 |
pixels = torch.clamp(self.stage1.decode_code(codes) * 0.5 + 0.5, 0, 1) # [B, 3, 256, 256]
|
| 1436 |
print(pixels.shape)
|
| 1437 |
return pixels.view(n_candidates, self.config.story.story_len, pixels.shape[-3], pixels.shape[-2], pixels.shape[-1])
|
|
|
|
| 1449 |
pred = pred.view(bs, 16, 16) # [B, 16, 16]
|
| 1450 |
pixels = torch.clamp(self.stage1.decode_code(pred) * 0.5 + 0.5, 0, 1).cpu().numpy() # [B, 256, 256]
|
| 1451 |
pixels = np.transpose(pixels, (0, 2, 3, 1))
|
|
|
|
| 1452 |
prompt = self.get_prompt(bsz=5, eval=True)
|
| 1453 |
|
| 1454 |
images = []
|
| 1455 |
+
for i, t in enumerate(texts):
|
| 1456 |
pixels = self.sampling(t, prompt, top_k=64, num_candidates=5, labels=codes[i]).cpu().numpy()
|
| 1457 |
pixels = np.transpose(pixels, (0, 2, 3, 1))
|
| 1458 |
images.append(pixels)
|
dalle/models/__pycache__/__init__.cpython-38.pyc
CHANGED
|
Binary files a/dalle/models/__pycache__/__init__.cpython-38.pyc and b/dalle/models/__pycache__/__init__.cpython-38.pyc differ
|
|
|
dalle/models/stage2/__pycache__/layers.cpython-38.pyc
CHANGED
|
Binary files a/dalle/models/stage2/__pycache__/layers.cpython-38.pyc and b/dalle/models/stage2/__pycache__/layers.cpython-38.pyc differ
|
|
|
dalle/models/stage2/layers.py
CHANGED
|
@@ -182,8 +182,13 @@ class Block(nn.Module):
|
|
| 182 |
def sample_with_context(self, x, context, context_mask, cross_attn_layer, layer_past=None):
|
| 183 |
attn, present = self.attn(self.ln1(x), use_cache=True, layer_past=layer_past)
|
| 184 |
x = x + attn
|
| 185 |
-
|
| 186 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
x = x + self.mlp(self.ln2(x))
|
| 188 |
return x, present
|
| 189 |
|
|
|
|
| 182 |
def sample_with_context(self, x, context, context_mask, cross_attn_layer, layer_past=None):
|
| 183 |
attn, present = self.attn(self.ln1(x), use_cache=True, layer_past=layer_past)
|
| 184 |
x = x + attn
|
| 185 |
+
|
| 186 |
+
c_attn = cross_attn_layer(x.to(device=context.device),
|
| 187 |
+
context,
|
| 188 |
+
context_mask.to(device=context.device))
|
| 189 |
+
|
| 190 |
+
x = x + c_attn.to(device=x.device)
|
| 191 |
+
|
| 192 |
x = x + self.mlp(self.ln2(x))
|
| 193 |
return x, present
|
| 194 |
|
gradio_demo_pororo.png
ADDED
|
Git LFS Details
|
requirements.txt
CHANGED
|
@@ -10,3 +10,4 @@ pytorch-lightning
|
|
| 10 |
einops
|
| 11 |
tokenizers
|
| 12 |
tensorflow
|
|
|
|
|
|
| 10 |
einops
|
| 11 |
tokenizers
|
| 12 |
tensorflow
|
| 13 |
+
allennlp==2.10.0
|