Image Classification
mlx-image
Safetensors
MLX
vision
riccardomusmeci commited on
Commit
c717703
·
verified ·
1 Parent(s): b856510

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +56 -59
README.md CHANGED
@@ -1,81 +1,78 @@
1
- ---
2
- {}
3
  ---
4
 
5
- ---
6
- license: apache-2.0
7
- tags:
8
- - mlx
9
- - mlx-image
10
- - vision
11
- - image-classification
12
- datasets:
13
- - imagenet-1k
14
- library_name: mlx-image
15
- ---
16
- # vit_small_patch8_224.dino
17
 
18
- A [Vision Transformer](https://arxiv.org/abs/2010.11929v2) image classification model trained on ImageNet-1k dataset with [DINO](https://arxiv.org/abs/2104.14294).
19
 
20
- The model was trained in self-supervised fashion on ImageNet-1k dataset. No classification head was trained, only the backbone.
21
 
22
- Disclaimer: This is a porting of the torch model weights to Apple MLX Framework.
23
 
24
- <div align="center">
25
- <img width="100%" alt="DINO illustration" src="dino.gif">
26
- </div>
27
 
28
 
29
- ## How to use
30
- ```bash
31
- pip install mlx-image
32
- ```
33
 
34
- Here is how to use this model for image classification:
35
 
36
- ```python
37
- from mlxim.model import create_model
38
- from mlxim.io import read_rgb
39
- from mlxim.transform import ImageNetTransform
40
 
41
- transform = ImageNetTransform(train=False, img_size=224)
42
- x = transform(read_rgb("cat.png"))
43
- x = mx.expand_dims(x, 0)
44
 
45
- model = create_model("vit_small_patch8_224.dino")
46
- model.eval()
47
 
48
- logits, attn_masks = model(x, attn_masks=True)
49
- ```
50
 
51
- You can also use the embeds from layer before head:
52
- ```python
53
- from mlxim.model import create_model
54
- from mlxim.io import read_rgb
55
- from mlxim.transform import ImageNetTransform
56
 
57
- transform = ImageNetTransform(train=False, img_size=512)
58
- x = transform(read_rgb("cat.png"))
59
- x = mx.expand_dims(x, 0)
60
 
61
- # first option
62
- model = create_model("vit_small_patch8_224.dino", num_classes=0)
63
- model.eval()
64
 
65
- embeds = model(x)
66
 
67
- # second option
68
- model = create_model("vit_small_patch8_224.dino")
69
- model.eval()
70
 
71
- embeds, attn_masks = model.get_features(x)
72
- ```
73
 
74
- ## Attention maps
75
- You can visualize the attention maps using the `attn_masks` returned by the model. Go check the mlx-image [notebook](https://github.com/riccardomusmeci/mlx-image/blob/main/notebooks/dino_attention.ipynb).
76
 
77
- <div align="center">
78
- <img width="100%" alt="Attention Map" src="attention_maps.png">
79
- </div>
80
 
81
-
 
 
 
1
  ---
2
 
3
+ ---
4
+ license: apache-2.0
5
+ tags:
6
+ - mlx
7
+ - mlx-image
8
+ - vision
9
+ - image-classification
10
+ datasets:
11
+ - imagenet-1k
12
+ library_name: mlx-image
13
+ ---
14
+ # vit_small_patch8_224.dino
15
 
16
+ A [Vision Transformer](https://arxiv.org/abs/2010.11929v2) image classification model trained on ImageNet-1k dataset with [DINO](https://arxiv.org/abs/2104.14294).
17
 
18
+ The model was trained in self-supervised fashion on ImageNet-1k dataset. No classification head was trained, only the backbone.
19
 
20
+ Disclaimer: This is a porting of the torch model weights to Apple MLX Framework.
21
 
22
+ <div align="center">
23
+ <img width="100%" alt="DINO illustration" src="dino.gif">
24
+ </div>
25
 
26
 
27
+ ## How to use
28
+ ```bash
29
+ pip install mlx-image
30
+ ```
31
 
32
+ Here is how to use this model for image classification:
33
 
34
+ ```python
35
+ from mlxim.model import create_model
36
+ from mlxim.io import read_rgb
37
+ from mlxim.transform import ImageNetTransform
38
 
39
+ transform = ImageNetTransform(train=False, img_size=224)
40
+ x = transform(read_rgb("cat.png"))
41
+ x = mx.expand_dims(x, 0)
42
 
43
+ model = create_model("vit_small_patch8_224.dino")
44
+ model.eval()
45
 
46
+ logits, attn_masks = model(x, attn_masks=True)
47
+ ```
48
 
49
+ You can also use the embeds from layer before head:
50
+ ```python
51
+ from mlxim.model import create_model
52
+ from mlxim.io import read_rgb
53
+ from mlxim.transform import ImageNetTransform
54
 
55
+ transform = ImageNetTransform(train=False, img_size=512)
56
+ x = transform(read_rgb("cat.png"))
57
+ x = mx.expand_dims(x, 0)
58
 
59
+ # first option
60
+ model = create_model("vit_small_patch8_224.dino", num_classes=0)
61
+ model.eval()
62
 
63
+ embeds = model(x)
64
 
65
+ # second option
66
+ model = create_model("vit_small_patch8_224.dino")
67
+ model.eval()
68
 
69
+ embeds, attn_masks = model.get_features(x)
70
+ ```
71
 
72
+ ## Attention maps
73
+ You can visualize the attention maps using the `attn_masks` returned by the model. Go check the mlx-image [notebook](https://github.com/riccardomusmeci/mlx-image/blob/main/notebooks/dino_attention.ipynb).
74
 
75
+ <div align="center">
76
+ <img width="100%" alt="Attention Map" src="attention_maps.png">
77
+ </div>
78