import re
import torch
import pandas as pd
from tqdm import tqdm
from collections import defaultdict
from datasets import load_dataset

from transformers import AutoModelForCausalLM, PreTrainedTokenizerFast

import fla
from fla.models import kda  # <-- Add this line

# Patch the KDA model to handle 4D attention masks during generation
import fla.models.kda.modeling_kda

original_forward = fla.models.kda.modeling_kda.KDAModel.forward

def patched_forward(self, input_ids=None, attention_mask=None, **kwargs):
    # During generation without padding, don't pass attention_mask
    if attention_mask is not None and attention_mask.dim() == 4:
        attention_mask = None

    return original_forward(self, input_ids=input_ids, attention_mask=attention_mask, **kwargs)

# Apply the patch
fla.models.kda.modeling_kda.KDAModel.forward = patched_forward

# --- Configuration ---
MODEL_ID = "THIS REPO"
DATASET_ID = "kreasof-ai/ECA-Zero"
BATCH_SIZE = 128
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# From the dataset generation script
WOLFRAM_CLASSES_MAP = {
    1: [0, 8, 32, 40, 128, 136, 160, 168],
    2: [1, 19, 23, 29, 37, 50, 108, 178],
    3: [30, 45, 60, 90, 105, 126, 150],
    4: [54, 106, 110, 124, 137, 147, 193]
}

# Invert for fast lookup: Rule -> Class
RULE_TO_CLASS = {}
for cls, rules in WOLFRAM_CLASSES_MAP.items():
    for r in rules:
        RULE_TO_CLASS[r] = cls

class ECAVerifier:
    def __init__(self):
        self.re_rule = re.compile(r"Rule: (\d+)")
        self.re_start = re.compile(r"Start: ([01]+)")
        self.re_end = re.compile(r"End: ([01]+)")
        self.re_steps = re.compile(r"Steps: (\d+)")
        self.re_hint_class = re.compile(r"Hint: Class (\d)")
        self.re_tt = re.compile(r"([01]{3})->([01])")

    def get_wolfram_class(self, prompt):
        # 1. Check for explicit Hint (Induction tasks)
        m = self.re_hint_class.search(prompt)
        if m:
            return int(m.group(1))

        # 2. Check for Rule ID (Deduction/Abduction) and look up
        m = self.re_rule.search(prompt)
        if m:
            rule = int(m.group(1))
            return RULE_TO_CLASS.get(rule, 0) # 0 = Unknown/Other

        return 0

    def get_next_state(self, state, rule):
        next_state = []
        L = len(state)
        for i in range(L):
            l, c, r = state[(i - 1) % L], state[i], state[(i + 1) % L]
            pattern = (l << 2) | (c << 1) | r
            bit = 1 if (rule & (1 << pattern)) else 0
            next_state.append(bit)
        return next_state

    def simulate(self, start_state, rule, steps):
        current = list(start_state)
        for _ in range(steps):
            current = self.get_next_state(current, rule)
        return current

    def parse_rule_string(self, text):
        matches = self.re_tt.findall(text)
        if not matches: return None
        rule = 0
        for pat, res in matches:
            if res == '1': rule |= (1 << int(pat, 2))
        return rule

    def verify(self, task_type, prompt, model_output_str):
        try:
            steps = int(self.re_steps.search(prompt).group(1))
            start_match = self.re_start.search(prompt)
            start_state = [int(x) for x in start_match.group(1)] if start_match else None
            end_match = self.re_end.search(prompt)
            end_state = [int(x) for x in end_match.group(1)] if end_match else None
            rule_match = self.re_rule.search(prompt)
            rule = int(rule_match.group(1)) if rule_match else None
        except AttributeError:
            return False

        answer = model_output_str.strip()
        try:
            if task_type == 'deduction':
                pred_state = [int(x) for x in answer if x in '01']
                if not pred_state: return False
                expected = self.simulate(start_state, rule, steps)
                return pred_state == expected

            elif task_type == 'induction':
                pred_rule = self.parse_rule_string(answer)
                if pred_rule is None: return False
                sim_end = self.simulate(start_state, pred_rule, steps)
                return sim_end == end_state

            elif task_type == 'abduction':
                pred_start = [int(x) for x in answer if x in '01']
                if not pred_start or len(pred_start) != len(end_state): return False
                sim_end = self.simulate(pred_start, rule, steps)
                return sim_end == end_state
        except Exception:
            return False
        return False

