Yuan Gao
commited on
Commit
·
38b8a8c
1
Parent(s):
b7bc2ce
preprocessing code
Browse files- .gitignore +15 -0
- preprocessor.py +86 -0
.gitignore
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.log/
|
| 2 |
+
.vscode/
|
| 3 |
+
__pycache__/
|
| 4 |
+
assets_local/
|
| 5 |
+
*/__pycache__/
|
| 6 |
+
.venv/
|
| 7 |
+
_t*.py
|
| 8 |
+
_*.yaml
|
| 9 |
+
_develop_*.py
|
| 10 |
+
outputs/
|
| 11 |
+
jupyter/*
|
| 12 |
+
*.csv
|
| 13 |
+
third_party/*
|
| 14 |
+
dev_*.py
|
| 15 |
+
.venv
|
preprocessor.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from math import gcd
|
| 3 |
+
from typing import Optional, Union
|
| 4 |
+
import joblib
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
from scipy import signal
|
| 8 |
+
|
| 9 |
+
def load_scaler_joblib(path: str) -> tuple[torch.Tensor, torch.Tensor]:
|
| 10 |
+
"""
|
| 11 |
+
Load ecg_scaler.pkl and return center and scale as torch tensors.
|
| 12 |
+
Args:
|
| 13 |
+
path: Path to the joblib file.
|
| 14 |
+
Returns:
|
| 15 |
+
center: torch.Tensor
|
| 16 |
+
scale: torch.Tensor
|
| 17 |
+
"""
|
| 18 |
+
sc = joblib.load(path)
|
| 19 |
+
center = torch.from_numpy(sc.mean_.astype(np.float32))
|
| 20 |
+
scale = torch.from_numpy(sc.scale_.astype(np.float32)).clamp_min(1e-8)
|
| 21 |
+
return center, scale
|
| 22 |
+
|
| 23 |
+
class ECGTransform:
|
| 24 |
+
"""
|
| 25 |
+
Unified ECG preprocessing: downsampling and scaling.
|
| 26 |
+
Usage:
|
| 27 |
+
transform = ECGTransform(center, scale, src_fs=512, target_fs=100)
|
| 28 |
+
ecg_out = transform(ecg_in)
|
| 29 |
+
"""
|
| 30 |
+
def __init__(
|
| 31 |
+
self,
|
| 32 |
+
center: Union[np.ndarray, torch.Tensor],
|
| 33 |
+
scale: Union[np.ndarray, torch.Tensor],
|
| 34 |
+
src_fs: int = 100, #we assume the input ECG is already at 100Hz
|
| 35 |
+
target_fs: int = 100,
|
| 36 |
+
band: Optional[tuple[float, float]] = (0.5, 40.0),
|
| 37 |
+
bp_order: int = 4,
|
| 38 |
+
axis: int = -1,
|
| 39 |
+
) -> None:
|
| 40 |
+
self.center = torch.as_tensor(center, dtype=torch.float32)
|
| 41 |
+
self.scale = torch.as_tensor(scale, dtype=torch.float32).clamp_min(1e-8)
|
| 42 |
+
self.src_fs = src_fs
|
| 43 |
+
self.target_fs = target_fs
|
| 44 |
+
self.band = band
|
| 45 |
+
self.bp_order = bp_order
|
| 46 |
+
self.axis = axis
|
| 47 |
+
|
| 48 |
+
def downsample(self, x: np.ndarray) -> np.ndarray:
|
| 49 |
+
x = np.asarray(x)
|
| 50 |
+
if self.band is not None:
|
| 51 |
+
lowcut, highcut = self.band
|
| 52 |
+
max_high = 0.45 * self.target_fs
|
| 53 |
+
highcut = min(highcut, max_high)
|
| 54 |
+
nyq = self.src_fs / 2.0
|
| 55 |
+
if lowcut <= 0:
|
| 56 |
+
wn = highcut / nyq
|
| 57 |
+
sos = signal.butter(self.bp_order, wn, btype="low", output="sos")
|
| 58 |
+
else:
|
| 59 |
+
wn = (lowcut / nyq, highcut / nyq)
|
| 60 |
+
sos = signal.butter(self.bp_order, wn, btype="band", output="sos")
|
| 61 |
+
x = signal.sosfiltfilt(sos, x, axis=self.axis)
|
| 62 |
+
g = gcd(self.src_fs, self.target_fs)
|
| 63 |
+
up = self.target_fs // g
|
| 64 |
+
down = self.src_fs // g
|
| 65 |
+
y = signal.resample_poly(x, up, down, axis=self.axis, window=("kaiser", 5.0), padtype="median")
|
| 66 |
+
return y
|
| 67 |
+
|
| 68 |
+
def scale(self, ecg: torch.Tensor) -> torch.Tensor:
|
| 69 |
+
ecg = ecg.to(torch.float32)
|
| 70 |
+
ecg = (ecg - self.center[:, None]) / self.scale[:, None]
|
| 71 |
+
return ecg
|
| 72 |
+
|
| 73 |
+
def __call__(self, x: np.ndarray) -> torch.Tensor:
|
| 74 |
+
"""
|
| 75 |
+
Downsample and scale ECG data.
|
| 76 |
+
Args:
|
| 77 |
+
x: np.ndarray, shape (leads, time)
|
| 78 |
+
Returns:
|
| 79 |
+
torch.Tensor, shape (leads, time)
|
| 80 |
+
"""
|
| 81 |
+
if self.src_fs != self.target_fs:
|
| 82 |
+
x = self.downsample(x)
|
| 83 |
+
if not isinstance(x, torch.Tensor):
|
| 84 |
+
x = torch.from_numpy(x)
|
| 85 |
+
x = self.scale(x)
|
| 86 |
+
return x
|