Skip to content

nhra_gt.visualization.distributional

Distributional and Cross-Sectional Plotting.

Heatmaps, pareto charts, and CDFs for comparing simulation results.

Classes

Functions

plot_strategy_heatmap(data, config=None, **kwargs)

Shows strategy shares over time for each game (one panel per game).

Source code in src/nhra_gt/visualization/distributional.py
def plot_strategy_heatmap(
    data: pd.DataFrame,
    config: PlotConfig | None = None,
    **kwargs,
) -> Figure:
    """
    Shows strategy shares over time for each game (one panel per game).
    """
    # Validation
    StrategyFrequencySchema.validate(data)

    if config is None:
        config = PlotConfig()

    games = sorted(data["game"].unique())
    figsize = (config.default_figsize[0], 2.1 * len(games))
    fig = plt.figure(figsize=figsize)

    for i, g in enumerate(games, start=1):
        ax = fig.add_subplot(len(games), 1, i)
        sub = data[data["game"] == g].copy()

        pivot = sub.pivot_table(
            index="year", columns="strategy", values="share", aggfunc="mean"
        ).fillna(0)

        for idx, col in enumerate(pivot.columns):
            color = config.color_palette[idx % len(config.color_palette)]
            ax.plot(
                pivot.index, pivot[col], label=f"{col}", linewidth=config.linewidth, color=color
            )

        ax.set_ylim(0, 1)
        ax.set_ylabel(g, fontsize=config.fontsize_label)
        ax.grid(True, alpha=config.alpha_grid)
        ax.tick_params(axis="both", labelsize=config.fontsize_tick)

        if i == 1:
            ax.legend(ncol=4, fontsize=config.fontsize_legend, loc="upper right", frameon=False)

    ax.set_xlabel("Year", fontsize=config.fontsize_label)
    return fig

plot_risk_heatmap(data, x_col, y_col, z_col, title, config=None)

Plots a 2D heatmap of system risk/state. Typically used for parameter sweeps (e.g. Bed Capacity vs Demand).

Source code in src/nhra_gt/visualization/distributional.py
def plot_risk_heatmap(
    data: pd.DataFrame,
    x_col: str,
    y_col: str,
    z_col: str,
    title: str,
    config: PlotConfig | None = None,
) -> Figure:
    """
    Plots a 2D heatmap of system risk/state.
    Typically used for parameter sweeps (e.g. Bed Capacity vs Demand).
    """
    if config is None:
        config = PlotConfig()

    fig, ax = plt.subplots(figsize=config.default_figsize)

    pivot = data.pivot_table(index=y_col, columns=x_col, values=z_col)

    sns.heatmap(
        pivot,
        ax=ax,
        cmap="YlGnBu",  # Professional gradient
        cbar_klabel=z_col.replace("_", " ").title(),
        annot=True,
        fmt=".2f",
        annot_kws={"size": 8},
    )

    ax.set_title(title, fontsize=config.fontsize_title)
    ax.invert_yaxis()  # Standard orientation for sweeps
    return fig

plot_distributions(data, value_col, group_col=None, config=None)

Plots distributions (KDE/Histogram) of a variable, optionally grouped.

Source code in src/nhra_gt/visualization/distributional.py
def plot_distributions(
    data: pd.DataFrame,
    value_col: str,
    group_col: str | None = None,
    config: PlotConfig | None = None,
) -> Figure:
    """
    Plots distributions (KDE/Histogram) of a variable, optionally grouped.
    """
    if config is None:
        config = PlotConfig()

    fig = plt.figure(figsize=config.default_figsize)
    ax = fig.gca()

    if group_col:
        sns.kdeplot(
            data=data, x=value_col, hue=group_col, fill=True, palette=config.color_palette, ax=ax
        )
    else:
        sns.histplot(data=data, x=value_col, kde=True, color=config.primary_color, ax=ax)

    ax.set_xlabel(value_col, fontsize=config.fontsize_label)
    ax.set_title(f"Distribution: {value_col}", fontsize=config.fontsize_title)
    ax.grid(True, alpha=config.alpha_grid)

    return fig

plot_pareto(data, x_col, y_col, label_col=None, config=None)

Plots a Pareto frontier (tradeoff scatter plot).

