This model is trained through the approach described in DMRetriever: A Family of Models for Improved Text Retrieval in Disaster Management. The associated GitHub repository is available here. This model has 596M parameters and it is the pre-trained version (trained using only unlabeled dataset containing in-batch negative).

๐Ÿง  Model Overview

DMRetriever-596M-PT has the following features:

  • Model Type: Text Embedding
  • Supported Languages: English
  • Number of Paramaters: 0.6B
  • Embedding Dimension: 1024

For more details, including model training, benchmark evaluation, and inference performance, please refer to our paper, GitHub.

๐Ÿ“ฆ DMRetriever Series Model List

Model Description Backbone Backbone Type Hidden Size #Layers
DMRetriever-33M Base 33M variant MiniLM Encoder-only 384 12
DMRetriever-33M-PT Pre-trained version of 33M MiniLM Encoder-only 384 12
DMRetriever-109M Base 109M variant BERT-base-uncased Encoder-only 768 12
DMRetriever-109M-PT Pre-trained version of 109M BERT-base-uncased Encoder-only 768 12
DMRetriever-335M Base 335M variant BERT-large-uncased-WWM Encoder-only 1024 24
DMRetriever-335M-PT Pre-trained version of 335M BERT-large-uncased-WWM Encoder-only 1024 24
DMRetriever-596M Base 596M variant Qwen3-0.6B Decoder-only 1024 28
DMRetriever-596M-PT Pre-trained version of 596M Qwen3-0.6B Decoder-only 1024 28
DMRetriever-4B Base 4B variant Qwen3-4B Decoder-only 2560 36
DMRetriever-4B-PT Pre-trained version of 4B Qwen3-4B Decoder-only 2560 36
DMRetriever-7.6B Base 7.6B variant Qwen3-8B Decoder-only 4096 36
DMRetriever-7.6B-PT Pre-trained version of 7.6B Qwen3-8B Decoder-only 4096 36

๐Ÿš€ Usage

Using HuggingFace Transformers:

# pip install torch transformers
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer
from bidirectional_qwen3 import Qwen3BiModel  # custom bidirectional backbone

MODEL_ID = "DMIR01/DMRetriever-596M-PT"

# Device & dtype
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32

# --- Tokenizer (needs remote code for custom modules) ---
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_ID,
    trust_remote_code=True,
    use_fast=False,
)
# Ensure pad token and right padding (matches training)
if getattr(tokenizer, "pad_token_id", None) is None and getattr(tokenizer, "eos_token", None) is not None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

# --- Bidirectional encoder (non-autoregressive; for retrieval/embedding) ---
model = Qwen3BiModel.from_pretrained(
    MODEL_ID,
    torch_dtype=dtype,
    trust_remote_code=True,
).to(device).eval()

# --- Mean pooling over valid tokens ---
def mean_pool(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
    mask = attention_mask.unsqueeze(-1).to(last_hidden_state.dtype)  # [B, L, 1]
    summed = (last_hidden_state * mask).sum(dim=1)                   # [B, H]
    counts = mask.sum(dim=1).clamp(min=1e-9)                         # [B, 1]
    return summed / counts

# --- Batch encoder: returns L2-normalized embeddings ---
def encode_texts(texts, batch_size=32, max_length=512):
    vecs = []
    for i in range(0, len(texts), batch_size):
        batch = texts[i:i+batch_size]
        with torch.no_grad():
            inputs = tokenizer(
                batch,
                max_length=max_length,
                truncation=True,
                padding=True,
                return_tensors="pt",
            ).to(device)
            hidden = model(**inputs).last_hidden_state
            emb = mean_pool(hidden, inputs["attention_mask"])
            emb = F.normalize(emb, p=2, dim=1)  # cosine-ready
            vecs.append(emb.cpu())
    return torch.cat(vecs, dim=0) if vecs else torch.empty(0, model.config.hidden_size)

# --- Task instructions (apply to queries only) ---
TASK2PREFIX = {
    "FactCheck": "Given the claim, retrieve most relevant document that supports or refutes the claim",
    "NLI":       "Given the premise, retrieve most relevant hypothesis that is entailed by the premise",
    "QA":        "Given the question, retrieve most relevant passage that best answers the question",
    "QAdoc":     "Given the question, retrieve the most relevant document that answers the question",
    "STS":       "Given the sentence, retrieve the sentence with the same meaning",
    "Twitter":   "Given the user query, retrieve the most relevant Twitter text that meets the request",
}

def apply_task_prefix(queries, task: str):
    """Add instruction to queries; corpus texts remain unchanged."""
    prefix = TASK2PREFIX.get(task, "")
    if prefix:
        return [f"{prefix}: {q.strip()}" for q in queries]
    return [q.strip() for q in queries]

# ========================= Usage =========================
# Queries need task instruction
task = "QA"
queries_raw = [
    "Who wrote The Little Prince?",
    "What is the capital of France?",
]
queries = apply_task_prefix(queries_raw, task)

# Corpus: no instruction
corpus_passages = [
    "The Little Prince is a novella by Antoine de Saint-Exupรฉry, first published in 1943.",
    "Paris is the capital and most populous city of France.",
    "Transformers are neural architectures that rely on attention mechanisms.",
]

# Encode
query_emb  = encode_texts(queries,         batch_size=32, max_length=512)  # [Q, H]
corpus_emb = encode_texts(corpus_passages, batch_size=32, max_length=512)  # [D, H]
print("Query embeddings:",  tuple(query_emb.shape))
print("Corpus embeddings:", tuple(corpus_emb.shape))

# Retrieval demo: cosine similarity via dot product (embeddings are normalized)
scores = query_emb @ corpus_emb.T  # [Q, D]
topk = scores.topk(k=min(3, corpus_emb.size(0)), dim=1)

for i, q in enumerate(queries_raw):
    print(f"\nQuery[{i}] {q}")
    for rank, (score, idx) in enumerate(zip(topk.values[i].tolist(), topk.indices[i].tolist()), start=1):
        print(f"  Top{rank}: doc#{idx} | score={score:.4f} | text={corpus_passages[idx]}")

โš ๏ธ Notice

  1. The backbone used in DMRetriever is Bidirectional Qwen3, not the standard Qwen3.
    Please ensure that the bidirectional_qwen3 module (included in the released model checkpoint folder) is correctly placed inside your model directory.

  2. Make sure that your transformers library version is > 4.51.0 to avoid the error:
    KeyError: 'qwen3'.

๐Ÿงพ Citation

If you find this repository helpful, please kindly consider citing the corresponding paper. Thanks!

@article{yin2025dmretriever,
  title={DMRetriever: A Family of Models for Improved Text Retrieval in Disaster Management},
  author={Yin, Kai and Dong, Xiangjue and Liu, Chengkai and Lin, Allen and Shi, Lingfeng and Mostafavi, Ali and Caverlee, James},
  journal={arXiv preprint arXiv:2510.15087},
  year={2025}
}
Downloads last month
21
Safetensors
Model size
0.6B params
Tensor type
F16
ยท
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support