nhra_gt.optimization_jax¶
Policy Optimization and Parameter Search.
Utilities for finding optimal policy levers using JAX-accelerated objective functions.
Classes¶
Functions¶
optimize_policy_jax(init_state, base_params, strategies, prng_key, num_steps, param_to_optimize, bounds, objective_fn)
¶
Optimizes a single parameter to minimize a given objective.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
init_state
|
StateJax
|
Starting state. |
required |
base_params
|
Params
|
Base parameters. |
required |
strategies
|
ndarray
|
Pre-defined strategies for the simulation. |
required |
prng_key
|
Any
|
Random key. |
required |
num_steps
|
int
|
Simulation length. |
required |
param_to_optimize
|
str
|
Name of field in Params to vary. |
required |
bounds
|
tuple[float, float]
|
(min, max) for the parameter. |
required |
objective_fn
|
Callable[[StateJax, PyTree], float]
|
Function to minimize. |
required |
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
A dictionary with the optimized parameter value and result metadata. |