mjbommar commited on
Commit
797dfb1
·
verified ·
1 Parent(s): a2706cb

Upload magic-bert-50m-mlm model files

Browse files
README.md ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - en
4
+ license: apache-2.0
5
+ library_name: transformers
6
+ tags:
7
+ - binary-analysis
8
+ - file-type-detection
9
+ - byte-level
10
+ - fill-mask
11
+ - mlm
12
+ - magic-bytes
13
+ - security
14
+ pipeline_tag: fill-mask
15
+ model-index:
16
+ - name: magic-bert-50m-mlm
17
+ results:
18
+ - task:
19
+ type: fill-mask
20
+ name: Masked Language Modeling
21
+ metrics:
22
+ - name: Perplexity
23
+ type: perplexity
24
+ value: 1.05
25
+ - name: Fill-mask Top-1 Accuracy
26
+ type: accuracy
27
+ value: 58.9
28
+ - name: Fill-mask Top-5 Accuracy
29
+ type: accuracy
30
+ value: 73.5
31
+ - name: Probing Classification Accuracy
32
+ type: accuracy
33
+ value: 87.0
34
+ ---
35
+
36
+ # Magic-BERT 50M MLM
37
+
38
+ A BERT-style transformer model trained for binary file understanding using masked language modeling (MLM). This model learns byte-level patterns in binary files, including magic bytes, headers, and structural patterns across 106 file types.
39
+
40
+ ## Why Not Just Use libmagic?
41
+
42
+ For intact files starting at byte 0, libmagic works well. But libmagic matches *signatures at fixed offsets*. Magic-BERT learns *structural patterns* throughout the file, enabling use cases where you don't have clean file boundaries:
43
+
44
+ - **Network streams**: Classifying packet payloads mid-connection, before headers arrive
45
+ - **Disk forensics**: Identifying file types during carving, when scanning raw disk images without filesystem metadata
46
+ - **Fragment analysis**: Working with partial files, slack space, or corrupted data
47
+ - **Adversarial contexts**: Detecting file types when magic bytes are stripped, spoofed, or deliberately misleading
48
+
49
+ ## Model Description
50
+
51
+ Magic-BERT uses a custom BERT architecture with absolute position embeddings, trained on binary file data using a byte-level BPE tokenizer. The MLM objective teaches the model to predict masked bytes given surrounding context, which implicitly learns file format structure.
52
+
53
+ | Property | Value |
54
+ |----------|-------|
55
+ | Parameters | 59M |
56
+ | Hidden Size | 512 |
57
+ | Layers | 8 |
58
+ | Attention Heads | 8 |
59
+ | Max Sequence Length | 512 tokens |
60
+ | Vocabulary Size | 32,768 (byte-level BPE) |
61
+ | Position Encoding | Absolute (learned embeddings) |
62
+
63
+ ### Tokenizer
64
+
65
+ The tokenizer uses the Binary BPE methodology introduced in [Bommarito (2025)](https://arxiv.org/abs/2511.17573). The original Binary BPE tokenizers (available at [mjbommar/binary-tokenizer-001-64k](https://huggingface.co/mjbommar/binary-tokenizer-001-64k)) were trained exclusively on executable binaries (ELF, PE, Mach-O). This tokenizer uses the same BPE training approach but was trained on a diverse corpus spanning 106 file types including documents, images, audio/video, archives, and source code.
66
+
67
+ ## Intended Uses
68
+
69
+ **Primary use cases:**
70
+ - Fill-mask: Predicting missing bytes in binary files
71
+ - Magic byte and file signature recognition
72
+ - Feature extraction for downstream classification
73
+ - Research on binary file structure
74
+
75
+ **Example tasks:**
76
+ - Completing partial file headers
77
+ - Identifying file type from structure
78
+ - Anomaly detection in binary data
79
+
80
+ ## Detailed Use Cases
81
+
82
+ ### Network Traffic Analysis
83
+ When inspecting packet payloads, you often see file data mid-stream—TCP reassembly may give you bytes 1500-3000 of a PDF before you ever see byte 0. Traditional signature matching fails here. Magic-BERT's structural understanding can identify file types from interior content.
84
+
85
+ ### Disk Forensics & File Carving
86
+ During disk image analysis, you scan raw bytes looking for file boundaries. Tools like Scalpel rely on header/footer signatures, but many files lack clear footers. Magic-BERT can score byte ranges for file type probability, helping identify carved fragments or validate carving results.
87
+
88
+ ### Incident Response
89
+ Malware often strips or modifies magic bytes to evade detection. Polyglot files (valid as multiple types) exploit signature-based tools. Learning structural patterns provides a second opinion that doesn't rely solely on the first few bytes.
90
+
91
+ ### Embedded Content Detection
92
+ Files within files (email attachments, archive contents, OLE streams) may appear at arbitrary offsets. Embeddings from Magic-BERT enable similarity search: "find all chunks that look structurally like JPEG data" regardless of where they appear.
93
+
94
+ ## Training
95
+
96
+ ### Data
97
+ Trained on a diverse corpus of binary files spanning 106 MIME types, including:
98
+ - Documents (PDF, Office formats, OpenDocument)
99
+ - Images (PNG, JPEG, GIF, WebP, TIFF)
100
+ - Audio/Video (MP3, MP4, WebM, FLAC)
101
+ - Archives (ZIP, GZIP, 7z, TAR)
102
+ - Executables (ELF, PE, Mach-O)
103
+ - And 90+ additional formats
104
+
105
+ ### Procedure
106
+
107
+ | Phase | Steps | Learning Rate | Batch Size | Objective |
108
+ |-------|-------|---------------|------------|-----------|
109
+ | MLM Pre-training | 100,000 | 1e-4 | 240 | Masked LM (15% masking) |
110
+
111
+ **Data augmentation:** 50% of samples use random byte offset to reduce position bias.
112
+
113
+ ## Evaluation Results
114
+
115
+ ### Perplexity by Region
116
+ | Region | Perplexity |
117
+ |--------|------------|
118
+ | Magic Bytes (0-9) | 1.07 |
119
+ | Header (10-49) | 1.06 |
120
+ | Body (50+) | 1.05 |
121
+ | **Overall** | **1.05** |
122
+
123
+ ### Fill-Mask Accuracy
124
+ | Metric | Value |
125
+ |--------|-------|
126
+ | Top-1 Accuracy | 58.9% |
127
+ | Top-5 Accuracy | 73.5% |
128
+ | Mean Reciprocal Rank | 0.67 |
129
+
130
+ ### Representation Quality
131
+ | Metric | Value |
132
+ |--------|-------|
133
+ | Linear Probe Accuracy | 87.0% |
134
+ | Silhouette Score | 0.39 |
135
+ | Separation Ratio | 2.78 |
136
+
137
+ ## Architecture: Absolute vs Rotary Position Embeddings
138
+
139
+ This model uses **absolute position embeddings**, where each position (0-511) has a learned embedding vector added to the token embedding. This is the original BERT approach.
140
+
141
+ An alternative is **Rotary Position Embeddings (RoPE)**, used by the RoFormer variant. RoPE encodes relative position through rotation matrices applied to query and key vectors in attention, rather than learning absolute position vectors.
142
+
143
+ **Key finding from our experiments:** Both approaches show similar position bias (~47-48% accuracy drop at offset 1000). Position bias is primarily a data distribution issue (files naturally start at offset 0) rather than an architecture limitation.
144
+
145
+ | Aspect | Magic-BERT (this) | RoFormer |
146
+ |--------|-------------------|----------|
147
+ | Position Encoding | Absolute (learned) | RoPE (rotary) |
148
+ | Parameters | 59M | 42.3M |
149
+ | Perplexity | **1.05** | 1.13 |
150
+ | Fill-mask Top-1 | 58.9% | **61.8%** |
151
+ | Probing Accuracy | **87.0%** | 85.0% |
152
+
153
+ Magic-BERT achieves slightly better perplexity and probing accuracy, while RoFormer achieves better fill-mask accuracy with fewer parameters.
154
+
155
+ ## MLM vs Classification: Two-Phase Training
156
+
157
+ This is the **Phase 1 (MLM)** model. The training pipeline has two phases:
158
+
159
+ | Phase | Model | Task | Purpose |
160
+ |-------|-------|------|---------|
161
+ | **Phase 1** | **This model** | Masked Language Modeling | Learn byte-level patterns and file structure |
162
+ | Phase 2 | magic-bert-50m-classification | Contrastive Learning | Optimize embeddings for file type discrimination |
163
+
164
+ **When to use each:**
165
+ - Use **this model (MLM)** for: fill-mask tasks, research, or as a base for custom fine-tuning
166
+ - Use **classification model** for: file type detection, similarity search, production classification
167
+
168
+ ## How to Use
169
+
170
+ ```python
171
+ from transformers import AutoTokenizer
172
+ from safetensors.torch import load_file
173
+ import torch
174
+
175
+ # Load tokenizer
176
+ tokenizer = AutoTokenizer.from_pretrained("path/to/magic-bert-50m-mlm")
177
+
178
+ # For custom MagicBERT architecture, load directly
179
+ from modeling_magic_bert import MagicBERTForMaskedLM
180
+ from configuration_magic_bert import MagicBERTConfig
181
+
182
+ config = MagicBERTConfig.from_pretrained("path/to/magic-bert-50m-mlm")
183
+ model = MagicBERTForMaskedLM(config)
184
+ state_dict = load_file("path/to/magic-bert-50m-mlm/model.safetensors")
185
+ model.load_state_dict(state_dict)
186
+
187
+ # Fill-mask example
188
+ with open("example.pdf", "rb") as f:
189
+ data = f.read(512)
190
+
191
+ # Decode bytes to string using latin-1 (preserves all byte values 0-255)
192
+ text = data.decode("latin-1")
193
+
194
+ # Tokenize and mask
195
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
196
+ mask_pos = 0 # Mask first token
197
+ inputs["input_ids"][0, mask_pos] = tokenizer.mask_token_id
198
+
199
+ # Predict
200
+ with torch.no_grad():
201
+ outputs = model(**inputs)
202
+ predictions = outputs.logits[0, mask_pos].topk(5)
203
+
204
+ print("Top-5 predictions:", tokenizer.convert_ids_to_tokens(predictions.indices))
205
+ ```
206
+
207
+ ### Getting Embeddings
208
+
209
+ ```python
210
+ # Get CLS embeddings for downstream tasks
211
+ with torch.no_grad():
212
+ embeddings = model.get_embeddings(inputs["input_ids"], inputs["attention_mask"])
213
+ # embeddings shape: [batch_size, 512]
214
+ ```
215
+
216
+ ## Limitations
217
+
218
+ 1. **Position bias:** The model performs best when file content starts at position 0. Accuracy drops ~48% when content starts at offset 1000. This reflects training data distribution, not architectural limitations.
219
+
220
+ 2. **Sequence length:** Limited to 512 tokens. Longer files require truncation or chunking.
221
+
222
+ 3. **Text files:** Lower performance on high-entropy or highly variable content (e.g., encrypted data, random bytes).
223
+
224
+ 4. **Domain specificity:** Trained on common file formats; may not generalize to rare or proprietary formats.
225
+
226
+ ## Model Selection Guide
227
+
228
+ | Use Case | Recommended Model | Reason |
229
+ |----------|-------------------|--------|
230
+ | Fill-mask / byte prediction | **This model** | Best perplexity (1.05) |
231
+ | Research baseline | **This model** | Established BERT architecture |
232
+ | Classification + fill-mask | magic-bert-50m-classification | Retains 41.8% fill-mask capability |
233
+ | **Production classification** | **magic-bert-50m-roformer-classification** | Highest accuracy (93.7%), efficient (42M params) |
234
+
235
+ ## Related Models
236
+
237
+ - **magic-bert-50m-classification**: Same architecture fine-tuned for classification (89.7% accuracy)
238
+ - **magic-bert-50m-roformer-mlm**: RoFormer variant with rotary position embeddings
239
+ - **magic-bert-50m-roformer-classification**: RoFormer variant fine-tuned for classification (93.7% accuracy, recommended for production)
240
+
241
+ ## Related Work
242
+
243
+ This model builds on the Binary BPE tokenization approach:
244
+
245
+ - **Binary BPE Paper**: [Bommarito (2025)](https://arxiv.org/abs/2511.17573) introduced byte-level BPE tokenization for binary analysis, demonstrating 2-3x compression over raw bytes for executable content.
246
+ - **Binary BPE Tokenizers**: Pre-trained tokenizers for executables are available at [mjbommar/binary-tokenizer-001-64k](https://huggingface.co/mjbommar/binary-tokenizer-001-64k).
247
+
248
+ **Key difference**: The original Binary BPE work focused on executable binaries (ELF, PE, Mach-O). Magic-BERT extends this to general file type understanding across 106 diverse formats, using a tokenizer trained on the broader dataset.
249
+
250
+ ## Citation
251
+
252
+ A paper describing Magic-BERT, the training methodology, and the dataset is forthcoming.
253
+
254
+ ```bibtex
255
+ @article{bommarito2025binarybpe,
256
+ title={Binary BPE: A Family of Cross-Platform Tokenizers for Binary Analysis},
257
+ author={Bommarito, Michael J., II},
258
+ journal={arXiv preprint arXiv:2511.17573},
259
+ year={2025}
260
+ }
261
+ ```
config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "magic-bert",
3
+ "architectures": [
4
+ "MagicBERTForMaskedLM"
5
+ ],
6
+ "vocab_size": 32768,
7
+ "hidden_size": 512,
8
+ "num_hidden_layers": 8,
9
+ "num_attention_heads": 8,
10
+ "intermediate_size": 2048,
11
+ "hidden_dropout_prob": 0.1,
12
+ "attention_probs_dropout_prob": 0.1,
13
+ "max_position_embeddings": 512,
14
+ "pad_token_id": 2,
15
+ "hidden_act": "gelu",
16
+ "layer_norm_eps": 1e-12,
17
+ "torch_dtype": "float32",
18
+ "transformers_version": "4.57.0",
19
+ "auto_map": {
20
+ "AutoConfig": "configuration_magic_bert.MagicBERTConfig",
21
+ "AutoModel": "modeling_magic_bert.MagicBERTModel",
22
+ "AutoModelForMaskedLM": "modeling_magic_bert.MagicBERTForMaskedLM"
23
+ }
24
+ }
configuration_magic_bert.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MagicBERT configuration for HuggingFace transformers."""
2
+
3
+ from transformers import PretrainedConfig
4
+
5
+
6
+ class MagicBERTConfig(PretrainedConfig):
7
+ """Configuration class for MagicBERT model.
8
+
9
+ MagicBERT is a BERT-style transformer model designed for binary file
10
+ type classification. It uses a byte-level BPE tokenizer with a 32K vocabulary.
11
+ """
12
+
13
+ model_type = "magic-bert"
14
+
15
+ def __init__(
16
+ self,
17
+ vocab_size: int = 32768,
18
+ hidden_size: int = 512,
19
+ num_hidden_layers: int = 8,
20
+ num_attention_heads: int = 8,
21
+ intermediate_size: int = 2048,
22
+ hidden_dropout_prob: float = 0.1,
23
+ attention_probs_dropout_prob: float = 0.1,
24
+ max_position_embeddings: int = 512,
25
+ pad_token_id: int = 2,
26
+ hidden_act: str = "gelu",
27
+ layer_norm_eps: float = 1e-12,
28
+ **kwargs,
29
+ ):
30
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
31
+ self.vocab_size = vocab_size
32
+ self.hidden_size = hidden_size
33
+ self.num_hidden_layers = num_hidden_layers
34
+ self.num_attention_heads = num_attention_heads
35
+ self.intermediate_size = intermediate_size
36
+ self.hidden_dropout_prob = hidden_dropout_prob
37
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
38
+ self.max_position_embeddings = max_position_embeddings
39
+ self.hidden_act = hidden_act
40
+ self.layer_norm_eps = layer_norm_eps
mime_type_mapping.json ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "0": "application/SIMH-tape-data",
3
+ "1": "application/encrypted",
4
+ "2": "application/gzip",
5
+ "3": "application/javascript",
6
+ "4": "application/json",
7
+ "5": "application/msword",
8
+ "6": "application/mxf",
9
+ "7": "application/octet-stream",
10
+ "8": "application/pdf",
11
+ "9": "application/pgp-keys",
12
+ "10": "application/postscript",
13
+ "11": "application/vnd.microsoft.portable-executable",
14
+ "12": "application/vnd.ms-excel",
15
+ "13": "application/vnd.ms-opentype",
16
+ "14": "application/vnd.ms-powerpoint",
17
+ "15": "application/vnd.oasis.opendocument.spreadsheet",
18
+ "16": "application/vnd.openxmlformats-officedocument.presentationml.presentation",
19
+ "17": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
20
+ "18": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
21
+ "19": "application/vnd.rn-realmedia",
22
+ "20": "application/vnd.wordperfect",
23
+ "21": "application/wasm",
24
+ "22": "application/x-7z-compressed",
25
+ "23": "application/x-archive",
26
+ "24": "application/x-bzip2",
27
+ "25": "application/x-coff",
28
+ "26": "application/x-dbf",
29
+ "27": "application/x-dosexec",
30
+ "28": "application/x-executable",
31
+ "29": "application/x-gettext-translation",
32
+ "30": "application/x-ms-ne-executable",
33
+ "31": "application/x-ndjson",
34
+ "32": "application/x-object",
35
+ "33": "application/x-ole-storage",
36
+ "34": "application/x-sharedlib",
37
+ "35": "application/x-shockwave-flash",
38
+ "36": "application/x-tar",
39
+ "37": "application/x-wine-extension-ini",
40
+ "38": "application/zip",
41
+ "39": "application/zlib",
42
+ "40": "application/zstd",
43
+ "41": "audio/amr",
44
+ "42": "audio/flac",
45
+ "43": "audio/mpeg",
46
+ "44": "audio/ogg",
47
+ "45": "audio/x-ape",
48
+ "46": "audio/x-hx-aac-adts",
49
+ "47": "audio/x-m4a",
50
+ "48": "audio/x-wav",
51
+ "49": "biosig/atf",
52
+ "50": "font/sfnt",
53
+ "51": "font/woff",
54
+ "52": "font/woff2",
55
+ "53": "image/bmp",
56
+ "54": "image/fits",
57
+ "55": "image/gif",
58
+ "56": "image/heif",
59
+ "57": "image/jpeg",
60
+ "58": "image/png",
61
+ "59": "image/svg+xml",
62
+ "60": "image/tiff",
63
+ "61": "image/vnd.adobe.photoshop",
64
+ "62": "image/vnd.microsoft.icon",
65
+ "63": "image/webp",
66
+ "64": "image/x-eps",
67
+ "65": "image/x-exr",
68
+ "66": "image/x-jp2-codestream",
69
+ "67": "image/x-portable-bitmap",
70
+ "68": "image/x-portable-greymap",
71
+ "69": "image/x-tga",
72
+ "70": "image/x-xpixmap",
73
+ "71": "inode/x-empty",
74
+ "72": "message/rfc822",
75
+ "73": "text/csv",
76
+ "74": "text/html",
77
+ "75": "text/plain",
78
+ "76": "text/rtf",
79
+ "77": "text/troff",
80
+ "78": "text/x-Algol68",
81
+ "79": "text/x-asm",
82
+ "80": "text/x-c",
83
+ "81": "text/x-c++",
84
+ "82": "text/x-diff",
85
+ "83": "text/x-file",
86
+ "84": "text/x-fortran",
87
+ "85": "text/x-java",
88
+ "86": "text/x-m4",
89
+ "87": "text/x-makefile",
90
+ "88": "text/x-msdos-batch",
91
+ "89": "text/x-perl",
92
+ "90": "text/x-php",
93
+ "91": "text/x-po",
94
+ "92": "text/x-ruby",
95
+ "93": "text/x-script.python",
96
+ "94": "text/x-shellscript",
97
+ "95": "text/x-tex",
98
+ "96": "text/xml",
99
+ "97": "video/3gpp",
100
+ "98": "video/mp4",
101
+ "99": "video/mpeg",
102
+ "100": "video/quicktime",
103
+ "101": "video/webm",
104
+ "102": "video/x-ivf",
105
+ "103": "video/x-matroska",
106
+ "104": "video/x-ms-asf",
107
+ "105": "video/x-msvideo"
108
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:494f918a228fcd32d8967cb5decebafdf3b2d0e9a34601e6f9771387e0080d1f
3
+ size 236291992
modeling_magic_bert.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MagicBERT model implementation for HuggingFace transformers.
2
+
3
+ This module provides HuggingFace-compatible implementations of MagicBERT,
4
+ a BERT-style model trained for binary file type understanding.
5
+ """
6
+
7
+ import math
8
+ from dataclasses import dataclass
9
+ from typing import Optional, Tuple, Union
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from transformers import PreTrainedModel
15
+ from transformers.modeling_outputs import (
16
+ MaskedLMOutput,
17
+ SequenceClassifierOutput,
18
+ BaseModelOutput,
19
+ )
20
+
21
+ try:
22
+ from .configuration_magic_bert import MagicBERTConfig
23
+ except ImportError:
24
+ from configuration_magic_bert import MagicBERTConfig
25
+
26
+
27
+ class MagicBERTEmbeddings(nn.Module):
28
+ """MagicBERT embeddings: token + position embeddings."""
29
+
30
+ def __init__(self, config: MagicBERTConfig):
31
+ super().__init__()
32
+ self.token_embeddings = nn.Embedding(
33
+ config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
34
+ )
35
+ self.position_embeddings = nn.Embedding(
36
+ config.max_position_embeddings, config.hidden_size
37
+ )
38
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
39
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
40
+
41
+ self.register_buffer(
42
+ "position_ids",
43
+ torch.arange(config.max_position_embeddings).expand((1, -1)),
44
+ persistent=False,
45
+ )
46
+
47
+ def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
48
+ batch_size, seq_length = input_ids.shape
49
+ token_embeds = self.token_embeddings(input_ids)
50
+ position_ids = self.position_ids[:, :seq_length]
51
+ position_embeds = self.position_embeddings(position_ids)
52
+ embeddings = token_embeds + position_embeds
53
+ embeddings = self.layer_norm(embeddings)
54
+ embeddings = self.dropout(embeddings)
55
+ return embeddings
56
+
57
+
58
+ class MagicBERTAttention(nn.Module):
59
+ """Multi-head self-attention."""
60
+
61
+ def __init__(self, config: MagicBERTConfig):
62
+ super().__init__()
63
+ self.num_attention_heads = config.num_attention_heads
64
+ self.attention_head_size = config.hidden_size // config.num_attention_heads
65
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
66
+
67
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
68
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
69
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
70
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
71
+
72
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
73
+ new_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
74
+ x = x.view(new_shape)
75
+ return x.permute(0, 2, 1, 3)
76
+
77
+ def forward(
78
+ self,
79
+ hidden_states: torch.Tensor,
80
+ attention_mask: Optional[torch.Tensor] = None,
81
+ ) -> torch.Tensor:
82
+ query_layer = self.transpose_for_scores(self.query(hidden_states))
83
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
84
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
85
+
86
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
87
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
88
+
89
+ if attention_mask is not None:
90
+ attention_mask = attention_mask[:, None, None, :]
91
+ attention_scores = attention_scores + (1.0 - attention_mask) * -10000.0
92
+
93
+ attention_probs = F.softmax(attention_scores, dim=-1)
94
+ attention_probs = self.dropout(attention_probs)
95
+ context = torch.matmul(attention_probs, value_layer)
96
+ context = context.permute(0, 2, 1, 3).contiguous()
97
+ new_shape = context.size()[:-2] + (self.all_head_size,)
98
+ context = context.view(new_shape)
99
+ return context
100
+
101
+
102
+ class MagicBERTLayer(nn.Module):
103
+ """Single transformer layer."""
104
+
105
+ def __init__(self, config: MagicBERTConfig):
106
+ super().__init__()
107
+ self.attention = MagicBERTAttention(config)
108
+ self.attention_output = nn.Linear(config.hidden_size, config.hidden_size)
109
+ self.attention_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
110
+ self.attention_dropout = nn.Dropout(config.hidden_dropout_prob)
111
+
112
+ self.intermediate = nn.Linear(config.hidden_size, config.intermediate_size)
113
+ self.output = nn.Linear(config.intermediate_size, config.hidden_size)
114
+ self.output_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
115
+ self.output_dropout = nn.Dropout(config.hidden_dropout_prob)
116
+
117
+ def forward(
118
+ self,
119
+ hidden_states: torch.Tensor,
120
+ attention_mask: Optional[torch.Tensor] = None,
121
+ ) -> torch.Tensor:
122
+ # Self-attention with residual
123
+ attention_output = self.attention(hidden_states, attention_mask)
124
+ attention_output = self.attention_output(attention_output)
125
+ attention_output = self.attention_dropout(attention_output)
126
+ attention_output = self.attention_norm(hidden_states + attention_output)
127
+
128
+ # Feed-forward with residual
129
+ intermediate_output = self.intermediate(attention_output)
130
+ intermediate_output = F.gelu(intermediate_output)
131
+ layer_output = self.output(intermediate_output)
132
+ layer_output = self.output_dropout(layer_output)
133
+ layer_output = self.output_norm(attention_output + layer_output)
134
+ return layer_output
135
+
136
+
137
+ class MagicBERTEncoder(nn.Module):
138
+ """Stack of transformer layers."""
139
+
140
+ def __init__(self, config: MagicBERTConfig):
141
+ super().__init__()
142
+ self.layers = nn.ModuleList(
143
+ [MagicBERTLayer(config) for _ in range(config.num_hidden_layers)]
144
+ )
145
+
146
+ def forward(
147
+ self,
148
+ hidden_states: torch.Tensor,
149
+ attention_mask: Optional[torch.Tensor] = None,
150
+ ) -> torch.Tensor:
151
+ for layer in self.layers:
152
+ hidden_states = layer(hidden_states, attention_mask)
153
+ return hidden_states
154
+
155
+
156
+ class MagicBERTPreTrainedModel(PreTrainedModel):
157
+ """Base class for MagicBERT models."""
158
+
159
+ config_class = MagicBERTConfig
160
+ base_model_prefix = "magic_bert"
161
+ supports_gradient_checkpointing = False
162
+
163
+ def _init_weights(self, module):
164
+ if isinstance(module, nn.Linear):
165
+ module.weight.data.normal_(mean=0.0, std=0.02)
166
+ if module.bias is not None:
167
+ module.bias.data.zero_()
168
+ elif isinstance(module, nn.Embedding):
169
+ module.weight.data.normal_(mean=0.0, std=0.02)
170
+ if module.padding_idx is not None:
171
+ module.weight.data[module.padding_idx].zero_()
172
+ elif isinstance(module, nn.LayerNorm):
173
+ module.bias.data.zero_()
174
+ module.weight.data.fill_(1.0)
175
+
176
+
177
+ class MagicBERTModel(MagicBERTPreTrainedModel):
178
+ """MagicBERT base model outputting raw hidden states."""
179
+
180
+ def __init__(self, config: MagicBERTConfig):
181
+ super().__init__(config)
182
+ self.config = config
183
+ self.embeddings = MagicBERTEmbeddings(config)
184
+ self.encoder = MagicBERTEncoder(config)
185
+ self.post_init()
186
+
187
+ def forward(
188
+ self,
189
+ input_ids: torch.Tensor,
190
+ attention_mask: Optional[torch.Tensor] = None,
191
+ token_type_ids: Optional[torch.Tensor] = None, # Ignored, for tokenizer compatibility
192
+ return_dict: Optional[bool] = None,
193
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], BaseModelOutput]:
194
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
195
+
196
+ hidden_states = self.embeddings(input_ids)
197
+ sequence_output = self.encoder(hidden_states, attention_mask)
198
+ pooled_output = sequence_output[:, 0, :]
199
+
200
+ if not return_dict:
201
+ return (sequence_output, pooled_output)
202
+
203
+ return BaseModelOutput(
204
+ last_hidden_state=sequence_output,
205
+ hidden_states=None,
206
+ attentions=None,
207
+ )
208
+
209
+
210
+ class MagicBERTForMaskedLM(MagicBERTPreTrainedModel):
211
+ """MagicBERT for masked language modeling (fill-mask task)."""
212
+
213
+ def __init__(self, config: MagicBERTConfig):
214
+ super().__init__(config)
215
+ self.config = config
216
+ self.embeddings = MagicBERTEmbeddings(config)
217
+ self.encoder = MagicBERTEncoder(config)
218
+ self.mlm_head = nn.Linear(config.hidden_size, config.vocab_size)
219
+ self.post_init()
220
+
221
+ def forward(
222
+ self,
223
+ input_ids: torch.Tensor,
224
+ attention_mask: Optional[torch.Tensor] = None,
225
+ token_type_ids: Optional[torch.Tensor] = None, # Ignored, for tokenizer compatibility
226
+ labels: Optional[torch.Tensor] = None,
227
+ return_dict: Optional[bool] = None,
228
+ ) -> Union[Tuple[torch.Tensor, ...], MaskedLMOutput]:
229
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
230
+
231
+ hidden_states = self.embeddings(input_ids)
232
+ sequence_output = self.encoder(hidden_states, attention_mask)
233
+ logits = self.mlm_head(sequence_output)
234
+
235
+ loss = None
236
+ if labels is not None:
237
+ loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
238
+ loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
239
+
240
+ if not return_dict:
241
+ output = (logits,)
242
+ return ((loss,) + output) if loss is not None else output
243
+
244
+ return MaskedLMOutput(
245
+ loss=loss,
246
+ logits=logits,
247
+ hidden_states=None,
248
+ attentions=None,
249
+ )
250
+
251
+ def get_embeddings(
252
+ self,
253
+ input_ids: torch.Tensor,
254
+ attention_mask: Optional[torch.Tensor] = None,
255
+ pooling: str = "cls",
256
+ ) -> torch.Tensor:
257
+ """Get embeddings for downstream tasks.
258
+
259
+ Args:
260
+ input_ids: Input token IDs
261
+ attention_mask: Attention mask
262
+ pooling: Pooling strategy ("cls" or "mean")
263
+
264
+ Returns:
265
+ Pooled embeddings [batch_size, hidden_size]
266
+ """
267
+ hidden_states = self.embeddings(input_ids)
268
+ sequence_output = self.encoder(hidden_states, attention_mask)
269
+
270
+ if pooling == "cls":
271
+ return sequence_output[:, 0, :]
272
+ elif pooling == "mean":
273
+ if attention_mask is not None:
274
+ mask = attention_mask.unsqueeze(-1).float()
275
+ return (sequence_output * mask).sum(1) / mask.sum(1).clamp(min=1e-9)
276
+ return sequence_output.mean(dim=1)
277
+ else:
278
+ raise ValueError(f"Unknown pooling: {pooling}")
279
+
280
+
281
+ class MagicBERTForSequenceClassification(MagicBERTPreTrainedModel):
282
+ """MagicBERT for sequence classification (file type classification)."""
283
+
284
+ def __init__(self, config: MagicBERTConfig):
285
+ super().__init__(config)
286
+ self.config = config
287
+ self.num_labels = getattr(config, "num_labels", 106)
288
+
289
+ self.embeddings = MagicBERTEmbeddings(config)
290
+ self.encoder = MagicBERTEncoder(config)
291
+
292
+ # Projection head (for contrastive learning compatibility)
293
+ projection_dim = getattr(config, "contrastive_projection_dim", 256)
294
+ self.projection = nn.Sequential(
295
+ nn.Linear(config.hidden_size, config.hidden_size),
296
+ nn.ReLU(),
297
+ nn.Linear(config.hidden_size, projection_dim),
298
+ )
299
+ self.classifier = nn.Linear(projection_dim, self.num_labels)
300
+ self.post_init()
301
+
302
+ def forward(
303
+ self,
304
+ input_ids: torch.Tensor,
305
+ attention_mask: Optional[torch.Tensor] = None,
306
+ token_type_ids: Optional[torch.Tensor] = None, # Ignored, for tokenizer compatibility
307
+ labels: Optional[torch.Tensor] = None,
308
+ return_dict: Optional[bool] = None,
309
+ ) -> Union[Tuple[torch.Tensor, ...], SequenceClassifierOutput]:
310
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
311
+
312
+ hidden_states = self.embeddings(input_ids)
313
+ sequence_output = self.encoder(hidden_states, attention_mask)
314
+ pooled_output = sequence_output[:, 0, :]
315
+
316
+ projections = self.projection(pooled_output)
317
+ projections = F.normalize(projections, p=2, dim=1)
318
+ logits = self.classifier(projections)
319
+
320
+ loss = None
321
+ if labels is not None:
322
+ loss_fct = nn.CrossEntropyLoss()
323
+ loss = loss_fct(logits, labels)
324
+
325
+ if not return_dict:
326
+ output = (logits,)
327
+ return ((loss,) + output) if loss is not None else output
328
+
329
+ return SequenceClassifierOutput(
330
+ loss=loss,
331
+ logits=logits,
332
+ hidden_states=None,
333
+ attentions=None,
334
+ )
335
+
336
+ def get_embeddings(
337
+ self,
338
+ input_ids: torch.Tensor,
339
+ attention_mask: Optional[torch.Tensor] = None,
340
+ ) -> torch.Tensor:
341
+ """Get normalized projection embeddings for similarity search."""
342
+ hidden_states = self.embeddings(input_ids)
343
+ sequence_output = self.encoder(hidden_states, attention_mask)
344
+ pooled_output = sequence_output[:, 0, :]
345
+ projections = self.projection(pooled_output)
346
+ return F.normalize(projections, p=2, dim=1)
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "tokenizer_class": "PreTrainedTokenizerFast",
3
+ "model_max_length": 512,
4
+ "pad_token": "[PAD]",
5
+ "mask_token": "[MASK]",
6
+ "cls_token": "[CLS]",
7
+ "sep_token": "[SEP]",
8
+ "unk_token": "[UNK]"
9
+ }