Skip to content

nhra_gt.engine_jax

Compatibility layer for historical engine_jax imports.

The core JAX implementation lives in nhra_gt.engine. Older code and tests expect nhra_gt.engine_jax to exist and to expose *_jax entry points.

Classes

ParamsJax

Bases: ParamsGenerated

JAX-compatible simulation parameters.

This class extends the auto-generated scalar parameters with runtime objects and JAX-native rules.

Source code in src/nhra_gt/domain/state.py
@struct.dataclass
class ParamsJax(ParamsGenerated):
    """JAX-compatible simulation parameters.

    This class extends the auto-generated scalar parameters with runtime
    objects and JAX-native rules.
    """

    # Grouped coefficients
    ops: OperationalParamsJax = struct.field(default_factory=OperationalParamsJax)
    behavior: BehavioralParamsJax = struct.field(default_factory=BehavioralParamsJax)
    policy: PolicyParamsJax = struct.field(default_factory=PolicyParamsJax)

    # Modular Rules (JAX-compatible PyTrees)
    cap_rule: Any = struct.field(default_factory=lambda: None)
    audit_rule: Any = struct.field(default_factory=lambda: None)
    eligibility_rule: Any = struct.field(default_factory=lambda: None)
    reconciliation_rule: Any = struct.field(default_factory=lambda: None)

    # Economic Spine (optional JAX arrays)
    spine: EconomicSpineJax | None = None
    economic_spine: str | None = struct.field(default=None, pytree_node=False)  # alias/placeholder

    def replace(self, **kwargs: Any) -> ParamsJax:
        """Flax-compatible field replacement."""
        return struct.replace(self, **kwargs)

    @classmethod
    def from_yaml(cls, path: Path | str) -> ParamsJax:
        """Loads parameters from a YAML file via Pydantic for validation."""
        import yaml

        from .params import Params

        with open(path) as f:
            data = yaml.safe_load(f)

        # 1. Flatten the legacy groups if they exist
        flat_data = {}
        for k, v in data.items():
            if isinstance(v, dict) and k not in {"ops", "behavior", "policy"}:
                flat_data.update(v)
            else:
                flat_data[k] = v

        # 2. Use Pydantic to validate and handle nesting
        p_pydantic = Params(**flat_data)
        return p_pydantic.to_params_jax()

Functions

replace(**kwargs)

Flax-compatible field replacement.

Source code in src/nhra_gt/domain/state.py
def replace(self, **kwargs: Any) -> ParamsJax:
    """Flax-compatible field replacement."""
    return struct.replace(self, **kwargs)
from_yaml(path) classmethod

Loads parameters from a YAML file via Pydantic for validation.

Source code in src/nhra_gt/domain/state.py
@classmethod
def from_yaml(cls, path: Path | str) -> ParamsJax:
    """Loads parameters from a YAML file via Pydantic for validation."""
    import yaml

    from .params import Params

    with open(path) as f:
        data = yaml.safe_load(f)

    # 1. Flatten the legacy groups if they exist
    flat_data = {}
    for k, v in data.items():
        if isinstance(v, dict) and k not in {"ops", "behavior", "policy"}:
            flat_data.update(v)
        else:
            flat_data[k] = v

    # 2. Use Pydantic to validate and handle nesting
    p_pydantic = Params(**flat_data)
    return p_pydantic.to_params_jax()

StateJax

JAX-compatible simulation state (Global Orchestrator).

The root PyTree for the entire simulation state. It contains both global aggregates and hierarchical jurisdictional/LHN states.

