Skip to content

API Reference

This section provides auto-generated documentation for the nhra_gt Python package.

If you want comprehensive coverage, start with the Full API Index which is generated from the package module tree.


Core Modules

Engine

nhra_gt.engine.Params

Bases: BaseModel

Pydantic Params wrapper for validation and tooling.

Source code in src/nhra_gt/domain/params.py
class Params(BaseModel):
    """Pydantic Params wrapper for validation and tooling."""

    # Operational, Behavioral & Policy Groupings
    ops: OperationalParams = Field(default_factory=OperationalParams)
    behavior: BehavioralParams = Field(default_factory=BehavioralParams)
    policy: PolicyParams = Field(default_factory=PolicyParams)

    rurality_weight: float = Field(default=0.35, ge=0.0, le=1.0)
    remote_weight: float = Field(default=0.07, ge=0.0, le=1.0)
    nominal_cth_share_target: float = Field(default=0.45, ge=0.0, le=1.0)
    effective_cth_share_base: float = Field(default=0.38, ge=0.0, le=1.0)

    nep_annual_growth: float = Field(default=0.03, ge=-1.0, le=1.0)
    input_cost_annual_growth: float = Field(default=0.04, ge=-1.0, le=1.0)
    demand_base: float = Field(default=0.85, ge=0.0, le=10.0)
    avoidable_ed_share: float = Field(default=0.18, ge=0.0, le=1.0)

    discharge_delay_base: float = Field(default=1.0, ge=0.0, le=10.0)
    bed_capacity_index: float = Field(default=1.0, ge=0.0, le=10.0)
    capacity_lag: float = Field(default=0.15, ge=0.0, le=10.0)
    expansion_lag: float = Field(default=0.10, ge=0.0, le=10.0)
    contraction_lag: float = Field(default=0.20, ge=0.0, le=10.0)
    adjustment_cost_beta: float = Field(default=5.0, ge=0.0, le=1e6)

    cost_shifting_intensity: float = Field(default=0.35, ge=0.0, le=10.0)
    fragmentation_index: float = Field(default=1.0, ge=0.0, le=10.0)
    audit_pressure: float = Field(default=0.50, ge=0.0, le=10.0)
    admin_burden_weight: float = Field(default=0.25, ge=0.0, le=10.0)
    cannibalization_beta: float = Field(default=0.10, ge=0.0, le=10.0)

    block_funding_base: float = Field(default=0.15, ge=0.0, le=1.0)
    shifting_friction: float = Field(default=0.05, ge=0.0, le=10.0)

    signal_lag_months: int = Field(default=1, ge=0, le=24)
    claims_lag_months: int = Field(default=3, ge=0, le=24)

    occupancy_base: float = Field(default=0.88, ge=0.0, le=10.0)
    offload_base_min: float = Field(default=18.0, ge=0.0, le=1e6)
    within4_base: float = Field(default=0.53, ge=0.0, le=1.0)

    rr_beta_pressure: float = Field(default=0.35, ge=0.0, le=10.0)
    rr_beta_offload: float = Field(default=0.015, ge=0.0, le=10.0)
    offload_threshold_min: float = Field(default=20.0, ge=0.0, le=1e6)

    tau: float = Field(default=0.25, ge=0.0, le=10.0)
    bargaining_cost: float = Field(default=0.12, ge=0.0, le=10.0)
    political_salience: float = Field(default=0.30, ge=0.0, le=10.0)

    gp_out_of_pocket: float = Field(default=40.0, ge=0.0, le=1e6)
    gp_wait_time_min: float = Field(default=15.0, ge=0.0, le=1e6)
    patient_time_value_hour: float = Field(default=25.0, ge=0.0, le=1e6)

    cap_growth: float = Field(default=0.065, ge=0.0, le=10.0)
    has_cumulative_cap: bool = False

    cap_rule_type: int = Field(default=0, ge=0, le=1)
    audit_rule_type: int = Field(default=0, ge=0, le=1)
    orchestration_mode: int = Field(default=0, ge=0, le=10)
    equilibrium_selection_rule: str = "nash"
    isolated_game: str | None = None
    use_stage_game_equilibria: bool = True

    use_equilibrium_bargaining: bool = False
    use_quantal_response: bool = False
    qre_lambda: float = Field(default=4.0, ge=0.0, le=1e6)
    use_burden_feedback: bool = False
    burden_to_throughput_beta: float = Field(default=0.06, ge=0.0, le=10.0)
    noise_sd: float = Field(default=0.03, ge=0.0, le=10.0)

    model_config = ConfigDict(validate_assignment=True)

    @classmethod
    def from_flat_dict(cls, data: dict[str, Any]) -> Params:
        """Creates a Params object from a potentially flat dictionary."""
        ops_fields = OperationalParams.model_fields.keys()
        behavior_fields = BehavioralParams.model_fields.keys()
        policy_fields = PolicyParams.model_fields.keys()

        # Extract nested structures if present as flat keys
        ops_data = {k: data.pop(k) for k in list(data.keys()) if k in ops_fields}
        behavior_data = {k: data.pop(k) for k in list(data.keys()) if k in behavior_fields}
        policy_data = {k: data.pop(k) for k in list(data.keys()) if k in policy_fields}

        # If data already had them as dicts, update with the popped values
        if "ops" in data and isinstance(data["ops"], dict):
            data["ops"].update(ops_data)
        elif ops_data:
            data["ops"] = ops_data

        if "behavior" in data and isinstance(data["behavior"], dict):
            data["behavior"].update(behavior_data)
        elif behavior_data:
            data["behavior"] = behavior_data

        if "policy" in data and isinstance(data["policy"], dict):
            data["policy"].update(policy_data)
        elif policy_data:
            data["policy"] = policy_data

        return cls(**data)

    def flatten(self) -> dict[str, Any]:
        """Returns a flat dictionary representation of all parameters."""
        data = self.model_dump()
        ops = data.pop("ops", {})
        behavior = data.pop("behavior", {})
        policy = data.pop("policy", {})
        return {**data, **ops, **behavior, **policy}

    def to_params_jax(self) -> ParamsJax:
        """Converts Pydantic Params to JAX-native ParamsJax."""
        from .state import BehavioralParamsJax, OperationalParamsJax, PolicyParamsJax

        data = self.model_dump()
        # Handle nested groups
        ops_data = data.pop("ops", {})
        behavior_data = data.pop("behavior", {})
        policy_data = data.pop("policy", {})

        return ParamsJax(
            ops=OperationalParamsJax(**ops_data),
            behavior=BehavioralParamsJax(**behavior_data),
            policy=PolicyParamsJax(**policy_data),
            **data,
        )

    def replace(self, **kwargs: Any) -> Params:
        """Pydantic-compatible field replacement."""
        return self.model_copy(update=kwargs)

