Papers
arxiv:2509.22906

Extract-0: A Specialized Language Model for Document Information Extraction

Published on Sep 26
Authors:

Abstract

Extract-0, a 7-billion parameter language model optimized for document information extraction, outperforms larger models using synthetic data, Low-Rank Adaptation, and Group Relative Policy Optimization.

AI-generated summary

This paper presents Extract-0, a 7-billion parameter language model specifically optimized for document information extraction that achieves performance exceeding models with parameter counts several orders of magnitude larger. Through a novel combination of synthetic data generation, supervised fine-tuning with Low-Rank Adaptation (LoRA), and reinforcement learning via Group Relative Policy Optimization (GRPO), Extract-0 achieves a mean reward of 0.573 on a benchmark of 1,000 diverse document extraction tasks, outperforming GPT-4.1 (0.457), o3 (0.464), and GPT-4.1-2025 (0.459). The training methodology employs a memory-preserving synthetic data generation pipeline that produces 280,128 training examples from diverse document sources, followed by parameterefficient fine-tuning that modifies only 0.53% of model weights (40.4M out of 7.66B parameters). The reinforcement learning phase introduces a novel semantic similarity-based reward function that handles the inherent ambiguity in information extraction tasks. This research demonstrates that task-specific optimization can yield models that surpass general-purpose systems while requiring substantially fewer computational resource.

Community

Hi,

I’ve been attempting to replicate the results reported in the paper for the base model. While the paper mentions achieving a reward of around 23%, I’m consistently obtaining only ~9% (±19%).

================================================================================
2025-10-14 07:08:06,143 - INFO - EVALUATION COMPLETE
2025-10-14 07:08:06,143 - INFO - Mean Reward: 0.0899 ± 0.1965
2025-10-14 07:08:06,143 - INFO - JSON Validity: 90.10%
2025-10-14 07:08:06,143 - INFO - Samples evaluated: 1000
2025-10-14 07:08:06,143 - INFO - ================================================================================

Here’s what I did and where I might be diverging from your setup:

  1. Evaluation pipeline:
    I modified the official evaluation script to query the model via vLLM API instead of loading it directly into memory.

  2. Dataset sampling:
    In the original evaluation code, it seems there are only 47 valid samples between indices 100000 and 100000+1000.
    To ensure I had 1000 valid samples, I wrote a custom loop to collect the first 1000 valid samples after index 100000.

    df = pd.read_csv('data/extraction_training_data.csv')
    df = df[100000:100000 + NUM_TEST_SAMPLES]
    
  3. Dataset source:
    I used the dataset from this repo (HenriqueGodoy/extract-0), since I couldn’t find instructions for generating the original evaluation data.
    I understand this may slightly change the reward, but a drop from 23% to 9% seems too large to be due to dataset alone.

  4. Max Tokens
    Increased the max output tokens to 4096 from the papers mentioned 1500 something

I’ll attach my evaluation code below so you can verify if I might have missed something subtle.
Would you be able to clarify whether the paper’s results were obtained using a different data slice, preprocessing step, or reward normalization?

Thanks!

import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import pandas as pd
import json
import orjson
from pathlib import Path
from datetime import datetime
import gc
import numpy as np
from dateutil import parser as date_parser
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
import warnings
import logging
from openai import OpenAI
from concurrent.futures import ThreadPoolExecutor, as_completed

warnings.filterwarnings('ignore')

# Setup logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler(f'evaluation_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log'),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')

BASE_MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
ASSISTANT_TAG = "\n\n### Assistant:\n"
MAX_SEQ_LEN = 2048
NUM_TEST_SAMPLES = 1000
TIMESTAMP = datetime.now().strftime('%Y%m%d_%H%M%S')
BATCH_SIZE = 64

def clear_gpu_memory():
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
        gc.collect()

