diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..a9246b2c3fce9b244d9b54b0420511b3b1bf818b Binary files /dev/null and b/.DS_Store differ diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..f97b94d099f966dfa97ac9a6bbff039626256b2f --- /dev/null +++ b/Dockerfile @@ -0,0 +1,55 @@ +# Dockerfile + +# --- Этап 1: Сборка статического фронтенда --- +# Мы используем легковесный образ Node.js для сборки React-приложения +FROM node:18-alpine AS frontend-builder +WORKDIR /app/frontend + +# Копируем только package.json, чтобы кэшировать установку зависимостей +COPY frontend/package.json ./ +COPY frontend/package-lock.json ./ +RUN npm install + +# Копируем весь остальной код фронтенда и запускаем сборку +COPY frontend/ ./ +# Важно: Убедитесь, что в вашем package.json есть скрипт "build" +# Обычно он выглядит так: "build": "vite build" или "react-scripts build" +RUN npm run build + + +# --- Этап 2: Настройка Python-окружения и бэкенда --- +# Используем официальный образ Python +FROM python:3.10-slim +ENV PYTHONDONTWRITEBYTECODE=1 +ENV PYTHONUNBUFFERED=1 + +WORKDIR /app + +# Устанавливаем системные зависимости, если они нужны +# (например, для компиляции C++ расширений в ReConV2) +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential \ + && rm -rf /var/lib/apt/lists/* + +# Устанавливаем Python-зависимости +COPY backend/requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +# Копируем код бэкенда +COPY backend/ . + +# Если у ReConV2 есть C++ расширения, которые нужно компилировать, +# раскомментируйте и адаптируйте следующую строку: +# RUN cd /app/ReConV2/extensions/ && python setup.py install + + +# --- Этап 3: Финальный образ --- +# Копируем собранный фронтенд из первого этапа в папку 'static' +# FastAPI будет автоматически раздавать файлы из этой папки +COPY --from=frontend-builder /app/frontend/dist ./static + +# Открываем порт, на котором будет работать FastAPI (стандартный для HF Spaces) +EXPOSE 7860 + +# Запускаем наш API-сервер с помощью Uvicorn +CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"] \ No newline at end of file diff --git a/README.md b/README.md index 51c824306c9a07a5590f4899392a4db2d2c3820b..a687b24140b9ca53e7be86473d0599fbdbf444b4 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,143 @@ --- -title: Aic25 V2 -emoji: 💻 -colorFrom: pink -colorTo: gray +title: Cross-Modal Object Comparison Tool +emoji: 👀 +colorFrom: green +colorTo: yellow sdk: docker -pinned: false +pinned: true +short_description: Demo of Image <-> 3D <-> Text retrival tool for AI Challenge +license: mit --- -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference + +# 🚀 Cross-Modal 3D Asset Retrieval & Comparison Tool + +[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) +[![React](https://img.shields.io/badge/React-19-blue?logo=react)](https://react.dev/) +[![FastAPI](https://img.shields.io/badge/FastAPI-0.110-green?logo=fastapi)](https://fastapi.tiangolo.com/) +[![PyTorch](https://img.shields.io/badge/PyTorch-2.0-orange?logo=pytorch)](https://pytorch.org/) + +An advanced, full-stack application designed to manage and analyze multi-modal datasets containing 3D models, images, and text descriptions. The tool leverages deep learning models to compute and compare embeddings across different modalities, enabling powerful cross-modal search and retrieval. + +The interface allows users to upload their own datasets, explore a pre-loaded shared dataset, and perform detailed comparisons to find the most similar assets, regardless of their original format. + +--- + +## ✨ Key Features + +- **🗂️ Multi-Modal Dataset Management**: Upload `.zip` archives containing images (`.png`), text (`.txt`), and 3D models (`.stl`). The system automatically processes and indexes them. +- **☁️ Cloud & Local Datasets**: Seamlessly switch between a large, pre-processed shared dataset hosted on the server and local datasets stored securely in your browser's IndexedDB. +- **👁️ Interactive Content Viewer**: + - A high-performance 3D viewer for `.stl` models with zoom/pan/rotate controls, powered by **Three.js**. + - Integrated image and text viewers. + - Fullscreen mode for detailed inspection of any asset. +- **🧠 Powerful Cross-Modal Comparison**: + - **Dataset Item Search**: Select any item within a dataset to instantly see its top matches across all other modalities based on semantic similarity. + - **Ad-Hoc Search**: Upload a new, external image, 3D model, or text snippet to find the most similar items within a selected dataset. +- **📊 Full Analysis Export**: Download the complete, pre-computed similarity matrix for any processed dataset as a `.json` or `.csv` file for offline analysis and reporting. +- **⚡ Responsive & Modern UI**: A clean, fast, and intuitive user interface built with **React**, **TypeScript**, and **TailwindCSS**. +- **🚀 High-Performance Backend**: Powered by **FastAPI** and **PyTorch**, the backend is optimized for asynchronous operations and efficient deep learning inference. + +--- + +## 🛠️ Technical Stack + +| Area | Technology | +| :-------- | :---------------------------------------------------------------------------------------------------------- | +| **Frontend** | [React 19](https://react.dev/), [TypeScript](https://www.typescriptlang.org/), [TailwindCSS](https://tailwindcss.com/), [Three.js](https://threejs.org/), [IndexedDB](https://developer.mozilla.org/en-US/docs/Web/API/IndexedDB_API) | +| **Backend** | [Python 3.10](https://www.python.org/), [FastAPI](https://fastapi.tiangolo.com/), [PyTorch](https://pytorch.org/), [Uvicorn](https://www.uvicorn.org/), [scikit-learn](https://scikit-learn.org/) | +| **Deployment**| [Docker](https://www.docker.com/), [Hugging Face Spaces](https://huggingface.co/spaces) (or any container-based platform) | + +--- + +## 🏛️ Project Architecture + +The application is architected as a modern monorepo with a clear separation between the frontend and backend services, designed for containerization and easy deployment. + +### Frontend (`/frontend`) + +A standalone Single-Page Application (SPA) built with React. +- **`components/`**: Contains reusable UI components, organized by feature (e.g., `DatasetManager`, `ComparisonTool`, `common/`). +- **`services/`**: Handles all side effects and external communication. + - `apiService.ts`: Manages all HTTP requests to the backend API. + - `dbService.ts`: Provides a simple interface for interacting with the browser's IndexedDB for local dataset persistence. + - `comparisonService.ts`: Logic for handling client-side interactions with pre-computed similarity data. +- **`types.ts`**: Centralized TypeScript type definitions for robust data modeling. +- **`App.tsx`**: The main application component that orchestrates state and views. + +### Backend (`/backend`) + +A high-performance API server built with FastAPI. +- **`main.py`**: The main entry point for the FastAPI application. It defines all API endpoints, manages application lifecycle events (like model loading on startup), and serves the static frontend files. +- **`inference_utils.py`**: The core of the AI logic. It handles ZIP file processing, asset parsing, embedding generation using the PyTorch models, and similarity calculation (cosine similarity). It also manages an in-memory cache for embeddings to ensure fast retrieval. +- **`download_utils.py`**: A utility module for downloading model weights and shared datasets from external storage (e.g., Yandex.Disk) during the startup phase. +- **`cad_retrieval_utils/`**: A proprietary library containing the core model definitions, data loaders, and training/inference configurations for the cross-modal retrieval task. +- **`ReConV2/`**: A dependency containing model architectures and potentially C++ extensions for efficient 3D point cloud processing. + +--- + +## ⚙️ How It Works + +The core workflow for processing a new dataset is as follows: + +1. **Upload**: The user uploads a `.zip` file via the React frontend. +2. **API Request**: The frontend sends the file to the `/api/process-dataset` endpoint on the FastAPI backend. +3. **Unpacking & Preprocessing**: The backend saves the archive to a temporary directory and extracts all image, text, and mesh files. +4. **Embedding Generation**: For each file, a specialized PyTorch model generates a high-dimensional vector embedding: + - An **Image Encoder** processes `.png` files. + - A **Text Encoder** processes `.txt` files. + - A **Point Cloud (PC) Encoder** processes `.stl` files after converting them to point clouds. +5. **Caching**: The generated embeddings and asset metadata are stored in an in-memory cache on the server for instant access. +6. **Full Comparison**: The backend pre-computes a full N x N similarity matrix by calculating the cosine similarity between every pair of embeddings. +7. **Response & Client-Side Storage**: The fully processed dataset object, including the comparison matrix, is sent back to the client. The frontend then saves this complete dataset to IndexedDB, making it available for future sessions without needing to re-upload. + +--- + +## 🚀 Getting Started + +You can run this project locally using Docker, which encapsulates both the frontend and backend services. + +### Prerequisites + +- [Docker](https://www.docker.com/get-started) installed on your machine. + +### Local Installation & Startup + +1. **Clone the repository:** + ```bash + git clone + cd + ``` + +2. **Check Model & Data URLs:** + The application is configured to download pre-trained models and a shared dataset from public URLs. Please verify the links inside `backend/main.py` and replace them with your own if necessary. + +3. **Build and run with Docker:** + The provided `Dockerfile` is a multi-stage build that compiles the frontend and sets up the Python backend in a single, optimized image. + + ```bash + # Build the Docker image + docker build -t cross-modal-retrieval . + + # Run the container + docker run -p 7860:7860 cross-modal-retrieval + ``` + +4. **Access the application:** + Open your browser and navigate to [http://localhost:7860](http://localhost:7860). + +--- + +## 💡 Future Improvements + +- **Support for More Formats**: Extend file support to `.obj`/`.glb` for 3D models and `.jpeg`/`.webp` for images. +- **Advanced Search**: Implement more complex filtering and search options within the dataset viewer (e.g., by similarity score, item count). +- **Embedding Visualization**: Add a new section to visualize the high-dimensional embedding space using techniques like t-SNE or UMAP. +- **User Authentication**: Introduce user accounts to manage private datasets and share them with collaborators. +- **Model Fine-tuning**: Allow users to fine-tune the retrieval models on their own datasets to improve domain-specific accuracy. + +--- + +## 📜 License + +This project is licensed under the MIT License. See the [LICENSE](LICENSE) file for details. \ No newline at end of file diff --git a/backend/.DS_Store b/backend/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..2c6b64df67ba6e5ceeb79373d57650e4d656ff4b Binary files /dev/null and b/backend/.DS_Store differ diff --git a/backend/ReConV2/.DS_Store b/backend/ReConV2/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..7dcee206a3b5967c676e9862c50557e45d289e8d Binary files /dev/null and b/backend/ReConV2/.DS_Store differ diff --git a/backend/ReConV2/extensions/chamfer_distance/__init__.py b/backend/ReConV2/extensions/chamfer_distance/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2e15be7028d12ddc55b29752ac718c5284200203 --- /dev/null +++ b/backend/ReConV2/extensions/chamfer_distance/__init__.py @@ -0,0 +1 @@ +from .chamfer_distance import ChamferDistance diff --git a/backend/ReConV2/extensions/chamfer_distance/chamfer_distance.cpp b/backend/ReConV2/extensions/chamfer_distance/chamfer_distance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0fd269649e88febb1ba808d7c7654c4b461acff0 --- /dev/null +++ b/backend/ReConV2/extensions/chamfer_distance/chamfer_distance.cpp @@ -0,0 +1,185 @@ +#include + +// CUDA forward declarations +void ChamferDistanceKernelLauncher( + const int b, const int n, + const float* xyz, + const int m, + const float* xyz2, + float* result, + int* result_i, + float* result2, + int* result2_i); + +void ChamferDistanceGradKernelLauncher( + const int b, const int n, + const float* xyz1, + const int m, + const float* xyz2, + const float* grad_dist1, + const int* idx1, + const float* grad_dist2, + const int* idx2, + float* grad_xyz1, + float* grad_xyz2); + + +void chamfer_distance_forward_cuda( + const at::Tensor xyz1, + const at::Tensor xyz2, + const at::Tensor dist1, + const at::Tensor dist2, + const at::Tensor idx1, + const at::Tensor idx2) +{ + ChamferDistanceKernelLauncher(xyz1.size(0), xyz1.size(1), xyz1.data(), + xyz2.size(1), xyz2.data(), + dist1.data(), idx1.data(), + dist2.data(), idx2.data()); +} + +void chamfer_distance_backward_cuda( + const at::Tensor xyz1, + const at::Tensor xyz2, + at::Tensor gradxyz1, + at::Tensor gradxyz2, + at::Tensor graddist1, + at::Tensor graddist2, + at::Tensor idx1, + at::Tensor idx2) +{ + ChamferDistanceGradKernelLauncher(xyz1.size(0), xyz1.size(1), xyz1.data(), + xyz2.size(1), xyz2.data(), + graddist1.data(), idx1.data(), + graddist2.data(), idx2.data(), + gradxyz1.data(), gradxyz2.data()); +} + + +void nnsearch( + const int b, const int n, const int m, + const float* xyz1, + const float* xyz2, + float* dist, + int* idx) +{ + for (int i = 0; i < b; i++) { + for (int j = 0; j < n; j++) { + const float x1 = xyz1[(i*n+j)*3+0]; + const float y1 = xyz1[(i*n+j)*3+1]; + const float z1 = xyz1[(i*n+j)*3+2]; + double best = 0; + int besti = 0; + for (int k = 0; k < m; k++) { + const float x2 = xyz2[(i*m+k)*3+0] - x1; + const float y2 = xyz2[(i*m+k)*3+1] - y1; + const float z2 = xyz2[(i*m+k)*3+2] - z1; + const double d=x2*x2+y2*y2+z2*z2; + if (k==0 || d < best){ + best = d; + besti = k; + } + } + dist[i*n+j] = best; + idx[i*n+j] = besti; + } + } +} + + +void chamfer_distance_forward( + const at::Tensor xyz1, + const at::Tensor xyz2, + const at::Tensor dist1, + const at::Tensor dist2, + const at::Tensor idx1, + const at::Tensor idx2) +{ + const int batchsize = xyz1.size(0); + const int n = xyz1.size(1); + const int m = xyz2.size(1); + + const float* xyz1_data = xyz1.data(); + const float* xyz2_data = xyz2.data(); + float* dist1_data = dist1.data(); + float* dist2_data = dist2.data(); + int* idx1_data = idx1.data(); + int* idx2_data = idx2.data(); + + nnsearch(batchsize, n, m, xyz1_data, xyz2_data, dist1_data, idx1_data); + nnsearch(batchsize, m, n, xyz2_data, xyz1_data, dist2_data, idx2_data); +} + + +void chamfer_distance_backward( + const at::Tensor xyz1, + const at::Tensor xyz2, + at::Tensor gradxyz1, + at::Tensor gradxyz2, + at::Tensor graddist1, + at::Tensor graddist2, + at::Tensor idx1, + at::Tensor idx2) +{ + const int b = xyz1.size(0); + const int n = xyz1.size(1); + const int m = xyz2.size(1); + + const float* xyz1_data = xyz1.data(); + const float* xyz2_data = xyz2.data(); + float* gradxyz1_data = gradxyz1.data(); + float* gradxyz2_data = gradxyz2.data(); + float* graddist1_data = graddist1.data(); + float* graddist2_data = graddist2.data(); + const int* idx1_data = idx1.data(); + const int* idx2_data = idx2.data(); + + for (int i = 0; i < b*n*3; i++) + gradxyz1_data[i] = 0; + for (int i = 0; i < b*m*3; i++) + gradxyz2_data[i] = 0; + for (int i = 0;i < b; i++) { + for (int j = 0; j < n; j++) { + const float x1 = xyz1_data[(i*n+j)*3+0]; + const float y1 = xyz1_data[(i*n+j)*3+1]; + const float z1 = xyz1_data[(i*n+j)*3+2]; + const int j2 = idx1_data[i*n+j]; + + const float x2 = xyz2_data[(i*m+j2)*3+0]; + const float y2 = xyz2_data[(i*m+j2)*3+1]; + const float z2 = xyz2_data[(i*m+j2)*3+2]; + const float g = graddist1_data[i*n+j]*2; + + gradxyz1_data[(i*n+j)*3+0] += g*(x1-x2); + gradxyz1_data[(i*n+j)*3+1] += g*(y1-y2); + gradxyz1_data[(i*n+j)*3+2] += g*(z1-z2); + gradxyz2_data[(i*m+j2)*3+0] -= (g*(x1-x2)); + gradxyz2_data[(i*m+j2)*3+1] -= (g*(y1-y2)); + gradxyz2_data[(i*m+j2)*3+2] -= (g*(z1-z2)); + } + for (int j = 0; j < m; j++) { + const float x1 = xyz2_data[(i*m+j)*3+0]; + const float y1 = xyz2_data[(i*m+j)*3+1]; + const float z1 = xyz2_data[(i*m+j)*3+2]; + const int j2 = idx2_data[i*m+j]; + const float x2 = xyz1_data[(i*n+j2)*3+0]; + const float y2 = xyz1_data[(i*n+j2)*3+1]; + const float z2 = xyz1_data[(i*n+j2)*3+2]; + const float g = graddist2_data[i*m+j]*2; + gradxyz2_data[(i*m+j)*3+0] += g*(x1-x2); + gradxyz2_data[(i*m+j)*3+1] += g*(y1-y2); + gradxyz2_data[(i*m+j)*3+2] += g*(z1-z2); + gradxyz1_data[(i*n+j2)*3+0] -= (g*(x1-x2)); + gradxyz1_data[(i*n+j2)*3+1] -= (g*(y1-y2)); + gradxyz1_data[(i*n+j2)*3+2] -= (g*(z1-z2)); + } + } +} + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &chamfer_distance_forward, "ChamferDistance forward"); + m.def("forward_cuda", &chamfer_distance_forward_cuda, "ChamferDistance forward (CUDA)"); + m.def("backward", &chamfer_distance_backward, "ChamferDistance backward"); + m.def("backward_cuda", &chamfer_distance_backward_cuda, "ChamferDistance backward (CUDA)"); +} diff --git a/backend/ReConV2/extensions/chamfer_distance/chamfer_distance.cu b/backend/ReConV2/extensions/chamfer_distance/chamfer_distance.cu new file mode 100644 index 0000000000000000000000000000000000000000..f10f2ba854883d7f590236bb69e3598e8a4ef379 --- /dev/null +++ b/backend/ReConV2/extensions/chamfer_distance/chamfer_distance.cu @@ -0,0 +1,209 @@ +#include + +#include +#include + +__global__ +void ChamferDistanceKernel( + int b, + int n, + const float* xyz, + int m, + const float* xyz2, + float* result, + int* result_i) +{ + const int batch=512; + __shared__ float buf[batch*3]; + for (int i=blockIdx.x;ibest){ + result[(i*n+j)]=best; + result_i[(i*n+j)]=best_i; + } + } + __syncthreads(); + } + } +} + +void ChamferDistanceKernelLauncher( + const int b, const int n, + const float* xyz, + const int m, + const float* xyz2, + float* result, + int* result_i, + float* result2, + int* result2_i) +{ + ChamferDistanceKernel<<>>(b, n, xyz, m, xyz2, result, result_i); + ChamferDistanceKernel<<>>(b, m, xyz2, n, xyz, result2, result2_i); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + printf("error in chamfer distance updateOutput: %s\n", cudaGetErrorString(err)); +} + + +__global__ +void ChamferDistanceGradKernel( + int b, int n, + const float* xyz1, + int m, + const float* xyz2, + const float* grad_dist1, + const int* idx1, + float* grad_xyz1, + float* grad_xyz2) +{ + for (int i = blockIdx.x; i>>(b, n, xyz1, m, xyz2, grad_dist1, idx1, grad_xyz1, grad_xyz2); + ChamferDistanceGradKernel<<>>(b, m, xyz2, n, xyz1, grad_dist2, idx2, grad_xyz2, grad_xyz1); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + printf("error in chamfer distance get grad: %s\n", cudaGetErrorString(err)); +} diff --git a/backend/ReConV2/extensions/chamfer_distance/chamfer_distance.py b/backend/ReConV2/extensions/chamfer_distance/chamfer_distance.py new file mode 100644 index 0000000000000000000000000000000000000000..776ed3cae4caea5f596d1e53a5e0b03cc02b1909 --- /dev/null +++ b/backend/ReConV2/extensions/chamfer_distance/chamfer_distance.py @@ -0,0 +1,71 @@ +import os + +import torch + +script_path = os.path.dirname(os.path.abspath(__file__)) + +from torch.utils.cpp_extension import load + +if torch.cuda.is_available(): + cd = load( + name="cd", + sources=[ + os.path.join(script_path, "chamfer_distance.cpp"), + os.path.join(script_path, "chamfer_distance.cu"), + ], + ) + + +class ChamferDistanceFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, xyz1, xyz2): + batchsize, n, _ = xyz1.size() + _, m, _ = xyz2.size() + xyz1 = xyz1.contiguous() + xyz2 = xyz2.contiguous() + dist1 = torch.zeros(batchsize, n) + dist2 = torch.zeros(batchsize, m) + + idx1 = torch.zeros(batchsize, n, dtype=torch.int) + idx2 = torch.zeros(batchsize, m, dtype=torch.int) + + if not xyz1.is_cuda: + cd.forward(xyz1, xyz2, dist1, dist2, idx1, idx2) + else: + dist1 = dist1.cuda() + dist2 = dist2.cuda() + idx1 = idx1.cuda() + idx2 = idx2.cuda() + cd.forward_cuda(xyz1, xyz2, dist1, dist2, idx1, idx2) + + ctx.save_for_backward(xyz1, xyz2, idx1, idx2) + + return dist1, dist2, idx1 + + @staticmethod + def backward(ctx, graddist1, graddist2, _): + xyz1, xyz2, idx1, idx2 = ctx.saved_tensors + + graddist1 = graddist1.contiguous() + graddist2 = graddist2.contiguous() + + gradxyz1 = torch.zeros(xyz1.size()) + gradxyz2 = torch.zeros(xyz2.size()) + + if not graddist1.is_cuda: + cd.backward( + xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2 + ) + else: + gradxyz1 = gradxyz1.cuda() + gradxyz2 = gradxyz2.cuda() + cd.backward_cuda( + xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2 + ) + + return gradxyz1, gradxyz2 + + +class ChamferDistance(torch.nn.Module): + def forward(self, xyz1, xyz2): + return ChamferDistanceFunction.apply(xyz1, xyz2) diff --git a/backend/ReConV2/models/ReCon.py b/backend/ReConV2/models/ReCon.py new file mode 100644 index 0000000000000000000000000000000000000000..aa733daedc47e3b76132ac4de0d8473fa5067175 --- /dev/null +++ b/backend/ReConV2/models/ReCon.py @@ -0,0 +1,630 @@ +import numpy as np +import timm +import torch +import torch.nn as nn +import torch.nn.functional as F +from scipy.optimize import linear_sum_assignment +from timm.layers import trunc_normal_ + +from ReConV2.extensions.chamfer_distance import ChamferDistance +from ReConV2.models.transformer import ( + GPTExtractor, + GPTGenerator, + Group, + MAEExtractor, + MAEGenerator, + PatchEmbedding, + PositionEmbeddingCoordsSine, + ZGroup, +) +from ReConV2.utils.checkpoint import ( + get_missing_parameters_message, + get_unexpected_parameters_message, +) +from ReConV2.utils.logger import * + +from .build import MODELS + + +# Pretrain model +class MaskTransformer(nn.Module): + def __init__(self, config): + super().__init__() + + self.embed_dim = config.embed_dim + self.num_group = config.num_group + self.group_size = config.group_size + self.with_color = config.with_color + self.input_channel = 6 if self.with_color else 3 + self.img_queries = config.img_queries + self.text_queries = config.text_queries + self.global_query_num = self.img_queries + self.text_queries + self.mask_type = config.mask_type + self.mask_ratio = config.mask_ratio + self.stop_grad = config.stop_grad + + self.embed = PatchEmbedding( + embed_dim=self.embed_dim, + input_channel=self.input_channel, + large=config.large_embedding, + ) + + print_log( + f"[ReCon] divide point cloud into G{config.num_group} x S{config.group_size} points ...", + logger="ReCon", + ) + + if self.mask_type == "causal": + self.group_divider = ZGroup( + num_group=config.num_group, group_size=config.group_size + ) + self.encoder = GPTExtractor( + embed_dim=config.embed_dim, + num_heads=config.num_heads, + depth=config.depth, + group_size=config.group_size, + drop_path_rate=config.drop_path_rate, + stop_grad=self.stop_grad, + pretrained_model_name=config.pretrained_model_name, + ) + self.decoder = GPTGenerator( + embed_dim=config.embed_dim, + depth=config.decoder_depth, + drop_path_rate=config.drop_path_rate, + num_heads=config.num_heads, + group_size=config.group_size, + input_channel=self.input_channel, + ) + self.pos_embed = PositionEmbeddingCoordsSine(3, self.embed_dim, 1.0) + + else: + self.group_divider = Group( + num_group=config.num_group, group_size=config.group_size + ) + self.encoder = MAEExtractor( + embed_dim=config.embed_dim, + num_heads=config.num_heads, + depth=config.depth, + group_size=config.group_size, + drop_path_rate=config.drop_path_rate, + stop_grad=self.stop_grad, + pretrained_model_name=config.pretrained_model_name, + ) + self.decoder = MAEGenerator( + embed_dim=config.embed_dim, + depth=config.decoder_depth, + drop_path_rate=config.drop_path_rate, + num_heads=config.num_heads, + group_size=config.group_size, + input_channel=self.input_channel, + ) + self.pos_embed = nn.Sequential( + nn.Linear(3, 128), nn.GELU(), nn.Linear(128, self.embed_dim) + ) + self.decoder_pos_embed = nn.Sequential( + nn.Linear(3, 128), nn.GELU(), nn.Linear(128, self.embed_dim) + ) + + self.norm = nn.LayerNorm(self.embed_dim) + self.global_query = nn.Parameter( + torch.zeros(1, self.global_query_num, self.embed_dim) + ) + self.apply(self._init_weights) + + # do not perform additional mask on the first (self.keep_attend) tokens + self.keep_attend = 10 + self.num_group = config.num_group + self.num_mask = int((self.num_group - self.keep_attend) * self.mask_ratio) + + if config.pretrained_model_name == "": + print_log("[ReCon] No pretrained model is loaded.", logger="ReCon") + elif config.pretrained_model_name in timm.list_models(pretrained=True): + self.encoder.blocks.load_pretrained_timm_weights() + print_log( + f"[ReCon] Timm pretrained model {config.pretrained_model_name} is successful loaded.", + logger="ReCon", + ) + else: + print_log( + f"[ReCon] Pretrained model {config.pretrained_model_name} is not found in Timm.", + logger="ReCon", + ) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0.02, 0.01) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm1d): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def _mask_center_rand(self, center): + """ + center : B G 3 + -------------- + mask : B G (bool) + """ + B, G, _ = center.shape + num_mask = int(self.mask_ratio * G) + + overall_mask = np.zeros([B, G]) + for i in range(B): + mask = np.hstack([ + np.zeros(G - num_mask), + np.ones(num_mask), + ]) + np.random.shuffle(mask) + overall_mask[i, :] = mask + overall_mask = torch.from_numpy(overall_mask).to(torch.bool) + + return num_mask, overall_mask.to(center.device) + + def inference(self, pts): + with torch.no_grad(): + neighborhood, center = self.group_divider(pts) + group_input_tokens = self.embed(neighborhood) # B G C + batch_size, seq_len, C = group_input_tokens.size() + + global_query = self.global_query.expand(batch_size, -1, -1) + pos = self.pos_embed(center.to(group_input_tokens.dtype)) + + mask = torch.full( + (seq_len, seq_len), + -float("Inf"), + device=group_input_tokens.device, + dtype=group_input_tokens.dtype, + ).to(torch.bool) + if self.mask_type == "causal": + mask = torch.triu(mask, diagonal=1) + else: + mask = None + + local_features, global_features = self.encoder( + group_input_tokens, pos, mask, global_query + ) + + return pos, local_features, global_features + + def forward_mae(self, pts): + neighborhood, center = self.group_divider(pts) + num_mask, mask = self._mask_center_rand(center) + group_input_tokens = self.embed(neighborhood) # B G C + batch_size, seq_len, C = group_input_tokens.size() + global_query = self.global_query.expand(batch_size, -1, -1) + + pos = self.pos_embed(center.reshape(batch_size, -1, 3)) + decoder_pos = self.decoder_pos_embed(center.reshape(batch_size, -1, 3)) + x_vis, global_features = self.encoder( + group_input_tokens, pos, mask, global_query + ) + generated_points = self.decoder(x_vis, decoder_pos, mask) + + gt_points = neighborhood[mask].reshape( + batch_size * num_mask, self.group_size, self.input_channel + ) + + return generated_points, gt_points, global_features + + def forward_gpt(self, pts): + neighborhood, center = self.group_divider(pts) + group_input_tokens = self.embed(neighborhood) # B G C + batch_size, seq_len, C = group_input_tokens.size() + + global_query = self.global_query.expand(batch_size, -1, -1) + pos_absolute = self.pos_embed(center).to(group_input_tokens.dtype) + + relative_position = center[:, 1:, :] - center[:, :-1, :] + relative_norm = torch.norm(relative_position, dim=-1, keepdim=True) + relative_direction = relative_position / (relative_norm + 1e-5) + position = torch.cat([center[:, 0, :].unsqueeze(1), relative_direction], dim=1) + pos_relative = self.pos_embed(position).to(group_input_tokens.dtype) + + attn_mask = torch.full( + (seq_len, seq_len), + -float("Inf"), + device=group_input_tokens.device, + dtype=group_input_tokens.dtype, + ).to(torch.bool) + + with torch.no_grad(): + attn_mask = torch.triu(attn_mask, diagonal=1) + + # column wise + overall_mask = np.hstack([ + np.zeros(self.num_group - self.keep_attend - self.num_mask), + np.ones(self.num_mask), + ]) + np.random.shuffle(overall_mask) + overall_mask = np.hstack([ + np.zeros(self.keep_attend), + overall_mask, + ]) + overall_mask = ( + torch.from_numpy(overall_mask) + .to(torch.bool) + .to(group_input_tokens.device) + ) + eye_mask = torch.eye( + self.num_group, device=group_input_tokens.device, dtype=torch.bool + ) + attn_mask = attn_mask | overall_mask.unsqueeze(0) & ~eye_mask + + local_features, global_features = self.encoder( + group_input_tokens, pos_absolute, attn_mask, global_query + ) + generated_points = self.decoder(local_features, pos_relative, attn_mask) + + gt_points = neighborhood.reshape( + batch_size * self.num_group, self.group_size, self.input_channel + ) + + return generated_points, gt_points, global_features + + def forward(self, pts): + if self.mask_type == "causal": + generated_points, gt_points, global_query = self.forward_gpt(pts) + else: + generated_points, gt_points, global_query = self.forward_mae(pts) + + return generated_points, gt_points, global_query + + +@MODELS.register_module() +class ReCon2(nn.Module): + def __init__(self, config): + super().__init__() + print_log("[ReCon V2]", logger="ReCon V2") + self.config = config + self.embed_dim = config.embed_dim + self.with_color = config.with_color + self.img_queries = config.img_queries + self.text_queries = config.text_queries + self.global_query_num = self.img_queries + self.text_queries + self.input_channel = 6 if self.with_color else 3 + self.contrast_type = config.contrast_type + + self.model = MaskTransformer(config) + self.cd_loss = ChamferDistance() + self.l1_loss = torch.nn.SmoothL1Loss() + + self.img_proj = nn.Linear(self.embed_dim, 1280) + self.img_proj.apply(self._init_weights) + self.text_proj = nn.Linear(self.embed_dim, 1280) + self.text_proj.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0.02, 0.01) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm1d): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def info_nce_loss(self, feat1, feat2, logit_scale=1, mask=None): + feat1 = F.normalize(feat1, dim=1) + feat2 = F.normalize(feat2, dim=1) + all_feat1 = torch.cat(torch.distributed.nn.all_gather(feat1), dim=0) + all_feat2 = torch.cat(torch.distributed.nn.all_gather(feat2), dim=0) + logits = logit_scale * all_feat1 @ all_feat2.T + if mask is not None: + logits = logits * mask + labels = torch.arange(logits.shape[0]).to(self.config.device) + accuracy = (logits.argmax(dim=1) == labels).float().mean() + loss = (F.cross_entropy(logits, labels) + F.cross_entropy(logits.T, labels)) / 2 + return loss, accuracy + + def distillation_loss(self, token, feature): + B = token.shape[0] + loss = 0.0 + for i in range(B): + pred = token[i] + feat = feature[i][torch.any(feature[i] != 0, dim=1)] + feat = F.normalize(feat, dim=-1) + similarity_matrix = torch.mm(pred, feat.T).cpu().detach().numpy() + row_ind, col_ind = linear_sum_assignment(-similarity_matrix) + loss = loss + self.l1_loss(pred[row_ind], feat[col_ind]) + + return loss * 5 + + def contrast_loss(self, token, feature): + if self.contrast_type == "simclr": + return self.info_nce_loss( + token, feature, logit_scale=self.logit_scale, mask=self.mask + ) + elif self.contrast_type == "byol": + return self.distillation_loss(token, feature) + else: + raise ValueError("Unknown contrast type") + + def inference(self, pts): + _, encoded_features, global_token = self.model.inference(pts) + + img_token = global_token[:, : self.img_queries] + img_token = self.img_proj(img_token) + img_token = F.normalize(img_token, dim=-1) + + text_token = global_token[:, self.img_queries :] + text_token = self.text_proj(text_token) + text_token = F.normalize(text_token, dim=-1) + + return encoded_features, global_token, img_token, text_token + + def forward_features(self, pts): + generated_points, gt_points, global_token = self.model(pts) + + img_token = global_token[:, : self.img_queries] + img_token = self.img_proj(img_token) + img_token = F.normalize(img_token, dim=-1) + + text_token = global_token[:, self.img_queries :] + text_token = self.text_proj(text_token) + text_token = F.normalize(text_token, dim=-1) + + return img_token, text_token, gt_points, generated_points + + def forward_reconstruct(self, pts): + _, _, gt_points, generated_points = self.forward_features(pts) + + generated_xyz = generated_points[:, :, :3] + gt_xyz = gt_points[:, :, :3] + dist1, dist2, idx = self.cd_loss(generated_xyz, gt_xyz) + if self.with_color: + generated_color = generated_points[:, :, 3:] + gt_color = gt_points[:, :, 3:] + color_l1_loss = self.l1_loss( + generated_color, + torch.gather(gt_color, 1, idx.unsqueeze(-1).expand(-1, -1, 3).long()), + ) + else: + color_l1_loss = 0 + cd_l2_loss = (torch.mean(dist1)) + (torch.mean(dist2)) + cd_l1_loss = (torch.mean(torch.sqrt(dist1)) + torch.mean(torch.sqrt(dist2))) / 2 + + loss = cd_l1_loss + cd_l2_loss + color_l1_loss + + return loss + + def forward_contrast(self, pts, img, text): + img_token, text_token, _, _ = self.forward_features(pts) + img_loss = self.contrast_loss(img_token, img) + text_loss = self.contrast_loss(text_token, text) + loss = img_loss + text_loss + + return loss + + def forward_all(self, pts, img, text): + img_token, text_token, gt_points, generated_points = self.forward_features(pts) + + losses = {"mdm": 0, "csc_img": 0, "csc_text": 0} + + generated_xyz = generated_points[:, :, :3] + gt_xyz = gt_points[:, :, :3] + dist1, dist2, idx = self.cd_loss(generated_xyz, gt_xyz) + if self.with_color: + generated_color = generated_points[:, :, 3:] + gt_color = gt_points[:, :, 3:] + color_l1_loss = self.l1_loss( + generated_color, + torch.gather(gt_color, 1, idx.unsqueeze(-1).expand(-1, -1, 3).long()), + ) + else: + color_l1_loss = 0 + cd_l2_loss = (torch.mean(dist1)) + (torch.mean(dist2)) + cd_l1_loss = (torch.mean(torch.sqrt(dist1)) + torch.mean(torch.sqrt(dist2))) / 2 + + losses["mdm"] = cd_l1_loss + cd_l2_loss + color_l1_loss + losses["csc_img"] = self.contrast_loss(img_token, img) + losses["csc_text"] = self.contrast_loss(text_token, text) + + print(losses) + loss = sum(losses.values()) + return loss + + def forward(self, pts, img, text, type="all"): + if type == "all": + return self.forward_all(pts, img, text) + elif type == "reconstruct": + return self.forward_reconstruct(pts) + elif type == "contrast": + return self.forward_contrast(pts, img, text) + else: + raise ValueError("Unknown type") + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + +# finetune model +@MODELS.register_module() +class PointTransformer(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + + self.cls_dim = config.cls_dim + self.embed_dim = config.embed_dim + self.with_color = config.with_color + self.input_channel = 6 if self.with_color else 3 + self.num_group = config.num_group + self.group_size = config.group_size + self.img_queries = config.img_queries + self.text_queries = config.text_queries + self.global_query_num = self.img_queries + self.text_queries + self.large_embedding = config.large_embedding + + self.embed = PatchEmbedding( + embed_dim=self.embed_dim, + input_channel=self.input_channel, + large=self.large_embedding, + ) + self.pos_embed = PositionEmbeddingCoordsSine(3, self.embed_dim, 1.0) + + self.group_divider = ZGroup( + num_group=config.num_group, group_size=config.group_size + ) + print_log( + f"[PointTransformer] divide point cloud into G{config.num_group} x S{config.group_size} points ...", + logger="PointTransformer", + ) + + self.encoder = GPTExtractor( + embed_dim=config.embed_dim, + num_heads=config.num_heads, + depth=config.depth, + group_size=config.group_size, + drop_path_rate=config.drop_path_rate, + stop_grad=False, + ) + + self.decoder = GPTGenerator( + embed_dim=config.embed_dim, + depth=config.decoder_depth, + drop_path_rate=config.drop_path_rate, + num_heads=config.num_heads, + group_size=config.group_size, + input_channel=self.input_channel, + ) + self.global_query = nn.Parameter( + torch.zeros(1, self.global_query_num, self.embed_dim) + ) + self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) + + feature_dim = self.embed_dim * 4 + self.cls_head_finetune = nn.Sequential( + nn.Linear(feature_dim, 256), + nn.BatchNorm1d(256), + nn.ReLU(inplace=True), + nn.Dropout(0.5), + nn.Linear(256, 256), + nn.BatchNorm1d(256), + nn.ReLU(inplace=True), + nn.Dropout(0.5), + nn.Linear(256, self.cls_dim), + ) + + self.loss_ce = nn.CrossEntropyLoss() + # chamfer distance loss + self.cd_loss = ChamferDistance() + self.apply(self._init_weights) + + def get_loss_acc(self, ret, gt): + loss = self.loss_ce(ret, gt.long()) + pred = ret.argmax(-1) + acc = (pred == gt).sum() / float(gt.size(0)) + return loss, acc * 100 + + def load_model_from_ckpt(self, ckpt_path, log=True): + if ckpt_path is not None: + ckpt = torch.load(ckpt_path) + base_ckpt = { + k.replace("module.", ""): v for k, v in ckpt["base_model"].items() + } + + for k in list(base_ckpt.keys()): + if k.startswith("model"): + base_ckpt[k[len("model.") :]] = base_ckpt[k] + del base_ckpt[k] + elif k.startswith("cls_head_finetune"): + del base_ckpt[k] + + incompatible = self.load_state_dict(base_ckpt, strict=False) + if log: + if incompatible.missing_keys: + print_log("missing_keys", logger="PointTransformer") + print_log( + get_missing_parameters_message(incompatible.missing_keys), + logger="PointTransformer", + ) + if incompatible.unexpected_keys: + print_log("unexpected_keys", logger="PointTransformer") + print_log( + get_unexpected_parameters_message(incompatible.unexpected_keys), + logger="PointTransformer", + ) + + print_log( + f"[PointTransformer] Successful Loading the ckpt from {ckpt_path}", + logger="PointTransformer", + ) + else: + print_log("Training from scratch!!!", logger="PointTransformer") + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv1d): + trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, pts): + neighborhood, center = self.group_divider(pts) + group_input_tokens = self.embed(neighborhood) # B G C + batch_size, seq_len, C = group_input_tokens.size() + + global_query = self.global_query.expand(batch_size, -1, -1) + cls_query = self.cls_token.expand(batch_size, -1, -1) + query = torch.cat([global_query, cls_query], dim=1) + + relative_position = center[:, 1:, :] - center[:, :-1, :] + relative_norm = torch.norm(relative_position, dim=-1, keepdim=True) + relative_direction = relative_position / (relative_norm + 1e-5) + position = torch.cat([center[:, 0, :].unsqueeze(1), relative_direction], dim=1) + pos_relative = self.pos_embed(position).to(group_input_tokens.dtype) + + pos = self.pos_embed(center).to(group_input_tokens.dtype) + + attn_mask = torch.full( + (seq_len, seq_len), + -float("Inf"), + device=group_input_tokens.device, + dtype=group_input_tokens.dtype, + ).to(torch.bool) + attn_mask = torch.triu(attn_mask, diagonal=1) + + # transformer + encoded_features, global_tokens = self.encoder( + group_input_tokens, pos, attn_mask, query + ) + generated_points = self.decoder(encoded_features, pos_relative, attn_mask) + + # neighborhood[:, :, :, :3] = neighborhood[:, :, :, :3] + center.unsqueeze(2) + gt_points = neighborhood.reshape( + batch_size * self.num_group, self.group_size, self.input_channel + ) + + generated_xyz = generated_points[:, :, :3] + gt_xyz = gt_points[:, :, :3] + dist1, dist2, idx = self.cd_loss(generated_xyz, gt_xyz) + + cd_l2_loss = (torch.mean(dist1)) + (torch.mean(dist2)) + cd_l1_loss = (torch.mean(torch.sqrt(dist1)) + torch.mean(torch.sqrt(dist2))) / 2 + + img_token = global_tokens[:, : self.img_queries] + text_token = global_tokens[:, self.img_queries : -1] + cls_token = global_tokens[:, -1] + + concat_f = torch.cat( + [ + cls_token, + img_token.max(1)[0], + text_token.max(1)[0], + encoded_features.max(1)[0], + ], + dim=-1, + ) + ret = self.cls_head_finetune(concat_f) + + return ret, cd_l1_loss + cd_l2_loss diff --git a/backend/ReConV2/models/__init__.py b/backend/ReConV2/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c403dffaec7ed3e8827329dafae98d3c233e7ce0 --- /dev/null +++ b/backend/ReConV2/models/__init__.py @@ -0,0 +1,4 @@ +import ReConV2.models.ReCon +import ReConV2.models.transformer + +from .build import build_model_from_cfg diff --git a/backend/ReConV2/models/build.py b/backend/ReConV2/models/build.py new file mode 100644 index 0000000000000000000000000000000000000000..c2356f8d640770b1fcbab77cd57cd6844a49a1b5 --- /dev/null +++ b/backend/ReConV2/models/build.py @@ -0,0 +1,14 @@ +from ReConV2.utils import registry + +MODELS = registry.Registry("models") + + +def build_model_from_cfg(cfg, **kwargs): + """ + Build a dataset, defined by `dataset_name`. + Args: + cfg (eDICT): + Returns: + Dataset: a constructed dataset specified by dataset_name. + """ + return MODELS.build(cfg, **kwargs) diff --git a/backend/ReConV2/models/transformer.py b/backend/ReConV2/models/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..69a754f6c5ce48b31913ca5c2671454d6f0c8f86 --- /dev/null +++ b/backend/ReConV2/models/transformer.py @@ -0,0 +1,788 @@ +import math + +import timm +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.layers import DropPath, Mlp + +from ReConV2.utils import misc +from ReConV2.utils.knn import knn_point +from ReConV2.utils.logger import * + + +class PatchEmbedding(nn.Module): # Embedding module + def __init__(self, embed_dim, input_channel=3, large=False): + super().__init__() + self.embed_dim = embed_dim + self.input_channel = input_channel + + # embed_dim_list = [c * (embed_dim // 512 + 1) for c in [128, 256, 512]] + # + # self.first_conv = nn.Sequential( + # nn.Conv1d(self.input_channel, embed_dim_list[0], 1), + # nn.BatchNorm1d(embed_dim_list[0]), + # nn.ReLU(inplace=True), + # nn.Conv1d(embed_dim_list[0], embed_dim_list[1], 1) + # ) + # self.second_conv = nn.Sequential( + # nn.Conv1d(embed_dim_list[2], embed_dim_list[2], 1), + # nn.BatchNorm1d(embed_dim_list[2]), + # nn.ReLU(inplace=True), + # nn.Conv1d(embed_dim_list[2], self.embed_dim, 1) + # ) + + if large: + self.first_conv = nn.Sequential( + nn.Conv1d(self.input_channel, 256, 1), + nn.BatchNorm1d(256), + nn.ReLU(inplace=True), + nn.Conv1d(256, 512, 1), + nn.BatchNorm1d(512), + nn.ReLU(inplace=True), + nn.Conv1d(512, 1024, 1), + ) + self.second_conv = nn.Sequential( + nn.Conv1d(2048, 2048, 1), + nn.BatchNorm1d(2048), + nn.ReLU(inplace=True), + nn.Conv1d(2048, embed_dim, 1), + ) + else: + self.first_conv = nn.Sequential( + nn.Conv1d(self.input_channel, 128, 1), + nn.BatchNorm1d(128), + nn.ReLU(inplace=True), + nn.Conv1d(128, 256, 1), + ) + self.second_conv = nn.Sequential( + nn.Conv1d(512, 512, 1), + nn.BatchNorm1d(512), + nn.ReLU(inplace=True), + nn.Conv1d(512, embed_dim, 1), + ) + + def forward(self, point_groups): + """ + point_groups : B G N 3/6 + ----------------- + feature_global : B G C + """ + bs, g, n, _ = point_groups.shape + point_groups = point_groups.reshape(bs * g, n, self.input_channel) + # encoder + feature = self.first_conv(point_groups.transpose(2, 1)) + feature_global = torch.max(feature, dim=2, keepdim=True)[0] + feature = torch.cat([feature_global.expand(-1, -1, n), feature], dim=1) + feature = self.second_conv(feature) + feature_global = torch.max(feature, dim=2, keepdim=False)[0] + return feature_global.reshape(bs, g, self.embed_dim) + + +class PositionEmbeddingCoordsSine(nn.Module): + """Similar to transformer's position encoding, but generalizes it to + arbitrary dimensions and continuous coordinates. + + Args: + n_dim: Number of input dimensions, e.g. 2 for image coordinates. + d_model: Number of dimensions to encode into + temperature: + scale: + """ + + def __init__(self, n_dim: int = 1, d_model: int = 256, temperature=1.0, scale=None): + super().__init__() + + self.n_dim = n_dim + self.num_pos_feats = d_model // n_dim // 2 * 2 + self.temperature = temperature + self.padding = d_model - self.num_pos_feats * self.n_dim + + if scale is None: + scale = 1.0 + self.scale = scale * 2 * math.pi + + def forward(self, xyz: torch.Tensor) -> torch.Tensor: + """ + Args: + xyz: Point positions (*, d_in) + + Returns: + pos_emb (*, d_out) + """ + assert xyz.shape[-1] == self.n_dim + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=xyz.device) + dim_t = self.temperature ** ( + 2 * torch.div(dim_t, 2, rounding_mode="trunc") / self.num_pos_feats + ) + + xyz = xyz * self.scale + pos_divided = xyz.unsqueeze(-1) / dim_t + pos_sin = pos_divided[..., 0::2].sin() + pos_cos = pos_divided[..., 1::2].cos() + pos_emb = torch.stack([pos_sin, pos_cos], dim=-1).reshape(*xyz.shape[:-1], -1) + + # Pad unused dimensions with zeros + pos_emb = F.pad(pos_emb, (0, self.padding)) + return pos_emb + + +class Group(nn.Module): # FPS + KNN + def __init__(self, num_group, group_size): + super().__init__() + self.num_group = num_group + self.group_size = group_size + + def forward(self, pts): + """ + input: B N 3/6 + --------------------------- + output: B G M 3/6 + center : B G 3 + """ + xyz = pts[:, :, :3] + c = pts.shape[2] + batch_size, num_points, _ = xyz.shape + # fps the centers out + xyz = xyz.float() + center = misc.fps(xyz.contiguous(), self.num_group) # B G 3 + # knn to get the neighborhood + idx = knn_point(self.group_size, xyz, center) + assert idx.size(1) == self.num_group + assert idx.size(2) == self.group_size + idx_base = ( + torch.arange(0, batch_size, device=xyz.device).view(-1, 1, 1) * num_points + ) + idx = idx + idx_base + idx = idx.view(-1) + neighborhood = pts.view(batch_size * num_points, -1)[idx, :] + neighborhood = neighborhood.view( + batch_size, self.num_group, self.group_size, c + ).contiguous() + # normalize + neighborhood[:, :, :, :3] = neighborhood[:, :, :, :3] - center.unsqueeze(2) + return neighborhood, center + + +class ZGroup(nn.Module): + def __init__(self, num_group, group_size): + super().__init__() + self.num_group = num_group + self.group_size = group_size + + def simplied_morton_sorting(self, xyz, center): + """ + Simplifying the Morton code sorting to iterate and set the nearest patch to the last patch as the next patch, we found this to be more efficient. + """ + batch_size, num_points, _ = xyz.shape + distances_batch = torch.cdist(center, center) + distances_batch[:, torch.eye(self.num_group).bool()] = float("inf") + idx_base = torch.arange(0, batch_size, device=xyz.device) * self.num_group + sorted_indices_list = [idx_base] + distances_batch = ( + distances_batch.view(batch_size, self.num_group, self.num_group) + .transpose(1, 2) + .contiguous() + .view(batch_size * self.num_group, self.num_group) + ) + distances_batch[idx_base] = float("inf") + distances_batch = ( + distances_batch.view(batch_size, self.num_group, self.num_group) + .transpose(1, 2) + .contiguous() + ) + for i in range(self.num_group - 1): + distances_batch = distances_batch.view( + batch_size * self.num_group, self.num_group + ) + distances_to_last_batch = distances_batch[sorted_indices_list[-1]] + closest_point_idx = torch.argmin(distances_to_last_batch, dim=-1) + closest_point_idx = closest_point_idx + idx_base + sorted_indices_list.append(closest_point_idx) + distances_batch = ( + distances_batch.view(batch_size, self.num_group, self.num_group) + .transpose(1, 2) + .contiguous() + .view(batch_size * self.num_group, self.num_group) + ) + distances_batch[closest_point_idx] = float("inf") + distances_batch = ( + distances_batch.view(batch_size, self.num_group, self.num_group) + .transpose(1, 2) + .contiguous() + ) + sorted_indices = torch.stack(sorted_indices_list, dim=-1) + sorted_indices = sorted_indices.view(-1) + return sorted_indices + + def forward(self, pts): + """ + input: B N 3/6 + --------------------------- + output: B G M 3/6 + center : B G 3 + """ + xyz = pts[:, :, :3] + c = pts.shape[2] + batch_size, num_points, _ = xyz.shape + # fps the centers out + xyz = xyz.float() + center = misc.fps(xyz.contiguous(), self.num_group) # B G 3 + # knn to get the neighborhood + idx = knn_point(self.group_size, xyz, center) + assert idx.size(1) == self.num_group + assert idx.size(2) == self.group_size + idx_base = ( + torch.arange(0, batch_size, device=xyz.device).view(-1, 1, 1) * num_points + ) + idx = idx + idx_base + idx = idx.view(-1) + neighborhood = pts.view(batch_size * num_points, -1)[idx, :] + neighborhood = neighborhood.view( + batch_size, self.num_group, self.group_size, c + ).contiguous() + # normalize + neighborhood[:, :, :, :3] = neighborhood[:, :, :, :3] - center.unsqueeze(2) + + # can utilize morton_sorting by choosing morton_sorting function + sorted_indices = self.simplied_morton_sorting(xyz, center) + + neighborhood = neighborhood.view( + batch_size * self.num_group, self.group_size, c + )[sorted_indices, :, :] + neighborhood = neighborhood.view( + batch_size, self.num_group, self.group_size, c + ).contiguous() + center = center.view(batch_size * self.num_group, 3)[sorted_indices, :] + center = center.view(batch_size, self.num_group, 3).contiguous() + + return neighborhood, center + + +# Transformers +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + qk_norm: bool = False, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + norm_layer: nn.Module = nn.LayerNorm, + ) -> None: + super().__init__() + assert dim % num_heads == 0, "dim should be divisible by num_heads" + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward( + self, x: torch.Tensor, mask: torch.Tensor | None = None + ) -> torch.Tensor: + B, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, self.head_dim) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = qkv.unbind(0) + q, k = self.q_norm(q), self.k_norm(k) + + q = q * self.scale + attn = q @ k.transpose(-2, -1) + if mask is not None: + attn = attn.masked_fill(mask, float("-inf")) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class CrossAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + qk_norm: bool = False, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + norm_layer: nn.Module = nn.LayerNorm, + ) -> None: + super().__init__() + assert dim % num_heads == 0, "dim should be divisible by num_heads" + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim**-0.5 + + self.q = nn.Linear(dim, dim, bias=qkv_bias) + self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward( + self, x: torch.Tensor, y: torch.Tensor, mask: torch.Tensor | None = None + ) -> torch.Tensor: + B, N, C = y.shape + kv = ( + self.kv(y) + .reshape(B, N, 2, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + k, v = kv.unbind(0) + + B, N, C = x.shape + q = ( + self.q(x) + .reshape(B, N, 1, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4)[0] + ) + + q, k = self.q_norm(q), self.k_norm(k) + q = q * self.scale + attn = q @ k.transpose(-2, -1) + if mask is not None: + attn = attn.masked_fill(mask, float("-inf")) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: float = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + qk_norm: bool = False, + proj_drop: float = 0.0, + attn_drop: float = 0.0, + init_values: float | None = None, + drop_path: float = 0.0, + act_layer: nn.Module = nn.GELU, + norm_layer: nn.Module = nn.LayerNorm, + ) -> None: + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + attn_drop=attn_drop, + proj_drop=proj_drop, + norm_layer=norm_layer, + ) + self.ls1 = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + self.mlp = Mlp( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + drop=proj_drop, + ) + self.ls2 = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + def forward(self, x, attn_mask=None): + x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), attn_mask))) + x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) + return x + + +class CrossBlock(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + qk_norm: bool = False, + proj_drop: float = 0.0, + attn_drop: float = 0.0, + init_values: float | None = None, + drop_path: float = 0.0, + act_layer: nn.Module = nn.GELU, + norm_layer: nn.Module = nn.LayerNorm, + stop_grad: bool = False, + ) -> None: + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = CrossAttention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + attn_drop=attn_drop, + proj_drop=proj_drop, + norm_layer=norm_layer, + ) + self.ls1 = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + self.mlp = Mlp( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + drop=proj_drop, + ) + self.ls2 = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.stop_grad = stop_grad + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + if self.stop_grad: + x = x + self.drop_path1( + self.ls1(self.attn(self.norm1(x), self.norm1(y.detach()))) + ) + else: + x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), self.norm1(y)))) + + x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) + return x + + +class ReConBlocks(nn.Module): + def __init__( + self, + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + qk_norm: bool = False, + init_values: float | None = None, + proj_drop: float = 0.0, + attn_drop_rate: float = 0.0, + drop_path_rate: list = [], + norm_layer: nn.Module = nn.LayerNorm, + act_layer: nn.Module = nn.GELU, + stop_grad: bool = False, + pretrained_model_name: str = "vit_base_patch32_clip_224.openai", + every_layer_add_pos: bool = True, + ): + super().__init__() + + self.depth = depth + self.stop_grad = stop_grad + self.pretrained_model_name = pretrained_model_name + self.every_layer_add_pos = every_layer_add_pos + if "dino" in self.pretrained_model_name: + init_values = 1e-5 + if "giant" in self.pretrained_model_name: + mlp_ratio = 48 / 11 + self.local_blocks = nn.Sequential(*[ + Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + init_values=init_values, + proj_drop=proj_drop, + attn_drop=attn_drop_rate, + drop_path=drop_path_rate[i], + norm_layer=norm_layer, + act_layer=act_layer, + ) + for i in range(depth) + ]) + + self.global_blocks = nn.Sequential(*[ + CrossBlock( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + init_values=init_values, + proj_drop=proj_drop, + attn_drop=attn_drop_rate, + drop_path=drop_path_rate[i], + norm_layer=norm_layer, + act_layer=act_layer, + stop_grad=stop_grad, + ) + for i in range(depth) + ]) + + def load_pretrained_timm_weights(self): + model = timm.create_model(self.pretrained_model_name, pretrained=True) + state_dict = model.blocks.state_dict() + self.local_blocks.load_state_dict(state_dict, strict=True) + + cross_state_dict = {} + for k, v in state_dict.items(): + if "qkv" in k: + cross_state_dict[k.replace("qkv", "q")] = v[: int(v.shape[0] / 3)] + cross_state_dict[k.replace("qkv", "kv")] = v[int(v.shape[0] / 3) :] + else: + cross_state_dict[k] = v + self.global_blocks.load_state_dict(cross_state_dict, strict=True) + + def forward(self, x, pos, attn_mask=None, query=None): + if self.every_layer_add_pos: + for i in range(self.depth): + x = self.local_blocks[i](x + pos, attn_mask) + if query is not None: + query = self.global_blocks[i](query, x) + else: + x = x + pos + for i in range(self.depth): + x = self.local_blocks[i](x, attn_mask) + if query is not None: + query = self.global_blocks[i](query, x) + return x, query + + +class GPTExtractor(nn.Module): + def __init__( + self, + embed_dim: int = 768, + num_heads: int = 12, + depth: int = 12, + group_size: int = 32, + drop_path_rate: float = 0.0, + stop_grad: bool = False, + pretrained_model_name: str = "vit_base_patch32_clip_224.openai", + ): + super().__init__() + + self.embed_dim = embed_dim + self.group_size = group_size + + # start of sequence token + self.sos = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.sos_pos = nn.Parameter(torch.zeros(1, 1, embed_dim)) + nn.init.normal_(self.sos) + nn.init.normal_(self.sos_pos) + + drop_path_rate = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] + self.blocks = ReConBlocks( + embed_dim=embed_dim, + num_heads=num_heads, + depth=depth, + drop_path_rate=drop_path_rate, + stop_grad=stop_grad, + pretrained_model_name=pretrained_model_name, + ) + + self.ln_f1 = nn.LayerNorm(embed_dim) + self.ln_f2 = nn.LayerNorm(embed_dim) + + def forward(self, x, pos, attn_mask, query): + """ + Expect input as shape [sequence len, batch] + """ + + batch, length, _ = x.shape + + # prepend sos token + sos = self.sos.expand(batch, -1, -1) + sos_pos = self.sos_pos.expand(batch, -1, -1) + + x = torch.cat([sos, x[:, :-1]], dim=1) + pos = torch.cat([sos_pos, pos[:, :-1]], dim=1) + + # transformer + x, query = self.blocks(x, pos, attn_mask, query) + + encoded_points = self.ln_f1(x) + query = self.ln_f2(query) + + return encoded_points, query + + +class GPTGenerator(nn.Module): + def __init__( + self, + embed_dim: int = 768, + num_heads: int = 12, + depth: int = 4, + group_size: int = 32, + drop_path_rate: float = 0.0, + input_channel: int = 3, + ): + super().__init__() + + self.embed_dim = embed_dim + self.input_channel = input_channel + + drop_path_rate = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] + self.blocks = nn.ModuleList([ + Block(dim=embed_dim, num_heads=num_heads, drop_path=drop_path_rate[i]) + for i in range(depth) + ]) + + self.ln_f = nn.LayerNorm(embed_dim) + self.increase_dim = nn.Sequential( + nn.Conv1d(embed_dim, input_channel * group_size, 1) + ) + + def forward(self, x, pos, attn_mask): + batch, length, C = x.shape + + # transformer + for block in self.blocks: + x = block(x + pos, attn_mask) + + x = self.ln_f(x) + + rebuild_points = ( + self.increase_dim(x.transpose(1, 2)) + .transpose(1, 2) + .reshape(batch * length, -1, self.input_channel) + ) + + return rebuild_points + + +class MAEExtractor(nn.Module): + def __init__( + self, + embed_dim: int = 768, + num_heads: int = 12, + depth: int = 12, + group_size: int = 32, + drop_path_rate: float = 0.0, + stop_grad: bool = False, + pretrained_model_name: str = "vit_base_patch32_clip_224.openai", + ): + super().__init__() + + self.embed_dim = embed_dim + self.group_size = group_size + + drop_path_rate = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] + self.blocks = ReConBlocks( + embed_dim=embed_dim, + num_heads=num_heads, + depth=depth, + drop_path_rate=drop_path_rate, + stop_grad=stop_grad, + pretrained_model_name=pretrained_model_name, + ) + + self.ln_f1 = nn.LayerNorm(embed_dim) + self.ln_f2 = nn.LayerNorm(embed_dim) + + def forward(self, x, pos, mask=None, query=None): + """ + Expect input as shape [sequence len, batch] + """ + + batch, length, C = x.shape + if mask is not None: + x_vis = x[~mask].reshape(batch, -1, C) + pos_vis = pos[~mask].reshape(batch, -1, C) + else: + x_vis = x + pos_vis = pos + + # transformer + x_vis, query = self.blocks(x_vis, pos_vis, None, query) + + encoded_points = self.ln_f1(x_vis) + query = self.ln_f2(query) + + return encoded_points, query + + +class MAEGenerator(nn.Module): + def __init__( + self, + embed_dim: int = 768, + num_heads: int = 12, + depth: int = 4, + group_size: int = 32, + drop_path_rate: float = 0.0, + input_channel: int = 3, + ): + super().__init__() + + self.embed_dim = embed_dim + self.input_channel = input_channel + self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) + + drop_path_rate = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] + self.blocks = nn.ModuleList([ + Block(dim=embed_dim, num_heads=num_heads, drop_path=drop_path_rate[i]) + for i in range(depth) + ]) + + self.ln_f = nn.LayerNorm(embed_dim) + self.increase_dim = nn.Sequential( + nn.Conv1d(embed_dim, input_channel * group_size, 1) + ) + + def forward(self, x_vis, pos, mask): + batch, length, C = x_vis.shape + + pos_vis = pos[~mask].reshape(batch, -1, C) + pos_mask = pos[mask].reshape(batch, -1, C) + pos_full = torch.cat([pos_vis, pos_mask], dim=1) + mask_token = self.mask_token.expand(batch, pos_mask.shape[1], -1) + x = torch.cat([x_vis, mask_token], dim=1) + + # transformer + for block in self.blocks: + x = block(x + pos_full) + + x = self.ln_f(x[:, -pos_mask.shape[1] :]) + + rebuild_points = ( + self.increase_dim(x.transpose(1, 2)) + .transpose(1, 2) + .reshape(batch * pos_mask.shape[1], -1, self.input_channel) + ) + + return rebuild_points diff --git a/backend/ReConV2/utils/checkpoint.py b/backend/ReConV2/utils/checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..1b78a8936a58c4a12488f65924064aba5b2958ac --- /dev/null +++ b/backend/ReConV2/utils/checkpoint.py @@ -0,0 +1,129 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + +from collections import defaultdict +from collections.abc import Iterable +from typing import Any + +import torch.nn as nn +from termcolor import colored + + +def get_missing_parameters_message(keys: list[str]) -> str: + """ + Get a logging-friendly message to report parameter names (keys) that are in + the model but not found in a checkpoint. + Args: + keys (list[str]): List of keys that were not found in the checkpoint. + Returns: + str: message. + """ + groups = _group_checkpoint_keys(keys) + msg = "Some model parameters or buffers are not found in the checkpoint:\n" + msg += "\n".join( + " " + colored(k + _group_to_str(v), "blue") for k, v in groups.items() + ) + return msg + + +def get_unexpected_parameters_message(keys: list[str]) -> str: + """ + Get a logging-friendly message to report parameter names (keys) that are in + the checkpoint but not found in the model. + Args: + keys (list[str]): List of keys that were not found in the model. + Returns: + str: message. + """ + groups = _group_checkpoint_keys(keys) + msg = "The checkpoint state_dict contains keys that are not used by the model:\n" + msg += "\n".join( + " " + colored(k + _group_to_str(v), "magenta") for k, v in groups.items() + ) + return msg + + +def _strip_prefix_if_present(state_dict: dict[str, Any], prefix: str) -> None: + """ + Strip the prefix in metadata, if any. + Args: + state_dict (OrderedDict): a state-dict to be loaded to the model. + prefix (str): prefix. + """ + keys = sorted(state_dict.keys()) + if not all(len(key) == 0 or key.startswith(prefix) for key in keys): + return + + for key in keys: + newkey = key[len(prefix) :] + state_dict[newkey] = state_dict.pop(key) + + # also strip the prefix in metadata, if any.. + try: + metadata = state_dict._metadata # pyre-ignore + except AttributeError: + pass + else: + for key in list(metadata.keys()): + # for the metadata dict, the key can be: + # '': for the DDP module, which we want to remove. + # 'module': for the actual model. + # 'module.xx.xx': for the rest. + + if len(key) == 0: + continue + newkey = key[len(prefix) :] + metadata[newkey] = metadata.pop(key) + + +def _group_checkpoint_keys(keys: list[str]) -> dict[str, list[str]]: + """ + Group keys based on common prefixes. A prefix is the string up to the final + "." in each key. + Args: + keys (list[str]): list of parameter names, i.e. keys in the model + checkpoint dict. + Returns: + dict[list]: keys with common prefixes are grouped into lists. + """ + groups = defaultdict(list) + for key in keys: + pos = key.rfind(".") + if pos >= 0: + head, tail = key[:pos], [key[pos + 1 :]] + else: + head, tail = key, [] + groups[head].extend(tail) + return groups + + +def _group_to_str(group: list[str]) -> str: + """ + Format a group of parameter name suffixes into a loggable string. + Args: + group (list[str]): list of parameter name suffixes. + Returns: + str: formated string. + """ + if len(group) == 0: + return "" + + if len(group) == 1: + return "." + group[0] + + return ".{" + ", ".join(group) + "}" + + +def _named_modules_with_dup( + model: nn.Module, prefix: str = "" +) -> Iterable[tuple[str, nn.Module]]: + """ + The same as `model.named_modules()`, except that it includes + duplicated modules that have more than one name. + """ + yield prefix, model + for name, module in model._modules.items(): # pyre-ignore + if module is None: + continue + submodule_prefix = prefix + ("." if prefix else "") + name + yield from _named_modules_with_dup(module, submodule_prefix) diff --git a/backend/ReConV2/utils/config.py b/backend/ReConV2/utils/config.py new file mode 100644 index 0000000000000000000000000000000000000000..33d8319a010b3891bb65204cf7f4902c8dcefb8f --- /dev/null +++ b/backend/ReConV2/utils/config.py @@ -0,0 +1,73 @@ +import os + +import yaml +from easydict import EasyDict + +from .logger import print_log + + +def log_args_to_file(args, pre="args", logger=None): + for key, val in args.__dict__.items(): + print_log(f"{pre}.{key} : {val}", logger=logger) + + +def log_config_to_file(cfg, pre="cfg", logger=None): + for key, val in cfg.items(): + if isinstance(cfg[key], EasyDict): + print_log(f"{pre}.{key} = edict()", logger=logger) + log_config_to_file(cfg[key], pre=pre + "." + key, logger=logger) + continue + print_log(f"{pre}.{key} : {val}", logger=logger) + + +def merge_new_config(config, new_config): + for key, val in new_config.items(): + if not isinstance(val, dict): + if key == "_base_": + with open(new_config["_base_"]) as f: + try: + val = yaml.load(f, Loader=yaml.FullLoader) + except: + val = yaml.load(f) + config[key] = EasyDict() + merge_new_config(config[key], val) + else: + config[key] = val + continue + if key not in config: + config[key] = EasyDict() + merge_new_config(config[key], val) + return config + + +def cfg_from_yaml_file(cfg_file): + config = EasyDict() + with open(cfg_file) as f: + try: + new_config = yaml.load(f, Loader=yaml.FullLoader) + except: + new_config = yaml.load(f) + merge_new_config(config=config, new_config=new_config) + return config + + +def get_config(args, logger=None): + if args.resume: + cfg_path = os.path.join(args.experiment_path, "config.yaml") + if not os.path.exists(cfg_path): + print_log("Failed to resume", logger=logger) + raise FileNotFoundError() + print_log(f"Resume yaml from {cfg_path}", logger=logger) + args.config = cfg_path + config = cfg_from_yaml_file(args.config) + if not args.resume and args.local_rank == 0: + save_experiment_config(args, config, logger) + return config + + +def save_experiment_config(args, config, logger=None): + config_path = os.path.join(args.experiment_path, "config.yaml") + os.system(f"cp {args.config} {config_path}") + print_log( + f"Copy the Config file from {args.config} to {config_path}", logger=logger + ) diff --git a/backend/ReConV2/utils/knn.py b/backend/ReConV2/utils/knn.py new file mode 100644 index 0000000000000000000000000000000000000000..55add0b39056eae2d2ca15ae737f1039b09f4057 --- /dev/null +++ b/backend/ReConV2/utils/knn.py @@ -0,0 +1,37 @@ +import torch + + +def square_distance(src, dst): + """ + Calculate Euclid distance between each two points. + src^T * dst = xn * xm + yn * ym + zn * zm; + sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn; + sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm; + dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2 + = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst + Input: + src: source points, [B, N, C] + dst: target points, [B, M, C] + Output: + dist: per-point square distance, [B, N, M] + """ + B, N, _ = src.shape + _, M, _ = dst.shape + dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) + dist += torch.sum(src**2, -1).view(B, N, 1) + dist += torch.sum(dst**2, -1).view(B, 1, M) + return dist + + +def knn_point(nsample, xyz, new_xyz): + """ + Input: + nsample: max sample number in local region + xyz: all points, [B, N, C] + new_xyz: query points, [B, S, C] + Return: + group_idx: grouped points index, [B, S, nsample] + """ + sqrdists = square_distance(new_xyz, xyz) + _, group_idx = torch.topk(sqrdists, nsample, dim=-1, largest=False, sorted=False) + return group_idx diff --git a/backend/ReConV2/utils/logger.py b/backend/ReConV2/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..97ef82b07ab550354bec4e6bf28338635b91d51c --- /dev/null +++ b/backend/ReConV2/utils/logger.py @@ -0,0 +1,130 @@ +import logging + +import torch.distributed as dist + +logger_initialized = {} + + +def get_root_logger(log_file=None, log_level=logging.INFO, name="main"): + """Get root logger and add a keyword filter to it. + The logger will be initialized if it has not been initialized. By default a + StreamHandler will be added. If `log_file` is specified, a FileHandler will + also be added. The name of the root logger is the top-level package name, + e.g., "mmdet3d". + Args: + log_file (str, optional): File path of log. Defaults to None. + log_level (int, optional): The level of logger. + Defaults to logging.INFO. + name (str, optional): The name of the root logger, also used as a + filter keyword. Defaults to 'mmdet3d'. + Returns: + :obj:`logging.Logger`: The obtained logger + """ + logger = get_logger(name=name, log_file=log_file, log_level=log_level) + # add a logging filter + logging_filter = logging.Filter(name) + logging_filter.filter = lambda record: record.find(name) != -1 + + return logger + + +def get_logger(name, log_file=None, log_level=logging.INFO, file_mode="w"): + """Initialize and get a logger by name. + If the logger has not been initialized, this method will initialize the + logger by adding one or two handlers, otherwise the initialized logger will + be directly returned. During initialization, a StreamHandler will always be + added. If `log_file` is specified and the process rank is 0, a FileHandler + will also be added. + Args: + name (str): Logger name. + log_file (str | None): The log filename. If specified, a FileHandler + will be added to the logger. + log_level (int): The logger level. Note that only the process of + rank 0 is affected, and other processes will set the level to + "Error" thus be silent most of the time. + file_mode (str): The file mode used in opening log file. + Defaults to 'w'. + Returns: + logging.Logger: The expected logger. + """ + logger = logging.getLogger(name) + if name in logger_initialized: + return logger + # handle hierarchical names + # e.g., logger "a" is initialized, then logger "a.b" will skip the + # initialization since it is a child of "a". + for logger_name in logger_initialized: + if name.startswith(logger_name): + return logger + + # handle duplicate logs to the console + # Starting in 1.8.0, PyTorch DDP attaches a StreamHandler (NOTSET) + # to the root logger. As logger.propagate is True by default, this root + # level handler causes logging messages from rank>0 processes to + # unexpectedly show up on the console, creating much unwanted clutter. + # To fix this issue, we set the root logger's StreamHandler, if any, to log + # at the ERROR level. + for handler in logger.root.handlers: + if type(handler) is logging.StreamHandler: + handler.setLevel(logging.ERROR) + + stream_handler = logging.StreamHandler() + handlers = [stream_handler] + + if dist.is_available() and dist.is_initialized(): + rank = dist.get_rank() + else: + rank = 0 + + # only rank 0 will add a FileHandler + if rank == 0 and log_file is not None: + # Here, the default behaviour of the official logger is 'a'. Thus, we + # provide an interface to change the file mode to the default + # behaviour. + file_handler = logging.FileHandler(log_file, file_mode) + handlers.append(file_handler) + + formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + for handler in handlers: + handler.setFormatter(formatter) + handler.setLevel(log_level) + logger.addHandler(handler) + + if rank == 0: + logger.setLevel(log_level) + else: + logger.setLevel(logging.ERROR) + + logger_initialized[name] = True + + return logger + + +def print_log(msg, logger=None, level=logging.INFO): + """Print a log message. + Args: + msg (str): The message to be logged. + logger (logging.Logger | str | None): The logger to be used. + Some special loggers are: + - "silent": no message will be printed. + - other str: the logger obtained with `get_root_logger(logger)`. + - None: The `print()` method will be used to print log messages. + level (int): Logging level. Only available when `logger` is a Logger + object or "root". + """ + if logger is None: + print(msg) + elif isinstance(logger, logging.Logger): + logger.log(level, msg) + elif logger == "silent": + pass + elif isinstance(logger, str): + _logger = get_logger(logger) + _logger.log(level, msg) + else: + raise TypeError( + "logger should be either a logging.Logger object, str, " + f'"silent" or None, but got {type(logger)}' + ) diff --git a/backend/ReConV2/utils/misc.py b/backend/ReConV2/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..569c736481d95e20a3b7708767150ef54c4dafa0 --- /dev/null +++ b/backend/ReConV2/utils/misc.py @@ -0,0 +1,287 @@ +import os +import random +from collections import abc + +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from mpl_toolkits.mplot3d import Axes3D + + +def fps(data: torch.Tensor, number: int) -> torch.Tensor: + B, N, _ = data.shape + device = data.device + + centroids = torch.empty(B, number, dtype=torch.long, device=device) + distances = torch.full((B, N), float("inf"), device=device) + farthest = torch.randint(0, N, (B,), device=device) # случайная первая + + for i in range(number): + centroids[:, i] = farthest + + centroid = data[torch.arange(B, device=device), farthest] # (B,3) + dist = torch.sum((data - centroid[:, None, :]) ** 2, dim=-1) + + distances = torch.minimum(distances, dist) + farthest = torch.max(distances, dim=1).indices # чуть короче + # (или .indices в ≥1.10) + return data.gather(1, centroids[..., None].expand(-1, -1, 3)) + + +def worker_init_fn(worker_id): + np.random.seed(np.random.get_state()[1][0] + worker_id) + + +def build_lambda_sche(opti, config): + if config.get("decay_step") is not None: + + def lr_lbmd(e): + return max(config.lr_decay ** (e / config.decay_step), config.lowest_decay) + + scheduler = torch.optim.lr_scheduler.LambdaLR(opti, lr_lbmd) + else: + raise NotImplementedError() + return scheduler + + +def build_lambda_bnsche(model, config): + if config.get("decay_step") is not None: + + def bnm_lmbd(e): + return max( + config.bn_momentum * config.bn_decay ** (e / config.decay_step), + config.lowest_decay, + ) + + bnm_scheduler = BNMomentumScheduler(model, bnm_lmbd) + else: + raise NotImplementedError() + return bnm_scheduler + + +def set_random_seed(seed, deterministic=False): + """Set random seed. + Args: + seed (int): Seed to be used. + deterministic (bool): Whether to set the deterministic option for + CUDNN backend, i.e., set `torch.backends.cudnn.deterministic` + to True and `torch.backends.cudnn.benchmark` to False. + Default: False. + + # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html + if cuda_deterministic: # slower, more reproducible + cudnn.deterministic = True + cudnn.benchmark = False + else: # faster, less reproducible + cudnn.deterministic = False + cudnn.benchmark = True + + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + if deterministic: + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +def is_seq_of(seq, expected_type, seq_type=None): + """Check whether it is a sequence of some type. + Args: + seq (Sequence): The sequence to be checked. + expected_type (type): Expected type of sequence items. + seq_type (type, optional): Expected sequence type. + Returns: + bool: Whether the sequence is valid. + """ + if seq_type is None: + exp_seq_type = abc.Sequence + else: + assert isinstance(seq_type, type) + exp_seq_type = seq_type + if not isinstance(seq, exp_seq_type): + return False + for item in seq: + if not isinstance(item, expected_type): + return False + return True + + +def set_bn_momentum_default(bn_momentum): + def fn(m): + if isinstance(m, nn.BatchNorm1d | nn.BatchNorm2d | nn.BatchNorm3d): + m.momentum = bn_momentum + + return fn + + +class BNMomentumScheduler: + def __init__(self, model, bn_lambda, last_epoch=-1, setter=set_bn_momentum_default): + if not isinstance(model, nn.Module): + raise RuntimeError( + f"Class '{type(model).__name__}' is not a PyTorch nn Module" + ) + + self.model = model + self.setter = setter + self.lmbd = bn_lambda + + self.step(last_epoch + 1) + self.last_epoch = last_epoch + + def step(self, epoch=None): + if epoch is None: + epoch = self.last_epoch + 1 + + self.last_epoch = epoch + self.model.apply(self.setter(self.lmbd(epoch))) + + def get_momentum(self, epoch=None): + if epoch is None: + epoch = self.last_epoch + 1 + return self.lmbd(epoch) + + +def seprate_point_cloud(xyz, num_points, crop, fixed_points=None, padding_zeros=False): + """ + seprate point cloud: usage : using to generate the incomplete point cloud with a setted number. + """ + _, n, c = xyz.shape + + assert n == num_points + assert c == 3 + if crop == num_points: + return xyz, None + + INPUT = [] + CROP = [] + for points in xyz: + if isinstance(crop, list): + num_crop = random.randint(crop[0], crop[1]) + else: + num_crop = crop + + points = points.unsqueeze(0) + + if fixed_points is None: + center = F.normalize(torch.randn(1, 1, 3), p=2, dim=-1).cuda() + else: + if isinstance(fixed_points, list): + fixed_point = random.sample(fixed_points, 1)[0] + else: + fixed_point = fixed_points + center = fixed_point.reshape(1, 1, 3).cuda() + + distance_matrix = torch.norm( + center.unsqueeze(2) - points.unsqueeze(1), p=2, dim=-1 + ) # 1 1 2048 + + idx = torch.argsort(distance_matrix, dim=-1, descending=False)[0, 0] # 2048 + + if padding_zeros: + input_data = points.clone() + input_data[0, idx[:num_crop]] = input_data[0, idx[:num_crop]] * 0 + + else: + input_data = points.clone()[0, idx[num_crop:]].unsqueeze(0) # 1 N 3 + + crop_data = points.clone()[0, idx[:num_crop]].unsqueeze(0) + + if isinstance(crop, list): + INPUT.append(fps(input_data, 2048)) + CROP.append(fps(crop_data, 2048)) + else: + INPUT.append(input_data) + CROP.append(crop_data) + + input_data = torch.cat(INPUT, dim=0) # B N 3 + crop_data = torch.cat(CROP, dim=0) # B M 3 + + return input_data.contiguous(), crop_data.contiguous() + + +def get_ptcloud_img(ptcloud, roll, pitch): + fig = plt.figure(figsize=(8, 8)) + + x, z, y = ptcloud.transpose(1, 0) + ax = fig.gca(projection=Axes3D.name, adjustable="box") + ax.axis("off") + # ax.axis('scaled') + ax.view_init(roll, pitch) + max, min = np.max(ptcloud), np.min(ptcloud) + ax.set_xbound(min, max) + ax.set_ybound(min, max) + ax.set_zbound(min, max) + ax.scatter(x, y, z, zdir="z", c=y, cmap="jet") + + fig.canvas.draw() + img = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") + img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + return img + + +def visualize_KITTI( + path, + data_list, + titles=["input", "pred"], + cmap=["bwr", "autumn"], + zdir="y", + xlim=(-1, 1), + ylim=(-1, 1), + zlim=(-1, 1), +): + fig = plt.figure(figsize=(6 * len(data_list), 6)) + cmax = data_list[-1][:, 0].max() + + for i in range(len(data_list)): + data = data_list[i][:-2048] if i == 1 else data_list[i] + color = data[:, 0] / cmax + ax = fig.add_subplot(1, len(data_list), i + 1, projection="3d") + ax.view_init(30, -120) + ax.scatter( + data[:, 0], + data[:, 1], + data[:, 2], + zdir=zdir, + c=color, + vmin=-1, + vmax=1, + cmap=cmap[0], + s=4, + linewidth=0.05, + edgecolors="black", + ) + ax.set_title(titles[i]) + + ax.set_axis_off() + ax.set_xlim(xlim) + ax.set_ylim(ylim) + ax.set_zlim(zlim) + plt.subplots_adjust(left=0, right=1, bottom=0, top=1, wspace=0.2, hspace=0) + if not os.path.exists(path): + os.makedirs(path) + + pic_path = path + ".png" + fig.savefig(pic_path) + + np.save(os.path.join(path, "input.npy"), data_list[0].numpy()) + np.save(os.path.join(path, "pred.npy"), data_list[1].numpy()) + plt.close(fig) + + +def random_dropping(pc, e): + up_num = max(64, 768 // (e // 50 + 1)) + pc = pc + random_num = torch.randint(1, up_num, (1, 1))[0, 0] + pc = fps(pc, random_num) + padding = torch.zeros(pc.size(0), 2048 - pc.size(1), 3).to(pc.device) + pc = torch.cat([pc, padding], dim=1) + return pc + + +def random_scale(partial, scale_range=[0.8, 1.2]): + scale = torch.rand(1).cuda() * (scale_range[1] - scale_range[0]) + scale_range[0] + return partial * scale diff --git a/backend/ReConV2/utils/registry.py b/backend/ReConV2/utils/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..34e6949ed67194691a0a6d0fa952e55fafe1b567 --- /dev/null +++ b/backend/ReConV2/utils/registry.py @@ -0,0 +1,290 @@ +import inspect +import warnings +from functools import partial + +from ReConV2.utils import config + + +class Registry: + """A registry to map strings to classes. + Registered object could be built from registry. + Example: + >>> MODELS = Registry('models') + >>> @MODELS.register_module() + >>> class ResNet: + >>> pass + >>> resnet = MODELS.build(dict(NAME='ResNet')) + Please refer to https://mmcv.readthedocs.io/en/latest/registry.html for + advanced useage. + Args: + name (str): Registry name. + build_func(func, optional): Build function to construct instance from + Registry, func:`build_from_cfg` is used if neither ``parent`` or + ``build_func`` is specified. If ``parent`` is specified and + ``build_func`` is not given, ``build_func`` will be inherited + from ``parent``. Default: None. + parent (Registry, optional): Parent registry. The class registered in + children registry could be built from parent. Default: None. + scope (str, optional): The scope of registry. It is the key to search + for children registry. If not specified, scope will be the name of + the package where class is defined, e.g. mmdet, mmcls, mmseg. + Default: None. + """ + + def __init__(self, name, build_func=None, parent=None, scope=None): + self._name = name + self._module_dict = dict() + self._children = dict() + self._scope = self.infer_scope() if scope is None else scope + + # self.build_func will be set with the following priority: + # 1. build_func + # 2. parent.build_func + # 3. build_from_cfg + if build_func is None: + if parent is not None: + self.build_func = parent.build_func + else: + self.build_func = build_from_cfg + else: + self.build_func = build_func + if parent is not None: + assert isinstance(parent, Registry) + parent._add_children(self) + self.parent = parent + else: + self.parent = None + + def __len__(self): + return len(self._module_dict) + + def __contains__(self, key): + return self.get(key) is not None + + def __repr__(self): + format_str = ( + self.__class__.__name__ + f"(name={self._name}, items={self._module_dict})" + ) + return format_str + + @staticmethod + def infer_scope(): + """Infer the scope of registry. + The name of the package where registry is defined will be returned. + Example: + # in mmdet/models/backbone/resnet.py + >>> MODELS = Registry('models') + >>> @MODELS.register_module() + >>> class ResNet: + >>> pass + The scope of ``ResNet`` will be ``mmdet``. + Returns: + scope (str): The inferred scope name. + """ + # inspect.stack() trace where this function is called, the index-2 + # indicates the frame where `infer_scope()` is called + filename = inspect.getmodule(inspect.stack()[2][0]).__name__ + split_filename = filename.split(".") + return split_filename[0] + + @staticmethod + def split_scope_key(key): + """Split scope and key. + The first scope will be split from key. + Examples: + >>> Registry.split_scope_key('mmdet.ResNet') + 'mmdet', 'ResNet' + >>> Registry.split_scope_key('ResNet') + None, 'ResNet' + Return: + scope (str, None): The first scope. + key (str): The remaining key. + """ + split_index = key.find(".") + if split_index != -1: + return key[:split_index], key[split_index + 1 :] + else: + return None, key + + @property + def name(self): + return self._name + + @property + def scope(self): + return self._scope + + @property + def module_dict(self): + return self._module_dict + + @property + def children(self): + return self._children + + def get(self, key): + """Get the registry record. + Args: + key (str): The class name in string format. + Returns: + class: The corresponding class. + """ + scope, real_key = self.split_scope_key(key) + if scope is None or scope == self._scope: + # get from self + if real_key in self._module_dict: + return self._module_dict[real_key] + else: + # get from self._children + if scope in self._children: + return self._children[scope].get(real_key) + else: + # goto root + parent = self.parent + while parent.parent is not None: + parent = parent.parent + return parent.get(key) + + def build(self, *args, **kwargs): + return self.build_func(*args, **kwargs, registry=self) + + def _add_children(self, registry): + """Add children for a registry. + The ``registry`` will be added as children based on its scope. + The parent registry could build objects from children registry. + Example: + >>> models = Registry('models') + >>> mmdet_models = Registry('models', parent=models) + >>> @mmdet_models.register_module() + >>> class ResNet: + >>> pass + >>> resnet = models.build(dict(NAME='mmdet.ResNet')) + """ + + assert isinstance(registry, Registry) + assert registry.scope is not None + assert registry.scope not in self.children, ( + f"scope {registry.scope} exists in {self.name} registry" + ) + self.children[registry.scope] = registry + + def _register_module(self, module_class, module_name=None, force=False): + if not inspect.isclass(module_class): + raise TypeError(f"module must be a class, but got {type(module_class)}") + + if module_name is None: + module_name = module_class.__name__ + if isinstance(module_name, str): + module_name = [module_name] + for name in module_name: + if not force and name in self._module_dict: + raise KeyError(f"{name} is already registered in {self.name}") + self._module_dict[name] = module_class + + def deprecated_register_module(self, cls=None, force=False): + warnings.warn( + "The old API of register_module(module, force=False) " + "is deprecated and will be removed, please use the new API " + "register_module(name=None, force=False, module=None) instead." + ) + if cls is None: + return partial(self.deprecated_register_module, force=force) + self._register_module(cls, force=force) + return cls + + def register_module(self, name=None, force=False, module=None): + """Register a module. + A record will be added to `self._module_dict`, whose key is the class + name or the specified name, and value is the class itself. + It can be used as a decorator or a normal function. + Example: + >>> backbones = Registry('backbone') + >>> @backbones.register_module() + >>> class ResNet: + >>> pass + >>> backbones = Registry('backbone') + >>> @backbones.register_module(name='mnet') + >>> class MobileNet: + >>> pass + >>> backbones = Registry('backbone') + >>> class ResNet: + >>> pass + >>> backbones.register_module(ResNet) + Args: + name (str | None): The module name to be registered. If not + specified, the class name will be used. + force (bool, optional): Whether to override an existing class with + the same name. Default: False. + module (type): Module class to be registered. + """ + if not isinstance(force, bool): + raise TypeError(f"force must be a boolean, but got {type(force)}") + # NOTE: This is a walkaround to be compatible with the old api, + # while it may introduce unexpected bugs. + if isinstance(name, type): + return self.deprecated_register_module(name, force=force) + + # raise the error ahead of time + if not (name is None or isinstance(name, str) or misc.is_seq_of(name, str)): + raise TypeError( + "name must be either of None, an instance of str or a sequence" + f" of str, but got {type(name)}" + ) + + # use it as a normal method: x.register_module(module=SomeClass) + if module is not None: + self._register_module(module_class=module, module_name=name, force=force) + return module + + # use it as a decorator: @x.register_module() + def _register(cls): + self._register_module(module_class=cls, module_name=name, force=force) + return cls + + return _register + + +def build_from_cfg(cfg, registry, default_args=None): + """Build a module from config dict. + Args: + cfg (edict): Config dict. It should at least contain the key "NAME". + registry (:obj:`Registry`): The registry to search the type from. + Returns: + object: The constructed object. + """ + if not isinstance(cfg, dict): + raise TypeError(f"cfg must be a dict, but got {type(cfg)}") + if "NAME" not in cfg: + if default_args is None or "NAME" not in default_args: + raise KeyError( + '`cfg` or `default_args` must contain the key "NAME", ' + f"but got {cfg}\n{default_args}" + ) + if not isinstance(registry, Registry): + raise TypeError( + f"registry must be an mmcv.Registry object, but got {type(registry)}" + ) + + if not (isinstance(default_args, dict) or default_args is None): + raise TypeError( + f"default_args must be a dict or None, but got {type(default_args)}" + ) + + if default_args is not None: + cfg = config.merge_new_config(cfg, default_args) + + obj_type = cfg.get("NAME") + + if isinstance(obj_type, str): + obj_cls = registry.get(obj_type) + if obj_cls is None: + raise KeyError(f"{obj_type} is not in the {registry.name} registry") + elif inspect.isclass(obj_type): + obj_cls = obj_type + else: + raise TypeError(f"type must be a str or valid type, but got {type(obj_type)}") + try: + return obj_cls(cfg) + except Exception as e: + # Normal TypeError does not print class name. + raise type(e)(f"{obj_cls.__name__}: {e}") diff --git a/backend/cad_retrieval_utils/__init__.py b/backend/cad_retrieval_utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..71fb043e1f2a8628d929959c3ad63bdef2935bd2 --- /dev/null +++ b/backend/cad_retrieval_utils/__init__.py @@ -0,0 +1,3 @@ +from .inference import make_submission + +__all__ = ["make_submission"] diff --git a/backend/cad_retrieval_utils/augmentations.py b/backend/cad_retrieval_utils/augmentations.py new file mode 100644 index 0000000000000000000000000000000000000000..0ee4df51134a06244b44fa5dc9b972ca7544cd85 --- /dev/null +++ b/backend/cad_retrieval_utils/augmentations.py @@ -0,0 +1,15 @@ +from typing import cast + +import torchvision.transforms as T + +from .type_defs import ImageTransform + + +def build_img_transforms(img_size: int) -> ImageTransform: + transform = T.Compose([ + T.Resize(img_size), + T.CenterCrop(img_size), + T.ToTensor(), + T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + ]) + return cast(ImageTransform, transform) diff --git a/backend/cad_retrieval_utils/configs/config.py b/backend/cad_retrieval_utils/configs/config.py new file mode 100644 index 0000000000000000000000000000000000000000..601531018f9e7c1c0302e73b31372eb67591c0b1 --- /dev/null +++ b/backend/cad_retrieval_utils/configs/config.py @@ -0,0 +1,57 @@ +from pathlib import Path + +import torch +from easydict import EasyDict as edict + +CONFIG = edict() + +# --- Конфиг pretrained модели Recon для загрузки --- +CONFIG.model = edict({ + "NAME": "ReCon2", + "group_size": 32, + "num_group": 512, + "mask_ratio": 0.7, + "mask_type": "rand", + "embed_dim": 1024, + "depth": 24, + "drop_path_rate": 0.2, + "num_heads": 16, + "decoder_depth": 4, + "with_color": True, + "stop_grad": False, + "large_embedding": False, + "img_queries": 13, + "text_queries": 3, + "contrast_type": "byol", + "pretrained_model_name": "eva_large_patch14_336.in22k_ft_in22k_in1k", +}) + +# --- Общие параметры --- +CONFIG.npoints = 10_000 +CONFIG.emb_dim = 1280 +CONFIG.img_size = 336 +CONFIG.seed = 42 +CONFIG.device = torch.device("cpu") +CONFIG.text_ratio = 0.3 + +# --- Параметры инференса --- +CONFIG.infer_img_batch_size = 32 +CONFIG.infer_pc_batch_size = 16 +CONFIG.infer_text_batch_size = 32 + +# --- Параметры для MoE --- +CONFIG.train_params = edict() +CONFIG.train_params.n_experts = 8 + +# --- Пути --- +CONFIG.paths = edict() +CONFIG.paths.test_data_root = Path("/kaggle/input/test-final/test") +CONFIG.paths.submission_save_file = Path("./submission.csv") + +# Эти пути будут перезаписаны из командной строки inference_runner.py +CONFIG.paths.model_spec = { + "text_proj": None, + "text_encoder": None, + "moe": None, + "pc_encoder": None +} diff --git a/backend/cad_retrieval_utils/datasets.py b/backend/cad_retrieval_utils/datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..915842df7627571c069f5379cb1e702fb95441ef --- /dev/null +++ b/backend/cad_retrieval_utils/datasets.py @@ -0,0 +1,80 @@ +from pathlib import Path + +import numpy as np +import torch +import trimesh +from PIL import Image +from torch.utils.data import Dataset + +from .type_defs import ImageTransform + + +def normalize_pc(pc: np.ndarray) -> np.ndarray: + centroid = np.mean(pc, axis=0) + pc = pc - centroid + m = np.max(np.sqrt(np.sum(pc**2, axis=1))) + if m < 1e-6: + return pc + pc = pc / m + return pc + + +def create_pc_tensor_with_dummy_color(pc: np.ndarray, npoints: int) -> torch.Tensor: + pc_with_dummy_color = np.zeros((npoints, 6), dtype=np.float32) + pc_with_dummy_color[:, :3] = pc + # Модель ReConV2 ожидает 6 каналов (XYZ + RGB), добавляем нейтральный серый + pc_with_dummy_color[:, 3:6] = 0.5 + return torch.from_numpy(pc_with_dummy_color).float() + + +def load_mesh_safe(mesh_path: Path, npoints: int, seed: int) -> np.ndarray: + """Безопасная загрузка меша с обработкой Scene объектов""" + mesh_data = trimesh.load(str(mesh_path)) + mesh = mesh_data.to_mesh() if isinstance(mesh_data, trimesh.Scene) else mesh_data + pc, _ = trimesh.sample.sample_surface(mesh, npoints, seed=seed) + return np.array(pc, dtype=np.float32) + + +class InferenceMeshDataset(Dataset): + def __init__(self, file_paths: list[str], npoints: int, base_seed: int = 42) -> None: + self.file_paths = file_paths + self.npoints = npoints + self.base_seed = base_seed + + def __len__(self) -> int: + return len(self.file_paths) + + def __getitem__(self, idx: int) -> torch.Tensor: + pc_path = Path(self.file_paths[idx]) + sample_seed = self.base_seed + idx + pc = load_mesh_safe(pc_path, self.npoints, sample_seed) + pc = normalize_pc(pc) + return create_pc_tensor_with_dummy_color(pc, self.npoints) + + +class InferenceImageDataset(Dataset): + def __init__(self, file_paths: list[str], transform: ImageTransform) -> None: + self.file_paths = file_paths + self.transform = transform + + def __len__(self) -> int: + return len(self.file_paths) + + def __getitem__(self, idx: int) -> torch.Tensor: + img_path = self.file_paths[idx] + img = Image.open(img_path).convert("RGB") + return self.transform(img) + + +class InferenceTextDataset(Dataset): + def __init__(self, file_paths: list[str]) -> None: + self.texts = [] + for path in file_paths: + with open(path) as f: + self.texts.append(f.read().strip()) + + def __len__(self) -> int: + return len(self.texts) + + def __getitem__(self, idx: int) -> str: + return self.texts[idx] diff --git a/backend/cad_retrieval_utils/evaluation.py b/backend/cad_retrieval_utils/evaluation.py new file mode 100644 index 0000000000000000000000000000000000000000..0d1ddf707bc60c2a110392b1ca8c39d0c24bfd20 --- /dev/null +++ b/backend/cad_retrieval_utils/evaluation.py @@ -0,0 +1,43 @@ +import numpy as np +import torch +from easydict import EasyDict as edict +from torch.utils.data import DataLoader +from tqdm.auto import tqdm + +from .models import ImageEncoder, InferencePcEncoder, InferenceTextEncoder +from .type_defs import EmbeddingArray + + +@torch.no_grad() +def get_inference_embeddings_text( + model: InferenceTextEncoder, loader: DataLoader, config: edict +) -> EmbeddingArray: + all_embs = [] + for batch in tqdm(loader, desc="Извлечение text эмбеддингов"): + embs = model.encode_text(batch, normalize=True) + all_embs.append(embs.cpu().numpy()) + return np.vstack(all_embs) + + +@torch.no_grad() +def get_inference_embeddings_mesh( + model: InferencePcEncoder, loader: DataLoader, config: edict +) -> EmbeddingArray: + all_embs = [] + for batch in tqdm(loader, desc="Извлечение mesh эмбеддингов"): + batch = batch.to(config.device) + embs = model.encode_pc(batch, normalize=True) + all_embs.append(embs.cpu().numpy()) + return np.vstack(all_embs) + + +@torch.no_grad() +def get_inference_embeddings_image( + model: ImageEncoder, loader: DataLoader, config: edict +) -> EmbeddingArray: + all_embs = [] + for batch in tqdm(loader, desc="Извлечение image эмбеддингов"): + batch = batch.to(config.device) + embs = model.encode_image(batch, normalize=True) + all_embs.append(embs.cpu().numpy()) + return np.vstack(all_embs) diff --git a/backend/cad_retrieval_utils/inference.py b/backend/cad_retrieval_utils/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..d18b2a36805620840fddeea0d7a78b72277bfe56 --- /dev/null +++ b/backend/cad_retrieval_utils/inference.py @@ -0,0 +1,242 @@ +from pathlib import Path + +import numpy as np +import pandas as pd +from easydict import EasyDict as edict +from sklearn.metrics.pairwise import cosine_similarity +from sklearn.preprocessing import normalize +from torch.utils.data import DataLoader + +from .augmentations import build_img_transforms +from .datasets import InferenceImageDataset, InferenceMeshDataset, InferenceTextDataset +from .evaluation import ( + get_inference_embeddings_image, + get_inference_embeddings_mesh, + get_inference_embeddings_text, +) +from .models import ImageEncoder, InferencePcEncoder, InferenceTextEncoder +from .type_defs import CheckpointSpec + + +# --- Загрузчики моделей --- +def load_text_encoder(spec: CheckpointSpec, config: edict) -> InferenceTextEncoder: + text_encoder = InferenceTextEncoder(config).to(config.device) + text_encoder.load_text_weights(str(spec["text_proj"]), str(spec["text_encoder"])) + text_encoder.eval() + return text_encoder + + +def load_pc_encoder(spec: CheckpointSpec, config: edict) -> InferencePcEncoder: + pc_encoder = InferencePcEncoder(config).to(config.device) + pc_encoder.load_pc_encoder_weights(str(spec["pc_encoder"])) + pc_encoder.eval() + return pc_encoder + + +def load_image_encoder(spec: CheckpointSpec, config: edict) -> ImageEncoder: + img_encoder = ImageEncoder(config).to(config.device) + img_encoder.load_moe_weights(str(spec["moe"])) + img_encoder.eval() + return img_encoder + + +# --- Подготовка данных --- +def prepare_all_data(config: edict) -> dict: + test_root = Path(config.paths.test_data_root) + img_transform = build_img_transforms(config.img_size) + data_loaders = {} + data_ids = {} + + # Image-to-Mesh + q_img_paths = sorted(test_root.joinpath("queries_image_to_mesh").glob("*.png")) + g_mesh_for_img_paths = sorted(test_root.joinpath("gallery_mesh_for_image").glob("*.stl")) + data_loaders['q_img'] = DataLoader(InferenceImageDataset([str(p) for p in q_img_paths], img_transform), + batch_size=config.infer_img_batch_size, shuffle=False) + data_loaders['g_mesh_for_img'] = DataLoader( + InferenceMeshDataset([str(p) for p in g_mesh_for_img_paths], config.npoints, config.seed), + batch_size=config.infer_pc_batch_size, shuffle=False) + data_ids['q_img'] = [p.stem for p in q_img_paths] + data_ids['g_mesh_for_img'] = [p.stem for p in g_mesh_for_img_paths] + + # Mesh-to-Image + q_mesh_to_img_paths = sorted(test_root.joinpath("queries_mesh_to_image").glob("*.stl")) + g_img_for_mesh_paths = sorted(test_root.joinpath("gallery_image_for_mesh").glob("*.png")) + data_loaders['q_mesh_to_img'] = DataLoader( + InferenceMeshDataset([str(p) for p in q_mesh_to_img_paths], config.npoints, config.seed), + batch_size=config.infer_pc_batch_size, shuffle=False) + data_loaders['g_img_for_mesh'] = DataLoader( + InferenceImageDataset([str(p) for p in g_img_for_mesh_paths], img_transform), + batch_size=config.infer_img_batch_size, shuffle=False) + data_ids['q_mesh_to_img'] = [p.stem for p in q_mesh_to_img_paths] + data_ids['g_img_for_mesh_paths'] = g_img_for_mesh_paths # Нужны полные пути для группировки + + # Text-to-Mesh + q_text_paths = sorted(test_root.joinpath("queries_text_to_mesh").glob("*.txt")) + g_mesh_for_text_paths = sorted(test_root.joinpath("gallery_mesh_for_text").glob("*.stl")) + data_loaders['q_text'] = DataLoader(InferenceTextDataset([str(p) for p in q_text_paths]), + batch_size=config.infer_text_batch_size, shuffle=False) + data_loaders['g_mesh_for_text'] = DataLoader( + InferenceMeshDataset([str(p) for p in g_mesh_for_text_paths], config.npoints, config.seed), + batch_size=config.infer_pc_batch_size, shuffle=False) + data_ids['q_text'] = [p.stem for p in q_text_paths] + data_ids['g_mesh_for_text'] = [p.stem for p in g_mesh_for_text_paths] + + # Mesh-to-Text + q_mesh_to_text_paths = sorted(test_root.joinpath("queries_mesh_to_text").glob("*.stl")) + g_text_for_mesh_paths = sorted(test_root.joinpath("gallery_text_for_mesh").glob("*.txt")) + data_loaders['q_mesh_to_text'] = DataLoader( + InferenceMeshDataset([str(p) for p in q_mesh_to_text_paths], config.npoints, config.seed), + batch_size=config.infer_pc_batch_size, shuffle=False) + data_loaders['g_text_for_mesh'] = DataLoader(InferenceTextDataset([str(p) for p in g_text_for_mesh_paths]), + batch_size=config.infer_text_batch_size, shuffle=False) + data_ids['q_mesh_to_text'] = [p.stem for p in q_mesh_to_text_paths] + data_ids['g_text_for_mesh'] = [p.stem for p in g_text_for_mesh_paths] + + return data_loaders, data_ids + + +# --- Решатели задач --- +def solve_img2mesh(loaders, ids, model_spec, config) -> pd.DataFrame: + print(" 🖼️ → 📦 Image-to-Mesh: получение эмбеддингов...") + img_encoder = load_image_encoder(model_spec, config) + pc_encoder = load_pc_encoder(model_spec, config) + + query_embs = get_inference_embeddings_image(img_encoder, loaders['q_img'], config) + gallery_embs = get_inference_embeddings_mesh(pc_encoder, loaders['g_mesh_for_img'], config) + + sims = cosine_similarity(query_embs, gallery_embs) + top_indices = np.argsort(sims, axis=1)[:, ::-1][:, :3] + + results = {q_id: [ids['g_mesh_for_img'][j] for j in top_indices[i]] for i, q_id in enumerate(ids['q_img'])} + df = pd.DataFrame(list(results.items()), columns=["image_to_mesh_image", "image_to_mesh_mesh"]) + return df.sort_values("image_to_mesh_image").reset_index(drop=True) + + +def solve_mesh2img(loaders, ids, model_spec, config) -> pd.DataFrame: + print(" 📦 → 🖼️ Mesh-to-Image: получение эмбеддингов...") + pc_encoder = load_pc_encoder(model_spec, config) + img_encoder = load_image_encoder(model_spec, config) + + query_embs = get_inference_embeddings_mesh(pc_encoder, loaders['q_mesh_to_img'], config) + gallery_embs = get_inference_embeddings_image(img_encoder, loaders['g_img_for_mesh'], config) + + gallery_img_model_ids = [p.name.split("_")[0] for p in ids['g_img_for_mesh_paths']] + df_gallery = pd.DataFrame(gallery_embs) + df_gallery["model_id"] = gallery_img_model_ids + mean_embs_df = df_gallery.groupby("model_id").mean() + + avg_gallery_embs = normalize(mean_embs_df.to_numpy(), axis=1) + avg_gallery_ids = mean_embs_df.index.tolist() + + sims = cosine_similarity(query_embs, avg_gallery_embs) + top_indices = np.argsort(sims, axis=1)[:, ::-1][:, :3] + + results = {q_id: [avg_gallery_ids[j] for j in top_indices[i]] for i, q_id in enumerate(ids['q_mesh_to_img'])} + df = pd.DataFrame(list(results.items()), columns=["mesh_to_image_mesh", "mesh_to_image_image"]) + return df.sort_values("mesh_to_image_mesh").reset_index(drop=True) + + +def solve_text2mesh(loaders, ids, model_spec, config) -> pd.DataFrame: + print(" 📝 → 📦 Text-to-Mesh: получение эмбеддингов...") + text_encoder = load_text_encoder(model_spec, config) + pc_encoder = load_pc_encoder(model_spec, config) + + query_embs = get_inference_embeddings_text(text_encoder, loaders['q_text'], config) + gallery_embs = get_inference_embeddings_mesh(pc_encoder, loaders['g_mesh_for_text'], config) + + sims = cosine_similarity(query_embs, gallery_embs) + top_indices = np.argsort(sims, axis=1)[:, ::-1][:, :3] + + results = {q_id: [ids['g_mesh_for_text'][j] for j in top_indices[i]] for i, q_id in enumerate(ids['q_text'])} + df = pd.DataFrame(list(results.items()), columns=["text_to_mesh_text", "text_to_mesh_mesh"]) + return df.sort_values("text_to_mesh_text").reset_index(drop=True) + + +def solve_mesh2text(loaders, ids, model_spec, config) -> pd.DataFrame: + print(" 📦 → 📝 Mesh-to-Text: получение эмбеддингов...") + pc_encoder = load_pc_encoder(model_spec, config) + text_encoder = load_text_encoder(model_spec, config) + + query_embs = get_inference_embeddings_mesh(pc_encoder, loaders['q_mesh_to_text'], config) + gallery_embs = get_inference_embeddings_text(text_encoder, loaders['g_text_for_mesh'], config) + + sims = cosine_similarity(query_embs, gallery_embs) + top_indices = np.argsort(sims, axis=1)[:, ::-1][:, :3] + + results = {q_id: [ids['g_text_for_mesh'][j] for j in top_indices[i]] for i, q_id in + enumerate(ids['q_mesh_to_text'])} + df = pd.DataFrame(list(results.items()), columns=["mesh_to_text_mesh", "mesh_to_text_text"]) + return df.sort_values("mesh_to_text_mesh").reset_index(drop=True) + + +# --- Главная функция --- +def make_submission(config: edict) -> None: + print("\n" + "=" * 60) + print("🚀 Создание submission файла для всех 4 задач") + print("=" * 60) + + model_spec = config.paths.model_spec + loaders, ids = prepare_all_data(config) + + # Решаем все задачи + text2mesh_df = solve_text2mesh(loaders, ids, model_spec, config) + mesh2text_df = solve_mesh2text(loaders, ids, model_spec, config) + img2mesh_df = solve_img2mesh(loaders, ids, model_spec, config) + mesh2img_df = solve_mesh2img(loaders, ids, model_spec, config) + + # Создаем финальный DataFrame с правильной структурой + # 2187 строк для image задач + 100 строк для text задач = 2287 строк + total_rows = 2287 + final_df = pd.DataFrame(index=range(total_rows)) + + # Добавляем колонку id + final_df["id"] = final_df.index + + # Инициализируем все колонки как None + for col in ["image_to_mesh_image", "image_to_mesh_mesh", + "mesh_to_image_mesh", "mesh_to_image_image", + "text_to_mesh_text", "text_to_mesh_mesh", + "mesh_to_text_mesh", "mesh_to_text_text"]: + final_df[col] = None + + # Заполняем image задачи (первые 2187 строк) + # Используем .at для присвоения списков + for i in range(len(img2mesh_df)): + final_df.at[i, "image_to_mesh_image"] = img2mesh_df.loc[i, "image_to_mesh_image"] + final_df.at[i, "image_to_mesh_mesh"] = img2mesh_df.loc[i, "image_to_mesh_mesh"] + + for i in range(len(mesh2img_df)): + final_df.at[i, "mesh_to_image_mesh"] = mesh2img_df.loc[i, "mesh_to_image_mesh"] + final_df.at[i, "mesh_to_image_image"] = mesh2img_df.loc[i, "mesh_to_image_image"] + + # Заполняем text задачи (последние 100 строк, начиная с индекса 2187) + text_start_idx = 2187 + for i in range(len(text2mesh_df)): + final_df.at[text_start_idx + i, "text_to_mesh_text"] = text2mesh_df.loc[i, "text_to_mesh_text"] + final_df.at[text_start_idx + i, "text_to_mesh_mesh"] = text2mesh_df.loc[i, "text_to_mesh_mesh"] + + for i in range(len(mesh2text_df)): + final_df.at[text_start_idx + i, "mesh_to_text_mesh"] = mesh2text_df.loc[i, "mesh_to_text_mesh"] + final_df.at[text_start_idx + i, "mesh_to_text_text"] = mesh2text_df.loc[i, "mesh_to_text_text"] + + # Статистика + print(f"\n📊 Статистика submission:") + print(f" Заполненных image_to_mesh: {final_df['image_to_mesh_image'].notna().sum()}") + print(f" Заполненных mesh_to_image: {final_df['mesh_to_image_mesh'].notna().sum()}") + print(f" Заполненных text_to_mesh: {final_df['text_to_mesh_text'].notna().sum()}") + print(f" Заполненных mesh_to_text: {final_df['mesh_to_text_mesh'].notna().sum()}") + + # Преобразуем списки в строки для CSV + for col in final_df.columns: + if col != "id": # Не трогаем колонку id + mask = final_df[col].apply(lambda x: isinstance(x, list)) + final_df.loc[mask, col] = final_df.loc[mask, col].apply(str) + + # Сохраняем результат + output_path = config.paths.submission_save_file + final_df.to_csv(output_path, index=False) + print(f"\n✅ Файл для сабмита успешно создан: {output_path}") + print(f" Всего строк: {len(final_df)}") + print(f" Image задачи: строки 0-2186") + print(f" Text задачи: строки 2187-2286") + print("=" * 60) diff --git a/backend/cad_retrieval_utils/inference_runner.py b/backend/cad_retrieval_utils/inference_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..7a8c50ee18bfe7a4050ae3e95ef48fba2c6fe84f --- /dev/null +++ b/backend/cad_retrieval_utils/inference_runner.py @@ -0,0 +1,45 @@ +import argparse +from pathlib import Path + +from cad_retrieval_utils.utils import init_environment, load_config + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Inference runner for all 4 tasks", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("--config", required=True, help="Путь к .py-файлу с CONFIG") + parser.add_argument("--pc_encoder", required=True, help="Путь к весам PC encoder") + parser.add_argument("--img_moe", required=True, help="Путь к весам Image MoE head") + parser.add_argument("--text_proj", required=True, help="Путь к весам text projection") + parser.add_argument("--text_encoder", required=True, help="Путь к весам text encoder") + parser.add_argument("--output", default="submission.csv", help="Путь для сохранения submission.csv") + args = parser.parse_args() + + CONFIG = load_config(args.config) + print(f"Using config: {args.config}") + + # Обновляем конфиг путями из аргументов + CONFIG.paths.model_spec = { + "pc_encoder": args.pc_encoder, + "moe": args.img_moe, + "text_proj": args.text_proj, + "text_encoder": args.text_encoder, + } + CONFIG.paths.submission_save_file = Path(args.output) + + # Проверка существования файлов + for key, path in CONFIG.paths.model_spec.items(): + if path and not Path(path).exists(): + raise FileNotFoundError(f"Файл не найден: {key} -> {path}") + + init_environment(CONFIG) + + # Импортируем после инициализации + from cad_retrieval_utils.inference import make_submission + make_submission(CONFIG) + + +if __name__ == "__main__": + main() diff --git a/backend/cad_retrieval_utils/models.py b/backend/cad_retrieval_utils/models.py new file mode 100644 index 0000000000000000000000000000000000000000..fd60ca15eb1bb3547dae200d3d0660473fd56ce8 --- /dev/null +++ b/backend/cad_retrieval_utils/models.py @@ -0,0 +1,124 @@ +from typing import cast + +import open_clip +import timm +import torch +import torch.nn as nn +import torch.nn.functional as F +from easydict import EasyDict as edict + +from ReConV2.models.ReCon import ReCon2 + + +# --- Базовый PC Encoder (общий для всех) --- +class BasePcEncoder(nn.Module): + def __init__(self, config: edict): + super().__init__() + self.text_ratio = config.text_ratio + self.pc_encoder_base = ReCon2(config.model) + self.config = config + + def encode_pc(self, pc: torch.Tensor, normalize: bool) -> torch.Tensor: + img_token, text_token, _, _ = self.pc_encoder_base.forward_features(pc) + img_pred_feat = torch.mean(img_token, dim=1) + text_pred_feat = torch.mean(text_token, dim=1) + pc_feats = img_pred_feat + text_pred_feat * self.text_ratio + return F.normalize(pc_feats, dim=-1) if normalize else pc_feats + + +# --- Модели для Text-Mesh --- +class TextEncoder(nn.Module): + def __init__(self, config: edict) -> None: + super().__init__() + self.config = config + model, _, _ = open_clip.create_model_and_transforms( + 'EVA02-L-14-336', + pretrained='merged2b_s6b_b61k' + ) + self.text_encoder = model + self.tokenizer = open_clip.get_tokenizer('EVA02-L-14-336') + + text_dim = 768 + self.text_proj = nn.Sequential( + nn.Linear(text_dim, config.emb_dim), + nn.ReLU(), + nn.Linear(config.emb_dim, config.emb_dim) + ) + + def encode_text(self, texts: list[str], normalize: bool = True) -> torch.Tensor: + tokens = self.tokenizer(texts).to(self.config.device) + text_features = self.text_encoder.encode_text(tokens) + text_embeddings = self.text_proj(text_features.float()) + return F.normalize(text_embeddings, dim=-1) if normalize else text_embeddings + + +class InferenceTextEncoder(nn.Module): + def __init__(self, config: edict) -> None: + super().__init__() + self.encoder = TextEncoder(config) + + def load_text_weights(self, text_proj_path: str, text_encoder_path: str) -> None: + self.encoder.text_proj.load_state_dict(torch.load(text_proj_path, map_location="cpu"), strict=True) + print(f"✅ Text projection weights loaded from {text_proj_path}") + + + checkpoint = torch.load(text_encoder_path, map_location="cpu") + # Загружаем только те параметры, которые есть в чекпоинте + missing, unexpected = self.encoder.text_encoder.load_state_dict(checkpoint, strict=False) + print(f"✅ Text encoder weights loaded from {text_encoder_path}") + if missing: + print(f" ℹ️ Missing keys (expected, frozen params): {len(missing)}") + if unexpected: + raise Exception(f" ⚠️ Unexpected keys: {unexpected}") + #strict=False так как последние слои грузим(при обучении 4 слоя размораживали, их сохранили и их же грузим в инференсе) + + + + def encode_text(self, texts: list[str], normalize: bool = True) -> torch.Tensor: + return self.encoder.encode_text(texts, normalize) + + +# --- Модель для PC --- +class InferencePcEncoder(BasePcEncoder): + def __init__(self, config: edict) -> None: + super().__init__(config) + + def load_pc_encoder_weights(self, checkpoint_path: str) -> None: + self.pc_encoder_base.load_state_dict(torch.load(checkpoint_path, map_location="cpu"), strict=True) + print(f"✅ PC encoder weights loaded from {checkpoint_path}") + + +# --- Модели для Image-Mesh --- +class MoEImgHead(nn.Module): + def __init__(self, in_dim: int, out_dim: int, n_experts: int = 8) -> None: + super().__init__() + self.experts = nn.ModuleList([nn.Linear(in_dim, out_dim) for _ in range(n_experts)]) + self.gate = nn.Sequential(nn.LayerNorm(in_dim), nn.Linear(in_dim, n_experts)) + + def forward(self, feats: torch.Tensor, normalize: bool) -> torch.Tensor: + logits = self.gate(feats) + w = torch.softmax(logits, dim=-1) + outs = torch.stack([e(feats) for e in self.experts], dim=1) + out = (w.unsqueeze(-1) * outs).sum(1) + return F.normalize(out, dim=-1) if normalize else out + + +class ImageEncoder(nn.Module): + def __init__(self, config: edict) -> None: + super().__init__() + self.model = timm.create_model(config.model.pretrained_model_name, pretrained=True, num_classes=0) + + self.img_proj = MoEImgHead( + config.model.embed_dim, + config.emb_dim, + n_experts=config.train_params.n_experts, + ) + + def encode_image(self, image: torch.Tensor, normalize: bool = True) -> torch.Tensor: + image_features = self.model(image) #вызываем под декоратором @torch.no_grad в evolution + image_embeddings = self.img_proj(image_features.float(), normalize=normalize) + return cast(torch.Tensor, image_embeddings) + + def load_moe_weights(self, checkpoint_path: str) -> None: + self.img_proj.load_state_dict(torch.load(checkpoint_path, map_location="cpu"), strict=True) + print(f"✅ MoE weights loaded from {checkpoint_path}") diff --git a/backend/cad_retrieval_utils/type_defs.py b/backend/cad_retrieval_utils/type_defs.py new file mode 100644 index 0000000000000000000000000000000000000000..7045eeba297dcbedff4adc75526e6943fd3257c0 --- /dev/null +++ b/backend/cad_retrieval_utils/type_defs.py @@ -0,0 +1,27 @@ +import os +from collections.abc import Callable +from pathlib import Path +from typing import TypeAlias, TypedDict + +import numpy as np +import torch +from PIL import Image + +# --- Примитивные псевдонимы --- +ModelID: TypeAlias = str +PathLike: TypeAlias = str | Path | os.PathLike[str] +ImageTransform: TypeAlias = Callable[[Image.Image], torch.Tensor] + +# --- Типы для NumPy массивов --- +EmbeddingArray: TypeAlias = np.ndarray + + +# --- Спецификация чекпоинтов для инференса --- +class CheckpointSpec(TypedDict): + # Пути для text-to-mesh + text_proj: PathLike + text_encoder: PathLike + + # Пути для image-to-mesh + moe: PathLike + pc_encoder: PathLike diff --git a/backend/cad_retrieval_utils/utils.py b/backend/cad_retrieval_utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4a815c68f5fbf1de713ab2b22a89fc69426e89d9 --- /dev/null +++ b/backend/cad_retrieval_utils/utils.py @@ -0,0 +1,91 @@ +import importlib.util +import os +import random +from pathlib import Path + +import numpy as np +import torch +from easydict import EasyDict as edict + + +def load_config(config_path: str) -> edict: + CONFIG = edict() + + # --- Конфиг pretrained модели Recon для загрузки --- + CONFIG.model = edict({ + "NAME": "ReCon2", + "group_size": 32, + "num_group": 512, + "mask_ratio": 0.7, + "mask_type": "rand", + "embed_dim": 1024, + "depth": 24, + "drop_path_rate": 0.2, + "num_heads": 16, + "decoder_depth": 4, + "with_color": True, + "stop_grad": False, + "large_embedding": False, + "img_queries": 13, + "text_queries": 3, + "contrast_type": "byol", + "pretrained_model_name": "eva_large_patch14_336.in22k_ft_in22k_in1k", + }) + + # --- Общие параметры --- + CONFIG.npoints = 10_000 + CONFIG.emb_dim = 1280 + CONFIG.img_size = 336 + CONFIG.seed = 42 + CONFIG.device = torch.device("cpu") + CONFIG.text_ratio = 0.3 + + # --- Параметры инференса --- + CONFIG.infer_img_batch_size = 32 + CONFIG.infer_pc_batch_size = 16 + CONFIG.infer_text_batch_size = 32 + + # --- Параметры для MoE --- + CONFIG.train_params = edict() + CONFIG.train_params.n_experts = 8 + + # --- Пути --- + CONFIG.paths = edict() + CONFIG.paths.test_data_root = Path("/kaggle/input/test-final/test") + CONFIG.paths.submission_save_file = Path("./submission.csv") + + # Эти пути будут перезаписаны из командной строки inference_runner.py + CONFIG.paths.model_spec = { + "text_proj": None, + "text_encoder": None, + "moe": None, + "pc_encoder": None + } + return CONFIG + + +def init_environment(config: edict) -> None: + SEED = config.seed + + # Все используют один и тот же базовый сид + random.seed(SEED) + os.environ["PYTHONHASHSEED"] = str(SEED) + np.random.seed(SEED) + torch.manual_seed(SEED) + torch.cuda.manual_seed(SEED) + torch.cuda.manual_seed_all(SEED) + + # CuDNN настройки + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + # Отключение TF32 + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = False + + # Детерминированные алгоритмы + torch.use_deterministic_algorithms(True) + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" + + + print(f"✅ Детерминированная среда установлена с seed = {SEED}") diff --git a/backend/config.py b/backend/config.py new file mode 100644 index 0000000000000000000000000000000000000000..601531018f9e7c1c0302e73b31372eb67591c0b1 --- /dev/null +++ b/backend/config.py @@ -0,0 +1,57 @@ +from pathlib import Path + +import torch +from easydict import EasyDict as edict + +CONFIG = edict() + +# --- Конфиг pretrained модели Recon для загрузки --- +CONFIG.model = edict({ + "NAME": "ReCon2", + "group_size": 32, + "num_group": 512, + "mask_ratio": 0.7, + "mask_type": "rand", + "embed_dim": 1024, + "depth": 24, + "drop_path_rate": 0.2, + "num_heads": 16, + "decoder_depth": 4, + "with_color": True, + "stop_grad": False, + "large_embedding": False, + "img_queries": 13, + "text_queries": 3, + "contrast_type": "byol", + "pretrained_model_name": "eva_large_patch14_336.in22k_ft_in22k_in1k", +}) + +# --- Общие параметры --- +CONFIG.npoints = 10_000 +CONFIG.emb_dim = 1280 +CONFIG.img_size = 336 +CONFIG.seed = 42 +CONFIG.device = torch.device("cpu") +CONFIG.text_ratio = 0.3 + +# --- Параметры инференса --- +CONFIG.infer_img_batch_size = 32 +CONFIG.infer_pc_batch_size = 16 +CONFIG.infer_text_batch_size = 32 + +# --- Параметры для MoE --- +CONFIG.train_params = edict() +CONFIG.train_params.n_experts = 8 + +# --- Пути --- +CONFIG.paths = edict() +CONFIG.paths.test_data_root = Path("/kaggle/input/test-final/test") +CONFIG.paths.submission_save_file = Path("./submission.csv") + +# Эти пути будут перезаписаны из командной строки inference_runner.py +CONFIG.paths.model_spec = { + "text_proj": None, + "text_encoder": None, + "moe": None, + "pc_encoder": None +} diff --git a/backend/download_utils.py b/backend/download_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..29f761ed6db60b615592344947f5aef1dd3e7efe --- /dev/null +++ b/backend/download_utils.py @@ -0,0 +1,56 @@ +import os +import requests +from tqdm.auto import tqdm + +def download_yandex_file(public_file_url: str, destination_path: str, filename: str): + """ + Скачивает один файл с публичного Яндекс.Диска по прямой ссылке на файл. + """ + api_url = "https://cloud-api.yandex.net/v1/disk/public/resources/download" + params = {'public_key': public_file_url} + + print(f"🔎 Получение информации о файле: {filename}...") + try: + response = requests.get(api_url, params=params) + response.raise_for_status() + data = response.json() + download_url = data.get('href') + + if not download_url: + print(f"❌ Не удалось получить URL для скачивания файла '{filename}'. Ответ API: {data}") + return False + + except requests.exceptions.RequestException as e: + print(f"❌ Ошибка при получении информации о файле '{filename}': {e}") + return False + except KeyError as e: + print(f"❌ Ошибка при разборе ответа API для '{filename}': отсутствует ключ {e}. Ответ: {data}") + return False + + full_path = os.path.join(destination_path, filename) + os.makedirs(destination_path, exist_ok=True) + + print(f"📥 Скачивание '{filename}' в '{full_path}'...") + try: + size_response = requests.head(download_url) + total_size = int(size_response.headers.get('content-length', 0)) + + download_response = requests.get(download_url, stream=True) + download_response.raise_for_status() + + with open(full_path, 'wb') as f: + with tqdm(total=total_size, unit='B', unit_scale=True, desc=filename) as pbar: + for chunk in download_response.iter_content(chunk_size=8192): + if chunk: + f.write(chunk) + pbar.update(len(chunk)) + + except requests.exceptions.RequestException as e: + print(f"\n❌ Ошибка при скачивании файла '{filename}': {e}") + return False + except Exception as e: + print(f"\n❌ Неожиданная ошибка при скачивании '{filename}': {e}") + return False + + print(f"🎉 Файл '{filename}' успешно скачан.") + return True \ No newline at end of file diff --git a/backend/inference_utils.py b/backend/inference_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ebeb0a0aa36d6aa9faa1e06e925f18efd5bff4a0 --- /dev/null +++ b/backend/inference_utils.py @@ -0,0 +1,355 @@ +# app/inference_utils.py + +import base64 +import tempfile +import uuid +import zipfile +from io import BytesIO +from pathlib import Path +import datetime +from typing import Callable # Add this import + +import numpy as np +import torch +from easydict import EasyDict as edict +from PIL import Image +from sklearn.metrics.pairwise import cosine_similarity +from torch.utils.data import DataLoader + +from cad_retrieval_utils.augmentations import build_img_transforms +from cad_retrieval_utils.datasets import (InferenceImageDataset, + InferenceMeshDataset, + InferenceTextDataset) +from cad_retrieval_utils.evaluation import (get_inference_embeddings_image, + get_inference_embeddings_mesh, + get_inference_embeddings_text) +from cad_retrieval_utils.inference import (load_image_encoder, load_pc_encoder, + load_text_encoder) +from cad_retrieval_utils.models import (ImageEncoder, InferencePcEncoder, + InferenceTextEncoder) +from cad_retrieval_utils.utils import init_environment, load_config + +CONFIG: edict = None +IMG_TRANSFORM = None +PC_ENCODER: InferencePcEncoder = None +IMG_ENCODER: ImageEncoder = None +TEXT_ENCODER: InferenceTextEncoder = None +DATASET_CACHE = {} +TOP_K_MATCHES = 5 + +def load_models_and_config(config_path: str, model_paths: dict) -> None: + # This function is unchanged + global CONFIG, IMG_TRANSFORM, PC_ENCODER, IMG_ENCODER, TEXT_ENCODER + print("🚀 Загрузка конфигурации и моделей...") + if CONFIG is not None: + print(" Модели уже загружены.") + return + try: + CONFIG = load_config(config_path) + CONFIG.paths.model_spec = model_paths + init_environment(CONFIG) + PC_ENCODER = load_pc_encoder(CONFIG.paths.model_spec, CONFIG) + IMG_ENCODER = load_image_encoder(CONFIG.paths.model_spec, CONFIG) + TEXT_ENCODER = load_text_encoder(CONFIG.paths.model_spec, CONFIG) + IMG_TRANSFORM = build_img_transforms(CONFIG.img_size) + print("✅ Все модели успешно загружены в память.") + except Exception as e: + print(f"🔥 Критическая ошибка при загрузке моделей: {e}") + raise + +@torch.no_grad() +def get_embedding_for_single_item(modality: str, content_bytes: bytes) -> np.ndarray: + # This function is unchanged + if modality == "image": + image = Image.open(BytesIO(content_bytes)).convert("RGB") + tensor = IMG_TRANSFORM(image).unsqueeze(0).to(CONFIG.device) + emb = IMG_ENCODER.encode_image(tensor, normalize=True) + return emb.cpu().numpy() + if modality == "text": + text = content_bytes.decode("utf-8") + emb = TEXT_ENCODER.encode_text([text], normalize=True) + return emb.cpu().numpy() + if modality == "mesh": + with tempfile.NamedTemporaryFile(suffix=".stl", delete=True) as tmp: + tmp.write(content_bytes) + tmp.flush() + dataset = InferenceMeshDataset([tmp.name], CONFIG.npoints, CONFIG.seed) + tensor = dataset[0].unsqueeze(0).to(CONFIG.device) + emb = PC_ENCODER.encode_pc(tensor, normalize=True) + return emb.cpu().numpy() + raise ValueError(f"Неизвестная модальность: {modality}") + +def process_uploaded_zip( + zip_file_bytes: bytes, + original_filename: str, + update_status: Callable[[str, int], None] +) -> dict: + """ + Основная функция для обработки ZIP-архива с обратными вызовами для обновления статуса. + """ + dataset_id = str(uuid.uuid4()) + print(f"⚙️ Начало обработки нового датасета: {original_filename} (ID: {dataset_id})") + update_status("Starting", 0) + + with tempfile.TemporaryDirectory() as tmpdir: + tmp_path = Path(tmpdir) + zip_path = tmp_path / "data.zip" + zip_path.write_bytes(zip_file_bytes) + + update_status("Unpacking Files", 5) + with zipfile.ZipFile(zip_path, 'r') as zip_ref: + zip_ref.extractall(tmp_path) + print(f" 🗂️ Архив распакован в {tmpdir}") + + update_status("Preparing Data", 10) + image_paths = sorted(list(tmp_path.glob("**/*.png"))) + text_paths = sorted(list(tmp_path.glob("**/*.txt"))) + mesh_paths = sorted(list(tmp_path.glob("**/*.stl"))) + + image_ds = InferenceImageDataset([str(p) for p in image_paths], IMG_TRANSFORM) + text_ds = InferenceTextDataset([str(p) for p in text_paths]) + mesh_ds = InferenceMeshDataset([str(p) for p in mesh_paths], CONFIG.npoints, CONFIG.seed) + + image_loader = DataLoader(image_ds, batch_size=CONFIG.infer_img_batch_size, shuffle=False) + text_loader = DataLoader(text_ds, batch_size=CONFIG.infer_text_batch_size, shuffle=False) + mesh_loader = DataLoader(mesh_ds, batch_size=CONFIG.infer_pc_batch_size, shuffle=False) + + print(" 🧠 Вычисление эмбеддингов...") + update_status("Processing Images", 15) + image_embs = get_inference_embeddings_image(IMG_ENCODER, image_loader, CONFIG) + + update_status("Processing Texts", 50) + text_embs = get_inference_embeddings_text(TEXT_ENCODER, text_loader, CONFIG) + + update_status("Processing 3D Models", 55) + mesh_embs = get_inference_embeddings_mesh(PC_ENCODER, mesh_loader, CONFIG) + print(" ✅ Эмбеддинги вычислены.") + + update_status("Caching Data", 90) + image_names = [p.name for p in image_paths] + text_names = [p.name for p in text_paths] + mesh_names = [p.name for p in mesh_paths] + + image_items = [{"id": f"image_{i}", "name": name, "content": base64.b64encode(p.read_bytes()).decode('utf-8')} for i, (p, name) in enumerate(zip(image_paths, image_names))] + text_items = [{"id": f"text_{i}", "name": name, "content": p.read_text()} for i, (p, name) in enumerate(zip(text_paths, text_names))] + mesh_items = [{"id": f"mesh_{i}", "name": name, "content": base64.b64encode(p.read_bytes()).decode('utf-8')} for i, (p, name) in enumerate(zip(mesh_paths, mesh_names))] + + dataset_data = {"images": image_items, "texts": text_items, "meshes": mesh_items} + + DATASET_CACHE[dataset_id] = { + "data": dataset_data, + "embeddings": { + "image": (image_names, image_embs), + "text": (text_names, text_embs), + "mesh": (mesh_names, mesh_embs) + } + } + print(f" 💾 Датасет {dataset_id} сохранен в кэш.") + + print(" ⚖️ Вычисление полной матрицы схожести...") + update_status("Building Matrix", 95) + full_comparison = {"images": [], "texts": [], "meshes": []} + + all_embeddings = { + "image": (image_names, image_embs), + "text": (text_names, text_embs), + "mesh": (mesh_names, mesh_embs) + } + + for source_modality, (source_names, source_embs) in all_embeddings.items(): + for i, source_name in enumerate(source_names): + source_emb = source_embs[i:i+1] + matches = {} + for target_modality, (target_names, target_embs) in all_embeddings.items(): + if not target_names: continue + sims = cosine_similarity(source_emb, target_embs).flatten() + + if source_modality == target_modality: + sims[i] = -1 + + top_indices = np.argsort(sims)[::-1][:TOP_K_MATCHES] + matches[target_modality] = [ + {"item": target_names[j], "confidence": float(sims[j])} for j in top_indices if sims[j] > -1 + ] + + key_name = "meshes" if source_modality == "mesh" else source_modality + 's' + full_comparison[key_name].append({"source": source_name, "matches": matches}) + + print(" ✅ Матрица схожести готова.") + + final_response = { + "id": dataset_id, + "name": original_filename, + "uploadDate": datetime.datetime.utcnow().isoformat() + "Z", + "data": dataset_data, + "processingState": "processed", + "processingProgress": 100, + "fullComparison": full_comparison + } + + print(f"✅ Обработка датасета {dataset_id} завершена.") + return final_response + + +def process_shared_dataset_directory(directory_path: Path, embeddings_path: Path, dataset_id: str, dataset_name: str) -> dict: + # This function is unchanged + print(f"⚙️ Начало обработки общего датасета: {dataset_name} (ID: {dataset_id})") + print(" 📂 Сканирование файлов данных...") + image_paths = sorted(list(directory_path.glob("**/*.png"))) + text_paths = sorted(list(directory_path.glob("**/*.txt"))) + mesh_paths = sorted(list(directory_path.glob("**/*.stl"))) + if not any([image_paths, text_paths, mesh_paths]): + print(f"⚠️ В директории общего датасета '{directory_path}' не найдено файлов.") + return None + print(f" ✅ Найдено: {len(image_paths)} изображений, {len(text_paths)} текстов, {len(mesh_paths)} моделей.") + print(" 🧠 Индексирование предварительно вычисленных эмбеддингов...") + all_embedding_paths = list(embeddings_path.glob("**/*.npy")) + embedding_map = {p.stem: p for p in all_embedding_paths} + print(f" ✅ Найдено {len(embedding_map)} файлов эмбеддингов.") + def load_embeddings_for_paths(data_paths: list[Path]): + names = [] + embs_list = [] + for data_path in data_paths: + file_stem = data_path.stem + if file_stem in embedding_map: + embedding_path = embedding_map[file_stem] + try: + emb = np.load(embedding_path) + embs_list.append(emb) + names.append(data_path.name) + except Exception as e: + print(f" ⚠️ Не удалось загрузить или разобрать эмбеддинг для {data_path.name}: {e}") + else: + print(f" ⚠️ Внимание: не найден соответствующий эмбеддинг для {data_path.name}") + return names, np.array(embs_list) if embs_list else np.array([]) + print(" 🚚 Загрузка и сопоставление эмбеддингов...") + image_names, image_embs = load_embeddings_for_paths(image_paths) + text_names, text_embs = load_embeddings_for_paths(text_paths) + mesh_names, mesh_embs = load_embeddings_for_paths(mesh_paths) + print(" ✅ Эмбеддинги для общего датасета загружены.") + static_root = Path("static") + image_items = [{"id": f"image_{i}", "name": p.name, "content": None, "contentUrl": f"/{p.relative_to(static_root)}"} for i, p in enumerate(image_paths)] + text_items = [{"id": f"text_{i}", "name": p.name, "content": None, "contentUrl": f"/{p.relative_to(static_root)}"} for i, p in enumerate(text_paths)] + mesh_items = [{"id": f"mesh_{i}", "name": p.name, "content": None, "contentUrl": f"/{p.relative_to(static_root)}"} for i, p in enumerate(mesh_paths)] + dataset_data = {"images": image_items, "texts": text_items, "meshes": mesh_items} + DATASET_CACHE[dataset_id] = {"data": dataset_data, "embeddings": {"image": (image_names, image_embs), "text": (text_names, text_embs), "mesh": (mesh_names, mesh_embs)}} + print(f" 💾 Эмбеддинги для общего датасета {dataset_id} сохранены в кэш.") + print(" ⚖️ Вычисление полной матрицы схожести для общего датасета...") + full_comparison = {"images": [], "texts": [], "meshes": []} + all_embeddings = {"image": (image_names, image_embs), "text": (text_names, text_embs), "mesh": (mesh_names, mesh_embs)} + for source_modality, (source_names, source_embs) in all_embeddings.items(): + if len(source_names) == 0: continue + for i, source_name in enumerate(source_names): + source_emb = source_embs[i:i+1] + matches = {} + for target_modality, (target_names, target_embs) in all_embeddings.items(): + if len(target_names) == 0: continue + sims = cosine_similarity(source_emb, target_embs).flatten() + if source_modality == target_modality: + sims[i] = -1 + top_indices = np.argsort(sims)[::-1][:TOP_K_MATCHES] + matches[target_modality] = [{"item": target_names[j], "confidence": float(sims[j])} for j in top_indices if sims[j] > -1] + key_name = "meshes" if source_modality == "mesh" else source_modality + 's' + full_comparison[key_name].append({"source": source_name, "matches": matches}) + print(" ✅ Матрица схожести для общего датасета готова.") + try: + creation_time = datetime.datetime.fromtimestamp(directory_path.stat().st_ctime) + except Exception: + creation_time = datetime.datetime.utcnow() + final_response = {"id": dataset_id, "name": dataset_name, "uploadDate": creation_time.isoformat() + "Z", "data": dataset_data, "processingState": "processed", "processingProgress": 100, "fullComparison": full_comparison, "isShared": True} + print(f"✅ Обработка общего датасета {dataset_id} завершена.") + return final_response + +def find_matches_for_item(modality: str, content_base64: str, dataset_id: str) -> dict: + # This function is unchanged + print(f"🔍 Поиск совпадений для объекта ({modality}) в датасете {dataset_id}...") + if dataset_id not in DATASET_CACHE: + raise ValueError(f"Датасет с ID {dataset_id} не найден в кэше.") + content_bytes = base64.b64decode(content_base64) + source_emb = get_embedding_for_single_item(modality, content_bytes) + cached_dataset = DATASET_CACHE[dataset_id] + results = {} + for target_modality, (target_names, target_embs) in cached_dataset["embeddings"].items(): + key_name = "meshes" if target_modality == "mesh" else target_modality + 's' + if not target_names: continue + sims = cosine_similarity(source_emb, target_embs).flatten() + top_indices = np.argsort(sims)[::-1][:TOP_K_MATCHES] + target_items_map = {item['name']: item for item in cached_dataset['data'][key_name]} + matches = [] + for j in top_indices: + item_name = target_names[j] + if item_name in target_items_map: + matches.append({"item": target_items_map[item_name], "confidence": float(sims[j])}) + results[key_name] = matches + print(" ✅ Поиск завершен.") + return {"results": results} + +def cache_local_dataset(dataset: dict) -> None: + """ + Receives a full dataset object from the frontend, computes embeddings, + and loads it into the in-memory cache. + """ + dataset_id = dataset.get('id') + if not dataset_id: + print("⚠️ Attempted to cache a dataset without an ID.") + return + + if dataset_id in DATASET_CACHE: + print(f"✅ Dataset {dataset_id} is already in the backend cache. Skipping re-hydration.") + return + + print(f"🧠 Re-hydrating backend cache for local dataset ID: {dataset_id}") + + try: + all_embeddings = {} + all_names = {} + + # The content comes in different formats (data URL for images, text for text, etc.) + # We need to decode it before sending to the embedding function. + def get_bytes_from_content(content_str: str, modality: str) -> bytes: + if modality in ['image', 'mesh']: + # Handle data URLs (e.g., "data:image/png;base64,...") or raw base64 + if ',' in content_str: + header, encoded = content_str.split(',', 1) + return base64.b64decode(encoded) + else: + return base64.b64decode(content_str) + else: # text + return content_str.encode('utf-8') + + + for modality_plural, items in dataset.get('data', {}).items(): + modality_singular = "mesh" if modality_plural == "meshes" else modality_plural[:-1] + + names = [] + embs_list = [] + + print(f" ⚙️ Processing {len(items)} items for modality: {modality_singular}") + + for item in items: + item_content = item.get('content') + if not item_content: + continue + + content_bytes = get_bytes_from_content(item_content, modality_singular) + embedding = get_embedding_for_single_item(modality_singular, content_bytes) + + embs_list.append(embedding[0]) # get_embedding returns shape (1, D) + names.append(item.get('name')) + + all_names[modality_singular] = names + all_embeddings[modality_singular] = np.array(embs_list) if embs_list else np.array([]) + + # Structure the cache entry exactly like process_uploaded_zip does + DATASET_CACHE[dataset_id] = { + "data": dataset.get('data'), + "embeddings": { + mod: (all_names[mod], all_embeddings[mod]) for mod in all_embeddings + } + } + print(f" ✅ Successfully cached {dataset_id} with embeddings.") + + except Exception as e: + print(f"🔥 CRITICAL ERROR while re-hydrating cache for {dataset_id}: {e}") + import traceback + traceback.print_exc() \ No newline at end of file diff --git a/backend/main.py b/backend/main.py new file mode 100644 index 0000000000000000000000000000000000000000..e612c86a0e064504158f2b7157459c8314e90c10 --- /dev/null +++ b/backend/main.py @@ -0,0 +1,278 @@ +# app/main.py + +import os +import asyncio +from pathlib import Path +import zipfile +import io +import requests +import uuid # Add this import +from fastapi import FastAPI, UploadFile, File, HTTPException, BackgroundTasks # Add BackgroundTasks +from fastapi.staticfiles import StaticFiles +from fastapi.middleware.cors import CORSMiddleware +from pydantic import BaseModel, Field +from typing import List, Dict, Any + +# Импортируем утилиты +from inference_utils import ( + load_models_and_config, + process_uploaded_zip, + find_matches_for_item, + process_shared_dataset_directory, + cache_local_dataset, +) +# Импортируем нашу новую функцию для скачивания +from download_utils import download_yandex_file + +# --- Инициализация --- +app = FastAPI() + +# Разрешаем CORS +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# --- Глобальные кэши --- +SHARED_DATASET_FULL_DATA = {} +SHARED_DATASET_ID = "shared_dataset_1" +PROCESSING_STATUS = {} # NEW: For tracking progress + +# --- Helper Functions --- + +def download_and_unzip_yandex_archive(public_url: str, destination_dir: Path, description: str): + # This function is unchanged + print(f"--- 📥 Checking for {description} ---") + if destination_dir.exists() and any(destination_dir.iterdir()): + print(f"✅ {description} already exists in '{destination_dir}'. Skipping download.") + return True + print(f"⏳ {description} not found. Starting download from Yandex.Disk...") + destination_dir.mkdir(parents=True, exist_ok=True) + if "YOUR_" in public_url or "ВАША_" in public_url: + print(f"🔥 WARNING: Placeholder URL detected for {description}. Download skipped.") + return False + try: + api_url = "https://cloud-api.yandex.net/v1/disk/public/resources/download" + params = {'public_key': public_url} + response = requests.get(api_url, params=params) + response.raise_for_status() + download_url = response.json().get('href') + if not download_url: + raise RuntimeError(f"Could not retrieve download URL for {description} from Yandex.Disk API.") + print(f" 🔗 Got download link. Fetching ZIP archive for {description}...") + zip_response = requests.get(download_url, stream=True) + zip_response.raise_for_status() + zip_in_memory = io.BytesIO(zip_response.content) + print(f" 🗂️ Unzipping archive for {description}...") + with zipfile.ZipFile(zip_in_memory, 'r') as zip_ref: + zip_ref.extractall(destination_dir) + print(f"🎉 {description} successfully downloaded and extracted to '{destination_dir}'.") + return True + except Exception as e: + print(f"🔥 CRITICAL ERROR downloading or unzipping {description}: {e}") + return False + +# --- NEW: Background Processing Wrapper --- +def background_process_zip(zip_bytes: bytes, original_filename: str, job_id: str): + """Wrapper function to run processing and update status.""" + def update_status(stage: str, progress: int): + """Callback to update the global status dictionary.""" + print(f"Job {job_id}: {stage} - {progress}%") + PROCESSING_STATUS[job_id] = {"stage": stage, "progress": progress, "status": "processing"} + + try: + processed_data = process_uploaded_zip( + zip_bytes, original_filename, update_status + ) + PROCESSING_STATUS[job_id] = { + "status": "complete", + "result": processed_data + } + except Exception as e: + import traceback + traceback.print_exc() + PROCESSING_STATUS[job_id] = { + "status": "error", + "message": f"An error occurred during processing: {e}" + } + +class SingleMatchRequest(BaseModel): + modality: str + content: str + dataset_id: str + +# --- MODIFIED: process-dataset endpoint --- +class ProcessDatasetResponse(BaseModel): + job_id: str + +class DataItemModel(BaseModel): + id: str + name: str + content: str | None = None # Frontend sends content as string (base64 or text) + contentUrl: str | None = None + +class DatasetDataModel(BaseModel): + images: List[DataItemModel] + texts: List[DataItemModel] + meshes: List[DataItemModel] + +class LocalDatasetModel(BaseModel): + id: str + name: str + data: DatasetDataModel + # We only need the core data for re-hydration, other fields are optional + # Use 'Any' for complex fields we don't need to strictly validate here + fullComparison: Dict[str, Any] | None = None + +@app.post("/api/cache-local-dataset") +async def cache_local_dataset_endpoint(dataset: LocalDatasetModel): + """ + Receives a local dataset from the frontend to re-hydrate the server's in-memory cache. + """ + try: + # Pydantic's .dict() is deprecated, use .model_dump() + dataset_dict = dataset.model_dump() + await asyncio.to_thread(cache_local_dataset, dataset_dict) + return {"status": "cached", "id": dataset.id} + except Exception as e: + import traceback + traceback.print_exc() + raise HTTPException(status_code=500, detail=f"Failed to cache dataset: {e}") + +# --- Startup Event --- +@app.on_event("startup") +def startup_event(): + # This function is unchanged + SHARED_DATASET_DIR = Path("static/shared_dataset") + SHARED_EMBEDDINGS_DIR = Path("static/shared_embeddings") + + SHARED_DATASET_ZIP_URL = "https://disk.yandex.ru/d/G9C3_FGGzSLAXw" + SHARED_EMBEDDINGS_ZIP_URL = "https://disk.yandex.ru/d/aVTX6n2pc0hrCw" + dataset_ready = download_and_unzip_yandex_archive(SHARED_DATASET_ZIP_URL, SHARED_DATASET_DIR, "shared dataset files") + embeddings_ready = download_and_unzip_yandex_archive(SHARED_EMBEDDINGS_ZIP_URL, SHARED_EMBEDDINGS_DIR, "pre-computed embeddings") + DATA_DIR = Path("data/") + MODEL_URLS = { + "text_proj.pth": "https://disk.yandex.ru/d/uMH1ls0nYM4txw", + "text_encoder.pth": "https://disk.yandex.ru/d/R0BBLPXj828OhA", + "moe.pth": "https://disk.yandex.ru/d/vDfuIPziuO45wg", + "pc_encoder.pth": "https://disk.yandex.ru/d/03Ps2TMcWAKkww", + } + print("--- 📥 Checking and loading models ---") + DATA_DIR.mkdir(parents=True, exist_ok=True) + all_models_present = True + for filename, url in MODEL_URLS.items(): + destination_file = DATA_DIR / filename + if not destination_file.exists(): + print(f"⏳ Модель '{filename}' не найдена. Начинаю загрузку...") + if not "ВАША_ССЫЛКА" in url: + success = download_yandex_file(public_file_url=url, destination_path=str(DATA_DIR), filename=filename) + if not success: + all_models_present = False + print(f"🔥 Критическая ошибка: не удалось скачать модель '{filename}'.") + else: + all_models_present = False + print(f"🔥 ВНИМАНИЕ: Пропущена загрузка '{filename}', т.к. ссылка является плейсхолдером.") + else: + print(f"✅ Модель '{filename}' уже существует. Пропускаю загрузку.") + if not all_models_present: + raise RuntimeError("Не удалось загрузить все необходимые модели. Приложение не может запуститься.") + print("--- ✅ Все модели готовы к использованию ---") + model_paths = {"text_proj": str(DATA_DIR / "text_proj.pth"), "text_encoder": str(DATA_DIR / "text_encoder.pth"), "moe": str(DATA_DIR / "moe.pth"), "pc_encoder": str(DATA_DIR / "pc_encoder.pth")} + config_path = "cad_retrieval_utils/config/config.py" + try: + load_models_and_config(config_path=config_path, model_paths=model_paths) + print("✅ Все модели успешно загружены в память.") + except Exception as e: + print(f"🔥 Ошибка при загрузке моделей: {e}") + import traceback + traceback.print_exc() + raise RuntimeError(f"Ошибка загрузки моделей, приложение не может запуститься.") from e + if dataset_ready and embeddings_ready: + print("--- 🧠 Loading pre-computed embeddings for shared dataset ---") + try: + full_data = process_shared_dataset_directory(directory_path=SHARED_DATASET_DIR, embeddings_path=SHARED_EMBEDDINGS_DIR, dataset_id=SHARED_DATASET_ID, dataset_name="Cloud Multi-Modal Dataset") + if full_data: + SHARED_DATASET_FULL_DATA[SHARED_DATASET_ID] = full_data + print("--- ✅ Shared dataset processed and cached successfully. ---") + else: + print("--- ⚠️ Shared dataset processing returned no data. Caching skipped. ---") + except Exception as e: + print(f"🔥 CRITICAL ERROR processing shared dataset: {e}") + import traceback + traceback.print_exc() + else: + print("--- ⚠️ Shared dataset or embeddings not available. Processing skipped. ---") + + +# --- API Endpoints --- +@app.get("/api/shared-dataset-metadata") +async def get_shared_dataset_metadata(): + # This function is unchanged + metadata_list = [] + for dataset_id, full_data in SHARED_DATASET_FULL_DATA.items(): + metadata = {"id": full_data["id"], "name": full_data["name"], "uploadDate": full_data["uploadDate"], "processingState": full_data["processingState"], "itemCounts": {"images": len(full_data["data"]["images"]), "texts": len(full_data["data"]["texts"]), "meshes": len(full_data["data"]["meshes"])}, "isShared": True} + metadata_list.append(metadata) + return metadata_list + +@app.get("/api/shared-dataset") +async def get_shared_dataset(id: str): + # This function is unchanged + dataset = SHARED_DATASET_FULL_DATA.get(id) + if not dataset: + raise HTTPException(status_code=404, detail=f"Shared dataset with id '{id}' not found.") + return dataset + +@app.post("/api/process-dataset", response_model=ProcessDatasetResponse) +async def process_dataset_endpoint( + background_tasks: BackgroundTasks, file: UploadFile = File(...) +): + if not file.filename or not file.filename.endswith('.zip'): + raise HTTPException(status_code=400, detail="A ZIP archive is required.") + + zip_bytes = await file.read() + job_id = str(uuid.uuid4()) + PROCESSING_STATUS[job_id] = {"status": "starting", "stage": "Queued", "progress": 0} + + background_tasks.add_task( + background_process_zip, zip_bytes, file.filename, job_id + ) + return {"job_id": job_id} + +# --- NEW: processing-status endpoint --- +class StatusResponse(BaseModel): + status: str + stage: str | None = None + progress: int | None = None + message: str | None = None + result: dict | None = None + +@app.get("/api/processing-status/{job_id}", response_model=StatusResponse) +async def get_processing_status(job_id: str): + """Poll this endpoint to get the status of a processing job.""" + status = PROCESSING_STATUS.get(job_id) + if not status: + raise HTTPException(status_code=404, detail="Job ID not found.") + return status + + +@app.post("/api/find-matches") +async def find_matches_endpoint(request: SingleMatchRequest): + # This function is unchanged + try: + match_results = await asyncio.to_thread( + find_matches_for_item, request.modality, request.content, request.dataset_id + ) + source_item_data = {"id": "source_item", "name": "Source Item", "content": request.content} + final_response = {"sourceItem": source_item_data, "sourceModality": request.modality, **match_results} + return final_response + except ValueError as ve: + raise HTTPException(status_code=404, detail=str(ve)) + except Exception as e: + import traceback + traceback.print_exc() + raise HTTPException(status_code=500, detail=f"Ошибка при поиске совпадений: {e}") + +app.mount("/", StaticFiles(directory="static", html=True), name="static") \ No newline at end of file diff --git a/backend/requirements.txt b/backend/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..2eec2b52754365d40ab4a7c9657ec4603628c169 --- /dev/null +++ b/backend/requirements.txt @@ -0,0 +1,31 @@ +--extra-index-url https://download.pytorch.org/whl/cpu + +# Framework +fastapi +uvicorn[standard] +python-multipart + +# ML & Data Science +easydict +matplotlib +ninja +numpy +open_clip_torch +pandas +Pillow +PyYAML +scikit-learn +scipy +seaborn +termcolor +timm +torch +torchaudio +torchvision +tqdm +trimesh +umap-learn + +# Other +requests +wandb \ No newline at end of file diff --git a/frontend/.DS_Store b/frontend/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..696c532944e5cb91d7cd9cb065e7a7437c2e7ffa Binary files /dev/null and b/frontend/.DS_Store differ diff --git a/frontend/App.tsx b/frontend/App.tsx new file mode 100644 index 0000000000000000000000000000000000000000..36e034ce3edbacde7ca59cfada1160167d7ec22e --- /dev/null +++ b/frontend/App.tsx @@ -0,0 +1,220 @@ + +import React, { useState, useMemo, useEffect, useCallback, lazy, Suspense } from 'react'; +import { DatasetManager } from './components/DatasetManager'; +import type { Dataset, DatasetMetadata } from './types'; +import * as db from './services/dbService'; +import * as apiService from './services/apiService'; +import { Spinner } from './components/common/Spinner'; + +const ComparisonTool = lazy(() => import('./components/ComparisonTool').then(module => ({ default: module.ComparisonTool }))); + +type View = 'manager' | 'comparison'; + +const App: React.FC = () => { + const [datasets, setDatasets] = useState([]); + const [selectedDatasetId, setSelectedDatasetId] = useState(null); + const [activeDataset, setActiveDataset] = useState(null); + const [view, setView] = useState('manager'); + const [isLoading, setIsLoading] = useState(true); + const [isNavigating, setIsNavigating] = useState(false); + const [error, setError] = useState(null); + + useEffect(() => { + const loadInitialData = async () => { + try { + const localMeta = await db.getAllDatasetMetadata(); + let sharedMeta: DatasetMetadata[] = []; + try { + sharedMeta = await apiService.getSharedDatasetMetadata(); + } catch (e) { + console.error("Could not load shared datasets, continuing with local.", e); + setError("Could not load cloud datasets. The backend service may be unavailable. Local datasets are still accessible."); + } + + const allMeta = [...sharedMeta, ...localMeta]; + + setDatasets(allMeta); + + if (allMeta.length > 0) { + // Select the most recent dataset by default + const sortedMeta = [...allMeta].sort((a, b) => new Date(b.uploadDate).getTime() - new Date(a.uploadDate).getTime()); + setSelectedDatasetId(sortedMeta[0].id); + } + } catch (error) { + console.error("Failed to load initial data", error); + setError("A critical error occurred while loading local datasets."); + } finally { + setIsLoading(false); + } + }; + loadInitialData(); + }, []); + + const addDataset = async (newDataset: Dataset) => { + await db.addDataset(newDataset); + const localMeta = await db.getAllDatasetMetadata(); + const sharedMeta = datasets.filter(d => d.isShared); // Keep existing shared meta + setDatasets([...sharedMeta, ...localMeta]); + setSelectedDatasetId(newDataset.id); + }; + + const deleteDataset = async (id: string) => { + await db.deleteDataset(id); + setDatasets(prevDatasets => { + const newDatasets = prevDatasets.filter(d => d.id !== id); + if (selectedDatasetId === id) { + const sortedMeta = [...newDatasets].sort((a, b) => new Date(b.uploadDate).getTime() - new Date(a.uploadDate).getTime()); + setSelectedDatasetId(sortedMeta.length > 0 ? sortedMeta[0].id : null); + } + return newDatasets; + }); + }; + + const renameDataset = async (id: string, newName: string) => { + await db.renameDataset(id, newName); + setDatasets(prev => prev.map(d => d.id === id ? { ...d, name: newName } : d)); + }; + + const processedDatasets = useMemo(() => { + return datasets.filter(d => d.processingState === 'processed'); + }, [datasets]); + + const getFullDataset = async (id: string): Promise => { + const meta = datasets.find(d => d.id === id); + if (!meta) return null; + + if (meta.isShared) { + return apiService.getSharedDataset(id); + } else { + return db.getDataset(id); + } + }; + + const handleOpenComparisonTool = useCallback(async () => { + if (!selectedDatasetId) return; + const selectedMeta = datasets.find(d => d.id === selectedDatasetId); + if (!selectedMeta || selectedMeta.processingState !== 'processed') return; + + setView('comparison'); + setActiveDataset(null); + setIsNavigating(true); + + try { + const fullDataset = await getFullDataset(selectedDatasetId); + if (!fullDataset) { + throw new Error(`Failed to load dataset ${selectedDatasetId}.`); + } + + // *** NEW LOGIC *** + // If it's a local dataset, ensure it's in the backend's cache before proceeding. + if (!fullDataset.isShared) { + console.log("Local dataset selected. Ensuring it's cached on the backend..."); + await apiService.ensureDatasetInCache(fullDataset); + console.log("Backend cache confirmed."); + } + + setActiveDataset(fullDataset); + + } catch (error) { + console.error("Error preparing comparison tool:", error); + alert(`Error: Could not load the selected dataset. ${error instanceof Error ? error.message : ''}`); + setView('manager'); // Go back on error + } finally { + setIsNavigating(false); + } + }, [selectedDatasetId, datasets]); + + const handleDatasetChange = useCallback(async (newId: string) => { + setSelectedDatasetId(newId); + + setActiveDataset(null); + setIsNavigating(true); + + try { + const fullDataset = await getFullDataset(newId); + if (!fullDataset) { + throw new Error(`Failed to load dataset ${newId}.`); + } + + // Also ensure cache is hydrated when switching datasets inside the tool + if (!fullDataset.isShared) { + await apiService.ensureDatasetInCache(fullDataset); + } + + setActiveDataset(fullDataset); + + } catch (error) { + console.error(`Error switching dataset to ${newId}:`, error); + setActiveDataset(null); + } finally { + setIsNavigating(false); + } + }, []); + + + const mainContent = () => { + if (isLoading) { + return
Loading Datasets...
; + } + + const errorBanner = error ? ( +
+

+ Cloud Connection Error: {error} +

+
+ ) : null; + + if (view === 'manager') { + return ( + + ); + } + + if (view === 'comparison') { + const fallbackUI =
Loading Comparison Tool...
; + if (isNavigating || !activeDataset) { + return fallbackUI; + } + return ( + + { setView('manager'); setActiveDataset(null); }} + /> + + ); + } + return null; + }; + + return ( +
+
+
+

+ Cross-Modal Object Comparison Tool +

+
+
+
+ {mainContent()} +
+
+ ); +}; + +export default App; diff --git a/frontend/components/ComparisonTool.tsx b/frontend/components/ComparisonTool.tsx new file mode 100644 index 0000000000000000000000000000000000000000..ac16d8eaa3427fc7451ada47be4255807cb18d0e --- /dev/null +++ b/frontend/components/ComparisonTool.tsx @@ -0,0 +1,558 @@ + +import React, { useState, useMemo, useCallback, useEffect, useRef, lazy, Suspense } from 'react'; +import type { Dataset, Modality, DataItem, SingleComparisonResult, DatasetMetadata } from '../types'; +import { findTopMatches } from '../services/apiService'; +import { findTopMatchesFromLocal } from '../services/comparisonService'; +import { downloadJson } from '../services/fileService'; +import { FullscreenViewer } from './common/FullscreenViewer'; +import { Spinner } from './common/Spinner'; +import { getItemContent } from '../services/sharedDatasetService'; + +const MeshViewer = lazy(() => import('./common/MeshViewer').then(module => ({ default: module.MeshViewer }))); + +interface ItemCardBaseProps { + modality: Modality; + isSource?: boolean; + confidence?: number; + onView: (item: DataItem, modality: Modality) => void; + onClick?: () => void; + className?: string; +} + +const LazyItemCard: React.FC = ({ item, ...props }) => { + const [loadedItem, setLoadedItem] = useState(item); + const [isLoading, setIsLoading] = useState(!item.content && !!item.contentUrl); + + useEffect(() => { + let isMounted = true; + // Images are loaded by the browser directly via `contentUrl`, so we only fetch for text/mesh. + if (!item.content && item.contentUrl && props.modality !== 'image') { + setIsLoading(true); + getItemContent(item.contentUrl) + .then(content => { + if (isMounted) { + setLoadedItem({ ...item, content }); + setIsLoading(false); + } + }) + .catch(err => { + console.error("Failed to load item content", err); + if (isMounted) setIsLoading(false); + }); + } else { + setLoadedItem(item); + setIsLoading(false); + } + return () => { isMounted = false; } + }, [item, props.modality]); + + if (isLoading) { + return ( +
+ +
+ ); + } + + return ; +}; + + +const ItemCard: React.FC = ({ item, modality, isSource, confidence, onView, onClick }) => { + const isText = modality === 'text'; + + // Use contentUrl for images directly, fallback to loaded content if available. + const imageUrl = modality === 'image' ? (item.contentUrl || (item.content as string)) : null; + + const content = useMemo(() => { + switch (modality) { + case 'image': + if (!imageUrl) return null; + return ( +
+
+ {item.name} +
+
+ ); + case 'text': + if (typeof item.content !== 'string') return null; + return

