File size: 13,350 Bytes
f6212ca
 
 
5e32fba
f6212ca
 
 
15b8e5b
1be336f
f6212ca
3622941
15b8e5b
 
 
 
 
 
 
 
 
 
 
9f63005
15b8e5b
 
 
 
3622941
 
f6212ca
 
 
5e32fba
f6212ca
 
 
 
 
15b8e5b
 
9f63005
a13563f
15b8e5b
 
9f63005
15b8e5b
 
 
 
 
a13563f
15b8e5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3622941
15b8e5b
 
 
 
9f63005
15b8e5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3622941
 
15b8e5b
 
 
 
9f63005
 
3622941
15b8e5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3622941
 
15b8e5b
 
 
9f63005
15b8e5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a13563f
15b8e5b
a13563f
9f63005
da70ee3
 
 
3a26cb9
da70ee3
 
 
ad9c7fd
da70ee3
 
3a26cb9
ad9c7fd
 
5e32fba
ad9c7fd
3a26cb9
 
a12ec75
ad9c7fd
 
 
 
 
 
 
 
 
 
 
 
9f63005
3622941
ad9c7fd
9f63005
ad9c7fd
 
9f63005
 
3622941
ad9c7fd
9f63005
ad9c7fd
 
9f63005
ad9c7fd
 
9f63005
ad9c7fd
 
 
 
9f63005
ad9c7fd
 
 
9f63005
3622941
ad9c7fd
 
 
9f63005
ad9c7fd
13d07d6
ad9c7fd
 
9f63005
ad9c7fd
 
 
 
 
 
9f63005
ad9c7fd
 
 
 
 
d0ff9d9
7bf9267
ad9c7fd
d0ff9d9
9f63005
5e32fba
9f63005
 
 
 
 
 
 
 
 
3622941
 
 
 
9f63005
 
15b8e5b
3622941
 
9f63005
3622941
 
 
 
 
 
 
 
 
9f63005
 
3622941
9f63005
 
 
 
3622941
9f63005
 
3622941
 
 
 
 
 
 
 
9f63005
3622941
 
 
 
9f63005
15b8e5b
f6212ca
5e32fba
3622941
 
 
 
 
 
 
 
 
 
 
9f63005
05acf34
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
import os
import torch
import numpy as np
from PIL import Image
import gradio as gr
from transformers import AutoModelForCausalLM
from janus.models import MultiModalityCausalLM, VLChatProcessor
from dataclasses import dataclass
import spaces

# This dataclass definition is required for the processor
@dataclass
class VLChatProcessorOutput():
    sft_format: str
    input_ids: torch.Tensor
    pixel_values: torch.Tensor
    num_image_tokens: torch.IntTensor

    def __len__(self):
        return len(self.input_ids)

def process_image(image_paths, vl_chat_processor):
    """Processes a list of image paths into pixel values."""
    images = [Image.open(image_path).convert("RGB") for image_path in image_paths]
    images_outputs = vl_chat_processor.image_processor(images, return_tensors="pt")
    return images_outputs['pixel_values']

# === Load Janus model and processor ===
# This setup assumes the necessary model files are accessible.
model_path = "FreedomIntelligence/Janus-4o-7B"
vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
tokenizer = vl_chat_processor.tokenizer

vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(
    model_path, trust_remote_code=True, torch_dtype=torch.bfloat16
)
vl_gpt = vl_gpt.cuda().eval()

