""" Kirim-1-Math Inference Script Mathematical reasoning with tool calling capabilities """ import torch import json import re from transformers import AutoModelForCausalLM, AutoTokenizer from typing import List, Dict, Any, Optional import warnings warnings.filterwarnings('ignore') class MathToolExecutor: """Execute mathematical tools called by the model""" def __init__(self): try: import sympy as sp import numpy as np self.sp = sp self.np = np except ImportError: print("Warning: SymPy or NumPy not installed. Tool execution limited.") self.sp = None self.np = None def execute_tool(self, tool_name: str, arguments: Dict[str, Any]) -> str: """Execute a tool and return results""" try: if tool_name == "calculator": return self._calculator(arguments) elif tool_name == "symbolic_solver": return self._symbolic_solver(arguments) elif tool_name == "derivative": return self._derivative(arguments) elif tool_name == "integrate": return self._integrate(arguments) elif tool_name == "simplify": return self._simplify(arguments) elif tool_name == "latex_formatter": return self._latex_formatter(arguments) else: return f"Unknown tool: {tool_name}" except Exception as e: return f"Tool execution error: {str(e)}" def _calculator(self, args: Dict) -> str: """Precise calculator""" expr = args.get("expression", "") precision = args.get("precision", 15) if not self.sp: return "SymPy not available" try: result = self.sp.sympify(expr) result = self.sp.N(result, precision) return f"Result: {result}" except Exception as e: return f"Calculation error: {e}" def _symbolic_solver(self, args: Dict) -> str: """Solve equations symbolically""" equation = args.get("equation", "") variable = args.get("variable", "x") if not self.sp: return "SymPy not available" try: var = self.sp.Symbol(variable) eq = self.sp.sympify(equation) solutions = self.sp.solve(eq, var) return f"Solutions: {solutions}" except Exception as e: return f"Solver error: {e}" def _derivative(self, args: Dict) -> str: """Calculate derivatives""" function = args.get("function", "") variable = args.get("variable", "x") order = args.get("order", 1) if not self.sp: return "SymPy not available" try: var = self.sp.Symbol(variable) func = self.sp.sympify(function) result = self.sp.diff(func, var, order) return f"Derivative: {result}" except Exception as e: return f"Derivative error: {e}" def _integrate(self, args: Dict) -> str: """Calculate integrals""" function = args.get("function", "") variable = args.get("variable", "x") lower = args.get("lower_bound") upper = args.get("upper_bound") if not self.sp: return "SymPy not available" try: var = self.sp.Symbol(variable) func = self.sp.sympify(function) if lower is not None and upper is not None: result = self.sp.integrate(func, (var, lower, upper)) else: result = self.sp.integrate(func, var) return f"Integral: {result}" except Exception as e: return f"Integration error: {e}" def _simplify(self, args: Dict) -> str: """Simplify expressions""" expression = args.get("expression", "") if not self.sp: return "SymPy not available" try: expr = self.sp.sympify(expression) result = self.sp.simplify(expr) return f"Simplified: {result}" except Exception as e: return f"Simplification error: {e}" def _latex_formatter(self, args: Dict) -> str: """Format as LaTeX""" expression = args.get("expression", "") inline = args.get("inline", False) if not self.sp: return "SymPy not available" try: expr = self.sp.sympify(expression) latex = self.sp.latex(expr) if inline: return f"${latex}$" else: return f"$$\n{latex}\n$$" except Exception as e: return f"LaTeX formatting error: {e}" class KirimMath: """Kirim-1-Math inference with tool calling""" def __init__( self, model_path: str = "Kirim-ai/Kirim-1-Math", device: str = "auto", load_in_8bit: bool = False, load_in_4bit: bool = False ): print(f"Loading Kirim-1-Math from {model_path}...") # Load tokenizer self.tokenizer = AutoTokenizer.from_pretrained( model_path, trust_remote_code=True, use_fast=True ) # Configure model loading model_kwargs = { "trust_remote_code": True, "torch_dtype": torch.bfloat16, "low_cpu_mem_usage": True, } if load_in_8bit: model_kwargs["load_in_8bit"] = True print("Loading in 8-bit mode (30GB VRAM)") elif load_in_4bit: model_kwargs["load_in_4bit"] = True print("Loading in 4-bit mode (20GB VRAM)") else: print("Loading in full precision (80GB VRAM)") if device == "auto": model_kwargs["device_map"] = "auto" # Load model self.model = AutoModelForCausalLM.from_pretrained( model_path, **model_kwargs ) if device not in ["auto"] and not (load_in_8bit or load_in_4bit): self.model = self.model.to(device) self.model.eval() # Initialize tool executor self.tool_executor = MathToolExecutor() print("✓ Model loaded successfully!") print("✓ Tool calling enabled\n") def solve_problem( self, problem: str, show_work: bool = True, use_tools: bool = True, max_new_tokens: int = 4096, temperature: float = 0.1 ) -> str: """ Solve a mathematical problem Args: problem: Math problem to solve show_work: Show step-by-step solution use_tools: Enable tool calling max_new_tokens: Maximum tokens to generate temperature: Sampling temperature (lower = more deterministic) Returns: Solution with reasoning """ # Construct prompt system_prompt = "You are Kirim-1-Math, an advanced mathematical reasoning AI. " if show_work: system_prompt += "Show your work step-by-step. " if use_tools: system_prompt += "You can use tools for calculations. Available tools: calculator, symbolic_solver, derivative, integrate, simplify." messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": problem} ] # Generate initial response response = self._generate(messages, max_new_tokens, temperature) # Check for tool calls if use_tools and "" in response: response = self._handle_tool_calls(response, messages, max_new_tokens, temperature) return response def _generate(self, messages: List[Dict], max_new_tokens: int, temperature: float) -> str: """Generate response from model""" formatted_prompt = self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) inputs = self.tokenizer( formatted_prompt, return_tensors="pt", truncation=True, max_length=28672 ) if hasattr(self.model, 'device'): inputs = {k: v.to(self.model.device) for k, v in inputs.items()} gen_kwargs = { "max_new_tokens": max_new_tokens, "temperature": temperature, "top_p": 0.95, "do_sample": temperature > 0, "pad_token_id": self.tokenizer.pad_token_id, "eos_token_id": self.tokenizer.eos_token_id, } with torch.no_grad(): outputs = self.model.generate(**inputs, **gen_kwargs) full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=False) # Extract assistant response if "<|assistant|>" in full_response: response = full_response.split("<|assistant|>")[-1] response = response.replace("<|end_of_text|>", "").strip() return response return full_response.strip() def _handle_tool_calls(self, response: str, messages: List[Dict], max_new_tokens: int, temperature: float) -> str: """Process tool calls in response""" # Extract tool calls tool_pattern = r'(.*?)' tool_calls = re.findall(tool_pattern, response, re.DOTALL) if not tool_calls: return response # Execute each tool call for tool_call_str in tool_calls: try: tool_call = json.loads(tool_call_str.strip()) tool_name = tool_call.get("name", "") arguments = tool_call.get("arguments", {}) print(f"\n🔧 Executing tool: {tool_name}") print(f" Arguments: {arguments}") # Execute tool result = self.tool_executor.execute_tool(tool_name, arguments) print(f" Result: {result}\n") # Add tool result to messages messages.append({"role": "assistant", "content": response}) messages.append({"role": "tool", "content": f"{result}"}) # Generate continuation with tool result response = self._generate(messages, max_new_tokens, temperature) except json.JSONDecodeError: print(f"⚠️ Failed to parse tool call: {tool_call_str}") continue return response def interactive_math(self): """Interactive math problem solver""" print("\n" + "="*60) print(" Kirim-1-Math - Interactive Mode") print(" First model with tool calling!") print("="*60) print("\nCommands:") print(" 'quit' or 'exit' - End session") print(" 'tools off/on' - Toggle tool calling") print(" 'work off/on' - Toggle showing work") print("\n" + "="*60 + "\n") use_tools = True show_work = True while True: try: user_input = input("Problem: ").strip() if user_input.lower() in ['quit', 'exit', 'q']: print("\nGoodbye! Happy solving! 🧮\n") break if user_input.lower().startswith('tools'): use_tools = 'on' in user_input.lower() print(f"✓ Tool calling: {'enabled' if use_tools else 'disabled'}\n") continue if user_input.lower().startswith('work'): show_work = 'on' in user_input.lower() print(f"✓ Show work: {'enabled' if show_work else 'disabled'}\n") continue if not user_input: continue # Solve problem print("\n" + "-"*60) solution = self.solve_problem( user_input, show_work=show_work, use_tools=use_tools ) print(solution) print("-"*60 + "\n") except KeyboardInterrupt: print("\n\nGoodbye! 🧮\n") break except Exception as e: print(f"\n❌ Error: {e}\n") def main(): import argparse parser = argparse.ArgumentParser(description="Kirim-1-Math Inference") parser.add_argument("--model_path", type=str, default="Kirim-ai/Kirim-1-Math") parser.add_argument("--device", type=str, default="auto") parser.add_argument("--load_in_8bit", action="store_true") parser.add_argument("--load_in_4bit", action="store_true") parser.add_argument("--interactive", action="store_true") parser.add_argument("--problem", type=str, help="Single problem to solve") args = parser.parse_args() # Initialize model kirim_math = KirimMath( model_path=args.model_path, device=args.device, load_in_8bit=args.load_in_8bit, load_in_4bit=args.load_in_4bit ) if args.interactive: kirim_math.interactive_math() elif args.problem: solution = kirim_math.solve_problem(args.problem) print(f"\nProblem: {args.problem}") print(f"\nSolution:\n{solution}\n") else: # Demo examples print("="*60) print(" Demo Examples") print("="*60 + "\n") demos = [ "Solve: x² - 5x + 6 = 0", "Calculate the derivative of x³ + 2x² - x + 1", "解方程: 2x + 3y = 12, 4x - y = 5", "Integrate: ∫(x² + 1)dx" ] for problem in demos: print(f"\nProblem: {problem}") print("-" * 60) solution = kirim_math.solve_problem(problem) print(solution) print("=" * 60) if __name__ == "__main__": main()