Extract-0: A Specialized Language Model for Document Information Extraction
Abstract
Extract-0, a 7-billion parameter language model optimized for document information extraction, outperforms larger models using synthetic data, Low-Rank Adaptation, and Group Relative Policy Optimization.
This paper presents Extract-0, a 7-billion parameter language model specifically optimized for document information extraction that achieves performance exceeding models with parameter counts several orders of magnitude larger. Through a novel combination of synthetic data generation, supervised fine-tuning with Low-Rank Adaptation (LoRA), and reinforcement learning via Group Relative Policy Optimization (GRPO), Extract-0 achieves a mean reward of 0.573 on a benchmark of 1,000 diverse document extraction tasks, outperforming GPT-4.1 (0.457), o3 (0.464), and GPT-4.1-2025 (0.459). The training methodology employs a memory-preserving synthetic data generation pipeline that produces 280,128 training examples from diverse document sources, followed by parameterefficient fine-tuning that modifies only 0.53% of model weights (40.4M out of 7.66B parameters). The reinforcement learning phase introduces a novel semantic similarity-based reward function that handles the inherent ambiguity in information extraction tasks. This research demonstrates that task-specific optimization can yield models that surpass general-purpose systems while requiring substantially fewer computational resource.
Community
Hi,
I’ve been attempting to replicate the results reported in the paper for the base model. While the paper mentions achieving a reward of around 23%, I’m consistently obtaining only ~9% (±19%).
================================================================================
2025-10-14 07:08:06,143 - INFO - EVALUATION COMPLETE
2025-10-14 07:08:06,143 - INFO - Mean Reward: 0.0899 ± 0.1965
2025-10-14 07:08:06,143 - INFO - JSON Validity: 90.10%
2025-10-14 07:08:06,143 - INFO - Samples evaluated: 1000
2025-10-14 07:08:06,143 - INFO - ================================================================================
Here’s what I did and where I might be diverging from your setup:
Evaluation pipeline:
I modified the official evaluation script to query the model via vLLM API instead of loading it directly into memory.Dataset sampling:
In the original evaluation code, it seems there are only 47 valid samples between indices100000and100000+1000.
To ensure I had 1000 valid samples, I wrote a custom loop to collect the first 1000 valid samples after index 100000.df = pd.read_csv('data/extraction_training_data.csv') df = df[100000:100000 + NUM_TEST_SAMPLES]Dataset source:
I used the dataset from this repo (HenriqueGodoy/extract-0), since I couldn’t find instructions for generating the original evaluation data.
I understand this may slightly change the reward, but a drop from 23% to 9% seems too large to be due to dataset alone.Max Tokens
Increased the max output tokens to 4096 from the papers mentioned 1500 something
I’ll attach my evaluation code below so you can verify if I might have missed something subtle.
Would you be able to clarify whether the paper’s results were obtained using a different data slice, preprocessing step, or reward normalization?
Thanks!
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import pandas as pd
import json
import orjson
from pathlib import Path
from datetime import datetime
import gc
import numpy as np
from dateutil import parser as date_parser
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
import warnings
import logging
from openai import OpenAI
from concurrent.futures import ThreadPoolExecutor, as_completed
warnings.filterwarnings('ignore')
# Setup logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(f'evaluation_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log'),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
BASE_MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
ASSISTANT_TAG = "\n\n### Assistant:\n"
MAX_SEQ_LEN = 2048
NUM_TEST_SAMPLES = 1000
TIMESTAMP = datetime.now().strftime('%Y%m%d_%H%M%S')
BATCH_SIZE = 64
def clear_gpu_memory():
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
gc.collect()
def extract_json_from_text(text):
text = text.strip()
start_idx = text.find('{')
if start_idx == -1:
return None
brace_count = 0
in_string = False
escape_next = False
for i, char in enumerate(text[start_idx:], start_idx):
if escape_next:
escape_next = False
continue
if char == '\\' and in_string:
escape_next = True
continue
if char == '"' and not escape_next:
in_string = not in_string
continue
if not in_string:
if char == '{':
brace_count += 1
elif char == '}':
brace_count -= 1
if brace_count == 0:
try:
json_str = text[start_idx:i+1]
return orjson.loads(json_str)
except:
return None
return None
def get_expected_keys(schema):
if isinstance(schema, dict):
if schema.get('type') == 'object' and 'properties' in schema:
return set(schema['properties'].keys())
else:
return set(schema.keys()) - {'type', 'properties', 'required', 'additionalProperties'}
return set()
def calculate_field_similarity(pred_value, gold_value):
if pred_value is None and gold_value is None:
return 1.0
if pred_value is None or gold_value is None:
return 0.0
if isinstance(gold_value, (int, float)) and isinstance(pred_value, (int, float)):
if gold_value == 0:
return 1.0 if pred_value == 0 else 0.0
diff = abs(pred_value - gold_value) / abs(gold_value)
return max(0.0, 1.0 - diff)
if isinstance(gold_value, str) and isinstance(pred_value, str):
try:
gold_date = date_parser.parse(gold_value)
pred_date = date_parser.parse(pred_value)
diff_days = abs((pred_date - gold_date).total_seconds()) / 86400
return max(0.0, 1.0 - diff_days / 365)
except:
if not gold_value and not pred_value:
return 1.0
if not gold_value or not pred_value:
return 0.0
gold_emb = embedding_model.encode(gold_value, convert_to_numpy=True)
pred_emb = embedding_model.encode(pred_value, convert_to_numpy=True)
similarity = np.dot(gold_emb, pred_emb) / (np.linalg.norm(gold_emb) * np.linalg.norm(pred_emb))
return max(0.0, similarity)
if isinstance(gold_value, list) and isinstance(pred_value, list):
if not gold_value and not pred_value:
return 1.0
if not gold_value or not pred_value:
return 0.0
m = len(pred_value)
n = len(gold_value)
if m == 0 and n == 0:
return 1.0
pairs = []
for i, pred_item in enumerate(pred_value):
for j, gold_item in enumerate(gold_value):
s = calculate_field_similarity(pred_item, gold_item)
pairs.append((s, i, j))
pairs.sort(key=lambda x: x[0], reverse=True)
used_pred = set()
used_gold = set()
matched_sum = 0.0
threshold = 0.35
for s, i, j in pairs:
if s < threshold:
break
if i in used_pred or j in used_gold:
continue
used_pred.add(i)
used_gold.add(j)
matched_sum += float(max(0.0, min(1.0, s)))
eps = 1e-8
precision = matched_sum / (m + eps)
recall = matched_sum / (n + eps)
if precision <= 0.0 and recall <= 0.0:
return 0.0
return float((2.0 * precision * recall) / (precision + recall + eps))
if isinstance(gold_value, dict) and isinstance(pred_value, dict):
if not gold_value and not pred_value:
return 1.0
if not gold_value or not pred_value:
return 0.0
scores = []
all_keys = set(gold_value.keys()) | set(pred_value.keys())
for key in all_keys:
scores.append(calculate_field_similarity(pred_value.get(key), gold_value.get(key)))
return np.mean(scores) if scores else 0.0
if isinstance(gold_value, bool) and isinstance(pred_value, bool):
return 1.0 if pred_value == gold_value else 0.0
return 1.0 if str(pred_value) == str(gold_value) else 0.0
def compute_extraction_reward(response_text, schema_str, gold_str):
try:
pred_json = extract_json_from_text(response_text)
if pred_json is None:
return 0.0
try:
schema = orjson.loads(schema_str) if isinstance(schema_str, str) else schema_str
gold = orjson.loads(gold_str) if isinstance(gold_str, str) else gold_str
except:
return 0.0
expected_keys = get_expected_keys(schema)
if not expected_keys:
return 0.0
pred_keys = set(pred_json.keys())
if not expected_keys.issubset(pred_keys):
return 0.0
scores = []
for k in expected_keys:
pred_val = pred_json.get(k)
gold_val = gold.get(k)
similarity = calculate_field_similarity(pred_val, gold_val)
scores.append(similarity)
if not scores:
return 0.0
return float(np.mean(scores))
except Exception as e:
return 0.0
def load_test_data():
logger.info("Loading test data...")
ref_df = pd.read_csv('data/extract-0/documents.csv')
chunk_texts = dict(zip(ref_df['chunk_id'], ref_df['text']))
df = pd.read_csv('data/extract-0/train.csv')
system_prompt = """You are an expert data extraction system. Extract structured information from documents according to the provided schema.
Return only valid JSON that matches the schema exactly.
Be precise and accurate in your extractions."""
test_examples = []
start_idx = 100000
max_attempts = NUM_TEST_SAMPLES * 100 # Try up to 100x the target to find valid samples
logger.info(f"Scanning rows starting from index {start_idx}...")
for offset in tqdm(range(max_attempts), desc="Preprocessing data"):
if len(test_examples) >= NUM_TEST_SAMPLES:
break
idx = start_idx + offset
if idx >= len(df):
logger.warning(f"Reached end of dataset at index {idx}, only found {len(test_examples)} valid examples")
break
row = df.iloc[idx]
try:
if pd.notna(row.get('chunk_refs', None)):
chunk_refs = json.loads(row['chunk_refs'])
elif pd.notna(row.get('reference_text', None)):
chunk_refs = [str(row['reference_text']).strip()]
else:
continue
schema = json.loads(row['input'])
result = json.loads(row['output'])
if 'chunk_refs' in schema:
del schema['chunk_refs']
compact_schema = json.dumps(schema, separators=(',', ':'))
chunks = []
for chunk_id in chunk_refs:
if chunk_id in chunk_texts:
chunks.append(chunk_texts[chunk_id])
if not chunks:
continue
# Try with all chunks first, then reduce if needed
for k in range(len(chunks), 0, -1):
selected_chunks = chunks[:k]
joined_chunks = "\n---\n".join(selected_chunks)
prompt = f"### System:\n{system_prompt}\n\n### User:\nSchema:\n{compact_schema}\n\nDocument:\n{joined_chunks}{ASSISTANT_TAG}"
if len(prompt) <= MAX_SEQ_LEN * 3:
test_examples.append({
"prompt": prompt,
"schema": compact_schema,
"gold_output": json.dumps(result, separators=(',', ':')),
"source_index": idx,
"system_prompt": system_prompt,
"user_content": f"Schema:\n{compact_schema}\n\nDocument:\n{joined_chunks}"
})
break
except Exception as e:
continue
logger.info(f"Loaded {len(test_examples)} test examples from {offset+1} rows scanned")
return test_examples
def evaluate_api_model(base_url, model_id, api_key=None, workers=20):
logger.info("=" * 80)
logger.info(f"EVALUATING API MODEL")
logger.info(f"Base URL: {base_url}")
logger.info(f"Model ID: {model_id}")
logger.info(f"API Key: {'PROVIDED' if api_key else 'NOT REQUIRED'}")
logger.info(f"Workers: {workers}")
logger.info(f"Test samples: {NUM_TEST_SAMPLES}")
logger.info("=" * 80)
# Initialize OpenAI client
client = OpenAI(
base_url=base_url,
api_key=api_key if api_key else "not-needed"
)
test_data = load_test_data()
results = [None] * len(test_data)
json_valid_count = 0
completed_count = 0
logger.info(f"Evaluating {len(test_data)} examples with {workers} workers...")
def process_example(idx, example):
try:
# Create chat completion request
response = client.chat.completions.create(
model=model_id,
messages=[
{"role": "system", "content": example["system_prompt"]},
{"role": "user", "content": example["user_content"]}
],
max_tokens=4096,
# reasoning_effort = "medium"
)
# Extract generated text
generated_text = response.choices[0].message.content
# Compute reward
reward = compute_extraction_reward(
generated_text,
example["schema"],
example["gold_output"]
)
# Check JSON validity
is_valid = extract_json_from_text(generated_text) is not None
return idx, reward, is_valid
except Exception as e:
logger.error(f"Error processing example {idx}: {str(e)}")
return idx, 0.0, False
# Use ThreadPoolExecutor for parallel processing
with ThreadPoolExecutor(max_workers=workers) as executor:
# Submit all tasks
futures = {
executor.submit(process_example, i, test_data[i]): i
for i in range(len(test_data))
}
# Process completed tasks with progress bar
with tqdm(total=len(test_data), desc="Evaluating") as pbar:
for future in as_completed(futures):
idx, reward, is_valid = future.result()
results[idx] = reward
if is_valid:
json_valid_count += 1
completed_count += 1
pbar.update(1)
# Log progress every 100 samples
if completed_count % 100 == 0:
valid_results = [r for r in results if r is not None]
current_mean = np.mean(valid_results) if valid_results else 0.0
logger.info(f"Progress: {completed_count}/{len(test_data)}, Current mean reward: {current_mean:.4f}")
mean_reward = np.mean(results)
std_reward = np.std(results)
json_validity = json_valid_count / len(test_data)
logger.info("=" * 80)
logger.info("EVALUATION COMPLETE")
logger.info(f"Mean Reward: {mean_reward:.4f} ± {std_reward:.4f}")
logger.info(f"JSON Validity: {json_validity:.2%}")
logger.info(f"Samples evaluated: {len(results)}")
logger.info("=" * 80)
return {
"model": model_id,
"base_url": base_url,
"mean_reward": mean_reward,
"std_reward": std_reward,
"json_validity": json_validity,
"num_samples": len(results)
}
def evaluate_local_model(model_path, model_type="extract0"):
logger.info("=" * 80)
logger.info(f"EVALUATING {model_type.upper()} MODEL")
logger.info(f"Model: {model_path}")
logger.info(f"Test samples: {NUM_TEST_SAMPLES}")
logger.info("=" * 80)
clear_gpu_memory()
logger.info(f"Loading {model_type} model...")
if model_type == "extract0":
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL_NAME,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True
)
model = PeftModel.from_pretrained(base_model, model_path)
else:
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model.eval()
test_data = load_test_data()
results = []
json_valid_count = 0
logger.info(f"Evaluating {len(test_data)} examples...")
for i in tqdm(range(0, len(test_data), BATCH_SIZE), desc="Evaluating"):
batch = test_data[i:i+BATCH_SIZE]
prompts = [ex["prompt"] for ex in batch]
schemas = [ex["schema"] for ex in batch]
gold_outputs = [ex["gold_output"] for ex in batch]
inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True, max_length=MAX_SEQ_LEN)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=532,
temperature=0.7,
top_p=0.95,
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id
)
for j, output in enumerate(outputs):
generated_tokens = output[inputs["input_ids"].shape[1]:]
generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
reward = compute_extraction_reward(generated_text, schemas[j], gold_outputs[j])
results.append(reward)
if extract_json_from_text(generated_text) is not None:
json_valid_count += 1
if (i // BATCH_SIZE + 1) % 10 == 0:
current_mean = np.mean(results)
logger.info(f"Progress: {len(results)}/{len(test_data)}, Current mean reward: {current_mean:.4f}")
torch.cuda.empty_cache()
mean_reward = np.mean(results)
std_reward = np.std(results)
json_validity = json_valid_count / len(test_data)
logger.info("=" * 80)
logger.info("EVALUATION COMPLETE")
logger.info(f"Mean Reward: {mean_reward:.4f} ± {std_reward:.4f}")
logger.info(f"JSON Validity: {json_validity:.2%}")
logger.info(f"Samples evaluated: {len(results)}")
logger.info("=" * 80)
return {
"model": model_type,
"mean_reward": mean_reward,
"std_reward": std_reward,
"json_validity": json_validity,
"num_samples": len(results)
}
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Evaluate extraction models")
parser.add_argument("--model", type=str, help="Path to local model checkpoint")
parser.add_argument("--base-url", type=str, help="Base URL for API model")
parser.add_argument("--model-id", type=str, help="Model ID for API")
parser.add_argument("--api-key", type=str, help="API key (optional)")
parser.add_argument("--workers", type=int, default=20, help="Number of worker threads for API evaluation (default: 20)")
args = parser.parse_args()
if args.base_url and args.model_id:
# Evaluate API model
result = evaluate_api_model(args.base_url, args.model_id, args.api_key, args.workers)
# Save results
results_df = pd.DataFrame([result])
output_file = f"evaluation_results_{TIMESTAMP}.csv"
results_df.to_csv(output_file, index=False)
logger.info(f"Results saved to {output_file}")
elif args.model:
# Evaluate local model
result = evaluate_local_model(args.model, "base")
# Save results
results_df = pd.DataFrame([result])
output_file = f"evaluation_results_{TIMESTAMP}.csv"
results_df.to_csv(output_file, index=False)
logger.info(f"Results saved to {output_file}")
else:
logger.error("Please provide either --base-url and --model-id for API evaluation, or --model for local evaluation")
parser.print_help()
Models citing this paper 0
No model linking this paper
Datasets citing this paper 1
Spaces citing this paper 0
No Space linking this paper