#include "data_loader.hpp" #include #include #include #include "optical_model.hpp" // For IMG_SIZE FashionMNISTSet load_fashion_mnist_data(const std::string& data_dir, bool is_train) { FashionMNISTSet set; const std::string prefix = is_train ? "train" : "test"; const std::string images_path = data_dir + "/" + prefix + "-images.bin"; const std::string labels_path = data_dir + "/" + prefix + "-labels.bin"; // Load images std::ifstream f_images(images_path, std::ios::binary); if (!f_images) throw std::runtime_error("Cannot open: " + images_path); f_images.seekg(0, std::ios::end); size_t num_bytes = f_images.tellg(); f_images.seekg(0, std::ios::beg); set.N = num_bytes / (IMG_SIZE * sizeof(float)); set.images.resize(set.N * IMG_SIZE); f_images.read(reinterpret_cast(set.images.data()), num_bytes); // Load labels std::ifstream f_labels(labels_path, std::ios::binary); if (!f_labels) throw std::runtime_error("Cannot open: " + labels_path); f_labels.seekg(0, std::ios::end); num_bytes = f_labels.tellg(); f_labels.seekg(0, std::ios::beg); if (set.N != num_bytes) throw std::runtime_error("Image and label count mismatch!"); set.labels.resize(set.N); f_labels.read(reinterpret_cast(set.labels.data()), num_bytes); std::cout << "[INFO] Loaded " << set.N << " " << prefix << " samples.\n"; return set; }