|
|
#pragma once |
|
|
#include <vector> |
|
|
#include <cuda_runtime.h> |
|
|
#include <cufft.h> |
|
|
#include "fungi.hpp" |
|
|
|
|
|
|
|
|
constexpr int IMG_H = 28; |
|
|
constexpr int IMG_W = 28; |
|
|
constexpr int IMG_SIZE = IMG_H * IMG_W; |
|
|
constexpr int NUM_CLASSES = 10; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
constexpr int SCALE_1 = 28; |
|
|
constexpr int SCALE_2 = 14; |
|
|
constexpr int SCALE_3 = 7; |
|
|
constexpr int SCALE_1_SIZE = SCALE_1 * SCALE_1; |
|
|
constexpr int SCALE_2_SIZE = SCALE_2 * SCALE_2; |
|
|
constexpr int SCALE_3_SIZE = SCALE_3 * SCALE_3; |
|
|
constexpr int SINGLE_SCALE_SIZE = SCALE_1_SIZE + SCALE_2_SIZE + SCALE_3_SIZE; |
|
|
constexpr int MULTISCALE_SIZE = 2 * SINGLE_SCALE_SIZE; |
|
|
|
|
|
constexpr int HIDDEN_SIZE = 1800; |
|
|
|
|
|
struct OpticalParams { |
|
|
|
|
|
std::vector<float> W1; |
|
|
std::vector<float> b1; |
|
|
std::vector<float> W2; |
|
|
std::vector<float> b2; |
|
|
|
|
|
std::vector<float> m_W1, v_W1, m_b1, v_b1; |
|
|
std::vector<float> m_W2, v_W2, m_b2, v_b2; |
|
|
}; |
|
|
|
|
|
struct DeviceBuffers { |
|
|
float* d_batch_in = nullptr; |
|
|
uint8_t* d_batch_lbl = nullptr; |
|
|
|
|
|
|
|
|
cufftComplex* d_field_scale1 = nullptr; |
|
|
cufftComplex* d_freq_scale1 = nullptr; |
|
|
float* d_features_scale1 = nullptr; |
|
|
|
|
|
cufftComplex* d_field_scale2 = nullptr; |
|
|
cufftComplex* d_freq_scale2 = nullptr; |
|
|
float* d_features_scale2 = nullptr; |
|
|
|
|
|
cufftComplex* d_field_scale3 = nullptr; |
|
|
cufftComplex* d_freq_scale3 = nullptr; |
|
|
float* d_features_scale3 = nullptr; |
|
|
|
|
|
|
|
|
float* d_features_scale1_mirror = nullptr; |
|
|
float* d_features_scale2_mirror = nullptr; |
|
|
float* d_features_scale3_mirror = nullptr; |
|
|
|
|
|
|
|
|
float* d_magnitude_features = nullptr; |
|
|
float* d_phase_features = nullptr; |
|
|
|
|
|
float* d_multiscale_features = nullptr; |
|
|
float* d_hidden = nullptr; |
|
|
float* d_logits = nullptr; |
|
|
float* d_probs = nullptr; |
|
|
float* d_grad_logits = nullptr; |
|
|
float* d_grad_hidden = nullptr; |
|
|
float* d_grad_multiscale = nullptr; |
|
|
|
|
|
float* d_A = nullptr; |
|
|
float* d_P = nullptr; |
|
|
float* d_grad_map = nullptr; |
|
|
|
|
|
|
|
|
float* d_W1 = nullptr; |
|
|
float* d_b1 = nullptr; |
|
|
float* d_W2 = nullptr; |
|
|
float* d_b2 = nullptr; |
|
|
float* d_gW1 = nullptr; |
|
|
float* d_gb1 = nullptr; |
|
|
float* d_gW2 = nullptr; |
|
|
float* d_gb2 = nullptr; |
|
|
float* d_loss_scalar = nullptr; |
|
|
|
|
|
|
|
|
float* d_bottleneck_metrics = nullptr; |
|
|
}; |
|
|
|
|
|
struct FFTPlan { |
|
|
cufftHandle plan_fwd_scale1{}; |
|
|
cufftHandle plan_inv_scale1{}; |
|
|
cufftHandle plan_fwd_scale2{}; |
|
|
cufftHandle plan_inv_scale2{}; |
|
|
cufftHandle plan_fwd_scale3{}; |
|
|
cufftHandle plan_inv_scale3{}; |
|
|
}; |
|
|
|
|
|
void allocate_device_buffers(DeviceBuffers& db, int batch); |
|
|
void free_device_buffers(DeviceBuffers& db); |
|
|
void create_fft_plan(FFTPlan& fft, int batch); |
|
|
void destroy_fft_plan(FFTPlan& fft); |
|
|
|
|
|
void init_params(OpticalParams& p, unsigned seed); |
|
|
|
|
|
|
|
|
void upload_params_to_gpu(const OpticalParams& params, DeviceBuffers& db); |
|
|
void download_params_from_gpu(OpticalParams& params, const DeviceBuffers& db); |
|
|
|
|
|
|
|
|
void adam_update_gpu(OpticalParams& params, const DeviceBuffers& db, float lr, float wd, int t_adam); |
|
|
|
|
|
float train_batch(const float* h_batch_in, const uint8_t* h_batch_lbl, |
|
|
int B, FungiSoA& fungi, OpticalParams& params, |
|
|
DeviceBuffers& db, FFTPlan& fft, |
|
|
float lr, float wd, int t_adam); |
|
|
|
|
|
void infer_batch(const float* h_batch_in, int B, |
|
|
const FungiSoA& fungi, const OpticalParams& params, |
|
|
DeviceBuffers& db, FFTPlan& fft, |
|
|
std::vector<int>& out_predictions); |
|
|
|