File size: 5,460 Bytes
37f4150
 
 
 
 
 
2242fb6
37f4150
 
 
1ee8357
37f4150
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2242fb6
3eb8145
37f4150
 
 
 
2242fb6
37f4150
3eb8145
1ee8357
37f4150
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a50b0f8
2242fb6
37f4150
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2242fb6
 
 
37f4150
2242fb6
37f4150
 
185d799
 
 
37f4150
 
 
 
 
185d799
37f4150
 
2242fb6
 
 
37f4150
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f84af5
 
37f4150
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import gradio as gr
import spaces
import torch
from diffusers import DiffusionPipeline
from transformers import T5EncoderModel
import tempfile
import os

# Global variable to store the text pipeline
text_pipe = None
device = 'cuda' if torch.cuda.is_available() else 'cpu'

def load_model():
    """Load the T5 text encoder model"""
    global text_pipe
    if text_pipe is None:
        print("Loading T5 text encoder...")
        
        # Get token from environment
        token = os.getenv("HF_TOKEN")
        
        text_encoder = T5EncoderModel.from_pretrained(
            "DeepFloyd/IF-I-L-v1.0",
            subfolder="text_encoder",
            load_in_8bit=True,
            variant="8bit",
            token=token
        )
        text_pipe = DiffusionPipeline.from_pretrained(
            "DeepFloyd/IF-I-L-v1.0",
            text_encoder=text_encoder,
            unet=None,
            token=token,
        )
        text_pipe = text_pipe.to(device)
        print("Model loaded successfully!")
    return text_pipe

@spaces.GPU
def generate_embeddings(prompts_text):
    """
    Generate embeddings from text prompts
    Args:
        prompts_text: String with one prompt per line
    Returns:
        Path to the saved .pth file and a status message
    """
    try:
        # Load model if not already loaded
        pipe = load_model()
        
        # Note: 8-bit models are already on the correct device, no need to move them
        
        # Parse prompts (one per line)
        prompts = [p.strip() for p in prompts_text.strip().split('\n') if p.strip()]
        
        if not prompts:
            return None, "Error: Please enter at least one prompt"
        
        # Add empty string for CFG (Classifier Free Guidance)
        if '' not in prompts:
            prompts.append('')
        
        # Generate embeddings
        print(f"Generating embeddings for {len(prompts)} prompts...")
        prompt_embeds_list = []
        for prompt in prompts:
            embeds = pipe.encode_prompt(prompt)
            prompt_embeds_list.append(embeds)
        
        # Extract positive prompt embeddings
        prompt_embeds, negative_prompt_embeds = zip(*prompt_embeds_list)
        
        # Move embeddings to CPU before saving
        prompt_embeds_cpu = [emb.cpu() if isinstance(emb, torch.Tensor) else emb for emb in prompt_embeds]
        
        # Create dictionary
        prompt_embeds_dict = dict(zip(prompts, prompt_embeds_cpu))
        
        # Save to temporary file
        temp_dir = tempfile.gettempdir()
        temp_file_path = os.path.join(temp_dir, 'prompt_embeds_dict.pth')
        torch.save(prompt_embeds_dict, temp_file_path)
        
        status_msg = f"βœ… Successfully generated embeddings for {len(prompts)} prompts!\n"
        status_msg += "Each embedding has shape: [1, 77, 4096]\n"
        status_msg += "Prompts processed:\n" + "\n".join([f"  - '{p}'" for p in prompts])
        
        return temp_file_path, status_msg
        
    except Exception as e:
        import traceback
        error_details = traceback.format_exc()
        return None, f"❌ Error: {str(e)}\n\nDetails:\n{error_details}"

# Create Gradio interface
with gr.Blocks(title="T5 Text Encoder - Embeddings Generator") as demo:
    gr.Markdown("""
    # πŸ”€ CS180 HW5: T5 Text Encoder Embeddings Generator
    
    This space uses the **DeepFloyd IF** T5 text encoder to generate embeddings from your text prompts.
    
    ### How to use:
    1. Enter your prompts in the text box (one prompt per line)
    2. Click "Generate Embeddings"
    3. Download the generated `.pth` file containing the embeddings
    
    ### About the embeddings:
    - Each embedding has shape: `[1, 77, 4096]`
    - `77` = max sequence length
    - `4096` = embedding dimension of the T5 encoder
    - An empty prompt (`''`) is automatically added for Classifier Free Guidance (CFG)
    """)
    
    with gr.Row():
        with gr.Column():
            prompts_input = gr.Textbox(
                label="Enter Prompts (one per line)",
                placeholder="an oil painting of a snowy mountain village\na photo of the amalfi coast\na photo of a man\n...",
                lines=15,
                value="""a high quality picture
an oil painting of a snowy mountain village
a photo of the amalfi coast
a photo of a man
a photo of a hipster barista
a photo of a dog
an oil painting of people around a campfire
an oil painting of an old man
a lithograph of waterfalls
a lithograph of a skull
a man wearing a hat
a high quality photo
a rocket ship
a pencil"""
            )
            
            generate_btn = gr.Button("πŸš€ Generate Embeddings", variant="primary", size="lg")
        
        with gr.Column():
            status_output = gr.Textbox(
                label="Status",
                lines=10,
                interactive=False
            )
            file_output = gr.File(
                label="Download Embeddings (.pth file)"
            )
    
    generate_btn.click(
        fn=generate_embeddings,
        inputs=[prompts_input],
        outputs=[file_output, status_output]
    )
    
    gr.Markdown("""
    ### πŸ“ Note:
    - The first run may take a while as the model needs to download (~8GB)
    - Subsequent runs will be faster
    - The generated `.pth` file can be loaded in PyTorch using: `torch.load('prompt_embeds_dict.pth')`
    """)

if __name__ == "__main__":
    demo.launch()