# train_infer.py # Usage: # Train: python train_infer.py --mode train --train_dir data/train --val_dir data/val --epochs 3 # Predict: python train_infer.py --mode predict --image img.jpg --ckpt ckpt.pth import argparse, os from pathlib import Path from PIL import Image import torch, torch.nn as nn, torch.optim as optim from torchvision import transforms, models from torchvision.datasets import ImageFolder from torch.utils.data import DataLoader def get_loaders(train_dir, val_dir, batch=16, img=224): tr = transforms.Compose([transforms.RandomResizedCrop(img), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])]) val = transforms.Compose([transforms.Resize(int(img*1.14)), transforms.CenterCrop(img), transforms.ToTensor(), transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])]) train_ds = ImageFolder(train_dir, transform=tr) val_ds = ImageFolder(val_dir, transform=val) return DataLoader(train_ds, batch_size=batch, shuffle=True), DataLoader(val_ds, batch_size=batch), train_ds.classes def train(args): train_loader, val_loader, classes = get_loaders(args.train_dir, args.val_dir, args.batch) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = models.resnet18(pretrained=True) model.fc = nn.Linear(model.fc.in_features, len(classes)) model.to(device) opt = optim.Adam(model.parameters(), lr=1e-4) loss_fn = nn.CrossEntropyLoss() best=0 for e in range(args.epochs): model.train() for imgs, lbl in train_loader: imgs, lbl = imgs.to(device), lbl.to(device) opt.zero_grad() out = model(imgs) loss = loss_fn(out, lbl) loss.backward(); opt.step() # val model.eval() correct=total=0 with torch.no_grad(): for imgs,lbl in val_loader: imgs,lbl = imgs.to(device), lbl.to(device) out = model(imgs).argmax(1) correct += (out==lbl).sum().item(); total += lbl.size(0) acc = correct/total if total else 0 print(f"Epoch {e+1}/{args.epochs} val_acc={acc:.4f}") if acc>best: best=acc torch.save({'model':model.state_dict(),'classes':classes}, args.ckpt) print("Best val acc", best) def predict(args): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") ck = torch.load(args.ckpt, map_location=device) classes = ck['classes'] model = models.resnet18(pretrained=False) model.fc = nn.Linear(model.fc.in_features, len(classes)) model.load_state_dict(ck['model']); model.to(device).eval() tf = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])]) img = Image.open(args.image).convert('RGB') x = tf(img).unsqueeze(0).to(device) with torch.no_grad(): p = model(x).argmax(1).item() print("Predicted:", classes[p]) if __name__=="__main__": p=argparse.ArgumentParser() p.add_argument('--mode', choices=['train','predict'], required=True) p.add_argument('--train_dir', default='data/train') p.add_argument('--val_dir', default='data/val') p.add_argument('--epochs', type=int, default=3) p.add_argument('--batch', type=int, default=16) p.add_argument('--ckpt', default='ckpt.pth') p.add_argument('--image') args=p.parse_args() if args.mode=='train': train(args) else: if not args.image: raise SystemExit("Provide --image for predict") predict(args)