# === Text-and-Image-to-Image generation ===
def text_and_image_to_image_generate(input_prompt, input_image_path, output_path, vl_chat_processor, vl_gpt, temperature=1.0, parallel_size=2, cfg_weight=5, cfg_weight2=5):
    """Generates an image from a text prompt and an input image."""
    torch.cuda.empty_cache()

    input_img_tokens = vl_chat_processor.image_start_tag + vl_chat_processor.image_tag * vl_chat_processor.num_image_tokens + vl_chat_processor.image_end_tag + vl_chat_processor.image_start_tag + vl_chat_processor.pad_tag * vl_chat_processor.num_image_tokens + vl_chat_processor.image_end_tag
    output_img_tokens = vl_chat_processor.image_start_tag

    pre_data = []
    input_images = [input_image_path]
    img_len = len(input_images)
    prompts = input_img_tokens * img_len + input_prompt
    conversation = [
        {"role": "<|User|>", "content": prompts},
        {"role": "<|Assistant|>", "content": ""}
    ]
    sft_format = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(
        conversations=conversation,
        sft_format=vl_chat_processor.sft_format,
        system_prompt="",
    )

    sft_format = sft_format + output_img_tokens

    image_token_num_per_image = 576
    img_size = 384
    patch_size = 16

    with torch.inference_mode():
        input_image_pixel_values = process_image(input_images, vl_chat_processor).to(torch.bfloat16).cuda()
        _, _, info_input = vl_gpt.gen_vision_model.encode(input_image_pixel_values)
        image_tokens_input = info_input[2].detach().reshape(input_image_pixel_values.shape[0], -1)
        image_embeds_input = vl_gpt.prepare_gen_img_embeds(image_tokens_input)

        input_ids = torch.LongTensor(vl_chat_processor.tokenizer.encode(sft_format))

        encoder_pixel_values = process_image(input_images, vl_chat_processor).cuda()
        tokens = torch.zeros((parallel_size * 3, len(input_ids)), dtype=torch.long)
        for i in range(parallel_size * 3):
            tokens[i, :] = input_ids
            if i % 3 == 2:
                tokens[i, 1:-1] = vl_chat_processor.pad_id
                pre_data.append(VLChatProcessorOutput(sft_format=sft_format, pixel_values=encoder_pixel_values, input_ids=tokens[i-2], num_image_tokens=[vl_chat_processor.num_image_tokens] * img_len))
                pre_data.append(VLChatProcessorOutput(sft_format=sft_format, pixel_values=encoder_pixel_values, input_ids=tokens[i-1], num_image_tokens=[vl_chat_processor.num_image_tokens] * img_len))
                pre_data.append(VLChatProcessorOutput(sft_format=sft_format, pixel_values=None, input_ids=tokens[i], num_image_tokens=[]))

        prepare_inputs = vl_chat_processor.batchify(pre_data)

        inputs_embeds = vl_gpt.prepare_inputs_embeds(
            input_ids=tokens.cuda(),
            pixel_values=prepare_inputs['pixel_values'].to(torch.bfloat16).cuda(),
            images_emb_mask=prepare_inputs['images_emb_mask'].cuda(),
            images_seq_mask=prepare_inputs['images_seq_mask'].cuda()
        )

        image_gen_indices = (tokens == vl_chat_processor.image_end_id).nonzero()

        for ii, ind in enumerate(image_gen_indices):
            if ii % 4 == 0:
                offset = ind[1] + 2
                inputs_embeds[ind[0], offset: offset + image_embeds_input.shape[1], :] = image_embeds_input[(ii // 2) % img_len]

        generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).cuda()
        
        # --- FIX: Initialize past_key_values for cached generation ---
        past_key_values = None

        for i in range(image_token_num_per_image):
            outputs = vl_gpt.language_model.model(
                inputs_embeds=inputs_embeds,
                use_cache=True,
                past_key_values=past_key_values # Pass cached values
            )
            hidden_states = outputs.last_hidden_state

            logits = vl_gpt.gen_head(hidden_states[:, -1, :])
            logit_cond_full = logits[0::3, :]
            logit_cond_part = logits[1::3, :]
            logit_uncond = logits[2::3, :]

            logit_cond = (logit_cond_full + cfg_weight2 * logit_cond_part) / (1 + cfg_weight2)
            logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
            probs = torch.softmax(logits / temperature, dim=-1)

            next_token = torch.multinomial(probs, num_samples=1)
            generated_tokens[:, i] = next_token.squeeze(dim=-1)

            next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
            img_embeds = vl_gpt.prepare_gen_img_embeds(next_token)
            inputs_embeds = img_embeds.unsqueeze(dim=1)
            
            # --- FIX: Update past_key_values with the output from the current step ---
            past_key_values = outputs.past_key_values

        dec = vl_gpt.gen_vision_model.decode_code(
            generated_tokens.to(dtype=torch.int),
            shape=[parallel_size, 8, img_size // patch_size, img_size // patch_size]
        )
        dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
        dec = np.clip((dec + 1) / 2 * 255, 0, 255)

        visual_img = np.zeros((parallel_size, img_size, img_size, 3), dtype=np.uint8)
        visual_img[:, :, :] = dec

        output_dir = os.path.dirname(output_path)
        if output_dir:
            os.makedirs(output_dir, exist_ok=True)

        output_images = []
        for i in range(parallel_size):
            save_path = output_path.replace('.png', f'_{i}.png')
            Image.fromarray(visual_img[i]).save(save_path)
            output_images.append(save_path)

    torch.cuda.empty_cache()
    return output_images

# === Text-to-Image generation ===
def text_to_image_generate(input_prompt, output_path, vl_chat_processor, vl_gpt, temperature=1.0, parallel_size=2, cfg_weight=5.0):
    """Generates an image from a text prompt only."""
    torch.cuda.empty_cache()

    conversation = [
        {"role": "<|User|>", "content": input_prompt},
        {"role": "<|Assistant|>", "content": ""},
    ]

    sft_format = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(
        conversations=conversation,
        sft_format=vl_chat_processor.sft_format,
        system_prompt="",
    )
    prompt = sft_format + vl_chat_processor.image_start_tag

    image_token_num_per_image = 576
    img_size = 384
    patch_size = 16

    with torch.inference_mode():
        input_ids = vl_chat_processor.tokenizer.encode(prompt)
        input_ids = torch.LongTensor(input_ids)
        tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).cuda()

        for i in range(parallel_size * 2):
            tokens[i, :] = input_ids
            if i % 2 != 0:
                tokens[i, 1:-1] = vl_chat_processor.pad_id

        inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
        generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).cuda()

        # --- FIX: Initialize past_key_values for cached generation ---
        past_key_values = None

        for i in range(image_token_num_per_image):
            outputs = vl_gpt.language_model.model(
                inputs_embeds=inputs_embeds,
                use_cache=True,
                past_key_values=past_key_values # Pass cached values
            )

            hidden_states = outputs.last_hidden_state
            logits = vl_gpt.gen_head(hidden_states[:, -1, :])

            logit_cond = logits[0::2, :]
            logit_uncond = logits[1::2, :]

            logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
            probs = torch.softmax(logits / temperature, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            generated_tokens[:, i] = next_token.squeeze(dim=-1)

            next_token_expanded = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
            img_embeds = vl_gpt.prepare_gen_img_embeds(next_token_expanded)
            inputs_embeds = img_embeds.unsqueeze(dim=1)

            # --- FIX: Update past_key_values with the output from the current step ---
            past_key_values = outputs.past_key_values

        dec = vl_gpt.gen_vision_model.decode_code(
            generated_tokens.to(dtype=torch.int),
            shape=[parallel_size, 8, img_size // patch_size, img_size // patch_size]
        )
        dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
        dec = np.clip((dec + 1) / 2 * 255, 0, 255)

        visual_img = np.zeros((parallel_size, img_size, img_size, 3), dtype=np.uint8)
        visual_img[:, :, :] = dec

        output_dir = os.path.dirname(output_path)
        if output_dir:
            os.makedirs(output_dir, exist_ok=True)

        output_images = []
        for i in range(parallel_size):
            save_path = output_path.replace('.png', f'_{i}.png')
            Image.fromarray(visual_img[i]).save(save_path)
            output_images.append(save_path)

    torch.cuda.empty_cache()
    return output_images

# === Unified Gradio handler for ChatInterface ===
@spaces.GPU(duration=120)
def janus_chat_responder(message, history):
    """
    Handles both text-only and multimodal (text+image) inputs from the ChatInterface.
    'message' is a dictionary with 'text' and 'files' keys.
    """
    output_path = "./output/chat_image.png"
    prompt = message["text"]
    uploaded_files = message["files"]

    try:
        if uploaded_files:
            # Handle text+image to image generation
            temp_image_path = uploaded_files[0]
            images = text_and_image_to_image_generate(
                prompt, temp_image_path, output_path, vl_chat_processor, vl_gpt
            )
        else:
            # Handle text-to-image generation
            images = text_to_image_generate(prompt, output_path, vl_chat_processor, vl_gpt)
        
        # Return a gallery component to display all generated images
        return gr.Gallery(value=images, label="Generated Images")

    except Exception as e:
        # Return a user-friendly error message
        gr.Error(f"An error occurred during generation: {str(e)}")
        # Return None or an empty list for the gallery to clear it
        return None


# === Gradio UI with a single ChatInterface ===
with gr.Blocks(theme="soft", title="Janus Image Generation") as demo:
    gr.Markdown("# Janus Multi-Modal Image Generation")
    gr.Markdown("Generate images from text prompts, or upload an image and a prompt to transform it.")

    # Using gr.ChatInterface which handles the chat history and input box automatically
    gr.ChatInterface(
        fn=janus_chat_responder,
        multimodal=True, # Enables file uploads
        title="Janus-4o-7B",
        chatbot=gr.Chatbot(height=400, label="Chat", show_label=False),
        textbox=gr.MultimodalTextbox(
            file_types=["image"],
            placeholder="Type a prompt or upload an image...",
            label="Input"
        ),
        examples=[
            {"text": "A cat made of glass, sitting on a table.", "files": []},
            {"text": "A futuristic city at sunset, with flying cars.", "files": []},
            {"text": "A dragon breathing fire over a medieval castle.", "files": []},
            {"text": "Turn this into a watercolor painting.", "files": ["./assets/example_image.jpg"]}
        ]
    )

if __name__ == "__main__":
    # Create a dummy image for the example if it doesn't exist to prevent errors
    assets_dir = "./assets"
    example_image_path = os.path.join(assets_dir, "example_image.jpg")
    if not os.path.exists(example_image_path):
        os.makedirs(assets_dir, exist_ok=True)
        try:
            dummy_image = Image.new('RGB', (384, 384), color = 'red')
            dummy_image.save(example_image_path)
            print(f"Created dummy example image at: {example_image_path}")
        except Exception as e:
            print(f"Could not create dummy image: {e}")

    demo.launch()