Spaces:
Runtime error
Runtime error
| # app.py for Hugging Face Space | |
| # Make sure to add 'gradio', 'transformers', and 'torch' (or 'tensorflow'/'flax') | |
| # to your requirements.txt file in the Hugging Face Space repository. | |
| # gated model | |
| # Set Hugging Face token if needed (for gated models, though Llama 3.1 might not require it after initial access grant) | |
| from huggingface_hub import login | |
| # app.py for Hugging Face Space | |
| # Make sure to add 'gradio', 'transformers', 'torch' (or 'tensorflow'/'flax'), | |
| # and 'huggingface_hub' to your requirements.txt file in the Hugging Face Space repository. | |
| import gradio as gr | |
| import torch # Or tensorflow/flax depending on backend | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from huggingface_hub import hf_hub_download # Import hub download function | |
| import json # Import json library | |
| import os # Import os library for path joining | |
| # --- hf lpgin --- | |
| hf_token = os.getenv("HF_TOKEN") | |
| login(token=hf_token) | |
| # --- Configuration --- | |
| MODEL_NAME = "google/txgemma-2b-predict" | |
| PROMPT_FILENAME = "tdc_prompts.json" | |
| MODEL_CACHE = "model_cache" # Optional: define a cache directory | |
| MAX_EXAMPLES = 100 # Limit the number of examples loaded from the JSON | |
| EXAMPLE_SMILES = "C1=CC=CC=C1" # Default SMILES for examples (Benzene) | |
| # --- Load Model, Tokenizer, and Prompts --- | |
| print(f"Loading model: {MODEL_NAME}...") | |
| tdc_prompts_data = None # Initialize as None | |
| examples_list = [] # Initialize empty list for examples | |
| try: | |
| # Check if GPU is available and use it, otherwise use CPU | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device: {device}") | |
| # Load the tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=MODEL_CACHE) | |
| print("Tokenizer loaded.") | |
| # Load the model | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| cache_dir=MODEL_CACHE, | |
| device_map="auto" # Automatically distribute model across available devices (GPU/CPU) | |
| ) | |
| print("Model loaded.") | |
| # Download and load the prompts JSON file | |
| print(f"Downloading {PROMPT_FILENAME}...") | |
| prompts_file_path = hf_hub_download( | |
| repo_id=MODEL_NAME, | |
| filename=PROMPT_FILENAME, | |
| cache_dir=MODEL_CACHE, | |
| # force_download=True, # Uncomment to force redownload if needed | |
| ) | |
| print(f"{PROMPT_FILENAME} downloaded to: {prompts_file_path}") | |
| # Load the JSON data | |
| with open(prompts_file_path, 'r') as f: | |
| tdc_prompts_data = json.load(f) | |
| print(f"Loaded prompts data from {PROMPT_FILENAME}.") | |
| # --- Prepare examples for Gradio --- | |
| # Updated logic: Parse the dictionary format from tdc_prompts.json | |
| # The JSON is expected to be a dictionary where values are prompt templates. | |
| if isinstance(tdc_prompts_data, dict): | |
| print(f"Processing {len(tdc_prompts_data)} prompts from dictionary...") | |
| count = 0 | |
| for prompt_template in tdc_prompts_data.values(): | |
| if count >= MAX_EXAMPLES: | |
| break | |
| if isinstance(prompt_template, str): | |
| # Replace the placeholder with the example SMILES string | |
| example_prompt = prompt_template.replace("{Drug SMILES}", EXAMPLE_SMILES) | |
| # Add to examples list with default parameters | |
| examples_list.append([example_prompt, 100, 0.7]) # Default max_tokens=100, temp=0.7 | |
| count += 1 | |
| else: | |
| print(f"Warning: Skipping non-string value in prompts dictionary: {prompt_template}") | |
| print(f"Prepared {len(examples_list)} examples for Gradio.") | |
| else: | |
| print(f"Warning: Expected {PROMPT_FILENAME} to contain a dictionary, but found {type(tdc_prompts_data)}. Cannot load examples.") | |
| # examples_list remains empty | |
| except Exception as e: | |
| print(f"Error loading model, tokenizer, or prompts: {e}") | |
| # Ensure examples_list is empty on error during setup | |
| examples_list = [] | |
| raise gr.Error(f"Failed during setup. Check logs for details. Error: {e}") | |
| # --- Prediction Function --- | |
| def predict(prompt, max_new_tokens=100, temperature=0.7): | |
| """ | |
| Generates text based on the input prompt using the loaded model. | |
| Args: | |
| prompt (str): The input text prompt. | |
| max_new_tokens (int): The maximum number of new tokens to generate. | |
| temperature (float): Controls the randomness of the generation. Lower is more deterministic. | |
| Returns: | |
| str: The generated text. | |
| """ | |
| print(f"Received prompt: {prompt}") | |
| print(f"Generation parameters: max_new_tokens={max_new_tokens}, temperature={temperature}") | |
| try: | |
| # Prepare the input for the model | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) # Move inputs to the model's device | |
| # Generate text | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=int(max_new_tokens), # Ensure it's an integer | |
| temperature=float(temperature), # Ensure it's a float | |
| do_sample=True if float(temperature) > 0 else False, # Only sample if temp > 0 | |
| pad_token_id=tokenizer.eos_token_id # Set pad token id | |
| ) | |
| # Decode the generated tokens | |
| generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| print(f"Generated text (raw): {generated_text}") | |
| # Remove the prompt from the beginning of the generated text | |
| if generated_text.startswith(prompt): | |
| prompt_length = len(prompt) | |
| result_text = generated_text[prompt_length:].lstrip() | |
| else: | |
| # Handle cases where the model might slightly alter the prompt start | |
| # This is a basic check; more robust checks might be needed | |
| common_prefix = os.path.commonprefix([prompt, generated_text]) | |
| # Check if a significant portion of the prompt is at the start | |
| # Use a threshold relative to prompt length, e.g., 80% | |
| if len(prompt) > 0 and len(common_prefix) / len(prompt) > 0.8: | |
| result_text = generated_text[len(common_prefix):].lstrip() | |
| else: | |
| result_text = generated_text # Assume prompt is not included or significantly altered | |
| print(f"Generated text (processed): {result_text}") | |
| return result_text | |
| except Exception as e: | |
| print(f"Error during prediction: {e}") | |
| return f"An error occurred during generation: {e}" | |
| # --- Gradio Interface --- | |
| print("Creating Gradio interface...") | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| f""" | |
| # 🤖 TXGemma-2B-Predict Text Generation | |
| Enter a prompt below or select an example, and the model ({MODEL_NAME}) will generate text based on it. | |
| Adjust the parameters for different results. Examples loaded from `{PROMPT_FILENAME}`. | |
| Example prompts use the SMILES string `{EXAMPLE_SMILES}` (Benzene) as a placeholder. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| prompt_input = gr.Textbox( | |
| label="Your Prompt", | |
| placeholder="Enter your text prompt here, potentially including a specific Drug SMILES string...", | |
| lines=5 | |
| ) | |
| with gr.Row(): | |
| max_tokens_slider = gr.Slider( | |
| minimum=10, | |
| maximum=500, # Adjust max limit if needed | |
| value=100, | |
| step=10, | |
| label="Max New Tokens", | |
| info="Maximum number of tokens to generate after the prompt." | |
| ) | |
| temperature_slider = gr.Slider( | |
| minimum=0.0, # Allow deterministic generation | |
| maximum=1.5, | |
| value=0.7, | |
| step=0.05, # Finer control for temperature | |
| label="Temperature", | |
| info="Controls randomness (0=deterministic, >0=random)." | |
| ) | |
| submit_button = gr.Button("Generate Text", variant="primary") | |
| with gr.Column(scale=3): | |
| output_text = gr.Textbox( | |
| label="Generated Text", | |
| lines=10, | |
| interactive=False # Output is not editable by user | |
| ) | |
| # --- Connect Components --- | |
| submit_button.click( | |
| fn=predict, | |
| inputs=[prompt_input, max_tokens_slider, temperature_slider], | |
| outputs=output_text, | |
| api_name="predict" # Name for API endpoint if needed | |
| ) | |
| # Use the loaded examples if available | |
| if examples_list: | |
| gr.Examples( | |
| examples=examples_list, | |
| # Ensure inputs match the order expected by the 'predict' function and the structure of examples_list | |
| inputs=[prompt_input, max_tokens_slider, temperature_slider], | |
| outputs=output_text, | |
| fn=predict, # The function to run when an example is clicked | |
| cache_examples=False # Caching might be slow/problematic for LLMs | |
| ) | |
| else: | |
| gr.Markdown("_(Could not load examples from JSON file or file format was incorrect.)_") | |
| # --- Launch the App --- | |
| print("Launching Gradio app...") | |
| # queue() enables handling multiple users concurrently | |
| # Set share=True if you need a public link, otherwise False or omit | |
| demo.queue().launch(debug=True) # Set debug=False for production | |