Commit
·
0eb79a8
0
Parent(s):
init
Browse files- Dockerfile +25 -0
- LICENSE +21 -0
- docs/README.md +144 -0
- docs/data.mseed +0 -0
- docs/example_batch_prediction.ipynb +211 -0
- docs/example_fastapi.ipynb +0 -0
- docs/example_gradio.ipynb +0 -0
- docs/test_api.py +37 -0
- env.yml +17 -0
- mkdocs.yml +18 -0
- model/190703-214543/checkpoint +3 -0
- model/190703-214543/config.log +3 -0
- model/190703-214543/loss.log +3 -0
- model/190703-214543/model_95.ckpt.data-00000-of-00001 +3 -0
- model/190703-214543/model_95.ckpt.index +3 -0
- model/190703-214543/model_95.ckpt.meta +3 -0
- phasenet/__init__.py +1 -0
- phasenet/app.py +341 -0
- phasenet/data_reader.py +1010 -0
- phasenet/detect_peaks.py +207 -0
- phasenet/model.py +489 -0
- phasenet/postprocess.py +377 -0
- phasenet/predict.py +262 -0
- phasenet/slide_window.py +88 -0
- phasenet/test_app.py +47 -0
- phasenet/train.py +246 -0
- phasenet/util.py +238 -0
- phasenet/visulization.py +481 -0
- requirements.txt +7 -0
- setup.py +116 -0
Dockerfile
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM tensorflow/tensorflow
|
| 2 |
+
|
| 3 |
+
# Create the environment:
|
| 4 |
+
# COPY env.yml /app
|
| 5 |
+
# RUN conda env create --name cs329s --file=env.yml
|
| 6 |
+
# Make RUN commands use the new environment:
|
| 7 |
+
# SHELL ["conda", "run", "-n", "cs329s", "/bin/bash", "-c"]
|
| 8 |
+
|
| 9 |
+
RUN pip install tqdm obspy pandas
|
| 10 |
+
RUN pip install uvicorn fastapi
|
| 11 |
+
|
| 12 |
+
WORKDIR /opt
|
| 13 |
+
|
| 14 |
+
# Copy files
|
| 15 |
+
COPY phasenet /opt/phasenet
|
| 16 |
+
COPY model /opt/model
|
| 17 |
+
|
| 18 |
+
# Expose API port
|
| 19 |
+
EXPOSE 8000
|
| 20 |
+
|
| 21 |
+
ENV PYTHONUNBUFFERED=1
|
| 22 |
+
|
| 23 |
+
# Start API server
|
| 24 |
+
#ENTRYPOINT ["conda", "run", "--no-capture-output", "-n", "cs329s", "uvicorn", "--app-dir", "phasenet", "app:app", "--reload", "--port", "8000", "--host", "0.0.0.0"]
|
| 25 |
+
ENTRYPOINT ["uvicorn", "--app-dir", "phasenet", "app:app", "--reload", "--port", "7860", "--host", "0.0.0.0"]
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2021 Weiqiang Zhu
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
docs/README.md
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# PhaseNet: A Deep-Neural-Network-Based Seismic Arrival Time Picking Method
|
| 2 |
+
|
| 3 |
+
[](https://ai4eps.github.io/PhaseNet)
|
| 4 |
+
|
| 5 |
+
## 1. Install [miniconda](https://docs.conda.io/en/latest/miniconda.html) and requirements
|
| 6 |
+
- Download PhaseNet repository
|
| 7 |
+
```bash
|
| 8 |
+
git clone https://github.com/wayneweiqiang/PhaseNet.git
|
| 9 |
+
cd PhaseNet
|
| 10 |
+
```
|
| 11 |
+
- Install to default environment
|
| 12 |
+
```bash
|
| 13 |
+
conda env update -f=env.yml -n base
|
| 14 |
+
```
|
| 15 |
+
- Install to "phasenet" virtual envirionment
|
| 16 |
+
```bash
|
| 17 |
+
conda env create -f env.yml
|
| 18 |
+
conda activate phasenet
|
| 19 |
+
```
|
| 20 |
+
|
| 21 |
+
## 2. Pre-trained model
|
| 22 |
+
Located in directory: **model/190703-214543**
|
| 23 |
+
|
| 24 |
+
## 3. Related papers
|
| 25 |
+
- Zhu, Weiqiang, and Gregory C. Beroza. "PhaseNet: A Deep-Neural-Network-Based Seismic Arrival Time Picking Method." arXiv preprint arXiv:1803.03211 (2018).
|
| 26 |
+
- Liu, Min, et al. "Rapid characterization of the July 2019 Ridgecrest, California, earthquake sequence from raw seismic data using machine‐learning phase picker." Geophysical Research Letters 47.4 (2020): e2019GL086189.
|
| 27 |
+
- Park, Yongsoo, et al. "Machine‐learning‐based analysis of the Guy‐Greenbrier, Arkansas earthquakes: A tale of two sequences." Geophysical Research Letters 47.6 (2020): e2020GL087032.
|
| 28 |
+
- Chai, Chengping, et al. "Using a deep neural network and transfer learning to bridge scales for seismic phase picking." Geophysical Research Letters 47.16 (2020): e2020GL088651.
|
| 29 |
+
- Tan, Yen Joe, et al. "Machine‐Learning‐Based High‐Resolution Earthquake Catalog Reveals How Complex Fault Structures Were Activated during the 2016–2017 Central Italy Sequence." The Seismic Record 1.1 (2021): 11-19.
|
| 30 |
+
|
| 31 |
+
## 4. Batch prediction
|
| 32 |
+
See examples in the [notebook](https://github.com/wayneweiqiang/PhaseNet/blob/master/docs/example_batch_prediction.ipynb): [example_batch_prediction.ipynb](example_batch_prediction.ipynb)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
PhaseNet currently supports four data formats: mseed, sac, hdf5, and numpy. The test data can be downloaded here:
|
| 36 |
+
```
|
| 37 |
+
wget https://github.com/wayneweiqiang/PhaseNet/releases/download/test_data/test_data.zip
|
| 38 |
+
unzip test_data.zip
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
- For mseed format:
|
| 42 |
+
```
|
| 43 |
+
python phasenet/predict.py --model=model/190703-214543 --data_list=test_data/mseed.csv --data_dir=test_data/mseed --format=mseed --amplitude --response_xml=test_data/stations.xml --batch_size=1 --sampling_rate=100 --plot_figure
|
| 44 |
+
```
|
| 45 |
+
```
|
| 46 |
+
python phasenet/predict.py --model=model/190703-214543 --data_list=test_data/mseed2.csv --data_dir=test_data/mseed --format=mseed --amplitude --response_xml=test_data/stations.xml --batch_size=1 --sampling_rate=100 --plot_figure
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
- For sac format:
|
| 50 |
+
```
|
| 51 |
+
python phasenet/predict.py --model=model/190703-214543 --data_list=test_data/sac.csv --data_dir=test_data/sac --format=sac --batch_size=1 --plot_figure
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
- For numpy format:
|
| 55 |
+
```
|
| 56 |
+
python phasenet/predict.py --model=model/190703-214543 --data_list=test_data/npz.csv --data_dir=test_data/npz --format=numpy --plot_figure
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
- For hdf5 format:
|
| 60 |
+
```
|
| 61 |
+
python phasenet/predict.py --model=model/190703-214543 --hdf5_file=test_data/data.h5 --hdf5_group=data --format=hdf5 --plot_figure
|
| 62 |
+
```
|
| 63 |
+
|
| 64 |
+
- For a seismic array (used by [QuakeFlow](https://github.com/wayneweiqiang/QuakeFlow)):
|
| 65 |
+
```
|
| 66 |
+
python phasenet/predict.py --model=model/190703-214543 --data_list=test_data/mseed_array.csv --data_dir=test_data/mseed_array --stations=test_data/stations.json --format=mseed_array --amplitude
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
Notes:
|
| 70 |
+
|
| 71 |
+
1. The reason for using "--batch_size=1" is because the mseed or sac files usually are not the same length. If you want to use a larger batch size for a good prediction speed, you need to cut the data to the same length.
|
| 72 |
+
|
| 73 |
+
2. Remove the "--plot_figure" argument for large datasets, because plotting can be very slow.
|
| 74 |
+
|
| 75 |
+
Optional arguments:
|
| 76 |
+
```
|
| 77 |
+
usage: predict.py [-h] [--batch_size BATCH_SIZE] [--model_dir MODEL_DIR]
|
| 78 |
+
[--data_dir DATA_DIR] [--data_list DATA_LIST]
|
| 79 |
+
[--hdf5_file HDF5_FILE] [--hdf5_group HDF5_GROUP]
|
| 80 |
+
[--result_dir RESULT_DIR] [--result_fname RESULT_FNAME]
|
| 81 |
+
[--min_p_prob MIN_P_PROB] [--min_s_prob MIN_S_PROB]
|
| 82 |
+
[--mpd MPD] [--amplitude] [--format FORMAT]
|
| 83 |
+
[--s3_url S3_URL] [--stations STATIONS] [--plot_figure]
|
| 84 |
+
[--save_prob]
|
| 85 |
+
|
| 86 |
+
optional arguments:
|
| 87 |
+
-h, --help show this help message and exit
|
| 88 |
+
--batch_size BATCH_SIZE
|
| 89 |
+
batch size
|
| 90 |
+
--model_dir MODEL_DIR
|
| 91 |
+
Checkpoint directory (default: None)
|
| 92 |
+
--data_dir DATA_DIR Input file directory
|
| 93 |
+
--data_list DATA_LIST
|
| 94 |
+
Input csv file
|
| 95 |
+
--hdf5_file HDF5_FILE
|
| 96 |
+
Input hdf5 file
|
| 97 |
+
--hdf5_group HDF5_GROUP
|
| 98 |
+
data group name in hdf5 file
|
| 99 |
+
--result_dir RESULT_DIR
|
| 100 |
+
Output directory
|
| 101 |
+
--result_fname RESULT_FNAME
|
| 102 |
+
Output file
|
| 103 |
+
--min_p_prob MIN_P_PROB
|
| 104 |
+
Probability threshold for P pick
|
| 105 |
+
--min_s_prob MIN_S_PROB
|
| 106 |
+
Probability threshold for S pick
|
| 107 |
+
--mpd MPD Minimum peak distance
|
| 108 |
+
--amplitude if return amplitude value
|
| 109 |
+
--format FORMAT input format
|
| 110 |
+
--stations STATIONS seismic station info
|
| 111 |
+
--plot_figure If plot figure for test
|
| 112 |
+
--save_prob If save result for test
|
| 113 |
+
```
|
| 114 |
+
|
| 115 |
+
- The output picks are saved to "results/picks.csv" on default
|
| 116 |
+
|
| 117 |
+
|file_name |begin_time |station_id|phase_index|phase_time |phase_score|phase_amp |phase_type|
|
| 118 |
+
|-----------------|-----------------------|----------|-----------|-----------------------|-----------|----------------------|----------|
|
| 119 |
+
|2020-10-01T00:00*|2020-10-01T00:00:00.003|CI.BOM..HH|14734 |2020-10-01T00:02:27.343|0.708 |2.4998866231208325e-14|P |
|
| 120 |
+
|2020-10-01T00:00*|2020-10-01T00:00:00.003|CI.BOM..HH|15487 |2020-10-01T00:02:34.873|0.416 |2.4998866231208325e-14|S |
|
| 121 |
+
|2020-10-01T00:00*|2020-10-01T00:00:00.003|CI.COA..HH|319 |2020-10-01T00:00:03.193|0.762 |3.708662269972206e-14 |P |
|
| 122 |
+
|
| 123 |
+
Notes:
|
| 124 |
+
1. The *phase_index* means which data point is the pick in the original sequence. So *phase_time* = *begin_time* + *phase_index* / *sampling rate*. The default *sampling_rate* is 100Hz
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
## 5. QuakeFlow example
|
| 128 |
+
A complete earthquake detection workflow can be found in the [QuakeFlow](https://wayneweiqiang.github.io/QuakeFlow/) project.
|
| 129 |
+
|
| 130 |
+
## 6. Interactive example
|
| 131 |
+
See details in the [notebook](https://github.com/wayneweiqiang/PhaseNet/blob/master/docs/example_gradio.ipynb): [example_interactive.ipynb](example_gradio.ipynb)
|
| 132 |
+
|
| 133 |
+
## 7. Training
|
| 134 |
+
- Download a small sample dataset:
|
| 135 |
+
```bash
|
| 136 |
+
wget https://github.com/wayneweiqiang/PhaseNet/releases/download/test_data/test_data.zip
|
| 137 |
+
unzip test_data.zip
|
| 138 |
+
```
|
| 139 |
+
- Start training from the pre-trained model
|
| 140 |
+
```
|
| 141 |
+
python phasenet/train.py --model_dir=model/190703-214543/ --train_dir=test_data/npz --train_list=test_data/npz.csv --plot_figure --epochs=10 --batch_size=10
|
| 142 |
+
```
|
| 143 |
+
- Check results in the **log** folder
|
| 144 |
+
|
docs/data.mseed
ADDED
|
Binary file (73.7 kB). View file
|
|
|
docs/example_batch_prediction.ipynb
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"# Batch Prediction\n",
|
| 8 |
+
"\n",
|
| 9 |
+
"## 1. Download demo data\n",
|
| 10 |
+
"\n",
|
| 11 |
+
"```\n",
|
| 12 |
+
"cd PhaseNet\n",
|
| 13 |
+
"wget https://github.com/wayneweiqiang/PhaseNet/releases/download/test_data/test_data.zip\n",
|
| 14 |
+
"unzip test_data.zip\n",
|
| 15 |
+
"```\n",
|
| 16 |
+
"\n",
|
| 17 |
+
"## 2. Run batch prediction \n",
|
| 18 |
+
"\n",
|
| 19 |
+
"PhaseNet currently supports four data formats: mseed, sac, hdf5, and numpy. \n",
|
| 20 |
+
"\n",
|
| 21 |
+
"- For mseed format:\n",
|
| 22 |
+
"```\n",
|
| 23 |
+
"python phasenet/predict.py --model=model/190703-214543 --data_list=test_data/mseed.csv --data_dir=test_data/mseed --format=mseed --plot_figure\n",
|
| 24 |
+
"```\n",
|
| 25 |
+
"\n",
|
| 26 |
+
"- For sac format:\n",
|
| 27 |
+
"```\n",
|
| 28 |
+
"python phasenet/predict.py --model=model/190703-214543 --data_list=test_data/sac.csv --data_dir=test_data/sac --format=sac --plot_figure\n",
|
| 29 |
+
"```\n",
|
| 30 |
+
"\n",
|
| 31 |
+
"- For numpy format:\n",
|
| 32 |
+
"```\n",
|
| 33 |
+
"python phasenet/predict.py --model=model/190703-214543 --data_list=test_data/npz.csv --data_dir=test_data/npz --format=numpy --plot_figure\n",
|
| 34 |
+
"```\n",
|
| 35 |
+
"\n",
|
| 36 |
+
"- For hdf5 format:\n",
|
| 37 |
+
"```\n",
|
| 38 |
+
"python phasenet/predict.py --model=model/190703-214543 --hdf5_file=test_data/data.h5 --hdf5_group=data --format=hdf5 --plot_figure\n",
|
| 39 |
+
"```\n",
|
| 40 |
+
"\n",
|
| 41 |
+
"- For a seismic array (used by [QuakeFlow](https://github.com/wayneweiqiang/QuakeFlow)):\n",
|
| 42 |
+
"```\n",
|
| 43 |
+
"python phasenet/predict.py --model=model/190703-214543 --data_list=test_data/mseed_array.csv --data_dir=test_data/mseed_array --stations=test_data/stations.json --format=mseed_array --amplitude\n",
|
| 44 |
+
"```\n",
|
| 45 |
+
"```\n",
|
| 46 |
+
"python phasenet/predict.py --model=model/190703-214543 --data_list=test_data/mseed2.csv --data_dir=test_data/mseed --stations=test_data/stations.json --format=mseed_array --amplitude\n",
|
| 47 |
+
"```\n",
|
| 48 |
+
"\n",
|
| 49 |
+
"Notes: \n",
|
| 50 |
+
"1. Remove the \"--plot_figure\" argument for large datasets, because plotting can be very slow.\n",
|
| 51 |
+
"\n",
|
| 52 |
+
"Optional arguments:\n",
|
| 53 |
+
"```\n",
|
| 54 |
+
"usage: predict.py [-h] [--batch_size BATCH_SIZE] [--model_dir MODEL_DIR]\n",
|
| 55 |
+
" [--data_dir DATA_DIR] [--data_list DATA_LIST]\n",
|
| 56 |
+
" [--hdf5_file HDF5_FILE] [--hdf5_group HDF5_GROUP]\n",
|
| 57 |
+
" [--result_dir RESULT_DIR] [--result_fname RESULT_FNAME]\n",
|
| 58 |
+
" [--min_p_prob MIN_P_PROB] [--min_s_prob MIN_S_PROB]\n",
|
| 59 |
+
" [--mpd MPD] [--amplitude] [--format FORMAT]\n",
|
| 60 |
+
" [--s3_url S3_URL] [--stations STATIONS] [--plot_figure]\n",
|
| 61 |
+
" [--save_prob]\n",
|
| 62 |
+
"\n",
|
| 63 |
+
"optional arguments:\n",
|
| 64 |
+
" -h, --help show this help message and exit\n",
|
| 65 |
+
" --batch_size BATCH_SIZE\n",
|
| 66 |
+
" batch size\n",
|
| 67 |
+
" --model_dir MODEL_DIR\n",
|
| 68 |
+
" Checkpoint directory (default: None)\n",
|
| 69 |
+
" --data_dir DATA_DIR Input file directory\n",
|
| 70 |
+
" --data_list DATA_LIST\n",
|
| 71 |
+
" Input csv file\n",
|
| 72 |
+
" --hdf5_file HDF5_FILE\n",
|
| 73 |
+
" Input hdf5 file\n",
|
| 74 |
+
" --hdf5_group HDF5_GROUP\n",
|
| 75 |
+
" data group name in hdf5 file\n",
|
| 76 |
+
" --result_dir RESULT_DIR\n",
|
| 77 |
+
" Output directory\n",
|
| 78 |
+
" --result_fname RESULT_FNAME\n",
|
| 79 |
+
" Output file\n",
|
| 80 |
+
" --min_p_prob MIN_P_PROB\n",
|
| 81 |
+
" Probability threshold for P pick\n",
|
| 82 |
+
" --min_s_prob MIN_S_PROB\n",
|
| 83 |
+
" Probability threshold for S pick\n",
|
| 84 |
+
" --mpd MPD Minimum peak distance\n",
|
| 85 |
+
" --amplitude if return amplitude value\n",
|
| 86 |
+
" --format FORMAT input format\n",
|
| 87 |
+
" --stations STATIONS seismic station info\n",
|
| 88 |
+
" --plot_figure If plot figure for test\n",
|
| 89 |
+
" --save_prob If save result for test\n",
|
| 90 |
+
"```\n",
|
| 91 |
+
"\n",
|
| 92 |
+
"## 3. Output picks\n",
|
| 93 |
+
"- The output picks are saved to \"results/picks.csv\" on default\n",
|
| 94 |
+
"\n",
|
| 95 |
+
"|file_name |begin_time |station_id|phase_index|phase_time |phase_score|phase_amp |phase_type|\n",
|
| 96 |
+
"|-----------------|-----------------------|----------|-----------|-----------------------|-----------|----------------------|----------|\n",
|
| 97 |
+
"|2020-10-01T00:00*|2020-10-01T00:00:00.003|CI.BOM..HH|14734 |2020-10-01T00:02:27.343|0.708 |2.4998866231208325e-14|P |\n",
|
| 98 |
+
"|2020-10-01T00:00*|2020-10-01T00:00:00.003|CI.BOM..HH|15487 |2020-10-01T00:02:34.873|0.416 |2.4998866231208325e-14|S |\n",
|
| 99 |
+
"|2020-10-01T00:00*|2020-10-01T00:00:00.003|CI.COA..HH|319 |2020-10-01T00:00:03.193|0.762 |3.708662269972206e-14 |P |\n",
|
| 100 |
+
"\n",
|
| 101 |
+
"Notes:\n",
|
| 102 |
+
"1. The *phase_index* means which data point is the pick in the original sequence. So *phase_time* = *begin_time* + *phase_index* / *sampling rate*. The default *sampling_rate* is 100Hz \n"
|
| 103 |
+
]
|
| 104 |
+
},
|
| 105 |
+
{
|
| 106 |
+
"cell_type": "markdown",
|
| 107 |
+
"metadata": {},
|
| 108 |
+
"source": [
|
| 109 |
+
"## 3. Read P/S picks\n",
|
| 110 |
+
"\n",
|
| 111 |
+
"PhaseNet currently outputs two format: **CSV** and **JSON**"
|
| 112 |
+
]
|
| 113 |
+
},
|
| 114 |
+
{
|
| 115 |
+
"cell_type": "code",
|
| 116 |
+
"execution_count": 1,
|
| 117 |
+
"metadata": {},
|
| 118 |
+
"outputs": [],
|
| 119 |
+
"source": [
|
| 120 |
+
"import pandas as pd\n",
|
| 121 |
+
"import json\n",
|
| 122 |
+
"import os\n",
|
| 123 |
+
"PROJECT_ROOT = os.path.realpath(os.path.join(os.path.abspath(''), \"..\"))"
|
| 124 |
+
]
|
| 125 |
+
},
|
| 126 |
+
{
|
| 127 |
+
"cell_type": "code",
|
| 128 |
+
"execution_count": 2,
|
| 129 |
+
"metadata": {},
|
| 130 |
+
"outputs": [
|
| 131 |
+
{
|
| 132 |
+
"name": "stdout",
|
| 133 |
+
"output_type": "stream",
|
| 134 |
+
"text": [
|
| 135 |
+
"fname NC.MCV..EH.0361339.npz\n",
|
| 136 |
+
"t0 1970-01-01T00:00:00.000\n",
|
| 137 |
+
"p_idx [5999, 9015]\n",
|
| 138 |
+
"p_prob [0.987, 0.981]\n",
|
| 139 |
+
"s_idx [6181, 9205]\n",
|
| 140 |
+
"s_prob [0.553, 0.873]\n",
|
| 141 |
+
"Name: 1, dtype: object\n",
|
| 142 |
+
"fname NN.LHV..EH.0384064.npz\n",
|
| 143 |
+
"t0 1970-01-01T00:00:00.000\n",
|
| 144 |
+
"p_idx []\n",
|
| 145 |
+
"p_prob []\n",
|
| 146 |
+
"s_idx []\n",
|
| 147 |
+
"s_prob []\n",
|
| 148 |
+
"Name: 0, dtype: object\n"
|
| 149 |
+
]
|
| 150 |
+
}
|
| 151 |
+
],
|
| 152 |
+
"source": [
|
| 153 |
+
"picks_csv = pd.read_csv(os.path.join(PROJECT_ROOT, \"results/picks.csv\"), sep=\"\\t\")\n",
|
| 154 |
+
"picks_csv.loc[:, 'p_idx'] = picks_csv[\"p_idx\"].apply(lambda x: x.strip(\"[]\").split(\",\"))\n",
|
| 155 |
+
"picks_csv.loc[:, 'p_prob'] = picks_csv[\"p_prob\"].apply(lambda x: x.strip(\"[]\").split(\",\"))\n",
|
| 156 |
+
"picks_csv.loc[:, 's_idx'] = picks_csv[\"s_idx\"].apply(lambda x: x.strip(\"[]\").split(\",\"))\n",
|
| 157 |
+
"picks_csv.loc[:, 's_prob'] = picks_csv[\"s_prob\"].apply(lambda x: x.strip(\"[]\").split(\",\"))\n",
|
| 158 |
+
"print(picks_csv.iloc[1])\n",
|
| 159 |
+
"print(picks_csv.iloc[0])"
|
| 160 |
+
]
|
| 161 |
+
},
|
| 162 |
+
{
|
| 163 |
+
"cell_type": "code",
|
| 164 |
+
"execution_count": 3,
|
| 165 |
+
"metadata": {},
|
| 166 |
+
"outputs": [
|
| 167 |
+
{
|
| 168 |
+
"name": "stdout",
|
| 169 |
+
"output_type": "stream",
|
| 170 |
+
"text": [
|
| 171 |
+
"{'id': 'NC.MCV..EH.0361339.npz', 'timestamp': '1970-01-01T00:01:30.150', 'prob': 0.9811667799949646, 'type': 'p'}\n",
|
| 172 |
+
"{'id': 'NC.MCV..EH.0361339.npz', 'timestamp': '1970-01-01T00:00:59.990', 'prob': 0.9872905611991882, 'type': 'p'}\n"
|
| 173 |
+
]
|
| 174 |
+
}
|
| 175 |
+
],
|
| 176 |
+
"source": [
|
| 177 |
+
"with open(os.path.join(PROJECT_ROOT, \"results/picks.json\")) as fp:\n",
|
| 178 |
+
" picks_json = json.load(fp) \n",
|
| 179 |
+
"print(picks_json[1])\n",
|
| 180 |
+
"print(picks_json[0])"
|
| 181 |
+
]
|
| 182 |
+
}
|
| 183 |
+
],
|
| 184 |
+
"metadata": {
|
| 185 |
+
"kernelspec": {
|
| 186 |
+
"display_name": "Python 3.10.4 64-bit",
|
| 187 |
+
"language": "python",
|
| 188 |
+
"name": "python3"
|
| 189 |
+
},
|
| 190 |
+
"language_info": {
|
| 191 |
+
"codemirror_mode": {
|
| 192 |
+
"name": "ipython",
|
| 193 |
+
"version": 3
|
| 194 |
+
},
|
| 195 |
+
"file_extension": ".py",
|
| 196 |
+
"mimetype": "text/x-python",
|
| 197 |
+
"name": "python",
|
| 198 |
+
"nbconvert_exporter": "python",
|
| 199 |
+
"pygments_lexer": "ipython3",
|
| 200 |
+
"version": "3.10.4"
|
| 201 |
+
},
|
| 202 |
+
"orig_nbformat": 4,
|
| 203 |
+
"vscode": {
|
| 204 |
+
"interpreter": {
|
| 205 |
+
"hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
|
| 206 |
+
}
|
| 207 |
+
}
|
| 208 |
+
},
|
| 209 |
+
"nbformat": 4,
|
| 210 |
+
"nbformat_minor": 2
|
| 211 |
+
}
|
docs/example_fastapi.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
docs/example_gradio.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
docs/test_api.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# %%
|
| 2 |
+
from gradio_client import Client
|
| 3 |
+
import obspy
|
| 4 |
+
import numpy as np
|
| 5 |
+
import json
|
| 6 |
+
import pandas as pd
|
| 7 |
+
|
| 8 |
+
# %%
|
| 9 |
+
|
| 10 |
+
waveform = obspy.read()
|
| 11 |
+
array = np.array([x.data for x in waveform]).T
|
| 12 |
+
|
| 13 |
+
# pipeline = PreTrainedPipeline()
|
| 14 |
+
inputs = array.tolist()
|
| 15 |
+
inputs = json.dumps(inputs)
|
| 16 |
+
# picks = pipeline(inputs)
|
| 17 |
+
# print(picks)
|
| 18 |
+
|
| 19 |
+
# %%
|
| 20 |
+
client = Client("ai4eps/phasenet")
|
| 21 |
+
output, file = client.predict(["test_test.mseed"])
|
| 22 |
+
# %%
|
| 23 |
+
with open(output, "r") as f:
|
| 24 |
+
picks = json.load(f)["data"]
|
| 25 |
+
|
| 26 |
+
# %%
|
| 27 |
+
picks = pd.read_csv(file)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# %%
|
| 31 |
+
job = client.submit(["test_test.mseed", "test_test.mseed"], api_name="/predict") # This is not blocking
|
| 32 |
+
|
| 33 |
+
print(job.status())
|
| 34 |
+
|
| 35 |
+
# %%
|
| 36 |
+
output, file = job.result()
|
| 37 |
+
|
env.yml
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: phasenet
|
| 2 |
+
channels:
|
| 3 |
+
- defaults
|
| 4 |
+
- conda-forge
|
| 5 |
+
dependencies:
|
| 6 |
+
- python
|
| 7 |
+
- numpy
|
| 8 |
+
- scipy
|
| 9 |
+
- matplotlib
|
| 10 |
+
- pandas
|
| 11 |
+
- scikit-learn
|
| 12 |
+
- tqdm
|
| 13 |
+
- obspy
|
| 14 |
+
- uvicorn
|
| 15 |
+
- fastapi
|
| 16 |
+
- tensorflow
|
| 17 |
+
- keras
|
mkdocs.yml
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
site_name: "PhaseNet"
|
| 2 |
+
site_description: 'PhaseNet: a deep-neural-network-based seismic arrival-time picking method'
|
| 3 |
+
site_author: 'Weiqiang Zhu'
|
| 4 |
+
docs_dir: docs/
|
| 5 |
+
repo_name: 'AI4EPS/PhaseNet'
|
| 6 |
+
repo_url: 'https://github.com/ai4eps/PhaseNet'
|
| 7 |
+
nav:
|
| 8 |
+
- Overview: README.md
|
| 9 |
+
- Interactive Example: example_gradio.ipynb
|
| 10 |
+
- Batch Prediction: example_batch_prediction.ipynb
|
| 11 |
+
theme:
|
| 12 |
+
name: 'material'
|
| 13 |
+
plugins:
|
| 14 |
+
- mkdocs-jupyter
|
| 15 |
+
extra:
|
| 16 |
+
analytics:
|
| 17 |
+
provider: google
|
| 18 |
+
property: G-RZQ9LRPL0S
|
model/190703-214543/checkpoint
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1606ccb25e1533fa0398c5dbce7f3a45ac77f90b78b99f81a044294ba38a2c0c
|
| 3 |
+
size 83
|
model/190703-214543/config.log
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ed9dfa705053a5025facc9952c7da6abef19ec5f672d9e50386bf3f2d80294f2
|
| 3 |
+
size 345
|
model/190703-214543/loss.log
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ccb6f19117497571e19bec5da6012ac7af91f1bd29e931ffd0b23c6b657bb401
|
| 3 |
+
size 8101
|
model/190703-214543/model_95.ckpt.data-00000-of-00001
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9ee2c15dd78fb15de45a55ad64a446f1a0ced152ba4ac5c506d82b9194da85b4
|
| 3 |
+
size 3226256
|
model/190703-214543/model_95.ckpt.index
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f96b553b76be4ebae9a455eaf8d83cfa8c0e110f06cfba958de2568e5b6b2780
|
| 3 |
+
size 7223
|
model/190703-214543/model_95.ckpt.meta
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7ebd154a5ba0721ba8bbb627ba61b556ee60660eb34bbcd1b1f50396b07cc4ed
|
| 3 |
+
size 2172055
|
phasenet/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
__version__ = "0.1.0"
|
phasenet/app.py
ADDED
|
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from collections import defaultdict, namedtuple
|
| 3 |
+
from datetime import datetime, timedelta
|
| 4 |
+
from json import dumps
|
| 5 |
+
from typing import Any, AnyStr, Dict, List, NamedTuple, Union, Optional
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import requests
|
| 9 |
+
import tensorflow as tf
|
| 10 |
+
from fastapi import FastAPI, WebSocket
|
| 11 |
+
from kafka import KafkaProducer
|
| 12 |
+
from pydantic import BaseModel
|
| 13 |
+
from scipy.interpolate import interp1d
|
| 14 |
+
|
| 15 |
+
from model import ModelConfig, UNet
|
| 16 |
+
from postprocess import extract_picks
|
| 17 |
+
|
| 18 |
+
tf.compat.v1.disable_eager_execution()
|
| 19 |
+
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
|
| 20 |
+
PROJECT_ROOT = os.path.realpath(os.path.join(os.path.dirname(__file__), ".."))
|
| 21 |
+
JSONObject = Dict[AnyStr, Any]
|
| 22 |
+
JSONArray = List[Any]
|
| 23 |
+
JSONStructure = Union[JSONArray, JSONObject]
|
| 24 |
+
|
| 25 |
+
app = FastAPI()
|
| 26 |
+
X_SHAPE = [3000, 1, 3]
|
| 27 |
+
SAMPLING_RATE = 100
|
| 28 |
+
|
| 29 |
+
# load model
|
| 30 |
+
model = UNet(mode="pred")
|
| 31 |
+
sess_config = tf.compat.v1.ConfigProto()
|
| 32 |
+
sess_config.gpu_options.allow_growth = True
|
| 33 |
+
|
| 34 |
+
sess = tf.compat.v1.Session(config=sess_config)
|
| 35 |
+
saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables())
|
| 36 |
+
init = tf.compat.v1.global_variables_initializer()
|
| 37 |
+
sess.run(init)
|
| 38 |
+
latest_check_point = tf.train.latest_checkpoint(f"{PROJECT_ROOT}/model/190703-214543")
|
| 39 |
+
print(f"restoring model {latest_check_point}")
|
| 40 |
+
saver.restore(sess, latest_check_point)
|
| 41 |
+
|
| 42 |
+
# GAMMA API Endpoint
|
| 43 |
+
GAMMA_API_URL = "http://gamma-api:8001"
|
| 44 |
+
# GAMMA_API_URL = 'http://localhost:8001'
|
| 45 |
+
# GAMMA_API_URL = "http://gamma.quakeflow.com"
|
| 46 |
+
# GAMMA_API_URL = "http://127.0.0.1:8001"
|
| 47 |
+
|
| 48 |
+
# Kafak producer
|
| 49 |
+
use_kafka = False
|
| 50 |
+
|
| 51 |
+
try:
|
| 52 |
+
print("Connecting to k8s kafka")
|
| 53 |
+
BROKER_URL = "quakeflow-kafka-headless:9092"
|
| 54 |
+
# BROKER_URL = "34.83.137.139:9094"
|
| 55 |
+
producer = KafkaProducer(
|
| 56 |
+
bootstrap_servers=[BROKER_URL],
|
| 57 |
+
key_serializer=lambda x: dumps(x).encode("utf-8"),
|
| 58 |
+
value_serializer=lambda x: dumps(x).encode("utf-8"),
|
| 59 |
+
)
|
| 60 |
+
use_kafka = True
|
| 61 |
+
print("k8s kafka connection success!")
|
| 62 |
+
except BaseException:
|
| 63 |
+
print("k8s Kafka connection error")
|
| 64 |
+
try:
|
| 65 |
+
print("Connecting to local kafka")
|
| 66 |
+
producer = KafkaProducer(
|
| 67 |
+
bootstrap_servers=["localhost:9092"],
|
| 68 |
+
key_serializer=lambda x: dumps(x).encode("utf-8"),
|
| 69 |
+
value_serializer=lambda x: dumps(x).encode("utf-8"),
|
| 70 |
+
)
|
| 71 |
+
use_kafka = True
|
| 72 |
+
print("local kafka connection success!")
|
| 73 |
+
except BaseException:
|
| 74 |
+
print("local Kafka connection error")
|
| 75 |
+
print(f"Kafka status: {use_kafka}")
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def normalize_batch(data, window=3000):
|
| 79 |
+
"""
|
| 80 |
+
data: nsta, nt, nch
|
| 81 |
+
"""
|
| 82 |
+
shift = window // 2
|
| 83 |
+
nsta, nt, nch = data.shape
|
| 84 |
+
|
| 85 |
+
# std in slide windows
|
| 86 |
+
data_pad = np.pad(data, ((0, 0), (window // 2, window // 2), (0, 0)), mode="reflect")
|
| 87 |
+
t = np.arange(0, nt, shift, dtype="int")
|
| 88 |
+
std = np.zeros([nsta, len(t) + 1, nch])
|
| 89 |
+
mean = np.zeros([nsta, len(t) + 1, nch])
|
| 90 |
+
for i in range(1, len(t)):
|
| 91 |
+
std[:, i, :] = np.std(data_pad[:, i * shift : i * shift + window, :], axis=1)
|
| 92 |
+
mean[:, i, :] = np.mean(data_pad[:, i * shift : i * shift + window, :], axis=1)
|
| 93 |
+
|
| 94 |
+
t = np.append(t, nt)
|
| 95 |
+
# std[:, -1, :] = np.std(data_pad[:, -window:, :], axis=1)
|
| 96 |
+
# mean[:, -1, :] = np.mean(data_pad[:, -window:, :], axis=1)
|
| 97 |
+
std[:, -1, :], mean[:, -1, :] = std[:, -2, :], mean[:, -2, :]
|
| 98 |
+
std[:, 0, :], mean[:, 0, :] = std[:, 1, :], mean[:, 1, :]
|
| 99 |
+
std[std == 0] = 1
|
| 100 |
+
|
| 101 |
+
# ## normalize data with interplated std
|
| 102 |
+
t_interp = np.arange(nt, dtype="int")
|
| 103 |
+
std_interp = interp1d(t, std, axis=1, kind="slinear")(t_interp)
|
| 104 |
+
mean_interp = interp1d(t, mean, axis=1, kind="slinear")(t_interp)
|
| 105 |
+
data = (data - mean_interp) / std_interp
|
| 106 |
+
|
| 107 |
+
return data
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def preprocess(data):
|
| 111 |
+
raw = data.copy()
|
| 112 |
+
data = normalize_batch(data)
|
| 113 |
+
if len(data.shape) == 3:
|
| 114 |
+
data = data[:, :, np.newaxis, :]
|
| 115 |
+
raw = raw[:, :, np.newaxis, :]
|
| 116 |
+
return data, raw
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def calc_timestamp(timestamp, sec):
|
| 120 |
+
timestamp = datetime.strptime(timestamp, "%Y-%m-%dT%H:%M:%S.%f") + timedelta(seconds=sec)
|
| 121 |
+
return timestamp.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3]
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def format_picks(picks, dt, amplitudes):
|
| 125 |
+
picks_ = []
|
| 126 |
+
for pick, amplitude in zip(picks, amplitudes):
|
| 127 |
+
for idxs, probs, amps in zip(pick.p_idx, pick.p_prob, amplitude.p_amp):
|
| 128 |
+
for idx, prob, amp in zip(idxs, probs, amps):
|
| 129 |
+
picks_.append(
|
| 130 |
+
{
|
| 131 |
+
"id": pick.fname,
|
| 132 |
+
"timestamp": calc_timestamp(pick.t0, float(idx) * dt),
|
| 133 |
+
"prob": prob,
|
| 134 |
+
"amp": amp,
|
| 135 |
+
"type": "p",
|
| 136 |
+
}
|
| 137 |
+
)
|
| 138 |
+
for idxs, probs, amps in zip(pick.s_idx, pick.s_prob, amplitude.s_amp):
|
| 139 |
+
for idx, prob, amp in zip(idxs, probs, amps):
|
| 140 |
+
picks_.append(
|
| 141 |
+
{
|
| 142 |
+
"id": pick.fname,
|
| 143 |
+
"timestamp": calc_timestamp(pick.t0, float(idx) * dt),
|
| 144 |
+
"prob": prob,
|
| 145 |
+
"amp": amp,
|
| 146 |
+
"type": "s",
|
| 147 |
+
}
|
| 148 |
+
)
|
| 149 |
+
return picks_
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def format_data(data):
|
| 153 |
+
# chn2idx = {"ENZ": {"E":0, "N":1, "Z":2},
|
| 154 |
+
# "123": {"3":0, "2":1, "1":2},
|
| 155 |
+
# "12Z": {"1":0, "2":1, "Z":2}}
|
| 156 |
+
chn2idx = {"E": 0, "N": 1, "Z": 2, "3": 0, "2": 1, "1": 2}
|
| 157 |
+
Data = NamedTuple("data", [("id", list), ("timestamp", list), ("vec", list), ("dt", float)])
|
| 158 |
+
|
| 159 |
+
# Group by station
|
| 160 |
+
chn_ = defaultdict(list)
|
| 161 |
+
t0_ = defaultdict(list)
|
| 162 |
+
vv_ = defaultdict(list)
|
| 163 |
+
for i in range(len(data.id)):
|
| 164 |
+
key = data.id[i][:-1]
|
| 165 |
+
chn_[key].append(data.id[i][-1])
|
| 166 |
+
t0_[key].append(datetime.strptime(data.timestamp[i], "%Y-%m-%dT%H:%M:%S.%f").timestamp() * SAMPLING_RATE)
|
| 167 |
+
vv_[key].append(np.array(data.vec[i]))
|
| 168 |
+
|
| 169 |
+
# Merge to Data tuple
|
| 170 |
+
id_ = []
|
| 171 |
+
timestamp_ = []
|
| 172 |
+
vec_ = []
|
| 173 |
+
for k in chn_:
|
| 174 |
+
id_.append(k)
|
| 175 |
+
min_t0 = min(t0_[k])
|
| 176 |
+
timestamp_.append(datetime.fromtimestamp(min_t0 / SAMPLING_RATE).strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3])
|
| 177 |
+
vec = np.zeros([X_SHAPE[0], X_SHAPE[-1]])
|
| 178 |
+
for i in range(len(chn_[k])):
|
| 179 |
+
# vec[int(t0_[k][i]-min_t0):len(vv_[k][i]), chn2idx[chn_[k][i]]] = vv_[k][i][int(t0_[k][i]-min_t0):X_SHAPE[0]] - np.mean(vv_[k][i])
|
| 180 |
+
shift = int(t0_[k][i] - min_t0)
|
| 181 |
+
vec[shift : len(vv_[k][i]) + shift, chn2idx[chn_[k][i]]] = vv_[k][i][: X_SHAPE[0] - shift] - np.mean(
|
| 182 |
+
vv_[k][i][: X_SHAPE[0] - shift]
|
| 183 |
+
)
|
| 184 |
+
vec_.append(vec.tolist())
|
| 185 |
+
|
| 186 |
+
return Data(id=id_, timestamp=timestamp_, vec=vec_, dt=1 / SAMPLING_RATE)
|
| 187 |
+
# return {"id": id_, "timestamp": timestamp_, "vec": vec_, "dt":1 / SAMPLING_RATE}
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def get_prediction(data, return_preds=False):
|
| 191 |
+
vec = np.array(data.vec)
|
| 192 |
+
vec, vec_raw = preprocess(vec)
|
| 193 |
+
|
| 194 |
+
feed = {model.X: vec, model.drop_rate: 0, model.is_training: False}
|
| 195 |
+
preds = sess.run(model.preds, feed_dict=feed)
|
| 196 |
+
|
| 197 |
+
picks = extract_picks(preds, station_ids=data.id, begin_times=data.timestamp, waveforms=vec_raw)
|
| 198 |
+
|
| 199 |
+
picks = [
|
| 200 |
+
{k: v for k, v in pick.items() if k in ["station_id", "phase_time", "phase_score", "phase_type", "dt"]}
|
| 201 |
+
for pick in picks
|
| 202 |
+
]
|
| 203 |
+
|
| 204 |
+
if return_preds:
|
| 205 |
+
return picks, preds
|
| 206 |
+
|
| 207 |
+
return picks
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
class Data(BaseModel):
|
| 211 |
+
# id: Union[List[str], str]
|
| 212 |
+
# timestamp: Union[List[str], str]
|
| 213 |
+
# vec: Union[List[List[List[float]]], List[List[float]]]
|
| 214 |
+
id: List[str]
|
| 215 |
+
timestamp: List[Union[str, float, datetime]]
|
| 216 |
+
vec: Union[List[List[List[float]]], List[List[float]]]
|
| 217 |
+
|
| 218 |
+
dt: Optional[float] = 0.01
|
| 219 |
+
## gamma
|
| 220 |
+
stations: Optional[List[Dict[str, Union[float, str]]]] = None
|
| 221 |
+
config: Optional[Dict[str, Union[List[float], List[int], List[str], float, int, str]]] = None
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
# @app.on_event("startup")
|
| 225 |
+
# def set_default_executor():
|
| 226 |
+
# from concurrent.futures import ThreadPoolExecutor
|
| 227 |
+
# import asyncio
|
| 228 |
+
#
|
| 229 |
+
# loop = asyncio.get_running_loop()
|
| 230 |
+
# loop.set_default_executor(
|
| 231 |
+
# ThreadPoolExecutor(max_workers=2)
|
| 232 |
+
# )
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
@app.post("/predict")
|
| 236 |
+
def predict(data: Data):
|
| 237 |
+
picks = get_prediction(data)
|
| 238 |
+
|
| 239 |
+
return picks
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
@app.websocket("/ws")
|
| 243 |
+
async def websocket_endpoint(websocket: WebSocket):
|
| 244 |
+
await websocket.accept()
|
| 245 |
+
while True:
|
| 246 |
+
data = await websocket.receive_json()
|
| 247 |
+
# data = json.loads(data)
|
| 248 |
+
data = Data(**data)
|
| 249 |
+
picks = get_prediction(data)
|
| 250 |
+
await websocket.send_json(picks)
|
| 251 |
+
print("PhaseNet Updating...")
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
@app.post("/predict_prob")
|
| 255 |
+
def predict(data: Data):
|
| 256 |
+
picks, preds = get_prediction(data, True)
|
| 257 |
+
|
| 258 |
+
return picks, preds.tolist()
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
@app.post("/predict_phasenet2gamma")
|
| 262 |
+
def predict(data: Data):
|
| 263 |
+
picks = get_prediction(data)
|
| 264 |
+
|
| 265 |
+
# if use_kafka:
|
| 266 |
+
# print("Push picks to kafka...")
|
| 267 |
+
# for pick in picks:
|
| 268 |
+
# producer.send("phasenet_picks", key=pick["id"], value=pick)
|
| 269 |
+
try:
|
| 270 |
+
catalog = requests.post(
|
| 271 |
+
f"{GAMMA_API_URL}/predict", json={"picks": picks, "stations": data.stations, "config": data.config}
|
| 272 |
+
)
|
| 273 |
+
print(catalog.json()["catalog"])
|
| 274 |
+
return catalog.json()
|
| 275 |
+
except Exception as error:
|
| 276 |
+
print(error)
|
| 277 |
+
|
| 278 |
+
return {}
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
@app.post("/predict_phasenet2gamma2ui")
|
| 282 |
+
def predict(data: Data):
|
| 283 |
+
picks = get_prediction(data)
|
| 284 |
+
|
| 285 |
+
try:
|
| 286 |
+
catalog = requests.post(
|
| 287 |
+
f"{GAMMA_API_URL}/predict", json={"picks": picks, "stations": data.stations, "config": data.config}
|
| 288 |
+
)
|
| 289 |
+
print(catalog.json()["catalog"])
|
| 290 |
+
return catalog.json()
|
| 291 |
+
except Exception as error:
|
| 292 |
+
print(error)
|
| 293 |
+
|
| 294 |
+
if use_kafka:
|
| 295 |
+
print("Push picks to kafka...")
|
| 296 |
+
for pick in picks:
|
| 297 |
+
producer.send("phasenet_picks", key=pick["id"], value=pick)
|
| 298 |
+
print("Push waveform to kafka...")
|
| 299 |
+
for id, timestamp, vec in zip(data.id, data.timestamp, data.vec):
|
| 300 |
+
producer.send("waveform_phasenet", key=id, value={"timestamp": timestamp, "vec": vec, "dt": data.dt})
|
| 301 |
+
|
| 302 |
+
return {}
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
@app.post("/predict_stream_phasenet2gamma")
|
| 306 |
+
def predict(data: Data):
|
| 307 |
+
data = format_data(data)
|
| 308 |
+
# for i in range(len(data.id)):
|
| 309 |
+
# plt.clf()
|
| 310 |
+
# plt.subplot(311)
|
| 311 |
+
# plt.plot(np.array(data.vec)[i, :, 0])
|
| 312 |
+
# plt.subplot(312)
|
| 313 |
+
# plt.plot(np.array(data.vec)[i, :, 1])
|
| 314 |
+
# plt.subplot(313)
|
| 315 |
+
# plt.plot(np.array(data.vec)[i, :, 2])
|
| 316 |
+
# plt.savefig(f"{data.id[i]}.png")
|
| 317 |
+
|
| 318 |
+
picks = get_prediction(data)
|
| 319 |
+
|
| 320 |
+
return_value = {}
|
| 321 |
+
try:
|
| 322 |
+
catalog = requests.post(f"{GAMMA_API_URL}/predict_stream", json={"picks": picks})
|
| 323 |
+
print("GMMA:", catalog.json()["catalog"])
|
| 324 |
+
return_value = catalog.json()
|
| 325 |
+
except Exception as error:
|
| 326 |
+
print(error)
|
| 327 |
+
|
| 328 |
+
if use_kafka:
|
| 329 |
+
print("Push picks to kafka...")
|
| 330 |
+
for pick in picks:
|
| 331 |
+
producer.send("phasenet_picks", key=pick["id"], value=pick)
|
| 332 |
+
print("Push waveform to kafka...")
|
| 333 |
+
for id, timestamp, vec in zip(data.id, data.timestamp, data.vec):
|
| 334 |
+
producer.send("waveform_phasenet", key=id, value={"timestamp": timestamp, "vec": vec, "dt": data.dt})
|
| 335 |
+
|
| 336 |
+
return return_value
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
@app.get("/healthz")
|
| 340 |
+
def healthz():
|
| 341 |
+
return {"status": "ok"}
|
phasenet/data_reader.py
ADDED
|
@@ -0,0 +1,1010 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tensorflow as tf
|
| 2 |
+
|
| 3 |
+
tf.compat.v1.disable_eager_execution()
|
| 4 |
+
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
|
| 5 |
+
import logging
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import pandas as pd
|
| 10 |
+
|
| 11 |
+
pd.options.mode.chained_assignment = None
|
| 12 |
+
import json
|
| 13 |
+
import random
|
| 14 |
+
from collections import defaultdict
|
| 15 |
+
|
| 16 |
+
# import s3fs
|
| 17 |
+
import h5py
|
| 18 |
+
import obspy
|
| 19 |
+
from scipy.interpolate import interp1d
|
| 20 |
+
from tqdm import tqdm
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def py_func_decorator(output_types=None, output_shapes=None, name=None):
|
| 24 |
+
def decorator(func):
|
| 25 |
+
def call(*args, **kwargs):
|
| 26 |
+
nonlocal output_shapes
|
| 27 |
+
# flat_output_types = nest.flatten(output_types)
|
| 28 |
+
flat_output_types = tf.nest.flatten(output_types)
|
| 29 |
+
# flat_values = tf.py_func(
|
| 30 |
+
flat_values = tf.numpy_function(func, inp=args, Tout=flat_output_types, name=name)
|
| 31 |
+
if output_shapes is not None:
|
| 32 |
+
for v, s in zip(flat_values, output_shapes):
|
| 33 |
+
v.set_shape(s)
|
| 34 |
+
# return nest.pack_sequence_as(output_types, flat_values)
|
| 35 |
+
return tf.nest.pack_sequence_as(output_types, flat_values)
|
| 36 |
+
|
| 37 |
+
return call
|
| 38 |
+
|
| 39 |
+
return decorator
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def dataset_map(iterator, output_types, output_shapes=None, num_parallel_calls=None, name=None, shuffle=False):
|
| 43 |
+
dataset = tf.data.Dataset.range(len(iterator))
|
| 44 |
+
if shuffle:
|
| 45 |
+
dataset = dataset.shuffle(len(iterator), reshuffle_each_iteration=True)
|
| 46 |
+
|
| 47 |
+
@py_func_decorator(output_types, output_shapes, name=name)
|
| 48 |
+
def index_to_entry(idx):
|
| 49 |
+
return iterator[idx]
|
| 50 |
+
|
| 51 |
+
return dataset.map(index_to_entry, num_parallel_calls=num_parallel_calls)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def normalize(data, axis=(0,)):
|
| 55 |
+
"""data shape: (nt, nsta, nch)"""
|
| 56 |
+
data -= np.mean(data, axis=axis, keepdims=True)
|
| 57 |
+
std_data = np.std(data, axis=axis, keepdims=True)
|
| 58 |
+
std_data[std_data == 0] = 1
|
| 59 |
+
data /= std_data
|
| 60 |
+
# data /= (std_data + 1e-12)
|
| 61 |
+
return data
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def normalize_long(data, axis=(0,), window=3000):
|
| 65 |
+
"""
|
| 66 |
+
data: nt, nch
|
| 67 |
+
"""
|
| 68 |
+
nt, nar, nch = data.shape
|
| 69 |
+
if window is None:
|
| 70 |
+
window = nt
|
| 71 |
+
shift = window // 2
|
| 72 |
+
|
| 73 |
+
dtype = data.dtype
|
| 74 |
+
## std in slide windows
|
| 75 |
+
data_pad = np.pad(data, ((window // 2, window // 2), (0, 0), (0, 0)), mode="reflect")
|
| 76 |
+
t = np.arange(0, nt, shift, dtype="int")
|
| 77 |
+
std = np.zeros([len(t) + 1, nar, nch])
|
| 78 |
+
mean = np.zeros([len(t) + 1, nar, nch])
|
| 79 |
+
for i in range(1, len(std)):
|
| 80 |
+
std[i, :] = np.std(data_pad[i * shift : i * shift + window, :, :], axis=axis)
|
| 81 |
+
mean[i, :] = np.mean(data_pad[i * shift : i * shift + window, :, :], axis=axis)
|
| 82 |
+
|
| 83 |
+
t = np.append(t, nt)
|
| 84 |
+
# std[-1, :] = np.std(data_pad[-window:, :], axis=0)
|
| 85 |
+
# mean[-1, :] = np.mean(data_pad[-window:, :], axis=0)
|
| 86 |
+
std[-1, ...], mean[-1, ...] = std[-2, ...], mean[-2, ...]
|
| 87 |
+
std[0, ...], mean[0, ...] = std[1, ...], mean[1, ...]
|
| 88 |
+
# std[std == 0] = 1.0
|
| 89 |
+
|
| 90 |
+
## normalize data with interplated std
|
| 91 |
+
t_interp = np.arange(nt, dtype="int")
|
| 92 |
+
std_interp = interp1d(t, std, axis=0, kind="slinear")(t_interp)
|
| 93 |
+
# std_interp = np.exp(interp1d(t, np.log(std), axis=0, kind="slinear")(t_interp))
|
| 94 |
+
mean_interp = interp1d(t, mean, axis=0, kind="slinear")(t_interp)
|
| 95 |
+
tmp = np.sum(std_interp, axis=(0, 1))
|
| 96 |
+
std_interp[std_interp == 0] = 1.0
|
| 97 |
+
data = (data - mean_interp) / std_interp
|
| 98 |
+
# data = (data - mean_interp)/(std_interp + 1e-12)
|
| 99 |
+
|
| 100 |
+
### dropout effect of < 3 channel
|
| 101 |
+
nonzero = np.count_nonzero(tmp)
|
| 102 |
+
if (nonzero < 3) and (nonzero > 0):
|
| 103 |
+
data *= 3.0 / nonzero
|
| 104 |
+
|
| 105 |
+
return data.astype(dtype)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def normalize_batch(data, window=3000):
|
| 109 |
+
"""
|
| 110 |
+
data: nsta, nt, nch
|
| 111 |
+
"""
|
| 112 |
+
nsta, nt, nar, nch = data.shape
|
| 113 |
+
if window is None:
|
| 114 |
+
window = nt
|
| 115 |
+
shift = window // 2
|
| 116 |
+
|
| 117 |
+
## std in slide windows
|
| 118 |
+
data_pad = np.pad(data, ((0, 0), (window // 2, window // 2), (0, 0), (0, 0)), mode="reflect")
|
| 119 |
+
t = np.arange(0, nt, shift, dtype="int")
|
| 120 |
+
std = np.zeros([nsta, len(t) + 1, nar, nch])
|
| 121 |
+
mean = np.zeros([nsta, len(t) + 1, nar, nch])
|
| 122 |
+
for i in range(1, len(t)):
|
| 123 |
+
std[:, i, :, :] = np.std(data_pad[:, i * shift : i * shift + window, :, :], axis=1)
|
| 124 |
+
mean[:, i, :, :] = np.mean(data_pad[:, i * shift : i * shift + window, :, :], axis=1)
|
| 125 |
+
|
| 126 |
+
t = np.append(t, nt)
|
| 127 |
+
# std[:, -1, :] = np.std(data_pad[:, -window:, :], axis=1)
|
| 128 |
+
# mean[:, -1, :] = np.mean(data_pad[:, -window:, :], axis=1)
|
| 129 |
+
std[:, -1, :, :], mean[:, -1, :, :] = std[:, -2, :, :], mean[:, -2, :, :]
|
| 130 |
+
std[:, 0, :, :], mean[:, 0, :, :] = std[:, 1, :, :], mean[:, 1, :, :]
|
| 131 |
+
# std[std == 0] = 1
|
| 132 |
+
|
| 133 |
+
# ## normalize data with interplated std
|
| 134 |
+
t_interp = np.arange(nt, dtype="int")
|
| 135 |
+
std_interp = interp1d(t, std, axis=1, kind="slinear")(t_interp)
|
| 136 |
+
# std_interp = np.exp(interp1d(t, np.log(std), axis=1, kind="slinear")(t_interp))
|
| 137 |
+
mean_interp = interp1d(t, mean, axis=1, kind="slinear")(t_interp)
|
| 138 |
+
tmp = np.sum(std_interp, axis=(1, 2))
|
| 139 |
+
std_interp[std_interp == 0] = 1.0
|
| 140 |
+
data = (data - mean_interp) / std_interp
|
| 141 |
+
# data = (data - mean_interp)/(std_interp + 1e-12)
|
| 142 |
+
|
| 143 |
+
### dropout effect of < 3 channel
|
| 144 |
+
nonzero = np.count_nonzero(tmp, axis=-1)
|
| 145 |
+
data[nonzero > 0, ...] *= 3.0 / nonzero[nonzero > 0][:, np.newaxis, np.newaxis, np.newaxis]
|
| 146 |
+
|
| 147 |
+
return data
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class DataConfig:
|
| 151 |
+
seed = 123
|
| 152 |
+
use_seed = True
|
| 153 |
+
n_channel = 3
|
| 154 |
+
n_class = 3
|
| 155 |
+
sampling_rate = 100
|
| 156 |
+
dt = 1.0 / sampling_rate
|
| 157 |
+
X_shape = [3000, 1, n_channel]
|
| 158 |
+
Y_shape = [3000, 1, n_class]
|
| 159 |
+
min_event_gap = 3 * sampling_rate
|
| 160 |
+
label_shape = "gaussian"
|
| 161 |
+
label_width = 30
|
| 162 |
+
dtype = "float32"
|
| 163 |
+
|
| 164 |
+
def __init__(self, **kwargs):
|
| 165 |
+
for k, v in kwargs.items():
|
| 166 |
+
setattr(self, k, v)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class DataReader:
|
| 170 |
+
def __init__(
|
| 171 |
+
self, format="numpy", config=DataConfig(), response_xml=None, sampling_rate=100, highpass_filter=0, **kwargs
|
| 172 |
+
):
|
| 173 |
+
self.buffer = {}
|
| 174 |
+
self.n_channel = config.n_channel
|
| 175 |
+
self.n_class = config.n_class
|
| 176 |
+
self.X_shape = config.X_shape
|
| 177 |
+
self.Y_shape = config.Y_shape
|
| 178 |
+
self.dt = config.dt
|
| 179 |
+
self.dtype = config.dtype
|
| 180 |
+
self.label_shape = config.label_shape
|
| 181 |
+
self.label_width = config.label_width
|
| 182 |
+
self.config = config
|
| 183 |
+
self.format = format
|
| 184 |
+
# if "highpass_filter" in kwargs:
|
| 185 |
+
# self.highpass_filter = kwargs["highpass_filter"]
|
| 186 |
+
self.highpass_filter = highpass_filter
|
| 187 |
+
# self.response_xml = response_xml
|
| 188 |
+
if response_xml is not None:
|
| 189 |
+
self.response = obspy.read_inventory(response_xml)
|
| 190 |
+
else:
|
| 191 |
+
self.response = None
|
| 192 |
+
self.sampling_rate = sampling_rate
|
| 193 |
+
if format in ["numpy", "mseed", "sac"]:
|
| 194 |
+
self.data_dir = kwargs["data_dir"]
|
| 195 |
+
try:
|
| 196 |
+
csv = pd.read_csv(kwargs["data_list"], header=0, sep="[,|\s+]", engine="python")
|
| 197 |
+
except:
|
| 198 |
+
csv = pd.read_csv(kwargs["data_list"], header=0, sep="\t")
|
| 199 |
+
self.data_list = csv["fname"]
|
| 200 |
+
self.num_data = len(self.data_list)
|
| 201 |
+
elif format == "hdf5":
|
| 202 |
+
self.h5 = h5py.File(kwargs["hdf5_file"], "r", libver="latest", swmr=True)
|
| 203 |
+
self.h5_data = self.h5[kwargs["hdf5_group"]]
|
| 204 |
+
self.data_list = list(self.h5_data.keys())
|
| 205 |
+
self.num_data = len(self.data_list)
|
| 206 |
+
elif format == "s3":
|
| 207 |
+
self.s3fs = s3fs.S3FileSystem(
|
| 208 |
+
anon=kwargs["anon"],
|
| 209 |
+
key=kwargs["key"],
|
| 210 |
+
secret=kwargs["secret"],
|
| 211 |
+
client_kwargs={"endpoint_url": kwargs["s3_url"]},
|
| 212 |
+
use_ssl=kwargs["use_ssl"],
|
| 213 |
+
)
|
| 214 |
+
self.num_data = 0
|
| 215 |
+
else:
|
| 216 |
+
raise (f"{format} not support!")
|
| 217 |
+
|
| 218 |
+
def __len__(self):
|
| 219 |
+
return self.num_data
|
| 220 |
+
|
| 221 |
+
def read_numpy(self, fname):
|
| 222 |
+
# try:
|
| 223 |
+
if fname not in self.buffer:
|
| 224 |
+
npz = np.load(fname)
|
| 225 |
+
meta = {}
|
| 226 |
+
if len(npz["data"].shape) == 2:
|
| 227 |
+
meta["data"] = npz["data"][:, np.newaxis, :]
|
| 228 |
+
else:
|
| 229 |
+
meta["data"] = npz["data"]
|
| 230 |
+
if "p_idx" in npz.files:
|
| 231 |
+
if len(npz["p_idx"].shape) == 0:
|
| 232 |
+
meta["itp"] = [[npz["p_idx"]]]
|
| 233 |
+
else:
|
| 234 |
+
meta["itp"] = npz["p_idx"]
|
| 235 |
+
if "s_idx" in npz.files:
|
| 236 |
+
if len(npz["s_idx"].shape) == 0:
|
| 237 |
+
meta["its"] = [[npz["s_idx"]]]
|
| 238 |
+
else:
|
| 239 |
+
meta["its"] = npz["s_idx"]
|
| 240 |
+
if "itp" in npz.files:
|
| 241 |
+
if len(npz["itp"].shape) == 0:
|
| 242 |
+
meta["itp"] = [[npz["itp"]]]
|
| 243 |
+
else:
|
| 244 |
+
meta["itp"] = npz["itp"]
|
| 245 |
+
if "its" in npz.files:
|
| 246 |
+
if len(npz["its"].shape) == 0:
|
| 247 |
+
meta["its"] = [[npz["its"]]]
|
| 248 |
+
else:
|
| 249 |
+
meta["its"] = npz["its"]
|
| 250 |
+
if "station_id" in npz.files:
|
| 251 |
+
meta["station_id"] = npz["station_id"]
|
| 252 |
+
if "sta_id" in npz.files:
|
| 253 |
+
meta["station_id"] = npz["sta_id"]
|
| 254 |
+
if "t0" in npz.files:
|
| 255 |
+
meta["t0"] = npz["t0"]
|
| 256 |
+
self.buffer[fname] = meta
|
| 257 |
+
else:
|
| 258 |
+
meta = self.buffer[fname]
|
| 259 |
+
return meta
|
| 260 |
+
# except:
|
| 261 |
+
# logging.error("Failed reading {}".format(fname))
|
| 262 |
+
# return None
|
| 263 |
+
|
| 264 |
+
def read_hdf5(self, fname):
|
| 265 |
+
data = self.h5_data[fname][()]
|
| 266 |
+
attrs = self.h5_data[fname].attrs
|
| 267 |
+
meta = {}
|
| 268 |
+
if len(data.shape) == 2:
|
| 269 |
+
meta["data"] = data[:, np.newaxis, :]
|
| 270 |
+
else:
|
| 271 |
+
meta["data"] = data
|
| 272 |
+
if "p_idx" in attrs:
|
| 273 |
+
if len(attrs["p_idx"].shape) == 0:
|
| 274 |
+
meta["itp"] = [[attrs["p_idx"]]]
|
| 275 |
+
else:
|
| 276 |
+
meta["itp"] = attrs["p_idx"]
|
| 277 |
+
if "s_idx" in attrs:
|
| 278 |
+
if len(attrs["s_idx"].shape) == 0:
|
| 279 |
+
meta["its"] = [[attrs["s_idx"]]]
|
| 280 |
+
else:
|
| 281 |
+
meta["its"] = attrs["s_idx"]
|
| 282 |
+
if "itp" in attrs:
|
| 283 |
+
if len(attrs["itp"].shape) == 0:
|
| 284 |
+
meta["itp"] = [[attrs["itp"]]]
|
| 285 |
+
else:
|
| 286 |
+
meta["itp"] = attrs["itp"]
|
| 287 |
+
if "its" in attrs:
|
| 288 |
+
if len(attrs["its"].shape) == 0:
|
| 289 |
+
meta["its"] = [[attrs["its"]]]
|
| 290 |
+
else:
|
| 291 |
+
meta["its"] = attrs["its"]
|
| 292 |
+
if "t0" in attrs:
|
| 293 |
+
meta["t0"] = attrs["t0"]
|
| 294 |
+
return meta
|
| 295 |
+
|
| 296 |
+
def read_s3(self, format, fname, bucket, key, secret, s3_url, use_ssl):
|
| 297 |
+
with self.s3fs.open(bucket + "/" + fname, "rb") as fp:
|
| 298 |
+
if format == "numpy":
|
| 299 |
+
meta = self.read_numpy(fp)
|
| 300 |
+
elif format == "mseed":
|
| 301 |
+
meta = self.read_mseed(fp)
|
| 302 |
+
else:
|
| 303 |
+
raise (f"Format {format} not supported")
|
| 304 |
+
return meta
|
| 305 |
+
|
| 306 |
+
def read_mseed(self, fname, response=None, highpass_filter=0.0, sampling_rate=100, return_single_station=True):
|
| 307 |
+
try:
|
| 308 |
+
stream = obspy.read(fname)
|
| 309 |
+
stream = stream.merge(fill_value="latest")
|
| 310 |
+
if response is not None:
|
| 311 |
+
# response = obspy.read_inventory(response_xml)
|
| 312 |
+
stream = stream.remove_sensitivity(response)
|
| 313 |
+
except Exception as e:
|
| 314 |
+
print(f"Error reading {fname}:\n{e}")
|
| 315 |
+
return {}
|
| 316 |
+
tmp_stream = obspy.Stream()
|
| 317 |
+
for trace in stream:
|
| 318 |
+
if len(trace.data) < 10:
|
| 319 |
+
continue
|
| 320 |
+
|
| 321 |
+
## interpolate to 100 Hz
|
| 322 |
+
if abs(trace.stats.sampling_rate - sampling_rate) > 0.1:
|
| 323 |
+
logging.warning(f"Resampling {trace.id} from {trace.stats.sampling_rate} to {sampling_rate} Hz")
|
| 324 |
+
try:
|
| 325 |
+
trace = trace.interpolate(sampling_rate, method="linear")
|
| 326 |
+
except Exception as e:
|
| 327 |
+
print(f"Error resampling {trace.id}:\n{e}")
|
| 328 |
+
|
| 329 |
+
trace = trace.detrend("demean")
|
| 330 |
+
|
| 331 |
+
## highpass filtering > 1Hz
|
| 332 |
+
if highpass_filter > 0.0:
|
| 333 |
+
trace = trace.filter("highpass", freq=highpass_filter)
|
| 334 |
+
|
| 335 |
+
tmp_stream.append(trace)
|
| 336 |
+
|
| 337 |
+
if len(tmp_stream) == 0:
|
| 338 |
+
return {}
|
| 339 |
+
stream = tmp_stream
|
| 340 |
+
|
| 341 |
+
begin_time = min([st.stats.starttime for st in stream])
|
| 342 |
+
end_time = max([st.stats.endtime for st in stream])
|
| 343 |
+
stream = stream.trim(begin_time, end_time, pad=True, fill_value=0)
|
| 344 |
+
|
| 345 |
+
comp = ["3", "2", "1", "E", "N", "U", "V", "Z"]
|
| 346 |
+
order = {key: i for i, key in enumerate(comp)}
|
| 347 |
+
comp2idx = {
|
| 348 |
+
"3": 0,
|
| 349 |
+
"2": 1,
|
| 350 |
+
"1": 2,
|
| 351 |
+
"E": 0,
|
| 352 |
+
"N": 1,
|
| 353 |
+
"Z": 2,
|
| 354 |
+
"U": 0,
|
| 355 |
+
"V": 1,
|
| 356 |
+
} ## only for cases less than 3 components
|
| 357 |
+
|
| 358 |
+
station_ids = defaultdict(list)
|
| 359 |
+
for tr in stream:
|
| 360 |
+
station_ids[tr.id[:-1]].append(tr.id[-1])
|
| 361 |
+
if tr.id[-1] not in comp:
|
| 362 |
+
print(f"Unknown component {tr.id[-1]}")
|
| 363 |
+
|
| 364 |
+
station_keys = sorted(list(station_ids.keys()))
|
| 365 |
+
|
| 366 |
+
nx = len(station_ids)
|
| 367 |
+
nt = len(stream[0].data)
|
| 368 |
+
data = np.zeros([3, nt, nx], dtype=np.float32)
|
| 369 |
+
for i, sta in enumerate(station_keys):
|
| 370 |
+
for j, c in enumerate(sorted(station_ids[sta], key=lambda x: order[x])):
|
| 371 |
+
if len(station_ids[sta]) != 3: ## less than 3 component
|
| 372 |
+
j = comp2idx[c]
|
| 373 |
+
|
| 374 |
+
if len(stream.select(id=sta + c)) == 0:
|
| 375 |
+
print(f"Empty trace: {sta+c} {begin_time}")
|
| 376 |
+
continue
|
| 377 |
+
|
| 378 |
+
trace = stream.select(id=sta + c)[0]
|
| 379 |
+
|
| 380 |
+
## accerleration to velocity
|
| 381 |
+
if sta[-1] == "N":
|
| 382 |
+
trace = trace.integrate().filter("highpass", freq=1.0)
|
| 383 |
+
|
| 384 |
+
tmp = trace.data.astype("float32")
|
| 385 |
+
data[j, : len(tmp), i] = tmp[:nt]
|
| 386 |
+
|
| 387 |
+
# if return_single_station and (len(station_keys) > 1):
|
| 388 |
+
# print(f"Warning: {fname} has multiple stations, returning only the first one {station_keys[0]}")
|
| 389 |
+
# data = data[:, :, 0:1]
|
| 390 |
+
# station_keys = station_keys[0:1]
|
| 391 |
+
|
| 392 |
+
meta = {
|
| 393 |
+
"data": data.transpose([1, 2, 0]),
|
| 394 |
+
"t0": begin_time.datetime.isoformat(timespec="milliseconds"),
|
| 395 |
+
"station_id": station_keys,
|
| 396 |
+
}
|
| 397 |
+
return meta
|
| 398 |
+
|
| 399 |
+
def read_sac(self, fname):
|
| 400 |
+
mseed = obspy.read(fname)
|
| 401 |
+
mseed = mseed.detrend("spline", order=2, dspline=5 * mseed[0].stats.sampling_rate)
|
| 402 |
+
mseed = mseed.merge(fill_value=0)
|
| 403 |
+
if self.highpass_filter > 0:
|
| 404 |
+
mseed = mseed.filter("highpass", freq=self.highpass_filter)
|
| 405 |
+
starttime = min([st.stats.starttime for st in mseed])
|
| 406 |
+
endtime = max([st.stats.endtime for st in mseed])
|
| 407 |
+
mseed = mseed.trim(starttime, endtime, pad=True, fill_value=0)
|
| 408 |
+
if abs(mseed[0].stats.sampling_rate - self.config.sampling_rate) > 1:
|
| 409 |
+
logging.warning(
|
| 410 |
+
f"Sampling rate mismatch in {fname.split('/')[-1]}: {mseed[0].stats.sampling_rate}Hz != {self.config.sampling_rate}Hz "
|
| 411 |
+
)
|
| 412 |
+
|
| 413 |
+
order = ["3", "2", "1", "E", "N", "Z"]
|
| 414 |
+
order = {key: i for i, key in enumerate(order)}
|
| 415 |
+
comp2idx = {"3": 0, "2": 1, "1": 2, "E": 0, "N": 1, "Z": 2}
|
| 416 |
+
|
| 417 |
+
t0 = starttime.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3]
|
| 418 |
+
nt = len(mseed[0].data)
|
| 419 |
+
data = np.zeros([nt, self.config.n_channel], dtype=self.dtype)
|
| 420 |
+
ids = [x.get_id() for x in mseed]
|
| 421 |
+
for j, id in enumerate(sorted(ids, key=lambda x: order[x[-1]])):
|
| 422 |
+
if len(ids) != 3:
|
| 423 |
+
if len(ids) > 3:
|
| 424 |
+
logging.warning(f"More than 3 channels {ids}!")
|
| 425 |
+
j = comp2idx[id[-1]]
|
| 426 |
+
data[:, j] = mseed.select(id=id)[0].data.astype(self.dtype)
|
| 427 |
+
|
| 428 |
+
data = data[:, np.newaxis, :]
|
| 429 |
+
meta = {"data": data, "t0": t0}
|
| 430 |
+
return meta
|
| 431 |
+
|
| 432 |
+
def read_mseed_array(self, fname, stations, amplitude=False, remove_resp=True):
|
| 433 |
+
data = []
|
| 434 |
+
station_id = []
|
| 435 |
+
t0 = []
|
| 436 |
+
raw_amp = []
|
| 437 |
+
|
| 438 |
+
try:
|
| 439 |
+
mseed = obspy.read(fname)
|
| 440 |
+
read_success = True
|
| 441 |
+
except Exception as e:
|
| 442 |
+
read_success = False
|
| 443 |
+
print(e)
|
| 444 |
+
|
| 445 |
+
if read_success:
|
| 446 |
+
try:
|
| 447 |
+
mseed = mseed.merge(fill_value=0)
|
| 448 |
+
except Exception as e:
|
| 449 |
+
print(e)
|
| 450 |
+
|
| 451 |
+
for i in range(len(mseed)):
|
| 452 |
+
if mseed[i].stats.sampling_rate != self.config.sampling_rate:
|
| 453 |
+
logging.warning(
|
| 454 |
+
f"Resampling {mseed[i].id} from {mseed[i].stats.sampling_rate} to {self.config.sampling_rate} Hz"
|
| 455 |
+
)
|
| 456 |
+
try:
|
| 457 |
+
mseed[i] = mseed[i].interpolate(self.config.sampling_rate, method="linear")
|
| 458 |
+
except Exception as e:
|
| 459 |
+
print(e)
|
| 460 |
+
mseed[i].data = mseed[i].data.astype(float) * 0.0 ## set to zero if resampling fails
|
| 461 |
+
|
| 462 |
+
if self.highpass_filter == 0:
|
| 463 |
+
try:
|
| 464 |
+
mseed = mseed.detrend("spline", order=2, dspline=5 * mseed[0].stats.sampling_rate)
|
| 465 |
+
except:
|
| 466 |
+
logging.error(f"Error: spline detrend failed at file {fname}")
|
| 467 |
+
mseed = mseed.detrend("demean")
|
| 468 |
+
else:
|
| 469 |
+
mseed = mseed.filter("highpass", freq=self.highpass_filter)
|
| 470 |
+
|
| 471 |
+
starttime = min([st.stats.starttime for st in mseed])
|
| 472 |
+
endtime = max([st.stats.endtime for st in mseed])
|
| 473 |
+
mseed = mseed.trim(starttime, endtime, pad=True, fill_value=0)
|
| 474 |
+
|
| 475 |
+
order = ["3", "2", "1", "E", "N", "Z"]
|
| 476 |
+
order = {key: i for i, key in enumerate(order)}
|
| 477 |
+
comp2idx = {"3": 0, "2": 1, "1": 2, "E": 0, "N": 1, "Z": 2}
|
| 478 |
+
|
| 479 |
+
nsta = len(stations)
|
| 480 |
+
nt = len(mseed[0].data)
|
| 481 |
+
# for i in range(nsta):
|
| 482 |
+
for sta in stations:
|
| 483 |
+
trace_data = np.zeros([nt, self.config.n_channel], dtype=self.dtype)
|
| 484 |
+
if amplitude:
|
| 485 |
+
trace_amp = np.zeros([nt, self.config.n_channel], dtype=self.dtype)
|
| 486 |
+
empty_station = True
|
| 487 |
+
# sta = stations.iloc[i]["station"]
|
| 488 |
+
# comp = stations.iloc[i]["component"].split(",")
|
| 489 |
+
comp = stations[sta]["component"]
|
| 490 |
+
if amplitude:
|
| 491 |
+
# resp = stations.iloc[i]["response"].split(",")
|
| 492 |
+
resp = stations[sta]["response"]
|
| 493 |
+
|
| 494 |
+
for j, c in enumerate(sorted(comp, key=lambda x: order[x[-1]])):
|
| 495 |
+
resp_j = resp[j]
|
| 496 |
+
if len(comp) != 3: ## less than 3 component
|
| 497 |
+
j = comp2idx[c]
|
| 498 |
+
|
| 499 |
+
if len(mseed.select(id=sta + c)) == 0:
|
| 500 |
+
print(f"Empty trace: {sta+c} {starttime}")
|
| 501 |
+
continue
|
| 502 |
+
else:
|
| 503 |
+
empty_station = False
|
| 504 |
+
|
| 505 |
+
tmp = mseed.select(id=sta + c)[0].data.astype(self.dtype)
|
| 506 |
+
trace_data[: len(tmp), j] = tmp[:nt]
|
| 507 |
+
if amplitude:
|
| 508 |
+
# if stations.iloc[i]["unit"] == "m/s**2":
|
| 509 |
+
if stations[sta]["unit"] == "m/s**2":
|
| 510 |
+
tmp = mseed.select(id=sta + c)[0]
|
| 511 |
+
tmp = tmp.integrate()
|
| 512 |
+
tmp = tmp.filter("highpass", freq=1.0)
|
| 513 |
+
tmp = tmp.data.astype(self.dtype)
|
| 514 |
+
trace_amp[: len(tmp), j] = tmp[:nt]
|
| 515 |
+
# elif stations.iloc[i]["unit"] == "m/s":
|
| 516 |
+
elif stations[sta]["unit"] == "m/s":
|
| 517 |
+
tmp = mseed.select(id=sta + c)[0].data.astype(self.dtype)
|
| 518 |
+
trace_amp[: len(tmp), j] = tmp[:nt]
|
| 519 |
+
else:
|
| 520 |
+
print(
|
| 521 |
+
f"Error in {stations.iloc[i]['station']}\n{stations.iloc[i]['unit']} should be m/s**2 or m/s!"
|
| 522 |
+
)
|
| 523 |
+
if amplitude and remove_resp:
|
| 524 |
+
# trace_amp[:, j] /= float(resp[j])
|
| 525 |
+
trace_amp[:, j] /= float(resp_j)
|
| 526 |
+
|
| 527 |
+
if not empty_station:
|
| 528 |
+
data.append(trace_data)
|
| 529 |
+
if amplitude:
|
| 530 |
+
raw_amp.append(trace_amp)
|
| 531 |
+
station_id.append([sta])
|
| 532 |
+
t0.append(starttime.datetime.isoformat(timespec="milliseconds"))
|
| 533 |
+
|
| 534 |
+
if len(data) > 0:
|
| 535 |
+
data = np.stack(data)
|
| 536 |
+
if len(data.shape) == 3:
|
| 537 |
+
data = data[:, :, np.newaxis, :]
|
| 538 |
+
if amplitude:
|
| 539 |
+
raw_amp = np.stack(raw_amp)
|
| 540 |
+
if len(raw_amp.shape) == 3:
|
| 541 |
+
raw_amp = raw_amp[:, :, np.newaxis, :]
|
| 542 |
+
else:
|
| 543 |
+
nt = 60 * 60 * self.config.sampling_rate # assume 1 hour data
|
| 544 |
+
data = np.zeros([1, nt, 1, self.config.n_channel], dtype=self.dtype)
|
| 545 |
+
if amplitude:
|
| 546 |
+
raw_amp = np.zeros([1, nt, 1, self.config.n_channel], dtype=self.dtype)
|
| 547 |
+
t0 = ["1970-01-01T00:00:00.000"]
|
| 548 |
+
station_id = ["None"]
|
| 549 |
+
|
| 550 |
+
if amplitude:
|
| 551 |
+
meta = {"data": data, "t0": t0, "station_id": station_id, "fname": fname.split("/")[-1], "raw_amp": raw_amp}
|
| 552 |
+
else:
|
| 553 |
+
meta = {"data": data, "t0": t0, "station_id": station_id, "fname": fname.split("/")[-1]}
|
| 554 |
+
return meta
|
| 555 |
+
|
| 556 |
+
def generate_label(self, data, phase_list, mask=None):
|
| 557 |
+
# target = np.zeros(self.Y_shape, dtype=self.dtype)
|
| 558 |
+
target = np.zeros_like(data)
|
| 559 |
+
|
| 560 |
+
if self.label_shape == "gaussian":
|
| 561 |
+
label_window = np.exp(
|
| 562 |
+
-((np.arange(-self.label_width // 2, self.label_width // 2 + 1)) ** 2)
|
| 563 |
+
/ (2 * (self.label_width / 5) ** 2)
|
| 564 |
+
)
|
| 565 |
+
elif self.label_shape == "triangle":
|
| 566 |
+
label_window = 1 - np.abs(
|
| 567 |
+
2 / self.label_width * (np.arange(-self.label_width // 2, self.label_width // 2 + 1))
|
| 568 |
+
)
|
| 569 |
+
else:
|
| 570 |
+
print(f"Label shape {self.label_shape} should be guassian or triangle")
|
| 571 |
+
raise
|
| 572 |
+
|
| 573 |
+
for i, phases in enumerate(phase_list):
|
| 574 |
+
for j, idx_list in enumerate(phases):
|
| 575 |
+
for idx in idx_list:
|
| 576 |
+
if np.isnan(idx):
|
| 577 |
+
continue
|
| 578 |
+
idx = int(idx)
|
| 579 |
+
if (idx - self.label_width // 2 >= 0) and (idx + self.label_width // 2 + 1 <= target.shape[0]):
|
| 580 |
+
target[idx - self.label_width // 2 : idx + self.label_width // 2 + 1, j, i + 1] = label_window
|
| 581 |
+
|
| 582 |
+
target[..., 0] = 1 - np.sum(target[..., 1:], axis=-1)
|
| 583 |
+
if mask is not None:
|
| 584 |
+
target[:, mask == 0, :] = 0
|
| 585 |
+
|
| 586 |
+
return target
|
| 587 |
+
|
| 588 |
+
def random_shift(self, sample, itp, its, itp_old=None, its_old=None, shift_range=None):
|
| 589 |
+
# anchor = np.round(1/2 * (min(itp[~np.isnan(itp.astype(float))]) + min(its[~np.isnan(its.astype(float))]))).astype(int)
|
| 590 |
+
flattern = lambda x: np.array([i for trace in x for i in trace], dtype=float)
|
| 591 |
+
shift_pick = lambda x, shift: [[i - shift for i in trace] for trace in x]
|
| 592 |
+
itp_flat = flattern(itp)
|
| 593 |
+
its_flat = flattern(its)
|
| 594 |
+
if (itp_old is None) and (its_old is None):
|
| 595 |
+
hi = np.round(np.median(itp_flat[~np.isnan(itp_flat)])).astype(int)
|
| 596 |
+
lo = -(sample.shape[0] - np.round(np.median(its_flat[~np.isnan(its_flat)])).astype(int))
|
| 597 |
+
if shift_range is None:
|
| 598 |
+
shift = np.random.randint(low=lo, high=hi + 1)
|
| 599 |
+
else:
|
| 600 |
+
shift = np.random.randint(low=max(lo, shift_range[0]), high=min(hi + 1, shift_range[1]))
|
| 601 |
+
else:
|
| 602 |
+
itp_old_flat = flattern(itp_old)
|
| 603 |
+
its_old_flat = flattern(its_old)
|
| 604 |
+
itp_ref = np.round(np.min(itp_flat[~np.isnan(itp_flat)])).astype(int)
|
| 605 |
+
its_ref = np.round(np.max(its_flat[~np.isnan(its_flat)])).astype(int)
|
| 606 |
+
itp_old_ref = np.round(np.min(itp_old_flat[~np.isnan(itp_old_flat)])).astype(int)
|
| 607 |
+
its_old_ref = np.round(np.max(its_old_flat[~np.isnan(its_old_flat)])).astype(int)
|
| 608 |
+
# min_event_gap = np.round(self.min_event_gap*(its_ref-itp_ref)).astype(int)
|
| 609 |
+
# min_event_gap_old = np.round(self.min_event_gap*(its_old_ref-itp_old_ref)).astype(int)
|
| 610 |
+
if shift_range is None:
|
| 611 |
+
hi = list(range(max(its_ref - itp_old_ref + self.min_event_gap, 0), itp_ref))
|
| 612 |
+
lo = list(range(-(sample.shape[0] - its_ref), -(max(its_old_ref - itp_ref + self.min_event_gap, 0))))
|
| 613 |
+
else:
|
| 614 |
+
lo_ = max(-(sample.shape[0] - its_ref), shift_range[0])
|
| 615 |
+
hi_ = min(itp_ref, shift_range[1])
|
| 616 |
+
hi = list(range(max(its_ref - itp_old_ref + self.min_event_gap, 0), hi_))
|
| 617 |
+
lo = list(range(lo_, -(max(its_old_ref - itp_ref + self.min_event_gap, 0))))
|
| 618 |
+
if len(hi + lo) > 0:
|
| 619 |
+
shift = np.random.choice(hi + lo)
|
| 620 |
+
else:
|
| 621 |
+
shift = 0
|
| 622 |
+
|
| 623 |
+
shifted_sample = np.zeros_like(sample)
|
| 624 |
+
if shift > 0:
|
| 625 |
+
shifted_sample[:-shift, ...] = sample[shift:, ...]
|
| 626 |
+
elif shift < 0:
|
| 627 |
+
shifted_sample[-shift:, ...] = sample[:shift, ...]
|
| 628 |
+
else:
|
| 629 |
+
shifted_sample[...] = sample[...]
|
| 630 |
+
|
| 631 |
+
return shifted_sample, shift_pick(itp, shift), shift_pick(its, shift), shift
|
| 632 |
+
|
| 633 |
+
def stack_events(self, sample_old, itp_old, its_old, shift_range=None, mask_old=None):
|
| 634 |
+
i = np.random.randint(self.num_data)
|
| 635 |
+
base_name = self.data_list[i]
|
| 636 |
+
if self.format == "numpy":
|
| 637 |
+
meta = self.read_numpy(os.path.join(self.data_dir, base_name))
|
| 638 |
+
elif self.format == "hdf5":
|
| 639 |
+
meta = self.read_hdf5(base_name)
|
| 640 |
+
if meta == -1:
|
| 641 |
+
return sample_old, itp_old, its_old
|
| 642 |
+
|
| 643 |
+
sample = np.copy(meta["data"])
|
| 644 |
+
itp = meta["itp"]
|
| 645 |
+
its = meta["its"]
|
| 646 |
+
if mask_old is not None:
|
| 647 |
+
mask = np.copy(meta["mask"])
|
| 648 |
+
sample = normalize(sample)
|
| 649 |
+
sample, itp, its, shift = self.random_shift(sample, itp, its, itp_old, its_old, shift_range)
|
| 650 |
+
|
| 651 |
+
if shift != 0:
|
| 652 |
+
sample_old += sample
|
| 653 |
+
# itp_old = [np.hstack([i, j]) for i,j in zip(itp_old, itp)]
|
| 654 |
+
# its_old = [np.hstack([i, j]) for i,j in zip(its_old, its)]
|
| 655 |
+
itp_old = [i + j for i, j in zip(itp_old, itp)]
|
| 656 |
+
its_old = [i + j for i, j in zip(its_old, its)]
|
| 657 |
+
if mask_old is not None:
|
| 658 |
+
mask_old = mask_old * mask
|
| 659 |
+
|
| 660 |
+
return sample_old, itp_old, its_old, mask_old
|
| 661 |
+
|
| 662 |
+
def cut_window(self, sample, target, itp, its, select_range):
|
| 663 |
+
shift_pick = lambda x, shift: [[i - shift for i in trace] for trace in x]
|
| 664 |
+
sample = sample[select_range[0] : select_range[1]]
|
| 665 |
+
target = target[select_range[0] : select_range[1]]
|
| 666 |
+
return (sample, target, shift_pick(itp, select_range[0]), shift_pick(its, select_range[0]))
|
| 667 |
+
|
| 668 |
+
|
| 669 |
+
class DataReader_train(DataReader):
|
| 670 |
+
def __init__(self, format="numpy", config=DataConfig(), **kwargs):
|
| 671 |
+
super().__init__(format=format, config=config, **kwargs)
|
| 672 |
+
|
| 673 |
+
self.min_event_gap = config.min_event_gap
|
| 674 |
+
self.buffer_channels = {}
|
| 675 |
+
self.shift_range = [-2000 + self.label_width * 2, 1000 - self.label_width * 2]
|
| 676 |
+
self.select_range = [5000, 8000]
|
| 677 |
+
|
| 678 |
+
def __getitem__(self, i):
|
| 679 |
+
base_name = self.data_list[i]
|
| 680 |
+
if self.format == "numpy":
|
| 681 |
+
meta = self.read_numpy(os.path.join(self.data_dir, base_name))
|
| 682 |
+
elif self.format == "hdf5":
|
| 683 |
+
meta = self.read_hdf5(base_name)
|
| 684 |
+
if meta == None:
|
| 685 |
+
return (np.zeros(self.X_shape, dtype=self.dtype), np.zeros(self.Y_shape, dtype=self.dtype), base_name)
|
| 686 |
+
|
| 687 |
+
sample = np.copy(meta["data"])
|
| 688 |
+
itp_list = meta["itp"]
|
| 689 |
+
its_list = meta["its"]
|
| 690 |
+
|
| 691 |
+
sample = normalize(sample)
|
| 692 |
+
if np.random.random() < 0.95:
|
| 693 |
+
sample, itp_list, its_list, _ = self.random_shift(sample, itp_list, its_list, shift_range=self.shift_range)
|
| 694 |
+
sample, itp_list, its_list, _ = self.stack_events(sample, itp_list, its_list, shift_range=self.shift_range)
|
| 695 |
+
target = self.generate_label(sample, [itp_list, its_list])
|
| 696 |
+
sample, target, itp_list, its_list = self.cut_window(sample, target, itp_list, its_list, self.select_range)
|
| 697 |
+
else:
|
| 698 |
+
## noise
|
| 699 |
+
assert self.X_shape[0] <= min(min(itp_list))
|
| 700 |
+
sample = sample[: self.X_shape[0], ...]
|
| 701 |
+
target = np.zeros(self.Y_shape).astype(self.dtype)
|
| 702 |
+
itp_list = [[]]
|
| 703 |
+
its_list = [[]]
|
| 704 |
+
|
| 705 |
+
sample = normalize(sample)
|
| 706 |
+
return (sample.astype(self.dtype), target.astype(self.dtype), base_name)
|
| 707 |
+
|
| 708 |
+
def dataset(self, batch_size, num_parallel_calls=2, shuffle=True, drop_remainder=True):
|
| 709 |
+
dataset = dataset_map(
|
| 710 |
+
self,
|
| 711 |
+
output_types=(self.dtype, self.dtype, "string"),
|
| 712 |
+
output_shapes=(self.X_shape, self.Y_shape, None),
|
| 713 |
+
num_parallel_calls=num_parallel_calls,
|
| 714 |
+
shuffle=shuffle,
|
| 715 |
+
)
|
| 716 |
+
dataset = dataset.batch(batch_size, drop_remainder=drop_remainder).prefetch(batch_size * 2)
|
| 717 |
+
return dataset
|
| 718 |
+
|
| 719 |
+
|
| 720 |
+
class DataReader_test(DataReader):
|
| 721 |
+
def __init__(self, format="numpy", config=DataConfig(), **kwargs):
|
| 722 |
+
super().__init__(format=format, config=config, **kwargs)
|
| 723 |
+
|
| 724 |
+
self.select_range = [5000, 8000]
|
| 725 |
+
|
| 726 |
+
def __getitem__(self, i):
|
| 727 |
+
base_name = self.data_list[i]
|
| 728 |
+
if self.format == "numpy":
|
| 729 |
+
meta = self.read_numpy(os.path.join(self.data_dir, base_name))
|
| 730 |
+
elif self.format == "hdf5":
|
| 731 |
+
meta = self.read_hdf5(base_name)
|
| 732 |
+
if meta == -1:
|
| 733 |
+
return (np.zeros(self.Y_shape, dtype=self.dtype), np.zeros(self.X_shape, dtype=self.dtype), base_name)
|
| 734 |
+
|
| 735 |
+
sample = np.copy(meta["data"])
|
| 736 |
+
itp_list = meta["itp"]
|
| 737 |
+
its_list = meta["its"]
|
| 738 |
+
|
| 739 |
+
# sample, itp_list, its_list, _ = self.random_shift(sample, itp_list, its_list, shift_range=self.shift_range)
|
| 740 |
+
target = self.generate_label(sample, [itp_list, its_list])
|
| 741 |
+
sample, target, itp_list, its_list = self.cut_window(sample, target, itp_list, its_list, self.select_range)
|
| 742 |
+
|
| 743 |
+
sample = normalize(sample)
|
| 744 |
+
return (sample, target, base_name, itp_list, its_list)
|
| 745 |
+
|
| 746 |
+
def dataset(self, batch_size, num_parallel_calls=2, shuffle=False, drop_remainder=False):
|
| 747 |
+
dataset = dataset_map(
|
| 748 |
+
self,
|
| 749 |
+
output_types=(self.dtype, self.dtype, "string", "int64", "int64"),
|
| 750 |
+
output_shapes=(self.X_shape, self.Y_shape, None, None, None),
|
| 751 |
+
num_parallel_calls=num_parallel_calls,
|
| 752 |
+
shuffle=shuffle,
|
| 753 |
+
)
|
| 754 |
+
dataset = dataset.batch(batch_size, drop_remainder=drop_remainder).prefetch(batch_size * 2)
|
| 755 |
+
return dataset
|
| 756 |
+
|
| 757 |
+
|
| 758 |
+
class DataReader_pred(DataReader):
|
| 759 |
+
def __init__(self, format="numpy", amplitude=True, config=DataConfig(), **kwargs):
|
| 760 |
+
super().__init__(format=format, config=config, **kwargs)
|
| 761 |
+
|
| 762 |
+
self.amplitude = amplitude
|
| 763 |
+
|
| 764 |
+
def adjust_missingchannels(self, data):
|
| 765 |
+
tmp = np.max(np.abs(data), axis=0, keepdims=True)
|
| 766 |
+
assert tmp.shape[-1] == data.shape[-1]
|
| 767 |
+
if np.count_nonzero(tmp) > 0:
|
| 768 |
+
data *= data.shape[-1] / np.count_nonzero(tmp)
|
| 769 |
+
return data
|
| 770 |
+
|
| 771 |
+
def __getitem__(self, i):
|
| 772 |
+
base_name = self.data_list[i]
|
| 773 |
+
|
| 774 |
+
if self.format == "numpy":
|
| 775 |
+
meta = self.read_numpy(os.path.join(self.data_dir, base_name))
|
| 776 |
+
elif (self.format == "mseed") or (self.format == "sac"):
|
| 777 |
+
meta = self.read_mseed(
|
| 778 |
+
os.path.join(self.data_dir, base_name),
|
| 779 |
+
response=self.response,
|
| 780 |
+
sampling_rate=self.sampling_rate,
|
| 781 |
+
highpass_filter=self.highpass_filter,
|
| 782 |
+
return_single_station=True,
|
| 783 |
+
)
|
| 784 |
+
elif self.format == "hdf5":
|
| 785 |
+
meta = self.read_hdf5(base_name)
|
| 786 |
+
else:
|
| 787 |
+
raise (f"{self.format} does not support!")
|
| 788 |
+
|
| 789 |
+
if "data" in meta:
|
| 790 |
+
raw_amp = meta["data"].copy()
|
| 791 |
+
sample = normalize_long(meta["data"])
|
| 792 |
+
else:
|
| 793 |
+
raw_amp = np.zeros([3000, 1, 3], dtype=np.float32)
|
| 794 |
+
sample = np.zeros([3000, 1, 3], dtype=np.float32)
|
| 795 |
+
|
| 796 |
+
if "t0" in meta:
|
| 797 |
+
t0 = meta["t0"]
|
| 798 |
+
else:
|
| 799 |
+
t0 = "1970-01-01T00:00:00.000"
|
| 800 |
+
|
| 801 |
+
if "station_id" in meta:
|
| 802 |
+
station_id = meta["station_id"]
|
| 803 |
+
else:
|
| 804 |
+
# station_id = base_name.split("/")[-1].rstrip("*")
|
| 805 |
+
station_id = os.path.basename(base_name).rstrip("*")
|
| 806 |
+
|
| 807 |
+
if np.isnan(sample).any() or np.isinf(sample).any():
|
| 808 |
+
logging.warning(f"Data error: Nan or Inf found in {base_name}")
|
| 809 |
+
sample[np.isnan(sample)] = 0
|
| 810 |
+
sample[np.isinf(sample)] = 0
|
| 811 |
+
|
| 812 |
+
# sample = self.adjust_missingchannels(sample)
|
| 813 |
+
|
| 814 |
+
if self.amplitude:
|
| 815 |
+
return (sample, raw_amp, base_name, t0, station_id)
|
| 816 |
+
else:
|
| 817 |
+
return (sample, base_name, t0, station_id)
|
| 818 |
+
|
| 819 |
+
def dataset(self, batch_size, num_parallel_calls=2, shuffle=False, drop_remainder=False):
|
| 820 |
+
if self.amplitude:
|
| 821 |
+
dataset = dataset_map(
|
| 822 |
+
self,
|
| 823 |
+
output_types=(self.dtype, self.dtype, "string", "string", "string"),
|
| 824 |
+
output_shapes=([None, None, 3], [None, None, 3], None, None, None),
|
| 825 |
+
num_parallel_calls=num_parallel_calls,
|
| 826 |
+
shuffle=shuffle,
|
| 827 |
+
)
|
| 828 |
+
else:
|
| 829 |
+
dataset = dataset_map(
|
| 830 |
+
self,
|
| 831 |
+
output_types=(self.dtype, "string", "string", "string"),
|
| 832 |
+
output_shapes=([None, None, 3], None, None, None),
|
| 833 |
+
num_parallel_calls=num_parallel_calls,
|
| 834 |
+
shuffle=shuffle,
|
| 835 |
+
)
|
| 836 |
+
dataset = dataset.batch(batch_size, drop_remainder=drop_remainder).prefetch(batch_size * 2)
|
| 837 |
+
return dataset
|
| 838 |
+
|
| 839 |
+
|
| 840 |
+
class DataReader_mseed_array(DataReader):
|
| 841 |
+
def __init__(self, stations, amplitude=True, remove_resp=True, config=DataConfig(), **kwargs):
|
| 842 |
+
super().__init__(format="mseed", config=config, **kwargs)
|
| 843 |
+
|
| 844 |
+
# self.stations = pd.read_json(stations)
|
| 845 |
+
with open(stations, "r") as f:
|
| 846 |
+
self.stations = json.load(f)
|
| 847 |
+
print(pd.DataFrame.from_dict(self.stations, orient="index").to_string())
|
| 848 |
+
|
| 849 |
+
self.amplitude = amplitude
|
| 850 |
+
self.remove_resp = remove_resp
|
| 851 |
+
self.X_shape = self.get_data_shape()
|
| 852 |
+
|
| 853 |
+
def get_data_shape(self):
|
| 854 |
+
fname = os.path.join(self.data_dir, self.data_list[0])
|
| 855 |
+
meta = self.read_mseed_array(fname, self.stations, self.amplitude, self.remove_resp)
|
| 856 |
+
return meta["data"].shape
|
| 857 |
+
|
| 858 |
+
def __getitem__(self, i):
|
| 859 |
+
fp = os.path.join(self.data_dir, self.data_list[i])
|
| 860 |
+
# try:
|
| 861 |
+
meta = self.read_mseed_array(fp, self.stations, self.amplitude, self.remove_resp)
|
| 862 |
+
# except Exception as e:
|
| 863 |
+
# logging.error(f"Failed reading {fp}: {e}")
|
| 864 |
+
# if self.amplitude:
|
| 865 |
+
# return (np.zeros(self.X_shape).astype(self.dtype), np.zeros(self.X_shape).astype(self.dtype),
|
| 866 |
+
# [self.stations.iloc[i]["station"] for i in range(len(self.stations))], ["0" for i in range(len(self.stations))])
|
| 867 |
+
# else:
|
| 868 |
+
# return (np.zeros(self.X_shape).astype(self.dtype), ["" for i in range(len(self.stations))],
|
| 869 |
+
# [self.stations.iloc[i]["station"] for i in range(len(self.stations))])
|
| 870 |
+
|
| 871 |
+
sample = np.zeros([len(meta["data"]), *self.X_shape[1:]], dtype=self.dtype)
|
| 872 |
+
sample[:, : meta["data"].shape[1], :, :] = normalize_batch(meta["data"])[:, : self.X_shape[1], :, :]
|
| 873 |
+
if np.isnan(sample).any() or np.isinf(sample).any():
|
| 874 |
+
logging.warning(f"Data error: Nan or Inf found in {fp}")
|
| 875 |
+
sample[np.isnan(sample)] = 0
|
| 876 |
+
sample[np.isinf(sample)] = 0
|
| 877 |
+
t0 = meta["t0"]
|
| 878 |
+
base_name = meta["fname"]
|
| 879 |
+
station_id = meta["station_id"]
|
| 880 |
+
# base_name = [self.stations.iloc[i]["station"]+"."+t0[i] for i in range(len(self.stations))]
|
| 881 |
+
# base_name = [self.stations.iloc[i]["station"] for i in range(len(self.stations))]
|
| 882 |
+
|
| 883 |
+
if self.amplitude:
|
| 884 |
+
raw_amp = np.zeros([len(meta["raw_amp"]), *self.X_shape[1:]], dtype=self.dtype)
|
| 885 |
+
raw_amp[:, : meta["raw_amp"].shape[1], :, :] = meta["raw_amp"][:, : self.X_shape[1], :, :]
|
| 886 |
+
if np.isnan(raw_amp).any() or np.isinf(raw_amp).any():
|
| 887 |
+
logging.warning(f"Data error: Nan or Inf found in {fp}")
|
| 888 |
+
raw_amp[np.isnan(raw_amp)] = 0
|
| 889 |
+
raw_amp[np.isinf(raw_amp)] = 0
|
| 890 |
+
return (sample, raw_amp, base_name, t0, station_id)
|
| 891 |
+
else:
|
| 892 |
+
return (sample, base_name, t0, station_id)
|
| 893 |
+
|
| 894 |
+
def dataset(self, num_parallel_calls=1, shuffle=False):
|
| 895 |
+
if self.amplitude:
|
| 896 |
+
dataset = dataset_map(
|
| 897 |
+
self,
|
| 898 |
+
output_types=(self.dtype, self.dtype, "string", "string", "string"),
|
| 899 |
+
output_shapes=([None, *self.X_shape[1:]], [None, *self.X_shape[1:]], None, None, None),
|
| 900 |
+
num_parallel_calls=num_parallel_calls,
|
| 901 |
+
)
|
| 902 |
+
else:
|
| 903 |
+
dataset = dataset_map(
|
| 904 |
+
self,
|
| 905 |
+
output_types=(self.dtype, "string", "string", "string"),
|
| 906 |
+
output_shapes=([None, *self.X_shape[1:]], None, None, None),
|
| 907 |
+
num_parallel_calls=num_parallel_calls,
|
| 908 |
+
)
|
| 909 |
+
dataset = dataset.prefetch(1)
|
| 910 |
+
# dataset = dataset.prefetch(len(self.stations)*2)
|
| 911 |
+
return dataset
|
| 912 |
+
|
| 913 |
+
|
| 914 |
+
###### test ########
|
| 915 |
+
|
| 916 |
+
|
| 917 |
+
def test_DataReader():
|
| 918 |
+
import os
|
| 919 |
+
import timeit
|
| 920 |
+
|
| 921 |
+
import matplotlib.pyplot as plt
|
| 922 |
+
|
| 923 |
+
if not os.path.exists("test_figures"):
|
| 924 |
+
os.mkdir("test_figures")
|
| 925 |
+
|
| 926 |
+
def plot_sample(sample, fname, label=None):
|
| 927 |
+
plt.clf()
|
| 928 |
+
plt.subplot(211)
|
| 929 |
+
plt.plot(sample[:, 0, -1])
|
| 930 |
+
if label is not None:
|
| 931 |
+
plt.subplot(212)
|
| 932 |
+
plt.plot(label[:, 0, 0])
|
| 933 |
+
plt.plot(label[:, 0, 1])
|
| 934 |
+
plt.plot(label[:, 0, 2])
|
| 935 |
+
plt.savefig(f"test_figures/{fname.decode()}.png")
|
| 936 |
+
|
| 937 |
+
def read(data_reader, batch=1):
|
| 938 |
+
start_time = timeit.default_timer()
|
| 939 |
+
if batch is None:
|
| 940 |
+
dataset = data_reader.dataset(shuffle=False)
|
| 941 |
+
else:
|
| 942 |
+
dataset = data_reader.dataset(1, shuffle=False)
|
| 943 |
+
sess = tf.compat.v1.Session()
|
| 944 |
+
|
| 945 |
+
print(len(data_reader))
|
| 946 |
+
print("-------", tf.data.Dataset.cardinality(dataset))
|
| 947 |
+
num = 0
|
| 948 |
+
x = tf.compat.v1.data.make_one_shot_iterator(dataset).get_next()
|
| 949 |
+
while True:
|
| 950 |
+
num += 1
|
| 951 |
+
# print(num)
|
| 952 |
+
try:
|
| 953 |
+
out = sess.run(x)
|
| 954 |
+
if len(out) == 2:
|
| 955 |
+
sample, fname = out[0], out[1]
|
| 956 |
+
for i in range(len(sample)):
|
| 957 |
+
plot_sample(sample[i], fname[i])
|
| 958 |
+
else:
|
| 959 |
+
sample, label, fname = out[0], out[1], out[2]
|
| 960 |
+
for i in range(len(sample)):
|
| 961 |
+
plot_sample(sample[i], fname[i], label[i])
|
| 962 |
+
except tf.errors.OutOfRangeError:
|
| 963 |
+
break
|
| 964 |
+
print("End of dataset")
|
| 965 |
+
print("Tensorflow Dataset:\nexecution time = ", timeit.default_timer() - start_time)
|
| 966 |
+
|
| 967 |
+
data_reader = DataReader_train(data_list="test_data/selected_phases.csv", data_dir="test_data/data/")
|
| 968 |
+
|
| 969 |
+
read(data_reader)
|
| 970 |
+
|
| 971 |
+
data_reader = DataReader_train(format="hdf5", hdf5="test_data/data.h5", group="data")
|
| 972 |
+
|
| 973 |
+
read(data_reader)
|
| 974 |
+
|
| 975 |
+
data_reader = DataReader_test(data_list="test_data/selected_phases.csv", data_dir="test_data/data/")
|
| 976 |
+
|
| 977 |
+
read(data_reader)
|
| 978 |
+
|
| 979 |
+
data_reader = DataReader_test(format="hdf5", hdf5="test_data/data.h5", group="data")
|
| 980 |
+
|
| 981 |
+
read(data_reader)
|
| 982 |
+
|
| 983 |
+
data_reader = DataReader_pred(format="numpy", data_list="test_data/selected_phases.csv", data_dir="test_data/data/")
|
| 984 |
+
|
| 985 |
+
read(data_reader)
|
| 986 |
+
|
| 987 |
+
data_reader = DataReader_pred(
|
| 988 |
+
format="mseed", data_list="test_data/mseed_station.csv", data_dir="test_data/waveforms/"
|
| 989 |
+
)
|
| 990 |
+
|
| 991 |
+
read(data_reader)
|
| 992 |
+
|
| 993 |
+
data_reader = DataReader_pred(
|
| 994 |
+
format="mseed", amplitude=True, data_list="test_data/mseed_station.csv", data_dir="test_data/waveforms/"
|
| 995 |
+
)
|
| 996 |
+
|
| 997 |
+
read(data_reader)
|
| 998 |
+
|
| 999 |
+
data_reader = DataReader_mseed_array(
|
| 1000 |
+
data_list="test_data/mseed.csv",
|
| 1001 |
+
data_dir="test_data/waveforms/",
|
| 1002 |
+
stations="test_data/stations.csv",
|
| 1003 |
+
remove_resp=False,
|
| 1004 |
+
)
|
| 1005 |
+
|
| 1006 |
+
read(data_reader, batch=None)
|
| 1007 |
+
|
| 1008 |
+
|
| 1009 |
+
if __name__ == "__main__":
|
| 1010 |
+
test_DataReader()
|
phasenet/detect_peaks.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Detect peaks in data based on their amplitude and other features."""
|
| 2 |
+
|
| 3 |
+
from __future__ import division, print_function
|
| 4 |
+
import warnings
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
__author__ = "Marcos Duarte, https://github.com/demotu"
|
| 8 |
+
__version__ = "1.0.6"
|
| 9 |
+
__license__ = "MIT"
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def detect_peaks(x, mph=None, mpd=1, threshold=0, edge='rising',
|
| 14 |
+
kpsh=False, valley=False, show=False, ax=None, title=True):
|
| 15 |
+
|
| 16 |
+
"""Detect peaks in data based on their amplitude and other features.
|
| 17 |
+
|
| 18 |
+
Parameters
|
| 19 |
+
----------
|
| 20 |
+
x : 1D array_like
|
| 21 |
+
data.
|
| 22 |
+
mph : {None, number}, optional (default = None)
|
| 23 |
+
detect peaks that are greater than minimum peak height (if parameter
|
| 24 |
+
`valley` is False) or peaks that are smaller than maximum peak height
|
| 25 |
+
(if parameter `valley` is True).
|
| 26 |
+
mpd : positive integer, optional (default = 1)
|
| 27 |
+
detect peaks that are at least separated by minimum peak distance (in
|
| 28 |
+
number of data).
|
| 29 |
+
threshold : positive number, optional (default = 0)
|
| 30 |
+
detect peaks (valleys) that are greater (smaller) than `threshold`
|
| 31 |
+
in relation to their immediate neighbors.
|
| 32 |
+
edge : {None, 'rising', 'falling', 'both'}, optional (default = 'rising')
|
| 33 |
+
for a flat peak, keep only the rising edge ('rising'), only the
|
| 34 |
+
falling edge ('falling'), both edges ('both'), or don't detect a
|
| 35 |
+
flat peak (None).
|
| 36 |
+
kpsh : bool, optional (default = False)
|
| 37 |
+
keep peaks with same height even if they are closer than `mpd`.
|
| 38 |
+
valley : bool, optional (default = False)
|
| 39 |
+
if True (1), detect valleys (local minima) instead of peaks.
|
| 40 |
+
show : bool, optional (default = False)
|
| 41 |
+
if True (1), plot data in matplotlib figure.
|
| 42 |
+
ax : a matplotlib.axes.Axes instance, optional (default = None).
|
| 43 |
+
title : bool or string, optional (default = True)
|
| 44 |
+
if True, show standard title. If False or empty string, doesn't show
|
| 45 |
+
any title. If string, shows string as title.
|
| 46 |
+
|
| 47 |
+
Returns
|
| 48 |
+
-------
|
| 49 |
+
ind : 1D array_like
|
| 50 |
+
indeces of the peaks in `x`.
|
| 51 |
+
|
| 52 |
+
Notes
|
| 53 |
+
-----
|
| 54 |
+
The detection of valleys instead of peaks is performed internally by simply
|
| 55 |
+
negating the data: `ind_valleys = detect_peaks(-x)`
|
| 56 |
+
|
| 57 |
+
The function can handle NaN's
|
| 58 |
+
|
| 59 |
+
See this IPython Notebook [1]_.
|
| 60 |
+
|
| 61 |
+
References
|
| 62 |
+
----------
|
| 63 |
+
.. [1] http://nbviewer.ipython.org/github/demotu/BMC/blob/master/notebooks/DetectPeaks.ipynb
|
| 64 |
+
|
| 65 |
+
Examples
|
| 66 |
+
--------
|
| 67 |
+
>>> from detect_peaks import detect_peaks
|
| 68 |
+
>>> x = np.random.randn(100)
|
| 69 |
+
>>> x[60:81] = np.nan
|
| 70 |
+
>>> # detect all peaks and plot data
|
| 71 |
+
>>> ind = detect_peaks(x, show=True)
|
| 72 |
+
>>> print(ind)
|
| 73 |
+
|
| 74 |
+
>>> x = np.sin(2*np.pi*5*np.linspace(0, 1, 200)) + np.random.randn(200)/5
|
| 75 |
+
>>> # set minimum peak height = 0 and minimum peak distance = 20
|
| 76 |
+
>>> detect_peaks(x, mph=0, mpd=20, show=True)
|
| 77 |
+
|
| 78 |
+
>>> x = [0, 1, 0, 2, 0, 3, 0, 2, 0, 1, 0]
|
| 79 |
+
>>> # set minimum peak distance = 2
|
| 80 |
+
>>> detect_peaks(x, mpd=2, show=True)
|
| 81 |
+
|
| 82 |
+
>>> x = np.sin(2*np.pi*5*np.linspace(0, 1, 200)) + np.random.randn(200)/5
|
| 83 |
+
>>> # detection of valleys instead of peaks
|
| 84 |
+
>>> detect_peaks(x, mph=-1.2, mpd=20, valley=True, show=True)
|
| 85 |
+
|
| 86 |
+
>>> x = [0, 1, 1, 0, 1, 1, 0]
|
| 87 |
+
>>> # detect both edges
|
| 88 |
+
>>> detect_peaks(x, edge='both', show=True)
|
| 89 |
+
|
| 90 |
+
>>> x = [-2, 1, -2, 2, 1, 1, 3, 0]
|
| 91 |
+
>>> # set threshold = 2
|
| 92 |
+
>>> detect_peaks(x, threshold = 2, show=True)
|
| 93 |
+
|
| 94 |
+
>>> x = [-2, 1, -2, 2, 1, 1, 3, 0]
|
| 95 |
+
>>> fig, axs = plt.subplots(ncols=2, nrows=1, figsize=(10, 4))
|
| 96 |
+
>>> detect_peaks(x, show=True, ax=axs[0], threshold=0.5, title=False)
|
| 97 |
+
>>> detect_peaks(x, show=True, ax=axs[1], threshold=1.5, title=False)
|
| 98 |
+
|
| 99 |
+
Version history
|
| 100 |
+
---------------
|
| 101 |
+
'1.0.6':
|
| 102 |
+
Fix issue of when specifying ax object only the first plot was shown
|
| 103 |
+
Add parameter to choose if a title is shown and input a title
|
| 104 |
+
'1.0.5':
|
| 105 |
+
The sign of `mph` is inverted if parameter `valley` is True
|
| 106 |
+
|
| 107 |
+
"""
|
| 108 |
+
|
| 109 |
+
x = np.atleast_1d(x).astype('float64')
|
| 110 |
+
if x.size < 3:
|
| 111 |
+
return np.array([], dtype=int)
|
| 112 |
+
if valley:
|
| 113 |
+
x = -x
|
| 114 |
+
if mph is not None:
|
| 115 |
+
mph = -mph
|
| 116 |
+
# find indices of all peaks
|
| 117 |
+
dx = x[1:] - x[:-1]
|
| 118 |
+
# handle NaN's
|
| 119 |
+
indnan = np.where(np.isnan(x))[0]
|
| 120 |
+
if indnan.size:
|
| 121 |
+
x[indnan] = np.inf
|
| 122 |
+
dx[np.where(np.isnan(dx))[0]] = np.inf
|
| 123 |
+
ine, ire, ife = np.array([[], [], []], dtype=int)
|
| 124 |
+
if not edge:
|
| 125 |
+
ine = np.where((np.hstack((dx, 0)) < 0) & (np.hstack((0, dx)) > 0))[0]
|
| 126 |
+
else:
|
| 127 |
+
if edge.lower() in ['rising', 'both']:
|
| 128 |
+
ire = np.where((np.hstack((dx, 0)) <= 0) & (np.hstack((0, dx)) > 0))[0]
|
| 129 |
+
if edge.lower() in ['falling', 'both']:
|
| 130 |
+
ife = np.where((np.hstack((dx, 0)) < 0) & (np.hstack((0, dx)) >= 0))[0]
|
| 131 |
+
ind = np.unique(np.hstack((ine, ire, ife)))
|
| 132 |
+
# handle NaN's
|
| 133 |
+
if ind.size and indnan.size:
|
| 134 |
+
# NaN's and values close to NaN's cannot be peaks
|
| 135 |
+
ind = ind[np.in1d(ind, np.unique(np.hstack((indnan, indnan-1, indnan+1))), invert=True)]
|
| 136 |
+
# first and last values of x cannot be peaks
|
| 137 |
+
if ind.size and ind[0] == 0:
|
| 138 |
+
ind = ind[1:]
|
| 139 |
+
if ind.size and ind[-1] == x.size-1:
|
| 140 |
+
ind = ind[:-1]
|
| 141 |
+
# remove peaks < minimum peak height
|
| 142 |
+
if ind.size and mph is not None:
|
| 143 |
+
ind = ind[x[ind] >= mph]
|
| 144 |
+
# remove peaks - neighbors < threshold
|
| 145 |
+
if ind.size and threshold > 0:
|
| 146 |
+
dx = np.min(np.vstack([x[ind]-x[ind-1], x[ind]-x[ind+1]]), axis=0)
|
| 147 |
+
ind = np.delete(ind, np.where(dx < threshold)[0])
|
| 148 |
+
# detect small peaks closer than minimum peak distance
|
| 149 |
+
if ind.size and mpd > 1:
|
| 150 |
+
ind = ind[np.argsort(x[ind])][::-1] # sort ind by peak height
|
| 151 |
+
idel = np.zeros(ind.size, dtype=bool)
|
| 152 |
+
for i in range(ind.size):
|
| 153 |
+
if not idel[i]:
|
| 154 |
+
# keep peaks with the same height if kpsh is True
|
| 155 |
+
idel = idel | (ind >= ind[i] - mpd) & (ind <= ind[i] + mpd) \
|
| 156 |
+
& (x[ind[i]] > x[ind] if kpsh else True)
|
| 157 |
+
idel[i] = 0 # Keep current peak
|
| 158 |
+
# remove the small peaks and sort back the indices by their occurrence
|
| 159 |
+
ind = np.sort(ind[~idel])
|
| 160 |
+
|
| 161 |
+
if show:
|
| 162 |
+
if indnan.size:
|
| 163 |
+
x[indnan] = np.nan
|
| 164 |
+
if valley:
|
| 165 |
+
x = -x
|
| 166 |
+
if mph is not None:
|
| 167 |
+
mph = -mph
|
| 168 |
+
_plot(x, mph, mpd, threshold, edge, valley, ax, ind, title)
|
| 169 |
+
|
| 170 |
+
return ind, x[ind]
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def _plot(x, mph, mpd, threshold, edge, valley, ax, ind, title):
|
| 174 |
+
"""Plot results of the detect_peaks function, see its help."""
|
| 175 |
+
try:
|
| 176 |
+
import matplotlib.pyplot as plt
|
| 177 |
+
except ImportError:
|
| 178 |
+
print('matplotlib is not available.')
|
| 179 |
+
else:
|
| 180 |
+
if ax is None:
|
| 181 |
+
_, ax = plt.subplots(1, 1, figsize=(8, 4))
|
| 182 |
+
no_ax = True
|
| 183 |
+
else:
|
| 184 |
+
no_ax = False
|
| 185 |
+
|
| 186 |
+
ax.plot(x, 'b', lw=1)
|
| 187 |
+
if ind.size:
|
| 188 |
+
label = 'valley' if valley else 'peak'
|
| 189 |
+
label = label + 's' if ind.size > 1 else label
|
| 190 |
+
ax.plot(ind, x[ind], '+', mfc=None, mec='r', mew=2, ms=8,
|
| 191 |
+
label='%d %s' % (ind.size, label))
|
| 192 |
+
ax.legend(loc='best', framealpha=.5, numpoints=1)
|
| 193 |
+
ax.set_xlim(-.02*x.size, x.size*1.02-1)
|
| 194 |
+
ymin, ymax = x[np.isfinite(x)].min(), x[np.isfinite(x)].max()
|
| 195 |
+
yrange = ymax - ymin if ymax > ymin else 1
|
| 196 |
+
ax.set_ylim(ymin - 0.1*yrange, ymax + 0.1*yrange)
|
| 197 |
+
ax.set_xlabel('Data #', fontsize=14)
|
| 198 |
+
ax.set_ylabel('Amplitude', fontsize=14)
|
| 199 |
+
if title:
|
| 200 |
+
if not isinstance(title, str):
|
| 201 |
+
mode = 'Valley detection' if valley else 'Peak detection'
|
| 202 |
+
title = "%s (mph=%s, mpd=%d, threshold=%s, edge='%s')"% \
|
| 203 |
+
(mode, str(mph), mpd, str(threshold), edge)
|
| 204 |
+
ax.set_title(title)
|
| 205 |
+
# plt.grid()
|
| 206 |
+
if no_ax:
|
| 207 |
+
plt.show()
|
phasenet/model.py
ADDED
|
@@ -0,0 +1,489 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tensorflow as tf
|
| 2 |
+
tf.compat.v1.disable_eager_execution()
|
| 3 |
+
import numpy as np
|
| 4 |
+
import logging
|
| 5 |
+
import warnings
|
| 6 |
+
warnings.filterwarnings('ignore', category=UserWarning)
|
| 7 |
+
|
| 8 |
+
class ModelConfig:
|
| 9 |
+
|
| 10 |
+
batch_size = 20
|
| 11 |
+
depths = 5
|
| 12 |
+
filters_root = 8
|
| 13 |
+
kernel_size = [7, 1]
|
| 14 |
+
pool_size = [4, 1]
|
| 15 |
+
dilation_rate = [1, 1]
|
| 16 |
+
class_weights = [1.0, 1.0, 1.0]
|
| 17 |
+
loss_type = "cross_entropy"
|
| 18 |
+
weight_decay = 0.0
|
| 19 |
+
optimizer = "adam"
|
| 20 |
+
momentum = 0.9
|
| 21 |
+
learning_rate = 0.01
|
| 22 |
+
decay_step = 1e9
|
| 23 |
+
decay_rate = 0.9
|
| 24 |
+
drop_rate = 0.0
|
| 25 |
+
summary = True
|
| 26 |
+
|
| 27 |
+
X_shape = [3000, 1, 3]
|
| 28 |
+
n_channel = X_shape[-1]
|
| 29 |
+
Y_shape = [3000, 1, 3]
|
| 30 |
+
n_class = Y_shape[-1]
|
| 31 |
+
|
| 32 |
+
def __init__(self, **kwargs):
|
| 33 |
+
for k,v in kwargs.items():
|
| 34 |
+
setattr(self, k, v)
|
| 35 |
+
|
| 36 |
+
def update_args(self, args):
|
| 37 |
+
for k,v in vars(args).items():
|
| 38 |
+
setattr(self, k, v)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def crop_and_concat(net1, net2):
|
| 42 |
+
"""
|
| 43 |
+
the size(net1) <= size(net2)
|
| 44 |
+
"""
|
| 45 |
+
# net1_shape = net1.get_shape().as_list()
|
| 46 |
+
# net2_shape = net2.get_shape().as_list()
|
| 47 |
+
# # print(net1_shape)
|
| 48 |
+
# # print(net2_shape)
|
| 49 |
+
# # if net2_shape[1] >= net1_shape[1] and net2_shape[2] >= net1_shape[2]:
|
| 50 |
+
# offsets = [0, (net2_shape[1] - net1_shape[1]) // 2, (net2_shape[2] - net1_shape[2]) // 2, 0]
|
| 51 |
+
# size = [-1, net1_shape[1], net1_shape[2], -1]
|
| 52 |
+
# net2_resize = tf.slice(net2, offsets, size)
|
| 53 |
+
# return tf.concat([net1, net2_resize], 3)
|
| 54 |
+
|
| 55 |
+
## dynamic shape
|
| 56 |
+
chn1 = net1.get_shape().as_list()[-1]
|
| 57 |
+
chn2 = net2.get_shape().as_list()[-1]
|
| 58 |
+
net1_shape = tf.shape(net1)
|
| 59 |
+
net2_shape = tf.shape(net2)
|
| 60 |
+
# print(net1_shape)
|
| 61 |
+
# print(net2_shape)
|
| 62 |
+
# if net2_shape[1] >= net1_shape[1] and net2_shape[2] >= net1_shape[2]:
|
| 63 |
+
offsets = [0, (net2_shape[1] - net1_shape[1]) // 2, (net2_shape[2] - net1_shape[2]) // 2, 0]
|
| 64 |
+
size = [-1, net1_shape[1], net1_shape[2], -1]
|
| 65 |
+
net2_resize = tf.slice(net2, offsets, size)
|
| 66 |
+
|
| 67 |
+
out = tf.concat([net1, net2_resize], 3)
|
| 68 |
+
out.set_shape([None, None, None, chn1+chn2])
|
| 69 |
+
|
| 70 |
+
return out
|
| 71 |
+
|
| 72 |
+
# else:
|
| 73 |
+
# offsets = [0, (net1_shape[1] - net2_shape[1]) // 2, (net1_shape[2] - net2_shape[2]) // 2, 0]
|
| 74 |
+
# size = [-1, net2_shape[1], net2_shape[2], -1]
|
| 75 |
+
# net1_resize = tf.slice(net1, offsets, size)
|
| 76 |
+
# return tf.concat([net1_resize, net2], 3)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def crop_only(net1, net2):
|
| 80 |
+
"""
|
| 81 |
+
the size(net1) <= size(net2)
|
| 82 |
+
"""
|
| 83 |
+
net1_shape = net1.get_shape().as_list()
|
| 84 |
+
net2_shape = net2.get_shape().as_list()
|
| 85 |
+
# print(net1_shape)
|
| 86 |
+
# print(net2_shape)
|
| 87 |
+
# if net2_shape[1] >= net1_shape[1] and net2_shape[2] >= net1_shape[2]:
|
| 88 |
+
offsets = [0, (net2_shape[1] - net1_shape[1]) // 2, (net2_shape[2] - net1_shape[2]) // 2, 0]
|
| 89 |
+
size = [-1, net1_shape[1], net1_shape[2], -1]
|
| 90 |
+
net2_resize = tf.slice(net2, offsets, size)
|
| 91 |
+
#return tf.concat([net1, net2_resize], 3)
|
| 92 |
+
return net2_resize
|
| 93 |
+
|
| 94 |
+
class UNet:
|
| 95 |
+
def __init__(self, config=ModelConfig(), input_batch=None, mode='train'):
|
| 96 |
+
self.depths = config.depths
|
| 97 |
+
self.filters_root = config.filters_root
|
| 98 |
+
self.kernel_size = config.kernel_size
|
| 99 |
+
self.dilation_rate = config.dilation_rate
|
| 100 |
+
self.pool_size = config.pool_size
|
| 101 |
+
self.X_shape = config.X_shape
|
| 102 |
+
self.Y_shape = config.Y_shape
|
| 103 |
+
self.n_channel = config.n_channel
|
| 104 |
+
self.n_class = config.n_class
|
| 105 |
+
self.class_weights = config.class_weights
|
| 106 |
+
self.batch_size = config.batch_size
|
| 107 |
+
self.loss_type = config.loss_type
|
| 108 |
+
self.weight_decay = config.weight_decay
|
| 109 |
+
self.optimizer = config.optimizer
|
| 110 |
+
self.learning_rate = config.learning_rate
|
| 111 |
+
self.decay_step = config.decay_step
|
| 112 |
+
self.decay_rate = config.decay_rate
|
| 113 |
+
self.momentum = config.momentum
|
| 114 |
+
self.global_step = tf.compat.v1.get_variable(name="global_step", initializer=0, dtype=tf.int32)
|
| 115 |
+
self.summary_train = []
|
| 116 |
+
self.summary_valid = []
|
| 117 |
+
|
| 118 |
+
self.build(input_batch, mode=mode)
|
| 119 |
+
|
| 120 |
+
def add_placeholders(self, input_batch=None, mode="train"):
|
| 121 |
+
if input_batch is None:
|
| 122 |
+
# self.X = tf.compat.v1.placeholder(dtype=tf.float32, shape=[None, self.X_shape[-3], self.X_shape[-2], self.X_shape[-1]], name='X')
|
| 123 |
+
# self.Y = tf.compat.v1.placeholder(dtype=tf.float32, shape=[None, self.Y_shape[-3], self.Y_shape[-2], self.n_class], name='y')
|
| 124 |
+
self.X = tf.compat.v1.placeholder(dtype=tf.float32, shape=[None, None, None, self.X_shape[-1]], name='X')
|
| 125 |
+
self.Y = tf.compat.v1.placeholder(dtype=tf.float32, shape=[None, None, None, self.n_class], name='y')
|
| 126 |
+
else:
|
| 127 |
+
self.X = input_batch[0]
|
| 128 |
+
if mode in ["train", "valid", "test"]:
|
| 129 |
+
self.Y = input_batch[1]
|
| 130 |
+
self.input_batch = input_batch
|
| 131 |
+
|
| 132 |
+
self.is_training = tf.compat.v1.placeholder(dtype=tf.bool, name="is_training")
|
| 133 |
+
# self.keep_prob = tf.compat.v1.placeholder(dtype=tf.float32, name="keep_prob")
|
| 134 |
+
self.drop_rate = tf.compat.v1.placeholder(dtype=tf.float32, name="drop_rate")
|
| 135 |
+
|
| 136 |
+
def add_prediction_op(self):
|
| 137 |
+
logging.info("Model: depths {depths}, filters {filters}, "
|
| 138 |
+
"filter size {kernel_size[0]}x{kernel_size[1]}, "
|
| 139 |
+
"pool size: {pool_size[0]}x{pool_size[1]}, "
|
| 140 |
+
"dilation rate: {dilation_rate[0]}x{dilation_rate[1]}".format(
|
| 141 |
+
depths=self.depths,
|
| 142 |
+
filters=self.filters_root,
|
| 143 |
+
kernel_size=self.kernel_size,
|
| 144 |
+
dilation_rate=self.dilation_rate,
|
| 145 |
+
pool_size=self.pool_size))
|
| 146 |
+
|
| 147 |
+
if self.weight_decay > 0:
|
| 148 |
+
weight_decay = tf.constant(self.weight_decay, dtype=tf.float32, name="weight_constant")
|
| 149 |
+
self.regularizer = tf.keras.regularizers.l2(l=0.5 * (weight_decay))
|
| 150 |
+
else:
|
| 151 |
+
self.regularizer = None
|
| 152 |
+
|
| 153 |
+
self.initializer = tf.compat.v1.keras.initializers.VarianceScaling(scale=1.0, mode="fan_avg", distribution="uniform")
|
| 154 |
+
|
| 155 |
+
# down sample layers
|
| 156 |
+
convs = [None] * self.depths # store output of each depth
|
| 157 |
+
|
| 158 |
+
with tf.compat.v1.variable_scope("Input"):
|
| 159 |
+
net = self.X
|
| 160 |
+
net = tf.compat.v1.layers.conv2d(net,
|
| 161 |
+
filters=self.filters_root,
|
| 162 |
+
kernel_size=self.kernel_size,
|
| 163 |
+
activation=None,
|
| 164 |
+
padding='same',
|
| 165 |
+
dilation_rate=self.dilation_rate,
|
| 166 |
+
kernel_initializer=self.initializer,
|
| 167 |
+
kernel_regularizer=self.regularizer,
|
| 168 |
+
name="input_conv")
|
| 169 |
+
net = tf.compat.v1.layers.batch_normalization(net,
|
| 170 |
+
training=self.is_training,
|
| 171 |
+
name="input_bn")
|
| 172 |
+
net = tf.nn.relu(net,
|
| 173 |
+
name="input_relu")
|
| 174 |
+
# net = tf.nn.dropout(net, self.keep_prob)
|
| 175 |
+
net = tf.compat.v1.layers.dropout(net,
|
| 176 |
+
rate=self.drop_rate,
|
| 177 |
+
training=self.is_training,
|
| 178 |
+
name="input_dropout")
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
for depth in range(0, self.depths):
|
| 182 |
+
with tf.compat.v1.variable_scope("DownConv_%d" % depth):
|
| 183 |
+
filters = int(2**(depth) * self.filters_root)
|
| 184 |
+
|
| 185 |
+
net = tf.compat.v1.layers.conv2d(net,
|
| 186 |
+
filters=filters,
|
| 187 |
+
kernel_size=self.kernel_size,
|
| 188 |
+
activation=None,
|
| 189 |
+
use_bias=False,
|
| 190 |
+
padding='same',
|
| 191 |
+
dilation_rate=self.dilation_rate,
|
| 192 |
+
kernel_initializer=self.initializer,
|
| 193 |
+
kernel_regularizer=self.regularizer,
|
| 194 |
+
name="down_conv1_{}".format(depth + 1))
|
| 195 |
+
net = tf.compat.v1.layers.batch_normalization(net,
|
| 196 |
+
training=self.is_training,
|
| 197 |
+
name="down_bn1_{}".format(depth + 1))
|
| 198 |
+
net = tf.nn.relu(net,
|
| 199 |
+
name="down_relu1_{}".format(depth+1))
|
| 200 |
+
net = tf.compat.v1.layers.dropout(net,
|
| 201 |
+
rate=self.drop_rate,
|
| 202 |
+
training=self.is_training,
|
| 203 |
+
name="down_dropout1_{}".format(depth + 1))
|
| 204 |
+
|
| 205 |
+
convs[depth] = net
|
| 206 |
+
|
| 207 |
+
if depth < self.depths - 1:
|
| 208 |
+
net = tf.compat.v1.layers.conv2d(net,
|
| 209 |
+
filters=filters,
|
| 210 |
+
kernel_size=self.kernel_size,
|
| 211 |
+
strides=self.pool_size,
|
| 212 |
+
activation=None,
|
| 213 |
+
use_bias=False,
|
| 214 |
+
padding='same',
|
| 215 |
+
dilation_rate=self.dilation_rate,
|
| 216 |
+
kernel_initializer=self.initializer,
|
| 217 |
+
kernel_regularizer=self.regularizer,
|
| 218 |
+
name="down_conv3_{}".format(depth + 1))
|
| 219 |
+
net = tf.compat.v1.layers.batch_normalization(net,
|
| 220 |
+
training=self.is_training,
|
| 221 |
+
name="down_bn3_{}".format(depth + 1))
|
| 222 |
+
net = tf.nn.relu(net,
|
| 223 |
+
name="down_relu3_{}".format(depth+1))
|
| 224 |
+
net = tf.compat.v1.layers.dropout(net,
|
| 225 |
+
rate=self.drop_rate,
|
| 226 |
+
training=self.is_training,
|
| 227 |
+
name="down_dropout3_{}".format(depth + 1))
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
# up layers
|
| 231 |
+
for depth in range(self.depths - 2, -1, -1):
|
| 232 |
+
with tf.compat.v1.variable_scope("UpConv_%d" % depth):
|
| 233 |
+
filters = int(2**(depth) * self.filters_root)
|
| 234 |
+
net = tf.compat.v1.layers.conv2d_transpose(net,
|
| 235 |
+
filters=filters,
|
| 236 |
+
kernel_size=self.kernel_size,
|
| 237 |
+
strides=self.pool_size,
|
| 238 |
+
activation=None,
|
| 239 |
+
use_bias=False,
|
| 240 |
+
padding="same",
|
| 241 |
+
kernel_initializer=self.initializer,
|
| 242 |
+
kernel_regularizer=self.regularizer,
|
| 243 |
+
name="up_conv0_{}".format(depth+1))
|
| 244 |
+
net = tf.compat.v1.layers.batch_normalization(net,
|
| 245 |
+
training=self.is_training,
|
| 246 |
+
name="up_bn0_{}".format(depth + 1))
|
| 247 |
+
net = tf.nn.relu(net,
|
| 248 |
+
name="up_relu0_{}".format(depth+1))
|
| 249 |
+
net = tf.compat.v1.layers.dropout(net,
|
| 250 |
+
rate=self.drop_rate,
|
| 251 |
+
training=self.is_training,
|
| 252 |
+
name="up_dropout0_{}".format(depth + 1))
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
#skip connection
|
| 256 |
+
net = crop_and_concat(convs[depth], net)
|
| 257 |
+
#net = crop_only(convs[depth], net)
|
| 258 |
+
|
| 259 |
+
net = tf.compat.v1.layers.conv2d(net,
|
| 260 |
+
filters=filters,
|
| 261 |
+
kernel_size=self.kernel_size,
|
| 262 |
+
activation=None,
|
| 263 |
+
use_bias=False,
|
| 264 |
+
padding='same',
|
| 265 |
+
dilation_rate=self.dilation_rate,
|
| 266 |
+
kernel_initializer=self.initializer,
|
| 267 |
+
kernel_regularizer=self.regularizer,
|
| 268 |
+
name="up_conv1_{}".format(depth + 1))
|
| 269 |
+
net = tf.compat.v1.layers.batch_normalization(net,
|
| 270 |
+
training=self.is_training,
|
| 271 |
+
name="up_bn1_{}".format(depth + 1))
|
| 272 |
+
net = tf.nn.relu(net,
|
| 273 |
+
name="up_relu1_{}".format(depth + 1))
|
| 274 |
+
net = tf.compat.v1.layers.dropout(net,
|
| 275 |
+
rate=self.drop_rate,
|
| 276 |
+
training=self.is_training,
|
| 277 |
+
name="up_dropout1_{}".format(depth + 1))
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
# Output Map
|
| 281 |
+
with tf.compat.v1.variable_scope("Output"):
|
| 282 |
+
net = tf.compat.v1.layers.conv2d(net,
|
| 283 |
+
filters=self.n_class,
|
| 284 |
+
kernel_size=(1,1),
|
| 285 |
+
activation=None,
|
| 286 |
+
padding='same',
|
| 287 |
+
#dilation_rate=self.dilation_rate,
|
| 288 |
+
kernel_initializer=self.initializer,
|
| 289 |
+
kernel_regularizer=self.regularizer,
|
| 290 |
+
name="output_conv")
|
| 291 |
+
# net = tf.nn.relu(net,
|
| 292 |
+
# name="output_relu")
|
| 293 |
+
# net = tf.compat.v1.layers.dropout(net,
|
| 294 |
+
# rate=self.drop_rate,
|
| 295 |
+
# training=self.is_training,
|
| 296 |
+
# name="output_dropout")
|
| 297 |
+
# net = tf.compat.v1.layers.batch_normalization(net,
|
| 298 |
+
# training=self.is_training,
|
| 299 |
+
# name="output_bn")
|
| 300 |
+
output = net
|
| 301 |
+
|
| 302 |
+
with tf.compat.v1.variable_scope("representation"):
|
| 303 |
+
self.representation = convs[-1]
|
| 304 |
+
|
| 305 |
+
with tf.compat.v1.variable_scope("logits"):
|
| 306 |
+
self.logits = output
|
| 307 |
+
tmp = tf.compat.v1.summary.histogram("logits", self.logits)
|
| 308 |
+
self.summary_train.append(tmp)
|
| 309 |
+
|
| 310 |
+
with tf.compat.v1.variable_scope("preds"):
|
| 311 |
+
self.preds = tf.nn.softmax(output)
|
| 312 |
+
tmp = tf.compat.v1.summary.histogram("preds", self.preds)
|
| 313 |
+
self.summary_train.append(tmp)
|
| 314 |
+
|
| 315 |
+
def add_loss_op(self):
|
| 316 |
+
if self.loss_type == "cross_entropy":
|
| 317 |
+
with tf.compat.v1.variable_scope("cross_entropy"):
|
| 318 |
+
flat_logits = tf.reshape(self.logits, [-1, self.n_class], name="logits")
|
| 319 |
+
flat_labels = tf.reshape(self.Y, [-1, self.n_class], name="labels")
|
| 320 |
+
if (np.array(self.class_weights) != 1).any():
|
| 321 |
+
class_weights = tf.constant(np.array(self.class_weights, dtype=np.float32), name="class_weights")
|
| 322 |
+
weight_map = tf.multiply(flat_labels, class_weights)
|
| 323 |
+
weight_map = tf.reduce_sum(input_tensor=weight_map, axis=1)
|
| 324 |
+
loss_map = tf.nn.softmax_cross_entropy_with_logits(logits=flat_logits,
|
| 325 |
+
labels=flat_labels)
|
| 326 |
+
|
| 327 |
+
weighted_loss = tf.multiply(loss_map, weight_map)
|
| 328 |
+
loss = tf.reduce_mean(input_tensor=weighted_loss)
|
| 329 |
+
else:
|
| 330 |
+
loss = tf.reduce_mean(input_tensor=tf.nn.softmax_cross_entropy_with_logits(logits=flat_logits,
|
| 331 |
+
labels=flat_labels))
|
| 332 |
+
|
| 333 |
+
elif self.loss_type == "IOU":
|
| 334 |
+
with tf.compat.v1.variable_scope("IOU"):
|
| 335 |
+
eps = 1e-7
|
| 336 |
+
loss = 0
|
| 337 |
+
for i in range(1, self.n_class):
|
| 338 |
+
intersection = eps + tf.reduce_sum(input_tensor=self.preds[:,:,:,i] * self.Y[:,:,:,i], axis=[1,2])
|
| 339 |
+
union = eps + tf.reduce_sum(input_tensor=self.preds[:,:,:,i], axis=[1,2]) + tf.reduce_sum(input_tensor=self.Y[:,:,:,i], axis=[1,2])
|
| 340 |
+
loss += 1 - tf.reduce_mean(input_tensor=intersection / union)
|
| 341 |
+
elif self.loss_type == "mean_squared":
|
| 342 |
+
with tf.compat.v1.variable_scope("mean_squared"):
|
| 343 |
+
flat_logits = tf.reshape(self.logits, [-1, self.n_class], name="logits")
|
| 344 |
+
flat_labels = tf.reshape(self.Y, [-1, self.n_class], name="labels")
|
| 345 |
+
with tf.compat.v1.variable_scope("mean_squared"):
|
| 346 |
+
loss = tf.compat.v1.losses.mean_squared_error(labels=flat_labels, predictions=flat_logits)
|
| 347 |
+
else:
|
| 348 |
+
raise ValueError("Unknown loss function: " % self.loss_type)
|
| 349 |
+
|
| 350 |
+
tmp = tf.compat.v1.summary.scalar("train_loss", loss)
|
| 351 |
+
self.summary_train.append(tmp)
|
| 352 |
+
tmp = tf.compat.v1.summary.scalar("valid_loss", loss)
|
| 353 |
+
self.summary_valid.append(tmp)
|
| 354 |
+
|
| 355 |
+
if self.weight_decay > 0:
|
| 356 |
+
with tf.compat.v1.name_scope('weight_loss'):
|
| 357 |
+
tmp = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES)
|
| 358 |
+
weight_loss = tf.add_n(tmp, name="weight_loss")
|
| 359 |
+
self.loss = loss + weight_loss
|
| 360 |
+
else:
|
| 361 |
+
self.loss = loss
|
| 362 |
+
|
| 363 |
+
def add_training_op(self):
|
| 364 |
+
if self.optimizer == "momentum":
|
| 365 |
+
self.learning_rate_node = tf.compat.v1.train.exponential_decay(learning_rate=self.learning_rate,
|
| 366 |
+
global_step=self.global_step,
|
| 367 |
+
decay_steps=self.decay_step,
|
| 368 |
+
decay_rate=self.decay_rate,
|
| 369 |
+
staircase=True)
|
| 370 |
+
optimizer = tf.compat.v1.train.MomentumOptimizer(learning_rate=self.learning_rate_node,
|
| 371 |
+
momentum=self.momentum)
|
| 372 |
+
elif self.optimizer == "adam":
|
| 373 |
+
self.learning_rate_node = tf.compat.v1.train.exponential_decay(learning_rate=self.learning_rate,
|
| 374 |
+
global_step=self.global_step,
|
| 375 |
+
decay_steps=self.decay_step,
|
| 376 |
+
decay_rate=self.decay_rate,
|
| 377 |
+
staircase=True)
|
| 378 |
+
|
| 379 |
+
optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=self.learning_rate_node)
|
| 380 |
+
update_ops = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS)
|
| 381 |
+
with tf.control_dependencies(update_ops):
|
| 382 |
+
self.train_op = optimizer.minimize(self.loss, global_step=self.global_step)
|
| 383 |
+
tmp = tf.compat.v1.summary.scalar("learning_rate", self.learning_rate_node)
|
| 384 |
+
self.summary_train.append(tmp)
|
| 385 |
+
|
| 386 |
+
def add_metrics_op(self):
|
| 387 |
+
with tf.compat.v1.variable_scope("metrics"):
|
| 388 |
+
|
| 389 |
+
Y= tf.argmax(input=self.Y, axis=-1)
|
| 390 |
+
confusion_matrix = tf.cast(tf.math.confusion_matrix(
|
| 391 |
+
labels=tf.reshape(Y, [-1]),
|
| 392 |
+
predictions=tf.reshape(self.preds, [-1]),
|
| 393 |
+
num_classes=self.n_class, name='confusion_matrix'),
|
| 394 |
+
dtype=tf.float32)
|
| 395 |
+
|
| 396 |
+
# with tf.variable_scope("P"):
|
| 397 |
+
c = tf.constant(1e-7, dtype=tf.float32)
|
| 398 |
+
precision_P = (confusion_matrix[1,1] + c) / (tf.reduce_sum(input_tensor=confusion_matrix[:,1]) + c)
|
| 399 |
+
recall_P = (confusion_matrix[1,1] + c) / (tf.reduce_sum(input_tensor=confusion_matrix[1,:]) + c)
|
| 400 |
+
f1_P = 2 * precision_P * recall_P / (precision_P + recall_P)
|
| 401 |
+
|
| 402 |
+
tmp1 = tf.compat.v1.summary.scalar("train_precision_p", precision_P)
|
| 403 |
+
tmp2 = tf.compat.v1.summary.scalar("train_recall_p", recall_P)
|
| 404 |
+
tmp3 = tf.compat.v1.summary.scalar("train_f1_p", f1_P)
|
| 405 |
+
self.summary_train.extend([tmp1, tmp2, tmp3])
|
| 406 |
+
|
| 407 |
+
tmp1 = tf.compat.v1.summary.scalar("valid_precision_p", precision_P)
|
| 408 |
+
tmp2 = tf.compat.v1.summary.scalar("valid_recall_p", recall_P)
|
| 409 |
+
tmp3 = tf.compat.v1.summary.scalar("valid_f1_p", f1_P)
|
| 410 |
+
self.summary_valid.extend([tmp1, tmp2, tmp3])
|
| 411 |
+
|
| 412 |
+
# with tf.variable_scope("S"):
|
| 413 |
+
precision_S = (confusion_matrix[2,2] + c) / (tf.reduce_sum(input_tensor=confusion_matrix[:,2]) + c)
|
| 414 |
+
recall_S = (confusion_matrix[2,2] + c) / (tf.reduce_sum(input_tensor=confusion_matrix[2,:]) + c)
|
| 415 |
+
f1_S = 2 * precision_S * recall_S / (precision_S + recall_S)
|
| 416 |
+
|
| 417 |
+
tmp1 = tf.compat.v1.summary.scalar("train_precision_s", precision_S)
|
| 418 |
+
tmp2 = tf.compat.v1.summary.scalar("train_recall_s", recall_S)
|
| 419 |
+
tmp3 = tf.compat.v1.summary.scalar("train_f1_s", f1_S)
|
| 420 |
+
self.summary_train.extend([tmp1, tmp2, tmp3])
|
| 421 |
+
|
| 422 |
+
tmp1 = tf.compat.v1.summary.scalar("valid_precision_s", precision_S)
|
| 423 |
+
tmp2 = tf.compat.v1.summary.scalar("valid_recall_s", recall_S)
|
| 424 |
+
tmp3 = tf.compat.v1.summary.scalar("valid_f1_s", f1_S)
|
| 425 |
+
self.summary_valid.extend([tmp1, tmp2, tmp3])
|
| 426 |
+
|
| 427 |
+
self.precision = [precision_P, precision_S]
|
| 428 |
+
self.recall = [recall_P, recall_S]
|
| 429 |
+
self.f1 = [f1_P, f1_S]
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
def train_on_batch(self, sess, inputs_batch, labels_batch, summary_writer, drop_rate=0.0):
|
| 434 |
+
feed = {self.X: inputs_batch,
|
| 435 |
+
self.Y: labels_batch,
|
| 436 |
+
self.drop_rate: drop_rate,
|
| 437 |
+
self.is_training: True}
|
| 438 |
+
|
| 439 |
+
_, step_summary, step, loss = sess.run([self.train_op,
|
| 440 |
+
self.summary_train,
|
| 441 |
+
self.global_step,
|
| 442 |
+
self.loss],
|
| 443 |
+
feed_dict=feed)
|
| 444 |
+
summary_writer.add_summary(step_summary, step)
|
| 445 |
+
return loss
|
| 446 |
+
|
| 447 |
+
def valid_on_batch(self, sess, inputs_batch, labels_batch, summary_writer):
|
| 448 |
+
feed = {self.X: inputs_batch,
|
| 449 |
+
self.Y: labels_batch,
|
| 450 |
+
self.drop_rate: 0,
|
| 451 |
+
self.is_training: False}
|
| 452 |
+
|
| 453 |
+
step_summary, step, loss, preds = sess.run([self.summary_valid,
|
| 454 |
+
self.global_step,
|
| 455 |
+
self.loss,
|
| 456 |
+
self.preds],
|
| 457 |
+
feed_dict=feed)
|
| 458 |
+
summary_writer.add_summary(step_summary, step)
|
| 459 |
+
return loss, preds
|
| 460 |
+
|
| 461 |
+
def test_on_batch(self, sess, summary_writer):
|
| 462 |
+
feed = {self.drop_rate: 0,
|
| 463 |
+
self.is_training: False}
|
| 464 |
+
step_summary, step, loss, preds, \
|
| 465 |
+
X_batch, Y_batch, fname_batch, \
|
| 466 |
+
itp_batch, its_batch = sess.run([self.summary_valid,
|
| 467 |
+
self.global_step,
|
| 468 |
+
self.loss,
|
| 469 |
+
self.preds,
|
| 470 |
+
self.X,
|
| 471 |
+
self.Y,
|
| 472 |
+
self.input_batch[2],
|
| 473 |
+
self.input_batch[3],
|
| 474 |
+
self.input_batch[4]],
|
| 475 |
+
feed_dict=feed)
|
| 476 |
+
summary_writer.add_summary(step_summary, step)
|
| 477 |
+
return loss, preds, X_batch, Y_batch, fname_batch, itp_batch, its_batch
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
def build(self, input_batch=None, mode='train'):
|
| 481 |
+
self.add_placeholders(input_batch, mode)
|
| 482 |
+
self.add_prediction_op()
|
| 483 |
+
if mode in ["train", "valid", "test"]:
|
| 484 |
+
self.add_loss_op()
|
| 485 |
+
self.add_training_op()
|
| 486 |
+
# self.add_metrics_op()
|
| 487 |
+
self.summary_train = tf.compat.v1.summary.merge(self.summary_train)
|
| 488 |
+
self.summary_valid = tf.compat.v1.summary.merge(self.summary_valid)
|
| 489 |
+
return 0
|
phasenet/postprocess.py
ADDED
|
@@ -0,0 +1,377 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
from collections import namedtuple
|
| 5 |
+
from datetime import datetime, timedelta
|
| 6 |
+
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
import numpy as np
|
| 9 |
+
from detect_peaks import detect_peaks
|
| 10 |
+
|
| 11 |
+
# def extract_picks(preds, fnames=None, station_ids=None, t0=None, config=None):
|
| 12 |
+
|
| 13 |
+
# if preds.shape[-1] == 4:
|
| 14 |
+
# record = namedtuple("phase", ["fname", "station_id", "t0", "p_idx", "p_prob", "s_idx", "s_prob", "ps_idx", "ps_prob"])
|
| 15 |
+
# else:
|
| 16 |
+
# record = namedtuple("phase", ["fname", "station_id", "t0", "p_idx", "p_prob", "s_idx", "s_prob"])
|
| 17 |
+
|
| 18 |
+
# picks = []
|
| 19 |
+
# for i, pred in enumerate(preds):
|
| 20 |
+
|
| 21 |
+
# if config is None:
|
| 22 |
+
# mph_p, mph_s, mpd = 0.3, 0.3, 50
|
| 23 |
+
# else:
|
| 24 |
+
# mph_p, mph_s, mpd = config.min_p_prob, config.min_s_prob, config.mpd
|
| 25 |
+
|
| 26 |
+
# if (fnames is None):
|
| 27 |
+
# fname = f"{i:04d}"
|
| 28 |
+
# else:
|
| 29 |
+
# if isinstance(fnames[i], str):
|
| 30 |
+
# fname = fnames[i]
|
| 31 |
+
# else:
|
| 32 |
+
# fname = fnames[i].decode()
|
| 33 |
+
|
| 34 |
+
# if (station_ids is None):
|
| 35 |
+
# station_id = f"{i:04d}"
|
| 36 |
+
# else:
|
| 37 |
+
# if isinstance(station_ids[i], str):
|
| 38 |
+
# station_id = station_ids[i]
|
| 39 |
+
# else:
|
| 40 |
+
# station_id = station_ids[i].decode()
|
| 41 |
+
|
| 42 |
+
# if (t0 is None):
|
| 43 |
+
# start_time = "1970-01-01T00:00:00.000"
|
| 44 |
+
# else:
|
| 45 |
+
# if isinstance(t0[i], str):
|
| 46 |
+
# start_time = t0[i]
|
| 47 |
+
# else:
|
| 48 |
+
# start_time = t0[i].decode()
|
| 49 |
+
|
| 50 |
+
# p_idx, p_prob, s_idx, s_prob = [], [], [], []
|
| 51 |
+
# for j in range(pred.shape[1]):
|
| 52 |
+
# p_idx_, p_prob_ = detect_peaks(pred[:,j,1], mph=mph_p, mpd=mpd, show=False)
|
| 53 |
+
# s_idx_, s_prob_ = detect_peaks(pred[:,j,2], mph=mph_s, mpd=mpd, show=False)
|
| 54 |
+
# p_idx.append(list(p_idx_))
|
| 55 |
+
# p_prob.append(list(p_prob_))
|
| 56 |
+
# s_idx.append(list(s_idx_))
|
| 57 |
+
# s_prob.append(list(s_prob_))
|
| 58 |
+
|
| 59 |
+
# if pred.shape[-1] == 4:
|
| 60 |
+
# ps_idx, ps_prob = detect_peaks(pred[:,0,3], mph=0.3, mpd=mpd, show=False)
|
| 61 |
+
# picks.append(record(fname, station_id, start_time, list(p_idx), list(p_prob), list(s_idx), list(s_prob), list(ps_idx), list(ps_prob)))
|
| 62 |
+
# else:
|
| 63 |
+
# picks.append(record(fname, station_id, start_time, list(p_idx), list(p_prob), list(s_idx), list(s_prob)))
|
| 64 |
+
|
| 65 |
+
# return picks
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def extract_picks(
|
| 69 |
+
preds,
|
| 70 |
+
file_names=None,
|
| 71 |
+
begin_times=None,
|
| 72 |
+
station_ids=None,
|
| 73 |
+
dt=0.01,
|
| 74 |
+
phases=["P", "S"],
|
| 75 |
+
config=None,
|
| 76 |
+
waveforms=None,
|
| 77 |
+
use_amplitude=False,
|
| 78 |
+
):
|
| 79 |
+
"""Extract picks from prediction results.
|
| 80 |
+
Args:
|
| 81 |
+
preds ([type]): [Nb, Nt, Ns, Nc] "batch, time, station, channel"
|
| 82 |
+
file_names ([type], optional): [Nb]. Defaults to None.
|
| 83 |
+
station_ids ([type], optional): [Ns]. Defaults to None.
|
| 84 |
+
t0 ([type], optional): [Nb]. Defaults to None.
|
| 85 |
+
config ([type], optional): [description]. Defaults to None.
|
| 86 |
+
|
| 87 |
+
Returns:
|
| 88 |
+
picks [type]: {file_name, station_id, pick_time, pick_prob, pick_type}
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
mph = {}
|
| 92 |
+
if config is None:
|
| 93 |
+
for x in phases:
|
| 94 |
+
mph[x] = 0.3
|
| 95 |
+
mpd = 50
|
| 96 |
+
pre_idx = int(1 / dt)
|
| 97 |
+
post_idx = int(4 / dt)
|
| 98 |
+
else:
|
| 99 |
+
mph["P"] = config.min_p_prob
|
| 100 |
+
mph["S"] = config.min_s_prob
|
| 101 |
+
mph["PS"] = 0.3
|
| 102 |
+
mpd = config.mpd
|
| 103 |
+
pre_idx = int(config.pre_sec / dt)
|
| 104 |
+
post_idx = int(config.post_sec / dt)
|
| 105 |
+
|
| 106 |
+
Nb, Nt, Ns, Nc = preds.shape
|
| 107 |
+
|
| 108 |
+
if file_names is None:
|
| 109 |
+
file_names = [f"{i:04d}" for i in range(Nb)]
|
| 110 |
+
elif not (isinstance(file_names, np.ndarray) or isinstance(file_names, list)):
|
| 111 |
+
if isinstance(file_names, bytes):
|
| 112 |
+
file_names = file_names.decode()
|
| 113 |
+
file_names = [file_names] * Nb
|
| 114 |
+
else:
|
| 115 |
+
file_names = [x.decode() if isinstance(x, bytes) else x for x in file_names]
|
| 116 |
+
|
| 117 |
+
if begin_times is None:
|
| 118 |
+
begin_times = ["1970-01-01T00:00:00.000+00:00"] * Nb
|
| 119 |
+
else:
|
| 120 |
+
begin_times = [x.decode() if isinstance(x, bytes) else x for x in begin_times]
|
| 121 |
+
|
| 122 |
+
picks = []
|
| 123 |
+
for i in range(Nb):
|
| 124 |
+
file_name = file_names[i]
|
| 125 |
+
begin_time = datetime.fromisoformat(begin_times[i])
|
| 126 |
+
|
| 127 |
+
for j in range(Ns):
|
| 128 |
+
if (station_ids is None) or (len(station_ids[i]) == 0):
|
| 129 |
+
station_id = f"{j:04d}"
|
| 130 |
+
else:
|
| 131 |
+
station_id = station_ids[i][j].decode() if isinstance(station_ids[i][j], bytes) else station_ids[i][j]
|
| 132 |
+
|
| 133 |
+
if (waveforms is not None) and use_amplitude:
|
| 134 |
+
amp = np.max(np.abs(waveforms[i, :, j, :]), axis=-1) ## amplitude over three channelspy
|
| 135 |
+
for k in range(Nc - 1): # 0-th channel noise
|
| 136 |
+
idxs, probs = detect_peaks(preds[i, :, j, k + 1], mph=mph[phases[k]], mpd=mpd, show=False)
|
| 137 |
+
for l, (phase_index, phase_prob) in enumerate(zip(idxs, probs)):
|
| 138 |
+
pick_time = begin_time + timedelta(seconds=phase_index * dt)
|
| 139 |
+
pick = {
|
| 140 |
+
"file_name": file_name,
|
| 141 |
+
"station_id": station_id,
|
| 142 |
+
"begin_time": begin_time.isoformat(timespec="milliseconds"),
|
| 143 |
+
"phase_index": int(phase_index),
|
| 144 |
+
"phase_time": pick_time.isoformat(timespec="milliseconds"),
|
| 145 |
+
"phase_score": round(phase_prob, 3),
|
| 146 |
+
"phase_type": phases[k],
|
| 147 |
+
"dt": dt,
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
## process waveform
|
| 151 |
+
if waveforms is not None:
|
| 152 |
+
tmp = np.zeros((pre_idx + post_idx, 3))
|
| 153 |
+
lo = phase_index - pre_idx
|
| 154 |
+
hi = phase_index + post_idx
|
| 155 |
+
insert_idx = 0
|
| 156 |
+
if lo < 0:
|
| 157 |
+
lo = 0
|
| 158 |
+
insert_idx = -lo
|
| 159 |
+
if hi > Nt:
|
| 160 |
+
hi = Nt
|
| 161 |
+
tmp[insert_idx : insert_idx + hi - lo, :] = waveforms[i, lo:hi, j, :]
|
| 162 |
+
if use_amplitude:
|
| 163 |
+
next_pick = idxs[l + 1] if l < len(idxs) - 1 else (phase_index + post_idx * 3)
|
| 164 |
+
pick["phase_amplitude"] = np.max(
|
| 165 |
+
amp[phase_index : min(phase_index + post_idx * 3, next_pick)]
|
| 166 |
+
).item() ## peak amplitude
|
| 167 |
+
|
| 168 |
+
picks.append(pick)
|
| 169 |
+
|
| 170 |
+
return picks
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def extract_amplitude(data, picks, window_p=10, window_s=5, config=None):
|
| 174 |
+
record = namedtuple("amplitude", ["p_amp", "s_amp"])
|
| 175 |
+
dt = 0.01 if config is None else config.dt
|
| 176 |
+
window_p = int(window_p / dt)
|
| 177 |
+
window_s = int(window_s / dt)
|
| 178 |
+
amps = []
|
| 179 |
+
for i, (da, pi) in enumerate(zip(data, picks)):
|
| 180 |
+
p_amp, s_amp = [], []
|
| 181 |
+
for j in range(da.shape[1]):
|
| 182 |
+
amp = np.max(np.abs(da[:, j, :]), axis=-1)
|
| 183 |
+
# amp = np.median(np.abs(da[:,j,:]), axis=-1)
|
| 184 |
+
# amp = np.linalg.norm(da[:,j,:], axis=-1)
|
| 185 |
+
tmp = []
|
| 186 |
+
for k in range(len(pi.p_idx[j]) - 1):
|
| 187 |
+
tmp.append(np.max(amp[pi.p_idx[j][k] : min(pi.p_idx[j][k] + window_p, pi.p_idx[j][k + 1])]))
|
| 188 |
+
if len(pi.p_idx[j]) >= 1:
|
| 189 |
+
tmp.append(np.max(amp[pi.p_idx[j][-1] : pi.p_idx[j][-1] + window_p]))
|
| 190 |
+
p_amp.append(tmp)
|
| 191 |
+
tmp = []
|
| 192 |
+
for k in range(len(pi.s_idx[j]) - 1):
|
| 193 |
+
tmp.append(np.max(amp[pi.s_idx[j][k] : min(pi.s_idx[j][k] + window_s, pi.s_idx[j][k + 1])]))
|
| 194 |
+
if len(pi.s_idx[j]) >= 1:
|
| 195 |
+
tmp.append(np.max(amp[pi.s_idx[j][-1] : pi.s_idx[j][-1] + window_s]))
|
| 196 |
+
s_amp.append(tmp)
|
| 197 |
+
amps.append(record(p_amp, s_amp))
|
| 198 |
+
return amps
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def save_picks(picks, output_dir, amps=None, fname=None):
|
| 202 |
+
if fname is None:
|
| 203 |
+
fname = "picks.csv"
|
| 204 |
+
|
| 205 |
+
int2s = lambda x: ",".join(["[" + ",".join(map(str, i)) + "]" for i in x])
|
| 206 |
+
flt2s = lambda x: ",".join(["[" + ",".join(map("{:0.3f}".format, i)) + "]" for i in x])
|
| 207 |
+
sci2s = lambda x: ",".join(["[" + ",".join(map("{:0.3e}".format, i)) + "]" for i in x])
|
| 208 |
+
if amps is None:
|
| 209 |
+
if hasattr(picks[0], "ps_idx"):
|
| 210 |
+
with open(os.path.join(output_dir, fname), "w") as fp:
|
| 211 |
+
fp.write("fname\tt0\tp_idx\tp_prob\ts_idx\ts_prob\tps_idx\tps_prob\n")
|
| 212 |
+
for pick in picks:
|
| 213 |
+
fp.write(
|
| 214 |
+
f"{pick.fname}\t{pick.t0}\t{int2s(pick.p_idx)}\t{flt2s(pick.p_prob)}\t{int2s(pick.s_idx)}\t{flt2s(pick.s_prob)}\t{int2s(pick.ps_idx)}\t{flt2s(pick.ps_prob)}\n"
|
| 215 |
+
)
|
| 216 |
+
fp.close()
|
| 217 |
+
else:
|
| 218 |
+
with open(os.path.join(output_dir, fname), "w") as fp:
|
| 219 |
+
fp.write("fname\tt0\tp_idx\tp_prob\ts_idx\ts_prob\n")
|
| 220 |
+
for pick in picks:
|
| 221 |
+
fp.write(
|
| 222 |
+
f"{pick.fname}\t{pick.t0}\t{int2s(pick.p_idx)}\t{flt2s(pick.p_prob)}\t{int2s(pick.s_idx)}\t{flt2s(pick.s_prob)}\n"
|
| 223 |
+
)
|
| 224 |
+
fp.close()
|
| 225 |
+
else:
|
| 226 |
+
with open(os.path.join(output_dir, fname), "w") as fp:
|
| 227 |
+
fp.write("fname\tt0\tp_idx\tp_prob\ts_idx\ts_prob\tp_amp\ts_amp\n")
|
| 228 |
+
for pick, amp in zip(picks, amps):
|
| 229 |
+
fp.write(
|
| 230 |
+
f"{pick.fname}\t{pick.t0}\t{int2s(pick.p_idx)}\t{flt2s(pick.p_prob)}\t{int2s(pick.s_idx)}\t{flt2s(pick.s_prob)}\t{sci2s(amp.p_amp)}\t{sci2s(amp.s_amp)}\n"
|
| 231 |
+
)
|
| 232 |
+
fp.close()
|
| 233 |
+
|
| 234 |
+
return 0
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def calc_timestamp(timestamp, sec):
|
| 238 |
+
timestamp = datetime.strptime(timestamp, "%Y-%m-%dT%H:%M:%S.%f") + timedelta(seconds=sec)
|
| 239 |
+
return timestamp.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3]
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def save_picks_json(picks, output_dir, dt=0.01, amps=None, fname=None):
|
| 243 |
+
if fname is None:
|
| 244 |
+
fname = "picks.json"
|
| 245 |
+
|
| 246 |
+
picks_ = []
|
| 247 |
+
if amps is None:
|
| 248 |
+
for pick in picks:
|
| 249 |
+
for idxs, probs in zip(pick.p_idx, pick.p_prob):
|
| 250 |
+
for idx, prob in zip(idxs, probs):
|
| 251 |
+
picks_.append(
|
| 252 |
+
{
|
| 253 |
+
"id": pick.station_id,
|
| 254 |
+
"timestamp": calc_timestamp(pick.t0, float(idx) * dt),
|
| 255 |
+
"prob": prob.astype(float),
|
| 256 |
+
"type": "p",
|
| 257 |
+
}
|
| 258 |
+
)
|
| 259 |
+
for idxs, probs in zip(pick.s_idx, pick.s_prob):
|
| 260 |
+
for idx, prob in zip(idxs, probs):
|
| 261 |
+
picks_.append(
|
| 262 |
+
{
|
| 263 |
+
"id": pick.station_id,
|
| 264 |
+
"timestamp": calc_timestamp(pick.t0, float(idx) * dt),
|
| 265 |
+
"prob": prob.astype(float),
|
| 266 |
+
"type": "s",
|
| 267 |
+
}
|
| 268 |
+
)
|
| 269 |
+
else:
|
| 270 |
+
for pick, amplitude in zip(picks, amps):
|
| 271 |
+
for idxs, probs, amps in zip(pick.p_idx, pick.p_prob, amplitude.p_amp):
|
| 272 |
+
for idx, prob, amp in zip(idxs, probs, amps):
|
| 273 |
+
picks_.append(
|
| 274 |
+
{
|
| 275 |
+
"id": pick.station_id,
|
| 276 |
+
"timestamp": calc_timestamp(pick.t0, float(idx) * dt),
|
| 277 |
+
"prob": prob.astype(float),
|
| 278 |
+
"amp": amp.astype(float),
|
| 279 |
+
"type": "p",
|
| 280 |
+
}
|
| 281 |
+
)
|
| 282 |
+
for idxs, probs, amps in zip(pick.s_idx, pick.s_prob, amplitude.s_amp):
|
| 283 |
+
for idx, prob, amp in zip(idxs, probs, amps):
|
| 284 |
+
picks_.append(
|
| 285 |
+
{
|
| 286 |
+
"id": pick.station_id,
|
| 287 |
+
"timestamp": calc_timestamp(pick.t0, float(idx) * dt),
|
| 288 |
+
"prob": prob.astype(float),
|
| 289 |
+
"amp": amp.astype(float),
|
| 290 |
+
"type": "s",
|
| 291 |
+
}
|
| 292 |
+
)
|
| 293 |
+
with open(os.path.join(output_dir, fname), "w") as fp:
|
| 294 |
+
json.dump(picks_, fp)
|
| 295 |
+
|
| 296 |
+
return 0
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def convert_true_picks(fname, itp, its, itps=None):
|
| 300 |
+
true_picks = []
|
| 301 |
+
if itps is None:
|
| 302 |
+
record = namedtuple("phase", ["fname", "p_idx", "s_idx"])
|
| 303 |
+
for i in range(len(fname)):
|
| 304 |
+
true_picks.append(record(fname[i].decode(), itp[i], its[i]))
|
| 305 |
+
else:
|
| 306 |
+
record = namedtuple("phase", ["fname", "p_idx", "s_idx", "ps_idx"])
|
| 307 |
+
for i in range(len(fname)):
|
| 308 |
+
true_picks.append(record(fname[i].decode(), itp[i], its[i], itps[i]))
|
| 309 |
+
|
| 310 |
+
return true_picks
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
def calc_metrics(nTP, nP, nT):
|
| 314 |
+
"""
|
| 315 |
+
nTP: true positive
|
| 316 |
+
nP: number of positive picks
|
| 317 |
+
nT: number of true picks
|
| 318 |
+
"""
|
| 319 |
+
precision = nTP / nP
|
| 320 |
+
recall = nTP / nT
|
| 321 |
+
f1 = 2 * precision * recall / (precision + recall)
|
| 322 |
+
return [precision, recall, f1]
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
def calc_performance(picks, true_picks, tol=3.0, dt=1.0):
|
| 326 |
+
assert len(picks) == len(true_picks)
|
| 327 |
+
logging.info("Total records: {}".format(len(picks)))
|
| 328 |
+
|
| 329 |
+
count = lambda picks: sum([len(x) for x in picks])
|
| 330 |
+
metrics = {}
|
| 331 |
+
for phase in true_picks[0]._fields:
|
| 332 |
+
if phase == "fname":
|
| 333 |
+
continue
|
| 334 |
+
true_positive, positive, true = 0, 0, 0
|
| 335 |
+
residual = []
|
| 336 |
+
for i in range(len(true_picks)):
|
| 337 |
+
true += count(getattr(true_picks[i], phase))
|
| 338 |
+
positive += count(getattr(picks[i], phase))
|
| 339 |
+
# print(i, phase, getattr(picks[i], phase), getattr(true_picks[i], phase))
|
| 340 |
+
diff = dt * (
|
| 341 |
+
np.array(getattr(picks[i], phase))[:, np.newaxis, :]
|
| 342 |
+
- np.array(getattr(true_picks[i], phase))[:, :, np.newaxis]
|
| 343 |
+
)
|
| 344 |
+
residual.extend(list(diff[np.abs(diff) <= tol]))
|
| 345 |
+
true_positive += np.sum(np.abs(diff) <= tol)
|
| 346 |
+
metrics[phase] = calc_metrics(true_positive, positive, true)
|
| 347 |
+
|
| 348 |
+
logging.info(f"{phase}-phase:")
|
| 349 |
+
logging.info(f"True={true}, Positive={positive}, True Positive={true_positive}")
|
| 350 |
+
logging.info(f"Precision={metrics[phase][0]:.3f}, Recall={metrics[phase][1]:.3f}, F1={metrics[phase][2]:.3f}")
|
| 351 |
+
logging.info(f"Residual mean={np.mean(residual):.4f}, std={np.std(residual):.4f}")
|
| 352 |
+
|
| 353 |
+
return metrics
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
def save_prob_h5(probs, fnames, output_h5):
|
| 357 |
+
if fnames is None:
|
| 358 |
+
fnames = [f"{i:04d}" for i in range(len(probs))]
|
| 359 |
+
elif type(fnames[0]) is bytes:
|
| 360 |
+
fnames = [f.decode().rstrip(".npz") for f in fnames]
|
| 361 |
+
else:
|
| 362 |
+
fnames = [f.rstrip(".npz") for f in fnames]
|
| 363 |
+
for prob, fname in zip(probs, fnames):
|
| 364 |
+
output_h5.create_dataset(fname, data=prob, dtype="float32")
|
| 365 |
+
return 0
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
def save_prob(probs, fnames, prob_dir):
|
| 369 |
+
if fnames is None:
|
| 370 |
+
fnames = [f"{i:04d}" for i in range(len(probs))]
|
| 371 |
+
elif type(fnames[0]) is bytes:
|
| 372 |
+
fnames = [f.decode().rstrip(".npz") for f in fnames]
|
| 373 |
+
else:
|
| 374 |
+
fnames = [f.rstrip(".npz") for f in fnames]
|
| 375 |
+
for prob, fname in zip(probs, fnames):
|
| 376 |
+
np.savez(os.path.join(prob_dir, fname + ".npz"), prob=prob)
|
| 377 |
+
return 0
|
phasenet/predict.py
ADDED
|
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import logging
|
| 3 |
+
import multiprocessing
|
| 4 |
+
import os
|
| 5 |
+
import pickle
|
| 6 |
+
import time
|
| 7 |
+
from functools import partial
|
| 8 |
+
|
| 9 |
+
import h5py
|
| 10 |
+
import numpy as np
|
| 11 |
+
import pandas as pd
|
| 12 |
+
import tensorflow as tf
|
| 13 |
+
from data_reader import DataReader_mseed_array, DataReader_pred
|
| 14 |
+
from postprocess import (
|
| 15 |
+
extract_amplitude,
|
| 16 |
+
extract_picks,
|
| 17 |
+
save_picks,
|
| 18 |
+
save_picks_json,
|
| 19 |
+
save_prob_h5,
|
| 20 |
+
)
|
| 21 |
+
from tqdm import tqdm
|
| 22 |
+
from visulization import plot_waveform
|
| 23 |
+
|
| 24 |
+
from model import ModelConfig, UNet
|
| 25 |
+
|
| 26 |
+
tf.compat.v1.disable_eager_execution()
|
| 27 |
+
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def read_args():
|
| 31 |
+
parser = argparse.ArgumentParser()
|
| 32 |
+
parser.add_argument("--batch_size", default=20, type=int, help="batch size")
|
| 33 |
+
parser.add_argument("--model_dir", help="Checkpoint directory (default: None)")
|
| 34 |
+
parser.add_argument("--data_dir", default="", help="Input file directory")
|
| 35 |
+
parser.add_argument("--data_list", default="", help="Input csv file")
|
| 36 |
+
parser.add_argument("--hdf5_file", default="", help="Input hdf5 file")
|
| 37 |
+
parser.add_argument("--hdf5_group", default="data", help="data group name in hdf5 file")
|
| 38 |
+
parser.add_argument("--result_dir", default="results", help="Output directory")
|
| 39 |
+
parser.add_argument("--result_fname", default="picks", help="Output file")
|
| 40 |
+
parser.add_argument("--min_p_prob", default=0.3, type=float, help="Probability threshold for P pick")
|
| 41 |
+
parser.add_argument("--min_s_prob", default=0.3, type=float, help="Probability threshold for S pick")
|
| 42 |
+
parser.add_argument("--mpd", default=50, type=float, help="Minimum peak distance")
|
| 43 |
+
parser.add_argument("--amplitude", action="store_true", help="if return amplitude value")
|
| 44 |
+
parser.add_argument("--format", default="numpy", help="input format")
|
| 45 |
+
parser.add_argument("--s3_url", default="localhost:9000", help="s3 url")
|
| 46 |
+
parser.add_argument("--stations", default="", help="seismic station info")
|
| 47 |
+
parser.add_argument("--plot_figure", action="store_true", help="If plot figure for test")
|
| 48 |
+
parser.add_argument("--save_prob", action="store_true", help="If save result for test")
|
| 49 |
+
parser.add_argument("--pre_sec", default=1, type=float, help="Window length before pick")
|
| 50 |
+
parser.add_argument("--post_sec", default=4, type=float, help="Window length after pick")
|
| 51 |
+
|
| 52 |
+
parser.add_argument("--highpass_filter", default=0.0, type=float, help="Highpass filter")
|
| 53 |
+
parser.add_argument("--response_xml", default=None, type=str, help="response xml file")
|
| 54 |
+
parser.add_argument("--sampling_rate", default=100, type=float, help="sampling rate")
|
| 55 |
+
args = parser.parse_args()
|
| 56 |
+
|
| 57 |
+
return args
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def pred_fn(args, data_reader, figure_dir=None, prob_dir=None, log_dir=None):
|
| 61 |
+
current_time = time.strftime("%y%m%d-%H%M%S")
|
| 62 |
+
if log_dir is None:
|
| 63 |
+
log_dir = os.path.join(args.log_dir, "pred", current_time)
|
| 64 |
+
if not os.path.exists(log_dir):
|
| 65 |
+
os.makedirs(log_dir)
|
| 66 |
+
if (args.plot_figure == True) and (figure_dir is None):
|
| 67 |
+
figure_dir = os.path.join(log_dir, "figures")
|
| 68 |
+
if not os.path.exists(figure_dir):
|
| 69 |
+
os.makedirs(figure_dir)
|
| 70 |
+
if (args.save_prob == True) and (prob_dir is None):
|
| 71 |
+
prob_dir = os.path.join(log_dir, "probs")
|
| 72 |
+
if not os.path.exists(prob_dir):
|
| 73 |
+
os.makedirs(prob_dir)
|
| 74 |
+
if args.save_prob:
|
| 75 |
+
h5 = h5py.File(os.path.join(args.result_dir, "result.h5"), "w", libver="latest")
|
| 76 |
+
prob_h5 = h5.create_group("/prob")
|
| 77 |
+
logging.info("Pred log: %s" % log_dir)
|
| 78 |
+
logging.info("Dataset size: {}".format(data_reader.num_data))
|
| 79 |
+
|
| 80 |
+
with tf.compat.v1.name_scope("Input_Batch"):
|
| 81 |
+
if args.format == "mseed_array":
|
| 82 |
+
batch_size = 1
|
| 83 |
+
else:
|
| 84 |
+
batch_size = args.batch_size
|
| 85 |
+
dataset = data_reader.dataset(batch_size)
|
| 86 |
+
batch = tf.compat.v1.data.make_one_shot_iterator(dataset).get_next()
|
| 87 |
+
|
| 88 |
+
config = ModelConfig(X_shape=data_reader.X_shape)
|
| 89 |
+
with open(os.path.join(log_dir, "config.log"), "w") as fp:
|
| 90 |
+
fp.write("\n".join("%s: %s" % item for item in vars(config).items()))
|
| 91 |
+
|
| 92 |
+
model = UNet(config=config, input_batch=batch, mode="pred")
|
| 93 |
+
# model = UNet(config=config, mode="pred")
|
| 94 |
+
sess_config = tf.compat.v1.ConfigProto()
|
| 95 |
+
sess_config.gpu_options.allow_growth = True
|
| 96 |
+
# sess_config.log_device_placement = False
|
| 97 |
+
|
| 98 |
+
with tf.compat.v1.Session(config=sess_config) as sess:
|
| 99 |
+
saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables(), max_to_keep=5)
|
| 100 |
+
init = tf.compat.v1.global_variables_initializer()
|
| 101 |
+
sess.run(init)
|
| 102 |
+
|
| 103 |
+
latest_check_point = tf.train.latest_checkpoint(args.model_dir)
|
| 104 |
+
logging.info(f"restoring model {latest_check_point}")
|
| 105 |
+
saver.restore(sess, latest_check_point)
|
| 106 |
+
|
| 107 |
+
picks = []
|
| 108 |
+
amps = [] if args.amplitude else None
|
| 109 |
+
if args.plot_figure:
|
| 110 |
+
multiprocessing.set_start_method("spawn")
|
| 111 |
+
pool = multiprocessing.Pool(multiprocessing.cpu_count())
|
| 112 |
+
|
| 113 |
+
for _ in tqdm(range(0, data_reader.num_data, batch_size), desc="Pred"):
|
| 114 |
+
if args.amplitude:
|
| 115 |
+
pred_batch, X_batch, amp_batch, fname_batch, t0_batch, station_batch = sess.run(
|
| 116 |
+
[model.preds, batch[0], batch[1], batch[2], batch[3], batch[4]],
|
| 117 |
+
feed_dict={model.drop_rate: 0, model.is_training: False},
|
| 118 |
+
)
|
| 119 |
+
# X_batch, amp_batch, fname_batch, t0_batch = sess.run([batch[0], batch[1], batch[2], batch[3]])
|
| 120 |
+
else:
|
| 121 |
+
pred_batch, X_batch, fname_batch, t0_batch, station_batch = sess.run(
|
| 122 |
+
[model.preds, batch[0], batch[1], batch[2], batch[3]],
|
| 123 |
+
feed_dict={model.drop_rate: 0, model.is_training: False},
|
| 124 |
+
)
|
| 125 |
+
# X_batch, fname_batch, t0_batch = sess.run([model.preds, batch[0], batch[1], batch[2]])
|
| 126 |
+
# pred_batch = []
|
| 127 |
+
# for i in range(0, len(X_batch), 1):
|
| 128 |
+
# pred_batch.append(sess.run(model.preds, feed_dict={model.X: X_batch[i:i+1], model.drop_rate: 0, model.is_training: False}))
|
| 129 |
+
# pred_batch = np.vstack(pred_batch)
|
| 130 |
+
|
| 131 |
+
waveforms = None
|
| 132 |
+
if args.amplitude:
|
| 133 |
+
waveforms = amp_batch
|
| 134 |
+
|
| 135 |
+
picks_ = extract_picks(
|
| 136 |
+
preds=pred_batch,
|
| 137 |
+
file_names=fname_batch,
|
| 138 |
+
station_ids=station_batch,
|
| 139 |
+
begin_times=t0_batch,
|
| 140 |
+
config=args,
|
| 141 |
+
waveforms=waveforms,
|
| 142 |
+
use_amplitude=args.amplitude,
|
| 143 |
+
dt=1.0 / args.sampling_rate,
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
picks.extend(picks_)
|
| 147 |
+
|
| 148 |
+
## save pick per file
|
| 149 |
+
if len(fname_batch) == 1:
|
| 150 |
+
df = pd.DataFrame(picks_)
|
| 151 |
+
df = df[df["phase_index"] > 10]
|
| 152 |
+
if not os.path.exists(os.path.join(args.result_dir, "picks")):
|
| 153 |
+
os.makedirs(os.path.join(args.result_dir, "picks"))
|
| 154 |
+
df = df[
|
| 155 |
+
[
|
| 156 |
+
"station_id",
|
| 157 |
+
"begin_time",
|
| 158 |
+
"phase_index",
|
| 159 |
+
"phase_time",
|
| 160 |
+
"phase_score",
|
| 161 |
+
"phase_type",
|
| 162 |
+
"phase_amplitude",
|
| 163 |
+
"dt",
|
| 164 |
+
]
|
| 165 |
+
]
|
| 166 |
+
df.to_csv(
|
| 167 |
+
os.path.join(
|
| 168 |
+
args.result_dir, "picks", fname_batch[0].decode().split("/")[-1].rstrip(".mseed") + ".csv"
|
| 169 |
+
),
|
| 170 |
+
index=False,
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
if args.plot_figure:
|
| 174 |
+
if not (isinstance(fname_batch, np.ndarray) or isinstance(fname_batch, list)):
|
| 175 |
+
fname_batch = [fname_batch.decode().rstrip(".mseed") + "_" + x.decode() for x in station_batch]
|
| 176 |
+
else:
|
| 177 |
+
fname_batch = [x.decode() for x in fname_batch]
|
| 178 |
+
pool.starmap(
|
| 179 |
+
partial(
|
| 180 |
+
plot_waveform,
|
| 181 |
+
figure_dir=figure_dir,
|
| 182 |
+
),
|
| 183 |
+
# zip(X_batch, pred_batch, [x.decode() for x in fname_batch]),
|
| 184 |
+
zip(X_batch, pred_batch, fname_batch),
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
if args.save_prob:
|
| 188 |
+
# save_prob(pred_batch, fname_batch, prob_dir=prob_dir)
|
| 189 |
+
if not (isinstance(fname_batch, np.ndarray) or isinstance(fname_batch, list)):
|
| 190 |
+
fname_batch = [fname_batch.decode().rstrip(".mseed") + "_" + x.decode() for x in station_batch]
|
| 191 |
+
else:
|
| 192 |
+
fname_batch = [x.decode() for x in fname_batch]
|
| 193 |
+
save_prob_h5(pred_batch, fname_batch, prob_h5)
|
| 194 |
+
|
| 195 |
+
if len(picks) > 0:
|
| 196 |
+
# save_picks(picks, args.result_dir, amps=amps, fname=args.result_fname+".csv")
|
| 197 |
+
# save_picks_json(picks, args.result_dir, dt=data_reader.dt, amps=amps, fname=args.result_fname+".json")
|
| 198 |
+
df = pd.DataFrame(picks)
|
| 199 |
+
# df["fname"] = df["file_name"]
|
| 200 |
+
# df["id"] = df["station_id"]
|
| 201 |
+
# df["timestamp"] = df["phase_time"]
|
| 202 |
+
# df["prob"] = df["phase_prob"]
|
| 203 |
+
# df["type"] = df["phase_type"]
|
| 204 |
+
|
| 205 |
+
base_columns = [
|
| 206 |
+
"station_id",
|
| 207 |
+
"begin_time",
|
| 208 |
+
"phase_index",
|
| 209 |
+
"phase_time",
|
| 210 |
+
"phase_score",
|
| 211 |
+
"phase_type",
|
| 212 |
+
"file_name",
|
| 213 |
+
]
|
| 214 |
+
if args.amplitude:
|
| 215 |
+
base_columns.append("phase_amplitude")
|
| 216 |
+
base_columns.append("phase_amp")
|
| 217 |
+
df["phase_amp"] = df["phase_amplitude"]
|
| 218 |
+
|
| 219 |
+
df = df[base_columns]
|
| 220 |
+
df.to_csv(os.path.join(args.result_dir, args.result_fname + ".csv"), index=False)
|
| 221 |
+
|
| 222 |
+
print(
|
| 223 |
+
f"Done with {len(df[df['phase_type'] == 'P'])} P-picks and {len(df[df['phase_type'] == 'S'])} S-picks"
|
| 224 |
+
)
|
| 225 |
+
else:
|
| 226 |
+
print(f"Done with 0 P-picks and 0 S-picks")
|
| 227 |
+
return 0
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def main(args):
|
| 231 |
+
logging.basicConfig(format="%(asctime)s %(message)s", level=logging.INFO)
|
| 232 |
+
|
| 233 |
+
with tf.compat.v1.name_scope("create_inputs"):
|
| 234 |
+
if args.format == "mseed_array":
|
| 235 |
+
data_reader = DataReader_mseed_array(
|
| 236 |
+
data_dir=args.data_dir,
|
| 237 |
+
data_list=args.data_list,
|
| 238 |
+
stations=args.stations,
|
| 239 |
+
amplitude=args.amplitude,
|
| 240 |
+
highpass_filter=args.highpass_filter,
|
| 241 |
+
)
|
| 242 |
+
else:
|
| 243 |
+
data_reader = DataReader_pred(
|
| 244 |
+
format=args.format,
|
| 245 |
+
data_dir=args.data_dir,
|
| 246 |
+
data_list=args.data_list,
|
| 247 |
+
hdf5_file=args.hdf5_file,
|
| 248 |
+
hdf5_group=args.hdf5_group,
|
| 249 |
+
amplitude=args.amplitude,
|
| 250 |
+
highpass_filter=args.highpass_filter,
|
| 251 |
+
response_xml=args.response_xml,
|
| 252 |
+
sampling_rate=args.sampling_rate,
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
pred_fn(args, data_reader, log_dir=args.result_dir)
|
| 256 |
+
|
| 257 |
+
return
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
if __name__ == "__main__":
|
| 261 |
+
args = read_args()
|
| 262 |
+
main(args)
|
phasenet/slide_window.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from collections import defaultdict, namedtuple
|
| 3 |
+
from datetime import datetime, timedelta
|
| 4 |
+
from json import dumps
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import tensorflow as tf
|
| 8 |
+
|
| 9 |
+
from model import ModelConfig, UNet
|
| 10 |
+
from postprocess import extract_amplitude, extract_picks
|
| 11 |
+
import pandas as pd
|
| 12 |
+
import obspy
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
tf.compat.v1.disable_eager_execution()
|
| 16 |
+
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
|
| 17 |
+
PROJECT_ROOT = os.path.realpath(os.path.join(os.path.dirname(__file__), ".."))
|
| 18 |
+
|
| 19 |
+
# load model
|
| 20 |
+
model = UNet(mode="pred")
|
| 21 |
+
sess_config = tf.compat.v1.ConfigProto()
|
| 22 |
+
sess_config.gpu_options.allow_growth = True
|
| 23 |
+
|
| 24 |
+
sess = tf.compat.v1.Session(config=sess_config)
|
| 25 |
+
saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables())
|
| 26 |
+
init = tf.compat.v1.global_variables_initializer()
|
| 27 |
+
sess.run(init)
|
| 28 |
+
latest_check_point = tf.train.latest_checkpoint(f"{PROJECT_ROOT}/model/190703-214543")
|
| 29 |
+
print(f"restoring model {latest_check_point}")
|
| 30 |
+
saver.restore(sess, latest_check_point)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def calc_timestamp(timestamp, sec):
|
| 34 |
+
timestamp = datetime.strptime(timestamp, "%Y-%m-%dT%H:%M:%S.%f") + timedelta(seconds=sec)
|
| 35 |
+
return timestamp.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3]
|
| 36 |
+
|
| 37 |
+
def format_picks(picks, dt):
|
| 38 |
+
picks_ = []
|
| 39 |
+
for pick in picks:
|
| 40 |
+
for idxs, probs in zip(pick.p_idx, pick.p_prob):
|
| 41 |
+
for idx, prob in zip(idxs, probs):
|
| 42 |
+
picks_.append(
|
| 43 |
+
{
|
| 44 |
+
"id": pick.fname,
|
| 45 |
+
"timestamp": calc_timestamp(pick.t0, float(idx) * dt),
|
| 46 |
+
"prob": prob,
|
| 47 |
+
"type": "p",
|
| 48 |
+
}
|
| 49 |
+
)
|
| 50 |
+
for idxs, probs in zip(pick.s_idx, pick.s_prob):
|
| 51 |
+
for idx, prob in zip(idxs, probs):
|
| 52 |
+
picks_.append(
|
| 53 |
+
{
|
| 54 |
+
"id": pick.fname,
|
| 55 |
+
"timestamp": calc_timestamp(pick.t0, float(idx) * dt),
|
| 56 |
+
"prob": prob,
|
| 57 |
+
"type": "s",
|
| 58 |
+
}
|
| 59 |
+
)
|
| 60 |
+
return picks_
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
stream = obspy.read()
|
| 64 |
+
stream = stream.sort() ## Assume it is NPZ sorted
|
| 65 |
+
assert(len(stream) == 3)
|
| 66 |
+
data = []
|
| 67 |
+
for trace in stream:
|
| 68 |
+
data.append(trace.data)
|
| 69 |
+
data = np.array(data).T
|
| 70 |
+
assert(data.shape[-1] == 3)
|
| 71 |
+
|
| 72 |
+
# data_id = stream[0].get_id()[:-1]
|
| 73 |
+
# timestamp = stream[0].stats.starttime.datetime.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3]
|
| 74 |
+
|
| 75 |
+
data = np.stack([data for i in range(10)]) ## Assume 10 windows
|
| 76 |
+
data = data[:,:,np.newaxis,:] ## batch, nt, dummy_dim, channel
|
| 77 |
+
print(f"{data.shape = }")
|
| 78 |
+
data = (data - data.mean(axis=1, keepdims=True))/data.std(axis=1, keepdims=True)
|
| 79 |
+
|
| 80 |
+
feed = {model.X: data, model.drop_rate: 0, model.is_training: False}
|
| 81 |
+
preds = sess.run(model.preds, feed_dict=feed)
|
| 82 |
+
|
| 83 |
+
picks = extract_picks(preds, fnames=None, station_ids=None, t0=None)
|
| 84 |
+
picks = format_picks(picks, dt=0.01)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
picks = pd.DataFrame(picks)
|
| 88 |
+
print(picks)
|
phasenet/test_app.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import requests
|
| 2 |
+
import obspy
|
| 3 |
+
import numpy as np
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
|
| 7 |
+
### Start running the model first:
|
| 8 |
+
### FLASK_ENV=development FLASK_APP=app.py flask run
|
| 9 |
+
|
| 10 |
+
def read_data(mseed):
|
| 11 |
+
data = []
|
| 12 |
+
mseed = mseed.sort()
|
| 13 |
+
for c in ["E", "N", "Z"]:
|
| 14 |
+
data.append(mseed.select(channel="*"+c)[0].data)
|
| 15 |
+
return np.array(data).T
|
| 16 |
+
|
| 17 |
+
timestamp = lambda x: x.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3]
|
| 18 |
+
|
| 19 |
+
## prepare some test data
|
| 20 |
+
mseed = obspy.read()
|
| 21 |
+
data = []
|
| 22 |
+
for i in range(1):
|
| 23 |
+
data.append(read_data(mseed))
|
| 24 |
+
data = {
|
| 25 |
+
"id": ["test01"],
|
| 26 |
+
"timestamp": [timestamp(datetime.now())],
|
| 27 |
+
"vec": np.array(data).tolist(),
|
| 28 |
+
"dt": 0.01
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
## run prediction
|
| 32 |
+
print(data["id"])
|
| 33 |
+
resp = requests.get("http://localhost:8000/predict", json=data)
|
| 34 |
+
# picks = resp.json()["picks"]
|
| 35 |
+
print(resp.json())
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
## plot figure
|
| 39 |
+
plt.figure()
|
| 40 |
+
plt.plot(np.array(data["data"])[0,:,1])
|
| 41 |
+
ylim = plt.ylim()
|
| 42 |
+
plt.plot([picks[0][0][0], picks[0][0][0]], ylim, label="P-phase")
|
| 43 |
+
plt.text(picks[0][0][0], ylim[1]*0.9, f"{picks[0][1][0]:.2f}")
|
| 44 |
+
plt.plot([picks[0][2][0], picks[0][2][0]], ylim, label="S-phase")
|
| 45 |
+
plt.text(picks[0][2][0], ylim[1]*0.9, f"{picks[0][1][0]:.2f}")
|
| 46 |
+
plt.legend()
|
| 47 |
+
plt.savefig("test.png")
|
phasenet/train.py
ADDED
|
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import tensorflow as tf
|
| 3 |
+
tf.compat.v1.disable_eager_execution()
|
| 4 |
+
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
|
| 5 |
+
import argparse, os, time, logging
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
import pandas as pd
|
| 8 |
+
import multiprocessing
|
| 9 |
+
from functools import partial
|
| 10 |
+
import pickle
|
| 11 |
+
from model import UNet, ModelConfig
|
| 12 |
+
from data_reader import DataReader_train, DataReader_test
|
| 13 |
+
from postprocess import extract_picks, save_picks, save_picks_json, extract_amplitude, convert_true_picks, calc_performance
|
| 14 |
+
from visulization import plot_waveform
|
| 15 |
+
from util import EMA, LMA
|
| 16 |
+
|
| 17 |
+
def read_args():
|
| 18 |
+
|
| 19 |
+
parser = argparse.ArgumentParser()
|
| 20 |
+
parser.add_argument("--mode", default="train", help="train/train_valid/test/debug")
|
| 21 |
+
parser.add_argument("--epochs", default=100, type=int, help="number of epochs (default: 10)")
|
| 22 |
+
parser.add_argument("--batch_size", default=20, type=int, help="batch size")
|
| 23 |
+
parser.add_argument("--learning_rate", default=0.01, type=float, help="learning rate")
|
| 24 |
+
parser.add_argument("--drop_rate", default=0.0, type=float, help="dropout rate")
|
| 25 |
+
parser.add_argument("--decay_step", default=-1, type=int, help="decay step")
|
| 26 |
+
parser.add_argument("--decay_rate", default=0.9, type=float, help="decay rate")
|
| 27 |
+
parser.add_argument("--momentum", default=0.9, type=float, help="momentum")
|
| 28 |
+
parser.add_argument("--optimizer", default="adam", help="optimizer: adam, momentum")
|
| 29 |
+
parser.add_argument("--summary", default=True, type=bool, help="summary")
|
| 30 |
+
parser.add_argument("--class_weights", nargs="+", default=[1, 1, 1], type=float, help="class weights")
|
| 31 |
+
parser.add_argument("--model_dir", default=None, help="Checkpoint directory (default: None)")
|
| 32 |
+
parser.add_argument("--load_model", action="store_true", help="Load checkpoint")
|
| 33 |
+
parser.add_argument("--log_dir", default="log", help="Log directory (default: log)")
|
| 34 |
+
parser.add_argument("--num_plots", default=10, type=int, help="Plotting training results")
|
| 35 |
+
parser.add_argument("--min_p_prob", default=0.3, type=float, help="Probability threshold for P pick")
|
| 36 |
+
parser.add_argument("--min_s_prob", default=0.3, type=float, help="Probability threshold for S pick")
|
| 37 |
+
parser.add_argument("--format", default="numpy", help="Input data format")
|
| 38 |
+
parser.add_argument("--train_dir", default="./dataset/waveform_train/", help="Input file directory")
|
| 39 |
+
parser.add_argument("--train_list", default="./dataset/waveform.csv", help="Input csv file")
|
| 40 |
+
parser.add_argument("--valid_dir", default=None, help="Input file directory")
|
| 41 |
+
parser.add_argument("--valid_list", default=None, help="Input csv file")
|
| 42 |
+
parser.add_argument("--test_dir", default=None, help="Input file directory")
|
| 43 |
+
parser.add_argument("--test_list", default=None, help="Input csv file")
|
| 44 |
+
parser.add_argument("--result_dir", default="results", help="result directory")
|
| 45 |
+
parser.add_argument("--plot_figure", action="store_true", help="If plot figure for test")
|
| 46 |
+
parser.add_argument("--save_prob", action="store_true", help="If save result for test")
|
| 47 |
+
args = parser.parse_args()
|
| 48 |
+
|
| 49 |
+
return args
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def train_fn(args, data_reader, data_reader_valid=None):
|
| 53 |
+
|
| 54 |
+
current_time = time.strftime("%y%m%d-%H%M%S")
|
| 55 |
+
log_dir = os.path.join(args.log_dir, current_time)
|
| 56 |
+
if not os.path.exists(log_dir):
|
| 57 |
+
os.makedirs(log_dir)
|
| 58 |
+
logging.info("Training log: {}".format(log_dir))
|
| 59 |
+
model_dir = os.path.join(log_dir, 'models')
|
| 60 |
+
os.makedirs(model_dir)
|
| 61 |
+
|
| 62 |
+
figure_dir = os.path.join(log_dir, 'figures')
|
| 63 |
+
if not os.path.exists(figure_dir):
|
| 64 |
+
os.makedirs(figure_dir)
|
| 65 |
+
|
| 66 |
+
config = ModelConfig(X_shape=data_reader.X_shape, Y_shape=data_reader.Y_shape)
|
| 67 |
+
if args.decay_step == -1:
|
| 68 |
+
args.decay_step = data_reader.num_data // args.batch_size
|
| 69 |
+
config.update_args(args)
|
| 70 |
+
with open(os.path.join(log_dir, 'config.log'), 'w') as fp:
|
| 71 |
+
fp.write('\n'.join("%s: %s" % item for item in vars(config).items()))
|
| 72 |
+
|
| 73 |
+
with tf.compat.v1.name_scope('Input_Batch'):
|
| 74 |
+
dataset = data_reader.dataset(args.batch_size, shuffle=True).repeat()
|
| 75 |
+
batch = tf.compat.v1.data.make_one_shot_iterator(dataset).get_next()
|
| 76 |
+
if data_reader_valid is not None:
|
| 77 |
+
dataset_valid = data_reader_valid.dataset(args.batch_size, shuffle=False).repeat()
|
| 78 |
+
valid_batch = tf.compat.v1.data.make_one_shot_iterator(dataset_valid).get_next()
|
| 79 |
+
|
| 80 |
+
model = UNet(config, input_batch=batch)
|
| 81 |
+
sess_config = tf.compat.v1.ConfigProto()
|
| 82 |
+
sess_config.gpu_options.allow_growth = True
|
| 83 |
+
# sess_config.log_device_placement = False
|
| 84 |
+
|
| 85 |
+
with tf.compat.v1.Session(config=sess_config) as sess:
|
| 86 |
+
|
| 87 |
+
summary_writer = tf.compat.v1.summary.FileWriter(log_dir, sess.graph)
|
| 88 |
+
saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables(), max_to_keep=5)
|
| 89 |
+
init = tf.compat.v1.global_variables_initializer()
|
| 90 |
+
sess.run(init)
|
| 91 |
+
|
| 92 |
+
if args.model_dir is not None:
|
| 93 |
+
logging.info("restoring models...")
|
| 94 |
+
latest_check_point = tf.train.latest_checkpoint(args.model_dir)
|
| 95 |
+
saver.restore(sess, latest_check_point)
|
| 96 |
+
|
| 97 |
+
if args.plot_figure:
|
| 98 |
+
multiprocessing.set_start_method('spawn')
|
| 99 |
+
pool = multiprocessing.Pool(multiprocessing.cpu_count())
|
| 100 |
+
|
| 101 |
+
flog = open(os.path.join(log_dir, 'loss.log'), 'w')
|
| 102 |
+
train_loss = EMA(0.9)
|
| 103 |
+
best_valid_loss = np.inf
|
| 104 |
+
for epoch in range(args.epochs):
|
| 105 |
+
progressbar = tqdm(range(0, data_reader.num_data, args.batch_size), desc="{}: epoch {}".format(log_dir.split("/")[-1], epoch))
|
| 106 |
+
for _ in progressbar:
|
| 107 |
+
loss_batch, _, _ = sess.run([model.loss, model.train_op, model.global_step],
|
| 108 |
+
feed_dict={model.drop_rate: args.drop_rate, model.is_training: True})
|
| 109 |
+
train_loss(loss_batch)
|
| 110 |
+
progressbar.set_description("{}: epoch {}, loss={:.6f}, mean={:.6f}".format(log_dir.split("/")[-1], epoch, loss_batch, train_loss.value))
|
| 111 |
+
flog.write("epoch: {}, mean loss: {}\n".format(epoch, train_loss.value))
|
| 112 |
+
|
| 113 |
+
if data_reader_valid is not None:
|
| 114 |
+
valid_loss = LMA()
|
| 115 |
+
progressbar = tqdm(range(0, data_reader_valid.num_data, args.batch_size), desc="Valid:")
|
| 116 |
+
for _ in progressbar:
|
| 117 |
+
loss_batch, preds_batch, X_batch, Y_batch, fname_batch = sess.run([model.loss, model.preds, valid_batch[0], valid_batch[1], valid_batch[2]],
|
| 118 |
+
feed_dict={model.drop_rate: 0, model.is_training: False})
|
| 119 |
+
valid_loss(loss_batch)
|
| 120 |
+
progressbar.set_description("valid, loss={:.6f}, mean={:.6f}".format(loss_batch, valid_loss.value))
|
| 121 |
+
if valid_loss.value < best_valid_loss:
|
| 122 |
+
best_valid_loss = valid_loss.value
|
| 123 |
+
saver.save(sess, os.path.join(model_dir, "model_{}.ckpt".format(epoch)))
|
| 124 |
+
flog.write("Valid: mean loss: {}\n".format(valid_loss.value))
|
| 125 |
+
else:
|
| 126 |
+
loss_batch, preds_batch, X_batch, Y_batch, fname_batch = sess.run([model.loss, model.preds, batch[0], batch[1], batch[2]],
|
| 127 |
+
feed_dict={model.drop_rate: 0, model.is_training: False})
|
| 128 |
+
saver.save(sess, os.path.join(model_dir, "model_{}.ckpt".format(epoch)))
|
| 129 |
+
|
| 130 |
+
if args.plot_figure:
|
| 131 |
+
pool.starmap(
|
| 132 |
+
partial(
|
| 133 |
+
plot_waveform,
|
| 134 |
+
figure_dir=figure_dir,
|
| 135 |
+
),
|
| 136 |
+
zip(X_batch, preds_batch, [x.decode() for x in fname_batch], Y_batch),
|
| 137 |
+
)
|
| 138 |
+
# plot_waveform(X_batch, preds_batch, fname_batch, label=Y_batch, figure_dir=figure_dir)
|
| 139 |
+
flog.flush()
|
| 140 |
+
|
| 141 |
+
flog.close()
|
| 142 |
+
|
| 143 |
+
return 0
|
| 144 |
+
|
| 145 |
+
def test_fn(args, data_reader):
|
| 146 |
+
current_time = time.strftime("%y%m%d-%H%M%S")
|
| 147 |
+
logging.info("{} log: {}".format(args.mode, current_time))
|
| 148 |
+
if args.model_dir is None:
|
| 149 |
+
logging.error(f"model_dir = None!")
|
| 150 |
+
return -1
|
| 151 |
+
if not os.path.exists(args.result_dir):
|
| 152 |
+
os.makedirs(args.result_dir)
|
| 153 |
+
figure_dir=os.path.join(args.result_dir, "figures")
|
| 154 |
+
if not os.path.exists(figure_dir):
|
| 155 |
+
os.makedirs(figure_dir)
|
| 156 |
+
|
| 157 |
+
config = ModelConfig(X_shape=data_reader.X_shape, Y_shape=data_reader.Y_shape)
|
| 158 |
+
config.update_args(args)
|
| 159 |
+
with open(os.path.join(args.result_dir, 'config.log'), 'w') as fp:
|
| 160 |
+
fp.write('\n'.join("%s: %s" % item for item in vars(config).items()))
|
| 161 |
+
|
| 162 |
+
with tf.compat.v1.name_scope('Input_Batch'):
|
| 163 |
+
dataset = data_reader.dataset(args.batch_size, shuffle=False)
|
| 164 |
+
batch = tf.compat.v1.data.make_one_shot_iterator(dataset).get_next()
|
| 165 |
+
|
| 166 |
+
model = UNet(config, input_batch=batch, mode='test')
|
| 167 |
+
sess_config = tf.compat.v1.ConfigProto()
|
| 168 |
+
sess_config.gpu_options.allow_growth = True
|
| 169 |
+
# sess_config.log_device_placement = False
|
| 170 |
+
|
| 171 |
+
with tf.compat.v1.Session(config=sess_config) as sess:
|
| 172 |
+
|
| 173 |
+
saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables())
|
| 174 |
+
init = tf.compat.v1.global_variables_initializer()
|
| 175 |
+
sess.run(init)
|
| 176 |
+
|
| 177 |
+
logging.info("restoring models...")
|
| 178 |
+
latest_check_point = tf.train.latest_checkpoint(args.model_dir)
|
| 179 |
+
if latest_check_point is None:
|
| 180 |
+
logging.error(f"No models found in model_dir: {args.model_dir}")
|
| 181 |
+
return -1
|
| 182 |
+
saver.restore(sess, latest_check_point)
|
| 183 |
+
|
| 184 |
+
flog = open(os.path.join(args.result_dir, 'loss.log'), 'w')
|
| 185 |
+
test_loss = LMA()
|
| 186 |
+
progressbar = tqdm(range(0, data_reader.num_data, args.batch_size), desc=args.mode)
|
| 187 |
+
picks = []
|
| 188 |
+
true_picks = []
|
| 189 |
+
for _ in progressbar:
|
| 190 |
+
loss_batch, preds_batch, X_batch, Y_batch, fname_batch, itp_batch, its_batch \
|
| 191 |
+
= sess.run([model.loss, model.preds, batch[0], batch[1], batch[2], batch[3], batch[4]],
|
| 192 |
+
feed_dict={model.drop_rate: 0, model.is_training: False})
|
| 193 |
+
|
| 194 |
+
test_loss(loss_batch)
|
| 195 |
+
progressbar.set_description("{}, loss={:.6f}, mean loss={:6f}".format(args.mode, loss_batch, test_loss.value))
|
| 196 |
+
|
| 197 |
+
picks_ = extract_picks(preds_batch, fname_batch)
|
| 198 |
+
picks.extend(picks_)
|
| 199 |
+
true_picks.extend(convert_true_picks(fname_batch, itp_batch, its_batch))
|
| 200 |
+
if args.plot_figure:
|
| 201 |
+
plot_waveform(data_reader.config, X_batch, preds_batch, label=Y_batch, fname=fname_batch,
|
| 202 |
+
itp=itp_batch, its=its_batch, figure_dir=figure_dir)
|
| 203 |
+
|
| 204 |
+
save_picks(picks, args.result_dir)
|
| 205 |
+
metrics = calc_performance(picks, true_picks, tol=3.0, dt=data_reader.config.dt)
|
| 206 |
+
flog.write("mean loss: {}\n".format(test_loss))
|
| 207 |
+
flog.close()
|
| 208 |
+
|
| 209 |
+
return 0
|
| 210 |
+
|
| 211 |
+
def main(args):
|
| 212 |
+
|
| 213 |
+
logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO)
|
| 214 |
+
coord = tf.train.Coordinator()
|
| 215 |
+
|
| 216 |
+
if (args.mode == "train") or (args.mode == "train_valid"):
|
| 217 |
+
with tf.compat.v1.name_scope('create_inputs'):
|
| 218 |
+
data_reader = DataReader_train(format=args.format,
|
| 219 |
+
data_dir=args.train_dir,
|
| 220 |
+
data_list=args.train_list)
|
| 221 |
+
if args.mode == "train_valid":
|
| 222 |
+
data_reader_valid = DataReader_train(format=args.format,
|
| 223 |
+
data_dir=args.valid_dir,
|
| 224 |
+
data_list=args.valid_list)
|
| 225 |
+
logging.info("Dataset size: train {}, valid {}".format(data_reader.num_data, data_reader_valid.num_data))
|
| 226 |
+
else:
|
| 227 |
+
data_reader_valid = None
|
| 228 |
+
logging.info("Dataset size: train {}".format(data_reader.num_data))
|
| 229 |
+
train_fn(args, data_reader, data_reader_valid)
|
| 230 |
+
|
| 231 |
+
elif args.mode == "test":
|
| 232 |
+
with tf.compat.v1.name_scope('create_inputs'):
|
| 233 |
+
data_reader = DataReader_test(format=args.format,
|
| 234 |
+
data_dir=args.test_dir,
|
| 235 |
+
data_list=args.test_list)
|
| 236 |
+
test_fn(args, data_reader)
|
| 237 |
+
|
| 238 |
+
else:
|
| 239 |
+
print("mode should be: train, train_valid, or test")
|
| 240 |
+
|
| 241 |
+
return
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
if __name__ == '__main__':
|
| 245 |
+
args = read_args()
|
| 246 |
+
main(args)
|
phasenet/util.py
ADDED
|
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import division
|
| 2 |
+
import matplotlib
|
| 3 |
+
matplotlib.use('agg')
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
import numpy as np
|
| 6 |
+
import os
|
| 7 |
+
from data_reader import DataConfig
|
| 8 |
+
from detect_peaks import detect_peaks
|
| 9 |
+
import logging
|
| 10 |
+
|
| 11 |
+
class EMA(object):
|
| 12 |
+
def __init__(self, alpha):
|
| 13 |
+
self.alpha = alpha
|
| 14 |
+
self.x = 0.
|
| 15 |
+
self.count = 0
|
| 16 |
+
|
| 17 |
+
@property
|
| 18 |
+
def value(self):
|
| 19 |
+
return self.x
|
| 20 |
+
|
| 21 |
+
def __call__(self, x):
|
| 22 |
+
if self.count == 0:
|
| 23 |
+
self.x = x
|
| 24 |
+
else:
|
| 25 |
+
self.x = self.alpha * self.x + (1 - self.alpha) * x
|
| 26 |
+
self.count += 1
|
| 27 |
+
return self.x
|
| 28 |
+
|
| 29 |
+
class LMA(object):
|
| 30 |
+
def __init__(self):
|
| 31 |
+
self.x = 0.
|
| 32 |
+
self.count = 0
|
| 33 |
+
|
| 34 |
+
@property
|
| 35 |
+
def value(self):
|
| 36 |
+
return self.x
|
| 37 |
+
|
| 38 |
+
def __call__(self, x):
|
| 39 |
+
if self.count == 0:
|
| 40 |
+
self.x = x
|
| 41 |
+
else:
|
| 42 |
+
self.x += (x - self.x)/(self.count+1)
|
| 43 |
+
self.count += 1
|
| 44 |
+
return self.x
|
| 45 |
+
|
| 46 |
+
def detect_peaks_thread(i, pred, fname=None, result_dir=None, args=None):
|
| 47 |
+
if args is None:
|
| 48 |
+
itp, prob_p = detect_peaks(pred[i,:,0,1], mph=0.5, mpd=0.5/DataConfig().dt, show=False)
|
| 49 |
+
its, prob_s = detect_peaks(pred[i,:,0,2], mph=0.5, mpd=0.5/DataConfig().dt, show=False)
|
| 50 |
+
else:
|
| 51 |
+
itp, prob_p = detect_peaks(pred[i,:,0,1], mph=args.tp_prob, mpd=0.5/DataConfig().dt, show=False)
|
| 52 |
+
its, prob_s = detect_peaks(pred[i,:,0,2], mph=args.ts_prob, mpd=0.5/DataConfig().dt, show=False)
|
| 53 |
+
if (fname is not None) and (result_dir is not None):
|
| 54 |
+
# np.savez(os.path.join(result_dir, fname[i].decode().split('/')[-1]), pred=pred[i], itp=itp, its=its, prob_p=prob_p, prob_s=prob_s)
|
| 55 |
+
try:
|
| 56 |
+
np.savez(os.path.join(result_dir, fname[i].decode()), pred=pred[i], itp=itp, its=its, prob_p=prob_p, prob_s=prob_s)
|
| 57 |
+
except FileNotFoundError:
|
| 58 |
+
#if not os.path.exists(os.path.dirname(os.path.join(result_dir, fname[i].decode()))):
|
| 59 |
+
os.makedirs(os.path.dirname(os.path.join(result_dir, fname[i].decode())), exist_ok=True)
|
| 60 |
+
np.savez(os.path.join(result_dir, fname[i].decode()), pred=pred[i], itp=itp, its=its, prob_p=prob_p, prob_s=prob_s)
|
| 61 |
+
return [(itp, prob_p), (its, prob_s)]
|
| 62 |
+
|
| 63 |
+
def plot_result_thread(i, pred, X, Y=None, itp=None, its=None,
|
| 64 |
+
itp_pred=None, its_pred=None, fname=None, figure_dir=None):
|
| 65 |
+
dt = DataConfig().dt
|
| 66 |
+
t = np.arange(0, pred.shape[1]) * dt
|
| 67 |
+
box = dict(boxstyle='round', facecolor='white', alpha=1)
|
| 68 |
+
text_loc = [0.05, 0.77]
|
| 69 |
+
|
| 70 |
+
plt.figure(i)
|
| 71 |
+
plt.clf()
|
| 72 |
+
# fig_size = plt.gcf().get_size_inches()
|
| 73 |
+
# plt.gcf().set_size_inches(fig_size*[1, 1.2])
|
| 74 |
+
plt.subplot(411)
|
| 75 |
+
plt.plot(t, X[i, :, 0, 0], 'k', label='E', linewidth=0.5)
|
| 76 |
+
plt.autoscale(enable=True, axis='x', tight=True)
|
| 77 |
+
tmp_min = np.min(X[i, :, 0, 0])
|
| 78 |
+
tmp_max = np.max(X[i, :, 0, 0])
|
| 79 |
+
if (itp is not None) and (its is not None):
|
| 80 |
+
for j in range(len(itp[i])):
|
| 81 |
+
if j == 0:
|
| 82 |
+
plt.plot([itp[i][j]*dt, itp[i][j]*dt], [tmp_min, tmp_max], 'b', label='P', linewidth=0.5)
|
| 83 |
+
else:
|
| 84 |
+
plt.plot([itp[i][j]*dt, itp[i][j]*dt], [tmp_min, tmp_max], 'b', linewidth=0.5)
|
| 85 |
+
for j in range(len(its[i])):
|
| 86 |
+
if j == 0:
|
| 87 |
+
plt.plot([its[i][j]*dt, its[i][j]*dt], [tmp_min, tmp_max], 'r', label='S', linewidth=0.5)
|
| 88 |
+
else:
|
| 89 |
+
plt.plot([its[i][j]*dt, its[i][j]*dt], [tmp_min, tmp_max], 'r', linewidth=0.5)
|
| 90 |
+
plt.ylabel('Amplitude')
|
| 91 |
+
plt.legend(loc='upper right', fontsize='small')
|
| 92 |
+
plt.gca().set_xticklabels([])
|
| 93 |
+
plt.text(text_loc[0], text_loc[1], '(i)', horizontalalignment='center',
|
| 94 |
+
transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
|
| 95 |
+
plt.subplot(412)
|
| 96 |
+
plt.plot(t, X[i, :, 0, 1], 'k', label='N', linewidth=0.5)
|
| 97 |
+
plt.autoscale(enable=True, axis='x', tight=True)
|
| 98 |
+
tmp_min = np.min(X[i, :, 0, 1])
|
| 99 |
+
tmp_max = np.max(X[i, :, 0, 1])
|
| 100 |
+
if (itp is not None) and (its is not None):
|
| 101 |
+
for j in range(len(itp[i])):
|
| 102 |
+
plt.plot([itp[i][j]*dt, itp[i][j]*dt], [tmp_min, tmp_max], 'b', linewidth=0.5)
|
| 103 |
+
for j in range(len(its[i])):
|
| 104 |
+
plt.plot([its[i][j]*dt, its[i][j]*dt], [tmp_min, tmp_max], 'r', linewidth=0.5)
|
| 105 |
+
plt.ylabel('Amplitude')
|
| 106 |
+
plt.legend(loc='upper right', fontsize='small')
|
| 107 |
+
plt.gca().set_xticklabels([])
|
| 108 |
+
plt.text(text_loc[0], text_loc[1], '(ii)', horizontalalignment='center',
|
| 109 |
+
transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
|
| 110 |
+
plt.subplot(413)
|
| 111 |
+
plt.plot(t, X[i, :, 0, 2], 'k', label='Z', linewidth=0.5)
|
| 112 |
+
plt.autoscale(enable=True, axis='x', tight=True)
|
| 113 |
+
tmp_min = np.min(X[i, :, 0, 2])
|
| 114 |
+
tmp_max = np.max(X[i, :, 0, 2])
|
| 115 |
+
if (itp is not None) and (its is not None):
|
| 116 |
+
for j in range(len(itp[i])):
|
| 117 |
+
plt.plot([itp[i][j]*dt, itp[i][j]*dt], [tmp_min, tmp_max], 'b', linewidth=0.5)
|
| 118 |
+
for j in range(len(its[i])):
|
| 119 |
+
plt.plot([its[i][j]*dt, its[i][j]*dt], [tmp_min, tmp_max], 'r', linewidth=0.5)
|
| 120 |
+
plt.ylabel('Amplitude')
|
| 121 |
+
plt.legend(loc='upper right', fontsize='small')
|
| 122 |
+
plt.gca().set_xticklabels([])
|
| 123 |
+
plt.text(text_loc[0], text_loc[1], '(iii)', horizontalalignment='center',
|
| 124 |
+
transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
|
| 125 |
+
plt.subplot(414)
|
| 126 |
+
if Y is not None:
|
| 127 |
+
plt.plot(t, Y[i, :, 0, 1], 'b', label='P', linewidth=0.5)
|
| 128 |
+
plt.plot(t, Y[i, :, 0, 2], 'r', label='S', linewidth=0.5)
|
| 129 |
+
plt.plot(t, pred[i, :, 0, 1], '--g', label='$\hat{P}$', linewidth=0.5)
|
| 130 |
+
plt.plot(t, pred[i, :, 0, 2], '-.m', label='$\hat{S}$', linewidth=0.5)
|
| 131 |
+
plt.autoscale(enable=True, axis='x', tight=True)
|
| 132 |
+
if (itp_pred is not None) and (its_pred is not None):
|
| 133 |
+
for j in range(len(itp_pred)):
|
| 134 |
+
plt.plot([itp_pred[j]*dt, itp_pred[j]*dt], [-0.1, 1.1], '--g', linewidth=0.5)
|
| 135 |
+
for j in range(len(its_pred)):
|
| 136 |
+
plt.plot([its_pred[j]*dt, its_pred[j]*dt], [-0.1, 1.1], '-.m', linewidth=0.5)
|
| 137 |
+
plt.ylim([-0.05, 1.05])
|
| 138 |
+
plt.text(text_loc[0], text_loc[1], '(iv)', horizontalalignment='center',
|
| 139 |
+
transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
|
| 140 |
+
plt.legend(loc='upper right', fontsize='small')
|
| 141 |
+
plt.xlabel('Time (s)')
|
| 142 |
+
plt.ylabel('Probability')
|
| 143 |
+
|
| 144 |
+
plt.tight_layout()
|
| 145 |
+
plt.gcf().align_labels()
|
| 146 |
+
|
| 147 |
+
try:
|
| 148 |
+
plt.savefig(os.path.join(figure_dir,
|
| 149 |
+
fname[i].decode().rstrip('.npz')+'.png'),
|
| 150 |
+
bbox_inches='tight')
|
| 151 |
+
except FileNotFoundError:
|
| 152 |
+
#if not os.path.exists(os.path.dirname(os.path.join(figure_dir, fname[i].decode()))):
|
| 153 |
+
os.makedirs(os.path.dirname(os.path.join(figure_dir, fname[i].decode())), exist_ok=True)
|
| 154 |
+
plt.savefig(os.path.join(figure_dir,
|
| 155 |
+
fname[i].decode().rstrip('.npz')+'.png'),
|
| 156 |
+
bbox_inches='tight')
|
| 157 |
+
#plt.savefig(os.path.join(figure_dir,
|
| 158 |
+
# fname[i].decode().split('/')[-1].rstrip('.npz')+'.png'),
|
| 159 |
+
# bbox_inches='tight')
|
| 160 |
+
# plt.savefig(os.path.join(figure_dir,
|
| 161 |
+
# fname[i].decode().split('/')[-1].rstrip('.npz')+'.pdf'),
|
| 162 |
+
# bbox_inches='tight')
|
| 163 |
+
plt.close(i)
|
| 164 |
+
return 0
|
| 165 |
+
|
| 166 |
+
def postprocessing_thread(i, pred, X, Y=None, itp=None, its=None, fname=None, result_dir=None, figure_dir=None, args=None):
|
| 167 |
+
(itp_pred, prob_p), (its_pred, prob_s) = detect_peaks_thread(i, pred, fname, result_dir, args)
|
| 168 |
+
if (fname is not None) and (figure_dir is not None):
|
| 169 |
+
plot_result_thread(i, pred, X, Y, itp, its, itp_pred, its_pred, fname, figure_dir)
|
| 170 |
+
return [(itp_pred, prob_p), (its_pred, prob_s)]
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def clean_queue(picks):
|
| 174 |
+
clean = []
|
| 175 |
+
for i in range(len(picks)):
|
| 176 |
+
tmp = []
|
| 177 |
+
for j in picks[i]:
|
| 178 |
+
if j != 0:
|
| 179 |
+
tmp.append(j)
|
| 180 |
+
clean.append(tmp)
|
| 181 |
+
return clean
|
| 182 |
+
|
| 183 |
+
def clean_queue_thread(picks):
|
| 184 |
+
tmp = []
|
| 185 |
+
for j in picks:
|
| 186 |
+
if j != 0:
|
| 187 |
+
tmp.append(j)
|
| 188 |
+
return tmp
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def metrics(TP, nP, nT):
|
| 192 |
+
'''
|
| 193 |
+
TP: true positive
|
| 194 |
+
nP: number of positive picks
|
| 195 |
+
nT: number of true picks
|
| 196 |
+
'''
|
| 197 |
+
precision = TP / nP
|
| 198 |
+
recall = TP / nT
|
| 199 |
+
F1 = 2* precision * recall / (precision + recall)
|
| 200 |
+
return [precision, recall, F1]
|
| 201 |
+
|
| 202 |
+
def correct_picks(picks, true_p, true_s, tol):
|
| 203 |
+
dt = DataConfig().dt
|
| 204 |
+
if len(true_p) != len(true_s):
|
| 205 |
+
print("The length of true P and S pickers are not the same")
|
| 206 |
+
num = len(true_p)
|
| 207 |
+
TP_p = 0; TP_s = 0; nP_p = 0; nP_s = 0; nT_p = 0; nT_s = 0
|
| 208 |
+
diff_p = []; diff_s = []
|
| 209 |
+
for i in range(num):
|
| 210 |
+
nT_p += len(true_p[i])
|
| 211 |
+
nT_s += len(true_s[i])
|
| 212 |
+
nP_p += len(picks[i][0][0])
|
| 213 |
+
nP_s += len(picks[i][1][0])
|
| 214 |
+
|
| 215 |
+
if len(true_p[i]) > 1 or len(true_s[i]) > 1:
|
| 216 |
+
print(i, picks[i], true_p[i], true_s[i])
|
| 217 |
+
tmp_p = np.array(picks[i][0][0]) - np.array(true_p[i])[:,np.newaxis]
|
| 218 |
+
tmp_s = np.array(picks[i][1][0]) - np.array(true_s[i])[:,np.newaxis]
|
| 219 |
+
TP_p += np.sum(np.abs(tmp_p) < tol/dt)
|
| 220 |
+
TP_s += np.sum(np.abs(tmp_s) < tol/dt)
|
| 221 |
+
diff_p.append(tmp_p[np.abs(tmp_p) < 0.5/dt])
|
| 222 |
+
diff_s.append(tmp_s[np.abs(tmp_s) < 0.5/dt])
|
| 223 |
+
|
| 224 |
+
return [TP_p, TP_s, nP_p, nP_s, nT_p, nT_s, diff_p, diff_s]
|
| 225 |
+
|
| 226 |
+
def calculate_metrics(picks, itp, its, tol=0.1):
|
| 227 |
+
TP_p, TP_s, nP_p, nP_s, nT_p, nT_s, diff_p, diff_s = correct_picks(picks, itp, its, tol)
|
| 228 |
+
precision_p, recall_p, f1_p = metrics(TP_p, nP_p, nT_p)
|
| 229 |
+
precision_s, recall_s, f1_s = metrics(TP_s, nP_s, nT_s)
|
| 230 |
+
|
| 231 |
+
logging.info("Total records: {}".format(len(picks)))
|
| 232 |
+
logging.info("P-phase:")
|
| 233 |
+
logging.info("True={}, Predict={}, TruePositive={}".format(nT_p, nP_p, TP_p))
|
| 234 |
+
logging.info("Precision={:.3f}, Recall={:.3f}, F1={:.3f}".format(precision_p, recall_p, f1_p))
|
| 235 |
+
logging.info("S-phase:")
|
| 236 |
+
logging.info("True={}, Predict={}, TruePositive={}".format(nT_s, nP_s, TP_s))
|
| 237 |
+
logging.info("Precision={:.3f}, Recall={:.3f}, F1={:.3f}".format(precision_s, recall_s, f1_s))
|
| 238 |
+
return [precision_p, recall_p, f1_p], [precision_s, recall_s, f1_s]
|
phasenet/visulization.py
ADDED
|
@@ -0,0 +1,481 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib
|
| 2 |
+
matplotlib.use("agg")
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
import numpy as np
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def plot_residual(diff_p, diff_s, diff_ps, tol, dt):
|
| 9 |
+
box = dict(boxstyle='round', facecolor='white', alpha=1)
|
| 10 |
+
text_loc = [0.07, 0.95]
|
| 11 |
+
plt.figure(figsize=(8,3))
|
| 12 |
+
plt.subplot(1,3,1)
|
| 13 |
+
plt.hist(diff_p, range=(-tol, tol), bins=int(2*tol/dt)+1, facecolor='b', edgecolor='black', linewidth=1)
|
| 14 |
+
plt.ylabel("Number of picks")
|
| 15 |
+
plt.xlabel("Residual (s)")
|
| 16 |
+
plt.text(text_loc[0], text_loc[1], "(i)", horizontalalignment='left', verticalalignment='top',
|
| 17 |
+
transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
|
| 18 |
+
plt.title("P-phase")
|
| 19 |
+
plt.subplot(1,3,2)
|
| 20 |
+
plt.hist(diff_s, range=(-tol, tol), bins=int(2*tol/dt)+1, facecolor='b', edgecolor='black', linewidth=1)
|
| 21 |
+
plt.xlabel("Residual (s)")
|
| 22 |
+
plt.text(text_loc[0], text_loc[1], "(ii)", horizontalalignment='left', verticalalignment='top',
|
| 23 |
+
transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
|
| 24 |
+
plt.title("S-phase")
|
| 25 |
+
plt.subplot(1,3,3)
|
| 26 |
+
plt.hist(diff_ps, range=(-tol, tol), bins=int(2*tol/dt)+1, facecolor='b', edgecolor='black', linewidth=1)
|
| 27 |
+
plt.xlabel("Residual (s)")
|
| 28 |
+
plt.text(text_loc[0], text_loc[1], "(iii)", horizontalalignment='left', verticalalignment='top',
|
| 29 |
+
transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
|
| 30 |
+
plt.title("PS-phase")
|
| 31 |
+
plt.tight_layout()
|
| 32 |
+
plt.savefig("residuals.png", dpi=300)
|
| 33 |
+
plt.savefig("residuals.pdf")
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# def plot_waveform(config, data, pred, label=None,
|
| 37 |
+
# itp=None, its=None, itps=None,
|
| 38 |
+
# itp_pred=None, its_pred=None, itps_pred=None,
|
| 39 |
+
# fname=None, figure_dir="./", epoch=0, max_fig=10):
|
| 40 |
+
|
| 41 |
+
# dt = config.dt if hasattr(config, "dt") else 1.0
|
| 42 |
+
# t = np.arange(0, pred.shape[1]) * dt
|
| 43 |
+
# box = dict(boxstyle='round', facecolor='white', alpha=1)
|
| 44 |
+
# text_loc = [0.05, 0.77]
|
| 45 |
+
# if fname is None:
|
| 46 |
+
# fname = [f"{epoch:03d}_{i:02d}" for i in range(len(data))]
|
| 47 |
+
# else:
|
| 48 |
+
# fname = [fname[i].decode().rstrip(".npz") for i in range(len(fname))]
|
| 49 |
+
|
| 50 |
+
# for i in range(min(len(data), max_fig)):
|
| 51 |
+
# plt.figure(i)
|
| 52 |
+
|
| 53 |
+
# plt.subplot(411)
|
| 54 |
+
# plt.plot(t, data[i, :, 0, 0], 'k', label='E', linewidth=0.5)
|
| 55 |
+
# plt.autoscale(enable=True, axis='x', tight=True)
|
| 56 |
+
# tmp_min = np.min(data[i, :, 0, 0])
|
| 57 |
+
# tmp_max = np.max(data[i, :, 0, 0])
|
| 58 |
+
# if (itp is not None) and (its is not None):
|
| 59 |
+
# for j in range(len(itp[i])):
|
| 60 |
+
# lb = "P" if j==0 else ""
|
| 61 |
+
# plt.plot([itp[i][j]*dt, itp[i][j]*dt], [tmp_min, tmp_max], 'C0', label=lb, linewidth=0.5)
|
| 62 |
+
# for j in range(len(its[i])):
|
| 63 |
+
# lb = "S" if j==0 else ""
|
| 64 |
+
# plt.plot([its[i][j]*dt, its[i][j]*dt], [tmp_min, tmp_max], 'C1', label=lb, linewidth=0.5)
|
| 65 |
+
# if (itps is not None):
|
| 66 |
+
# for j in range(len(itps[i])):
|
| 67 |
+
# lb = "PS" if j==0 else ""
|
| 68 |
+
# plt.plot([itps[i][j]*dt, its[i][j]*dt], [tmp_min, tmp_max], 'C2', label=lb, linewidth=0.5)
|
| 69 |
+
# plt.ylabel('Amplitude')
|
| 70 |
+
# plt.legend(loc='upper right', fontsize='small')
|
| 71 |
+
# plt.gca().set_xticklabels([])
|
| 72 |
+
# plt.text(text_loc[0], text_loc[1], '(i)', horizontalalignment='center',
|
| 73 |
+
# transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
|
| 74 |
+
|
| 75 |
+
# plt.subplot(412)
|
| 76 |
+
# plt.plot(t, data[i, :, 0, 1], 'k', label='N', linewidth=0.5)
|
| 77 |
+
# plt.autoscale(enable=True, axis='x', tight=True)
|
| 78 |
+
# tmp_min = np.min(data[i, :, 0, 1])
|
| 79 |
+
# tmp_max = np.max(data[i, :, 0, 1])
|
| 80 |
+
# if (itp is not None) and (its is not None):
|
| 81 |
+
# for j in range(len(itp[i])):
|
| 82 |
+
# lb = "P" if j==0 else ""
|
| 83 |
+
# plt.plot([itp[i][j]*dt, itp[i][j]*dt], [tmp_min, tmp_max], 'C0', label=lb, linewidth=0.5)
|
| 84 |
+
# for j in range(len(its[i])):
|
| 85 |
+
# lb = "S" if j==0 else ""
|
| 86 |
+
# plt.plot([its[i][j]*dt, its[i][j]*dt], [tmp_min, tmp_max], 'C1', label=lb, linewidth=0.5)
|
| 87 |
+
# if (itps is not None):
|
| 88 |
+
# for j in range(len(itps[i])):
|
| 89 |
+
# lb = "PS" if j==0 else ""
|
| 90 |
+
# plt.plot([itps[i][j]*dt, itps[i][j]*dt], [tmp_min, tmp_max], 'C2', label=lb, linewidth=0.5)
|
| 91 |
+
# plt.ylabel('Amplitude')
|
| 92 |
+
# plt.legend(loc='upper right', fontsize='small')
|
| 93 |
+
# plt.gca().set_xticklabels([])
|
| 94 |
+
# plt.text(text_loc[0], text_loc[1], '(ii)', horizontalalignment='center',
|
| 95 |
+
# transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
|
| 96 |
+
|
| 97 |
+
# plt.subplot(413)
|
| 98 |
+
# plt.plot(t, data[i, :, 0, 2], 'k', label='Z', linewidth=0.5)
|
| 99 |
+
# plt.autoscale(enable=True, axis='x', tight=True)
|
| 100 |
+
# tmp_min = np.min(data[i, :, 0, 2])
|
| 101 |
+
# tmp_max = np.max(data[i, :, 0, 2])
|
| 102 |
+
# if (itp is not None) and (its is not None):
|
| 103 |
+
# for j in range(len(itp[i])):
|
| 104 |
+
# lb = "P" if j==0 else ""
|
| 105 |
+
# plt.plot([itp[i][j]*dt, itp[i][j]*dt], [tmp_min, tmp_max], 'C0', label=lb, linewidth=0.5)
|
| 106 |
+
# for j in range(len(its[i])):
|
| 107 |
+
# lb = "S" if j==0 else ""
|
| 108 |
+
# plt.plot([its[i][j]*dt, its[i][j]*dt], [tmp_min, tmp_max], 'C1', label=lb, linewidth=0.5)
|
| 109 |
+
# if (itps is not None):
|
| 110 |
+
# for j in range(len(itps[i])):
|
| 111 |
+
# lb = "PS" if j==0 else ""
|
| 112 |
+
# plt.plot([itps[i][j]*dt, itps[i][j]*dt], [tmp_min, tmp_max], 'C2', label=lb, linewidth=0.5)
|
| 113 |
+
# plt.ylabel('Amplitude')
|
| 114 |
+
# plt.legend(loc='upper right', fontsize='small')
|
| 115 |
+
# plt.gca().set_xticklabels([])
|
| 116 |
+
# plt.text(text_loc[0], text_loc[1], '(iii)', horizontalalignment='center',
|
| 117 |
+
# transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
|
| 118 |
+
|
| 119 |
+
# plt.subplot(414)
|
| 120 |
+
# if label is not None:
|
| 121 |
+
# plt.plot(t, label[i, :, 0, 1], 'C0', label='P', linewidth=1)
|
| 122 |
+
# plt.plot(t, label[i, :, 0, 2], 'C1', label='S', linewidth=1)
|
| 123 |
+
# if label.shape[-1] == 4:
|
| 124 |
+
# plt.plot(t, label[i, :, 0, 3], 'C2', label='PS', linewidth=1)
|
| 125 |
+
# plt.plot(t, pred[i, :, 0, 1], '--C0', label='$\hat{P}$', linewidth=1)
|
| 126 |
+
# plt.plot(t, pred[i, :, 0, 2], '--C1', label='$\hat{S}$', linewidth=1)
|
| 127 |
+
# if pred.shape[-1] == 4:
|
| 128 |
+
# plt.plot(t, pred[i, :, 0, 3], '--C2', label='$\hat{PS}$', linewidth=1)
|
| 129 |
+
# plt.autoscale(enable=True, axis='x', tight=True)
|
| 130 |
+
# if (itp_pred is not None) and (its_pred is not None) :
|
| 131 |
+
# for j in range(len(itp_pred)):
|
| 132 |
+
# plt.plot([itp_pred[j]*dt, itp_pred[j]*dt], [-0.1, 1.1], '--C0', linewidth=1)
|
| 133 |
+
# for j in range(len(its_pred)):
|
| 134 |
+
# plt.plot([its_pred[j]*dt, its_pred[j]*dt], [-0.1, 1.1], '--C1', linewidth=1)
|
| 135 |
+
# if (itps_pred is not None):
|
| 136 |
+
# for j in range(len(itps_pred)):
|
| 137 |
+
# plt.plot([itps_pred[j]*dt, itps_pred[j]*dt], [-0.1, 1.1], '--C2', linewidth=1)
|
| 138 |
+
# plt.ylim([-0.05, 1.05])
|
| 139 |
+
# plt.text(text_loc[0], text_loc[1], '(iv)', horizontalalignment='center',
|
| 140 |
+
# transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
|
| 141 |
+
# plt.legend(loc='upper right', fontsize='small', ncol=2)
|
| 142 |
+
# plt.xlabel('Time (s)')
|
| 143 |
+
# plt.ylabel('Probability')
|
| 144 |
+
# plt.tight_layout()
|
| 145 |
+
# plt.gcf().align_labels()
|
| 146 |
+
|
| 147 |
+
# try:
|
| 148 |
+
# plt.savefig(os.path.join(figure_dir, fname[i]+'.png'), bbox_inches='tight')
|
| 149 |
+
# except FileNotFoundError:
|
| 150 |
+
# os.makedirs(os.path.dirname(os.path.join(figure_dir, fname[i])), exist_ok=True)
|
| 151 |
+
# plt.savefig(os.path.join(figure_dir, fname[i]+'.png'), bbox_inches='tight')
|
| 152 |
+
|
| 153 |
+
# plt.close(i)
|
| 154 |
+
# return 0
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def plot_waveform(data, pred, fname, label=None,
|
| 158 |
+
itp=None, its=None, itps=None,
|
| 159 |
+
itp_pred=None, its_pred=None, itps_pred=None,
|
| 160 |
+
figure_dir="./", dt=0.01):
|
| 161 |
+
|
| 162 |
+
t = np.arange(0, pred.shape[0]) * dt
|
| 163 |
+
box = dict(boxstyle='round', facecolor='white', alpha=1)
|
| 164 |
+
text_loc = [0.05, 0.77]
|
| 165 |
+
|
| 166 |
+
plt.figure()
|
| 167 |
+
|
| 168 |
+
plt.subplot(411)
|
| 169 |
+
plt.plot(t, data[:, 0, 0], 'k', label='E', linewidth=0.5)
|
| 170 |
+
plt.autoscale(enable=True, axis='x', tight=True)
|
| 171 |
+
tmp_min = np.min(data[:, 0, 0])
|
| 172 |
+
tmp_max = np.max(data[:, 0, 0])
|
| 173 |
+
if (itp is not None) and (its is not None):
|
| 174 |
+
for j in range(len(itp)):
|
| 175 |
+
lb = "P" if j==0 else ""
|
| 176 |
+
plt.plot([itp[j]*dt, itp[j]*dt], [tmp_min, tmp_max], 'C0', label=lb, linewidth=0.5)
|
| 177 |
+
for j in range(len(its[i])):
|
| 178 |
+
lb = "S" if j==0 else ""
|
| 179 |
+
plt.plot([its[j]*dt, its[j]*dt], [tmp_min, tmp_max], 'C1', label=lb, linewidth=0.5)
|
| 180 |
+
if (itps is not None):
|
| 181 |
+
for j in range(len(itps)):
|
| 182 |
+
lb = "PS" if j==0 else ""
|
| 183 |
+
plt.plot([itps[j]*dt, its[j]*dt], [tmp_min, tmp_max], 'C2', label=lb, linewidth=0.5)
|
| 184 |
+
plt.ylabel('Amplitude')
|
| 185 |
+
plt.legend(loc='upper right', fontsize='small')
|
| 186 |
+
plt.gca().set_xticklabels([])
|
| 187 |
+
plt.text(text_loc[0], text_loc[1], '(i)', horizontalalignment='center',
|
| 188 |
+
transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
|
| 189 |
+
|
| 190 |
+
plt.subplot(412)
|
| 191 |
+
plt.plot(t, data[:, 0, 1], 'k', label='N', linewidth=0.5)
|
| 192 |
+
plt.autoscale(enable=True, axis='x', tight=True)
|
| 193 |
+
tmp_min = np.min(data[:, 0, 1])
|
| 194 |
+
tmp_max = np.max(data[:, 0, 1])
|
| 195 |
+
if (itp is not None) and (its is not None):
|
| 196 |
+
for j in range(len(itp)):
|
| 197 |
+
lb = "P" if j==0 else ""
|
| 198 |
+
plt.plot([itp[j]*dt, itp[j]*dt], [tmp_min, tmp_max], 'C0', label=lb, linewidth=0.5)
|
| 199 |
+
for j in range(len(its)):
|
| 200 |
+
lb = "S" if j==0 else ""
|
| 201 |
+
plt.plot([its[j]*dt, its[j]*dt], [tmp_min, tmp_max], 'C1', label=lb, linewidth=0.5)
|
| 202 |
+
if (itps is not None):
|
| 203 |
+
for j in range(len(itps)):
|
| 204 |
+
lb = "PS" if j==0 else ""
|
| 205 |
+
plt.plot([itps[j]*dt, itps[j]*dt], [tmp_min, tmp_max], 'C2', label=lb, linewidth=0.5)
|
| 206 |
+
plt.ylabel('Amplitude')
|
| 207 |
+
plt.legend(loc='upper right', fontsize='small')
|
| 208 |
+
plt.gca().set_xticklabels([])
|
| 209 |
+
plt.text(text_loc[0], text_loc[1], '(ii)', horizontalalignment='center',
|
| 210 |
+
transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
|
| 211 |
+
|
| 212 |
+
plt.subplot(413)
|
| 213 |
+
plt.plot(t, data[:, 0, 2], 'k', label='Z', linewidth=0.5)
|
| 214 |
+
plt.autoscale(enable=True, axis='x', tight=True)
|
| 215 |
+
tmp_min = np.min(data[:, 0, 2])
|
| 216 |
+
tmp_max = np.max(data[:, 0, 2])
|
| 217 |
+
if (itp is not None) and (its is not None):
|
| 218 |
+
for j in range(len(itp)):
|
| 219 |
+
lb = "P" if j==0 else ""
|
| 220 |
+
plt.plot([itp[j]*dt, itp[j]*dt], [tmp_min, tmp_max], 'C0', label=lb, linewidth=0.5)
|
| 221 |
+
for j in range(len(its)):
|
| 222 |
+
lb = "S" if j==0 else ""
|
| 223 |
+
plt.plot([its[j]*dt, its[j]*dt], [tmp_min, tmp_max], 'C1', label=lb, linewidth=0.5)
|
| 224 |
+
if (itps is not None):
|
| 225 |
+
for j in range(len(itps)):
|
| 226 |
+
lb = "PS" if j==0 else ""
|
| 227 |
+
plt.plot([itps[j]*dt, itps[j]*dt], [tmp_min, tmp_max], 'C2', label=lb, linewidth=0.5)
|
| 228 |
+
plt.ylabel('Amplitude')
|
| 229 |
+
plt.legend(loc='upper right', fontsize='small')
|
| 230 |
+
plt.gca().set_xticklabels([])
|
| 231 |
+
plt.text(text_loc[0], text_loc[1], '(iii)', horizontalalignment='center',
|
| 232 |
+
transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
|
| 233 |
+
|
| 234 |
+
plt.subplot(414)
|
| 235 |
+
if label is not None:
|
| 236 |
+
plt.plot(t, label[:, 0, 1], 'C0', label='P', linewidth=1)
|
| 237 |
+
plt.plot(t, label[:, 0, 2], 'C1', label='S', linewidth=1)
|
| 238 |
+
if label.shape[-1] == 4:
|
| 239 |
+
plt.plot(t, label[:, 0, 3], 'C2', label='PS', linewidth=1)
|
| 240 |
+
plt.plot(t, pred[:, 0, 1], '--C0', label='$\hat{P}$', linewidth=1)
|
| 241 |
+
plt.plot(t, pred[:, 0, 2], '--C1', label='$\hat{S}$', linewidth=1)
|
| 242 |
+
if pred.shape[-1] == 4:
|
| 243 |
+
plt.plot(t, pred[:, 0, 3], '--C2', label='$\hat{PS}$', linewidth=1)
|
| 244 |
+
plt.autoscale(enable=True, axis='x', tight=True)
|
| 245 |
+
if (itp_pred is not None) and (its_pred is not None) :
|
| 246 |
+
for j in range(len(itp_pred)):
|
| 247 |
+
plt.plot([itp_pred[j]*dt, itp_pred[j]*dt], [-0.1, 1.1], '--C0', linewidth=1)
|
| 248 |
+
for j in range(len(its_pred)):
|
| 249 |
+
plt.plot([its_pred[j]*dt, its_pred[j]*dt], [-0.1, 1.1], '--C1', linewidth=1)
|
| 250 |
+
if (itps_pred is not None):
|
| 251 |
+
for j in range(len(itps_pred)):
|
| 252 |
+
plt.plot([itps_pred[j]*dt, itps_pred[j]*dt], [-0.1, 1.1], '--C2', linewidth=1)
|
| 253 |
+
plt.ylim([-0.05, 1.05])
|
| 254 |
+
plt.text(text_loc[0], text_loc[1], '(iv)', horizontalalignment='center',
|
| 255 |
+
transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
|
| 256 |
+
plt.legend(loc='upper right', fontsize='small', ncol=2)
|
| 257 |
+
plt.xlabel('Time (s)')
|
| 258 |
+
plt.ylabel('Probability')
|
| 259 |
+
plt.tight_layout()
|
| 260 |
+
plt.gcf().align_labels()
|
| 261 |
+
|
| 262 |
+
try:
|
| 263 |
+
plt.savefig(os.path.join(figure_dir, fname+'.png'), bbox_inches='tight')
|
| 264 |
+
except FileNotFoundError:
|
| 265 |
+
os.makedirs(os.path.dirname(os.path.join(figure_dir, fname)), exist_ok=True)
|
| 266 |
+
plt.savefig(os.path.join(figure_dir, fname+'.png'), bbox_inches='tight')
|
| 267 |
+
|
| 268 |
+
plt.close()
|
| 269 |
+
return 0
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def plot_array(config, data, pred, label=None,
|
| 273 |
+
itp=None, its=None, itps=None,
|
| 274 |
+
itp_pred=None, its_pred=None, itps_pred=None,
|
| 275 |
+
fname=None, figure_dir="./", epoch=0):
|
| 276 |
+
|
| 277 |
+
dt = config.dt if hasattr(config, "dt") else 1.0
|
| 278 |
+
t = np.arange(0, pred.shape[1]) * dt
|
| 279 |
+
box = dict(boxstyle='round', facecolor='white', alpha=1)
|
| 280 |
+
text_loc = [0.05, 0.95]
|
| 281 |
+
if fname is None:
|
| 282 |
+
fname = [f"{epoch:03d}_{i:03d}" for i in range(len(data))]
|
| 283 |
+
else:
|
| 284 |
+
fname = [fname[i].decode().rstrip(".npz") for i in range(len(fname))]
|
| 285 |
+
|
| 286 |
+
for i in range(len(data)):
|
| 287 |
+
plt.figure(i, figsize=(10, 5))
|
| 288 |
+
plt.clf()
|
| 289 |
+
|
| 290 |
+
plt.subplot(121)
|
| 291 |
+
for j in range(data.shape[-2]):
|
| 292 |
+
plt.plot(t, data[i, :, j, 0]/10 + j, 'k', label='E', linewidth=0.5)
|
| 293 |
+
plt.autoscale(enable=True, axis='x', tight=True)
|
| 294 |
+
tmp_min = np.min(data[i, :, 0, 0])
|
| 295 |
+
tmp_max = np.max(data[i, :, 0, 0])
|
| 296 |
+
plt.xlabel('Time (s)')
|
| 297 |
+
plt.ylabel('Amplitude')
|
| 298 |
+
# plt.legend(loc='upper right', fontsize='small')
|
| 299 |
+
# plt.gca().set_xticklabels([])
|
| 300 |
+
plt.text(text_loc[0], text_loc[1], '(i)', horizontalalignment='center', verticalalignment="top",
|
| 301 |
+
transform=plt.gca().transAxes, fontsize="large", fontweight="normal", bbox=box)
|
| 302 |
+
|
| 303 |
+
plt.subplot(122)
|
| 304 |
+
for j in range(pred.shape[-2]):
|
| 305 |
+
if label is not None:
|
| 306 |
+
plt.plot(t, label[i, :, j, 1]+j, 'C2', label='P', linewidth=0.5)
|
| 307 |
+
plt.plot(t, label[i, :, j, 2]+j, 'C3', label='S', linewidth=0.5)
|
| 308 |
+
# plt.plot(t, label[i, :, j, 0]+j, 'C4', label='N', linewidth=0.5)
|
| 309 |
+
plt.plot(t, pred[i, :, j, 1]+j, 'C0', label='$\hat{P}$', linewidth=1)
|
| 310 |
+
plt.plot(t, pred[i, :, j, 2]+j, 'C1', label='$\hat{S}$', linewidth=1)
|
| 311 |
+
plt.autoscale(enable=True, axis='x', tight=True)
|
| 312 |
+
if (itp_pred is not None) and (its_pred is not None) and (itps_pred is not None):
|
| 313 |
+
for j in range(len(itp_pred)):
|
| 314 |
+
plt.plot([itp_pred[j]*dt, itp_pred[j]*dt], [-0.1, 1.1], '--C0', linewidth=1)
|
| 315 |
+
for j in range(len(its_pred)):
|
| 316 |
+
plt.plot([its_pred[j]*dt, its_pred[j]*dt], [-0.1, 1.1], '--C1', linewidth=1)
|
| 317 |
+
for j in range(len(itps_pred)):
|
| 318 |
+
plt.plot([itps_pred[j]*dt, itps_pred[j]*dt], [-0.1, 1.1], '--C2', linewidth=1)
|
| 319 |
+
# plt.ylim([-0.05, 1.05])
|
| 320 |
+
plt.text(text_loc[0], text_loc[1], '(ii)', horizontalalignment='center', verticalalignment="top",
|
| 321 |
+
transform=plt.gca().transAxes, fontsize="large", fontweight="normal", bbox=box)
|
| 322 |
+
# plt.legend(loc='upper right', fontsize='small', ncol=2)
|
| 323 |
+
plt.xlabel('Time (s)')
|
| 324 |
+
plt.ylabel('Probability')
|
| 325 |
+
plt.tight_layout()
|
| 326 |
+
plt.gcf().align_labels()
|
| 327 |
+
|
| 328 |
+
try:
|
| 329 |
+
plt.savefig(os.path.join(figure_dir, fname[i]+'.png'), bbox_inches='tight')
|
| 330 |
+
except FileNotFoundError:
|
| 331 |
+
os.makedirs(os.path.dirname(os.path.join(figure_dir, fname[i])), exist_ok=True)
|
| 332 |
+
plt.savefig(os.path.join(figure_dir, fname[i]+'.png'), bbox_inches='tight')
|
| 333 |
+
|
| 334 |
+
plt.close(i)
|
| 335 |
+
return 0
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
def plot_spectrogram(config, data, pred, label=None,
|
| 339 |
+
itp=None, its=None, itps=None,
|
| 340 |
+
itp_pred=None, its_pred=None, itps_pred=None,
|
| 341 |
+
time=None, freq=None,
|
| 342 |
+
fname=None, figure_dir="./", epoch=0):
|
| 343 |
+
|
| 344 |
+
# dt = config.dt
|
| 345 |
+
# df = config.df
|
| 346 |
+
# t = np.arange(0, data.shape[1]) * dt
|
| 347 |
+
# f = np.arange(0, data.shape[2]) * df
|
| 348 |
+
t, f = time, freq
|
| 349 |
+
dt = t[1] - t[0]
|
| 350 |
+
box = dict(boxstyle='round', facecolor='white', alpha=1)
|
| 351 |
+
text_loc = [0.05, 0.75]
|
| 352 |
+
if fname is None:
|
| 353 |
+
fname = [f"{i:03d}" for i in range(len(data))]
|
| 354 |
+
elif type(fname[0]) is bytes:
|
| 355 |
+
fname = [f.decode() for f in fname]
|
| 356 |
+
|
| 357 |
+
numbers = ["(i)", "(ii)", "(iii)", "(iv)"]
|
| 358 |
+
for i in range(len(data)):
|
| 359 |
+
fig = plt.figure(i)
|
| 360 |
+
# gs = fig.add_gridspec(4, 1)
|
| 361 |
+
|
| 362 |
+
for j in range(3):
|
| 363 |
+
# fig.add_subplot(gs[j, 0])
|
| 364 |
+
plt.subplot(4,1,j+1)
|
| 365 |
+
plt.pcolormesh(t, f, np.abs(data[i, :, :, j]+1j*data[i, :, :, j+3]).T, vmax=2*np.std(data[i, :, :, j]+1j*data[i, :, :, j+3]), cmap="jet", shading='auto')
|
| 366 |
+
plt.autoscale(enable=True, axis='x', tight=True)
|
| 367 |
+
plt.gca().set_xticklabels([])
|
| 368 |
+
if j == 1:
|
| 369 |
+
plt.ylabel('Frequency (Hz)')
|
| 370 |
+
plt.text(text_loc[0], text_loc[1], numbers[j], horizontalalignment='center',
|
| 371 |
+
transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
|
| 372 |
+
|
| 373 |
+
# fig.add_subplot(gs[-1, 0])
|
| 374 |
+
plt.subplot(4,1,4)
|
| 375 |
+
if label is not None:
|
| 376 |
+
plt.plot(t, label[i, :, 0, 1], '--C0', linewidth=1)
|
| 377 |
+
plt.plot(t, label[i, :, 0, 2], '--C3', linewidth=1)
|
| 378 |
+
plt.plot(t, label[i, :, 0, 3], '--C1', linewidth=1)
|
| 379 |
+
plt.plot(t, pred[i, :, 0, 1], 'C0', label='P', linewidth=1)
|
| 380 |
+
plt.plot(t, pred[i, :, 0, 2], 'C3', label='S', linewidth=1)
|
| 381 |
+
plt.plot(t, pred[i, :, 0, 3], 'C1', label='PS', linewidth=1)
|
| 382 |
+
plt.plot(t, t*0, 'k', linewidth=1)
|
| 383 |
+
plt.autoscale(enable=True, axis='x', tight=True)
|
| 384 |
+
if (itp_pred is not None) and (its_pred is not None) and (itps_pred is not None):
|
| 385 |
+
for j in range(len(itp_pred)):
|
| 386 |
+
plt.plot([itp_pred[j]*dt, itp_pred[j]*dt], [-0.1, 1.1], ':C3', linewidth=1)
|
| 387 |
+
for j in range(len(its_pred)):
|
| 388 |
+
plt.plot([its_pred[j]*dt, its_pred[j]*dt], [-0.1, 1.1], '-.C6', linewidth=1)
|
| 389 |
+
for j in range(len(itps_pred)):
|
| 390 |
+
plt.plot([itps_pred[j]*dt, itps_pred[j]*dt], [-0.1, 1.1], '--C8', linewidth=1)
|
| 391 |
+
plt.ylim([-0.05, 1.05])
|
| 392 |
+
plt.text(text_loc[0], text_loc[1], numbers[-1], horizontalalignment='center',
|
| 393 |
+
transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
|
| 394 |
+
plt.legend(loc='upper right', fontsize='small', ncol=1)
|
| 395 |
+
plt.xlabel('Time (s)')
|
| 396 |
+
plt.ylabel('Probability')
|
| 397 |
+
# plt.tight_layout()
|
| 398 |
+
plt.gcf().align_labels()
|
| 399 |
+
|
| 400 |
+
try:
|
| 401 |
+
plt.savefig(os.path.join(figure_dir, f'{epoch:02d}_'+fname[i]+'.png'), bbox_inches='tight')
|
| 402 |
+
except FileNotFoundError:
|
| 403 |
+
os.makedirs(os.path.dirname(os.path.join(figure_dir, fname[i])), exist_ok=True)
|
| 404 |
+
plt.savefig(os.path.join(figure_dir, f'{epoch:02d}_'+fname[i]+'.png'), bbox_inches='tight')
|
| 405 |
+
|
| 406 |
+
plt.close(i)
|
| 407 |
+
return 0
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
def plot_spectrogram_waveform(config, spectrogram, waveform, pred, label=None,
|
| 411 |
+
itp=None, its=None, itps=None, picks=None,
|
| 412 |
+
time=None, freq=None,
|
| 413 |
+
fname=None, figure_dir="./", epoch=0):
|
| 414 |
+
|
| 415 |
+
# dt = config.dt
|
| 416 |
+
# df = config.df
|
| 417 |
+
# t = np.arange(0, spectrogram.shape[1]) * dt
|
| 418 |
+
# f = np.arange(0, spectrogram.shape[2]) * df
|
| 419 |
+
t, f = time, freq
|
| 420 |
+
dt = t[1] - t[0]
|
| 421 |
+
box = dict(boxstyle='round', facecolor='white', alpha=1)
|
| 422 |
+
text_loc = [0.02, 0.90]
|
| 423 |
+
if fname is None:
|
| 424 |
+
fname = [f"{i:03d}" for i in range(len(spectrogram))]
|
| 425 |
+
elif type(fname[0]) is bytes:
|
| 426 |
+
fname = [f.decode() for f in fname]
|
| 427 |
+
|
| 428 |
+
numbers = ["(i)", "(ii)", "(iii)", "(iv)", "(v)", "(vi)", "(vii)"]
|
| 429 |
+
for i in range(len(spectrogram)):
|
| 430 |
+
fig = plt.figure(i, figsize=(6.4, 10))
|
| 431 |
+
# gs = fig.add_gridspec(4, 1)
|
| 432 |
+
|
| 433 |
+
for j in range(3):
|
| 434 |
+
# fig.add_subplot(gs[j, 0])
|
| 435 |
+
plt.subplot(7,1,j*2+1)
|
| 436 |
+
plt.plot(waveform[i,:,j], 'k', linewidth=0.5)
|
| 437 |
+
plt.autoscale(enable=True, axis='x', tight=True)
|
| 438 |
+
plt.gca().set_xticklabels([])
|
| 439 |
+
plt.ylabel('')
|
| 440 |
+
plt.text(text_loc[0], text_loc[1], numbers[j*2], horizontalalignment='left', verticalalignment='top',
|
| 441 |
+
transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
|
| 442 |
+
|
| 443 |
+
for j in range(3):
|
| 444 |
+
# fig.add_subplot(gs[j, 0])
|
| 445 |
+
plt.subplot(7,1,j*2+2)
|
| 446 |
+
plt.pcolormesh(t, f, np.abs(spectrogram[i, :, :, j]+1j*spectrogram[i, :, :, j+3]).T, vmax=2*np.std(spectrogram[i, :, :, j]+1j*spectrogram[i, :, :, j+3]), cmap="jet", shading='auto')
|
| 447 |
+
plt.autoscale(enable=True, axis='x', tight=True)
|
| 448 |
+
plt.gca().set_xticklabels([])
|
| 449 |
+
if j == 1:
|
| 450 |
+
plt.ylabel('Frequency (Hz) or Amplitude')
|
| 451 |
+
plt.text(text_loc[0], text_loc[1], numbers[j*2+1], horizontalalignment='left', verticalalignment='top',
|
| 452 |
+
transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
|
| 453 |
+
|
| 454 |
+
# fig.add_subplot(gs[-1, 0])
|
| 455 |
+
plt.subplot(7,1,7)
|
| 456 |
+
if label is not None:
|
| 457 |
+
plt.plot(t, label[i, :, 0, 1], '--C0', linewidth=1)
|
| 458 |
+
plt.plot(t, label[i, :, 0, 2], '--C3', linewidth=1)
|
| 459 |
+
plt.plot(t, label[i, :, 0, 3], '--C1', linewidth=1)
|
| 460 |
+
plt.plot(t, pred[i, :, 0, 1], 'C0', label='P', linewidth=1)
|
| 461 |
+
plt.plot(t, pred[i, :, 0, 2], 'C3', label='S', linewidth=1)
|
| 462 |
+
plt.plot(t, pred[i, :, 0, 3], 'C1', label='PS', linewidth=1)
|
| 463 |
+
plt.plot(t, t*0, 'k', linewidth=1)
|
| 464 |
+
plt.autoscale(enable=True, axis='x', tight=True)
|
| 465 |
+
plt.ylim([-0.05, 1.05])
|
| 466 |
+
plt.text(text_loc[0], text_loc[1], numbers[-1], horizontalalignment='left', verticalalignment='top',
|
| 467 |
+
transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
|
| 468 |
+
plt.legend(loc='upper right', fontsize='small', ncol=1)
|
| 469 |
+
plt.xlabel('Time (s)')
|
| 470 |
+
plt.ylabel('Probability')
|
| 471 |
+
# plt.tight_layout()
|
| 472 |
+
plt.gcf().align_labels()
|
| 473 |
+
|
| 474 |
+
try:
|
| 475 |
+
plt.savefig(os.path.join(figure_dir, f'{epoch:02d}_'+fname[i]+'.png'), bbox_inches='tight')
|
| 476 |
+
except FileNotFoundError:
|
| 477 |
+
os.makedirs(os.path.dirname(os.path.join(figure_dir, fname[i])), exist_ok=True)
|
| 478 |
+
plt.savefig(os.path.join(figure_dir, f'{epoch:02d}_'+fname[i]+'.png'), bbox_inches='tight')
|
| 479 |
+
|
| 480 |
+
plt.close(i)
|
| 481 |
+
return 0
|
requirements.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
tensorflow
|
| 2 |
+
matplotlib
|
| 3 |
+
pandas
|
| 4 |
+
tqdm
|
| 5 |
+
scipy
|
| 6 |
+
obspy
|
| 7 |
+
|
setup.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import io
|
| 2 |
+
import os
|
| 3 |
+
import re
|
| 4 |
+
import sys
|
| 5 |
+
from shutil import rmtree
|
| 6 |
+
from typing import Tuple, List
|
| 7 |
+
|
| 8 |
+
from setuptools import Command, find_packages, setup
|
| 9 |
+
|
| 10 |
+
# Package meta-data.
|
| 11 |
+
name = "PhaseNet"
|
| 12 |
+
description = "PhaseNet"
|
| 13 |
+
url = ""
|
| 14 |
+
email = "[email protected]"
|
| 15 |
+
author = "Weiqiang Zhu"
|
| 16 |
+
requires_python = ">=3.6.0"
|
| 17 |
+
current_dir = os.path.abspath(os.path.dirname(__file__))
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def get_version():
|
| 21 |
+
version_file = os.path.join(current_dir, "phasenet", "__init__.py")
|
| 22 |
+
with io.open(version_file, encoding="utf-8") as f:
|
| 23 |
+
return re.search(r'^__version__ = [\'"]([^\'"]*)[\'"]', f.read(), re.M).group(1)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# What packages are required for this module to be executed?
|
| 27 |
+
try:
|
| 28 |
+
with open(os.path.join(current_dir, "requirements.txt"), encoding="utf-8") as f:
|
| 29 |
+
required = f.read().split("\n")
|
| 30 |
+
except FileNotFoundError:
|
| 31 |
+
required = []
|
| 32 |
+
|
| 33 |
+
# What packages are optional?
|
| 34 |
+
extras = {"test": ["pytest"]}
|
| 35 |
+
|
| 36 |
+
version = get_version()
|
| 37 |
+
|
| 38 |
+
about = {"__version__": version}
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def get_test_requirements():
|
| 42 |
+
requirements = ["pytest"]
|
| 43 |
+
if sys.version_info < (3, 3):
|
| 44 |
+
requirements.append("mock")
|
| 45 |
+
return requirements
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def get_long_description():
|
| 49 |
+
# base_dir = os.path.abspath(os.path.dirname(__file__))
|
| 50 |
+
# with io.open(os.path.join(base_dir, "README.md"), encoding="utf-8") as f:
|
| 51 |
+
# return f.read()
|
| 52 |
+
return ""
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class UploadCommand(Command):
|
| 56 |
+
"""Support setup.py upload."""
|
| 57 |
+
|
| 58 |
+
description = "Build and publish the package."
|
| 59 |
+
user_options: List[Tuple] = []
|
| 60 |
+
|
| 61 |
+
@staticmethod
|
| 62 |
+
def status(s):
|
| 63 |
+
"""Print things in bold."""
|
| 64 |
+
print(s)
|
| 65 |
+
|
| 66 |
+
def initialize_options(self):
|
| 67 |
+
pass
|
| 68 |
+
|
| 69 |
+
def finalize_options(self):
|
| 70 |
+
pass
|
| 71 |
+
|
| 72 |
+
def run(self):
|
| 73 |
+
try:
|
| 74 |
+
self.status("Removing previous builds...")
|
| 75 |
+
rmtree(os.path.join(current_dir, "dist"))
|
| 76 |
+
except OSError:
|
| 77 |
+
pass
|
| 78 |
+
|
| 79 |
+
self.status("Building Source and Wheel (universal) distribution...")
|
| 80 |
+
os.system(f"{sys.executable} setup.py sdist bdist_wheel --universal")
|
| 81 |
+
|
| 82 |
+
self.status("Uploading the package to PyPI via Twine...")
|
| 83 |
+
os.system("twine upload dist/*")
|
| 84 |
+
|
| 85 |
+
self.status("Pushing git tags...")
|
| 86 |
+
os.system("git tag v{}".format(about["__version__"]))
|
| 87 |
+
os.system("git push --tags")
|
| 88 |
+
|
| 89 |
+
sys.exit()
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
setup(
|
| 93 |
+
name=name,
|
| 94 |
+
version=version,
|
| 95 |
+
description=description,
|
| 96 |
+
long_description=get_long_description(),
|
| 97 |
+
long_description_content_type="text/markdown",
|
| 98 |
+
author="Weiqiang Zhu",
|
| 99 |
+
author_email = "[email protected]",
|
| 100 |
+
license="GPL-3.0",
|
| 101 |
+
url=url,
|
| 102 |
+
packages=find_packages(exclude=["tests", "docs", "dataset", "model", "log"]),
|
| 103 |
+
install_requires=required,
|
| 104 |
+
extras_require=extras,
|
| 105 |
+
classifiers=[
|
| 106 |
+
"License :: OSI Approved :: BSD License",
|
| 107 |
+
"Intended Audience :: Developers",
|
| 108 |
+
"Intended Audience :: Science/Research",
|
| 109 |
+
"Operating System :: OS Independent",
|
| 110 |
+
"Programming Language :: Python",
|
| 111 |
+
"Programming Language :: Python :: 3",
|
| 112 |
+
"Topic :: Software Development :: Libraries",
|
| 113 |
+
"Topic :: Software Development :: Libraries :: Python Modules",
|
| 114 |
+
],
|
| 115 |
+
cmdclass={"upload": UploadCommand},
|
| 116 |
+
)
|