{item.content}

; + case 'mesh': + return ( +
+
+
}> + + +
+ + ); + } + }, [item, modality, imageUrl]); + + return ( +
onClick ? onClick() : onView(item, modality)} + > +
{content}
+
+

{item.name}

+ {confidence !== undefined &&

Confidence: {confidence.toFixed(4)}

} +
+
+ ); +}; + +const pluralToSingular = (plural: string): Modality => { + if (plural === 'meshes') return 'mesh'; + return plural.slice(0, -1) as Modality; +} + +const ResultsDisplay: React.FC<{ + results: SingleComparisonResult; + onViewItem: (item: DataItem, modality: Modality) => void; +}> = ({ results, onViewItem }) => { + const sourcePluralModality = results.sourceModality === 'mesh' ? 'meshes' : `${results.sourceModality}s`; + + return ( +
+

Comparison Results

+
+ {/* Source Item */} +
+

Source Item

+
+ +
+
+ {/* Matches */} +
+ {Object.entries(results.results) + .filter(([pluralModality, matches]) => pluralModality !== sourcePluralModality && matches && matches.length > 0) + .map(([pluralModality, matches]) => ( +
+

{pluralModality} Matches

+
+ {(matches || []).slice(0, 3).map(match => ( +
+ +
+ ))} +
+
+ ))} +
+
+
+ ); +}; + +interface ComparisonToolProps { + dataset: Dataset; + allDatasets: DatasetMetadata[]; // Use metadata for the dropdown + onDatasetChange: (id: string) => void; + onBack: () => void; +} + +const FileUploader: React.FC<{ + onFileSelect: (file: File) => void | Promise; + accept: string; + modality: Modality; + clear: () => void; +}> = ({ onFileSelect, accept, modality, clear }) => { + const [file, setFile] = useState(null); + const inputRef = React.useRef(null); + + const handleFileChange = (e: React.ChangeEvent) => { + const selectedFile = e.target.files?.[0]; + if (selectedFile) { + setFile(selectedFile); + onFileSelect(selectedFile); + } + }; + + const handleClear = () => { + setFile(null); + if(inputRef.current) inputRef.current.value = ""; + clear(); + } + + return ( +
+ + + {file && ( +
+

Selected: {file.name}

+ +
+ )} +
+ ); +}; + +export const ComparisonTool: React.FC = ({ dataset, allDatasets, onDatasetChange, onBack }) => { + const [activeTab, setActiveTab] = useState('image'); + const [selectedItem, setSelectedItem] = useState(null); + const [newItem, setNewItem] = useState<{file?: File, text?: string} | null>(null); + const [newItemPreview, setNewItemPreview] = useState(null); + const [newItemMeshPreviewContent, setNewItemMeshPreviewContent] = useState(null); + const [comparisonResult, setComparisonResult] = useState(null); + const [isComparing, setIsComparing] = useState(false); + const [searchTerm, setSearchTerm] = useState(''); + const [viewingItem, setViewingItem] = useState<{ item: DataItem; modality: Modality } | null>(null); + const resultsRef = useRef(null); + + const MAX_ITEMS_TO_DISPLAY = 30; + + useEffect(() => { + // Cleanup object URLs to prevent memory leaks + return () => { + if(newItemPreview && newItemPreview.startsWith('blob:')) { + URL.revokeObjectURL(newItemPreview); + } + } + }, [newItemPreview]); + + useEffect(() => { + // Scroll to results when they appear + if (comparisonResult && resultsRef.current) { + // A small delay to ensure the element is rendered and painted before scrolling + const timer = setTimeout(() => { + resultsRef.current?.scrollIntoView({ behavior: 'smooth', block: 'start' }); + }, 100); + return () => clearTimeout(timer); + } + }, [comparisonResult]); + + const handleItemSelect = async (item: DataItem) => { + setSelectedItem(item); + setNewItem(null); + setNewItemPreview(null); + setNewItemMeshPreviewContent(null); + + // Optimistic UI update to show the newly selected source item immediately + setComparisonResult({ + sourceItem: item, + sourceModality: activeTab, + results: {}, // No matches yet, will be populated below + }); + + let itemWithContent = item; + // For cloud items, the content is not loaded yet. We must fetch it. + // The individual card will show a spinner, so we don't need a global one. + if (!item.content && item.contentUrl) { + try { + const content = await getItemContent(item.contentUrl); + itemWithContent = { ...item, content }; + } catch (e) { + console.error("Failed to lazy-load content for comparison:", e); + alert("Could not load item content from the server for comparison."); + setComparisonResult(null); // Clear results on error + setSelectedItem(null); + return; + } + } + + // Use fast, local search for existing dataset items and update the results + const results = findTopMatchesFromLocal(itemWithContent, activeTab, dataset); + setComparisonResult(results); + }; + + const handleNewItemSearch = async () => { + if (!newItem) return; + + setSelectedItem(null); // Deselect grid item + setIsComparing(true); + setComparisonResult(null); // Clear previous results + + try { + let content: string | ArrayBuffer; + let name: string; + + if (newItem.file) { + name = newItem.file.name; + if (activeTab === 'image') { + content = await new Promise((resolve, reject) => { + const reader = new FileReader(); + reader.onload = () => resolve(reader.result as string); + reader.onerror = reject; + reader.readAsDataURL(newItem.file); + }); + } else if (activeTab === 'mesh') { + content = await new Promise((resolve, reject) => { + const reader = new FileReader(); + reader.onload = () => resolve(reader.result as ArrayBuffer); + reader.onerror = reject; + reader.readAsArrayBuffer(newItem.file); + }); + } else { // text file + content = await new Promise((resolve, reject) => { + const reader = new FileReader(); + reader.onload = () => resolve(reader.result as string); + reader.onerror = reject; + reader.readAsText(newItem.file); + }); + } + } else if (newItem.text) { + name = 'Custom Text Input'; + content = newItem.text; + } else { + throw new Error("No new item content found"); + } + + const sourceItem: DataItem = { id: `new_${Date.now()}`, name, content }; + + // Optimistic UI: Show the source item while waiting for matches. + setComparisonResult({ + sourceItem, + sourceModality: activeTab, + results: {}, + }); + + const results = await findTopMatches(sourceItem, activeTab, dataset.id); + setComparisonResult(results); + + } catch (error) { + console.error("Failed to find matches for new item:", error); + alert(`Error finding matches: ${error instanceof Error ? error.message : String(error)}`); + setComparisonResult(null); // Clear on error + } finally { + setIsComparing(false); + } + } + + const handleFullComparison = () => { + if (dataset.fullComparison) { + downloadJson(dataset.fullComparison, `${dataset.name}-full-comparison.json`); + } else { + alert("Full comparison data is not available for this dataset."); + } + }; + + const clearNewItem = useCallback(() => { + setNewItem(null); + if(newItemPreview && newItemPreview.startsWith('blob:')) URL.revokeObjectURL(newItemPreview); + setNewItemPreview(null); + setNewItemMeshPreviewContent(null); + }, [newItemPreview]); + + const handleFileSelected = async (file: File) => { + setNewItem({ file }); + setSelectedItem(null); + if(newItemPreview && newItemPreview.startsWith('blob:')) URL.revokeObjectURL(newItemPreview); + setNewItemPreview(null); + setNewItemMeshPreviewContent(null); + + if(activeTab === 'image') { + setNewItemPreview(URL.createObjectURL(file)); + } else if (activeTab === 'mesh') { + try { + const buffer = await file.arrayBuffer(); + setNewItemMeshPreviewContent(buffer); + } catch (error) { + console.error("Error reading STL file for preview:", error); + alert("Could not read the file for preview."); + } + } + } + + const handleTabChange = (mod: Modality) => { + setActiveTab(mod); + setComparisonResult(null); + setSelectedItem(null); + clearNewItem(); + } + + const pluralKey = activeTab === 'mesh' ? 'meshes' : `${activeTab}s`; + const items = dataset.data[pluralKey as keyof typeof dataset.data]; + + const filteredItems = useMemo(() => { + if (!searchTerm) return items; + return items.filter(item => item.name.toLowerCase().includes(searchTerm.toLowerCase())); + }, [items, searchTerm]); + + const displayedItems = useMemo(() => { + return filteredItems.slice(0, MAX_ITEMS_TO_DISPLAY); + }, [filteredItems]); + + return ( + <> +
+
+
+

Comparison Tool:

+ +
+ +
+ +
+ +
+ + {/* Single Element Search */} +
+
+

Search with New Item

+
+ {/* Left side: Uploader Controls */} +
+ {activeTab === 'image' && ( +
+ +
+ )} + {activeTab === 'mesh' && ( +
+ +
+ )} + {activeTab === 'text' && ( +
+