Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,038 Bytes
37f4150 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import gradio as gr
import spaces
import torch
from diffusers import DiffusionPipeline
from transformers import T5EncoderModel
import tempfile
# Global variable to store the text pipeline
text_pipe = None
def load_model():
"""Load the T5 text encoder model"""
global text_pipe
if text_pipe is None:
print("Loading T5 text encoder...")
# Get token from environment
import os
token = os.getenv("HF_TOKEN")
text_encoder = T5EncoderModel.from_pretrained(
"DeepFloyd/IF-I-L-v1.0",
subfolder="text_encoder",
load_in_8bit=True,
variant="8bit",
device_map="auto",
token=token # Add this line
)
text_pipe = DiffusionPipeline.from_pretrained(
"DeepFloyd/IF-I-L-v1.0",
text_encoder=text_encoder,
unet=None,
token=token # Add this line
)
print("Model loaded successfully!")
return text_pipe
@spaces.GPU
def generate_embeddings(prompts_text):
"""
Generate embeddings from text prompts
Args:
prompts_text: String with one prompt per line
Returns:
Path to the saved .pth file and a status message
"""
try:
# Load model if not already loaded
pipe = load_model()
# Parse prompts (one per line)
prompts = [p.strip() for p in prompts_text.strip().split('\n') if p.strip()]
if not prompts:
return None, "Error: Please enter at least one prompt"
# Add empty string for CFG (Classifier Free Guidance)
if '' not in prompts:
prompts.append('')
# Generate embeddings
print(f"Generating embeddings for {len(prompts)} prompts...")
prompt_embeds_list = []
for prompt in prompts:
embeds = pipe.encode_prompt(prompt)
prompt_embeds_list.append(embeds)
# Extract positive prompt embeddings
prompt_embeds, negative_prompt_embeds = zip(*prompt_embeds_list)
# Create dictionary
prompt_embeds_dict = dict(zip(prompts, prompt_embeds))
# Save to temporary file
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.pth')
torch.save(prompt_embeds_dict, temp_file.name)
temp_file.close()
status_msg = f"β
Successfully generated embeddings for {len(prompts)} prompts!\n"
status_msg += "Each embedding has shape: [1, 77, 4096]\n"
status_msg += "Prompts processed:\n" + "\n".join([f" - '{p}'" for p in prompts])
return temp_file.name, status_msg
except Exception as e:
return None, f"β Error: {str(e)}"
# Create Gradio interface
with gr.Blocks(title="T5 Text Encoder - Embeddings Generator") as demo:
gr.Markdown("""
# π€ CS180 HW5: T5 Text Encoder Embeddings Generator
This space uses the **DeepFloyd IF** T5 text encoder to generate embeddings from your text prompts.
### How to use:
1. Enter your prompts in the text box (one prompt per line)
2. Click "Generate Embeddings"
3. Download the generated `.pth` file containing the embeddings
### About the embeddings:
- Each embedding has shape: `[1, 77, 4096]`
- `77` = max sequence length
- `4096` = embedding dimension of the T5 encoder
- An empty prompt (`''`) is automatically added for Classifier Free Guidance (CFG)
""")
with gr.Row():
with gr.Column():
prompts_input = gr.Textbox(
label="Enter Prompts (one per line)",
placeholder="an oil painting of a snowy mountain village\na photo of the amalfi coast\na photo of a man\n...",
lines=15,
value="""an oil painting of a snowy mountain village
a photo of the amalfi coast
a photo of a man
a photo of a hipster barista
a photo of a dog
an oil painting of people around a campfire
an oil painting of an old man
a lithograph of waterfalls
a lithograph of a skull
a man wearing a hat
a high quality photo
a rocket ship
a pencil"""
)
generate_btn = gr.Button("π Generate Embeddings", variant="primary", size="lg")
with gr.Column():
status_output = gr.Textbox(
label="Status",
lines=10,
interactive=False
)
file_output = gr.File(
label="Download Embeddings (.pth file)"
)
generate_btn.click(
fn=generate_embeddings,
inputs=[prompts_input],
outputs=[file_output, status_output]
)
gr.Markdown("""
### π Note:
- The first run may take a while as the model needs to download (~8GB)
- Subsequent runs will be faster
- The generated `.pth` file can be loaded in PyTorch using: `torch.load('prompt_embeds_dict.pth')`
""")
if __name__ == "__main__":
demo.launch()
|