import re
import torch
import pandas as pd
from tqdm import tqdm
from collections import defaultdict
from datasets import load_dataset
from transformers import AutoModelForCausalLM, PreTrainedTokenizerFast
import fla
from fla.models import kda # <-- Add this line
# Patch the KDA model to handle 4D attention masks during generation
import fla.models.kda.modeling_kda
original_forward = fla.models.kda.modeling_kda.KDAModel.forward
def patched_forward(self, input_ids=None, attention_mask=None, **kwargs):
# During generation without padding, don't pass attention_mask
if attention_mask is not None and attention_mask.dim() == 4:
attention_mask = None
return original_forward(self, input_ids=input_ids, attention_mask=attention_mask, **kwargs)
# Apply the patch
fla.models.kda.modeling_kda.KDAModel.forward = patched_forward
# --- Configuration ---
MODEL_ID = "THIS REPO"
DATASET_ID = "kreasof-ai/ECA-Zero"
BATCH_SIZE = 128
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# From the dataset generation script
WOLFRAM_CLASSES_MAP = {
1: [0, 8, 32, 40, 128, 136, 160, 168],
2: [1, 19, 23, 29, 37, 50, 108, 178],
3: [30, 45, 60, 90, 105, 126, 150],
4: [54, 106, 110, 124, 137, 147, 193]
}
# Invert for fast lookup: Rule -> Class
RULE_TO_CLASS = {}
for cls, rules in WOLFRAM_CLASSES_MAP.items():
for r in rules:
RULE_TO_CLASS[r] = cls
class ECAVerifier:
def __init__(self):
self.re_rule = re.compile(r"Rule: (\d+)")
self.re_start = re.compile(r"Start: ([01]+)")
self.re_end = re.compile(r"End: ([01]+)")
self.re_steps = re.compile(r"Steps: (\d+)")
self.re_hint_class = re.compile(r"Hint: Class (\d)")
self.re_tt = re.compile(r"([01]{3})->([01])")
def get_wolfram_class(self, prompt):
# 1. Check for explicit Hint (Induction tasks)
m = self.re_hint_class.search(prompt)
if m:
return int(m.group(1))
# 2. Check for Rule ID (Deduction/Abduction) and look up
m = self.re_rule.search(prompt)
if m:
rule = int(m.group(1))
return RULE_TO_CLASS.get(rule, 0) # 0 = Unknown/Other
return 0
def get_next_state(self, state, rule):
next_state = []
L = len(state)
for i in range(L):
l, c, r = state[(i - 1) % L], state[i], state[(i + 1) % L]
pattern = (l << 2) | (c << 1) | r
bit = 1 if (rule & (1 << pattern)) else 0
next_state.append(bit)
return next_state
def simulate(self, start_state, rule, steps):
current = list(start_state)
for _ in range(steps):
current = self.get_next_state(current, rule)
return current
def parse_rule_string(self, text):
matches = self.re_tt.findall(text)
if not matches: return None
rule = 0
for pat, res in matches:
if res == '1': rule |= (1 << int(pat, 2))
return rule
def verify(self, task_type, prompt, model_output_str):
try:
steps = int(self.re_steps.search(prompt).group(1))
start_match = self.re_start.search(prompt)
start_state = [int(x) for x in start_match.group(1)] if start_match else None
end_match = self.re_end.search(prompt)
end_state = [int(x) for x in end_match.group(1)] if end_match else None
rule_match = self.re_rule.search(prompt)
rule = int(rule_match.group(1)) if rule_match else None
except AttributeError:
return False
answer = model_output_str.strip()
try:
if task_type == 'deduction':
pred_state = [int(x) for x in answer if x in '01']
if not pred_state: return False
expected = self.simulate(start_state, rule, steps)
return pred_state == expected
elif task_type == 'induction':
pred_rule = self.parse_rule_string(answer)
if pred_rule is None: return False
sim_end = self.simulate(start_state, pred_rule, steps)
return sim_end == end_state
elif task_type == 'abduction':
pred_start = [int(x) for x in answer if x in '01']
if not pred_start or len(pred_start) != len(end_state): return False
sim_end = self.simulate(pred_start, rule, steps)
return sim_end == end_state
except Exception:
return False
return False
def main():
print(f"Loading tokenizer from {MODEL_ID}...")
try:
tokenizer = PreTrainedTokenizerFast.from_pretrained(MODEL_ID)
except:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print(f"Loading model from {MODEL_ID}...")
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
device_map=DEVICE,
)
print("Compiling the model")
model = torch.compile(model)
model.eval()
print("Loading Test Set...")
dataset = load_dataset(DATASET_ID, split="test")
verifier = ECAVerifier()
# Storage: results[task][class_id] = [True, False, ...]
results = defaultdict(lambda: defaultdict(list))
print("Starting Stratified Evaluation...")
for i in tqdm(range(0, len(dataset), BATCH_SIZE)):
batch = dataset[i : i + BATCH_SIZE]
tasks = batch['task']
inputs = batch['input']
prompts = [f"{tokenizer.bos_token}{inp}\n<think>\n" for inp in inputs]
# FIX: Added return_token_type_ids=False
encodings = tokenizer(
prompts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=2048,
return_token_type_ids=False,
).to(DEVICE)
with torch.no_grad():
generated_ids = model.generate(
input_ids=encodings['input_ids'],
max_new_tokens=2048,
do_sample=False,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
)
decoded_outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=False)
for j, raw_output in enumerate(decoded_outputs):
if "</think>" in raw_output:
final_answer = raw_output.split("</think>")[-1].replace(tokenizer.eos_token, "").strip()
else:
final_answer = ""
# Determine Class
w_class = verifier.get_wolfram_class(inputs[j])
# Verify
is_correct = verifier.verify(tasks[j], inputs[j], final_answer)
# Store
results[tasks[j]][w_class].append(is_correct)
results[tasks[j]]["ALL"].append(is_correct)
# --- Print Report ---
print("\n" + "="*60)
print("STRATIFIED RESULTS (Accuracy by Wolfram Class)")
print("="*60)
# Define column headers
print(f"{'Task':<12} | {'Class 1':<10} | {'Class 2':<10} | {'Class 3':<10} | {'Class 4':<10} | {'OVERALL':<10}")
print("-" * 75)
for task in ["deduction", "induction", "abduction"]:
row_str = f"{task.capitalize():<12} | "
for c in [1, 2, 3, 4]:
outcomes = results[task][c]
if outcomes:
acc = sum(outcomes) / len(outcomes)
row_str += f"{acc:.1%} ({len(outcomes):<3}) | " # concise
else:
row_str += "N/A | "
# Overall
all_outcomes = results[task]["ALL"]
if all_outcomes:
total_acc = sum(all_outcomes) / len(all_outcomes)
row_str += f"{total_acc:.1%} ({len(all_outcomes)})"
print(row_str)
print("="*60)
print("Class Legend:")
print("1: Uniform (Trivial) | 2: Periodic (Easy) | 3: Chaotic (Hard) | 4: Complex (Hardest)")
if __name__ == "__main__":
main()
============================================================
STRATIFIED RESULTS (Accuracy by Wolfram Class)
============================================================
Task | Class 1 | Class 2 | Class 3 | Class 4 | OVERALL
---------------------------------------------------------------------------
Deduction | 50.4% (113) | 44.2% (226) | 36.7% (412) | 38.8% (410) | 40.2% (1161)
Induction | 37.2% (113) | 8.8% (227) | 1.9% (414) | 0.2% (411) | 6.1% (1165)
Abduction | 19.1% (47 ) | 23.8% (185) | 5.7% (388) | 4.4% (387) | 9.1% (1007)
============================================================
Class Legend:
1: Uniform (Trivial) | 2: Periodic (Easy) | 3: Chaotic (Hard) | 4: Complex (Hardest)
- Downloads last month
- 58
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
🙋
Ask for provider support