--- license: apache-2.0 datasets: - marmal88/skin_cancer base_model: - google/vit-base-patch16-224-in21k pipeline_tag: image-classification tags: - medical --- ## Installation First, clone the repository: ```bash git clone https://github.com/ethicalabs-ai/SkinCancerViT.git cd SkinCancerViT ``` Then, install the package in editable mode using uv (or pip): ```bash uv sync # Recommended if you use uv # Or, if using pip: # pip install -e . ``` ## Quick Start / Usage This package allows you to load and use a pre-trained SkinCancerViT model for prediction. ```python import torch from skincancer_vit.model import SkinCancerViTModel from PIL import Image from datasets import load_dataset # To get a random sample # Load the model from Hugging Face Hub device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = SkinCancerViTModel.from_pretrained("ethicalabs/SkinCancerViT") model.to(device) # Move model to the desired device model.eval() # Set model to evaluation mode # Example Prediction from a Specific Image File image_file_path = "images/patient-001.jpg" # Specify your image file path here specific_image = Image.open(image_file_path).convert("RGB") # Example tabular data for this prediction specific_age = 42 specific_localization = "face" # Ensure this matches one of your trained localization categories predicted_dx, confidence = model.full_predict( raw_image=specific_image, raw_age=specific_age, raw_localization=specific_localization, device=device ) print(f"Predicted Diagnosis: {predicted_dx}") print(f"Confidence: {confidence:.4f}") # Example Prediction from a Random Test Sample from the Dataset dataset = load_dataset("marmal88/skin_cancer", split="test") random_sample = dataset.shuffle(seed=42).select(range(1))[0] # Get the first shuffled sample sample_image = random_sample["image"] sample_age = random_sample["age"] sample_localization = random_sample["localization"] sample_true_dx = random_sample["dx"] predicted_dx_sample, confidence_sample = model.full_predict( raw_image=sample_image, raw_age=sample_age, raw_localization=sample_localization, device=device ) print(f"Predicted Diagnosis: {predicted_dx_sample}") print(f"Confidence: {confidence_sample:.4f}") print(f"Correct Prediction: {predicted_dx_sample == sample_true_dx}") ```