jamesoncrate's picture
add additional prompt
5f84af5
import gradio as gr
import spaces
import torch
from diffusers import DiffusionPipeline
from transformers import T5EncoderModel
import tempfile
import os
# Global variable to store the text pipeline
text_pipe = None
device = 'cuda' if torch.cuda.is_available() else 'cpu'
def load_model():
"""Load the T5 text encoder model"""
global text_pipe
if text_pipe is None:
print("Loading T5 text encoder...")
# Get token from environment
token = os.getenv("HF_TOKEN")
text_encoder = T5EncoderModel.from_pretrained(
"DeepFloyd/IF-I-L-v1.0",
subfolder="text_encoder",
load_in_8bit=True,
variant="8bit",
token=token
)
text_pipe = DiffusionPipeline.from_pretrained(
"DeepFloyd/IF-I-L-v1.0",
text_encoder=text_encoder,
unet=None,
token=token,
)
text_pipe = text_pipe.to(device)
print("Model loaded successfully!")
return text_pipe
@spaces.GPU
def generate_embeddings(prompts_text):
"""
Generate embeddings from text prompts
Args:
prompts_text: String with one prompt per line
Returns:
Path to the saved .pth file and a status message
"""
try:
# Load model if not already loaded
pipe = load_model()
# Note: 8-bit models are already on the correct device, no need to move them
# Parse prompts (one per line)
prompts = [p.strip() for p in prompts_text.strip().split('\n') if p.strip()]
if not prompts:
return None, "Error: Please enter at least one prompt"
# Add empty string for CFG (Classifier Free Guidance)
if '' not in prompts:
prompts.append('')
# Generate embeddings
print(f"Generating embeddings for {len(prompts)} prompts...")
prompt_embeds_list = []
for prompt in prompts:
embeds = pipe.encode_prompt(prompt)
prompt_embeds_list.append(embeds)
# Extract positive prompt embeddings
prompt_embeds, negative_prompt_embeds = zip(*prompt_embeds_list)
# Move embeddings to CPU before saving
prompt_embeds_cpu = [emb.cpu() if isinstance(emb, torch.Tensor) else emb for emb in prompt_embeds]
# Create dictionary
prompt_embeds_dict = dict(zip(prompts, prompt_embeds_cpu))
# Save to temporary file
temp_dir = tempfile.gettempdir()
temp_file_path = os.path.join(temp_dir, 'prompt_embeds_dict.pth')
torch.save(prompt_embeds_dict, temp_file_path)
status_msg = f"βœ… Successfully generated embeddings for {len(prompts)} prompts!\n"
status_msg += "Each embedding has shape: [1, 77, 4096]\n"
status_msg += "Prompts processed:\n" + "\n".join([f" - '{p}'" for p in prompts])
return temp_file_path, status_msg
except Exception as e:
import traceback
error_details = traceback.format_exc()
return None, f"❌ Error: {str(e)}\n\nDetails:\n{error_details}"
# Create Gradio interface
with gr.Blocks(title="T5 Text Encoder - Embeddings Generator") as demo:
gr.Markdown("""
# πŸ”€ CS180 HW5: T5 Text Encoder Embeddings Generator
This space uses the **DeepFloyd IF** T5 text encoder to generate embeddings from your text prompts.
### How to use:
1. Enter your prompts in the text box (one prompt per line)
2. Click "Generate Embeddings"
3. Download the generated `.pth` file containing the embeddings
### About the embeddings:
- Each embedding has shape: `[1, 77, 4096]`
- `77` = max sequence length
- `4096` = embedding dimension of the T5 encoder
- An empty prompt (`''`) is automatically added for Classifier Free Guidance (CFG)
""")
with gr.Row():
with gr.Column():
prompts_input = gr.Textbox(
label="Enter Prompts (one per line)",
placeholder="an oil painting of a snowy mountain village\na photo of the amalfi coast\na photo of a man\n...",
lines=15,
value="""a high quality picture
an oil painting of a snowy mountain village
a photo of the amalfi coast
a photo of a man
a photo of a hipster barista
a photo of a dog
an oil painting of people around a campfire
an oil painting of an old man
a lithograph of waterfalls
a lithograph of a skull
a man wearing a hat
a high quality photo
a rocket ship
a pencil"""
)
generate_btn = gr.Button("πŸš€ Generate Embeddings", variant="primary", size="lg")
with gr.Column():
status_output = gr.Textbox(
label="Status",
lines=10,
interactive=False
)
file_output = gr.File(
label="Download Embeddings (.pth file)"
)
generate_btn.click(
fn=generate_embeddings,
inputs=[prompts_input],
outputs=[file_output, status_output]
)
gr.Markdown("""
### πŸ“ Note:
- The first run may take a while as the model needs to download (~8GB)
- Subsequent runs will be faster
- The generated `.pth` file can be loaded in PyTorch using: `torch.load('prompt_embeds_dict.pth')`
""")
if __name__ == "__main__":
demo.launch()