SIM_COT-GPT2-CODI / README.md
yuhangzang's picture
Create README.md
3e9b7c1 verified
metadata
language:
  - en
license: mit
tags:
  - chain-of-thought
  - implicit-reasoning
  - multimodal
  - llama3
  - instruction-tuned
datasets:
  - gsm8k
  - svamp
  - multi_arith
model-index:
  - name: SIM_COT-GPT2-CODI
    results:
      - task:
          type: math-word-problems
          name: Arithmetic Reasoning
        dataset:
          name: GSM8K
          type: gsm8k
        metrics:
          - type: accuracy
            value: xx.x
      - task:
          type: math-word-problems
          name: MultiArith
        dataset:
          name: MultiArith
          type: multi_arith
        metrics:
          - type: accuracy
            value: xx.x
      - task:
          type: math-word-problems
          name: SVAMP
        dataset:
          name: SVAMP
          type: svamp
        metrics:
          - type: accuracy
            value: xx.x

๐Ÿš€ SIM_COT-GPT2-CODI

๐Ÿค— Model Repo ๐Ÿ“‚ GitHub
๐Ÿ“„ Paper

Teaser Figure

๐Ÿ“– Introduction

Chain-of-Thought (CoT) prompting has become a widely adopted strategy for enhancing the reasoning capabilities of Large Language Models (LLMs). By decomposing problems into intermediate steps, explicit CoT improves accuracy across a variety of reasoning tasks. However, the token cost of explicit reasoning severely limits its scalability, especially when applied to long-horizon tasks or deployed under strict computational budgets.

Implicit CoT methods attempt to address this issue by replacing explicit intermediate steps with continuous latent representations. These approaches achieve higher token efficiency while retaining some of the benefits of step-wise reasoning. Despite this promise, a persistent performance gap remains: implicit CoT methods often underperform compared to explicit reasoning, especially as the number of latent tokens is scaled. Our analysis identifies a fundamental latent instability problem: as more implicit reasoning tokens are introduced, training frequently becomes unstable, with latent representations collapsing into homogeneous states that lack semantic diversity. This failure is largely due to the absence of fine-grained, step-level supervision in existing approaches.

To overcome this limitation, we introduce SIM-CoT, a plug-and-play training module designed to stabilize and enrich the latent reasoning space. SIM-CoT leverages an auxiliary decoder during training that aligns each implicit token with its corresponding explicit reasoning step. This step-level supervision ensures that latent states encode distinct and meaningful information. Importantly, the auxiliary decoder is removed during inference, meaning that SIM-CoT preserves the computational efficiency of implicit CoT without adding runtime overhead.

Empirical results demonstrate that SIM-CoT substantially improves both in-domain accuracy and out-of-domain stability. On smaller models such as GPT-2, SIM-CoT not only boosts implicit baselines like Coconut by +8.2% but also surpasses explicit CoT by +2.1% while being 2.3ร— more token-efficient. On larger models, including LLaMA-3.1 8B, SIM-CoT delivers consistent gains, improving CODI by +3.0% and significantly narrowing the performance gap with explicit reasoning. These findings highlight SIM-CoT as an effective and scalable solution for advancing implicit reasoning in LLMs.


SIM_COT-GPT2-CODI is a large implicit language model based on GPT2, fine-tuned with SIM-CoT (Supervised Implicit Chain-of-Thought) on top of the CODI latent reasoning framework.
It is designed to improve โœจ implicit reasoning and ๐Ÿงฎ arithmetic multi-step problem solving across benchmarks such as GSM8K, GSM-Hard, MultiArith, and SVAMP.


๐Ÿ“Š Experimental Results

We evaluate SIM-CoT across both in-domain (GSM8K-Aug) and out-of-domain (GSM-Hard, MultiArith, SVAMP) benchmarks, using GPT-2, LLaMA-3.2 1B, LLaMA-3.2 3B, and LLaMA-3.1 8B as backbones, applied to both Coconut and CODI frameworks.

Main Results on GPT2

Main results on GPT-2. We report accuracy % on in-domain (GSM8k-Aug) and out-of-domain (GSM-Hard, MultiArith, SVAMP) benchmarks. Our SIM-CoT is shown to provide accuracy gains on top of existing methods such as Coconut and CODI.

Main Results on LLaMA3 1B

Main results on LLaMA 3.2 1B. We report accuracy % on in-domain (GSM8k-Aug) and out-of-domain (GSM-Hard, MultiArith, SVAMP) benchmarks. Our SIM-CoT builds on CODI to achieve a new SOTA in implicit reasoning while setting performance comparable to explicit CoT.

Main Results on LLaMA3 3B and 8B

Main results on LLaMA 3.2 3B and 8B. We report accuracy % on in-domain (GSM8k-Aug) and out-of-domain (GSM-Hard, MultiArith, SVAMP) benchmarks.


๐Ÿ“Œ Model Details

  • ๐Ÿ—๏ธ Base model: GPT2
  • โšก Fine-tuning method: LoRA (r=128, alpha=32)
  • ๐Ÿ”‘ Latent reasoning: 6 latent steps, projection dimension = 768
  • ๐ŸŽฏ Dropout: 0.0 (projection layer)
  • ๐Ÿ–ฅ๏ธ Precision: bf16
  • ๐Ÿ“ Context length: 512 tokens

The model integrates implicit reasoning tokens during training and inference.
Unlike standard explicit CoT models, SIM-CoT encourages the model to generate latent structured thoughts that are decoded only during training, while remaining implicit during inference.


๐ŸŽฏ Intended Uses

  • ๐Ÿ”ฌ AI-related research (reasoning, representation learning, interpretability)
  • ๐Ÿ“Š Benchmarking on arithmetic reasoning datasets (e.g., GSM8K, SVAMP, MultiArith, GSM-Hard)
  • ๐Ÿงฉ Studying latent representation learning and reasoning generalization

โš ๏ธ Not intended for deployment in production without careful alignment and safety evaluation.


๐Ÿ’ป Usage

To reproduce our results, follow the steps below:

1. Clone the repository

git clone https://github.com/InternLM/SIM-CoT.git
cd SIM-CoT/CODI

2. Run the evaluation script

We provide shell scripts for different backbones and datasets. For example, to evaluate on GPT2 with the SVAMP dataset, run:

bash test_llama1b.sh

This will internally call the following command:

python test.py \
    --data_name "svamp" \
    --output_dir "$SAVE_DIR" \
    --model_name_or_path path/to/gpt2 \
    --seed 11 \
    --model_max_length 512 \
    --bf16 \
    --lora_r 128 --lora_alpha 32 --lora_init \
    --batch_size 128 \
    --greedy True \
    --num_latent 6 \
    --use_prj True \
    --prj_dim 768 \
    --prj_no_ln False \
    --prj_dropout 0.0 \
    --inf_latent_iterations 6 \
    --inf_num_iterations 1 \
    --remove_eos True \
    --use_lora True \
    --ckpt_dir path/to/sim_cot-checkpoints

3. Expected output

After running, the script will print the evaluation summary. An example output format is:

adapter: None | GSM8K test accuracy: xxx% | 
average length of COT: xxx
Average accuracy over 1 sampling: xxx
  • test accuracy: accuracy on the specified benchmark.
  • average length of COT: average number of latent reasoning tokens.
  • average accuracy: aggregated accuracy across sampled runs.