Functions

from_flat_dict(data) classmethod

Creates a Params object from a potentially flat dictionary.

Source code in src/nhra_gt/domain/params.py
@classmethod
def from_flat_dict(cls, data: dict[str, Any]) -> Params:
    """Creates a Params object from a potentially flat dictionary."""
    ops_fields = OperationalParams.model_fields.keys()
    behavior_fields = BehavioralParams.model_fields.keys()
    policy_fields = PolicyParams.model_fields.keys()

    # Extract nested structures if present as flat keys
    ops_data = {k: data.pop(k) for k in list(data.keys()) if k in ops_fields}
    behavior_data = {k: data.pop(k) for k in list(data.keys()) if k in behavior_fields}
    policy_data = {k: data.pop(k) for k in list(data.keys()) if k in policy_fields}

    # If data already had them as dicts, update with the popped values
    if "ops" in data and isinstance(data["ops"], dict):
        data["ops"].update(ops_data)
    elif ops_data:
        data["ops"] = ops_data

    if "behavior" in data and isinstance(data["behavior"], dict):
        data["behavior"].update(behavior_data)
    elif behavior_data:
        data["behavior"] = behavior_data

    if "policy" in data and isinstance(data["policy"], dict):
        data["policy"].update(policy_data)
    elif policy_data:
        data["policy"] = policy_data

    return cls(**data)

flatten()

