Source code for pmecg.plot

from __future__ import annotations

import warnings
from dataclasses import dataclass
from typing import TYPE_CHECKING, Literal

if TYPE_CHECKING:
    from typing import Never

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.figure import Figure

from .types import ConfigurationDataType, ECGDataType, RhythmStripsConfig
from .utils.attention import (
    AbstractAttentionMap,
    BackgroundAttentionMap,
    IntervalAttentionMap,
    LineColorAttentionMap,
    attention_map_from_indices_annotations,
    attention_map_from_time_annotations,
)
from .utils.data import (
    _apply_configuration,
    _numpy_to_dataframe,
    _resolve_configuration,
    _validate_input_lead_names,
)
from .utils.plot import (
    LEFT_MARGIN_MM,
    MM_PER_INCH,
    RIGHT_MARGIN_MM,
    _adjust_row_distance,
    _compute_figure_size,
    _compute_row_offsets,
    _nice_tick_step,
    _plot_attention_color_scale,
    _plot_grid,
    _plot_row,
    _print_information,
    _RenderContext,
    _validate_time_axis_config,
)

__all__ = [
    "AbstractAttentionMap",
    "BackgroundAttentionMap",
    "ECGInformation",
    "ECGPlotter",
    "ECGStats",
    "IntervalAttentionMap",
    "LineColorAttentionMap",
    "attention_map_from_indices_annotations",
    "attention_map_from_time_annotations",
]