def extract_json_from_text(text):
    text = text.strip()
    start_idx = text.find('{')
    if start_idx == -1:
        return None
    brace_count = 0
    in_string = False
    escape_next = False
    for i, char in enumerate(text[start_idx:], start_idx):
        if escape_next:
            escape_next = False
            continue
        if char == '\\' and in_string:
            escape_next = True
            continue
        if char == '"' and not escape_next:
            in_string = not in_string
            continue
        if not in_string:
            if char == '{':
                brace_count += 1
            elif char == '}':
                brace_count -= 1
                if brace_count == 0:
                    try:
                        json_str = text[start_idx:i+1]
                        return orjson.loads(json_str)
                    except:
                        return None
    return None

def get_expected_keys(schema):
    if isinstance(schema, dict):
        if schema.get('type') == 'object' and 'properties' in schema:
            return set(schema['properties'].keys())
        else:
            return set(schema.keys()) - {'type', 'properties', 'required', 'additionalProperties'}
    return set()

def calculate_field_similarity(pred_value, gold_value):
    if pred_value is None and gold_value is None:
        return 1.0
    if pred_value is None or gold_value is None:
        return 0.0
    
    if isinstance(gold_value, (int, float)) and isinstance(pred_value, (int, float)):
        if gold_value == 0:
            return 1.0 if pred_value == 0 else 0.0
        diff = abs(pred_value - gold_value) / abs(gold_value)
        return max(0.0, 1.0 - diff)
    
    if isinstance(gold_value, str) and isinstance(pred_value, str):
        try:
            gold_date = date_parser.parse(gold_value)
            pred_date = date_parser.parse(pred_value)
            diff_days = abs((pred_date - gold_date).total_seconds()) / 86400
            return max(0.0, 1.0 - diff_days / 365)
        except:
            if not gold_value and not pred_value:
                return 1.0
            if not gold_value or not pred_value:
                return 0.0
            gold_emb = embedding_model.encode(gold_value, convert_to_numpy=True)
            pred_emb = embedding_model.encode(pred_value, convert_to_numpy=True)
            similarity = np.dot(gold_emb, pred_emb) / (np.linalg.norm(gold_emb) * np.linalg.norm(pred_emb))
            return max(0.0, similarity)

    if isinstance(gold_value, list) and isinstance(pred_value, list):
        if not gold_value and not pred_value:
            return 1.0
        if not gold_value or not pred_value:
            return 0.0

        m = len(pred_value)
        n = len(gold_value)
        if m == 0 and n == 0:
            return 1.0

        pairs = []
        for i, pred_item in enumerate(pred_value):
            for j, gold_item in enumerate(gold_value):
                s = calculate_field_similarity(pred_item, gold_item)
                pairs.append((s, i, j))
        pairs.sort(key=lambda x: x[0], reverse=True)

        used_pred = set()
        used_gold = set()
        matched_sum = 0.0
        threshold = 0.35
        for s, i, j in pairs:
            if s < threshold:
                break
            if i in used_pred or j in used_gold:
                continue
            used_pred.add(i)
            used_gold.add(j)
            matched_sum += float(max(0.0, min(1.0, s)))

        eps = 1e-8
        precision = matched_sum / (m + eps)
        recall = matched_sum / (n + eps)
        if precision <= 0.0 and recall <= 0.0:
            return 0.0
        return float((2.0 * precision * recall) / (precision + recall + eps))
    
    if isinstance(gold_value, dict) and isinstance(pred_value, dict):
        if not gold_value and not pred_value:
            return 1.0
        if not gold_value or not pred_value:
            return 0.0
        scores = []
        all_keys = set(gold_value.keys()) | set(pred_value.keys())
        for key in all_keys:
            scores.append(calculate_field_similarity(pred_value.get(key), gold_value.get(key)))
        return np.mean(scores) if scores else 0.0
    
    if isinstance(gold_value, bool) and isinstance(pred_value, bool):
        return 1.0 if pred_value == gold_value else 0.0
    
    return 1.0 if str(pred_value) == str(gold_value) else 0.0

