File size: 5,038 Bytes
37f4150
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import gradio as gr
import spaces
import torch
from diffusers import DiffusionPipeline
from transformers import T5EncoderModel
import tempfile

# Global variable to store the text pipeline
text_pipe = None

def load_model():
    """Load the T5 text encoder model"""
    global text_pipe
    if text_pipe is None:
        print("Loading T5 text encoder...")
        
        # Get token from environment
        import os
        token = os.getenv("HF_TOKEN")
        
        text_encoder = T5EncoderModel.from_pretrained(
            "DeepFloyd/IF-I-L-v1.0",
            subfolder="text_encoder",
            load_in_8bit=True,
            variant="8bit",
            device_map="auto",
            token=token  # Add this line
        )
        text_pipe = DiffusionPipeline.from_pretrained(
            "DeepFloyd/IF-I-L-v1.0",
            text_encoder=text_encoder,
            unet=None,
            token=token  # Add this line
        )
        print("Model loaded successfully!")
    return text_pipe

@spaces.GPU
def generate_embeddings(prompts_text):
    """
    Generate embeddings from text prompts
    Args:
        prompts_text: String with one prompt per line
    Returns:
        Path to the saved .pth file and a status message
    """
    try:
        # Load model if not already loaded
        pipe = load_model()
        
        # Parse prompts (one per line)
        prompts = [p.strip() for p in prompts_text.strip().split('\n') if p.strip()]
        
        if not prompts:
            return None, "Error: Please enter at least one prompt"
        
        # Add empty string for CFG (Classifier Free Guidance)
        if '' not in prompts:
            prompts.append('')
        
        # Generate embeddings
        print(f"Generating embeddings for {len(prompts)} prompts...")
        prompt_embeds_list = []
        for prompt in prompts:
            embeds = pipe.encode_prompt(prompt)
            prompt_embeds_list.append(embeds)
        
        # Extract positive prompt embeddings
        prompt_embeds, negative_prompt_embeds = zip(*prompt_embeds_list)
        
        # Create dictionary
        prompt_embeds_dict = dict(zip(prompts, prompt_embeds))
        
        # Save to temporary file
        temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.pth')
        torch.save(prompt_embeds_dict, temp_file.name)
        temp_file.close()
        
        status_msg = f"βœ… Successfully generated embeddings for {len(prompts)} prompts!\n"
        status_msg += "Each embedding has shape: [1, 77, 4096]\n"
        status_msg += "Prompts processed:\n" + "\n".join([f"  - '{p}'" for p in prompts])
        
        return temp_file.name, status_msg
        
    except Exception as e:
        return None, f"❌ Error: {str(e)}"

# Create Gradio interface
with gr.Blocks(title="T5 Text Encoder - Embeddings Generator") as demo:
    gr.Markdown("""
    # πŸ”€ CS180 HW5: T5 Text Encoder Embeddings Generator
    
    This space uses the **DeepFloyd IF** T5 text encoder to generate embeddings from your text prompts.
    
    ### How to use:
    1. Enter your prompts in the text box (one prompt per line)
    2. Click "Generate Embeddings"
    3. Download the generated `.pth` file containing the embeddings
    
    ### About the embeddings:
    - Each embedding has shape: `[1, 77, 4096]`
    - `77` = max sequence length
    - `4096` = embedding dimension of the T5 encoder
    - An empty prompt (`''`) is automatically added for Classifier Free Guidance (CFG)
    """)
    
    with gr.Row():
        with gr.Column():
            prompts_input = gr.Textbox(
                label="Enter Prompts (one per line)",
                placeholder="an oil painting of a snowy mountain village\na photo of the amalfi coast\na photo of a man\n...",
                lines=15,
                value="""an oil painting of a snowy mountain village
a photo of the amalfi coast
a photo of a man
a photo of a hipster barista
a photo of a dog
an oil painting of people around a campfire
an oil painting of an old man
a lithograph of waterfalls
a lithograph of a skull
a man wearing a hat
a high quality photo
a rocket ship
a pencil"""
            )
            
            generate_btn = gr.Button("πŸš€ Generate Embeddings", variant="primary", size="lg")
        
        with gr.Column():
            status_output = gr.Textbox(
                label="Status",
                lines=10,
                interactive=False
            )
            file_output = gr.File(
                label="Download Embeddings (.pth file)"
            )
    
    generate_btn.click(
        fn=generate_embeddings,
        inputs=[prompts_input],
        outputs=[file_output, status_output]
    )
    
    gr.Markdown("""
    ### πŸ“ Note:
    - The first run may take a while as the model needs to download (~8GB)
    - Subsequent runs will be faster
    - The generated `.pth` file can be loaded in PyTorch using: `torch.load('prompt_embeds_dict.pth')`
    """)

if __name__ == "__main__":
    demo.launch()