import torch import torch.nn.functional as F import os import sys # --- Ensure src folder is in the path for imports --- # This helps the script find model.py, tokenizer.py, etc. sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), 'src'))) # --- Import all project components --- from src.tokenizer import generate_v1_data, CharacterTokenizer from src.model import TinyLLM, n_embed, n_head, n_layer, dropout # Also import hyperparams # --- Configuration (CHECK THIS PATH!) --- DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' # Use the file name confirmed in your last successful training run WEIGHTS_PATH = 'data/tinyllm_v1_weights1.pt' @torch.no_grad() def generate(model, idx, max_new_tokens): """ Takes a sequence of indices (idx) and generates max_new_tokens new indices using the model autoregressively. """ model.eval() # Set model to evaluation mode # idx is (B, T) array of indices in the current context for _ in range(max_new_tokens): # Crop context to the model's block size (block_size will be set below) block_size = model.block_size idx_cond = idx[:, -block_size:] # Get predictions logits, _ = model(idx_cond) # Focus only on the last time step (the next token) logits = logits[:, -1, :] # Apply softmax to get probabilities probs = F.softmax(logits, dim=-1) # Sample from the distribution idx_next = torch.multinomial(probs, num_samples=1) # Append sampled index to the running sequence idx = torch.cat((idx, idx_next), dim=1) return idx def setup_inference(): """Sets up the model, tokenizer, and loads weights for inference.""" try: # 1. Setup Data Pipeline to determine sequence lengths raw_data = generate_v1_data() tokenizer = CharacterTokenizer(raw_data) max_len = max(len(s) for s in raw_data) # FIX: Ensure block_size matches the model's training size (14) # block_size is the maximum sequence length (T) the model can handle block_size = max_len # Use max_len directly to get the 14 size for the V1 dataset # 2. Initialize Model Architecture model = TinyLLM( vocab_size=tokenizer.vocab_size, n_embed=n_embed, n_head=n_head, n_layer=n_layer, block_size=block_size, dropout=dropout ).to(DEVICE) # 3. Load Trained Weights model.load_state_dict(torch.load(WEIGHTS_PATH, map_location=DEVICE)) print(f"\nSuccessfully loaded model weights from {WEIGHTS_PATH}") return model, tokenizer, block_size except FileNotFoundError: print(f"Error: Weights file not found at {WEIGHTS_PATH}. Please run train.py first.") return None, None, None except RuntimeError as e: print(f"Runtime Error during loading: {e}") print("Please ensure your src/model.py hyperparameters match the saved weights.") return None, None, None def solve_problem(model, tokenizer, question_str, block_size): """Encodes a question, generates the answer, and prints the result.""" # 1. Encode the question string (e.g., "5 + 3") context_tokens = tokenizer.encode(question_str) # Add an extra space before the = for clean formatting context_tokens.append(tokenizer.encode(' ')[0]) # Convert list of token IDs to a PyTorch tensor (1, T) idx = torch.tensor([context_tokens], dtype=torch.long, device=DEVICE) # 2. Generate the rest of the sequence (the "= ANS" part) # The max_len is the length of the expected output: = 9 (4 characters) max_new_tokens = block_size - idx.shape[1] if max_new_tokens <= 0: print("Error: Input sequence is too long.") return # Generate tokens generated_idx = generate(model, idx, max_new_tokens=max_new_tokens) # 3. Decode the result and print generated_sequence = tokenizer.decode(generated_idx[0].tolist()) print(f"Question: '{question_str}'") print(f"Model Output: '{generated_sequence}'") # --- Main Interactive User Loop --- if __name__ == '__main__': model, tokenizer, block_size = setup_inference() if model is not None: print("\n--- TinyLLM Math Chatbot Initialized ---") print("Enter a single-digit math problem (e.g., 4 + 5, 8 / 2).") print("Type 'exit' to quit.") while True: # 1. Get user input question_str = input("Input: ") if question_str.lower() == 'exit': break # 2. Basic Input Validation question_str = question_str.strip() parts = question_str.split() # Simple check for format N op N and single digits is_valid = ( len(parts) == 3 and parts[0].isdigit() and len(parts[0]) == 1 and parts[2].isdigit() and len(parts[2]) == 1 and parts[1] in ['+', '-', '*', '/'] ) if not is_valid: print("Error: Please enter a problem in the format 'N op N' with single-digit operands (e.g., 2 + 3).\n") continue # 3. Solve the problem using the trained model solve_problem(model, tokenizer, question_str, block_size) print("-" * 30) print("\n--- Chatbot Shutting Down ---")