def compute_extraction_reward(response_text, schema_str, gold_str):
    try:
        pred_json = extract_json_from_text(response_text)
        if pred_json is None:
            return 0.0
        
        try:
            schema = orjson.loads(schema_str) if isinstance(schema_str, str) else schema_str
            gold = orjson.loads(gold_str) if isinstance(gold_str, str) else gold_str
        except:
            return 0.0
        
        expected_keys = get_expected_keys(schema)
        if not expected_keys:
            return 0.0
        
        pred_keys = set(pred_json.keys())
        if not expected_keys.issubset(pred_keys):
            return 0.0
        
        scores = []
        for k in expected_keys:
            pred_val = pred_json.get(k)
            gold_val = gold.get(k)
            similarity = calculate_field_similarity(pred_val, gold_val)
            scores.append(similarity)
        
        if not scores:
            return 0.0
        
        return float(np.mean(scores))
        
    except Exception as e:
        return 0.0

def load_test_data():
    logger.info("Loading test data...")
    
    ref_df = pd.read_csv('data/extract-0/documents.csv')
    chunk_texts = dict(zip(ref_df['chunk_id'], ref_df['text']))
    
    df = pd.read_csv('data/extract-0/train.csv')
    
    system_prompt = """You are an expert data extraction system. Extract structured information from documents according to the provided schema.
Return only valid JSON that matches the schema exactly.
Be precise and accurate in your extractions."""
    
    test_examples = []
    start_idx = 100000
    max_attempts = NUM_TEST_SAMPLES * 100  # Try up to 100x the target to find valid samples
    
    logger.info(f"Scanning rows starting from index {start_idx}...")
    
    for offset in tqdm(range(max_attempts), desc="Preprocessing data"):
        if len(test_examples) >= NUM_TEST_SAMPLES:
            break
            
        idx = start_idx + offset
        if idx >= len(df):
            logger.warning(f"Reached end of dataset at index {idx}, only found {len(test_examples)} valid examples")
            break
            
        row = df.iloc[idx]
        
        try:
            if pd.notna(row.get('chunk_refs', None)):
                chunk_refs = json.loads(row['chunk_refs'])
            elif pd.notna(row.get('reference_text', None)):
                chunk_refs = [str(row['reference_text']).strip()]
            else:
                continue
            
            schema = json.loads(row['input'])
            result = json.loads(row['output'])
            
            if 'chunk_refs' in schema:
                del schema['chunk_refs']
            
            compact_schema = json.dumps(schema, separators=(',', ':'))
            
            chunks = []
            for chunk_id in chunk_refs:
                if chunk_id in chunk_texts:
                    chunks.append(chunk_texts[chunk_id])
            
            if not chunks:
                continue
            
            # Try with all chunks first, then reduce if needed
            for k in range(len(chunks), 0, -1):
                selected_chunks = chunks[:k]
                joined_chunks = "\n---\n".join(selected_chunks)
                
                prompt = f"### System:\n{system_prompt}\n\n### User:\nSchema:\n{compact_schema}\n\nDocument:\n{joined_chunks}{ASSISTANT_TAG}"
                
                if len(prompt) <= MAX_SEQ_LEN * 3:
                    test_examples.append({
                        "prompt": prompt,
                        "schema": compact_schema,
                        "gold_output": json.dumps(result, separators=(',', ':')),
                        "source_index": idx,
                        "system_prompt": system_prompt,
                        "user_content": f"Schema:\n{compact_schema}\n\nDocument:\n{joined_chunks}"
                    })
                    break
                    
        except Exception as e:
            continue
    
    logger.info(f"Loaded {len(test_examples)} test examples from {offset+1} rows scanned")
    return test_examples