Returns a flat dictionary representation of all parameters.

Source code in src/nhra_gt/domain/params.py
def flatten(self) -> dict[str, Any]:
    """Returns a flat dictionary representation of all parameters."""
    data = self.model_dump()
    ops = data.pop("ops", {})
    behavior = data.pop("behavior", {})
    policy = data.pop("policy", {})
    return {**data, **ops, **behavior, **policy}

to_params_jax()

Converts Pydantic Params to JAX-native ParamsJax.

Source code in src/nhra_gt/domain/params.py
def to_params_jax(self) -> ParamsJax:
    """Converts Pydantic Params to JAX-native ParamsJax."""
    from .state import BehavioralParamsJax, OperationalParamsJax, PolicyParamsJax

    data = self.model_dump()
    # Handle nested groups
    ops_data = data.pop("ops", {})
    behavior_data = data.pop("behavior", {})
    policy_data = data.pop("policy", {})

    return ParamsJax(
        ops=OperationalParamsJax(**ops_data),
        behavior=BehavioralParamsJax(**behavior_data),
        policy=PolicyParamsJax(**policy_data),
        **data,
    )

replace(**kwargs)

Pydantic-compatible field replacement.

Source code in src/nhra_gt/domain/params.py
def replace(self, **kwargs: Any) -> Params:
    """Pydantic-compatible field replacement."""
    return self.model_copy(update=kwargs)

nhra_gt.engine.run_simulation(*, years=10, n_samples=1, params=None, seed=0, start_year=2025, strategies=None)

Run a baseline simulation with optional Monte Carlo sampling.

This is a convenience wrapper around the JAX core (run_simulation_jax) for documentation examples and quick interactive use.

Returns a dict of numpy arrays. For n_samples == 1, arrays are shaped [num_steps]. For n_samples > 1, arrays are shaped [n_samples, num_steps].

Source code in src/nhra_gt/engine.py
@beartype
def run_simulation(
    *,
    years: int = 10,
    n_samples: int = 1,
    params: ParamsJax | None = None,
    seed: int = 0,
    start_year: int = 2025,
    strategies: Any | None = None,
) -> dict[str, np.ndarray]:
    """Run a baseline simulation with optional Monte Carlo sampling.

    This is a convenience wrapper around the JAX core (`run_simulation_jax`) for
    documentation examples and quick interactive use.

    Returns a dict of numpy arrays. For `n_samples == 1`, arrays are shaped
    `[num_steps]`. For `n_samples > 1`, arrays are shaped `[n_samples, num_steps]`.
    """

    if params is None:
        params = ParamsJax()

    if years <= 0:
        raise ValueError("years must be positive")
    if n_samples <= 0:
        raise ValueError("n_samples must be positive")

    num_steps = int(years) * 12
    init_state = baseline_state(start_year=start_year, p=params)

    if strategies is None:
        strategies_arr = jnp.zeros((num_steps, 13))
    else:
        strategies_arr = jnp.asarray(strategies)
        if strategies_arr.ndim == 1:
            strategies_arr = jnp.tile(_pad_strategies(strategies_arr, width=13), (num_steps, 1))
        elif strategies_arr.ndim == 2:
            if int(strategies_arr.shape[0]) == 1:
                strategies_arr = jnp.tile(strategies_arr, (num_steps, 1))
            elif int(strategies_arr.shape[0]) != num_steps:
                raise ValueError(
                    f"strategies must have shape ({num_steps}, 13) or (13,), got {strategies_arr.shape}"
                )
            strategies_arr = _pad_strategies(strategies_arr, width=13)
        else:
            raise ValueError("strategies must be 1D (13,) or 2D (num_steps, 13)")

    keys = jax.random.split(jax.random.PRNGKey(seed), int(n_samples))

    def _one_run(key):
        _, traj = run_simulation_jax(init_state, params, strategies_arr, key, num_steps)
        return traj

    traj = jax.vmap(_one_run)(keys) if n_samples > 1 else _one_run(keys[0])
    traj_host = jax.device_get(traj)

    def _to_np(a: Any) -> np.ndarray:
        out = np.asarray(a)
        if n_samples == 1 and out.ndim >= 2:
            return out
        return out

    return {
        "year": _to_np(traj_host.year),
        "month": _to_np(traj_host.month),
        "pressure": _to_np(traj_host.pressure),
        "occupancy": _to_np(traj_host.occupancy),
        "within4": _to_np(traj_host.within4),
        "effective_cth_share": _to_np(traj_host.effective_cth_share),
        "efficiency_gap": _to_np(traj_host.efficiency_gap),
        "reported_pressure": _to_np(traj_host.reported_pressure),
        "reported_occupancy": _to_np(traj_host.reported_occupancy),
        "reported_within4": _to_np(traj_host.reported_within4),
        "prob_ed": _to_np(traj_host.prob_ed),
        "lhn_pressure": _to_np(traj_host.lhn_pressure),
        "lhn_nwau": _to_np(traj_host.lhn_nwau),
    }

