File size: 6,505 Bytes
db3c893
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
#pragma once
#include <vector>
#include <cuda_runtime.h>
#include <cufft.h>
#include "fungi.hpp"

// Image geometry - Fashion-MNIST
constexpr int IMG_H = 28;
constexpr int IMG_W = 28;
constexpr int IMG_SIZE = IMG_H * IMG_W;
constexpr int NUM_CLASSES = 10; // Fashion-MNIST -> 10 classes

// INTELLIGENT ROLLBACK: Enhanced FFT with stable 2058-feature architecture
// ORIGINAL: Single scale FFT (28x28 = 784 features)
// PREVIOUS: 3-scale FFT (28x28 + 14x14 + 7x7 = 1029 features)
// CURRENT: Enhanced 6-scale mirror FFT (3 normal + 3 mirrored = 2058 features) - OPTIMIZED INFORMATION
constexpr int SCALE_1 = 28; // Full resolution - fine details
constexpr int SCALE_2 = 14; // Half resolution - texture patterns
constexpr int SCALE_3 = 7;  // Quarter resolution - edge patterns
constexpr int SCALE_1_SIZE = SCALE_1 * SCALE_1; // 784
constexpr int SCALE_2_SIZE = SCALE_2 * SCALE_2; // 196
constexpr int SCALE_3_SIZE = SCALE_3 * SCALE_3; // 49
constexpr int SINGLE_SCALE_SIZE = SCALE_1_SIZE + SCALE_2_SIZE + SCALE_3_SIZE; // 1029 single features
constexpr int MULTISCALE_SIZE = 2 * SINGLE_SCALE_SIZE; // 2058 total features (normal + mirrored)

constexpr int HIDDEN_SIZE = 1800; // BALANCED: Optimal capacity for enhanced 2058-feature architecture

struct OpticalParams {
    // Multi-scale MLP: hidden = ReLU(W1 * multiscale_features + b1), logits = W2 * hidden + b2
    std::vector<float> W1; // [HIDDEN_SIZE, MULTISCALE_SIZE] - First layer weights (2058 inputs)
    std::vector<float> b1; // [HIDDEN_SIZE] - First layer bias
    std::vector<float> W2; // [NUM_CLASSES, HIDDEN_SIZE] - Second layer weights
    std::vector<float> b2; // [NUM_CLASSES] - Second layer bias
    // Adam moments for all parameters
    std::vector<float> m_W1, v_W1, m_b1, v_b1;
    std::vector<float> m_W2, v_W2, m_b2, v_b2;
};

struct DeviceBuffers {
    float* d_batch_in = nullptr;    // [B, IMG_SIZE]
    uint8_t* d_batch_lbl = nullptr; // [B]

    // Multi-scale optical processing buffers
    cufftComplex* d_field_scale1 = nullptr;  // [B, SCALE_1_SIZE] - Full resolution field
    cufftComplex* d_freq_scale1 = nullptr;   // [B, SCALE_1_SIZE] - Full resolution frequency
    float* d_features_scale1 = nullptr;      // [B, SCALE_1_SIZE] - Full resolution features

    cufftComplex* d_field_scale2 = nullptr;  // [B, SCALE_2_SIZE] - Half resolution field (14x14)
    cufftComplex* d_freq_scale2 = nullptr;   // [B, SCALE_2_SIZE] - Half resolution frequency
    float* d_features_scale2 = nullptr;      // [B, SCALE_2_SIZE] - Half resolution features

    cufftComplex* d_field_scale3 = nullptr;  // [B, SCALE_3_SIZE] - Quarter resolution field (7x7)
    cufftComplex* d_freq_scale3 = nullptr;   // [B, SCALE_3_SIZE] - Quarter resolution frequency
    float* d_features_scale3 = nullptr;      // [B, SCALE_3_SIZE] - Quarter resolution features

    // Mirror architecture: flipped versions for enhanced feature extraction
    float* d_features_scale1_mirror = nullptr;  // [B, SCALE_1_SIZE] - Mirrored scale1 features
    float* d_features_scale2_mirror = nullptr;  // [B, SCALE_2_SIZE] - Mirrored scale2 features
    float* d_features_scale3_mirror = nullptr;  // [B, SCALE_3_SIZE] - Mirrored scale3 features

