Skip to content

nhra_gt.plotting

Functions

plot_strategy_heatmap(data, config=None, **kwargs)

Shows strategy shares over time for each game (one panel per game).

Source code in src/nhra_gt/visualization/distributional.py
def plot_strategy_heatmap(
    data: pd.DataFrame,
    config: PlotConfig | None = None,
    **kwargs,
) -> Figure:
    """
    Shows strategy shares over time for each game (one panel per game).
    """
    # Validation
    StrategyFrequencySchema.validate(data)

    if config is None:
        config = PlotConfig()

    games = sorted(data["game"].unique())
    figsize = (config.default_figsize[0], 2.1 * len(games))
    fig = plt.figure(figsize=figsize)

    for i, g in enumerate(games, start=1):
        ax = fig.add_subplot(len(games), 1, i)
        sub = data[data["game"] == g].copy()

        pivot = sub.pivot_table(
            index="year", columns="strategy", values="share", aggfunc="mean"
        ).fillna(0)

        for idx, col in enumerate(pivot.columns):
            color = config.color_palette[idx % len(config.color_palette)]
            ax.plot(
                pivot.index, pivot[col], label=f"{col}", linewidth=config.linewidth, color=color
            )

        ax.set_ylim(0, 1)
        ax.set_ylabel(g, fontsize=config.fontsize_label)
        ax.grid(True, alpha=config.alpha_grid)
        ax.tick_params(axis="both", labelsize=config.fontsize_tick)

        if i == 1:
            ax.legend(ncol=4, fontsize=config.fontsize_legend, loc="upper right", frameon=False)

    ax.set_xlabel("Year", fontsize=config.fontsize_label)
    return fig

tornado_from_rankcorr(data, outcome_col, params, config=None, topk=10, path=None)

Rank-correlation tornado using Spearman rho.

Source code in src/nhra_gt/visualization/sensitivity.py
def plot_rank_tornado(
    data: pd.DataFrame,
    outcome_col: str,
    params: list[str],
    config: PlotConfig | None = None,
    topk: int = 10,
    path: str | Path | None = None,
) -> Figure:
    """Rank-correlation tornado using Spearman rho."""
    # Validation
    RankCorrelationSchema.validate(data)
    for col in [outcome_col, *params]:
        if col not in data.columns:
            raise ValueError(f"Column '{col}' not found in data.")

    if config is None:
        config = PlotConfig()

    rows = []
    for p in params:
        rho = data[[p, outcome_col]].corr(method="spearman").iloc[0, 1]
        rows.append((p, float(rho)))

    rows.sort(key=lambda x: abs(x[1]), reverse=True)
    rows = rows[:topk]
    labels = [r[0] for r in rows][::-1]
    vals = [r[1] for r in rows][::-1]

    # Dynamic height
    height = 0.45 * len(labels) + 1.6
    fig = plt.figure(figsize=(config.default_figsize[0], height))
    ax = fig.gca()

    ax.barh(labels, vals, color=config.primary_color, alpha=0.8)
    ax.axvline(0, color="black", linewidth=1)
    ax.set_xlabel("Spearman rank correlation", fontsize=config.fontsize_label)
    ax.set_title(f"Sensitivity (tornado): {outcome_col}", fontsize=config.fontsize_title)
    ax.grid(True, axis="x", alpha=config.alpha_grid)
    ax.tick_params(axis="both", labelsize=config.fontsize_tick)

    if path:
        save_figure(fig, path, config)

    return fig

plot_trajectory(data, y_col, ylabel, config=None, q_low_col=None, q_high_col=None, **kwargs)

Plots a time-series trajectory with optional quantile ribbons.

Parameters:

Name Type Description Default
data DataFrame

DataFrame containing 'year' and the target columns.

required
y_col str

Column name for the primary metric.

required
ylabel str

Label for the y-axis.

required
config PlotConfig | None

PlotConfig object for styling.

None
q_low_col str | None

Optional column name for the lower quantile ribbon.

None
q_high_col str | None

Optional column name for the upper quantile ribbon.

None
**kwargs Any

Additional parameters passed to ax.plot.

{}

Returns:

Type Description
Figure

A matplotlib Figure object.

Source code in src/nhra_gt/visualization/trajectories.py
def plot_trajectory(
    data: pd.DataFrame,
    y_col: str,
    ylabel: str,
    config: PlotConfig | None = None,
    q_low_col: str | None = None,
    q_high_col: str | None = None,
    **kwargs: Any,
) -> Figure:
    """
    Plots a time-series trajectory with optional quantile ribbons.

    Args:
        data: DataFrame containing 'year' and the target columns.
        y_col: Column name for the primary metric.
        ylabel: Label for the y-axis.
        config: PlotConfig object for styling.
        q_low_col: Optional column name for the lower quantile ribbon.
        q_high_col: Optional column name for the upper quantile ribbon.
        **kwargs: Additional parameters passed to ax.plot.

    Returns:
        A matplotlib Figure object.
    """
    # Validation
    TrajectorySchema.validate(data)
    if y_col not in data.columns:
        raise ValueError(f"y_col '{y_col}' not found in data.")

    if config is None:
        config = PlotConfig()

    fig = plt.figure(figsize=config.default_figsize)
    ax = fig.gca()

    # Data extraction
    x = data["year"].to_numpy(dtype=float)
    y = pd.to_numeric(data[y_col], errors="coerce").to_numpy(dtype=float)

    # Main plot
    ax.plot(x, y, linewidth=config.linewidth, color=config.primary_color, **kwargs)

    # Quantile ribbon
    if q_low_col and q_high_col and q_low_col in data.columns and q_high_col in data.columns:
        q_low = pd.to_numeric(data[q_low_col], errors="coerce").to_numpy(dtype=float)
        q_high = pd.to_numeric(data[q_high_col], errors="coerce").to_numpy(dtype=float)
        ax.fill_between(x, q_low, q_high, color=config.primary_color, alpha=config.alpha_ribbon)

    # Labels and grid
    ax.set_xlabel("Year", fontsize=config.fontsize_label)
    ax.set_ylabel(ylabel, fontsize=config.fontsize_label)
    ax.tick_params(axis="both", labelsize=config.fontsize_tick)
    ax.grid(True, alpha=config.alpha_grid)

    return fig