midwestern-simulation/essence-3b-v2

https://midwestern-simulation.neocities.org/main/library/Essense%209-16-2025

we've trained an ai model that compresses sequences of token embeddings into shorter sequences of token embeddings, which it then attempts to reconstruct the original text from—with varying degrees of success.

the main dial we can turn is the number of "embedding tokens" used to represent a text. here's how it works:

  1. in the encoder, these special tokens are appended to the original text.
  2. the hidden states from these token positions are extracted.
  3. they are then placed at the start of the decoder's context window.
  4. the decoder then attempts to reconstruct the original text from this compressed concept.

usage

from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
from torch import nn
import torch
from huggingface_hub import hf_hub_download

device = torch.device("cuda:0")
dtype = torch.bfloat16
base_model_id = "HuggingFaceTB/SmolLM3-3B-Base"
compressor_id = "midwestern-simulation/essence-3b-v2"

# === MODEL LOADING ===

tokenizer = AutoTokenizer.from_pretrained(base_model_id, padding_side='left')
encoder = AutoModelForCausalLM.from_pretrained(base_model_id, device_map={"":device}, torch_dtype=dtype)
decoder = AutoModelForCausalLM.from_pretrained(base_model_id, device_map={"":device}, torch_dtype=dtype)

encoder = PeftModel.from_pretrained(encoder, compressor_id, subfolder="encoder")
decoder = PeftModel.from_pretrained(decoder, compressor_id, subfolder="decoder")

projector = nn.Linear(2048, 2048).to(device).to(dtype)
projector.load_state_dict(torch.load(hf_hub_download(repo_id=compressor_id, filename="projector.pt")))


# === MODEL INFERENCE ===

text = "mary had a little lamb, little lamb, little lamb, mary had a little lamb whose fleece was white as snow"
n_embed_tokens = 4 # can be any in the range of 1-256 for best performance, may exhibit limited generalization outside of range

encoder_input = text.strip() + f"\n[[/END DOCUMENT]]\n[[START SUMMARY ntoks={n_embed_tokens}]]" + "<|im_end|>" * n_embed_tokens

tokenized = tokenizer(encoder_input, return_tensors='pt', add_special_tokens=False)
tokenized = {k: v.to(device) for k, v in tokenized.items()}
encoding = encoder.model.model(**tokenized).last_hidden_state[:, -n_embed_tokens:, :]
encoding = projector(encoding)

tokenized_prefix = tokenizer("\n[[/END SUMMARY]]\n[[START DOCUMENT]]\n", return_tensors="pt", add_special_tokens=False)
prefix_embeds = decoder.model.model.embed_tokens(tokenized_prefix['input_ids'].to(device))
inputs_embeds = torch.cat([encoding, prefix_embeds], 1)
output = decoder.generate(
    inputs_embeds=inputs_embeds,
    temperature=0.7,
    max_new_tokens=1024,
    do_sample=True,
    top_k=128,
    min_new_tokens=8,
    eos_token_id=tokenizer.eos_token_id,
    pad_token_id=tokenizer.pad_token_id
)
print(tokenizer.decode(output[0]))
# mary had a little lamb, little lamb, little lamb, mary had a little lamb whose fleece was white as snow
# [[/END DOCUMENT]]<|end_of_text|>

tips:

  • this model was also trained, over 12k steps, for not only reconstruction but also reconstruction with span-corruption and masked-language-modelling, the mask token id used was 128005
  • due to compute constraints, this model was only trained for ctxs up to 2k tokens, and number of embedding tokens up to 256, though the model may express limited generalization outside these bounds
  • during training, the average sample length was around 702 tokens, around this range is where you will find the best performance
  • for short samples, you can reach a very good balance of reconstruction quality and the ability to play around with concepts between multiple texts' embeddings at around 12 embedding tokens, the lower you go the better time you will have playing with concepts and the higher you go the better time you will have with reconstruction
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for midwestern-simulation/essence-3b-v2

Finetuned
(56)
this model

Datasets used to train midwestern-simulation/essence-3b-v2