|
|
--- |
|
|
language: |
|
|
- "en" |
|
|
pretty_name: "ModernTrajectoryNet: Transaction Embedding Classifier" |
|
|
tags: |
|
|
- embedding |
|
|
- pytorch |
|
|
- finance |
|
|
- transaction-classifier |
|
|
- contrastive-learning |
|
|
license: "apache-2.0" |
|
|
datasets: |
|
|
- "HighkeyPrxneeth/BusinessTransactions" |
|
|
library_name: "pytorch" |
|
|
--- |
|
|
|
|
|
# ModernTrajectoryNet: Transaction Embedding Classifier |
|
|
|
|
|
A state-of-the-art PyTorch embedding classifier trained with modern deep learning techniques for transaction categorization. The model learns to project transaction embeddings toward their target category embeddings through trajectory-based contrastive learning. |
|
|
|
|
|
## Model Architecture |
|
|
|
|
|
**ModernTrajectoryNet** combines several modern architectural innovations: |
|
|
|
|
|
### Core Components |
|
|
|
|
|
1. **RMSNorm (Root Mean Square Layer Normalization)** |
|
|
- More stable and computationally efficient than LayerNorm |
|
|
- Used in LLaMA, PaLM, and Gopher |
|
|
- Provides consistent gradient flow through deep networks |
|
|
|
|
|
2. **SwiGLU (Swish-Gated Linear Unit)** |
|
|
- SOTA activation function for feed-forward networks |
|
|
- Outperforms GELU and ReLU in expressivity |
|
|
- Gate mechanism: `(x * sigmoid(x)) * linear(x)` |
|
|
|
|
|
3. **SEBlock (Squeeze-and-Excitation)** |
|
|
- Channel attention mechanism |
|
|
- Allows dynamic weighting of embedding dimensions |
|
|
- Context-aware feature recalibration |
|
|
|
|
|
4. **ModernBlock (Pre-Norm Architecture)** |
|
|
- RMSNorm → SwiGLU → SEBlock → Residual Connection |
|
|
- Incorporates layer scaling and stochastic depth (DropPath) |
|
|
- Enables training of very deep networks |
|
|
|
|
|
### Configuration |
|
|
|
|
|
- **Input dimension**: 768 (embedding size) |
|
|
- **Hidden layers**: 12 transformer-style blocks |
|
|
- **Expansion ratio**: 4x hidden dimension in SwiGLU |
|
|
- **Dropout**: 0.1 |
|
|
- **Stochastic depth**: Linear decay across layers (0.0 → 0.1) |
|
|
|
|
|
## Training Objective: Hybrid Trajectory Learning |
|
|
|
|
|
The model is trained with **HybridTrajectoryLoss**, combining two objectives: |
|
|
|
|
|
### 1. Adaptive InfoNCE (Contrastive Component) |
|
|
- Learnable temperature parameter for dynamic scaling |
|
|
- Contrastive loss with label smoothing (0.1) |
|
|
- Ensures the model maps input embeddings close to their true target embedding |
|
|
- Equation: `L_contrastive = CrossEntropy(logits / T, labels)` |
|
|
|
|
|
### 2. Monotonic Ranking (Trajectory Component) |
|
|
- Enforces **monotonically increasing similarity** through the transaction sequence |
|
|
- Each step in the trajectory should have higher similarity than the previous step |
|
|
- Final embedding must achieve high similarity (ideally 1.0) with target |
|
|
- Margin constraint: `sim[i+1] > sim[i] + 0.01` |
|
|
- Ensures the model learns the **path** to the target, not just the endpoint |
|
|
|
|
|
### Loss Formulation |
|
|
|
|
|
``` |
|
|
Total Loss = InfoNCE Loss + Monotonicity Loss |
|
|
``` |
|
|
|
|
|
**Why Trajectory Learning?** |
|
|
- Transactions often evolve gradually toward their correct category |
|
|
- Intermediate embeddings should show progression toward the target |
|
|
- This inductive bias improves generalization and interpretability |
|
|
|
|
|
## Training Details |
|
|
|
|
|
- **Optimizer**: AdamW with weight decay (1e-4) |
|
|
- **Learning rate**: Cosine annealing from 3e-4 to 1e-6 |
|
|
- **Batch size**: 128 |
|
|
- **Gradient clipping**: 1.0 |
|
|
- **Epochs**: 50 with early stopping (patience=5) |
|
|
- **EMA (Exponential Moving Average)**: Decay=0.99 for evaluation stability |
|
|
- **Augmentation**: Input masking (p=0.15) and Gaussian noise (std=0.01) during training |
|
|
- **Mixed Precision**: AMP enabled for faster training on CUDA |
|
|
|
|
|
## Performance Metrics |
|
|
|
|
|
The model optimizes for: |
|
|
1. **Last Similarity**: Similarity of final embedding with target (Target: ≈1.0) |
|
|
2. **Monotonicity Accuracy**: % of transitions with strictly increasing similarity (Target: 100%) |
|
|
3. **Contrastive Accuracy**: Ability to distinguish true target from other targets in batch |
|
|
|
|
|
## How to Load |
|
|
|
|
|
```python |
|
|
from safetensors.torch import load_file |
|
|
import torch |
|
|
from config import Config |
|
|
from model import ModernTrajectoryNet |
|
|
|
|
|
# Load weights |
|
|
weights = load_file("model.safetensors") |
|
|
|
|
|
# Instantiate model |
|
|
config = Config() |
|
|
model = ModernTrajectoryNet(config) |
|
|
model.load_state_dict(weights) |
|
|
model.eval() |
|
|
|
|
|
# Use model |
|
|
with torch.no_grad(): |
|
|
input_embedding = torch.randn(1, 768) # Your transaction embedding |
|
|
output_embedding = model(input_embedding) |
|
|
print(output_embedding.shape) # [1, 768] |
|
|
``` |
|
|
|
|
|
## Usage Example |
|
|
|
|
|
```python |
|
|
import torch |
|
|
from torch.nn.functional import normalize |
|
|
|
|
|
# Assuming you have transaction embeddings and category embeddings |
|
|
transaction_emb = model(input_embedding) # [B, 768] |
|
|
|
|
|
# Compute similarity with category embeddings |
|
|
category_embs = normalize(category_embeddings, p=2, dim=1) # [N_cats, 768] |
|
|
transaction_emb_norm = normalize(transaction_emb, p=2, dim=1) # [B, 768] |
|
|
|
|
|
similarities = torch.matmul(transaction_emb_norm, category_embs.t()) # [B, N_cats] |
|
|
predicted_category = torch.argmax(similarities, dim=1) # [B] |
|
|
``` |
|
|
|
|
|
## Intended Uses |
|
|
|
|
|
- **Transaction categorization**: Classify business transactions into merchant categories |
|
|
- **Embedding refinement**: Project raw transaction embeddings to discriminative space |
|
|
- **Contrastive learning**: Extract improved embeddings for downstream tasks |
|
|
- **Research**: Study trajectory-based learning for sequential decision problems |
|
|
|
|
|
## Limitations & Biases |
|
|
|
|
|
- **Synthetic data**: Trained on synthetic transaction strings generated from Foursquare Open-Source (FSQ OS) business names and categories using `qwen2.5-4b-instruct` LLM |
|
|
- **FSQ OS biases**: Inherits biases from the FSQ OS dataset (e.g., geographic coverage, business type distribution) |
|
|
- **Generation artifacts**: LLM-based synthetic data may not reflect real-world transaction diversity |
|
|
- **Category coverage**: Limited to categories present in FSQ OS (typically 200-500 merchant types) |
|
|
- **Language**: Trained on English transaction strings; may not generalize to other languages |
|
|
|
|
|
**Recommendation**: Validate performance on your specific transaction domain before production deployment. |
|
|
|
|
|
## Dataset |
|
|
|
|
|
- **Source**: Foursquare Open-Source (FSQ OS) business names and categories |
|
|
- **Processing**: LLM-based synthetic transaction generation |
|
|
- **Size**: ~1M synthetic transaction embeddings |
|
|
- **Train/Val split**: 90% / 10% |
|
|
|
|
|
See the [dataset](https://huggingface.co/datasets/HighkeyPrxneeth/BusinessTransactions) for more details. |
|
|
|
|
|
## Files in This Repository |
|
|
|
|
|
- `model.safetensors`: Model weights in HuggingFace SafeTensors format (160MB) |
|
|
- `README.md`: This file |
|
|
- `LICENSE`: Apache 2.0 license |
|
|
|
|
|
## License |
|
|
|
|
|
Apache License 2.0. See LICENSE file for details. |
|
|
|
|
|
## Citation |
|
|
|
|
|
If you use this model, please cite: |
|
|
|
|
|
```bibtex |
|
|
@software{transactionclassifier2024, |
|
|
title={TransactionClassifier: Embedding-based Transaction Categorization}, |
|
|
author={HighkeyPrxneeth}, |
|
|
year={2024}, |
|
|
url={https://huggingface.co/HighkeyPrxneeth/ModernTrajectoryNet} |
|
|
} |
|
|
``` |
|
|
|
|
|
## Contact & Support |
|
|
|
|
|
- **Repository**: [GitHub - TransactionClassifier](https://github.com/HighkeyPrxneeth/TransactionClassifier) |
|
|
- **Issues**: Open an issue in the main project repository |
|
|
- **Author**: HighkeyPrxneeth |
|
|
|
|
|
For questions about the model architecture, training, or usage, feel free to reach out! |
|
|
|