|
|
import gradio as gr |
|
|
import timm |
|
|
import torch |
|
|
from cods.classif.cp import ClassificationConformalizer |
|
|
from cods.classif.data import ClassificationDataset |
|
|
from cods.classif.data.predictions import ClassificationPredictions |
|
|
from cods.classif.models import ClassificationModel |
|
|
from datasets import load_dataset |
|
|
|
|
|
|
|
|
from PIL import Image |
|
|
|
|
|
|
|
|
from dataset import DatasetWrapper |
|
|
|
|
|
DATASETS = { |
|
|
"miniimagenet": "timm/mini-imagenet", |
|
|
"imagenette": "frgfm/imagenette", |
|
|
"imagenet": "imagenet-1k", |
|
|
} |
|
|
|
|
|
MODELS = { |
|
|
"miniimagenet": [ |
|
|
"QuentinJG/ResNet18-miniimagenet", |
|
|
"shahrukhx01/vit-base-patch16-miniimagenet", |
|
|
], |
|
|
} |
|
|
|
|
|
classification_conformalizer = ClassificationConformalizer(method="lac", preprocess="softmax") |
|
|
|
|
|
|
|
|
def calibrate(dataset_name, model_name): |
|
|
global model |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_name = "resnet34" |
|
|
global pretrained_resnet_34 |
|
|
pretrained_resnet_34 = timm.create_model(model_name, pretrained=True) |
|
|
classifier = ClassificationModel(model=pretrained_resnet_34, model_name=model_name) |
|
|
global dataset |
|
|
dataset = load_dataset(DATASETS[dataset_name], split="validation") |
|
|
dataset = DatasetWrapper(dataset) |
|
|
|
|
|
val_preds = classifier.build_predictions( |
|
|
dataset, |
|
|
dataset_name=dataset_name, |
|
|
split_name="cal", |
|
|
batch_size=512, |
|
|
shuffle=False, |
|
|
) |
|
|
classification_conformalizer.calibrate(val_preds, alpha=0.1) |
|
|
return f"Calibrated on {dataset_name} with model {model_name}" |
|
|
|
|
|
|
|
|
def predict_image(img): |
|
|
img_old = img.copy() |
|
|
img = dataset.transforms(img).unsqueeze(0) |
|
|
pred = pretrained_resnet_34(img) |
|
|
inference_pred = ClassificationPredictions( |
|
|
dataset_name="uploaded", |
|
|
split_name="test", |
|
|
image_paths=[None], |
|
|
idx_to_cls=dataset.idx_to_cls, |
|
|
true_cls=torch.tensor([-1]), |
|
|
pred_cls=pred, |
|
|
) |
|
|
|
|
|
result = classification_conformalizer.conformalize(inference_pred) |
|
|
list_of_classes = [dataset.idx_to_cls[i] for i in result[0].detach().numpy()] |
|
|
result = f"Predicted classes with 90% confidence: {list_of_classes}" |
|
|
return img_old, result |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main_function(lbd, img): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
new_img = img |
|
|
return new_img |
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("# Image Classification with Conformal Prediction") |
|
|
gr.Markdown("## Upload an image and get conformalized classification predictions.") |
|
|
|
|
|
with gr.Row(): |
|
|
dataset_dropdown = gr.Dropdown( |
|
|
choices=DATASETS.keys(), label="Select Dataset", value=list(DATASETS.keys())[0] |
|
|
) |
|
|
model_dropdown = gr.Dropdown( |
|
|
choices=MODELS[dataset_dropdown.value], |
|
|
label="Select Model", |
|
|
value=MODELS[dataset_dropdown.value][0], |
|
|
) |
|
|
|
|
|
calibrate_btn = gr.Button("Calibrate") |
|
|
status_text = gr.Textbox(label="Status", interactive=False) |
|
|
|
|
|
gr.Markdown("---") |
|
|
|
|
|
with gr.Row(): |
|
|
input_image = gr.Image(label="Upload Image", type="pil") |
|
|
output_image = gr.Image(label="Processed Image") |
|
|
|
|
|
predict_btn = gr.Button("Predict") |
|
|
result_text = gr.Textbox(label="Prediction Result") |
|
|
|
|
|
|
|
|
calibrate_btn.click( |
|
|
fn=calibrate, inputs=[dataset_dropdown, model_dropdown], outputs=status_text |
|
|
) |
|
|
|
|
|
predict_btn.click(fn=predict_image, inputs=input_image, outputs=[output_image, result_text]) |
|
|
|
|
|
demo.launch() |
|
|
|