Source code in src/nhra_gt/visualization/distributional.py
def plot_pareto(
    data: pd.DataFrame,
    x_col: str,
    y_col: str,
    label_col: str | None = None,
    config: PlotConfig | None = None,
) -> Figure:
    """
    Plots a Pareto frontier (tradeoff scatter plot).
    """
    if config is None:
        config = PlotConfig()

    fig = plt.figure(figsize=config.default_figsize)
    ax = fig.gca()

    ax.scatter(data[x_col], data[y_col], color=config.primary_color, alpha=0.7)

    if label_col:
        for _, row in data.iterrows():
            ax.annotate(row[label_col], (row[x_col], row[y_col]), fontsize=8, alpha=0.8)

    ax.set_xlabel(x_col, fontsize=config.fontsize_label)
    ax.set_ylabel(y_col, fontsize=config.fontsize_label)
    ax.set_title(f"Tradeoff: {x_col} vs {y_col}", fontsize=config.fontsize_title)
    ax.grid(True, alpha=config.alpha_grid)

    return fig

plot_stacked_bar(data, title, xlabel, config=None, **kwargs)

Plots a stacked horizontal bar chart.

Source code in src/nhra_gt/visualization/distributional.py
def plot_stacked_bar(
    data: pd.DataFrame,
    title: str,
    xlabel: str,
    config: PlotConfig | None = None,
    **kwargs,
) -> Figure:
    """
    Plots a stacked horizontal bar chart.
    """
    if config is None:
        config = PlotConfig()

    fig, ax = plt.subplots(figsize=config.default_figsize)
    data.plot(kind="barh", stacked=True, ax=ax, color=config.color_palette, **kwargs)

    ax.set_title(title, fontsize=config.fontsize_title)
    ax.set_xlabel(xlabel, fontsize=config.fontsize_label)
    ax.grid(True, axis="x", alpha=config.alpha_grid)
    ax.tick_params(axis="both", labelsize=config.fontsize_tick)

    return fig

plot_comparison_bar(data, x_col, y_col, title, ylabel, config=None, **kwargs)

Plots a simple bar chart for scenario comparison.

Source code in src/nhra_gt/visualization/distributional.py
def plot_comparison_bar(
    data: pd.DataFrame,
    x_col: str,
    y_col: str,
    title: str,
    ylabel: str,
    config: PlotConfig | None = None,
    **kwargs,
) -> Figure:
    """
    Plots a simple bar chart for scenario comparison.
    """
    if config is None:
        config = PlotConfig()

    fig, ax = plt.subplots(figsize=config.default_figsize)
    sns.barplot(
        data=data,
        x=x_col,
        y=y_col,
        ax=ax,
        palette=config.color_palette,
        hue=x_col,
        legend=False,
        **kwargs,
    )

    ax.set_title(title, fontsize=config.fontsize_title)
    ax.set_ylabel(ylabel, fontsize=config.fontsize_label)
    ax.set_xlabel(x_col.replace("_", " ").title(), fontsize=config.fontsize_label)
    ax.tick_params(axis="both", labelsize=config.fontsize_tick)
    plt.xticks(rotation=45)
    plt.tight_layout()

    return fig

plot_cdf(data, value_col=None, title='Cumulative Distribution Function', config=None)

Plots a Cumulative Distribution Function (CDF).

Source code in src/nhra_gt/visualization/distributional.py
def plot_cdf(
    data: pd.Series | pd.DataFrame,
    value_col: str | None = None,
    title: str = "Cumulative Distribution Function",
    config: PlotConfig | None = None,
) -> Figure:
    """
    Plots a Cumulative Distribution Function (CDF).
    """
    if config is None:
        config = PlotConfig()

    if isinstance(data, pd.DataFrame):
        if value_col is None:
            raise ValueError("value_col must be provided for DataFrame input")
        s = data[value_col].dropna().sort_values().reset_index(drop=True)
    else:
        s = data.dropna().sort_values().reset_index(drop=True)

    y = (s.index + 1) / len(s)

    fig, ax = plt.subplots(figsize=config.default_figsize)
    ax.plot(s, y, color=config.primary_color, linewidth=config.linewidth)

    ax.set_title(title, fontsize=config.fontsize_title)
    ax.set_xlabel(value_col if value_col else "Value", fontsize=config.fontsize_label)
    ax.set_ylabel("CDF", fontsize=config.fontsize_label)
    ax.grid(True, alpha=config.alpha_grid)

    return fig