Alejandro Pirola commited on
Commit
b1f8e6c
·
0 Parent(s):

Initial commit: SAM2 finetuned checkpoint + config

Browse files
.gitattributes ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Track large model weights with LFS
2
+ *.pt filter=lfs diff=lfs merge=lfs -text
3
+ *.pth filter=lfs diff=lfs merge=lfs -text
4
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
5
+ # Text files with LF normalization
6
+ *.json text eol=lf
7
+ README.md text eol=lf
8
+ *.md text eol=lf
.gitignore ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Outputs / artifacts
2
+ outputs/
3
+ logs/
4
+ *.log
5
+
6
+ # Python cache
7
+ __pycache__/
8
+ *.pyc
9
+
10
+ # Datasets (no subir datos privados)
11
+ data/
README.md ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SAM2 ID Segmenter
2
+
3
+ Lightweight wrapper and fine‑tuning scaffold around Meta's Segment Anything 2 (SAM2) adapted to segment structured regions in ID / document images (e.g. portrait, number field, security areas). The repository currently focuses on: (1) reproducible loading of a fine‑tuned SAM2 checkpoint, (2) automatic multi‑mask generation + tight cropping, and (3) configuration file driven training/inference settings.
4
+
5
+ > Status: Inference wrapper implemented (`SamSegmentator`). End‑to‑end training loop is a planned addition. Config already anticipates training hyper‑parameters.
6
+
7
+ ---
8
+
9
+ ## Contents
10
+ 1. Motivation & Scope
11
+ 2. Intended Use & Non‑Goals
12
+ 3. Repository Structure
13
+ 4. Configuration (`config.json`)
14
+ 5. Installation
15
+ 6. Inference Usage (`SamSegmentator`)
16
+ 7. Dataset & Mask Format (planned training)
17
+ 8. Checkpoints & Auto‑Download
18
+ 9. Metrics (recommended)
19
+ 10. Limitations & Risks
20
+ 11. Roadmap
21
+ 12. License & Citation
22
+
23
+ ---
24
+
25
+ ## 1. Motivation & Scope
26
+ Document / ID workflows often need fast class‑agnostic region extraction (for OCR, redaction, or downstream classifiers). SAM2 provides strong general mask proposals; this project wraps it to directly yield cropped image + mask pairs ordered by area and optionally padded.
27
+
28
+ ## 2. Intended Use & Non‑Goals
29
+ Intended:
30
+ - Pre‑segmentation of ID / document fields prior to OCR.
31
+ - Selective anonymization / redaction pipelines (masking faces, MRZ, barcodes, etc.).
32
+ - Rapid prototyping for custom fine‑tuning of SAM2 on a small set of document classes.
33
+
34
+ Non‑Goals:
35
+ - Biometric identity verification or authoritative fraud detection.
36
+ - Legal decision making without human review.
37
+ - Full multi‑modal extraction (text recognition is out of scope here).
38
+
39
+ ## 3. Repository Structure
40
+ ```
41
+ model_repo/
42
+ config.json # Central hyper‑parameter & path config
43
+ README.md # (this file)
44
+ checkpoints/ # Local downloaded / fine‑tuned checkpoints
45
+ samples/
46
+ sample_us_passport.jpg
47
+ src/
48
+ sam_segmentator.py # Inference wrapper (SamSegmentator)
49
+ main.py # Placeholder entry point
50
+ ```
51
+ Planned: `train/` scripts for fine‑tuning (not yet implemented).
52
+
53
+ ## 4. Configuration (`model_repo/config.json`)
54
+ Key fields (example values included in the repo):
55
+ - `model_type`: Always `sam2` here.
56
+ - `checkpoint_path`: Path relative to project root or absolute; if omitted and `auto_download=True` the code will attempt remote download.
57
+ - `image_size`: Target square size used during training (future). Inference wrapper accepts raw image size.
58
+ - `num_classes`, `class_names`: For supervised training (future); not required by the current automatic mask generator, but kept for consistency.
59
+ - `augmentation`, `loss`, `optimizer`, `lr_scheduler`: Reserved for training loop integration.
60
+ - `paths`: Expected dataset layout for training: `data/train/images`, `data/train/masks`, etc.
61
+ - `mixed_precision`: Will enable `torch.autocast` during training.
62
+
63
+ Even if not all fields are consumed now, keeping them centralized avoids future breaking refactors.
64
+
65
+ ## 5. Installation
66
+
67
+ ### Prerequisites
68
+ - Python 3.10+ (recommended)
69
+ - CUDA GPU (optional but recommended for speed)
70
+
71
+ ### Using uv (preferred fast resolver)
72
+ If `pyproject.toml` is present (it is), you can do:
73
+ ```
74
+ uv sync
75
+ ```
76
+ This creates / updates the virtual environment and installs dependencies.
77
+
78
+ ### Using pip (alternative)
79
+ ```
80
+ python -m venv .venv
81
+ .venv\Scripts\activate
82
+ pip install -U pip
83
+ pip install -e .
84
+ ```
85
+
86
+ If SAM2 is not a published package in your environment, you may need to install it from source (instructions will depend on the upstream SAM2 repository—add here when finalized).
87
+
88
+ ## 6. Inference Usage (`SamSegmentator`)
89
+ Minimal example using the sample passport image:
90
+ ```python
91
+ import cv2
92
+ from pathlib import Path
93
+ from src.sam_segmentator import SamSegmentator
94
+
95
+ image_path = Path("samples/sample_us_passport.jpg")
96
+ img_bgr = cv2.imread(str(image_path)) # BGR (OpenCV)
97
+
98
+ segmentator = SamSegmentator(
99
+ checkpoint_path="checkpoints/sam2.1_hiera_base_plus_ft_ids.pt", # or None to auto-download if configured
100
+ pred_iou_thresh=0.88, # forwarded to SAM2AutomaticMaskGenerator
101
+ stability_score_thresh=0.90,
102
+ )
103
+
104
+ segments = segmentator.infer(img_bgr, pad_percent=0.05)
105
+ print(f"Total segments: {len(segments)}")
106
+
107
+ # Each segment is (crop_bgr, mask_255)
108
+ for i, (crop, mask) in enumerate(segments[:3]):
109
+ cv2.imwrite(f"outputs/segment_{i}_crop.png", crop)
110
+ cv2.imwrite(f"outputs/segment_{i}_mask.png", mask)
111
+ ```
112
+ Output: pairs of tightly cropped images and their binary masks (0 background, 255 foreground), sorted by mask area descending.
113
+
114
+ ### Parameter Notes
115
+ - `pad_percent`: Relative padding (default 5%) added around each tight bounding box.
116
+ - Deprecated `pad` (absolute pixels) still accepted but will warn.
117
+ - All additional kwargs go to `SAM2AutomaticMaskGenerator` (e.g., `box_nms_thresh`, `min_mask_region_area`).
118
+
119
+ ## 7. Dataset & Mask Format (For Future Training)
120
+ Expected layout (mirrors `paths` in config):
121
+ ```
122
+ data/
123
+ train/
124
+ images/*.jpg|png
125
+ masks/*.png # Single‑channel, integer indices (0=background)
126
+ val/
127
+ images/
128
+ masks/
129
+ ```
130
+ Class index mapping (example):
131
+ ```
132
+ class_names = ["ID1", "ID3", "IDCOVER"]
133
+ 0 -> background
134
+ 1 -> ID1
135
+ 2 -> ID3
136
+ 3 -> IDCOVER
137
+ ```
138
+ Masks should use nearest‑neighbor safe compression (PNG). Avoid palette mismatch; explicit integer pixel values are recommended.
139
+
140
+ ## 8. Checkpoints & Auto‑Download
141
+ `SamSegmentator` will:
142
+ 1. Use provided `checkpoint_path` if it exists.
143
+ 2. If none is provided and `auto_download=True`, download the default checkpoint to `checkpoints/` using an environment configured URL (`SAM2_CHECKPOINT_URL`).
144
+ 3. (Optional) Validate SHA256 if `SAM2_CHECKPOINT_SHA256` is set.
145
+
146
+ Environment variables:
147
+ ```
148
+ SAM2_CHECKPOINT_URL=<direct_download_url>
149
+ SAM2_CHECKPOINT_SHA256=<hex>
150
+ SAM2_CHECKPOINT_DIR=checkpoints
151
+ ```
152
+
153
+
154
+ ## 9. Metrics (Recommended When Training Added)
155
+ - Mean IoU (per class & macro average)
156
+ - Dice coefficient
157
+ - Pixel accuracy
158
+ - Class frequency distribution (to inform potential class weighting)
159
+ Store per‑epoch metrics as JSON for reproducibility.
160
+
161
+ ## 10. Limitations & Risks
162
+ Technical:
163
+ - Current version does not include a fine‑tuning script; only inference wrapper.
164
+ - Automatic mask generator is class‑agnostic; without fine‑tuning it may over‑segment or miss tiny fields.
165
+
166
+ Ethical / Compliance:
167
+ - Processing ID documents may involve PII; ensure secure storage and compliant handling.
168
+ - Not intended for biometric decisions nor identity verification pipelines without human oversight.
169
+
170
+ ## 11. Roadmap
171
+ - [ ] Add training script (supervised fine‑tuning using `config.json`).
172
+ - [ ] Optional class‑guided prompting (points / boxes) pipeline.
173
+ - [ ] Export to ONNX / TorchScript.
174
+ - [ ] CLI interface for batch folder inference.
175
+ - [ ] Lightweight web demo (Gradio / FastAPI).
176
+
177
+ ## 12. License & Citation
178
+ Specify a license in a top‑level `LICENSE` file (e.g., MIT or Apache‑2.0) ensuring compatibility with SAM2's original license.
179
+
180
+ Please cite SAM / SAM2 in academic work. Example (placeholder):
181
+ ```
182
+ @article{kirillov2023segmentanything,
183
+ title={Segment Anything},
184
+ author={Kirillov, Alexander and others},
185
+ journal={arXiv preprint arXiv:2304.02643},
186
+ year={2023}
187
+ }
188
+ ```
189
+ Add updated SAM2 citation once official reference is finalized.
190
+
191
+ ## Acknowledgments
192
+ - Meta AI for releasing Segment Anything & SAM2.
193
+ - OpenCV, PyTorch, and the broader CV community.
194
+
195
+ ---
196
+ If you have questions or need feature prioritization, open an Issue or start a Discussion.
config.json ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "sam2",
3
+ "checkpoint_path": "weights/sam2_base.pth",
4
+ "image_size": [1024, 1024],
5
+ "num_classes": 10,
6
+ "class_names": [
7
+ "ID1",
8
+ "ID3",
9
+ "IDCOVER"
10
+ ],
11
+ "input_channels": 3,
12
+ "learning_rate": 1e-5,
13
+ "weight_decay": 0.01,
14
+ "batch_size": 2,
15
+ "gradient_accumulation_steps": 8,
16
+ "num_epochs": 100,
17
+ "optimizer": "adamw",
18
+ "lr_scheduler": {
19
+ "type": "cosine",
20
+ "warmup_epochs": 5,
21
+ "min_lr": 1e-7
22
+ },
23
+ "loss": {
24
+ "primary": "cross_entropy",
25
+ "auxiliary": ["dice"],
26
+ "dice_smooth": 1.0,
27
+ "class_weights": null
28
+ },
29
+ "mixed_precision": true,
30
+ "early_stopping": {
31
+ "patience": 15,
32
+ "metric": "val_loss",
33
+ "mode": "min"
34
+ },
35
+ "dropout_rate": 0.0,
36
+ "augmentation": {
37
+ "horizontal_flip": true,
38
+ "vertical_flip": false,
39
+ "rotation_deg": 15,
40
+ "random_crop": true,
41
+ "scale_range": [0.9, 1.1],
42
+ "brightness": 0.1,
43
+ "contrast": 0.1,
44
+ "color_jitter_prob": 0.3
45
+ },
46
+ "normalization": {
47
+ "mean": [0.485, 0.456, 0.406],
48
+ "std": [0.229, 0.224, 0.225]
49
+ },
50
+ "dataloader": {
51
+ "num_workers": 4,
52
+ "pin_memory": true,
53
+ "shuffle": true
54
+ },
55
+ "paths": {
56
+ "train_images": "data/train/images",
57
+ "train_masks": "data/train/masks",
58
+ "val_images": "data/val/images",
59
+ "val_masks": "data/val/masks",
60
+ "output_dir": "outputs"
61
+ },
62
+ "logging": {
63
+ "log_interval": 50,
64
+ "save_checkpoint_every": 1
65
+ },
66
+ "seed": 42
67
+ }
processor_config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "preprocessing": {
3
+ "resize": {
4
+ "height": 256,
5
+ "width": 256
6
+ },
7
+ "normalization": {
8
+ "mean": [0.485, 0.456, 0.406],
9
+ "std": [0.229, 0.224, 0.225]
10
+ },
11
+ "augmentation": {
12
+ "random_flip": true,
13
+ "random_crop": {
14
+ "height": 224,
15
+ "width": 224
16
+ }
17
+ }
18
+ },
19
+ "tokenization": {
20
+ "do_lower_case": true,
21
+ "max_length": 512,
22
+ "padding": "max_length"
23
+ }
24
+ }
sam2.1_hiera_base_plus_ft_ids.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:64f76f41204b7694ea59200d85d8b742e1808532aa063118d3d043d79aa285b3
3
+ size 910662494
sam_checkpoint.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c30cc8d0758ccf4154a7857ae971917f379a2b781a4149c88c3b2d1bc654a452
3
+ size 40