Spaces:
Runtime error
Runtime error
| from threading import Thread | |
| import gradio as gr | |
| import random | |
| import torch | |
| import spaces | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| AutoConfig, | |
| TextIteratorStreamer | |
| ) | |
| # Constants for the model and configuration | |
| MODEL_ID = "AstroMLab/AstroSage-8B" | |
| WINDOW_SIZE = 2048 | |
| DEVICE = "cuda" | |
| # Load model configuration, tokenizer, and model | |
| config = AutoConfig.from_pretrained(pretrained_model_name_or_path=MODEL_ID) | |
| tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=MODEL_ID) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| pretrained_model_name_or_path=MODEL_ID, | |
| config=config, | |
| device_map="auto", | |
| use_safetensors=True, | |
| trust_remote_code=True, | |
| load_in_4bit=True, | |
| torch_dtype=torch.bfloat16 | |
| ) | |
| # Placeholder responses for when context is empty | |
| GREETING_MESSAGES = [ | |
| "Greetings! I am AstroSage, your guide to the cosmos. What would you like to explore today?", | |
| "Welcome to our cosmic journey! I am AstroSage. How may I assist you in understanding the universe?", | |
| "AstroSage here. Ready to explore the mysteries of space and time. How may I be of assistance?", | |
| "The universe awaits! I'm AstroSage. What astronomical wonders shall we discuss?", | |
| ] | |
| def format_message(role: str, content: str) -> str: | |
| """Format a single message according to Llama-3 chat template.""" | |
| return f"<|start_header_id|>{role}<|end_header_id|>\n\n{content}<|eot_id|>" | |
| def generate_text(prompt: str, history: list, max_new_tokens=512, temperature=0.7, top_p=0.95): | |
| """ | |
| Generate a response using the transformer model with proper Llama-3 chat formatting. | |
| """ | |
| # Start with begin_of_text token | |
| formatted_messages = ["<|begin_of_text|>"] | |
| # Add formatted history | |
| for msg in history: | |
| formatted_message = format_message(msg['role'], msg['content']) | |
| formatted_messages.append(formatted_message) | |
| # Add the current prompt | |
| formatted_message = format_message('user', prompt) | |
| formatted_messages.append(formatted_message) | |
| # Add the start of assistant's response | |
| formatted_messages.append("<|start_header_id|>assistant<|end_header_id|>\n\n") | |
| # Combine all messages | |
| prompt_with_history = "\n".join(formatted_messages) | |
| # Encode the prompt | |
| inputs = tokenizer([prompt_with_history], return_tensors="pt", truncation=True).to(DEVICE) | |
| input_length = inputs["input_ids"].shape[-1] | |
| max_new_tokens = min(max_new_tokens, WINDOW_SIZE - input_length) | |
| # Prepare text streamer for live updates | |
| streamer = TextIteratorStreamer( | |
| tokenizer=tokenizer, | |
| timeout=10.0, | |
| skip_prompt=True, | |
| skip_special_tokens=True | |
| ) | |
| generation_kwargs = dict( | |
| **inputs, | |
| streamer=streamer, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=True, | |
| top_p=top_p, | |
| temperature=temperature, | |
| ) | |
| # Generate the response in a separate thread for streaming | |
| thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| # Collect and return the response | |
| response = "" | |
| for new_text in streamer: | |
| response += new_text | |
| yield response | |
| def user(user_message, history): | |
| """ | |
| Add the user's message to the history. | |
| """ | |
| if history is None: | |
| history = [] | |
| return "", history + [{"role": "user", "content": user_message}] | |
| def bot(history): | |
| """ | |
| Generate the bot's response based on the history. | |
| """ | |
| if not history: | |
| history = [{"role": "assistant", "content": random.choice(GREETING_MESSAGES)}] | |
| last_user_message = history[-1]["content"] if history else "" | |
| response_generator = generate_text(last_user_message, history) | |
| history.append({"role": "assistant", "content": ""}) | |
| # Stream the response back | |
| for partial_response in response_generator: | |
| history[-1]["content"] = partial_response | |
| yield history | |
| def initial_greeting(): | |
| """ | |
| Return the initial greeting message. | |
| """ | |
| return [ | |
| {"role": "system","content": "You are AstroSage, an intelligent AI assistant specializing in astronomy, astrophysics, and cosmology. Provide accurate, scientific information while making complex concepts accessible. You're enthusiastic about space exploration and maintain a sense of wonder about the cosmos. Start by introducing yourself."}, | |
| {"role": "assistant", "content": random.choice(GREETING_MESSAGES)} | |
| ] | |
| # Custom CSS for a space theme | |
| custom_css = """ | |
| #component-0 { | |
| background-color: #1a1a2e; | |
| border-radius: 15px; | |
| padding: 20px; | |
| } | |
| .dark { | |
| background-color: #0f0f1a; | |
| } | |
| .contain { | |
| max-width: 1200px !important; | |
| } | |
| """ | |
| # Create the Gradio interface | |
| with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="indigo", neutral_hue="slate")) as demo: | |
| gr.Markdown( | |
| """ | |
| # π AstroSage-8B: Your Cosmic AI Companion | |
| Welcome to AstroSage-8B, an advanced AI assistant specializing in astronomy, astrophysics, and cosmology. | |
| Powered by the AstroSage-Llama-3.1-8B model, I'm here to help you explore the wonders of the universe! | |
| ### What Can I Help You With? | |
| - πͺ Explanations of astronomical phenomena | |
| - π Space exploration and missions | |
| - β Stars, galaxies, and cosmology | |
| - π Planetary science and exoplanets | |
| - π Astrophysics concepts and theories | |
| - π Astronomical instruments and observations | |
| Just type your question below and let's embark on a cosmic journey together! | |
| """ | |
| ) | |
| chatbot = gr.Chatbot( | |
| label="Chat with AstroSage", | |
| bubble_full_width=False, | |
| show_label=True, | |
| height=450, | |
| type="messages" | |
| ) | |
| with gr.Row(): | |
| msg = gr.Textbox( | |
| label="Type your message here", | |
| placeholder="Ask me anything about space and astronomy...", | |
| scale=9 | |
| ) | |
| clear = gr.Button("Clear Chat", scale=1) | |
| # Example questions for quick start | |
| gr.Examples( | |
| examples=[ | |
| "What is a black hole and how does it form?", | |
| "Can you explain the life cycle of a star?", | |
| "What are exoplanets and how do we detect them?", | |
| "Tell me about the James Webb Space Telescope.", | |
| "What is dark matter and why is it important?" | |
| ], | |
| inputs=msg, | |
| label="Example Questions" | |
| ) | |
| # Set up the message chain with streaming | |
| msg.submit( | |
| user, | |
| [msg, chatbot], | |
| [msg, chatbot], | |
| queue=False | |
| ).then( | |
| bot, | |
| chatbot, | |
| chatbot | |
| ) | |
| # Clear button functionality | |
| clear.click(lambda: None, None, chatbot, queue=False) | |
| # Initial greeting | |
| demo.load(initial_greeting, None, chatbot, queue=False) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo.launch() | |