import gradio as gr import torch from transformers import AutoTokenizer, AutoConfig from pathlib import Path import spaces from huggingface_hub import hf_hub_download from safetensors.torch import load_file import json from model import SAE, SteerableOlmo2ForCausalLM # Initialize model and tokenizer device = "cuda" if torch.cuda.is_available() else "cpu" model_name = "allenai/OLMo-2-1124-7B-Instruct" print("Loading model and tokenizer...") model = SteerableOlmo2ForCausalLM.from_pretrained( model_name, torch_dtype=torch.bfloat16 ).to(device) tokenizer = AutoTokenizer.from_pretrained(model_name) model_config = AutoConfig.from_pretrained(model_name) # Load SAE from Hugging Face Hub print("Loading SAE from Hugging Face Hub...") # Download SAE files from your model repository sae_weights_path = hf_hub_download( repo_id="open-concept-steering/olmo2-7b-sae-65k-v1", filename="sae_weights.safetensors" ) sae_config_path = hf_hub_download( repo_id="open-concept-steering/olmo2-7b-sae-65k-v1", filename="sae_config.json" ) # Load SAE sae_weights = load_file(sae_weights_path, device=device) with open(sae_config_path, "r") as f: sae_config = json.load(f) sae = SAE(sae_config['input_size'], sae_config['hidden_size']).to(device).to(torch.bfloat16) sae.load_state_dict(sae_weights) # Set up steering steering_layer = model_config.num_hidden_layers // 2 - 1 model.set_sae_and_layer(sae, steering_layer) # Steering features configuration STEERING_FEATURES = { "None": {"feature": None, "default": 0, "name": "No Steering"}, "Batman / Bruce Wayne": {"feature": 758, "default": 9, "name": "🦸 Superhero/Batman"}, "Japan": {"feature": 29940, "default": 8, "name": "🗾 Japan"}, "Baseball": {"feature": 65023, "default": 6, "name": "⚾ Baseball"} } default_system_prompt = "You are OLMo 2, a helpful and harmless AI Assistant built by the Allen Institute for AI." @spaces.GPU def generate_responses(message, history_unsteered, history_steered, steering_type, steering_strength, system_prompt): """Generate both unsteered and steered responses with conversation history""" if not message: return history_unsteered, history_steered, "" # Build messages for unsteered conversation messages_unsteered = [] if system_prompt: messages_unsteered.append({"role": "system", "content": system_prompt}) # Add conversation history for msg in history_unsteered: messages_unsteered.append({"role": msg["role"], "content": msg["content"]}) # Add current message messages_unsteered.append({"role": "user", "content": message}) # Format prompt for unsteered formatted_prompt_unsteered = tokenizer.apply_chat_template( messages_unsteered, tokenize=False, add_generation_prompt=True ) inputs_unsteered = tokenizer( formatted_prompt_unsteered, return_tensors="pt", padding=True, return_attention_mask=True ).to(device) # Generate unsteered response model.clear_steering() with torch.inference_mode(): outputs_unsteered = model.generate( input_ids=inputs_unsteered.input_ids, attention_mask=inputs_unsteered.attention_mask, max_new_tokens=256, temperature=0.7, top_p=0.9, do_sample=True, pad_token_id=tokenizer.eos_token_id ) full_response_unsteered = tokenizer.decode(outputs_unsteered[0], skip_special_tokens=False) unsteered_response = full_response_unsteered.split("<|assistant|>")[-1].split("<|endoftext|>")[0].strip() # Update unsteered history history_unsteered.append({"role": "user", "content": message}) history_unsteered.append({"role": "assistant", "content": unsteered_response}) # Generate steered response if steering_type != "None": # Build messages for steered conversation messages_steered = [] if system_prompt: messages_steered.append({"role": "system", "content": system_prompt}) # Add conversation history for msg in history_steered: messages_steered.append({"role": msg["role"], "content": msg["content"]}) # Add current message messages_steered.append({"role": "user", "content": message}) # Format prompt for steered formatted_prompt_steered = tokenizer.apply_chat_template( messages_steered, tokenize=False, add_generation_prompt=True ) inputs_steered = tokenizer( formatted_prompt_steered, return_tensors="pt", padding=True, return_attention_mask=True ).to(device) # Apply steering feature_config = STEERING_FEATURES[steering_type] steering_value = feature_config["default"] * steering_strength model.set_steering(feature_config["feature"], steering_value) with torch.inference_mode(): outputs_steered = model.generate( input_ids=inputs_steered.input_ids, attention_mask=inputs_steered.attention_mask, max_new_tokens=256, temperature=0.7, top_p=0.9, do_sample=True, pad_token_id=tokenizer.eos_token_id ) full_response_steered = tokenizer.decode(outputs_steered[0], skip_special_tokens=False) steered_response = full_response_steered.split("<|assistant|>")[-1].split("<|endoftext|>")[0].strip() model.clear_steering() else: steered_response = unsteered_response # Update steered history history_steered.append({"role": "user", "content": message}) history_steered.append({"role": "assistant", "content": steered_response}) return history_unsteered, history_steered, "" def clear_chats(): """Clear both chat histories""" return [], [] # Create Gradio interface with gr.Blocks(title="OLMo-2 Feature Steering Demo", theme=gr.themes.Default()) as demo: gr.Markdown(""" # 🎛️ OLMo-2 Feature Steering Demo This demo showcases how sparse autoencoders (SAEs) can steer OLMo-2's responses by manipulating specific features. Have a conversation and see how steering changes the model's behavior across multiple turns! """) with gr.Row(): with gr.Column(scale=1): steering_type = gr.Dropdown( choices=list(STEERING_FEATURES.keys()), value="None", label="Steering Type", info="Choose a feature to steer the model's response" ) steering_strength = gr.Slider( minimum=0.5, maximum=2.0, value=1.0, step=0.1, label="Steering Strength", info="Adjust the intensity of the steering effect (higher = more steering, very high values may cause gobbledygook)" ) system_prompt = gr.Textbox( label="System Prompt", value=default_system_prompt, lines=3 ) clear_btn = gr.Button("🗑️ Clear Chats", variant="secondary") with gr.Row(): with gr.Column(): gr.Markdown("### 🤖 Original OLMo") chatbot_unsteered = gr.Chatbot( label="Unsteered", height=500, show_copy_button=True, type="messages" ) with gr.Column(): gr.Markdown("### 🎯 Steered OLMo") chatbot_steered = gr.Chatbot( label="Steered", height=500, show_copy_button=True, type="messages" ) with gr.Row(): user_input = gr.Textbox( label="Your Message", placeholder="Type your message here... (Enter to send, Shift+Enter for new line)", lines=2, scale=4 ) submit_btn = gr.Button("Send", variant="primary", scale=1) # Example questions gr.Examples( examples=[ "What's an interesting way to spend a weekend?", "Tell me about your favorite subject.", "What should I do with $5?", "How do you approach solving difficult problems?", "What's something that makes you excited?", "Tell me a story about adventure.", "What advice would you give to someone feeling stuck?" ], inputs=user_input, label="Example Questions" ) # Handle submission def submit_message(message, history_unsteered, history_steered, steering_type, steering_strength, system_prompt): return generate_responses(message, history_unsteered, history_steered, steering_type, steering_strength, system_prompt) # Wire up the interface user_input.submit( fn=submit_message, inputs=[user_input, chatbot_unsteered, chatbot_steered, steering_type, steering_strength, system_prompt], outputs=[chatbot_unsteered, chatbot_steered, user_input] ) submit_btn.click( fn=submit_message, inputs=[user_input, chatbot_unsteered, chatbot_steered, steering_type, steering_strength, system_prompt], outputs=[chatbot_unsteered, chatbot_steered, user_input] ) clear_btn.click( fn=clear_chats, outputs=[chatbot_unsteered, chatbot_steered] ) # Update slider visibility based on steering selection def update_slider_visibility(steering_type): return gr.update(visible=(steering_type != "None")) steering_type.change( fn=update_slider_visibility, inputs=steering_type, outputs=steering_strength ) if __name__ == "__main__": demo.launch()