[docs] @dataclass class ECGStats: """Computed ECG diagnostic statistics to be printed on the plot. All fields default to ``None`` (not shown). Any field that is set will be displayed in the top-right corner when ``print_information=True``, arranged in columns of three rows. Parameters ---------- bpm : float, optional Heart rate in beats per minute. snr : float, optional Signal-to-noise ratio in dB. rr_interval_ms : float, optional Mean RR interval (beat-to-beat) in milliseconds. hrv_ms : float, optional Heart-rate variability — statistical spread of RR intervals (ms). pr_interval_ms : float, optional PR interval in milliseconds. qrs_duration_ms : float, optional QRS complex duration in milliseconds. qt_interval_ms : float, optional QT interval in milliseconds. qtc_interval_ms : float, optional Corrected QT interval (QTc) in milliseconds. p_axis_deg : float, optional P-wave axis in degrees. qrs_axis_deg : float, optional QRS axis in degrees. t_axis_deg : float, optional T-wave axis in degrees. """ bpm: float | None = None snr: float | None = None rr_interval_ms: float | None = None hrv_ms: float | None = None pr_interval_ms: float | None = None qrs_duration_ms: float | None = None qt_interval_ms: float | None = None qtc_interval_ms: float | None = None p_axis_deg: float | None = None qrs_axis_deg: float | None = None t_axis_deg: float | None = None
[docs] @dataclass class ECGInformation: """Patient and recording metadata to be printed on the ECG plot. Parameters ---------- hospital : str, optional Name of the hospital or clinic where the ECG was recorded. patient_name : str, optional Name of the patient. age : int, optional Age of the patient in years. sex : str, optional Sex of the patient (e.g. "Male", "Female"). date : str, optional Date of the recording (any human-readable format, e.g. "2024-01-15"). machine_model : str, optional ECG machine model, printed in the bottom-right corner. filter : str, optional Description of the filter(s) applied to the ECG (e.g. "0.05-150 Hz"). """ hospital: str | None = None patient_name: str | None = None age: int | None = None sex: str | None = None date: str | None = None machine_model: str | None = None filter: str | None = None
[docs] class ECGPlotter: """Generate paper-like ECG plots from signal data. Instantiate with visual parameters (speed, voltage scale, grid style, etc.), then call :meth:`plot` to render one or more ECGs using the same configuration. Parameters ---------- grid_mode : {'cm', None}, optional Grid style to overlay on the plot. ``'cm'`` draws lines every 0.1 cm (= 1 mm), with every 5th line slightly thicker. Pass ``None`` to disable the grid. By default ``'cm'``. speed : float, optional The speed of the plot in mm/s, by default 25.0 voltage : float, optional The space (in mm) corresponding to 1 mV, by default 10.0 row_distance : float, optional Distance between the zero-lines of consecutive rows, expressed in mV, by default 3.0 line_width : float, optional Thickness of the ECG signal lines (and calibration pulse) in points, by default 0.5 grid_color : str, optional Color of the grid lines. Any matplotlib color string is accepted (e.g. '#f4aaaa', 'lightgray', 'gray'). By default '#f4aaaa' (light ECG-paper red). print_information : bool, optional Whether to print diagnostic parameters (speed, voltage, sampling frequency) and any extra metadata in the corners of the figure, by default False. print_available_leads : bool, optional Whether to include the list of available leads in the diagnostic line. Only has an effect when ``print_information=True``. By default False. show_time_axis : bool, optional Whether to show the time axis (x-axis ticks and spine) at the bottom of the figure, by default False. show_calibration : bool, optional Whether to show the calibration pulse in the left margin of each row, by default True. show_leads_labels : bool, optional Whether to print lead names onto the plot, by default True. show_separators : bool, optional Whether to draw short vertical tick marks at the boundary between adjacent lead columns within each row, by default True. disconnect_segments : bool, optional If True, the last sample of each segment is set to NaN so that adjacent segments are not visually connected in the plot. By default True. show_dpi : int, optional DPI applied to the figure before displaying it when ``show=True``. Has no effect when ``show=False``. By default 300. """ def __init__( self, grid_mode: Literal["cm"] | None = "cm", speed: float = 25.0, voltage: float = 10.0, row_distance: float = 3.0, line_width: float = 0.5, grid_color: str = "#f4aaaa", print_information: bool = False, print_available_leads: bool = False, show_time_axis: bool = False, show_calibration: bool = True, show_leads_labels: bool = True, show_separators: bool = True, disconnect_segments: bool = True, show_dpi: int = 300, ): assert grid_mode in (None, "cm"), "grid_mode must be None or 'cm'" assert isinstance(speed, (int, float)) and speed > 0, "speed must be a positive number" assert isinstance(voltage, (int, float)) and voltage > 0, "voltage must be a positive number" assert isinstance(row_distance, (int, float)) and row_distance > 0, "row_distance must be a positive number" assert isinstance(line_width, (int, float)) and line_width > 0, "line_width must be a positive number" assert isinstance(grid_color, str) and len(grid_color) > 0, "grid_color must be a non-empty string" assert isinstance(show_dpi, int) and show_dpi > 0, "show_dpi must be a positive integer" self.grid_mode = grid_mode self.speed = speed self.voltage = voltage self.row_distance = row_distance self.line_width = line_width self.grid_color = grid_color self.print_information = print_information self.print_available_leads = print_available_leads self.show_time_axis = show_time_axis self.show_calibration = show_calibration self.show_leads_labels = show_leads_labels self.show_separators = show_separators self.disconnect_segments = disconnect_segments self.show_dpi = show_dpi
[docs] def plot( self, ecg_data: ECGDataType, configuration: ConfigurationDataType | None = None, sampling_frequency: float = 500.0, show: bool = True, information: ECGInformation | None = None, stats: ECGStats | None = None, attention_map: AbstractAttentionMap | None = None, rhythm_strips: RhythmStripsConfig | None = None, ) -> Figure: """Plot the ECG in ``ecg_data`` using the plotting configuration specified in ``configuration``. Parameters ---------- ecg_data : ECGDataType ECG signal data to plot. configuration : ConfigurationDataType | None, optional The plotting configuration to be used. By default None, meaning that each lead is plotted on its own row for its entire duration. sampling_frequency : float, optional The sampling frequency of the ECG data in Hz, by default 500.0 show : bool, optional Whether to show the plot, by default True information : ECGInformation | None, optional Patient and recording metadata. When ``self.print_information`` is True, the hospital, patient name and date are printed above the first ECG row, and the machine model is printed in the bottom-right corner. stats : ECGStats | None, optional Computed ECG statistics. When ``self.print_information`` is True, any non-None field is printed in the top-right corner, arranged in columns of up to three rows. attention_map : AbstractAttentionMap | None, optional Optional attention overlay. Pass an instance of :class:`~pmecg.BackgroundAttentionMap`, :class:`~pmecg.IntervalAttentionMap`, or :class:`~pmecg.LineColorAttentionMap`, where you specify the attention data and the style settings. When an attention map requests a color scale, ``plot()`` expands the right margin automatically to preserve the ECG plotting area. You can disable this by setting ``show_colormap=False`` in the AttentionMap initialization. rhythm_strips : RhythmStripsConfig | None, optional Optional rhythm strips appended after the configuration rows. Every lead present in ``rhythm_strips.ecg_data`` is plotted as a full-width row showing the entire recording. When ``rhythm_strips.speed`` differs from the plotter's speed, those rows use the specified paper speed and the figure width is expanded if needed. By default ``None``. Returns ------- matplotlib.figure.Figure The matplotlib figure object containing the plot """ if isinstance(ecg_data, tuple): df_data = _numpy_to_dataframe(ecg_data[0], ecg_data[1]) elif isinstance(ecg_data, pd.DataFrame): df_data = ecg_data else: raise ValueError( "ecg_data must be a tuple of (list of numpy arrays, list of lead names), " "a tuple of (numpy array, list of lead names), or a pandas DataFrame" ) _validate_input_lead_names(list(df_data.columns)) if configuration is not None and len(configuration) == 0: raise ValueError("configuration must not be empty; pass None to use the default single-lead-per-row layout") resolved_configuration = _resolve_configuration(configuration, list(df_data.columns)) prepared_attention = attention_map if prepared_attention is not None: prepared_attention.prepare(list(df_data.columns), df_data.shape[0], resolved_configuration) reserves_attention_margin = prepared_attention is not None and prepared_attention.shows_color_scale shows_attention_color_scale = reserves_attention_margin and prepared_attention.show_colormap # Apply the layout configuration → one (signal, leads, offsets, segments) 4-tuple per row config_rows = _apply_configuration(df_data, resolved_configuration, self.disconnect_segments) n_config_rows = len(config_rows) seq_len = max(len(row[0]) for row in config_rows) if config_rows else df_data.shape[0] # --- Rhythm strips --- rhythm_strip_rows: list[tuple[np.ndarray, list[str], list[int], list[Never]]] = [] rhythm_strip_speed: float | None = None rhythm_strip_tti: float | None = None # time_to_inches for rhythm strip rows rhythm_strip_df: pd.DataFrame | None = None if rhythm_strips is not None: if not isinstance(rhythm_strips, RhythmStripsConfig): raise TypeError(f"rhythm_strips must be a RhythmStripsConfig instance, got {type(rhythm_strips).__name__}") raw = rhythm_strips.ecg_data if isinstance(raw, tuple): rhythm_strip_df = _numpy_to_dataframe(raw[0], raw[1]) elif isinstance(raw, pd.DataFrame): rhythm_strip_df = raw else: raise ValueError("RhythmStripsConfig.ecg_data must be a tuple or DataFrame") if rhythm_strip_df.shape[1] == 0: raise ValueError("RhythmStripsConfig.ecg_data must contain at least one lead (got zero columns)") if rhythm_strip_df.shape[0] == 0: raise ValueError("RhythmStripsConfig.ecg_data must contain at least one sample (got zero rows)") _validate_input_lead_names(list(rhythm_strip_df.columns)) for lead_name in rhythm_strip_df.columns: rhythm_strip_rows.append((rhythm_strip_df[lead_name].values.copy(), [lead_name], [0], [])) if rhythm_strips.speed is not None and abs(rhythm_strips.speed - self.speed) > 1e-9: rhythm_strip_speed = rhythm_strips.speed rhythm_strip_tti = rhythm_strips.speed / (sampling_frequency * MM_PER_INCH) if self.show_time_axis: _validate_time_axis_config([row[3] for row in config_rows], rhythm_strip_speed, self.speed) all_rows = list(config_rows) + rhythm_strip_rows n_rows = len(all_rows) # Ensure row_distance * voltage is a multiple of 5mm adjusted_row_distance = _adjust_row_distance(self.row_distance, self.voltage) # Conversion factors and per-call render context ctx = _RenderContext( mv_to_inches=self.voltage / MM_PER_INCH, time_to_inches=self.speed / (sampling_frequency * MM_PER_INCH), row_distance_inches=adjusted_row_distance * self.voltage / MM_PER_INCH, line_width=self.line_width, grid_color=self.grid_color, speed=self.speed, voltage=self.voltage, show_calibration=self.show_calibration, show_leads_labels=self.show_leads_labels, show_separators=self.show_separators, ) # Figure dimensions right_mm = RIGHT_MARGIN_MM * 2.0 if reserves_attention_margin else RIGHT_MARGIN_MM width_inches, height_inches = _compute_figure_size( n_rows, seq_len, sampling_frequency, self.speed, self.voltage, adjusted_row_distance, print_information=self.print_information, right_margin_mm=right_mm, rhythm_strip_seq_len=rhythm_strip_df.shape[0] if rhythm_strip_df is not None else None, rhythm_strip_speed=rhythm_strip_speed, ) # Pre-compute the zero-line y position (in inches) for every row y_offsets = _compute_row_offsets( n_rows, height_inches, ctx.row_distance_inches, self.print_information, ) # Create figure with exact physical dimensions fig, ax = plt.subplots(1, 1, figsize=(width_inches, height_inches)) ax.set_xlim(0, width_inches) ax.set_ylim(0, height_inches) ax.set_aspect("equal") if self.grid_mode is not None: _plot_grid(ax, self.grid_mode, width_inches, height_inches, ctx) for i, (row_signal, row_leads, row_offsets, _row_segs) in enumerate(all_rows): is_rhythm_strip = i >= n_config_rows row_attention = None row_attention_map = None if not is_rhythm_strip and prepared_attention is not None: row_attention = prepared_attention.row_attentions[i] row_attention_map = prepared_attention elif is_rhythm_strip and prepared_attention is not None: rhythm_strip_lead = row_leads[0] rhythm_strip_attn = prepared_attention.rhythm_strip_attentions.get(rhythm_strip_lead) if rhythm_strip_attn is not None: if len(rhythm_strip_attn) == len(row_signal): row_attention = rhythm_strip_attn row_attention_map = prepared_attention else: warnings.warn( f"Rhythm strip {rhythm_strip_lead!r} attention length ({len(rhythm_strip_attn)}) " f"does not match rhythm strip ECG length ({len(row_signal)}); overlay skipped.", UserWarning, stacklevel=2, ) _plot_row( ax, (row_signal, row_leads), ctx, y_offsets[i], attention_values=row_attention, attention_map=row_attention_map, time_to_inches=rhythm_strip_tti if is_rhythm_strip else None, segment_offsets=row_offsets, ) first_row_top_inches = y_offsets[0] + ctx.row_distance_inches / 2.0 last_row_zero_inches = y_offsets[-1] last_row_bottom_inches = last_row_zero_inches - ctx.row_distance_inches / 2.0 if shows_attention_color_scale and prepared_attention is not None: _plot_attention_color_scale( ax, prepared_attention, width_inches, RIGHT_MARGIN_MM * 2.0, first_row_top_inches, last_row_bottom_inches, ) # --- Time axis --- left_margin_inches = LEFT_MARGIN_MM / MM_PER_INCH if self.show_time_axis: # When rhythm strips run at the same speed, the figure may be wider than seq_len alone. axis_seq_len = seq_len if rhythm_strip_df is not None and rhythm_strip_speed is None: axis_seq_len = max(seq_len, rhythm_strip_df.shape[0]) # Choose a sensible tick spacing (0.2 s, rounded to a nice step) total_time_s = axis_seq_len / sampling_frequency tick_step_s = _nice_tick_step(total_time_s) tick_times_s = np.arange(0, total_time_s + tick_step_s / 2, tick_step_s) # Convert time values (seconds) → x position in inches tick_positions_inches = tick_times_s * (self.speed / MM_PER_INCH) + left_margin_inches ax.set_xticks(tick_positions_inches) ax.set_xticklabels([f"{t:.2g} s" for t in tick_times_s], fontsize=7, fontfamily="monospace") ax.spines["bottom"].set_position(("axes", 0)) else: ax.xaxis.set_visible(False) ax.spines["bottom"].set_visible(False) # Remove the box: keep only the bottom spine (when visible) ax.spines["top"].set_visible(False) ax.spines["right"].set_visible(False) ax.spines["left"].set_visible(False) ax.yaxis.set_visible(False) if self.print_information: original_leads = list(df_data.columns) _print_information( ax, ctx, width_inches, sampling_frequency, original_leads, first_row_top_inches, last_row_zero_inches, information=information, stats=stats, rhythm_strip_speed=rhythm_strip_speed, print_available_leads=self.print_available_leads, ) if show: fig.set_dpi(self.show_dpi) plt.show() return fig