import gradio as gr import spaces import torch from diffusers import DiffusionPipeline from transformers import T5EncoderModel import tempfile import os # Global variable to store the text pipeline text_pipe = None device = 'cuda' if torch.cuda.is_available() else 'cpu' def load_model(): """Load the T5 text encoder model""" global text_pipe if text_pipe is None: print("Loading T5 text encoder...") # Get token from environment token = os.getenv("HF_TOKEN") text_encoder = T5EncoderModel.from_pretrained( "DeepFloyd/IF-I-L-v1.0", subfolder="text_encoder", load_in_8bit=True, variant="8bit", token=token ) text_pipe = DiffusionPipeline.from_pretrained( "DeepFloyd/IF-I-L-v1.0", text_encoder=text_encoder, unet=None, token=token, ) text_pipe = text_pipe.to(device) print("Model loaded successfully!") return text_pipe @spaces.GPU def generate_embeddings(prompts_text): """ Generate embeddings from text prompts Args: prompts_text: String with one prompt per line Returns: Path to the saved .pth file and a status message """ try: # Load model if not already loaded pipe = load_model() # Note: 8-bit models are already on the correct device, no need to move them # Parse prompts (one per line) prompts = [p.strip() for p in prompts_text.strip().split('\n') if p.strip()] if not prompts: return None, "Error: Please enter at least one prompt" # Add empty string for CFG (Classifier Free Guidance) if '' not in prompts: prompts.append('') # Generate embeddings print(f"Generating embeddings for {len(prompts)} prompts...") prompt_embeds_list = [] for prompt in prompts: embeds = pipe.encode_prompt(prompt) prompt_embeds_list.append(embeds) # Extract positive prompt embeddings prompt_embeds, negative_prompt_embeds = zip(*prompt_embeds_list) # Move embeddings to CPU before saving prompt_embeds_cpu = [emb.cpu() if isinstance(emb, torch.Tensor) else emb for emb in prompt_embeds] # Create dictionary prompt_embeds_dict = dict(zip(prompts, prompt_embeds_cpu)) # Save to temporary file temp_dir = tempfile.gettempdir() temp_file_path = os.path.join(temp_dir, 'prompt_embeds_dict.pth') torch.save(prompt_embeds_dict, temp_file_path) status_msg = f"✅ Successfully generated embeddings for {len(prompts)} prompts!\n" status_msg += "Each embedding has shape: [1, 77, 4096]\n" status_msg += "Prompts processed:\n" + "\n".join([f" - '{p}'" for p in prompts]) return temp_file_path, status_msg except Exception as e: import traceback error_details = traceback.format_exc() return None, f"❌ Error: {str(e)}\n\nDetails:\n{error_details}" # Create Gradio interface with gr.Blocks(title="T5 Text Encoder - Embeddings Generator") as demo: gr.Markdown(""" # 🔤 CS180 HW5: T5 Text Encoder Embeddings Generator This space uses the **DeepFloyd IF** T5 text encoder to generate embeddings from your text prompts. ### How to use: 1. Enter your prompts in the text box (one prompt per line) 2. Click "Generate Embeddings" 3. Download the generated `.pth` file containing the embeddings ### About the embeddings: - Each embedding has shape: `[1, 77, 4096]` - `77` = max sequence length - `4096` = embedding dimension of the T5 encoder - An empty prompt (`''`) is automatically added for Classifier Free Guidance (CFG) """) with gr.Row(): with gr.Column(): prompts_input = gr.Textbox( label="Enter Prompts (one per line)", placeholder="an oil painting of a snowy mountain village\na photo of the amalfi coast\na photo of a man\n...", lines=15, value="""a high quality picture an oil painting of a snowy mountain village a photo of the amalfi coast a photo of a man a photo of a hipster barista a photo of a dog an oil painting of people around a campfire an oil painting of an old man a lithograph of waterfalls a lithograph of a skull a man wearing a hat a high quality photo a rocket ship a pencil""" ) generate_btn = gr.Button("🚀 Generate Embeddings", variant="primary", size="lg") with gr.Column(): status_output = gr.Textbox( label="Status", lines=10, interactive=False ) file_output = gr.File( label="Download Embeddings (.pth file)" ) generate_btn.click( fn=generate_embeddings, inputs=[prompts_input], outputs=[file_output, status_output] ) gr.Markdown(""" ### 📝 Note: - The first run may take a while as the model needs to download (~8GB) - Subsequent runs will be faster - The generated `.pth` file can be loaded in PyTorch using: `torch.load('prompt_embeds_dict.pth')` """) if __name__ == "__main__": demo.launch()