File size: 7,001 Bytes
eee6498 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 |
---
language:
- "en"
pretty_name: "ModernTrajectoryNet: Transaction Embedding Classifier"
tags:
- embedding
- pytorch
- finance
- transaction-classifier
- contrastive-learning
license: "apache-2.0"
datasets:
- "HighkeyPrxneeth/BusinessTransactions"
library_name: "pytorch"
---
# ModernTrajectoryNet: Transaction Embedding Classifier
A state-of-the-art PyTorch embedding classifier trained with modern deep learning techniques for transaction categorization. The model learns to project transaction embeddings toward their target category embeddings through trajectory-based contrastive learning.
## Model Architecture
**ModernTrajectoryNet** combines several modern architectural innovations:
### Core Components
1. **RMSNorm (Root Mean Square Layer Normalization)**
- More stable and computationally efficient than LayerNorm
- Used in LLaMA, PaLM, and Gopher
- Provides consistent gradient flow through deep networks
2. **SwiGLU (Swish-Gated Linear Unit)**
- SOTA activation function for feed-forward networks
- Outperforms GELU and ReLU in expressivity
- Gate mechanism: `(x * sigmoid(x)) * linear(x)`
3. **SEBlock (Squeeze-and-Excitation)**
- Channel attention mechanism
- Allows dynamic weighting of embedding dimensions
- Context-aware feature recalibration
4. **ModernBlock (Pre-Norm Architecture)**
- RMSNorm → SwiGLU → SEBlock → Residual Connection
- Incorporates layer scaling and stochastic depth (DropPath)
- Enables training of very deep networks
### Configuration
- **Input dimension**: 768 (embedding size)
- **Hidden layers**: 12 transformer-style blocks
- **Expansion ratio**: 4x hidden dimension in SwiGLU
- **Dropout**: 0.1
- **Stochastic depth**: Linear decay across layers (0.0 → 0.1)
## Training Objective: Hybrid Trajectory Learning
The model is trained with **HybridTrajectoryLoss**, combining two objectives:
### 1. Adaptive InfoNCE (Contrastive Component)
- Learnable temperature parameter for dynamic scaling
- Contrastive loss with label smoothing (0.1)
- Ensures the model maps input embeddings close to their true target embedding
- Equation: `L_contrastive = CrossEntropy(logits / T, labels)`
### 2. Monotonic Ranking (Trajectory Component)
- Enforces **monotonically increasing similarity** through the transaction sequence
- Each step in the trajectory should have higher similarity than the previous step
- Final embedding must achieve high similarity (ideally 1.0) with target
- Margin constraint: `sim[i+1] > sim[i] + 0.01`
- Ensures the model learns the **path** to the target, not just the endpoint
### Loss Formulation
```
Total Loss = InfoNCE Loss + Monotonicity Loss
```
**Why Trajectory Learning?**
- Transactions often evolve gradually toward their correct category
- Intermediate embeddings should show progression toward the target
- This inductive bias improves generalization and interpretability
## Training Details
- **Optimizer**: AdamW with weight decay (1e-4)
- **Learning rate**: Cosine annealing from 3e-4 to 1e-6
- **Batch size**: 128
- **Gradient clipping**: 1.0
- **Epochs**: 50 with early stopping (patience=5)
- **EMA (Exponential Moving Average)**: Decay=0.99 for evaluation stability
- **Augmentation**: Input masking (p=0.15) and Gaussian noise (std=0.01) during training
- **Mixed Precision**: AMP enabled for faster training on CUDA
## Performance Metrics
The model optimizes for:
1. **Last Similarity**: Similarity of final embedding with target (Target: ≈1.0)
2. **Monotonicity Accuracy**: % of transitions with strictly increasing similarity (Target: 100%)
3. **Contrastive Accuracy**: Ability to distinguish true target from other targets in batch
## How to Load
```python
from safetensors.torch import load_file
import torch
from config import Config
from model import ModernTrajectoryNet
# Load weights
weights = load_file("model.safetensors")
# Instantiate model
config = Config()
model = ModernTrajectoryNet(config)
model.load_state_dict(weights)
model.eval()
# Use model
with torch.no_grad():
input_embedding = torch.randn(1, 768) # Your transaction embedding
output_embedding = model(input_embedding)
print(output_embedding.shape) # [1, 768]
```
## Usage Example
```python
import torch
from torch.nn.functional import normalize
# Assuming you have transaction embeddings and category embeddings
transaction_emb = model(input_embedding) # [B, 768]
# Compute similarity with category embeddings
category_embs = normalize(category_embeddings, p=2, dim=1) # [N_cats, 768]
transaction_emb_norm = normalize(transaction_emb, p=2, dim=1) # [B, 768]
similarities = torch.matmul(transaction_emb_norm, category_embs.t()) # [B, N_cats]
predicted_category = torch.argmax(similarities, dim=1) # [B]
```
## Intended Uses
- **Transaction categorization**: Classify business transactions into merchant categories
- **Embedding refinement**: Project raw transaction embeddings to discriminative space
- **Contrastive learning**: Extract improved embeddings for downstream tasks
- **Research**: Study trajectory-based learning for sequential decision problems
## Limitations & Biases
- **Synthetic data**: Trained on synthetic transaction strings generated from Foursquare Open-Source (FSQ OS) business names and categories using `qwen2.5-4b-instruct` LLM
- **FSQ OS biases**: Inherits biases from the FSQ OS dataset (e.g., geographic coverage, business type distribution)
- **Generation artifacts**: LLM-based synthetic data may not reflect real-world transaction diversity
- **Category coverage**: Limited to categories present in FSQ OS (typically 200-500 merchant types)
- **Language**: Trained on English transaction strings; may not generalize to other languages
**Recommendation**: Validate performance on your specific transaction domain before production deployment.
## Dataset
- **Source**: Foursquare Open-Source (FSQ OS) business names and categories
- **Processing**: LLM-based synthetic transaction generation
- **Size**: ~1M synthetic transaction embeddings
- **Train/Val split**: 90% / 10%
See the [dataset](https://huggingface.co/datasets/HighkeyPrxneeth/BusinessTransactions) for more details.
## Files in This Repository
- `model.safetensors`: Model weights in HuggingFace SafeTensors format (160MB)
- `README.md`: This file
- `LICENSE`: Apache 2.0 license
## License
Apache License 2.0. See LICENSE file for details.
## Citation
If you use this model, please cite:
```bibtex
@software{transactionclassifier2024,
title={TransactionClassifier: Embedding-based Transaction Categorization},
author={HighkeyPrxneeth},
year={2024},
url={https://huggingface.co/HighkeyPrxneeth/ModernTrajectoryNet}
}
```
## Contact & Support
- **Repository**: [GitHub - TransactionClassifier](https://github.com/HighkeyPrxneeth/TransactionClassifier)
- **Issues**: Open an issue in the main project repository
- **Author**: HighkeyPrxneeth
For questions about the model architecture, training, or usage, feel free to reach out!
|