Agents

nhra_gt.agents.base.HeuristicAgent

Bases: Agent

Orchestrator that delegates to distinct Commonwealth, State, and LHN agents.


Subgames

nhra_gt.subgames.games.GameParams dataclass

Inputs used to parameterise stage games (dimensionless indices).

Source code in src/nhra_gt/subgames/games.py
@dataclass(frozen=True)
class GameParams:
    """Inputs used to parameterise stage games (dimensionless indices)."""

    pressure: float
    efficiency_gap: float
    discharge_delay: float
    political_salience: float
    audit_pressure: float
    cost_shifting_intensity: float
    political_capital: float
    cannibalization_beta: float = 0.1
    behavior: BehavioralParams = field(default_factory=BehavioralParams)

nhra_gt.subgames.games.definition_game(gp)

Definition game: 'R' realism vs 'E' strict efficient-price framing.

Source code in src/nhra_gt/subgames/games.py
def definition_game(gp: GameParams) -> TwoPlayerGame:
    """Definition game: 'R' realism vs 'E' strict efficient-price framing."""
    pr = gp.pressure
    eg = gp.efficiency_gap
    ps = gp.political_salience
    b = gp.behavior

    # Realism has fiscal and political costs but stabilises pressure.
    realism_benefit = (
        b.def_realism_benefit_base
        + b.def_realism_eg_weight * eg
        + b.def_realism_pr_weight * (pr - 1.0)
    )
    realism_cost = b.def_realism_cost_base + b.def_realism_ps_weight * ps

    strict_benefit = b.def_strict_benefit_base + b.def_strict_ps_weight * ps
    strict_cost = b.def_strict_cost_base + b.def_strict_pr_weight * pr

    # Payoffs (R vs E)
    u_row = np.array(
        [
            [1.0 + realism_benefit - realism_cost, 1.0 - b.def_row_realism_offset - realism_cost],
            [1.0 + strict_benefit - strict_cost, 1.0 - b.def_row_strict_offset - strict_cost],
        ],
        dtype=float,
    )
    u_col = np.array(
        [
            [1.0 + realism_benefit - b.def_col_realism_offset, 1.0 - b.def_col_realism_cost],
            [1.0 - b.def_col_strict_cost_1, 1.0 - b.def_col_strict_cost_2],
        ],
        dtype=float,
    )

    return TwoPlayerGame(u_row=u_row, u_col=u_col, row_actions=("R", "E"), col_actions=("R", "E"))

nhra_gt.subgames.games.bargaining_game(gp)

Bargaining game: 'A' agree to converge vs 'D' defer/escalate.

