JAX-compatible state definitions for the NHRA simulation.
This module defines the data structures used to represent the state of the system
at various levels of granularity (LHN, Jurisdiction, Global). It uses
flax.struct.dataclass to ensure compatibility with JAX transformations
like jit, vmap, and grad.
Attributes
JurisdictionState = JurisdictionStateJax
module-attribute
Granular state for a single Jurisdiction (State/Territory).
Aggregates LHNs and manages jurisdictional-level fiscal and political state.
Classes
SystemModeJax
Bases: IntEnum
Enumeration of system-wide operational modes.
Source code in src/nhra_gt/domain/state.py
| class SystemModeJax(IntEnum):
"""Enumeration of system-wide operational modes."""
NORMAL = 0
STRESS = 1
CRISIS = 2
RECOVERY = 3
|
EconomicSpineJax
JAX-compatible container for economic indices.
Source code in src/nhra_gt/domain/state.py
| @struct.dataclass
class EconomicSpineJax:
"""JAX-compatible container for economic indices."""
years: jnp.ndarray # int32[N]
nep_per_nwau: jnp.ndarray # float64[N]
wpi_health_index: jnp.ndarray # float64[N]
|
MetricsJax
Accumulated metrics for policy optimization and objective functions.
Source code in src/nhra_gt/domain/state.py
| @struct.dataclass
class MetricsJax:
"""Accumulated metrics for policy optimization and objective functions."""
cumulative_pressure: float = 0.0
cumulative_budget_variance: float = 0.0
max_occupancy: float = 0.0
min_within4: float = 1.0
# Leakage Metrics
cumulative_indexation_loss: float = 0.0
cumulative_cap_loss: float = 0.0
cumulative_audit_loss: float = 0.0
cumulative_adjustment_costs: float = 0.0
# Stability Metrics
max_solver_n_equilibria: int = 0
mean_solver_residual: float = 0.0
def replace(self, **kwargs: Any) -> MetricsJax:
return struct.replace(self, **kwargs)
|
OperationalParamsJax
Hidden coefficients for system dynamics (JAX version).
Source code in src/nhra_gt/domain/state.py
| @struct.dataclass
class OperationalParamsJax:
"""Hidden coefficients for system dynamics (JAX version)."""
wf_drain_max: float = 0.2
wf_drain_min: float = 0.1
wf_comp_drain: float = 0.1
wf_impact_weight: float = 0.5
aged_coord_effect: float = 0.95
aged_frag_effect: float = 1.02
ndis_coord_effect: float = 0.96
ndis_frag_effect: float = 1.03
disc_coord_effect: float = 0.98
disc_frag_effect: float = 1.01
discharge_update_speed: float = 0.1
eff_gap_decay: float = 0.93
occ_demand_slope: float = 0.015
occ_discharge_slope: float = 0.035
offload_occ_slope: float = 8.0
offload_occ_base: float = 0.88
pressure_base: float = 0.8
pressure_wait_weight: float = 0.2
pressure_occ_weight: float = 0.5
pressure_occ_base: float = 0.8
pressure_occ_scale: float = 0.1
within4_intercept: float = 1.02
within4_slope: float = 0.45
within4_scale: float = 0.20
within4_min: float = 0.05
within4_max: float = 0.85
wf_drain_base: float = 0.02
wf_drain_intensity: float = 0.06
wf_recovery_rate: float = 0.1
wf_pool_min: float = 0.5
wf_pool_max: float = 1.5
reneg_occ_threshold: float = 0.95
reneg_share_inc_high: float = 0.06
reneg_share_inc_low: float = 0.03
init_n_lhns: int = 5
init_efficiency_gap: float = 0.10
init_agreement_clock: int = 5
init_lhn_nwau_base: float = 100.0
jurisdiction_pressure_threshold: float = 1.1
jurisdiction_discharge_target: float = 0.9
jurisdiction_noise_scale: float = 0.8
jurisdiction_noise_base: float = 0.03
decision_threshold: float = 0.5
venue_shift_revenue_scale: float = 100.0
reneg_share_clip_min: float = 0.40
reneg_share_clip_max: float = 0.70
minutes_per_hour: float = 60.0
hours_per_day: float = 24.0
discharge_clip_min: float = 0.75
discharge_clip_max: float = 1.50
occ_clip_min: float = 0.78
occ_clip_max: float = 0.98
offload_clip_min: float = 5.0
offload_clip_max: float = 120.0
capacity_scalar: float = 10.0
auditor_suspicion_increment: float = 0.03
auditor_suspicion_decay: float = 0.95
auditor_pressure_base: float = 0.25
mode_stress_threshold: float = 1.25
mode_crisis_threshold: float = 1.5
mode_normal_recovery_threshold: float = 1.05
mode_recovery_trigger_threshold: float = 1.3
mode_normal_final_threshold: float = 1.1
mode_crisis_relapse_threshold: float = 1.4
queuing_outside_utility: float = -100.0
queuing_init_prob: float = 0.5
demand_shift_slope: float = 1.04
demand_shift_base: float = 0.35
demand_invest_base: float = 0.96
demand_scale: float = 2.0
wait_time_cap: float = 1440.0
wait_time_min: float = 5.0
|
BehavioralParamsJax
Hidden coefficients for subgame payoffs (JAX version).
Source code in src/nhra_gt/domain/state.py
| @struct.dataclass
class BehavioralParamsJax:
"""Hidden coefficients for subgame payoffs (JAX version)."""
# Definition Game
def_realism_benefit_base: float = 0.5
def_realism_eg_weight: float = 0.8
def_realism_pr_weight: float = 0.4
def_realism_cost_base: float = 0.25
def_realism_ps_weight: float = 0.35
def_strict_benefit_base: float = 0.35
def_strict_ps_weight: float = 0.45
def_strict_cost_base: float = 0.30
def_strict_pr_weight: float = 0.50
# Bargaining Game
barg_converge_gain_base: float = 0.45
barg_converge_pr_weight: float = 0.25
barg_converge_pc_weight: float = 0.20
barg_conflict_cost_base: float = 0.55
barg_conflict_pr_weight: float = 0.90
barg_narrative_gain_base: float = 0.25
barg_narrative_ps_weight: float = 0.50
# Cost Shifting Game
shift_coop_gain_base: float = 0.55
shift_coop_eg_weight: float = 0.45
shift_gain_base: float = 0.35
shift_eg_weight: float = 0.75
shift_csi_weight: float = 1.0
shift_pr_cost_weight: float = 0.65
# Renegotiation Game
reneg_cth_fallout_weight: float = 0.8
reneg_state_failure_weight: float = 0.6
reneg_concede_agree_cost: float = 0.1
reneg_concede_holdup_cost: float = 0.3
reneg_concede_agree_gain: float = 0.2
reneg_concede_holdup_gain: float = 0.5
def_row_realism_offset: float = 0.15
def_row_strict_offset: float = 0.45
def_col_realism_offset: float = 0.15
def_col_realism_cost: float = 0.20
def_col_strict_cost_1: float = 0.35
def_col_strict_cost_2: float = 0.55
barg_row_converge_ps_penalty: float = 0.10
barg_row_converge_base_penalty: float = 0.25
barg_row_converge_pr_penalty: float = 0.15
barg_row_narrative_pr_penalty: float = 0.10
barg_col_converge_ps_penalty: float = 0.05
barg_col_converge_base_penalty: float = 0.30
barg_col_converge_pr_penalty: float = 0.20
barg_col_narrative_penalty: float = 0.20
shift_row_coop_penalty: float = 0.25
shift_row_shift_pr_penalty: float = 0.35
shift_row_shift_base_penalty: float = 0.60
shift_row_shift_pr_heavy_penalty: float = 1.00
disc_benefit_base: float = 0.70
disc_benefit_slope: float = 0.80
disc_cost_base: float = 0.30
disc_cost_slope: float = 0.10
disc_pr_penalty_weight: float = 0.45
disc_row_coop_penalty: float = 0.40
disc_row_frag_base_penalty: float = 0.25
disc_row_frag_heavy_penalty: float = 0.70
disc_row_frag_pr_penalty: float = 1.10
disc_col_coop_penalty: float = 0.35
disc_col_frag_base_penalty: float = 0.25
disc_col_frag_heavy_penalty: float = 0.70
disc_col_frag_pr_penalty: float = 1.00
gov_safety_gain_base: float = 0.55
gov_safety_gain_slope: float = 0.35
gov_int_cost_base: float = 0.20
gov_int_cost_slope: float = 0.35
gov_frag_risk_base: float = 0.40
gov_frag_risk_slope: float = 0.60
gov_row_safety_penalty: float = 0.25
gov_row_frag_bonus: float = 0.10
gov_row_frag_penalty: float = 0.45
gov_col_safety_penalty_1: float = 0.10
gov_col_safety_penalty_2: float = 0.20
gov_col_frag_penalty_1: float = 0.35
gov_col_frag_penalty_2: float = 0.55
aged_coord_benefit_base: float = 0.6
aged_coord_benefit_slope: float = 0.4
aged_frag_cost_weight: float = 0.5
ndis_coord_benefit_base: float = 0.5
ndis_coord_benefit_slope: float = 0.5
ndis_frag_cost_weight: float = 0.6
coding_upcode_gain_base: float = 0.3
coding_upcode_gain_slope: float = 0.7
coding_penalty_weight: float = 0.8
coding_audit_cost: float = 0.2
coding_recovery_weight: float = 0.4
comp_leakage_base: float = 0.40
comp_leakage_slope: float = 0.70
comp_admin_base: float = 0.18
comp_admin_slope: float = 0.45
comp_row_tight_bonus: float = 0.15
comp_row_light_penalty: float = 0.80
comp_col_tight_penalty: float = 0.10
comp_col_tight_ai_penalty: float = 0.35
comp_col_light_base_bonus: float = 0.20
venue_shift_gain_base: float = 0.25
venue_shift_gain_eg_weight: float = 0.5
venue_shift_gain_pr_weight: float = 0.3
venue_strict_penalty: float = 0.45
venue_enforce_cost: float = 0.15
venue_col_strict_penalty_weight: float = 0.15
venue_col_strict_base_bonus: float = 0.10
comp_capture_base: float = 0.4
comp_capture_slope: float = 0.6
comp_cost_base: float = 0.3
comp_cost_slope: float = 0.2
|
PolicyParamsJax
Hidden coefficients for rule logic (JAX version).
Source code in src/nhra_gt/domain/state.py
| @struct.dataclass
class PolicyParamsJax:
"""Hidden coefficients for rule logic (JAX version)."""
cap_soft_multiplier: float = 0.5
audit_prop_multiplier: float = 2.0
audit_threshold_penalty_high: float = 3.0
audit_threshold_penalty_low: float = 0.1
eligibility_venue_shift_impact: float = 0.10
eligibility_abf_share_min: float = 0.5
recon_bailout_increment: float = 0.05
recon_safety_net_generosity: float = 1.5
iv_pooled_csi_mult: float = 0.75
iv_pooled_csi_min: float = 0.05
iv_pooled_csi_max: float = 0.60
iv_ucc_frag_mult: float = 0.80
iv_ucc_frag_min: float = 0.60
iv_ucc_frag_max: float = 1.50
iv_nep_realism_inc: float = 0.03
iv_nep_realism_regional_inc: float = 0.04
iv_nep_realism_remote_inc: float = 0.05
iv_nep_realism_min: float = 0.6
iv_nep_realism_max: float = 1.0
iv_aged_ndis_delay_mult: float = 0.90
iv_aged_ndis_delay_min: float = 0.6
iv_aged_ndis_delay_max: float = 1.4
iv_cap_cumulative_growth: float = 0.070
iv_audit_relief_audit_mult: float = 0.70
iv_audit_relief_audit_min: float = 0.05
iv_audit_relief_audit_max: float = 1.0
iv_audit_relief_burden_mult: float = 0.8
iv_audit_relief_burden_min: float = 0.05
iv_audit_relief_burden_max: float = 0.6
|
ParamsJax
Bases: ParamsGenerated
JAX-compatible simulation parameters.
This class extends the auto-generated scalar parameters with runtime
objects and JAX-native rules.
Source code in src/nhra_gt/domain/state.py
| @struct.dataclass
class ParamsJax(ParamsGenerated):
"""JAX-compatible simulation parameters.
This class extends the auto-generated scalar parameters with runtime
objects and JAX-native rules.
"""
# Grouped coefficients
ops: OperationalParamsJax = struct.field(default_factory=OperationalParamsJax)
behavior: BehavioralParamsJax = struct.field(default_factory=BehavioralParamsJax)
policy: PolicyParamsJax = struct.field(default_factory=PolicyParamsJax)
# Modular Rules (JAX-compatible PyTrees)
cap_rule: Any = struct.field(default_factory=lambda: None)
audit_rule: Any = struct.field(default_factory=lambda: None)
eligibility_rule: Any = struct.field(default_factory=lambda: None)
reconciliation_rule: Any = struct.field(default_factory=lambda: None)
# Economic Spine (optional JAX arrays)
spine: EconomicSpineJax | None = None
economic_spine: str | None = struct.field(default=None, pytree_node=False) # alias/placeholder
def replace(self, **kwargs: Any) -> ParamsJax:
"""Flax-compatible field replacement."""
return struct.replace(self, **kwargs)
@classmethod
def from_yaml(cls, path: Path | str) -> ParamsJax:
"""Loads parameters from a YAML file via Pydantic for validation."""
import yaml
from .params import Params
with open(path) as f:
data = yaml.safe_load(f)
# 1. Flatten the legacy groups if they exist
flat_data = {}
for k, v in data.items():
if isinstance(v, dict) and k not in {"ops", "behavior", "policy"}:
flat_data.update(v)
else:
flat_data[k] = v
# 2. Use Pydantic to validate and handle nesting
p_pydantic = Params(**flat_data)
return p_pydantic.to_params_jax()
|
Functions
replace(**kwargs)
Flax-compatible field replacement.
Source code in src/nhra_gt/domain/state.py
| def replace(self, **kwargs: Any) -> ParamsJax:
"""Flax-compatible field replacement."""
return struct.replace(self, **kwargs)
|
from_yaml(path)
classmethod
Loads parameters from a YAML file via Pydantic for validation.
Source code in src/nhra_gt/domain/state.py
| @classmethod
def from_yaml(cls, path: Path | str) -> ParamsJax:
"""Loads parameters from a YAML file via Pydantic for validation."""
import yaml
from .params import Params
with open(path) as f:
data = yaml.safe_load(f)
# 1. Flatten the legacy groups if they exist
flat_data = {}
for k, v in data.items():
if isinstance(v, dict) and k not in {"ops", "behavior", "policy"}:
flat_data.update(v)
else:
flat_data[k] = v
# 2. Use Pydantic to validate and handle nesting
p_pydantic = Params(**flat_data)
return p_pydantic.to_params_jax()
|
BaselineProvider
Manages loading of the automated data spine and baseline parameters.
Provides a centralized interface for synchronizing empirical data into
the JAX simulation environment.
Source code in src/nhra_gt/domain/state.py
| class BaselineProvider:
"""
Manages loading of the automated data spine and baseline parameters.
Provides a centralized interface for synchronizing empirical data into
the JAX simulation environment.
"""
@staticmethod
def load_spine(
path: Path | str = "data/calibration/historical_normalized.csv",
) -> EconomicSpineJax:
"""
Loads the economic spine (NEP, WPI) from a CSV file.
Uses Polars if available, otherwise falls back to Pandas.
"""
required = {"year", "nep_per_nwau", "wpi_health_index"}
if pl is None:
import pandas as pd
df_pd = pd.read_csv(path)
missing = required - set(df_pd.columns)
if missing:
raise ValueError(f"Spine missing required columns: {sorted(missing)}")
return EconomicSpineJax(
years=jnp.array(df_pd["year"].to_numpy().astype(jnp.int32)),
nep_per_nwau=jnp.array(df_pd["nep_per_nwau"].to_numpy()),
wpi_health_index=jnp.array(df_pd["wpi_health_index"].to_numpy()),
)
df = pl.read_csv(path)
missing = required - set(df.columns)
if missing:
raise ValueError(f"Spine missing required columns: {sorted(missing)}")
return EconomicSpineJax(
years=jnp.array(df["year"].to_numpy().astype(jnp.int32)),
nep_per_nwau=jnp.array(df["nep_per_nwau"].to_numpy()),
wpi_health_index=jnp.array(df["wpi_health_index"].to_numpy()),
)
@classmethod
def get_baseline(cls, config_path: str = "configs/defaults.yaml") -> tuple[ParamsJax, StateJax]:
"""
Retrieves baseline parameters and state for a new simulation run.
"""
from nhra_gt.engine_jax import baseline_state_jax
params = ParamsJax.from_yaml(config_path)
# Check if spine exists
spine_path = Path("data/calibration/historical_normalized.csv")
if spine_path.exists():
try:
spine = cls.load_spine(spine_path)
except ValueError as exc:
logger.warning("Skipping spine load: %s", exc)
else:
params = params.replace(spine=spine)
state = baseline_state_jax(2025, params)
return params, state
|
Functions
load_spine(path='data/calibration/historical_normalized.csv')
staticmethod
Loads the economic spine (NEP, WPI) from a CSV file.
Uses Polars if available, otherwise falls back to Pandas.
Source code in src/nhra_gt/domain/state.py
| @staticmethod
def load_spine(
path: Path | str = "data/calibration/historical_normalized.csv",
) -> EconomicSpineJax:
"""
Loads the economic spine (NEP, WPI) from a CSV file.
Uses Polars if available, otherwise falls back to Pandas.
"""
required = {"year", "nep_per_nwau", "wpi_health_index"}
if pl is None:
import pandas as pd
df_pd = pd.read_csv(path)
missing = required - set(df_pd.columns)
if missing:
raise ValueError(f"Spine missing required columns: {sorted(missing)}")
return EconomicSpineJax(
years=jnp.array(df_pd["year"].to_numpy().astype(jnp.int32)),
nep_per_nwau=jnp.array(df_pd["nep_per_nwau"].to_numpy()),
wpi_health_index=jnp.array(df_pd["wpi_health_index"].to_numpy()),
)
df = pl.read_csv(path)
missing = required - set(df.columns)
if missing:
raise ValueError(f"Spine missing required columns: {sorted(missing)}")
return EconomicSpineJax(
years=jnp.array(df["year"].to_numpy().astype(jnp.int32)),
nep_per_nwau=jnp.array(df["nep_per_nwau"].to_numpy()),
wpi_health_index=jnp.array(df["wpi_health_index"].to_numpy()),
)
|
get_baseline(config_path='configs/defaults.yaml')
classmethod
Retrieves baseline parameters and state for a new simulation run.
Source code in src/nhra_gt/domain/state.py
| @classmethod
def get_baseline(cls, config_path: str = "configs/defaults.yaml") -> tuple[ParamsJax, StateJax]:
"""
Retrieves baseline parameters and state for a new simulation run.
"""
from nhra_gt.engine_jax import baseline_state_jax
params = ParamsJax.from_yaml(config_path)
# Check if spine exists
spine_path = Path("data/calibration/historical_normalized.csv")
if spine_path.exists():
try:
spine = cls.load_spine(spine_path)
except ValueError as exc:
logger.warning("Skipping spine load: %s", exc)
else:
params = params.replace(spine=spine)
state = baseline_state_jax(2025, params)
return params, state
|
LhnState
Granular state for a single Local Hospital Network (LHN).
Represents the operational and strategic status of a hospital cluster,
including its pressure, occupancy, and internal choices.
Source code in src/nhra_gt/domain/state.py
| @struct.dataclass
class LhnState:
"""
Granular state for a single Local Hospital Network (LHN).
Represents the operational and strategic status of a hospital cluster,
including its pressure, occupancy, and internal choices.
"""
id: int
pressure: float = 1.0
occupancy: float = 0.88
within4: float = 0.53
offload_min: float = 18.0
nwau_actual: float = 100.0
nwau_reported: float = 100.0
coding_intensity: float = 1.0
target_capacity: float = 1.0
current_capacity: float = 1.0
discharge_delay: float = 1.0
adjustment_costs: float = 0.0
def replace(self, **kwargs: Any) -> LhnState:
"""Flax-compatible field replacement."""
return struct.replace(self, **kwargs)
|
Functions
replace(**kwargs)
Flax-compatible field replacement.
Source code in src/nhra_gt/domain/state.py
| def replace(self, **kwargs: Any) -> LhnState:
"""Flax-compatible field replacement."""
return struct.replace(self, **kwargs)
|
JurisdictionStateJax
JAX-native jurisdiction state.
Source code in src/nhra_gt/domain/state.py
| @struct.dataclass
class JurisdictionStateJax:
"""JAX-native jurisdiction state."""
id: Any
reconciliation_balance: Any = 0.0
bailout_expectation: Any = 0.0
political_capital: Any = 1.0
effective_cth_share: Any = 0.38
efficiency_gap: Any = 0.10
equity_index: Any = 1.0
total_block_revenue: Any = 0.0
lhn_states: Any = None # Vectorized
|
StateJax
JAX-compatible simulation state (Global Orchestrator).
The root PyTree for the entire simulation state. It contains both global
aggregates and hierarchical jurisdictional/LHN states.
Source code in src/nhra_gt/domain/state.py
| @struct.dataclass
class StateJax:
"""
JAX-compatible simulation state (Global Orchestrator).
The root PyTree for the entire simulation state. It contains both global
aggregates and hierarchical jurisdictional/LHN states.
"""
year: Any
month: Any
pressure: Any
occupancy: Any
offload_min: Any
within4: Any
# Fiscal / bargaining state
effective_cth_share: Any = 0.38
efficiency_gap: Any = 0.10
discharge_delay: Any = 1.0
political_capital: Any = 1.0
equity_index: Any = 1.0
reconciliation_balance: Any = 0.0
bailout_expectation: Any = 0.0
total_block_revenue: Any = 0.0
# Orchestrator state
system_mode: Any = 0 # Mapped from SystemModeJax
agreement_clock: Any = 5
workforce_pool: Any = 1.0
target_capacity: Any = 1.0
current_capacity: Any = 1.0
coding_intensity: Any = 1.0
reputation_score: Any = 1.0
jurisdiction_id: Any = 0
# Per-LHN (flat) state used by tests + simple vmaps
lhn_pressure: jnp.ndarray = struct.field(default_factory=lambda: jnp.zeros(5))
lhn_nwau: jnp.ndarray = struct.field(default_factory=lambda: jnp.zeros(5))
# Hierarchical Entities (optional richer representation)
jurisdictions: JurisdictionState | None = None
# Auditor Agent state
auditor_suspicion: float = 0.0
audit_pressure_active: float = 0.0
adjustment_costs: float = 0.0
# Lags & Measurement
# Buffers store up to 12 months of history
lag_buffer_pressure: jnp.ndarray = struct.field(default_factory=lambda: jnp.zeros(12))
lag_buffer_occupancy: jnp.ndarray = struct.field(default_factory=lambda: jnp.zeros(12))
lag_buffer_within4: jnp.ndarray = struct.field(default_factory=lambda: jnp.zeros(12))
lag_buffer_nwau: jnp.ndarray = struct.field(default_factory=lambda: jnp.zeros(12))
lag_buffer_efficiency_gap: jnp.ndarray = struct.field(default_factory=lambda: jnp.zeros(12))
lag_buffer_coding: jnp.ndarray = struct.field(default_factory=lambda: jnp.zeros(12))
# Reported values (lagged) available to agents
reported_pressure: float = 1.0
reported_occupancy: float = 0.88
reported_within4: float = 0.53
reported_nwau: float = 0.0
reported_efficiency_gap: float = 0.10
reported_coding_intensity: float = 1.0
# Stability Telemetry
solver_n_equilibria: int = 1
solver_residual: float = 0.0
# Patient Choice
prob_ed: float = 0.5
# ... (other orchestrator state)
metrics: MetricsJax = struct.field(default_factory=MetricsJax)
def replace(self, **kwargs: Any) -> StateJax:
"""Flax-compatible field replacement."""
return struct.replace(self, **kwargs)
|
Functions
replace(**kwargs)
Flax-compatible field replacement.
Source code in src/nhra_gt/domain/state.py
| def replace(self, **kwargs: Any) -> StateJax:
"""Flax-compatible field replacement."""
return struct.replace(self, **kwargs)
|