from __future__ import annotations
import warnings
from dataclasses import dataclass
from typing import TYPE_CHECKING, Literal
if TYPE_CHECKING:
from typing import Never
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.figure import Figure
from .types import ConfigurationDataType, ECGDataType, RhythmStripsConfig
from .utils.attention import (
AbstractAttentionMap,
BackgroundAttentionMap,
IntervalAttentionMap,
LineColorAttentionMap,
attention_map_from_indices_annotations,
attention_map_from_time_annotations,
)
from .utils.data import (
_apply_configuration,
_numpy_to_dataframe,
_resolve_configuration,
_validate_input_lead_names,
)
from .utils.plot import (
LEFT_MARGIN_MM,
MM_PER_INCH,
RIGHT_MARGIN_MM,
_adjust_row_distance,
_compute_figure_size,
_compute_row_offsets,
_nice_tick_step,
_plot_attention_color_scale,
_plot_grid,
_plot_row,
_print_information,
_RenderContext,
_validate_time_axis_config,
)
__all__ = [
"AbstractAttentionMap",
"BackgroundAttentionMap",
"ECGInformation",
"ECGPlotter",
"ECGStats",
"IntervalAttentionMap",
"LineColorAttentionMap",
"attention_map_from_indices_annotations",
"attention_map_from_time_annotations",
]
[docs]
@dataclass
class ECGStats:
"""Computed ECG diagnostic statistics to be printed on the plot.
All fields default to ``None`` (not shown). Any field that is set will be
displayed in the top-right corner when ``print_information=True``, arranged
in columns of three rows.
Parameters
----------
bpm : float, optional
Heart rate in beats per minute.
snr : float, optional
Signal-to-noise ratio in dB.
rr_interval_ms : float, optional
Mean RR interval (beat-to-beat) in milliseconds.
hrv_ms : float, optional
Heart-rate variability — statistical spread of RR intervals (ms).
pr_interval_ms : float, optional
PR interval in milliseconds.
qrs_duration_ms : float, optional
QRS complex duration in milliseconds.
qt_interval_ms : float, optional
QT interval in milliseconds.
qtc_interval_ms : float, optional
Corrected QT interval (QTc) in milliseconds.
p_axis_deg : float, optional
P-wave axis in degrees.
qrs_axis_deg : float, optional
QRS axis in degrees.
t_axis_deg : float, optional
T-wave axis in degrees.
"""
bpm: float | None = None
snr: float | None = None
rr_interval_ms: float | None = None
hrv_ms: float | None = None
pr_interval_ms: float | None = None
qrs_duration_ms: float | None = None
qt_interval_ms: float | None = None
qtc_interval_ms: float | None = None
p_axis_deg: float | None = None
qrs_axis_deg: float | None = None
t_axis_deg: float | None = None
[docs]
class ECGPlotter:
"""Generate paper-like ECG plots from signal data.
Instantiate with visual parameters (speed, voltage scale, grid style, etc.),
then call :meth:`plot` to render one or more ECGs using the same configuration.
Parameters
----------
grid_mode : {'cm', None}, optional
Grid style to overlay on the plot. ``'cm'`` draws lines every 0.1 cm (= 1 mm),
with every 5th line slightly thicker. Pass ``None`` to disable the grid.
By default ``'cm'``.
speed : float, optional
The speed of the plot in mm/s, by default 25.0
voltage : float, optional
The space (in mm) corresponding to 1 mV, by default 10.0
row_distance : float, optional
Distance between the zero-lines of consecutive rows, expressed in mV, by default 3.0
line_width : float, optional
Thickness of the ECG signal lines (and calibration pulse) in points, by default 0.5
grid_color : str, optional
Color of the grid lines. Any matplotlib color string is accepted (e.g. '#f4aaaa',
'lightgray', 'gray'). By default '#f4aaaa' (light ECG-paper red).
print_information : bool, optional
Whether to print diagnostic parameters (speed, voltage, sampling frequency)
and any extra metadata in the corners of the figure, by default False.
print_available_leads : bool, optional
Whether to include the list of available leads in the diagnostic line.
Only has an effect when ``print_information=True``. By default False.
show_time_axis : bool, optional
Whether to show the time axis (x-axis ticks and spine) at the bottom of the
figure, by default False.
show_calibration : bool, optional
Whether to show the calibration pulse in the left margin of each row, by default True.
show_leads_labels : bool, optional
Whether to print lead names onto the plot, by default True.
show_separators : bool, optional
Whether to draw short vertical tick marks at the boundary between adjacent
lead columns within each row, by default True.
disconnect_segments : bool, optional
If True, the last sample of each segment is set to NaN so that adjacent
segments are not visually connected in the plot. By default True.
show_dpi : int, optional
DPI applied to the figure before displaying it when ``show=True``.
Has no effect when ``show=False``. By default 300.
"""
def __init__(
self,
grid_mode: Literal["cm"] | None = "cm",
speed: float = 25.0,
voltage: float = 10.0,
row_distance: float = 3.0,
line_width: float = 0.5,
grid_color: str = "#f4aaaa",
print_information: bool = False,
print_available_leads: bool = False,
show_time_axis: bool = False,
show_calibration: bool = True,
show_leads_labels: bool = True,
show_separators: bool = True,
disconnect_segments: bool = True,
show_dpi: int = 300,
):
assert grid_mode in (None, "cm"), "grid_mode must be None or 'cm'"
assert isinstance(speed, (int, float)) and speed > 0, "speed must be a positive number"
assert isinstance(voltage, (int, float)) and voltage > 0, "voltage must be a positive number"
assert isinstance(row_distance, (int, float)) and row_distance > 0, "row_distance must be a positive number"
assert isinstance(line_width, (int, float)) and line_width > 0, "line_width must be a positive number"
assert isinstance(grid_color, str) and len(grid_color) > 0, "grid_color must be a non-empty string"
assert isinstance(show_dpi, int) and show_dpi > 0, "show_dpi must be a positive integer"
self.grid_mode = grid_mode
self.speed = speed
self.voltage = voltage
self.row_distance = row_distance
self.line_width = line_width
self.grid_color = grid_color
self.print_information = print_information
self.print_available_leads = print_available_leads
self.show_time_axis = show_time_axis
self.show_calibration = show_calibration
self.show_leads_labels = show_leads_labels
self.show_separators = show_separators
self.disconnect_segments = disconnect_segments
self.show_dpi = show_dpi
[docs]
def plot(
self,
ecg_data: ECGDataType,
configuration: ConfigurationDataType | None = None,
sampling_frequency: float = 500.0,
show: bool = True,
information: ECGInformation | None = None,
stats: ECGStats | None = None,
attention_map: AbstractAttentionMap | None = None,
rhythm_strips: RhythmStripsConfig | None = None,
) -> Figure:
"""Plot the ECG in ``ecg_data`` using the plotting configuration specified in ``configuration``.
Parameters
----------
ecg_data : ECGDataType
ECG signal data to plot.
configuration : ConfigurationDataType | None, optional
The plotting configuration to be used. By default None, meaning that each lead is plotted
on its own row for its entire duration.
sampling_frequency : float, optional
The sampling frequency of the ECG data in Hz, by default 500.0
show : bool, optional
Whether to show the plot, by default True
information : ECGInformation | None, optional
Patient and recording metadata. When ``self.print_information`` is True, the
hospital, patient name and date are printed above the first ECG row, and
the machine model is printed in the bottom-right corner.
stats : ECGStats | None, optional
Computed ECG statistics. When ``self.print_information`` is True, any
non-None field is printed in the top-right corner, arranged in columns
of up to three rows.
attention_map : AbstractAttentionMap | None, optional
Optional attention overlay. Pass an instance of
:class:`~pmecg.BackgroundAttentionMap`, :class:`~pmecg.IntervalAttentionMap`, or
:class:`~pmecg.LineColorAttentionMap`, where you specify the attention data and
the style settings.
When an attention map requests a color scale, ``plot()`` expands the right margin
automatically to preserve the ECG plotting area. You can disable this by setting
``show_colormap=False`` in the AttentionMap initialization.
rhythm_strips : RhythmStripsConfig | None, optional
Optional rhythm strips appended after the configuration rows. Every lead
present in ``rhythm_strips.ecg_data`` is plotted as a full-width row
showing the entire recording. When ``rhythm_strips.speed`` differs from
the plotter's speed, those rows use the specified paper speed and the
figure width is expanded if needed. By default ``None``.
Returns
-------
matplotlib.figure.Figure
The matplotlib figure object containing the plot
"""
if isinstance(ecg_data, tuple):
df_data = _numpy_to_dataframe(ecg_data[0], ecg_data[1])
elif isinstance(ecg_data, pd.DataFrame):
df_data = ecg_data
else:
raise ValueError(
"ecg_data must be a tuple of (list of numpy arrays, list of lead names), "
"a tuple of (numpy array, list of lead names), or a pandas DataFrame"
)
_validate_input_lead_names(list(df_data.columns))
if configuration is not None and len(configuration) == 0:
raise ValueError("configuration must not be empty; pass None to use the default single-lead-per-row layout")
resolved_configuration = _resolve_configuration(configuration, list(df_data.columns))
prepared_attention = attention_map
if prepared_attention is not None:
prepared_attention.prepare(list(df_data.columns), df_data.shape[0], resolved_configuration)
reserves_attention_margin = prepared_attention is not None and prepared_attention.shows_color_scale
shows_attention_color_scale = reserves_attention_margin and prepared_attention.show_colormap
# Apply the layout configuration → one (signal, leads, offsets, segments) 4-tuple per row
config_rows = _apply_configuration(df_data, resolved_configuration, self.disconnect_segments)
n_config_rows = len(config_rows)
seq_len = max(len(row[0]) for row in config_rows) if config_rows else df_data.shape[0]
# --- Rhythm strips ---
rhythm_strip_rows: list[tuple[np.ndarray, list[str], list[int], list[Never]]] = []
rhythm_strip_speed: float | None = None
rhythm_strip_tti: float | None = None # time_to_inches for rhythm strip rows
rhythm_strip_df: pd.DataFrame | None = None
if rhythm_strips is not None:
if not isinstance(rhythm_strips, RhythmStripsConfig):
raise TypeError(f"rhythm_strips must be a RhythmStripsConfig instance, got {type(rhythm_strips).__name__}")
raw = rhythm_strips.ecg_data
if isinstance(raw, tuple):
rhythm_strip_df = _numpy_to_dataframe(raw[0], raw[1])
elif isinstance(raw, pd.DataFrame):
rhythm_strip_df = raw
else:
raise ValueError("RhythmStripsConfig.ecg_data must be a tuple or DataFrame")
if rhythm_strip_df.shape[1] == 0:
raise ValueError("RhythmStripsConfig.ecg_data must contain at least one lead (got zero columns)")
if rhythm_strip_df.shape[0] == 0:
raise ValueError("RhythmStripsConfig.ecg_data must contain at least one sample (got zero rows)")
_validate_input_lead_names(list(rhythm_strip_df.columns))
for lead_name in rhythm_strip_df.columns:
rhythm_strip_rows.append((rhythm_strip_df[lead_name].values.copy(), [lead_name], [0], []))
if rhythm_strips.speed is not None and abs(rhythm_strips.speed - self.speed) > 1e-9:
rhythm_strip_speed = rhythm_strips.speed
rhythm_strip_tti = rhythm_strips.speed / (sampling_frequency * MM_PER_INCH)
if self.show_time_axis:
_validate_time_axis_config([row[3] for row in config_rows], rhythm_strip_speed, self.speed)
all_rows = list(config_rows) + rhythm_strip_rows
n_rows = len(all_rows)
# Ensure row_distance * voltage is a multiple of 5mm
adjusted_row_distance = _adjust_row_distance(self.row_distance, self.voltage)
# Conversion factors and per-call render context
ctx = _RenderContext(
mv_to_inches=self.voltage / MM_PER_INCH,
time_to_inches=self.speed / (sampling_frequency * MM_PER_INCH),
row_distance_inches=adjusted_row_distance * self.voltage / MM_PER_INCH,
line_width=self.line_width,
grid_color=self.grid_color,
speed=self.speed,
voltage=self.voltage,
show_calibration=self.show_calibration,
show_leads_labels=self.show_leads_labels,
show_separators=self.show_separators,
)
# Figure dimensions
right_mm = RIGHT_MARGIN_MM * 2.0 if reserves_attention_margin else RIGHT_MARGIN_MM
width_inches, height_inches = _compute_figure_size(
n_rows,
seq_len,
sampling_frequency,
self.speed,
self.voltage,
adjusted_row_distance,
print_information=self.print_information,
right_margin_mm=right_mm,
rhythm_strip_seq_len=rhythm_strip_df.shape[0] if rhythm_strip_df is not None else None,
rhythm_strip_speed=rhythm_strip_speed,
)
# Pre-compute the zero-line y position (in inches) for every row
y_offsets = _compute_row_offsets(
n_rows,
height_inches,
ctx.row_distance_inches,
self.print_information,
)
# Create figure with exact physical dimensions
fig, ax = plt.subplots(1, 1, figsize=(width_inches, height_inches))
ax.set_xlim(0, width_inches)
ax.set_ylim(0, height_inches)
ax.set_aspect("equal")
if self.grid_mode is not None:
_plot_grid(ax, self.grid_mode, width_inches, height_inches, ctx)
for i, (row_signal, row_leads, row_offsets, _row_segs) in enumerate(all_rows):
is_rhythm_strip = i >= n_config_rows
row_attention = None
row_attention_map = None
if not is_rhythm_strip and prepared_attention is not None:
row_attention = prepared_attention.row_attentions[i]
row_attention_map = prepared_attention
elif is_rhythm_strip and prepared_attention is not None:
rhythm_strip_lead = row_leads[0]
rhythm_strip_attn = prepared_attention.rhythm_strip_attentions.get(rhythm_strip_lead)
if rhythm_strip_attn is not None:
if len(rhythm_strip_attn) == len(row_signal):
row_attention = rhythm_strip_attn
row_attention_map = prepared_attention
else:
warnings.warn(
f"Rhythm strip {rhythm_strip_lead!r} attention length ({len(rhythm_strip_attn)}) "
f"does not match rhythm strip ECG length ({len(row_signal)}); overlay skipped.",
UserWarning,
stacklevel=2,
)
_plot_row(
ax,
(row_signal, row_leads),
ctx,
y_offsets[i],
attention_values=row_attention,
attention_map=row_attention_map,
time_to_inches=rhythm_strip_tti if is_rhythm_strip else None,
segment_offsets=row_offsets,
)
first_row_top_inches = y_offsets[0] + ctx.row_distance_inches / 2.0
last_row_zero_inches = y_offsets[-1]
last_row_bottom_inches = last_row_zero_inches - ctx.row_distance_inches / 2.0
if shows_attention_color_scale and prepared_attention is not None:
_plot_attention_color_scale(
ax,
prepared_attention,
width_inches,
RIGHT_MARGIN_MM * 2.0,
first_row_top_inches,
last_row_bottom_inches,
)
# --- Time axis ---
left_margin_inches = LEFT_MARGIN_MM / MM_PER_INCH
if self.show_time_axis:
# When rhythm strips run at the same speed, the figure may be wider than seq_len alone.
axis_seq_len = seq_len
if rhythm_strip_df is not None and rhythm_strip_speed is None:
axis_seq_len = max(seq_len, rhythm_strip_df.shape[0])
# Choose a sensible tick spacing (0.2 s, rounded to a nice step)
total_time_s = axis_seq_len / sampling_frequency
tick_step_s = _nice_tick_step(total_time_s)
tick_times_s = np.arange(0, total_time_s + tick_step_s / 2, tick_step_s)
# Convert time values (seconds) → x position in inches
tick_positions_inches = tick_times_s * (self.speed / MM_PER_INCH) + left_margin_inches
ax.set_xticks(tick_positions_inches)
ax.set_xticklabels([f"{t:.2g} s" for t in tick_times_s], fontsize=7, fontfamily="monospace")
ax.spines["bottom"].set_position(("axes", 0))
else:
ax.xaxis.set_visible(False)
ax.spines["bottom"].set_visible(False)
# Remove the box: keep only the bottom spine (when visible)
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.spines["left"].set_visible(False)
ax.yaxis.set_visible(False)
if self.print_information:
original_leads = list(df_data.columns)
_print_information(
ax,
ctx,
width_inches,
sampling_frequency,
original_leads,
first_row_top_inches,
last_row_zero_inches,
information=information,
stats=stats,
rhythm_strip_speed=rhythm_strip_speed,
print_available_leads=self.print_available_leads,
)
if show:
fig.set_dpi(self.show_dpi)
plt.show()
return fig