Source code in src/nhra_gt/subgames/games.py
def bargaining_game(gp: GameParams) -> TwoPlayerGame:
    """Bargaining game: 'A' agree to converge vs 'D' defer/escalate."""
    pr = gp.pressure
    ps = gp.political_salience
    pc = gp.political_capital
    b = gp.behavior

    # Political capital boosts the effectiveness of agreement
    converge_gain = (
        b.barg_converge_gain_base
        + b.barg_converge_pr_weight * (pr - 1.0)
        + b.barg_converge_pc_weight * pc
    )
    conflict_cost = b.barg_conflict_cost_base + b.barg_conflict_pr_weight * pr
    narrative_gain = b.barg_narrative_gain_base + b.barg_narrative_ps_weight * ps

    u_row = np.array(
        [
            [
                1.0 + converge_gain - b.barg_row_converge_ps_penalty * ps,
                1.0 - b.barg_row_converge_base_penalty - b.barg_row_converge_pr_penalty * pr,
            ],
            [1.0 + narrative_gain - b.barg_row_narrative_pr_penalty * pr, 1.0 - conflict_cost],
        ],
        dtype=float,
    )
    u_col = np.array(
        [
            [
                1.0 + converge_gain - b.barg_col_converge_ps_penalty * ps,
                1.0 - b.barg_col_converge_base_penalty - b.barg_col_converge_pr_penalty * pr,
            ],
            [1.0 - b.barg_col_narrative_penalty, 1.0 - conflict_cost],
        ],
        dtype=float,
    )

    return TwoPlayerGame(u_row=u_row, u_col=u_col, row_actions=("A", "D"), col_actions=("A", "D"))

nhra_gt.subgames.games.cost_shifting_game(gp)

Cost shifting game: invest upstream 'I' vs shift downstream 'S'.

Source code in src/nhra_gt/subgames/games.py
def cost_shifting_game(gp: GameParams) -> TwoPlayerGame:
    """Cost shifting game: invest upstream 'I' vs shift downstream 'S'."""
    pr = gp.pressure
    eg = gp.efficiency_gap
    csi = gp.cost_shifting_intensity
    b = gp.behavior

    coop_gain = b.shift_coop_gain_base + b.shift_coop_eg_weight * (1.0 - eg)
    shift_gain = b.shift_gain_base + b.shift_eg_weight * eg + b.shift_csi_weight * csi
    pr_cost = b.shift_pr_cost_weight * pr

    u_row = np.array(
        [
            [1.0 + coop_gain - pr_cost, 1.0 - b.shift_row_coop_penalty - pr_cost],
            [
                1.0 + shift_gain - b.shift_row_shift_pr_penalty * pr,
                1.0 - b.shift_row_shift_base_penalty - b.shift_row_shift_pr_heavy_penalty * pr,
            ],
        ],
        dtype=float,
    )
    u_col = np.array(
        [
            [1.0 + coop_gain - pr_cost, 1.0 + shift_gain - b.shift_row_shift_pr_penalty * pr],
            [
                1.0 - b.shift_row_coop_penalty - pr_cost,
                1.0 - b.shift_row_shift_base_penalty - b.shift_row_shift_pr_heavy_penalty * pr,
            ],
        ],
        dtype=float,
    )

    return TwoPlayerGame(u_row=u_row, u_col=u_col, row_actions=("I", "S"), col_actions=("I", "S"))

nhra_gt.subgames.games.discharge_coordination_game(gp)

Discharge coordination: coordinate 'C' vs fragment 'F'.

Source code in src/nhra_gt/subgames/games.py
def discharge_coordination_game(gp: GameParams) -> TwoPlayerGame:
    """Discharge coordination: coordinate 'C' vs fragment 'F'."""
    pr = gp.pressure
    b = gp.behavior
    d_excess = max(0.0, gp.discharge_delay - 1.0)
    benefit = b.disc_benefit_base + b.disc_benefit_slope * d_excess
    cost = b.disc_cost_base + b.disc_cost_slope * (1.0 - min(1.0, d_excess))
    pr_penalty = b.disc_pr_penalty_weight * pr

    u_row = np.array(
        [
            [1.0 + benefit - cost - pr_penalty, 1.0 - b.disc_row_coop_penalty - pr_penalty],
            [
                1.0 - b.disc_row_frag_base_penalty - pr_penalty,
                1.0 - b.disc_row_frag_heavy_penalty - b.disc_row_frag_pr_penalty * pr,
            ],
        ],
        dtype=float,
    )
    u_col = np.array(
        [
            [1.0 + benefit - cost - pr_penalty, 1.0 - b.disc_col_coop_penalty - pr_penalty],
            [
                1.0 - b.disc_col_frag_base_penalty - pr_penalty,
                1.0 - b.disc_col_frag_heavy_penalty - b.disc_col_frag_pr_penalty * pr,
            ],
        ],
        dtype=float,
    )

    return TwoPlayerGame(u_row=u_row, u_col=u_col, row_actions=("C", "F"), col_actions=("C", "F"))

