import gradio as gr import timm import torch from cods.classif.cp import ClassificationConformalizer from cods.classif.data import ClassificationDataset from cods.classif.data.predictions import ClassificationPredictions from cods.classif.models import ClassificationModel from datasets import load_dataset # from ultralytics import YOLO from PIL import Image # from transformers import AutoImageProcessor, AutoModelForImageClassification from dataset import DatasetWrapper DATASETS = { "miniimagenet": "timm/mini-imagenet", "imagenette": "frgfm/imagenette", "imagenet": "imagenet-1k", } MODELS = { "miniimagenet": [ "QuentinJG/ResNet18-miniimagenet", "shahrukhx01/vit-base-patch16-miniimagenet", ], } classification_conformalizer = ClassificationConformalizer(method="lac", preprocess="softmax") def calibrate(dataset_name, model_name): global model # processor = AutoImageProcessor.from_pretrained(model_name) # model = AutoModelForImageClassification.from_pretrained(model_name) # model = #lambda x: model(processor(x))#**processor(x, return_tensors="pt")) # model = timm.create_model(model_name, pretrained=True)#, num_classes=100) model_name = "resnet34" global pretrained_resnet_34 pretrained_resnet_34 = timm.create_model(model_name, pretrained=True) classifier = ClassificationModel(model=pretrained_resnet_34, model_name=model_name) global dataset dataset = load_dataset(DATASETS[dataset_name], split="validation") dataset = DatasetWrapper(dataset) val_preds = classifier.build_predictions( dataset, dataset_name=dataset_name, split_name="cal", batch_size=512, shuffle=False, ) classification_conformalizer.calibrate(val_preds, alpha=0.1) return f"Calibrated on {dataset_name} with model {model_name}" def predict_image(img): img_old = img.copy() img = dataset.transforms(img).unsqueeze(0) pred = pretrained_resnet_34(img) inference_pred = ClassificationPredictions( dataset_name="uploaded", split_name="test", image_paths=[None], idx_to_cls=dataset.idx_to_cls, true_cls=torch.tensor([-1]), # Placeholder for true class pred_cls=pred, # Placeholder for predicted class probabilities ) result = classification_conformalizer.conformalize(inference_pred) list_of_classes = [dataset.idx_to_cls[i] for i in result[0].detach().numpy()] result = f"Predicted classes with 90% confidence: {list_of_classes}" return img_old, result # Load a pretrained YOLOv8n model # model = YOLO("yolov8n.pt") def main_function(lbd, img): # results = model(img) # predict on an image # r = results[0] # im_bgr = r.plot() # BGR-order numpy array # im_rgb = Image.fromarray(im_bgr[..., ::-1]) # RGB-order PIL image # new_img = im_rgb # res = results[0].save(filename="output.jpg") # save the image # # load image # new_img = Image.open("output.jpg") new_img = img return new_img with gr.Blocks() as demo: gr.Markdown("# Image Classification with Conformal Prediction") gr.Markdown("## Upload an image and get conformalized classification predictions.") with gr.Row(): dataset_dropdown = gr.Dropdown( choices=DATASETS.keys(), label="Select Dataset", value=list(DATASETS.keys())[0] ) model_dropdown = gr.Dropdown( choices=MODELS[dataset_dropdown.value], label="Select Model", value=MODELS[dataset_dropdown.value][0], ) calibrate_btn = gr.Button("Calibrate") status_text = gr.Textbox(label="Status", interactive=False) gr.Markdown("---") with gr.Row(): input_image = gr.Image(label="Upload Image", type="pil") output_image = gr.Image(label="Processed Image") predict_btn = gr.Button("Predict") result_text = gr.Textbox(label="Prediction Result") # Connect components calibrate_btn.click( fn=calibrate, inputs=[dataset_dropdown, model_dropdown], outputs=status_text ) predict_btn.click(fn=predict_image, inputs=input_image, outputs=[output_image, result_text]) demo.launch()