    // BREAKTHROUGH: Rich dual-channel processing - separate magnitude and phase
    float* d_magnitude_features = nullptr;       // [B, MIRROR_SCALE_SIZE] - All magnitude features (2058)
    float* d_phase_features = nullptr;           // [B, MIRROR_SCALE_SIZE] - All phase features (2058)

    float* d_multiscale_features = nullptr;  // [B, MULTISCALE_SIZE] - Enhanced mirror features (2058)
    float* d_hidden = nullptr;               // [B, HIDDEN_SIZE] - Hidden layer activations
    float* d_logits = nullptr;               // [B, NUM_CLASSES]
    float* d_probs = nullptr;                // [B, NUM_CLASSES]
    float* d_grad_logits = nullptr;          // [B, NUM_CLASSES]
    float* d_grad_hidden = nullptr;          // [B, HIDDEN_SIZE] - Hidden layer gradients
    float* d_grad_multiscale = nullptr;      // [B, MULTISCALE_SIZE] - Multi-scale gradients

    float* d_A = nullptr;           // [IMG_SIZE]
    float* d_P = nullptr;           // [IMG_SIZE]
    float* d_grad_map = nullptr;    // [IMG_SIZE]

    // C++ OPTIMIZATION: Persistent weight buffers in GPU memory
    float* d_W1 = nullptr;          // [HIDDEN_SIZE, MULTISCALE_SIZE] - Persistent weights
    float* d_b1 = nullptr;          // [HIDDEN_SIZE] - Persistent biases
    float* d_W2 = nullptr;          // [NUM_CLASSES, HIDDEN_SIZE] - Persistent weights
    float* d_b2 = nullptr;          // [NUM_CLASSES] - Persistent biases
    float* d_gW1 = nullptr;         // [HIDDEN_SIZE, MULTISCALE_SIZE] - Persistent gradients
    float* d_gb1 = nullptr;         // [HIDDEN_SIZE] - Persistent gradients
    float* d_gW2 = nullptr;         // [NUM_CLASSES, HIDDEN_SIZE] - Persistent gradients
    float* d_gb2 = nullptr;         // [NUM_CLASSES] - Persistent gradients
    float* d_loss_scalar = nullptr; // [1] - Persistent loss buffer

    // CRITICAL: Bottleneck detection buffers
    float* d_bottleneck_metrics = nullptr; // [4] - Real-time bottleneck analysis
};

struct FFTPlan {
    cufftHandle plan_fwd_scale1{};  // 28x28 FFT plan
    cufftHandle plan_inv_scale1{};
    cufftHandle plan_fwd_scale2{};  // 14x14 FFT plan
    cufftHandle plan_inv_scale2{};
    cufftHandle plan_fwd_scale3{};  // 7x7 FFT plan
    cufftHandle plan_inv_scale3{};
};

void allocate_device_buffers(DeviceBuffers& db, int batch);
void free_device_buffers(DeviceBuffers& db);
void create_fft_plan(FFTPlan& fft, int batch);
void destroy_fft_plan(FFTPlan& fft);

void init_params(OpticalParams& p, unsigned seed);

// C++ OPTIMIZATION: Initialize weights in GPU memory once
void upload_params_to_gpu(const OpticalParams& params, DeviceBuffers& db);
void download_params_from_gpu(OpticalParams& params, const DeviceBuffers& db);

// C++ OPTIMIZATION: GPU-side Adam updates (no CPU transfers!)
void adam_update_gpu(OpticalParams& params, const DeviceBuffers& db, float lr, float wd, int t_adam);

float train_batch(const float* h_batch_in, const uint8_t* h_batch_lbl,
                  int B, FungiSoA& fungi, OpticalParams& params,
                  DeviceBuffers& db, FFTPlan& fft,
                  float lr, float wd, int t_adam);

void infer_batch(const float* h_batch_in, int B,
                 const FungiSoA& fungi, const OpticalParams& params,
                 DeviceBuffers& db, FFTPlan& fft,
                 std::vector<int>& out_predictions);