Nash Solver

nhra_gt.subgames.nash.TwoPlayerGame dataclass

Represents a 2-player normal-form game.

Includes payoff matrices and action labels for both players.

Source code in src/nhra_gt/subgames/nash.py
@dataclass(frozen=True)
class TwoPlayerGame:
    """
    Represents a 2-player normal-form game.

    Includes payoff matrices and action labels for both players.
    """

    u_row: np.ndarray[Any, Any]  # shape (n,m)
    u_col: np.ndarray[Any, Any]  # shape (n,m)
    row_actions: tuple[str, ...]
    col_actions: tuple[str, ...]

nhra_gt.subgames.nash.solve_all_equilibria(game)

Backwards-compatible alias for all_nash.

Source code in src/nhra_gt/subgames/nash.py
def solve_all_equilibria(game: TwoPlayerGame) -> list[NashEquilibrium]:
    """Backwards-compatible alias for `all_nash`."""
    return all_nash(game)

Interfaces & Protocols

nhra_gt.interfaces.protocols.Strategy

Bases: Protocol

Protocol for a game-theory strategy (e.g., mixed or pure).

Source code in src/nhra_gt/interfaces/protocols.py
@runtime_checkable
class Strategy(Protocol):
    """Protocol for a game-theory strategy (e.g., mixed or pure)."""

    def sample(self) -> Any:  # pragma: no cover
        """Sample an action from the strategy."""
        ...

    def probability(self, action: Any) -> float:  # pragma: no cover
        """Get the probability of a specific action."""
        ...

Functions

probability(action)

Get the probability of a specific action.

Source code in src/nhra_gt/interfaces/protocols.py
def probability(self, action: Any) -> float:  # pragma: no cover
    """Get the probability of a specific action."""
    ...

sample()

Sample an action from the strategy.

Source code in src/nhra_gt/interfaces/protocols.py
def sample(self) -> Any:  # pragma: no cover
    """Sample an action from the strategy."""
    ...

nhra_gt.interfaces.protocols.NormalFormGame

Bases: Protocol

Protocol for a normal-form game container.

Source code in src/nhra_gt/interfaces/protocols.py
@runtime_checkable
class NormalFormGame(Protocol):
    """Protocol for a normal-form game container."""

    @property
    def num_players(self) -> int:  # pragma: no cover
        """Number of players in the game."""
        ...

    def payoffs(self, actions: IntArray) -> FloatArray:  # pragma: no cover
        """
        Calculate payoffs for all players given an action profile.

        Args:
            actions: An array of actions, one for each player.

        Returns:
            An array of payoffs, one for each player.
        """
        ...

Attributes

num_players property

Number of players in the game.

Functions

payoffs(actions)

Calculate payoffs for all players given an action profile.

Parameters:

Name Type Description Default
actions IntArray

An array of actions, one for each player.

required

Returns:

Type Description
FloatArray

An array of payoffs, one for each player.

Source code in src/nhra_gt/interfaces/protocols.py
def payoffs(self, actions: IntArray) -> FloatArray:  # pragma: no cover
    """
    Calculate payoffs for all players given an action profile.

    Args:
        actions: An array of actions, one for each player.

    Returns:
        An array of payoffs, one for each player.
    """
    ...

nhra_gt.interfaces.protocols.ExtensiveFormGame

Bases: Protocol

Protocol for an extensive-form (tree) game.

Source code in src/nhra_gt/interfaces/protocols.py
@runtime_checkable
class ExtensiveFormGame(Protocol):
    """Protocol for an extensive-form (tree) game."""

    def is_terminal(self, state: Any) -> bool: ...  # pragma: no cover
    def get_payoffs(self, state: Any) -> FloatArray: ...  # pragma: no cover
    def get_legal_actions(self, state: Any) -> list[Any]: ...  # pragma: no cover

See Also