image-classifier-simple / train_infer.py
Tasfiya025's picture
Create train_infer.py
3afd67c verified
# train_infer.py
# Usage:
# Train: python train_infer.py --mode train --train_dir data/train --val_dir data/val --epochs 3
# Predict: python train_infer.py --mode predict --image img.jpg --ckpt ckpt.pth
import argparse, os
from pathlib import Path
from PIL import Image
import torch, torch.nn as nn, torch.optim as optim
from torchvision import transforms, models
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
def get_loaders(train_dir, val_dir, batch=16, img=224):
tr = transforms.Compose([transforms.RandomResizedCrop(img), transforms.RandomHorizontalFlip(),
transforms.ToTensor(), transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])])
val = transforms.Compose([transforms.Resize(int(img*1.14)), transforms.CenterCrop(img),
transforms.ToTensor(), transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])])
train_ds = ImageFolder(train_dir, transform=tr)
val_ds = ImageFolder(val_dir, transform=val)
return DataLoader(train_ds, batch_size=batch, shuffle=True), DataLoader(val_ds, batch_size=batch), train_ds.classes
def train(args):
train_loader, val_loader, classes = get_loaders(args.train_dir, args.val_dir, args.batch)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = models.resnet18(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, len(classes))
model.to(device)
opt = optim.Adam(model.parameters(), lr=1e-4)
loss_fn = nn.CrossEntropyLoss()
best=0
for e in range(args.epochs):
model.train()
for imgs, lbl in train_loader:
imgs, lbl = imgs.to(device), lbl.to(device)
opt.zero_grad()
out = model(imgs)
loss = loss_fn(out, lbl)
loss.backward(); opt.step()
# val
model.eval()
correct=total=0
with torch.no_grad():
for imgs,lbl in val_loader:
imgs,lbl = imgs.to(device), lbl.to(device)
out = model(imgs).argmax(1)
correct += (out==lbl).sum().item(); total += lbl.size(0)
acc = correct/total if total else 0
print(f"Epoch {e+1}/{args.epochs} val_acc={acc:.4f}")
if acc>best:
best=acc
torch.save({'model':model.state_dict(),'classes':classes}, args.ckpt)
print("Best val acc", best)
def predict(args):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ck = torch.load(args.ckpt, map_location=device)
classes = ck['classes']
model = models.resnet18(pretrained=False)
model.fc = nn.Linear(model.fc.in_features, len(classes))
model.load_state_dict(ck['model']); model.to(device).eval()
tf = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224),
transforms.ToTensor(), transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])])
img = Image.open(args.image).convert('RGB')
x = tf(img).unsqueeze(0).to(device)
with torch.no_grad():
p = model(x).argmax(1).item()
print("Predicted:", classes[p])
if __name__=="__main__":
p=argparse.ArgumentParser()
p.add_argument('--mode', choices=['train','predict'], required=True)
p.add_argument('--train_dir', default='data/train')
p.add_argument('--val_dir', default='data/val')
p.add_argument('--epochs', type=int, default=3)
p.add_argument('--batch', type=int, default=16)
p.add_argument('--ckpt', default='ckpt.pth')
p.add_argument('--image')
args=p.parse_args()
if args.mode=='train':
train(args)
else:
if not args.image: raise SystemExit("Provide --image for predict")
predict(args)