jamesoncrate commited on
Commit
37f4150
Β·
1 Parent(s): a30a4af

add cs180 encoder

Browse files
Files changed (2) hide show
  1. app.py +154 -0
  2. requirements.txt +9 -0
app.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces
3
+ import torch
4
+ from diffusers import DiffusionPipeline
5
+ from transformers import T5EncoderModel
6
+ import tempfile
7
+
8
+ # Global variable to store the text pipeline
9
+ text_pipe = None
10
+
11
+ def load_model():
12
+ """Load the T5 text encoder model"""
13
+ global text_pipe
14
+ if text_pipe is None:
15
+ print("Loading T5 text encoder...")
16
+
17
+ # Get token from environment
18
+ import os
19
+ token = os.getenv("HF_TOKEN")
20
+
21
+ text_encoder = T5EncoderModel.from_pretrained(
22
+ "DeepFloyd/IF-I-L-v1.0",
23
+ subfolder="text_encoder",
24
+ load_in_8bit=True,
25
+ variant="8bit",
26
+ device_map="auto",
27
+ token=token # Add this line
28
+ )
29
+ text_pipe = DiffusionPipeline.from_pretrained(
30
+ "DeepFloyd/IF-I-L-v1.0",
31
+ text_encoder=text_encoder,
32
+ unet=None,
33
+ token=token # Add this line
34
+ )
35
+ print("Model loaded successfully!")
36
+ return text_pipe
37
+
38
+ @spaces.GPU
39
+ def generate_embeddings(prompts_text):
40
+ """
41
+ Generate embeddings from text prompts
42
+ Args:
43
+ prompts_text: String with one prompt per line
44
+ Returns:
45
+ Path to the saved .pth file and a status message
46
+ """
47
+ try:
48
+ # Load model if not already loaded
49
+ pipe = load_model()
50
+
51
+ # Parse prompts (one per line)
52
+ prompts = [p.strip() for p in prompts_text.strip().split('\n') if p.strip()]
53
+
54
+ if not prompts:
55
+ return None, "Error: Please enter at least one prompt"
56
+
57
+ # Add empty string for CFG (Classifier Free Guidance)
58
+ if '' not in prompts:
59
+ prompts.append('')
60
+
61
+ # Generate embeddings
62
+ print(f"Generating embeddings for {len(prompts)} prompts...")
63
+ prompt_embeds_list = []
64
+ for prompt in prompts:
65
+ embeds = pipe.encode_prompt(prompt)
66
+ prompt_embeds_list.append(embeds)
67
+
68
+ # Extract positive prompt embeddings
69
+ prompt_embeds, negative_prompt_embeds = zip(*prompt_embeds_list)
70
+
71
+ # Create dictionary
72
+ prompt_embeds_dict = dict(zip(prompts, prompt_embeds))
73
+
74
+ # Save to temporary file
75
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.pth')
76
+ torch.save(prompt_embeds_dict, temp_file.name)
77
+ temp_file.close()
78
+
79
+ status_msg = f"βœ… Successfully generated embeddings for {len(prompts)} prompts!\n"
80
+ status_msg += "Each embedding has shape: [1, 77, 4096]\n"
81
+ status_msg += "Prompts processed:\n" + "\n".join([f" - '{p}'" for p in prompts])
82
+
83
+ return temp_file.name, status_msg
84
+
85
+ except Exception as e:
86
+ return None, f"❌ Error: {str(e)}"
87
+
88
+ # Create Gradio interface
89
+ with gr.Blocks(title="T5 Text Encoder - Embeddings Generator") as demo:
90
+ gr.Markdown("""
91
+ # πŸ”€ CS180 HW5: T5 Text Encoder Embeddings Generator
92
+
93
+ This space uses the **DeepFloyd IF** T5 text encoder to generate embeddings from your text prompts.
94
+
95
+ ### How to use:
96
+ 1. Enter your prompts in the text box (one prompt per line)
97
+ 2. Click "Generate Embeddings"
98
+ 3. Download the generated `.pth` file containing the embeddings
99
+
100
+ ### About the embeddings:
101
+ - Each embedding has shape: `[1, 77, 4096]`
102
+ - `77` = max sequence length
103
+ - `4096` = embedding dimension of the T5 encoder
104
+ - An empty prompt (`''`) is automatically added for Classifier Free Guidance (CFG)
105
+ """)
106
+
107
+ with gr.Row():
108
+ with gr.Column():
109
+ prompts_input = gr.Textbox(
110
+ label="Enter Prompts (one per line)",
111
+ placeholder="an oil painting of a snowy mountain village\na photo of the amalfi coast\na photo of a man\n...",
112
+ lines=15,
113
+ value="""an oil painting of a snowy mountain village
114
+ a photo of the amalfi coast
115
+ a photo of a man
116
+ a photo of a hipster barista
117
+ a photo of a dog
118
+ an oil painting of people around a campfire
119
+ an oil painting of an old man
120
+ a lithograph of waterfalls
121
+ a lithograph of a skull
122
+ a man wearing a hat
123
+ a high quality photo
124
+ a rocket ship
125
+ a pencil"""
126
+ )
127
+
128
+ generate_btn = gr.Button("πŸš€ Generate Embeddings", variant="primary", size="lg")
129
+
130
+ with gr.Column():
131
+ status_output = gr.Textbox(
132
+ label="Status",
133
+ lines=10,
134
+ interactive=False
135
+ )
136
+ file_output = gr.File(
137
+ label="Download Embeddings (.pth file)"
138
+ )
139
+
140
+ generate_btn.click(
141
+ fn=generate_embeddings,
142
+ inputs=[prompts_input],
143
+ outputs=[file_output, status_output]
144
+ )
145
+
146
+ gr.Markdown("""
147
+ ### πŸ“ Note:
148
+ - The first run may take a while as the model needs to download (~8GB)
149
+ - Subsequent runs will be faster
150
+ - The generated `.pth` file can be loaded in PyTorch using: `torch.load('prompt_embeds_dict.pth')`
151
+ """)
152
+
153
+ if __name__ == "__main__":
154
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ diffusers
4
+ transformers
5
+ accelerate
6
+ bitsandbytes
7
+ sentencepiece
8
+ protobuf
9
+ bitsandbytes