Source code in src/nhra_gt/domain/state.py
@struct.dataclass
class StateJax:
    """
    JAX-compatible simulation state (Global Orchestrator).

    The root PyTree for the entire simulation state. It contains both global
    aggregates and hierarchical jurisdictional/LHN states.
    """

    year: Any
    month: Any
    pressure: Any
    occupancy: Any
    offload_min: Any
    within4: Any

    # Fiscal / bargaining state
    effective_cth_share: Any = 0.38
    efficiency_gap: Any = 0.10
    discharge_delay: Any = 1.0
    political_capital: Any = 1.0
    equity_index: Any = 1.0
    reconciliation_balance: Any = 0.0
    bailout_expectation: Any = 0.0
    total_block_revenue: Any = 0.0

    # Orchestrator state
    system_mode: Any = 0  # Mapped from SystemModeJax
    agreement_clock: Any = 5
    workforce_pool: Any = 1.0
    target_capacity: Any = 1.0
    current_capacity: Any = 1.0
    coding_intensity: Any = 1.0
    reputation_score: Any = 1.0
    jurisdiction_id: Any = 0

    # Per-LHN (flat) state used by tests + simple vmaps
    lhn_pressure: jnp.ndarray = struct.field(default_factory=lambda: jnp.zeros(5))
    lhn_nwau: jnp.ndarray = struct.field(default_factory=lambda: jnp.zeros(5))

    # Hierarchical Entities (optional richer representation)
    jurisdictions: JurisdictionState | None = None

    # Auditor Agent state
    auditor_suspicion: float = 0.0
    audit_pressure_active: float = 0.0
    adjustment_costs: float = 0.0

    # Lags & Measurement
    # Buffers store up to 12 months of history
    lag_buffer_pressure: jnp.ndarray = struct.field(default_factory=lambda: jnp.zeros(12))
    lag_buffer_occupancy: jnp.ndarray = struct.field(default_factory=lambda: jnp.zeros(12))
    lag_buffer_within4: jnp.ndarray = struct.field(default_factory=lambda: jnp.zeros(12))
    lag_buffer_nwau: jnp.ndarray = struct.field(default_factory=lambda: jnp.zeros(12))
    lag_buffer_efficiency_gap: jnp.ndarray = struct.field(default_factory=lambda: jnp.zeros(12))
    lag_buffer_coding: jnp.ndarray = struct.field(default_factory=lambda: jnp.zeros(12))

    # Reported values (lagged) available to agents
    reported_pressure: float = 1.0
    reported_occupancy: float = 0.88
    reported_within4: float = 0.53
    reported_nwau: float = 0.0
    reported_efficiency_gap: float = 0.10
    reported_coding_intensity: float = 1.0

    # Stability Telemetry
    solver_n_equilibria: int = 1
    solver_residual: float = 0.0

    # Patient Choice
    prob_ed: float = 0.5

    # ... (other orchestrator state)
    metrics: MetricsJax = struct.field(default_factory=MetricsJax)

    def replace(self, **kwargs: Any) -> StateJax:
        """Flax-compatible field replacement."""
        return struct.replace(self, **kwargs)

Functions

replace(**kwargs)

Flax-compatible field replacement.

Source code in src/nhra_gt/domain/state.py
def replace(self, **kwargs: Any) -> StateJax:
    """Flax-compatible field replacement."""
    return struct.replace(self, **kwargs)

Functions

lhn_step_jax(lhn, p, strategies, demand, month_growth_factor, offload_noise, discharge_delay_target, workforce_availability)

Operational step for a single LHN agent.

Handles localized demand, capacity constraints, discharge delays, and workforce attrition for a Local Health Network (LHN).

Returns:

Type Description
LhnState

Updated LhnState.

