Update model config and README
Browse files- README.md +21 -17
- model.safetensors +3 -0
    	
        README.md
    CHANGED
    
    | @@ -2,7 +2,7 @@ | |
| 2 | 
             
            tags:
         | 
| 3 | 
             
            - image-classification
         | 
| 4 | 
             
            - timm
         | 
| 5 | 
            -
             | 
| 6 | 
             
            license: apache-2.0
         | 
| 7 | 
             
            datasets:
         | 
| 8 | 
             
            - imagenet-1k
         | 
| @@ -14,7 +14,7 @@ A timm specific MaxViT (w/ a MLP Log-CPB (continuous log-coordinate relative pos | |
| 14 |  | 
| 15 | 
             
            ImageNet-12k pretraining and ImageNet-1k fine-tuning performed on 8x GPU [Lambda Labs](https://lambdalabs.com/) cloud instances..
         | 
| 16 |  | 
| 17 | 
            -
            ### Model Variants in [maxxvit.py](https://github.com/ | 
| 18 |  | 
| 19 | 
             
            MaxxViT covers a number of related model architectures that share a common structure including:
         | 
| 20 | 
             
            - CoAtNet - Combining MBConv (depthwise-separable) convolutional blocks in early stages with self-attention transformer blocks in later stages.
         | 
| @@ -46,8 +46,9 @@ from urllib.request import urlopen | |
| 46 | 
             
            from PIL import Image
         | 
| 47 | 
             
            import timm
         | 
| 48 |  | 
| 49 | 
            -
            img = Image.open(
         | 
| 50 | 
            -
                 | 
|  | |
| 51 |  | 
| 52 | 
             
            model = timm.create_model('maxvit_rmlp_base_rw_384.sw_in12k_ft_in1k', pretrained=True)
         | 
| 53 | 
             
            model = model.eval()
         | 
| @@ -67,8 +68,9 @@ from urllib.request import urlopen | |
| 67 | 
             
            from PIL import Image
         | 
| 68 | 
             
            import timm
         | 
| 69 |  | 
| 70 | 
            -
            img = Image.open(
         | 
| 71 | 
            -
                 | 
|  | |
| 72 |  | 
| 73 | 
             
            model = timm.create_model(
         | 
| 74 | 
             
                'maxvit_rmlp_base_rw_384.sw_in12k_ft_in1k',
         | 
| @@ -85,12 +87,13 @@ output = model(transforms(img).unsqueeze(0))  # unsqueeze single image into batc | |
| 85 |  | 
| 86 | 
             
            for o in output:
         | 
| 87 | 
             
                # print shape of each feature map in output
         | 
| 88 | 
            -
                # e.g.: | 
| 89 | 
            -
                #  torch.Size([1,  | 
| 90 | 
            -
                #  torch.Size([1,  | 
| 91 | 
            -
                #  torch.Size([1,  | 
| 92 | 
            -
                #  torch.Size([1,  | 
| 93 | 
            -
                #  torch.Size([1,  | 
|  | |
| 94 | 
             
                print(o.shape)
         | 
| 95 | 
             
            ```
         | 
| 96 |  | 
| @@ -100,8 +103,9 @@ from urllib.request import urlopen | |
| 100 | 
             
            from PIL import Image
         | 
| 101 | 
             
            import timm
         | 
| 102 |  | 
| 103 | 
            -
            img = Image.open(
         | 
| 104 | 
            -
                 | 
|  | |
| 105 |  | 
| 106 | 
             
            model = timm.create_model(
         | 
| 107 | 
             
                'maxvit_rmlp_base_rw_384.sw_in12k_ft_in1k',
         | 
| @@ -119,10 +123,10 @@ output = model(transforms(img).unsqueeze(0))  # output is (batch_size, num_featu | |
| 119 | 
             
            # or equivalently (without needing to set num_classes=0)
         | 
| 120 |  | 
| 121 | 
             
            output = model.forward_features(transforms(img).unsqueeze(0))
         | 
| 122 | 
            -
            # output is unpooled  | 
| 123 |  | 
| 124 | 
             
            output = model.forward_head(output, pre_logits=True)
         | 
| 125 | 
            -
            # output is ( | 
| 126 | 
             
            ```
         | 
| 127 |  | 
| 128 | 
             
            ## Model Comparison
         | 
| @@ -230,7 +234,7 @@ output = model.forward_head(output, pre_logits=True) | |
| 230 | 
             
              publisher = {GitHub},
         | 
| 231 | 
             
              journal = {GitHub repository},
         | 
| 232 | 
             
              doi = {10.5281/zenodo.4414861},
         | 
| 233 | 
            -
              howpublished = {\url{https://github.com/ | 
| 234 | 
             
            }
         | 
| 235 | 
             
            ```
         | 
| 236 | 
             
            ```bibtex
         | 
|  | |
| 2 | 
             
            tags:
         | 
| 3 | 
             
            - image-classification
         | 
| 4 | 
             
            - timm
         | 
| 5 | 
            +
            library_name: timm
         | 
| 6 | 
             
            license: apache-2.0
         | 
| 7 | 
             
            datasets:
         | 
| 8 | 
             
            - imagenet-1k
         | 
|  | |
| 14 |  | 
| 15 | 
             
            ImageNet-12k pretraining and ImageNet-1k fine-tuning performed on 8x GPU [Lambda Labs](https://lambdalabs.com/) cloud instances..
         | 
| 16 |  | 
| 17 | 
            +
            ### Model Variants in [maxxvit.py](https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/maxxvit.py)
         | 
| 18 |  | 
| 19 | 
             
            MaxxViT covers a number of related model architectures that share a common structure including:
         | 
| 20 | 
             
            - CoAtNet - Combining MBConv (depthwise-separable) convolutional blocks in early stages with self-attention transformer blocks in later stages.
         | 
|  | |
| 46 | 
             
            from PIL import Image
         | 
| 47 | 
             
            import timm
         | 
| 48 |  | 
| 49 | 
            +
            img = Image.open(urlopen(
         | 
| 50 | 
            +
                'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
         | 
| 51 | 
            +
            ))
         | 
| 52 |  | 
| 53 | 
             
            model = timm.create_model('maxvit_rmlp_base_rw_384.sw_in12k_ft_in1k', pretrained=True)
         | 
| 54 | 
             
            model = model.eval()
         | 
|  | |
| 68 | 
             
            from PIL import Image
         | 
| 69 | 
             
            import timm
         | 
| 70 |  | 
| 71 | 
            +
            img = Image.open(urlopen(
         | 
| 72 | 
            +
                'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
         | 
| 73 | 
            +
            ))
         | 
| 74 |  | 
| 75 | 
             
            model = timm.create_model(
         | 
| 76 | 
             
                'maxvit_rmlp_base_rw_384.sw_in12k_ft_in1k',
         | 
|  | |
| 87 |  | 
| 88 | 
             
            for o in output:
         | 
| 89 | 
             
                # print shape of each feature map in output
         | 
| 90 | 
            +
                # e.g.:
         | 
| 91 | 
            +
                #  torch.Size([1, 64, 192, 192])
         | 
| 92 | 
            +
                #  torch.Size([1, 96, 96, 96])
         | 
| 93 | 
            +
                #  torch.Size([1, 192, 48, 48])
         | 
| 94 | 
            +
                #  torch.Size([1, 384, 24, 24])
         | 
| 95 | 
            +
                #  torch.Size([1, 768, 12, 12])
         | 
| 96 | 
            +
             | 
| 97 | 
             
                print(o.shape)
         | 
| 98 | 
             
            ```
         | 
| 99 |  | 
|  | |
| 103 | 
             
            from PIL import Image
         | 
| 104 | 
             
            import timm
         | 
| 105 |  | 
| 106 | 
            +
            img = Image.open(urlopen(
         | 
| 107 | 
            +
                'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
         | 
| 108 | 
            +
            ))
         | 
| 109 |  | 
| 110 | 
             
            model = timm.create_model(
         | 
| 111 | 
             
                'maxvit_rmlp_base_rw_384.sw_in12k_ft_in1k',
         | 
|  | |
| 123 | 
             
            # or equivalently (without needing to set num_classes=0)
         | 
| 124 |  | 
| 125 | 
             
            output = model.forward_features(transforms(img).unsqueeze(0))
         | 
| 126 | 
            +
            # output is unpooled, a (1, 768, 12, 12) shaped tensor
         | 
| 127 |  | 
| 128 | 
             
            output = model.forward_head(output, pre_logits=True)
         | 
| 129 | 
            +
            # output is a (1, num_features) shaped tensor
         | 
| 130 | 
             
            ```
         | 
| 131 |  | 
| 132 | 
             
            ## Model Comparison
         | 
|  | |
| 234 | 
             
              publisher = {GitHub},
         | 
| 235 | 
             
              journal = {GitHub repository},
         | 
| 236 | 
             
              doi = {10.5281/zenodo.4414861},
         | 
| 237 | 
            +
              howpublished = {\url{https://github.com/huggingface/pytorch-image-models}}
         | 
| 238 | 
             
            }
         | 
| 239 | 
             
            ```
         | 
| 240 | 
             
            ```bibtex
         | 
    	
        model.safetensors
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:36937552a45c431fbd34e77456c0c2f44c0874632bed47f28c650c1a0fbf8821
         | 
| 3 | 
            +
            size 465234858
         | 

