Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import spaces | |
| import torch | |
| from diffusers import DiffusionPipeline | |
| from transformers import T5EncoderModel | |
| import tempfile | |
| import os | |
| # Global variable to store the text pipeline | |
| text_pipe = None | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| def load_model(): | |
| """Load the T5 text encoder model""" | |
| global text_pipe | |
| if text_pipe is None: | |
| print("Loading T5 text encoder...") | |
| # Get token from environment | |
| token = os.getenv("HF_TOKEN") | |
| text_encoder = T5EncoderModel.from_pretrained( | |
| "DeepFloyd/IF-I-L-v1.0", | |
| subfolder="text_encoder", | |
| load_in_8bit=True, | |
| variant="8bit", | |
| token=token | |
| ) | |
| text_pipe = DiffusionPipeline.from_pretrained( | |
| "DeepFloyd/IF-I-L-v1.0", | |
| text_encoder=text_encoder, | |
| unet=None, | |
| token=token, | |
| ) | |
| text_pipe = text_pipe.to(device) | |
| print("Model loaded successfully!") | |
| return text_pipe | |
| def generate_embeddings(prompts_text): | |
| """ | |
| Generate embeddings from text prompts | |
| Args: | |
| prompts_text: String with one prompt per line | |
| Returns: | |
| Path to the saved .pth file and a status message | |
| """ | |
| try: | |
| # Load model if not already loaded | |
| pipe = load_model() | |
| # Note: 8-bit models are already on the correct device, no need to move them | |
| # Parse prompts (one per line) | |
| prompts = [p.strip() for p in prompts_text.strip().split('\n') if p.strip()] | |
| if not prompts: | |
| return None, "Error: Please enter at least one prompt" | |
| # Add empty string for CFG (Classifier Free Guidance) | |
| if '' not in prompts: | |
| prompts.append('') | |
| # Generate embeddings | |
| print(f"Generating embeddings for {len(prompts)} prompts...") | |
| prompt_embeds_list = [] | |
| for prompt in prompts: | |
| embeds = pipe.encode_prompt(prompt) | |
| prompt_embeds_list.append(embeds) | |
| # Extract positive prompt embeddings | |
| prompt_embeds, negative_prompt_embeds = zip(*prompt_embeds_list) | |
| # Move embeddings to CPU before saving | |
| prompt_embeds_cpu = [emb.cpu() if isinstance(emb, torch.Tensor) else emb for emb in prompt_embeds] | |
| # Create dictionary | |
| prompt_embeds_dict = dict(zip(prompts, prompt_embeds_cpu)) | |
| # Save to temporary file | |
| temp_dir = tempfile.gettempdir() | |
| temp_file_path = os.path.join(temp_dir, 'prompt_embeds_dict.pth') | |
| torch.save(prompt_embeds_dict, temp_file_path) | |
| status_msg = f"β Successfully generated embeddings for {len(prompts)} prompts!\n" | |
| status_msg += "Each embedding has shape: [1, 77, 4096]\n" | |
| status_msg += "Prompts processed:\n" + "\n".join([f" - '{p}'" for p in prompts]) | |
| return temp_file_path, status_msg | |
| except Exception as e: | |
| import traceback | |
| error_details = traceback.format_exc() | |
| return None, f"β Error: {str(e)}\n\nDetails:\n{error_details}" | |
| # Create Gradio interface | |
| with gr.Blocks(title="T5 Text Encoder - Embeddings Generator") as demo: | |
| gr.Markdown(""" | |
| # π€ CS180 HW5: T5 Text Encoder Embeddings Generator | |
| This space uses the **DeepFloyd IF** T5 text encoder to generate embeddings from your text prompts. | |
| ### How to use: | |
| 1. Enter your prompts in the text box (one prompt per line) | |
| 2. Click "Generate Embeddings" | |
| 3. Download the generated `.pth` file containing the embeddings | |
| ### About the embeddings: | |
| - Each embedding has shape: `[1, 77, 4096]` | |
| - `77` = max sequence length | |
| - `4096` = embedding dimension of the T5 encoder | |
| - An empty prompt (`''`) is automatically added for Classifier Free Guidance (CFG) | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| prompts_input = gr.Textbox( | |
| label="Enter Prompts (one per line)", | |
| placeholder="an oil painting of a snowy mountain village\na photo of the amalfi coast\na photo of a man\n...", | |
| lines=15, | |
| value="""a high quality picture | |
| an oil painting of a snowy mountain village | |
| a photo of the amalfi coast | |
| a photo of a man | |
| a photo of a hipster barista | |
| a photo of a dog | |
| an oil painting of people around a campfire | |
| an oil painting of an old man | |
| a lithograph of waterfalls | |
| a lithograph of a skull | |
| a man wearing a hat | |
| a high quality photo | |
| a rocket ship | |
| a pencil""" | |
| ) | |
| generate_btn = gr.Button("π Generate Embeddings", variant="primary", size="lg") | |
| with gr.Column(): | |
| status_output = gr.Textbox( | |
| label="Status", | |
| lines=10, | |
| interactive=False | |
| ) | |
| file_output = gr.File( | |
| label="Download Embeddings (.pth file)" | |
| ) | |
| generate_btn.click( | |
| fn=generate_embeddings, | |
| inputs=[prompts_input], | |
| outputs=[file_output, status_output] | |
| ) | |
| gr.Markdown(""" | |
| ### π Note: | |
| - The first run may take a while as the model needs to download (~8GB) | |
| - Subsequent runs will be faster | |
| - The generated `.pth` file can be loaded in PyTorch using: `torch.load('prompt_embeds_dict.pth')` | |
| """) | |
| if __name__ == "__main__": | |
| demo.launch() | |