Source code in src/nhra_gt/engine.py
@beartype
def lhn_step_jax(
    lhn: LhnState,
    p: ParamsJax,
    strategies: Any,
    demand: Any,
    month_growth_factor: float,
    offload_noise: Any,
    discharge_delay_target: Any,
    workforce_availability: Any,
) -> LhnState:
    """
    Operational step for a single LHN agent.

    Handles localized demand, capacity constraints, discharge delays, and workforce
    attrition for a Local Health Network (LHN).

    Returns:
        Updated LhnState.
    """
    strategies = _pad_strategies(strategies)
    demand = jnp.asarray(demand)
    offload_noise = jnp.asarray(offload_noise)
    discharge_delay_target = jnp.asarray(discharge_delay_target)
    workforce_availability = jnp.asarray(workforce_availability)
    wf_intensity = strategies[8]
    wf_drain = (
        wf_intensity * p.ops.wf_drain_max + (1.0 - wf_intensity) * p.ops.wf_drain_min
    ) * month_growth_factor
    wf_drain += strategies[12] * p.ops.wf_comp_drain * month_growth_factor

    wf_impact = jnp.exp(p.ops.wf_impact_weight * jnp.maximum(0.0, 1.0 - workforce_availability))

    aged_val, ndis_val, disc_val = strategies[5], strategies[6], strategies[4]
    aged_effect = aged_val * p.ops.aged_coord_effect + (1.0 - aged_val) * (
        p.ops.aged_frag_effect * p.fragmentation_index
    )
    ndis_effect = ndis_val * p.ops.ndis_coord_effect + (1.0 - ndis_val) * (
        p.ops.ndis_frag_effect * p.fragmentation_index
    )
    disc_effect = disc_val * p.ops.disc_coord_effect + (1.0 - disc_val) * p.ops.disc_frag_effect

    discharge = (
        lhn.discharge_delay
        * ((aged_effect * ndis_effect * disc_effect) ** month_growth_factor)
        * wf_impact
    )
    discharge = jnp.clip(
        discharge + p.ops.discharge_update_speed * (discharge_delay_target - discharge),
        p.ops.discharge_clip_min,
        p.ops.discharge_clip_max,
    )

    is_expanding = lhn.target_capacity > lhn.current_capacity
    active_lag = jnp.where(is_expanding, p.expansion_lag, p.contraction_lag)
    capacity = lhn.current_capacity + active_lag * (lhn.target_capacity - lhn.current_capacity)

    wait_min = mm_s_queue_wait_jax(
        demand, 1.0 / jnp.maximum(1e-9, discharge), jnp.array(capacity * p.ops.capacity_scalar), p
    )
    occ = jnp.clip(
        lhn.occupancy
        + p.ops.occ_demand_slope * (demand - 1.0)
        + p.ops.occ_discharge_slope * (discharge - 1.0),
        p.ops.occ_clip_min,
        p.ops.occ_clip_max,
    )
    off = jnp.clip(
        lhn.offload_min + p.ops.offload_occ_slope * (occ - p.ops.offload_occ_base) + offload_noise,
        p.ops.offload_clip_min,
        p.ops.offload_clip_max,
    )
    pidx = (
        p.ops.pressure_base
        + p.ops.pressure_wait_weight * (wait_min / 60.0)
        + p.ops.pressure_occ_weight * (occ - p.ops.pressure_occ_base) / p.ops.pressure_occ_scale
    )

    return lhn.replace(
        pressure=pidx,
        occupancy=occ,
        offload_min=off,
        within4=within4_from_pressure_jax(pidx, p),
        discharge_delay=discharge,
        current_capacity=capacity,
        nwau_actual=occ * 100.0,
        adjustment_costs=p.adjustment_cost_beta * jnp.square(capacity - lhn.current_capacity),
    )

run_simulation_jax(init_state, params, strategies, prng_key, num_steps)

Runs a multi-step JAX simulation using lax.scan.

Parameters:

Name Type Description Default
init_state StateJax

The starting state of the simulation.

required
params ParamsJax

Simulation parameters.

required
strategies Any

Either a single strategy vector (applied to all steps) or a sequence of strategy vectors.

required
prng_key Any

JAX random number generator key.

required
num_steps int

Number of months to simulate.

required

Returns:

Type Description
tuple[StateJax, StateJax]

A tuple containing (final_state, trajectory_of_states).

