| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torchvision import transforms | |
| from PIL import Image | |
| import streamlit as st | |
| import numpy as np | |
| import requests | |
| from io import BytesIO | |
| from kan_linear import KANLinear | |
| class CNNKAN(nn.Module): | |
| def __init__(self): | |
| super(CNNKAN, self).__init__() | |
| self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1) | |
| self.bn1 = nn.BatchNorm2d(32) | |
| self.pool1 = nn.MaxPool2d(2) | |
| self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) | |
| self.bn2 = nn.BatchNorm2d(64) | |
| self.pool2 = nn.MaxPool2d(2) | |
| self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1) | |
| self.bn3 = nn.BatchNorm2d(128) | |
| self.pool3 = nn.MaxPool2d(2) | |
| self.conv4 = nn.Conv2d(128, 256, kernel_size=3, padding=1) | |
| self.bn4 = nn.BatchNorm2d(256) | |
| self.pool4 = nn.MaxPool2d(2) | |
| self.dropout = nn.Dropout(0.5) | |
| self.kan1 = KANLinear(256 * 12 * 12, 512) | |
| self.kan2 = KANLinear(512, 1) | |
| def forward(self, x): | |
| x = F.selu(self.bn1(self.conv1(x))) | |
| x = self.pool1(x) | |
| x = F.selu(self.bn2(self.conv2(x))) | |
| x = self.pool2(x) | |
| x = F.selu(self.bn3(self.conv3(x))) | |
| x = self.pool3(x) | |
| x = F.selu(self.bn4(self.conv4(x))) | |
| x = self.pool4(x) | |
| x = x.view(x.size(0), -1) | |
| x = self.dropout(x) | |
| x = self.kan1(x) | |
| x = self.dropout(x) | |
| x = self.kan2(x) | |
| return x | |
| def load_model(weights_path, device): | |
| model = CNNKAN().to(device) | |
| state_dict = torch.load(weights_path, map_location=device) | |
| # Remove 'module.' prefix from keys | |
| from collections import OrderedDict | |
| new_state_dict = OrderedDict() | |
| for k, v in state_dict.items(): | |
| if k.startswith('module.'): | |
| new_state_dict[k[7:]] = v | |
| else: | |
| new_state_dict[k] = v | |
| model.load_state_dict(new_state_dict) | |
| model.eval() | |
| return model | |
| def load_image_from_url(url): | |
| response = requests.get(url) | |
| img = Image.open(BytesIO(response.content)).convert('RGB') | |
| return img | |
| def preprocess_image(image): | |
| transform = transforms.Compose([ | |
| transforms.Resize((200, 200)), | |
| transforms.ToTensor() | |
| ]) | |
| return transform(image).unsqueeze(0) | |
| # Streamlit app | |
| st.title("Cat and Dog Classification with CNN-KAN") | |
| st.sidebar.title("Upload Images") | |
| uploaded_file = st.sidebar.file_uploader("Choose an image...", type=["jpg", "jpeg", "png", "webp"]) | |
| image_url = st.sidebar.text_input("Or enter image URL...") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = load_model('weights/best_model_weights_KAN.pth', device) | |
| img = None | |
| if uploaded_file is not None: | |
| img = Image.open(uploaded_file).convert('RGB') | |
| elif image_url: | |
| try: | |
| img = load_image_from_url(image_url) | |
| except Exception as e: | |
| st.sidebar.error(f"Error loading image from URL: {e}") | |
| st.sidebar.write("-----") | |
| # Define your information for the footer | |
| name = "Wayan Dadang" | |
| st.sidebar.write("Follow me on:") | |
| # Create a footer section with links and copyright information | |
| st.sidebar.markdown(f""" | |
| [LinkedIn](https://www.linkedin.com/in/wayan-dadang-801757116/) | |
| [GitHub](https://github.com/Wayan123) | |
| [Resume](https://wayan123.github.io/) | |
| © {name} - {2024} | |
| """, unsafe_allow_html=True) | |
| if img is not None: | |
| st.image(np.array(img), caption='Uploaded Image.', use_column_width=True) | |
| if st.button('Predict'): | |
| img_tensor = preprocess_image(img).to(device) | |
| with torch.no_grad(): | |
| output = model(img_tensor) | |
| prob = torch.sigmoid(output).item() | |
| st.write(f"Prediction: {prob:.4f}") | |
| if prob < 0.5: | |
| st.write("This image is classified as a Cat.") | |
| else: | |
| st.write("This image is classified as a Dog") | |