Pure_Optical_CUDA / optical_model.hpp
Agnuxo's picture
Upload 36 files
db3c893 verified
#pragma once
#include <vector>
#include <cuda_runtime.h>
#include <cufft.h>
#include "fungi.hpp"
// Image geometry - Fashion-MNIST
constexpr int IMG_H = 28;
constexpr int IMG_W = 28;
constexpr int IMG_SIZE = IMG_H * IMG_W;
constexpr int NUM_CLASSES = 10; // Fashion-MNIST -> 10 classes
// INTELLIGENT ROLLBACK: Enhanced FFT with stable 2058-feature architecture
// ORIGINAL: Single scale FFT (28x28 = 784 features)
// PREVIOUS: 3-scale FFT (28x28 + 14x14 + 7x7 = 1029 features)
// CURRENT: Enhanced 6-scale mirror FFT (3 normal + 3 mirrored = 2058 features) - OPTIMIZED INFORMATION
constexpr int SCALE_1 = 28; // Full resolution - fine details
constexpr int SCALE_2 = 14; // Half resolution - texture patterns
constexpr int SCALE_3 = 7; // Quarter resolution - edge patterns
constexpr int SCALE_1_SIZE = SCALE_1 * SCALE_1; // 784
constexpr int SCALE_2_SIZE = SCALE_2 * SCALE_2; // 196
constexpr int SCALE_3_SIZE = SCALE_3 * SCALE_3; // 49
constexpr int SINGLE_SCALE_SIZE = SCALE_1_SIZE + SCALE_2_SIZE + SCALE_3_SIZE; // 1029 single features
constexpr int MULTISCALE_SIZE = 2 * SINGLE_SCALE_SIZE; // 2058 total features (normal + mirrored)
constexpr int HIDDEN_SIZE = 1800; // BALANCED: Optimal capacity for enhanced 2058-feature architecture
struct OpticalParams {
// Multi-scale MLP: hidden = ReLU(W1 * multiscale_features + b1), logits = W2 * hidden + b2
std::vector<float> W1; // [HIDDEN_SIZE, MULTISCALE_SIZE] - First layer weights (2058 inputs)
std::vector<float> b1; // [HIDDEN_SIZE] - First layer bias
std::vector<float> W2; // [NUM_CLASSES, HIDDEN_SIZE] - Second layer weights
std::vector<float> b2; // [NUM_CLASSES] - Second layer bias
// Adam moments for all parameters
std::vector<float> m_W1, v_W1, m_b1, v_b1;
std::vector<float> m_W2, v_W2, m_b2, v_b2;
};
struct DeviceBuffers {
float* d_batch_in = nullptr; // [B, IMG_SIZE]
uint8_t* d_batch_lbl = nullptr; // [B]
// Multi-scale optical processing buffers
cufftComplex* d_field_scale1 = nullptr; // [B, SCALE_1_SIZE] - Full resolution field
cufftComplex* d_freq_scale1 = nullptr; // [B, SCALE_1_SIZE] - Full resolution frequency
float* d_features_scale1 = nullptr; // [B, SCALE_1_SIZE] - Full resolution features
cufftComplex* d_field_scale2 = nullptr; // [B, SCALE_2_SIZE] - Half resolution field (14x14)
cufftComplex* d_freq_scale2 = nullptr; // [B, SCALE_2_SIZE] - Half resolution frequency
float* d_features_scale2 = nullptr; // [B, SCALE_2_SIZE] - Half resolution features
cufftComplex* d_field_scale3 = nullptr; // [B, SCALE_3_SIZE] - Quarter resolution field (7x7)
cufftComplex* d_freq_scale3 = nullptr; // [B, SCALE_3_SIZE] - Quarter resolution frequency
float* d_features_scale3 = nullptr; // [B, SCALE_3_SIZE] - Quarter resolution features
// Mirror architecture: flipped versions for enhanced feature extraction
float* d_features_scale1_mirror = nullptr; // [B, SCALE_1_SIZE] - Mirrored scale1 features
float* d_features_scale2_mirror = nullptr; // [B, SCALE_2_SIZE] - Mirrored scale2 features
float* d_features_scale3_mirror = nullptr; // [B, SCALE_3_SIZE] - Mirrored scale3 features
// BREAKTHROUGH: Rich dual-channel processing - separate magnitude and phase
float* d_magnitude_features = nullptr; // [B, MIRROR_SCALE_SIZE] - All magnitude features (2058)
float* d_phase_features = nullptr; // [B, MIRROR_SCALE_SIZE] - All phase features (2058)
float* d_multiscale_features = nullptr; // [B, MULTISCALE_SIZE] - Enhanced mirror features (2058)
float* d_hidden = nullptr; // [B, HIDDEN_SIZE] - Hidden layer activations
float* d_logits = nullptr; // [B, NUM_CLASSES]
float* d_probs = nullptr; // [B, NUM_CLASSES]
float* d_grad_logits = nullptr; // [B, NUM_CLASSES]
float* d_grad_hidden = nullptr; // [B, HIDDEN_SIZE] - Hidden layer gradients
float* d_grad_multiscale = nullptr; // [B, MULTISCALE_SIZE] - Multi-scale gradients
float* d_A = nullptr; // [IMG_SIZE]
float* d_P = nullptr; // [IMG_SIZE]
float* d_grad_map = nullptr; // [IMG_SIZE]
// C++ OPTIMIZATION: Persistent weight buffers in GPU memory
float* d_W1 = nullptr; // [HIDDEN_SIZE, MULTISCALE_SIZE] - Persistent weights
float* d_b1 = nullptr; // [HIDDEN_SIZE] - Persistent biases
float* d_W2 = nullptr; // [NUM_CLASSES, HIDDEN_SIZE] - Persistent weights
float* d_b2 = nullptr; // [NUM_CLASSES] - Persistent biases
float* d_gW1 = nullptr; // [HIDDEN_SIZE, MULTISCALE_SIZE] - Persistent gradients
float* d_gb1 = nullptr; // [HIDDEN_SIZE] - Persistent gradients
float* d_gW2 = nullptr; // [NUM_CLASSES, HIDDEN_SIZE] - Persistent gradients
float* d_gb2 = nullptr; // [NUM_CLASSES] - Persistent gradients
float* d_loss_scalar = nullptr; // [1] - Persistent loss buffer
// CRITICAL: Bottleneck detection buffers
float* d_bottleneck_metrics = nullptr; // [4] - Real-time bottleneck analysis
};
struct FFTPlan {
cufftHandle plan_fwd_scale1{}; // 28x28 FFT plan
cufftHandle plan_inv_scale1{};
cufftHandle plan_fwd_scale2{}; // 14x14 FFT plan
cufftHandle plan_inv_scale2{};
cufftHandle plan_fwd_scale3{}; // 7x7 FFT plan
cufftHandle plan_inv_scale3{};
};
void allocate_device_buffers(DeviceBuffers& db, int batch);
void free_device_buffers(DeviceBuffers& db);
void create_fft_plan(FFTPlan& fft, int batch);
void destroy_fft_plan(FFTPlan& fft);
void init_params(OpticalParams& p, unsigned seed);
// C++ OPTIMIZATION: Initialize weights in GPU memory once
void upload_params_to_gpu(const OpticalParams& params, DeviceBuffers& db);
void download_params_from_gpu(OpticalParams& params, const DeviceBuffers& db);
// C++ OPTIMIZATION: GPU-side Adam updates (no CPU transfers!)
void adam_update_gpu(OpticalParams& params, const DeviceBuffers& db, float lr, float wd, int t_adam);
float train_batch(const float* h_batch_in, const uint8_t* h_batch_lbl,
int B, FungiSoA& fungi, OpticalParams& params,
DeviceBuffers& db, FFTPlan& fft,
float lr, float wd, int t_adam);
void infer_batch(const float* h_batch_in, int B,
const FungiSoA& fungi, const OpticalParams& params,
DeviceBuffers& db, FFTPlan& fft,
std::vector<int>& out_predictions);