Source code in src/nhra_gt/engine.py
@beartype
def run_simulation_jax(
    init_state: StateJax,
    params: ParamsJax,
    strategies: Any,
    prng_key: Any,
    num_steps: int,
) -> tuple[StateJax, StateJax]:
    """Runs a multi-step JAX simulation using lax.scan.

    Args:
        init_state: The starting state of the simulation.
        params: Simulation parameters.
        strategies: Either a single strategy vector (applied to all steps)
            or a sequence of strategy vectors.
        prng_key: JAX random number generator key.
        num_steps: Number of months to simulate.

    Returns:
        A tuple containing (final_state, trajectory_of_states).
    """
    strategies = _pad_strategies(strategies)

    def body_func(carry, input_tuple):
        strat, key = input_tuple
        next_s = step_jax(carry, params, strat, key)
        return next_s, next_s

    keys = jax.random.split(prng_key, num_steps)
    return lax.scan(body_func, init_state, (strategies, keys))

step_jax(s, p, strategies, prng_key)

Performs a single JAX-accelerated simulation step (one month).

This function handles the monthly transition of system state, including demand realization, jurisdictional allocation, funding calculation, and performance metric updates.

Parameters:

Name Type Description Default
s StateJax

Current simulation state.

required
p ParamsJax

Global simulation parameters.

required
strategies Any

Strategy vector for the current step.

required
prng_key Any

JAX random number generator key.

required

Returns:

Type Description
StateJax

The updated simulation state for the next step.

