#!/usr/bin/env python
# -*- coding: utf-8 -*-
# ==========================================================================
# ____ __ _ _____ ____ ____
# | _ \ ___ ___ _ __ / _| __ _| | _____ | ____/ ___/ ___|
# | | | |/ _ \/ _ \ '_ \| |_ / _` | |/ / _ \ | _|| | | | _
# | |_| | __/ __/ |_) | _| (_| | < __/ | |__| |__| |_| |
# |____/ \___|\___| .__/|_| \__,_|_|\_\___| |_____\____\____|
# |_|
#
# --- Deepfake ECG Generator ---
# https://github.com/vlbthambawita/deepfake-ecg
# ==========================================================================
#
# DeepfakeECG GUI Application
# Copyright (C) 2023-2025 by Vajira Thambawita
# Copyright (C) 2025 by Thomas Dreibholz
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see .
#
# Contact:
# * Vajira Thambawita
# * Thomas Dreibholz
import datetime
import deepfakeecg
import ecg_plot
import getopt
import gradio
import io
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.ticker
import neurokit2
import numpy
import pathlib
import random
import sys
import tempfile
import threading
import torch
import typing
import version
import PIL
import PIL.Image
from typing import Any, Final
# ###### Print log message ##################################################
def log(logstring : str) -> None:
print(('\x1b[34m' + datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S.%f') +
': ' + logstring + '\x1b[0m'));
# ###### DeepFakeECG Plus Session (session with web browser) ################
class Session:
# ###### Constructor #####################################################
def __init__(self) -> None:
self.Lock = threading.Lock()
self.Counter : int = 0
self.Selected : int = 0
self.Results : list[Any] = [ ]
self.Analysis : matplotlib.figure.Figure | None = None
self.Type : int | None = None
self.TempDirectory : tempfile.TemporaryDirectory = \
tempfile.TemporaryDirectory(dir = TempDirectory.name)
log(f'Prepared temporary directory {self.TempDirectory.name}')
# ###### Destructor ######################################################
def __del__(self) -> None:
log(f'Cleaning up temporary directory {self.TempDirectory.name}')
self.TempDirectory.cleanup()
TempDirectory : tempfile.TemporaryDirectory[Any]
Sessions : dict[str,Session] = { }
# ###### Initialize a new session ###########################################
def initializeSession(request: gradio.Request) -> None:
Sessions[request.session_hash] = Session()
log(f'Session "{request.session_hash}" initialized => {len(Sessions)} active sessions')
# ###### Clean up a session #################################################
def cleanUpSession(request: gradio.Request) -> None:
if request.session_hash in Sessions:
if Sessions[request.session_hash].Analysis:
matplotlib.pyplot.close(Sessions[request.session_hash].Analysis)
del Sessions[request.session_hash]
log(f'Session "{request.session_hash}" cleaned up => {len(Sessions)} active sessions')
# ###### Generate ECGs ######################################################
def predict(numberOfECGs: int = 1,
# ecgLengthInSeconds: int = 10,
ecgTypeString: str = 'ECG-12',
generatorModel: str = 'Default',
request: gradio.Request = None) -> tuple[list[tuple[PIL.Image.Image,str]],matplotlib.figure.Figure]:
ecgLengthInSeconds = 10
log(f'Session "{request.session_hash}": Generate EGCs!')
# ====== Set ECG type ====================================================
ecgType = deepfakeecg.DATA_ECG12
if ecgTypeString == 'ECG-8':
ecgType = deepfakeecg.DATA_ECG8
elif ecgTypeString == 'ECG-12':
ecgType = deepfakeecg.DATA_ECG12
else:
sys.stderr.write(f'WARNING: Invalid ecgTypeString {ecgTypeString}, using ECG-12!\n')
# ====== Raise Locator.MAXTICKS, if necessary ============================
matplotlib.ticker.Locator.MAXTICKS = \
max(1000, ecgLengthInSeconds * deepfakeecg.ECG_SAMPLING_RATE)
# print(matplotlib.ticker.Locator.MAXTICKS)
# ====== Generate the ECGs ===============================================
Sessions[request.session_hash].Results = \
deepfakeecg.generateDeepfakeECGs(numberOfECGs,
ecgType = ecgType,
ecgLengthInSeconds = ecgLengthInSeconds,
ecgScaleFactor = deepfakeecg.ECG_DEFAULT_SCALE_FACTOR,
outputFormat = deepfakeecg.OUTPUT_TENSOR,
showProgress = False,
runOnDevice = runOnDevice)
Sessions[request.session_hash].Type = ecgType
# ====== Create a list of image/label tuples for gradio.Gallery ==========
plotList : list[tuple[PIL.Image.Image,str]] = [ ]
ecgNumber : int = 1
info : Final[str] = '25 mm/sec, 1 mV/10 mm'
for result in Sessions[request.session_hash].Results:
# ====== Plot ECG =====================================================
# 1. Convert to NumPy
# 2. Remove the Timestamp column (0)
# 3. Convert from µV to mV
result = result.t().detach().cpu().numpy()[1:] / 1000
# print(result)
# ------ ECG-12 -------------------------------------------------------
if ecgType == deepfakeecg.DATA_ECG12:
ecg_plot.plot(result,
title = 'ECG-12 – ' + info,
sample_rate = deepfakeecg.ECG_SAMPLING_RATE,
lead_index = [ 'I', 'II', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6', 'III', 'aVR', 'aVL', 'aVF' ],
lead_order = [0, 1, 8, 9, 10, 11, 2, 3, 4, 5, 6, 7],
show_grid = True)
# ------ ECG-8 --------------------------------------------------------
else:
ecg_plot.plot(result,
title = 'ECG-8 – ' + info,
sample_rate = deepfakeecg.ECG_SAMPLING_RATE,
lead_index = [ 'I', 'II', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6' ],
lead_order = [0, 1, 2, 3, 4, 5, 6, 7],
show_grid = True)
# ====== Generate WebP output =========================================
imageBuffer = io.BytesIO()
plt.savefig(imageBuffer, format = 'webp')
plt.close()
image : PIL.Image.Image = PIL.Image.open(imageBuffer)
plotList.append( (image, f'ECG Number {ecgNumber}') )
ecgNumber = ecgNumber + 1
# ====== Prepare analysis results for first ECG ==========================
Sessions[request.session_hash].Analysis = \
plotAnalysis(Sessions[request.session_hash].Results[0])
return (plotList, Sessions[request.session_hash].Analysis)
# ###### Plot the analysis ##################################################
def plotAnalysis(data : torch.Tensor) -> matplotlib.figure.Figure:
data = data.t().detach().cpu().numpy()[1:] / 1000
leadI = data[0]
signals, info = neurokit2.ecg_process(leadI, sampling_rate = deepfakeecg.ECG_SAMPLING_RATE)
neurokit2.ecg_plot(signals, info)
# DIN A4 landscape: w=11.7, h=8.27
w = 508/25.4 # mm to inch
h = 122/25.4 # mm to inch
matplotlib.pyplot.gcf().set_size_inches(w, h, forward=True)
return matplotlib.pyplot.gcf()
# ###### Generic download ###################################################
def download(request: gradio.Request,
outputFormat: int) -> pathlib.Path | None:
if outputFormat == deepfakeecg.OUTPUT_CSV:
ecgResult = Sessions[request.session_hash].Results[Sessions[request.session_hash].Selected]
ecgType = Sessions[request.session_hash].Type
fileName = pathlib.Path(Sessions[request.session_hash].TempDirectory.name) / \
('ECG-' + str(Sessions[request.session_hash].Selected + 1) + '.csv')
deepfakeecg.dataToCSV(ecgResult, ecgType, fileName)
log(f'Session "{request.session_hash}": Download CSV file {fileName}')
return fileName
elif ( (outputFormat == deepfakeecg.OUTPUT_PDF) or
(outputFormat == deepfakeecg.OUTPUT_PDF_ANALYSIS) ):
ecgResult = Sessions[request.session_hash].Results[Sessions[request.session_hash].Selected]
ecgType = Sessions[request.session_hash].Type
fileName = pathlib.Path(Sessions[request.session_hash].TempDirectory.name) / \
('ECG-' + str(Sessions[request.session_hash].Selected + 1) + '.pdf')
if ecgType == deepfakeecg.DATA_ECG12:
outputLeads = [ 'I', 'II', 'III', 'aVL', 'aVR', 'aVF', 'V1', 'V2', 'V3', 'V4' , 'V5' , 'V6' ]
else:
outputLeads = [ 'I', 'II', 'V1', 'V2', 'V3', 'V4' , 'V5' , 'V6' ]
deepfakeecg.dataToPDF(ecgResult, ecgType, outputLeads, fileName, outputFormat,
Sessions[request.session_hash].Selected + 1)
log(f'Session "{request.session_hash}": Download PDF file {fileName}')
return fileName
return None
# ###### Download CSV #######################################################
def downloadCSV(request: gradio.Request) -> pathlib.Path | None:
return download(request, deepfakeecg.OUTPUT_CSV)
# ###### Download PDF #######################################################
def downloadPDF(request: gradio.Request) -> pathlib.Path | None:
return download(request, deepfakeecg.OUTPUT_PDF)
# ###### Download PDF #######################################################
def downloadPDFwithAnalysis(request: gradio.Request) -> pathlib.Path | None:
return download(request, deepfakeecg.OUTPUT_PDF_ANALYSIS)
# ###### Analyze the selected ECG ###########################################
def analyze(event: gradio.SelectData,
request: gradio.Request) -> matplotlib.figure.Figure:
Sessions[request.session_hash].Selected = event.index
log(f'Session "{request.session_hash}": Analyze ECG #{Sessions[request.session_hash].Selected + 1}!')
data = Sessions[request.session_hash].Results[Sessions[request.session_hash].Selected]
if Sessions[request.session_hash].Analysis:
matplotlib.pyplot.close(Sessions[request.session_hash].Analysis)
Sessions[request.session_hash].Analysis = plotAnalysis(data)
return Sessions[request.session_hash].Analysis
# ###### Print usage and exit ###############################################
def usage(exitCode : int = 0) -> str:
sys.stdout.write('Usage: ' + sys.argv[0] + ' [-d|--device cpu|cuda] [-v|--version]\n')
sys.exit(exitCode)
# ###### Main program #######################################################
# ====== Initialise =========================================================
runOnDevice: str = 'cuda' if torch.cuda.is_available() else 'cpu'
css = r"""
div {
background-image: url("https://www.nntb.no/~dreibh/graphics/backgrounds/background-essen.png");
}
/* ###### General Settings ############################################## */
html, body {
height: 100%;
margin: 0;
padding: 0;
font-family: sans-serif;
font-size: small;
background-color: #E3E3E3; /* Simula background colour: #E3E3E3 */
background-image: url("https://www.nntb.no/~dreibh/graphics/backgrounds/background-wiehl.png");
}
/* ###### Header ######################################################## */
div.program-header {
background-image: none;
background-color: #F15D22; /* Simula header colour: #F15D22 */
height: 7.5vh;
display: flex;
justify-content: space-between;
}
div.program-logo-left {
width: 12.5vw;
float: left;
display: flex;
padding: 0% 1%;
align-items: center;
background: white;
}
div.program-logo-right {
width: 12.5vw;
float: right;
display: flex;
padding: 0% 1%;
align-items: center;
background: white;
}
div.program-title {
display: flex;
align-items: center;
padding: 0% 1%;
background-image: none;
background-color: #F15D22; /* Simula header colour: #F15D22 */
font-family: "Open Sans", sans-serif;
font-size: 4vh;
font-weight: bold;
}
img.program-logo-image {
min-height: 4vh;
max-height: 4vh;
margin-left: auto;
margin-right: auto;
}
"""
# ====== Check arguments ====================================================
try:
options, args = getopt.gnu_getopt(
sys.argv[1:],
'd:v',
[
'device=',
'version'
])
for option, optarg in options:
if option in ( '-d', '--device' ):
runOnDevice = optarg
elif option in ( '-v', '--version' ):
sys.stdout.write('PyTorch version: ' + torch.__version__ + '\n')
sys.stdout.write('CUDA version: ' + torch.version.cuda + '\n')
sys.stdout.write('CUDA available: ' + ('yes' if torch.cuda.is_available() else 'no') + '\n')
sys.stdout.write('Device: ' + runOnDevice + '\n')
sys.exit(1)
else:
sys.stderr.write('ERROR: Invalid option ' + option + '!\n')
sys.exit(1)
except getopt.GetoptError as error:
sys.stderr.write('ERROR: ' + str(error) + '\n')
usage(1)
if len(args) > 0:
usage(1)
# ====== Create GUI =========================================================
with gradio.Blocks(css = css,
theme = gradio.themes.Glass(secondary_hue=gradio.themes.colors.blue),
analytics_enabled = False,
fill_height = True,
fill_width = True) as gui:
# ====== Session handling ================================================
# Session initialization, to be called when page is loaded
gui.load(initializeSession)
# ====== Header ==========================================================
with gradio.Row(height = '10vh', min_height = '10vh', max_height = '10vh'):
big_block = gradio.HTML("""
""")
# gradio.Markdown('## Settings')
with gradio.Row(height = '10vh', min_height = '10vh', max_height = '10vh'):
sliderNumberOfECGs = gradio.Slider(1, 100, label="Number of ECGs", step = 1, value = 4, interactive = True)
# sliderLengthInSeconds = gradio.Slider(5, 60, label="Length (s)", step = 5, value = 10, interactive = True)
dropdownType = gradio.Dropdown( [ 'ECG-12', 'ECG-8' ], label = 'ECG Type', interactive = True)
dropdownGeneratorModel = gradio.Dropdown( [ 'Default' ], label = 'Generator Model', interactive = True)
with gradio.Column():
buttonGenerate = gradio.Button("Generate ECGs!")
# buttonAnalyze = gradio.Button("Analyze this ECG!")
with gradio.Row():
buttonCSV = gradio.DownloadButton("Download CSV")
buttonCSV_hidden = gradio.DownloadButton(visible=False, elem_id="download_csv_hidden")
buttonPDF = gradio.DownloadButton("Download ECG PDF")
buttonPDF_hidden = gradio.DownloadButton(visible=False, elem_id="download_pdf_hidden")
buttonPDFwAnalysis = gradio.DownloadButton("Download ECG+Analysis PDF")
buttonPDFwAnalysis_hidden = gradio.DownloadButton(visible=False, elem_id="download_pdfwanalysis_hidden")
# gradio.Markdown('## Output')
with gradio.Row(): # height = '24vh', min_height = '24vh', max_height = '24vh'):
outputGallery = gradio.Gallery(label = 'Generated ECGs',
columns = 8,
# rows = 1,
height = 'auto',
object_fit = 'contain',
show_label = True,
allow_preview = True,
preview = False
)
with gradio.Row(): # height = '24vh', min_height = '24vh', max_height = '24vh'):
analysisOutput = gradio.Plot(label = 'Analysis')
# ====== Add click event handling for "Generate" button ==================
buttonGenerate.click(predict,
inputs = [ sliderNumberOfECGs,
# sliderLengthInSeconds,
dropdownType,
dropdownGeneratorModel ],
outputs = [ outputGallery, analysisOutput ]
)
# ====== Add click event handling for "Analyze" button ===================
outputGallery.select(analyze,
inputs = [ ],
outputs = [ analysisOutput ]
)
# ====== Add click event handling for download buttons ===================
# Using hidden button and JavaScript, to generate download file on-the-fly:
# https://github.com/gradio-app/gradio/issues/9230#issuecomment-2323771634
buttonCSV.click(fn = downloadCSV,
inputs = None,
outputs = [ buttonCSV_hidden ]).then(
fn = None, inputs = None, outputs = None,
js = "() => document.querySelector('#download_csv_hidden').click()")
buttonPDF.click(fn = downloadPDF,
inputs = None,
outputs = [ buttonPDF_hidden ]).then(
fn = None, inputs = None, outputs = None,
js = "() => document.querySelector('#download_pdf_hidden').click()")
buttonPDFwAnalysis.click(fn = downloadPDFwithAnalysis,
inputs = None,
outputs = [ buttonPDFwAnalysis_hidden ]).then(
fn = None, inputs = None, outputs = None,
js = "() => document.querySelector('#download_pdfwanalysis_hidden').click()")
# ====== Run on startup ==================================================
gui.load(predict,
inputs = [ sliderNumberOfECGs,
# sliderLengthInSeconds,
dropdownType,
dropdownGeneratorModel ],
outputs = [ outputGallery, analysisOutput ]
)
# ====== Session handling ================================================
# Session clean-up, to be called when page is closed/refreshed.
# This must be the last statement of gr.Blocks(), due to bug #12159!
# -> https://github.com/gradio-app/gradio/issues/12159
gui.unload(cleanUpSession)
# ====== Run the GUI ========================================================
if __name__ == "__main__":
# ------ Prepare temporary directory -------------------------------------
TempDirectory = tempfile.TemporaryDirectory(prefix = 'DeepFakeECGPlus-')
log(f'Prepared temporary directory {TempDirectory.name}')
# ------ Run the GUI, with downloads from temporary directory allowed ----
gui.launch(allowed_paths = [ TempDirectory.name ], debug = True)
# ------ Clean up --------------------------------------------------------
log(f'Cleaning up temporary directory {TempDirectory.name}')
TempDirectory.cleanup()
log('Done!')