def evaluate_api_model(base_url, model_id, api_key=None, workers=20):
    logger.info("=" * 80)
    logger.info(f"EVALUATING API MODEL")
    logger.info(f"Base URL: {base_url}")
    logger.info(f"Model ID: {model_id}")
    logger.info(f"API Key: {'PROVIDED' if api_key else 'NOT REQUIRED'}")
    logger.info(f"Workers: {workers}")
    logger.info(f"Test samples: {NUM_TEST_SAMPLES}")
    logger.info("=" * 80)
    
    # Initialize OpenAI client
    client = OpenAI(
        base_url=base_url,
        api_key=api_key if api_key else "not-needed"
    )
    
    test_data = load_test_data()
    
    results = [None] * len(test_data)
    json_valid_count = 0
    completed_count = 0
    
    logger.info(f"Evaluating {len(test_data)} examples with {workers} workers...")
    
    def process_example(idx, example):
        try:
            # Create chat completion request
            response = client.chat.completions.create(
                model=model_id,
                messages=[
                    {"role": "system", "content": example["system_prompt"]},
                    {"role": "user", "content": example["user_content"]}
                ],
                max_tokens=4096,
                # reasoning_effort = "medium"
            )
            
            # Extract generated text
            generated_text = response.choices[0].message.content
            
            # Compute reward
            reward = compute_extraction_reward(
                generated_text, 
                example["schema"], 
                example["gold_output"]
            )
            
            # Check JSON validity
            is_valid = extract_json_from_text(generated_text) is not None
            
            return idx, reward, is_valid
                
        except Exception as e:
            logger.error(f"Error processing example {idx}: {str(e)}")
            return idx, 0.0, False
    
    # Use ThreadPoolExecutor for parallel processing
    with ThreadPoolExecutor(max_workers=workers) as executor:
        # Submit all tasks
        futures = {
            executor.submit(process_example, i, test_data[i]): i 
            for i in range(len(test_data))
        }
        
        # Process completed tasks with progress bar
        with tqdm(total=len(test_data), desc="Evaluating") as pbar:
            for future in as_completed(futures):
                idx, reward, is_valid = future.result()
                results[idx] = reward
                if is_valid:
                    json_valid_count += 1
                
                completed_count += 1
                pbar.update(1)
                
                # Log progress every 100 samples
                if completed_count % 100 == 0:
                    valid_results = [r for r in results if r is not None]
                    current_mean = np.mean(valid_results) if valid_results else 0.0
                    logger.info(f"Progress: {completed_count}/{len(test_data)}, Current mean reward: {current_mean:.4f}")
    
    mean_reward = np.mean(results)
    std_reward = np.std(results)
    json_validity = json_valid_count / len(test_data)
    
    logger.info("=" * 80)
    logger.info("EVALUATION COMPLETE")
    logger.info(f"Mean Reward: {mean_reward:.4f} ± {std_reward:.4f}")
    logger.info(f"JSON Validity: {json_validity:.2%}")
    logger.info(f"Samples evaluated: {len(results)}")
    logger.info("=" * 80)
    
    return {
        "model": model_id,
        "base_url": base_url,
        "mean_reward": mean_reward,
        "std_reward": std_reward,
        "json_validity": json_validity,
        "num_samples": len(results)
    }

