Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import mysql.connector | |
| from mysql.connector import Error | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| # Load the model and tokenizer | |
| model_name = "premai-io/prem-1B-SQL" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForCausalLM.from_pretrained(model_name) | |
| def generate_sql(natural_language_query): | |
| """Generate SQL query from natural language.""" | |
| # Define your schema information | |
| schema_info = """ | |
| CREATE TABLE sales ( | |
| pizza_id DECIMAL(8,2) PRIMARY KEY, | |
| order_id DECIMAL(8,2), | |
| pizza_name_id VARCHAR(14), | |
| quantity DECIMAL(4,2), | |
| order_date DATE, | |
| order_time VARCHAR(8), | |
| unit_price DECIMAL(5,2), | |
| total_price DECIMAL(5,2), | |
| pizza_size VARCHAR(3), | |
| pizza_category VARCHAR(7), | |
| pizza_ingredients VARCHAR(97), | |
| pizza_name VARCHAR(42) | |
| ); | |
| """ | |
| # Construct the prompt | |
| prompt = f"""### Task: Generate a SQL query to answer the following question. | |
| ### Database Schema: | |
| {schema_info} | |
| ### Question: {natural_language_query} | |
| ### SQL Query:""" | |
| # Tokenize and generate | |
| inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(model.device) | |
| outputs = model.generate( | |
| inputs["input_ids"], | |
| max_length=512, | |
| temperature=0.1, | |
| do_sample=True, | |
| top_p=0.95, | |
| num_return_sequences=1, | |
| eos_token_id=tokenizer.eos_token_id, | |
| pad_token_id=tokenizer.pad_token_id | |
| ) | |
| # Decode and clean up the response | |
| generated_query = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| sql_query = generated_query.split("### SQL Query:")[-1].strip() | |
| return sql_query | |
| def main(): | |
| # Gradio interface setup | |
| iface = gr.Interface( | |
| fn=generate_sql, | |
| inputs="text", | |
| outputs="text", | |
| title="Natural Language to SQL Query Generator", | |
| description="Enter a natural language query to generate the corresponding SQL query." | |
| ) | |
| iface.launch() | |
| if __name__ == "__main__": | |
| main() | |