Distributional and Cross-Sectional Plotting.
Heatmaps, pareto charts, and CDFs for comparing simulation results.
Classes
Functions
plot_strategy_heatmap(data, config=None, **kwargs)
Shows strategy shares over time for each game (one panel per game).
Source code in src/nhra_gt/visualization/distributional.py
| def plot_strategy_heatmap(
data: pd.DataFrame,
config: PlotConfig | None = None,
**kwargs,
) -> Figure:
"""
Shows strategy shares over time for each game (one panel per game).
"""
# Validation
StrategyFrequencySchema.validate(data)
if config is None:
config = PlotConfig()
games = sorted(data["game"].unique())
figsize = (config.default_figsize[0], 2.1 * len(games))
fig = plt.figure(figsize=figsize)
for i, g in enumerate(games, start=1):
ax = fig.add_subplot(len(games), 1, i)
sub = data[data["game"] == g].copy()
pivot = sub.pivot_table(
index="year", columns="strategy", values="share", aggfunc="mean"
).fillna(0)
for idx, col in enumerate(pivot.columns):
color = config.color_palette[idx % len(config.color_palette)]
ax.plot(
pivot.index, pivot[col], label=f"{col}", linewidth=config.linewidth, color=color
)
ax.set_ylim(0, 1)
ax.set_ylabel(g, fontsize=config.fontsize_label)
ax.grid(True, alpha=config.alpha_grid)
ax.tick_params(axis="both", labelsize=config.fontsize_tick)
if i == 1:
ax.legend(ncol=4, fontsize=config.fontsize_legend, loc="upper right", frameon=False)
ax.set_xlabel("Year", fontsize=config.fontsize_label)
return fig
|
plot_risk_heatmap(data, x_col, y_col, z_col, title, config=None)
Plots a 2D heatmap of system risk/state.
Typically used for parameter sweeps (e.g. Bed Capacity vs Demand).
Source code in src/nhra_gt/visualization/distributional.py
| def plot_risk_heatmap(
data: pd.DataFrame,
x_col: str,
y_col: str,
z_col: str,
title: str,
config: PlotConfig | None = None,
) -> Figure:
"""
Plots a 2D heatmap of system risk/state.
Typically used for parameter sweeps (e.g. Bed Capacity vs Demand).
"""
if config is None:
config = PlotConfig()
fig, ax = plt.subplots(figsize=config.default_figsize)
pivot = data.pivot_table(index=y_col, columns=x_col, values=z_col)
sns.heatmap(
pivot,
ax=ax,
cmap="YlGnBu", # Professional gradient
cbar_klabel=z_col.replace("_", " ").title(),
annot=True,
fmt=".2f",
annot_kws={"size": 8},
)
ax.set_title(title, fontsize=config.fontsize_title)
ax.invert_yaxis() # Standard orientation for sweeps
return fig
|
plot_distributions(data, value_col, group_col=None, config=None)
Plots distributions (KDE/Histogram) of a variable, optionally grouped.
Source code in src/nhra_gt/visualization/distributional.py
| def plot_distributions(
data: pd.DataFrame,
value_col: str,
group_col: str | None = None,
config: PlotConfig | None = None,
) -> Figure:
"""
Plots distributions (KDE/Histogram) of a variable, optionally grouped.
"""
if config is None:
config = PlotConfig()
fig = plt.figure(figsize=config.default_figsize)
ax = fig.gca()
if group_col:
sns.kdeplot(
data=data, x=value_col, hue=group_col, fill=True, palette=config.color_palette, ax=ax
)
else:
sns.histplot(data=data, x=value_col, kde=True, color=config.primary_color, ax=ax)
ax.set_xlabel(value_col, fontsize=config.fontsize_label)
ax.set_title(f"Distribution: {value_col}", fontsize=config.fontsize_title)
ax.grid(True, alpha=config.alpha_grid)
return fig
|
plot_pareto(data, x_col, y_col, label_col=None, config=None)
Plots a Pareto frontier (tradeoff scatter plot).
Source code in src/nhra_gt/visualization/distributional.py
| def plot_pareto(
data: pd.DataFrame,
x_col: str,
y_col: str,
label_col: str | None = None,
config: PlotConfig | None = None,
) -> Figure:
"""
Plots a Pareto frontier (tradeoff scatter plot).
"""
if config is None:
config = PlotConfig()
fig = plt.figure(figsize=config.default_figsize)
ax = fig.gca()
ax.scatter(data[x_col], data[y_col], color=config.primary_color, alpha=0.7)
if label_col:
for _, row in data.iterrows():
ax.annotate(row[label_col], (row[x_col], row[y_col]), fontsize=8, alpha=0.8)
ax.set_xlabel(x_col, fontsize=config.fontsize_label)
ax.set_ylabel(y_col, fontsize=config.fontsize_label)
ax.set_title(f"Tradeoff: {x_col} vs {y_col}", fontsize=config.fontsize_title)
ax.grid(True, alpha=config.alpha_grid)
return fig
|
plot_stacked_bar(data, title, xlabel, config=None, **kwargs)
Plots a stacked horizontal bar chart.
Source code in src/nhra_gt/visualization/distributional.py
| def plot_stacked_bar(
data: pd.DataFrame,
title: str,
xlabel: str,
config: PlotConfig | None = None,
**kwargs,
) -> Figure:
"""
Plots a stacked horizontal bar chart.
"""
if config is None:
config = PlotConfig()
fig, ax = plt.subplots(figsize=config.default_figsize)
data.plot(kind="barh", stacked=True, ax=ax, color=config.color_palette, **kwargs)
ax.set_title(title, fontsize=config.fontsize_title)
ax.set_xlabel(xlabel, fontsize=config.fontsize_label)
ax.grid(True, axis="x", alpha=config.alpha_grid)
ax.tick_params(axis="both", labelsize=config.fontsize_tick)
return fig
|
plot_comparison_bar(data, x_col, y_col, title, ylabel, config=None, **kwargs)
Plots a simple bar chart for scenario comparison.
Source code in src/nhra_gt/visualization/distributional.py
| def plot_comparison_bar(
data: pd.DataFrame,
x_col: str,
y_col: str,
title: str,
ylabel: str,
config: PlotConfig | None = None,
**kwargs,
) -> Figure:
"""
Plots a simple bar chart for scenario comparison.
"""
if config is None:
config = PlotConfig()
fig, ax = plt.subplots(figsize=config.default_figsize)
sns.barplot(
data=data,
x=x_col,
y=y_col,
ax=ax,
palette=config.color_palette,
hue=x_col,
legend=False,
**kwargs,
)
ax.set_title(title, fontsize=config.fontsize_title)
ax.set_ylabel(ylabel, fontsize=config.fontsize_label)
ax.set_xlabel(x_col.replace("_", " ").title(), fontsize=config.fontsize_label)
ax.tick_params(axis="both", labelsize=config.fontsize_tick)
plt.xticks(rotation=45)
plt.tight_layout()
return fig
|
plot_cdf(data, value_col=None, title='Cumulative Distribution Function', config=None)
Plots a Cumulative Distribution Function (CDF).
Source code in src/nhra_gt/visualization/distributional.py
| def plot_cdf(
data: pd.Series | pd.DataFrame,
value_col: str | None = None,
title: str = "Cumulative Distribution Function",
config: PlotConfig | None = None,
) -> Figure:
"""
Plots a Cumulative Distribution Function (CDF).
"""
if config is None:
config = PlotConfig()
if isinstance(data, pd.DataFrame):
if value_col is None:
raise ValueError("value_col must be provided for DataFrame input")
s = data[value_col].dropna().sort_values().reset_index(drop=True)
else:
s = data.dropna().sort_values().reset_index(drop=True)
y = (s.index + 1) / len(s)
fig, ax = plt.subplots(figsize=config.default_figsize)
ax.plot(s, y, color=config.primary_color, linewidth=config.linewidth)
ax.set_title(title, fontsize=config.fontsize_title)
ax.set_xlabel(value_col if value_col else "Value", fontsize=config.fontsize_label)
ax.set_ylabel("CDF", fontsize=config.fontsize_label)
ax.grid(True, alpha=config.alpha_grid)
return fig
|