sadhaklal commited on
Commit
d6c1cae
·
verified ·
1 Parent(s): df87f01

added "Usage" section in README.md

Browse files
Files changed (1) hide show
  1. README.md +55 -1
README.md CHANGED
@@ -19,6 +19,60 @@ Code: https://github.com/sambitmukherjee/dlwpt-exercises/blob/main/chapter_7/exe
19
 
20
  Experiment tracking: https://wandb.ai/sadhaklal/mlp-cifar2-v2
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  ## Metric
23
 
24
- Accuracy on cifar2_val: 0.829
 
19
 
20
  Experiment tracking: https://wandb.ai/sadhaklal/mlp-cifar2-v2
21
 
22
+ ## Usage
23
+
24
+ !pip install -q datasets
25
+
26
+ from datasets import load_dataset
27
+
28
+ cifar10 = load_dataset("cifar10")
29
+ label_map = {0: 0.0, 2: 1.0}
30
+ class_names = ['airplane', 'bird']
31
+ cifar2_train = [(example['img'], label_map[example['label']]) for example in cifar10['train'] if example['label'] in [0, 2]]
32
+ cifar2_val = [(example['img'], label_map[example['label']]) for example in cifar10['test'] if example['label'] in [0, 2]]
33
+
34
+ example = cifar2_val[0]
35
+ img, label = example
36
+
37
+ import torch
38
+ from torchvision.transforms import v2
39
+
40
+ val_tfms = v2.Compose([
41
+ v2.ToImage(),
42
+ v2.ToDtype(torch.float32, scale=True),
43
+ v2.Normalize(mean=[0.4915, 0.4823, 0.4468], std=[0.2470, 0.2435, 0.2616])
44
+ ])
45
+ img = val_tfms(img)
46
+ batch = img.reshape(-1).unsqueeze(0) # Flatten.
47
+
48
+ import torch.nn as nn
49
+ from huggingface_hub import PyTorchModelHubMixin
50
+
51
+ class MLPForCIFAR2(nn.Module, PyTorchModelHubMixin):
52
+ def __init__(self):
53
+ super().__init__()
54
+ self.mlp = nn.Sequential(
55
+ nn.Linear(3072, 64), # Hidden layer.
56
+ nn.Tanh(),
57
+ nn.Linear(64, 1) # Output layer.
58
+ )
59
+
60
+ def forward(self, x):
61
+ return self.mlp(x)
62
+
63
+ model = MLPForCIFAR2.from_pretrained("sadhaklal/mlp-cifar2-v2")
64
+ model.eval()
65
+
66
+ import torch.nn.functional as F
67
+
68
+ with torch.no_grad():
69
+ logits = model(batch)
70
+ proba = F.sigmoid(logits.squeeze())
71
+ pred = int(proba.item() > 0.5)
72
+
73
+ print(f"Predicted class: {class_names[pred]}")
74
+ print(f"Predicted class probabilities ('airplane' vs. 'bird'): {[proba.item(), 1 - proba.item()]}")
75
+
76
  ## Metric
77
 
78
+ Accuracy on `cifar2_val`: 0.829