Skip to content

nhra_gt.helpers

Helper functions for simulation analysis and legacy compatibility.

This module contains utility functions for running sensitivity analyses, scenario summaries, and risk calculations, wrapping the core JAX engine.

Classes

Functions

relative_risk(pressure, offload_min, params=None)

Simple monotone risk proxy used by legacy tests.

Source code in src/nhra_gt/helpers.py
def relative_risk(pressure: float, offload_min: float, params: Params | None = None) -> float:
    """Simple monotone risk proxy used by legacy tests."""
    _ = params
    p = max(0.0, float(pressure) - 1.0)
    o = max(0.0, float(offload_min)) / 60.0
    return float(np.exp(0.9 * p + 0.15 * o))

run_hybrid(years, params, seed=123, n_mc=300, recorder=None, overrides=None)

Wrapper for the modern JAX engine that accepts Pydantic Params.

Source code in src/nhra_gt/helpers.py
def run_hybrid(
    years: list[int],
    params: Params,
    seed: int = 123,
    n_mc: int = 300,
    recorder: Any | None = None,
    overrides: dict[str, Any] | None = None,
) -> tuple[pd.DataFrame, pd.DataFrame]:
    """Wrapper for the modern JAX engine that accepts Pydantic Params."""
    return run_hybrid_modern(
        years=years,
        p=params.to_params_jax(),
        seed=seed,
        n_mc=n_mc,
        recorder=recorder,
        overrides=overrides,
    )

scenario_summary(years, params, scenarios, seed=123, n_mc=100)

Runs a batch of scenarios defined by intervention names.

Source code in src/nhra_gt/helpers.py
def scenario_summary(
    years: list[int],
    params: Params,
    scenarios: dict[str, list[str]],
    seed: int = 123,
    n_mc: int = 100,
) -> pd.DataFrame:
    """Runs a batch of scenarios defined by intervention names."""
    rows: list[dict[str, Any]] = []
    for name, interventions in scenarios.items():
        p_jax = params.to_params_jax()
        for iv in interventions:
            p_jax = apply_intervention(p_jax, iv)
        agg, _ = run_hybrid_modern(years=years, p=p_jax, seed=seed, n_mc=n_mc)
        last = agg.sort_values("year").iloc[-1]
        rows.append(
            {
                "scenario": name,
                "rr_mean": float(last.get("rr_mean", last.get("pressure_mean", 0.0))),
                "pressure_mean": float(last.get("pressure_mean", 0.0)),
                "within4_mean": float(last.get("within4_mean", 0.0)),
            }
        )
    return pd.DataFrame(rows)

one_way_sensitivity(years, params, grid, seed=123, n_mc=50)

Performs one-way sensitivity analysis over a grid of parameter values.

Source code in src/nhra_gt/helpers.py
def one_way_sensitivity(
    years: list[int],
    params: Params,
    grid: dict[str, list[float]],
    seed: int = 123,
    n_mc: int = 50,
) -> pd.DataFrame:
    """Performs one-way sensitivity analysis over a grid of parameter values."""
    rows: list[dict[str, Any]] = []
    for param_name, values in grid.items():
        for v in values:
            p2 = Params(**params.model_dump())
            p2 = p2.model_copy(update={param_name: v})
            agg, _ = run_hybrid(years, p2, seed=seed, n_mc=n_mc)
            rr_end = float(agg.sort_values("year").iloc[-1].get("rr_mean", 0.0))
            rows.append({"param": param_name, "value": float(v), "rr_end": rr_end})
    return pd.DataFrame(rows)

probabilistic_sensitivity(years, params, interventions, seed=123, n_param=50, n_mc=50)

Performs probabilistic sensitivity analysis (PSA) with noise sampling.

Source code in src/nhra_gt/helpers.py
def probabilistic_sensitivity(
    years: list[int],
    params: Params,
    interventions: list[str],
    seed: int = 123,
    n_param: int = 50,
    n_mc: int = 50,
) -> list[dict[str, Any]]:
    """Performs probabilistic sensitivity analysis (PSA) with noise sampling."""
    rng = np.random.default_rng(seed)
    out: list[dict[str, Any]] = []

    for _i in range(int(n_param)):
        sampled = Params(**params.model_dump())
        sampled = sampled.model_copy(update={"noise_sd": float(rng.uniform(0.01, 0.06))})
        p_jax = sampled.to_params_jax()
        for iv in interventions:
            p_jax = apply_intervention(p_jax, iv)

        agg, _ = run_hybrid_modern(
            years=years,
            p=p_jax,
            seed=int(rng.integers(0, 2**31 - 1)),
            n_mc=n_mc,
        )
        last = agg.sort_values("year").iloc[-1]
        out.append(
            {
                "noise_sd": float(sampled.noise_sd),
                "rr_end": float(last.get("rr_mean", 0.0)),
                "pressure_end": float(last.get("pressure_mean", 0.0)),
            }
        )

    return out