#pragma once #include #include #include #include "fungi.hpp" // Image geometry - Fashion-MNIST constexpr int IMG_H = 28; constexpr int IMG_W = 28; constexpr int IMG_SIZE = IMG_H * IMG_W; constexpr int NUM_CLASSES = 10; // Fashion-MNIST -> 10 classes // INTELLIGENT ROLLBACK: Enhanced FFT with stable 2058-feature architecture // ORIGINAL: Single scale FFT (28x28 = 784 features) // PREVIOUS: 3-scale FFT (28x28 + 14x14 + 7x7 = 1029 features) // CURRENT: Enhanced 6-scale mirror FFT (3 normal + 3 mirrored = 2058 features) - OPTIMIZED INFORMATION constexpr int SCALE_1 = 28; // Full resolution - fine details constexpr int SCALE_2 = 14; // Half resolution - texture patterns constexpr int SCALE_3 = 7; // Quarter resolution - edge patterns constexpr int SCALE_1_SIZE = SCALE_1 * SCALE_1; // 784 constexpr int SCALE_2_SIZE = SCALE_2 * SCALE_2; // 196 constexpr int SCALE_3_SIZE = SCALE_3 * SCALE_3; // 49 constexpr int SINGLE_SCALE_SIZE = SCALE_1_SIZE + SCALE_2_SIZE + SCALE_3_SIZE; // 1029 single features constexpr int MULTISCALE_SIZE = 2 * SINGLE_SCALE_SIZE; // 2058 total features (normal + mirrored) constexpr int HIDDEN_SIZE = 1800; // BALANCED: Optimal capacity for enhanced 2058-feature architecture struct OpticalParams { // Multi-scale MLP: hidden = ReLU(W1 * multiscale_features + b1), logits = W2 * hidden + b2 std::vector W1; // [HIDDEN_SIZE, MULTISCALE_SIZE] - First layer weights (2058 inputs) std::vector b1; // [HIDDEN_SIZE] - First layer bias std::vector W2; // [NUM_CLASSES, HIDDEN_SIZE] - Second layer weights std::vector b2; // [NUM_CLASSES] - Second layer bias // Adam moments for all parameters std::vector m_W1, v_W1, m_b1, v_b1; std::vector m_W2, v_W2, m_b2, v_b2; }; struct DeviceBuffers { float* d_batch_in = nullptr; // [B, IMG_SIZE] uint8_t* d_batch_lbl = nullptr; // [B] // Multi-scale optical processing buffers cufftComplex* d_field_scale1 = nullptr; // [B, SCALE_1_SIZE] - Full resolution field cufftComplex* d_freq_scale1 = nullptr; // [B, SCALE_1_SIZE] - Full resolution frequency float* d_features_scale1 = nullptr; // [B, SCALE_1_SIZE] - Full resolution features cufftComplex* d_field_scale2 = nullptr; // [B, SCALE_2_SIZE] - Half resolution field (14x14) cufftComplex* d_freq_scale2 = nullptr; // [B, SCALE_2_SIZE] - Half resolution frequency float* d_features_scale2 = nullptr; // [B, SCALE_2_SIZE] - Half resolution features cufftComplex* d_field_scale3 = nullptr; // [B, SCALE_3_SIZE] - Quarter resolution field (7x7) cufftComplex* d_freq_scale3 = nullptr; // [B, SCALE_3_SIZE] - Quarter resolution frequency float* d_features_scale3 = nullptr; // [B, SCALE_3_SIZE] - Quarter resolution features // Mirror architecture: flipped versions for enhanced feature extraction float* d_features_scale1_mirror = nullptr; // [B, SCALE_1_SIZE] - Mirrored scale1 features float* d_features_scale2_mirror = nullptr; // [B, SCALE_2_SIZE] - Mirrored scale2 features float* d_features_scale3_mirror = nullptr; // [B, SCALE_3_SIZE] - Mirrored scale3 features // BREAKTHROUGH: Rich dual-channel processing - separate magnitude and phase float* d_magnitude_features = nullptr; // [B, MIRROR_SCALE_SIZE] - All magnitude features (2058) float* d_phase_features = nullptr; // [B, MIRROR_SCALE_SIZE] - All phase features (2058) float* d_multiscale_features = nullptr; // [B, MULTISCALE_SIZE] - Enhanced mirror features (2058) float* d_hidden = nullptr; // [B, HIDDEN_SIZE] - Hidden layer activations float* d_logits = nullptr; // [B, NUM_CLASSES] float* d_probs = nullptr; // [B, NUM_CLASSES] float* d_grad_logits = nullptr; // [B, NUM_CLASSES] float* d_grad_hidden = nullptr; // [B, HIDDEN_SIZE] - Hidden layer gradients float* d_grad_multiscale = nullptr; // [B, MULTISCALE_SIZE] - Multi-scale gradients float* d_A = nullptr; // [IMG_SIZE] float* d_P = nullptr; // [IMG_SIZE] float* d_grad_map = nullptr; // [IMG_SIZE] // C++ OPTIMIZATION: Persistent weight buffers in GPU memory float* d_W1 = nullptr; // [HIDDEN_SIZE, MULTISCALE_SIZE] - Persistent weights float* d_b1 = nullptr; // [HIDDEN_SIZE] - Persistent biases float* d_W2 = nullptr; // [NUM_CLASSES, HIDDEN_SIZE] - Persistent weights float* d_b2 = nullptr; // [NUM_CLASSES] - Persistent biases float* d_gW1 = nullptr; // [HIDDEN_SIZE, MULTISCALE_SIZE] - Persistent gradients float* d_gb1 = nullptr; // [HIDDEN_SIZE] - Persistent gradients float* d_gW2 = nullptr; // [NUM_CLASSES, HIDDEN_SIZE] - Persistent gradients float* d_gb2 = nullptr; // [NUM_CLASSES] - Persistent gradients float* d_loss_scalar = nullptr; // [1] - Persistent loss buffer // CRITICAL: Bottleneck detection buffers float* d_bottleneck_metrics = nullptr; // [4] - Real-time bottleneck analysis }; struct FFTPlan { cufftHandle plan_fwd_scale1{}; // 28x28 FFT plan cufftHandle plan_inv_scale1{}; cufftHandle plan_fwd_scale2{}; // 14x14 FFT plan cufftHandle plan_inv_scale2{}; cufftHandle plan_fwd_scale3{}; // 7x7 FFT plan cufftHandle plan_inv_scale3{}; }; void allocate_device_buffers(DeviceBuffers& db, int batch); void free_device_buffers(DeviceBuffers& db); void create_fft_plan(FFTPlan& fft, int batch); void destroy_fft_plan(FFTPlan& fft); void init_params(OpticalParams& p, unsigned seed); // C++ OPTIMIZATION: Initialize weights in GPU memory once void upload_params_to_gpu(const OpticalParams& params, DeviceBuffers& db); void download_params_from_gpu(OpticalParams& params, const DeviceBuffers& db); // C++ OPTIMIZATION: GPU-side Adam updates (no CPU transfers!) void adam_update_gpu(OpticalParams& params, const DeviceBuffers& db, float lr, float wd, int t_adam); float train_batch(const float* h_batch_in, const uint8_t* h_batch_lbl, int B, FungiSoA& fungi, OpticalParams& params, DeviceBuffers& db, FFTPlan& fft, float lr, float wd, int t_adam); void infer_batch(const float* h_batch_in, int B, const FungiSoA& fungi, const OpticalParams& params, DeviceBuffers& db, FFTPlan& fft, std::vector& out_predictions);