Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,460 Bytes
37f4150 2242fb6 37f4150 1ee8357 37f4150 2242fb6 3eb8145 37f4150 2242fb6 37f4150 3eb8145 1ee8357 37f4150 a50b0f8 2242fb6 37f4150 2242fb6 37f4150 2242fb6 37f4150 185d799 37f4150 185d799 37f4150 2242fb6 37f4150 5f84af5 37f4150 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import gradio as gr
import spaces
import torch
from diffusers import DiffusionPipeline
from transformers import T5EncoderModel
import tempfile
import os
# Global variable to store the text pipeline
text_pipe = None
device = 'cuda' if torch.cuda.is_available() else 'cpu'
def load_model():
"""Load the T5 text encoder model"""
global text_pipe
if text_pipe is None:
print("Loading T5 text encoder...")
# Get token from environment
token = os.getenv("HF_TOKEN")
text_encoder = T5EncoderModel.from_pretrained(
"DeepFloyd/IF-I-L-v1.0",
subfolder="text_encoder",
load_in_8bit=True,
variant="8bit",
token=token
)
text_pipe = DiffusionPipeline.from_pretrained(
"DeepFloyd/IF-I-L-v1.0",
text_encoder=text_encoder,
unet=None,
token=token,
)
text_pipe = text_pipe.to(device)
print("Model loaded successfully!")
return text_pipe
@spaces.GPU
def generate_embeddings(prompts_text):
"""
Generate embeddings from text prompts
Args:
prompts_text: String with one prompt per line
Returns:
Path to the saved .pth file and a status message
"""
try:
# Load model if not already loaded
pipe = load_model()
# Note: 8-bit models are already on the correct device, no need to move them
# Parse prompts (one per line)
prompts = [p.strip() for p in prompts_text.strip().split('\n') if p.strip()]
if not prompts:
return None, "Error: Please enter at least one prompt"
# Add empty string for CFG (Classifier Free Guidance)
if '' not in prompts:
prompts.append('')
# Generate embeddings
print(f"Generating embeddings for {len(prompts)} prompts...")
prompt_embeds_list = []
for prompt in prompts:
embeds = pipe.encode_prompt(prompt)
prompt_embeds_list.append(embeds)
# Extract positive prompt embeddings
prompt_embeds, negative_prompt_embeds = zip(*prompt_embeds_list)
# Move embeddings to CPU before saving
prompt_embeds_cpu = [emb.cpu() if isinstance(emb, torch.Tensor) else emb for emb in prompt_embeds]
# Create dictionary
prompt_embeds_dict = dict(zip(prompts, prompt_embeds_cpu))
# Save to temporary file
temp_dir = tempfile.gettempdir()
temp_file_path = os.path.join(temp_dir, 'prompt_embeds_dict.pth')
torch.save(prompt_embeds_dict, temp_file_path)
status_msg = f"β
Successfully generated embeddings for {len(prompts)} prompts!\n"
status_msg += "Each embedding has shape: [1, 77, 4096]\n"
status_msg += "Prompts processed:\n" + "\n".join([f" - '{p}'" for p in prompts])
return temp_file_path, status_msg
except Exception as e:
import traceback
error_details = traceback.format_exc()
return None, f"β Error: {str(e)}\n\nDetails:\n{error_details}"
# Create Gradio interface
with gr.Blocks(title="T5 Text Encoder - Embeddings Generator") as demo:
gr.Markdown("""
# π€ CS180 HW5: T5 Text Encoder Embeddings Generator
This space uses the **DeepFloyd IF** T5 text encoder to generate embeddings from your text prompts.
### How to use:
1. Enter your prompts in the text box (one prompt per line)
2. Click "Generate Embeddings"
3. Download the generated `.pth` file containing the embeddings
### About the embeddings:
- Each embedding has shape: `[1, 77, 4096]`
- `77` = max sequence length
- `4096` = embedding dimension of the T5 encoder
- An empty prompt (`''`) is automatically added for Classifier Free Guidance (CFG)
""")
with gr.Row():
with gr.Column():
prompts_input = gr.Textbox(
label="Enter Prompts (one per line)",
placeholder="an oil painting of a snowy mountain village\na photo of the amalfi coast\na photo of a man\n...",
lines=15,
value="""a high quality picture
an oil painting of a snowy mountain village
a photo of the amalfi coast
a photo of a man
a photo of a hipster barista
a photo of a dog
an oil painting of people around a campfire
an oil painting of an old man
a lithograph of waterfalls
a lithograph of a skull
a man wearing a hat
a high quality photo
a rocket ship
a pencil"""
)
generate_btn = gr.Button("π Generate Embeddings", variant="primary", size="lg")
with gr.Column():
status_output = gr.Textbox(
label="Status",
lines=10,
interactive=False
)
file_output = gr.File(
label="Download Embeddings (.pth file)"
)
generate_btn.click(
fn=generate_embeddings,
inputs=[prompts_input],
outputs=[file_output, status_output]
)
gr.Markdown("""
### π Note:
- The first run may take a while as the model needs to download (~8GB)
- Subsequent runs will be faster
- The generated `.pth` file can be loaded in PyTorch using: `torch.load('prompt_embeds_dict.pth')`
""")
if __name__ == "__main__":
demo.launch()
|