def evaluate_local_model(model_path, model_type="extract0"):
    logger.info("=" * 80)
    logger.info(f"EVALUATING {model_type.upper()} MODEL")
    logger.info(f"Model: {model_path}")
    logger.info(f"Test samples: {NUM_TEST_SAMPLES}")
    logger.info("=" * 80)
    
    clear_gpu_memory()
    
    logger.info(f"Loading {model_type} model...")
    
    if model_type == "extract0":
        tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
        base_model = AutoModelForCausalLM.from_pretrained(
            BASE_MODEL_NAME,
            torch_dtype=torch.bfloat16,
            device_map="auto",
            trust_remote_code=True
        )
        model = PeftModel.from_pretrained(base_model, model_path)
    else:
        tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
        model = AutoModelForCausalLM.from_pretrained(
            model_path,
            torch_dtype=torch.bfloat16,
            device_map="auto",
            trust_remote_code=True
        )
    
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    model.eval()
    
    test_data = load_test_data()
    
    results = []
    json_valid_count = 0
    
    logger.info(f"Evaluating {len(test_data)} examples...")
    
    for i in tqdm(range(0, len(test_data), BATCH_SIZE), desc="Evaluating"):
        batch = test_data[i:i+BATCH_SIZE]
        
        prompts = [ex["prompt"] for ex in batch]
        schemas = [ex["schema"] for ex in batch]
        gold_outputs = [ex["gold_output"] for ex in batch]
        
        inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True, max_length=MAX_SEQ_LEN)
        inputs = {k: v.to(model.device) for k, v in inputs.items()}
        
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=532,
                temperature=0.7,
                top_p=0.95,
                do_sample=True,
                pad_token_id=tokenizer.pad_token_id,
                eos_token_id=tokenizer.eos_token_id
            )
        
        for j, output in enumerate(outputs):
            generated_tokens = output[inputs["input_ids"].shape[1]:]
            generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
            
            reward = compute_extraction_reward(generated_text, schemas[j], gold_outputs[j])
            results.append(reward)
            
            if extract_json_from_text(generated_text) is not None:
                json_valid_count += 1
        
        if (i // BATCH_SIZE + 1) % 10 == 0:
            current_mean = np.mean(results)
            logger.info(f"Progress: {len(results)}/{len(test_data)}, Current mean reward: {current_mean:.4f}")
        
        torch.cuda.empty_cache()
    
    mean_reward = np.mean(results)
    std_reward = np.std(results)
    json_validity = json_valid_count / len(test_data)
    
    logger.info("=" * 80)
    logger.info("EVALUATION COMPLETE")
    logger.info(f"Mean Reward: {mean_reward:.4f} ± {std_reward:.4f}")
    logger.info(f"JSON Validity: {json_validity:.2%}")
    logger.info(f"Samples evaluated: {len(results)}")
    logger.info("=" * 80)
    
    return {
        "model": model_type,
        "mean_reward": mean_reward,
        "std_reward": std_reward,
        "json_validity": json_validity,
        "num_samples": len(results)
    }

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser(description="Evaluate extraction models")
    parser.add_argument("--model", type=str, help="Path to local model checkpoint")
    parser.add_argument("--base-url", type=str, help="Base URL for API model")
    parser.add_argument("--model-id", type=str, help="Model ID for API")
    parser.add_argument("--api-key", type=str, help="API key (optional)")
    parser.add_argument("--workers", type=int, default=20, help="Number of worker threads for API evaluation (default: 20)")
    args = parser.parse_args()
    
    if args.base_url and args.model_id:
        # Evaluate API model
        result = evaluate_api_model(args.base_url, args.model_id, args.api_key, args.workers)
        
        # Save results
        results_df = pd.DataFrame([result])
        output_file = f"evaluation_results_{TIMESTAMP}.csv"
        results_df.to_csv(output_file, index=False)
        logger.info(f"Results saved to {output_file}")
        
    elif args.model:
        # Evaluate local model
        result = evaluate_local_model(args.model, "base")
        
        # Save results
        results_df = pd.DataFrame([result])
        output_file = f"evaluation_results_{TIMESTAMP}.csv"
        results_df.to_csv(output_file, index=False)
        logger.info(f"Results saved to {output_file}")
    else:
        logger.error("Please provide either --base-url and --model-id for API evaluation, or --model for local evaluation")
        parser.print_help()

Sign up or log in to comment

Models citing this paper 0

No model linking this paper

Cite arxiv.org/abs/2509.22906 in a model README.md to link it from this page.

Datasets citing this paper 1

Spaces citing this paper 0

No Space linking this paper

Cite arxiv.org/abs/2509.22906 in a Space README.md to link it from this page.

Collections including this paper 1