def main():
    print(f"Loading tokenizer from {MODEL_ID}...")
    try:
        tokenizer = PreTrainedTokenizerFast.from_pretrained(MODEL_ID)
    except:
        from transformers import AutoTokenizer
        tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    print(f"Loading model from {MODEL_ID}...")
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        torch_dtype=torch.bfloat16,
        device_map=DEVICE,
    )
    print("Compiling the model")
    model = torch.compile(model)
    model.eval()

    print("Loading Test Set...")
    dataset = load_dataset(DATASET_ID, split="test")
    verifier = ECAVerifier()

    # Storage: results[task][class_id] = [True, False, ...]
    results = defaultdict(lambda: defaultdict(list))

    print("Starting Stratified Evaluation...")

    for i in tqdm(range(0, len(dataset), BATCH_SIZE)):
        batch = dataset[i : i + BATCH_SIZE]
        tasks = batch['task']
        inputs = batch['input']

        prompts = [f"{tokenizer.bos_token}{inp}\n<think>\n" for inp in inputs]

        # FIX: Added return_token_type_ids=False
        encodings = tokenizer(
            prompts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=2048,
            return_token_type_ids=False,
        ).to(DEVICE)

        with torch.no_grad():
            generated_ids = model.generate(
                input_ids=encodings['input_ids'],
                max_new_tokens=2048,
                do_sample=False,
                pad_token_id=tokenizer.pad_token_id,
                eos_token_id=tokenizer.eos_token_id,
            )

        decoded_outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=False)

        for j, raw_output in enumerate(decoded_outputs):
            if "</think>" in raw_output:
                final_answer = raw_output.split("</think>")[-1].replace(tokenizer.eos_token, "").strip()
            else:
                final_answer = ""

            # Determine Class
            w_class = verifier.get_wolfram_class(inputs[j])

            # Verify
            is_correct = verifier.verify(tasks[j], inputs[j], final_answer)

            # Store
            results[tasks[j]][w_class].append(is_correct)
            results[tasks[j]]["ALL"].append(is_correct)

    # --- Print Report ---
    print("\n" + "="*60)
    print("STRATIFIED RESULTS (Accuracy by Wolfram Class)")
    print("="*60)

    # Define column headers
    print(f"{'Task':<12} | {'Class 1':<10} | {'Class 2':<10} | {'Class 3':<10} | {'Class 4':<10} | {'OVERALL':<10}")
    print("-" * 75)

    for task in ["deduction", "induction", "abduction"]:
        row_str = f"{task.capitalize():<12} | "

        for c in [1, 2, 3, 4]:
            outcomes = results[task][c]
            if outcomes:
                acc = sum(outcomes) / len(outcomes)
                row_str += f"{acc:.1%} ({len(outcomes):<3}) | " # concise
            else:
                row_str += "N/A        | "

        # Overall
        all_outcomes = results[task]["ALL"]
        if all_outcomes:
            total_acc = sum(all_outcomes) / len(all_outcomes)
            row_str += f"{total_acc:.1%} ({len(all_outcomes)})"

        print(row_str)

    print("="*60)
    print("Class Legend:")
    print("1: Uniform (Trivial) | 2: Periodic (Easy) | 3: Chaotic (Hard) | 4: Complex (Hardest)")

if __name__ == "__main__":
    main()
============================================================
STRATIFIED RESULTS (Accuracy by Wolfram Class)
============================================================
Task         | Class 1    | Class 2    | Class 3    | Class 4    | OVERALL   
---------------------------------------------------------------------------
Deduction    | 50.4% (113) | 44.2% (226) | 36.7% (412) | 38.8% (410) | 40.2% (1161)
Induction    | 37.2% (113) | 8.8% (227) | 1.9% (414) | 0.2% (411) | 6.1% (1165)
Abduction    | 19.1% (47 ) | 23.8% (185) | 5.7% (388) | 4.4% (387) | 9.1% (1007)
============================================================
Class Legend:
1: Uniform (Trivial) | 2: Periodic (Easy) | 3: Chaotic (Hard) | 4: Complex (Hardest)
Downloads last month
58
Safetensors
Model size
28M params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Dataset used to train ChavyvAkvar/kda-baseline-1-epoch