Source code in src/nhra_gt/engine.py
@beartype
def step_jax(s: StateJax, p: ParamsJax, strategies: Any, prng_key: Any) -> StateJax:
    """Performs a single JAX-accelerated simulation step (one month).

    This function handles the monthly transition of system state, including
    demand realization, jurisdictional allocation, funding calculation,
    and performance metric updates.

    Args:
        s: Current simulation state.
        p: Global simulation parameters.
        strategies: Strategy vector for the current step.
        prng_key: JAX random number generator key.

    Returns:
        The updated simulation state for the next step.
    """
    strategies = _pad_strategies(strategies)
    mgf = 1.0 / 12.0
    k_dem, k_jur = jax.random.split(prng_key)
    wf_pool = jnp.asarray(s.workforce_pool)

    if s.jurisdictions is None:
        next_m = jnp.where(s.month == 12, 1, s.month + 1)
        next_y = jnp.where(s.month == 12, s.year + 1, s.year)
        return s.replace(year=next_y, month=next_m)

    # 1. Macro demand
    demand_macro, prob_ed = demand_step_jax(
        s, p, strategies, jax.random.normal(k_dem) * jnp.asarray(p.noise_sd)
    )

    # 2. Vectorized Jurisdiction steps
    n_jur = s.jurisdictions.id.shape[0]
    keys = jax.random.split(k_jur, n_jur)

    # Sync global scalar controls into hierarchical state so tests that mutate
    # top-level fields (e.g. capacity, workforce) affect LHN dynamics.
    lhn_states_in = s.jurisdictions.lhn_states.replace(
        pressure=jnp.full_like(s.jurisdictions.lhn_states.pressure, jnp.asarray(s.pressure)),
        occupancy=jnp.full_like(s.jurisdictions.lhn_states.occupancy, jnp.asarray(s.occupancy)),
        within4=jnp.full_like(s.jurisdictions.lhn_states.within4, jnp.asarray(s.within4)),
        offload_min=jnp.full_like(
            s.jurisdictions.lhn_states.offload_min, jnp.asarray(s.offload_min)
        ),
        discharge_delay=jnp.full_like(
            s.jurisdictions.lhn_states.discharge_delay, jnp.asarray(s.discharge_delay)
        ),
        target_capacity=jnp.full_like(
            s.jurisdictions.lhn_states.target_capacity, jnp.asarray(s.target_capacity)
        ),
        current_capacity=jnp.full_like(
            s.jurisdictions.lhn_states.current_capacity, jnp.asarray(s.current_capacity)
        ),
    )
    jurisdictions_in = s.jurisdictions.replace(lhn_states=lhn_states_in)

    new_jurisdictions = jax.vmap(
        lambda j, k: jurisdiction_step_jax(j, p, strategies, demand_macro, mgf, k, wf_pool)
    )(jurisdictions_in, keys)

    venue_shift = strategies[10]
    new_jurisdictions = new_jurisdictions.replace(
        total_block_revenue=jnp.full_like(
            new_jurisdictions.total_block_revenue, venue_shift * p.ops.venue_shift_revenue_scale
        )
    )

    # 3. Global Aggregation
    avg_pidx = jnp.mean(new_jurisdictions.lhn_states.pressure)
    avg_occ = jnp.mean(new_jurisdictions.lhn_states.occupancy)
    avg_w4 = jnp.mean(new_jurisdictions.lhn_states.within4)
    avg_target_capacity = jnp.mean(new_jurisdictions.lhn_states.target_capacity)
    avg_current_capacity = jnp.mean(new_jurisdictions.lhn_states.current_capacity)
    adjustment_costs = jnp.mean(new_jurisdictions.lhn_states.adjustment_costs)
    next_reconciliation_balance = s.reconciliation_balance - adjustment_costs

    # Auditor: suspicion rises with gaming and decays otherwise.
    coding_strategy = strategies[7]
    coding_signal = jnp.maximum(0.0, jnp.asarray(s.coding_intensity) - 1.0) * coding_strategy
    next_suspicion = jnp.where(
        coding_strategy > p.ops.decision_threshold,
        jnp.clip(s.auditor_suspicion + p.ops.auditor_suspicion_increment * coding_signal, 0.0, 1.0),
        jnp.clip(s.auditor_suspicion * p.ops.auditor_suspicion_decay, 0.0, 1.0),
    )
    next_audit_pressure_active = jnp.clip(
        p.ops.auditor_pressure_base + next_suspicion * p.audit_pressure, 0.0, 2.0
    )

    # 4. Workforce Update
    wf_intensity = strategies[8]
    wf_drain = (
        jnp.sum(
            new_jurisdictions.lhn_states.occupancy
            * (p.ops.wf_drain_base + p.ops.wf_drain_intensity * wf_intensity)
        )
        * mgf
    )
    new_wf_pool = jnp.clip(
        s.workforce_pool - wf_drain + p.ops.wf_recovery_rate * mgf,
        p.ops.wf_pool_min,
        p.ops.wf_pool_max,
    )

    # 5. Roll time and buffers
    next_m = jnp.where(s.month == 12, 1, s.month + 1)
    next_y = jnp.where(s.month == 12, s.year + 1, s.year)
    next_agreement_clock = jnp.where(
        s.month == 12,
        jnp.where(s.agreement_clock <= 0, 5, s.agreement_clock - 1),
        s.agreement_clock,
    )

    def _renegotiate(jurs: JurisdictionState) -> JurisdictionState:
        from nhra_gt.solvers_jax import discrete_nash_jax, stackelberg_jax
        from nhra_gt.subgames.games_jax import GameParamsJax, renegotiation_game_jax

        # Aggregate params for game
        gp = GameParamsJax(
            pressure=avg_pidx,
            efficiency_gap=jnp.mean(jurs.efficiency_gap),
            discharge_delay=jnp.mean(jurs.lhn_states.discharge_delay),
            political_salience=p.political_salience,
            audit_pressure=p.audit_pressure,
            cost_shifting_intensity=p.cost_shifting_intensity,
            political_capital=jnp.mean(jurs.political_capital),
            behavior=p.behavior,
        )

        u_row, u_col = renegotiation_game_jax(gp)

        # Use sequential solver if configured
        def solve_nash() -> tuple[Float[Array, "N_ACTIONS"], Float[Array, "N_ACTIONS"]]:
            return discrete_nash_jax(u_row, u_col)

        def solve_stackelberg() -> tuple[Float[Array, "N_ACTIONS"], Float[Array, "N_ACTIONS"]]:
            # Assume Commonwealth (Row) is Leader
            return stackelberg_jax(u_row, u_col)

        p_row, q_col = lax.cond(p.use_sequential_bargaining, solve_stackelberg, solve_nash)

        cth_concede = p_row[0] > p.ops.decision_threshold
        state_hold_up = q_col[1] > p.ops.decision_threshold

        base_increase = jnp.where(
            jnp.asarray(s.occupancy) > p.ops.reneg_occ_threshold,
            p.ops.reneg_share_inc_high,
            p.ops.reneg_share_inc_low,
        )
        increase = jnp.where(
            cth_concede & state_hold_up,
            base_increase,
            jnp.where(cth_concede | state_hold_up, 0.5 * base_increase, 0.0),
        )

        next_share = jnp.clip(
            p.nominal_cth_share_target + increase,
            p.ops.reneg_share_clip_min,
            p.ops.reneg_share_clip_max,
        )
        next_share_batched = jnp.full_like(jurs.effective_cth_share, next_share)
        return jurs.replace(effective_cth_share=next_share_batched)

    do_renegotiate = (s.month == 12) & (s.agreement_clock == 0)
    new_jurisdictions = lax.cond(do_renegotiate, _renegotiate, lambda j: j, new_jurisdictions)

    # Apply cap rule (hard vs soft) to effective Commonwealth share.
    nwau_growth = jnp.maximum(0.0, avg_occ - jnp.asarray(p.occupancy_base))
    cap_rule = getattr(p, "cap_rule", None)
    cap_factor = cap_rule.apply(nwau_growth) if cap_rule is not None else 1.0
    new_jurisdictions = new_jurisdictions.replace(
        effective_cth_share=new_jurisdictions.effective_cth_share * cap_factor
    )
    eff_share = jnp.mean(new_jurisdictions.effective_cth_share)

    (nb_p, nb_o, nb_w, nb_n, nb_e, nb_c, rp, ro, rw, rn, re, rc) = update_lag_buffers(
        s,
        p,
        avg_pidx,
        avg_occ,
        avg_w4,
        jnp.sum(new_jurisdictions.lhn_states.nwau_actual),
        jnp.mean(new_jurisdictions.efficiency_gap),
        jnp.asarray(s.coding_intensity),
    )

    return s.replace(
        year=next_y,
        month=next_m,
        agreement_clock=next_agreement_clock,
        target_capacity=avg_target_capacity,
        current_capacity=avg_current_capacity,
        reconciliation_balance=next_reconciliation_balance,
        adjustment_costs=adjustment_costs,
        pressure=avg_pidx,
        occupancy=avg_occ,
        within4=avg_w4,
        effective_cth_share=eff_share,
        efficiency_gap=jnp.mean(new_jurisdictions.efficiency_gap),
        discharge_delay=jnp.mean(new_jurisdictions.lhn_states.discharge_delay),
        total_block_revenue=jnp.mean(new_jurisdictions.total_block_revenue),
        lhn_pressure=new_jurisdictions.lhn_states.pressure[0],
        lhn_nwau=new_jurisdictions.lhn_states.nwau_actual[0],
        jurisdictions=new_jurisdictions,
        auditor_suspicion=next_suspicion,
        audit_pressure_active=next_audit_pressure_active,
        metrics=s.metrics.replace(
            cumulative_adjustment_costs=s.metrics.cumulative_adjustment_costs + adjustment_costs
        ),
        workforce_pool=new_wf_pool,
        prob_ed=prob_ed,
        lag_buffer_pressure=nb_p,
        lag_buffer_occupancy=nb_o,
        lag_buffer_within4=nb_w,
        lag_buffer_nwau=nb_n,
        lag_buffer_efficiency_gap=nb_e,
        lag_buffer_coding=nb_c,
        reported_pressure=rp,
        reported_occupancy=ro,
        reported_within4=rw,
        reported_nwau=rn,
        reported_efficiency_gap=re,
        reported_coding_intensity=rc,
        system_mode=update_system_mode_jax(s, p, avg_pidx),
    )