Yuan Gao commited on
Commit
38b8a8c
·
1 Parent(s): b7bc2ce

preprocessing code

Browse files
Files changed (2) hide show
  1. .gitignore +15 -0
  2. preprocessor.py +86 -0
.gitignore ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .log/
2
+ .vscode/
3
+ __pycache__/
4
+ assets_local/
5
+ */__pycache__/
6
+ .venv/
7
+ _t*.py
8
+ _*.yaml
9
+ _develop_*.py
10
+ outputs/
11
+ jupyter/*
12
+ *.csv
13
+ third_party/*
14
+ dev_*.py
15
+ .venv
preprocessor.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from math import gcd
3
+ from typing import Optional, Union
4
+ import joblib
5
+
6
+ import numpy as np
7
+ from scipy import signal
8
+
9
+ def load_scaler_joblib(path: str) -> tuple[torch.Tensor, torch.Tensor]:
10
+ """
11
+ Load ecg_scaler.pkl and return center and scale as torch tensors.
12
+ Args:
13
+ path: Path to the joblib file.
14
+ Returns:
15
+ center: torch.Tensor
16
+ scale: torch.Tensor
17
+ """
18
+ sc = joblib.load(path)
19
+ center = torch.from_numpy(sc.mean_.astype(np.float32))
20
+ scale = torch.from_numpy(sc.scale_.astype(np.float32)).clamp_min(1e-8)
21
+ return center, scale
22
+
23
+ class ECGTransform:
24
+ """
25
+ Unified ECG preprocessing: downsampling and scaling.
26
+ Usage:
27
+ transform = ECGTransform(center, scale, src_fs=512, target_fs=100)
28
+ ecg_out = transform(ecg_in)
29
+ """
30
+ def __init__(
31
+ self,
32
+ center: Union[np.ndarray, torch.Tensor],
33
+ scale: Union[np.ndarray, torch.Tensor],
34
+ src_fs: int = 100, #we assume the input ECG is already at 100Hz
35
+ target_fs: int = 100,
36
+ band: Optional[tuple[float, float]] = (0.5, 40.0),
37
+ bp_order: int = 4,
38
+ axis: int = -1,
39
+ ) -> None:
40
+ self.center = torch.as_tensor(center, dtype=torch.float32)
41
+ self.scale = torch.as_tensor(scale, dtype=torch.float32).clamp_min(1e-8)
42
+ self.src_fs = src_fs
43
+ self.target_fs = target_fs
44
+ self.band = band
45
+ self.bp_order = bp_order
46
+ self.axis = axis
47
+
48
+ def downsample(self, x: np.ndarray) -> np.ndarray:
49
+ x = np.asarray(x)
50
+ if self.band is not None:
51
+ lowcut, highcut = self.band
52
+ max_high = 0.45 * self.target_fs
53
+ highcut = min(highcut, max_high)
54
+ nyq = self.src_fs / 2.0
55
+ if lowcut <= 0:
56
+ wn = highcut / nyq
57
+ sos = signal.butter(self.bp_order, wn, btype="low", output="sos")
58
+ else:
59
+ wn = (lowcut / nyq, highcut / nyq)
60
+ sos = signal.butter(self.bp_order, wn, btype="band", output="sos")
61
+ x = signal.sosfiltfilt(sos, x, axis=self.axis)
62
+ g = gcd(self.src_fs, self.target_fs)
63
+ up = self.target_fs // g
64
+ down = self.src_fs // g
65
+ y = signal.resample_poly(x, up, down, axis=self.axis, window=("kaiser", 5.0), padtype="median")
66
+ return y
67
+
68
+ def scale(self, ecg: torch.Tensor) -> torch.Tensor:
69
+ ecg = ecg.to(torch.float32)
70
+ ecg = (ecg - self.center[:, None]) / self.scale[:, None]
71
+ return ecg
72
+
73
+ def __call__(self, x: np.ndarray) -> torch.Tensor:
74
+ """
75
+ Downsample and scale ECG data.
76
+ Args:
77
+ x: np.ndarray, shape (leads, time)
78
+ Returns:
79
+ torch.Tensor, shape (leads, time)
80
+ """
81
+ if self.src_fs != self.target_fs:
82
+ x = self.downsample(x)
83
+ if not isinstance(x, torch.Tensor):
84
+ x = torch.from_numpy(x)
85
+ x = self.scale(x)
86
+ return x