Tasfiya025 commited on
Commit
3afd67c
·
verified ·
1 Parent(s): 3f4c0bb

Create train_infer.py

Browse files
Files changed (1) hide show
  1. train_infer.py +84 -0
train_infer.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # train_infer.py
2
+ # Usage:
3
+ # Train: python train_infer.py --mode train --train_dir data/train --val_dir data/val --epochs 3
4
+ # Predict: python train_infer.py --mode predict --image img.jpg --ckpt ckpt.pth
5
+
6
+ import argparse, os
7
+ from pathlib import Path
8
+ from PIL import Image
9
+ import torch, torch.nn as nn, torch.optim as optim
10
+ from torchvision import transforms, models
11
+ from torchvision.datasets import ImageFolder
12
+ from torch.utils.data import DataLoader
13
+
14
+ def get_loaders(train_dir, val_dir, batch=16, img=224):
15
+ tr = transforms.Compose([transforms.RandomResizedCrop(img), transforms.RandomHorizontalFlip(),
16
+ transforms.ToTensor(), transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])])
17
+ val = transforms.Compose([transforms.Resize(int(img*1.14)), transforms.CenterCrop(img),
18
+ transforms.ToTensor(), transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])])
19
+ train_ds = ImageFolder(train_dir, transform=tr)
20
+ val_ds = ImageFolder(val_dir, transform=val)
21
+ return DataLoader(train_ds, batch_size=batch, shuffle=True), DataLoader(val_ds, batch_size=batch), train_ds.classes
22
+
23
+ def train(args):
24
+ train_loader, val_loader, classes = get_loaders(args.train_dir, args.val_dir, args.batch)
25
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
+ model = models.resnet18(pretrained=True)
27
+ model.fc = nn.Linear(model.fc.in_features, len(classes))
28
+ model.to(device)
29
+ opt = optim.Adam(model.parameters(), lr=1e-4)
30
+ loss_fn = nn.CrossEntropyLoss()
31
+ best=0
32
+ for e in range(args.epochs):
33
+ model.train()
34
+ for imgs, lbl in train_loader:
35
+ imgs, lbl = imgs.to(device), lbl.to(device)
36
+ opt.zero_grad()
37
+ out = model(imgs)
38
+ loss = loss_fn(out, lbl)
39
+ loss.backward(); opt.step()
40
+ # val
41
+ model.eval()
42
+ correct=total=0
43
+ with torch.no_grad():
44
+ for imgs,lbl in val_loader:
45
+ imgs,lbl = imgs.to(device), lbl.to(device)
46
+ out = model(imgs).argmax(1)
47
+ correct += (out==lbl).sum().item(); total += lbl.size(0)
48
+ acc = correct/total if total else 0
49
+ print(f"Epoch {e+1}/{args.epochs} val_acc={acc:.4f}")
50
+ if acc>best:
51
+ best=acc
52
+ torch.save({'model':model.state_dict(),'classes':classes}, args.ckpt)
53
+ print("Best val acc", best)
54
+
55
+ def predict(args):
56
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
57
+ ck = torch.load(args.ckpt, map_location=device)
58
+ classes = ck['classes']
59
+ model = models.resnet18(pretrained=False)
60
+ model.fc = nn.Linear(model.fc.in_features, len(classes))
61
+ model.load_state_dict(ck['model']); model.to(device).eval()
62
+ tf = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224),
63
+ transforms.ToTensor(), transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])])
64
+ img = Image.open(args.image).convert('RGB')
65
+ x = tf(img).unsqueeze(0).to(device)
66
+ with torch.no_grad():
67
+ p = model(x).argmax(1).item()
68
+ print("Predicted:", classes[p])
69
+
70
+ if __name__=="__main__":
71
+ p=argparse.ArgumentParser()
72
+ p.add_argument('--mode', choices=['train','predict'], required=True)
73
+ p.add_argument('--train_dir', default='data/train')
74
+ p.add_argument('--val_dir', default='data/val')
75
+ p.add_argument('--epochs', type=int, default=3)
76
+ p.add_argument('--batch', type=int, default=16)
77
+ p.add_argument('--ckpt', default='ckpt.pth')
78
+ p.add_argument('--image')
79
+ args=p.parse_args()
80
+ if args.mode=='train':
81
+ train(args)
82
+ else:
83
+ if not args.image: raise SystemExit("Provide --image for predict")
84
+ predict(args)