JAX Backend for High-Performance ComputingΒΆ
This tutorial demonstrates how to use the JAX backend for high-performance computing.
from innovate.backend import use_backend
# Switch to the JAX backend
use_backend("jax")
# Now, all models will use the JAX backend for their computations.
# For example, let's use the BassModel:
import numpy as np
import matplotlib.pyplot as plt
from innovate.diffuse.bass import BassModel
# Initialize the model
model = BassModel()
# Set the parameters
model.params_ = {
"p": 0.03, "q": 0.38, "m": 1000
}
# Generate the time points
t = np.linspace(0, 20, 100)
# Predict the diffusion
y = model.predict(t)
# Plot the results
plt.figure(figsize=(10, 6))
plt.plot(t, y, label=f'p=0.03, q=0.38, m=1000')
plt.title("Bass Diffusion Model with JAX Backend")
plt.xlabel("Time")
plt.ylabel("Cumulative Adopters")
plt.grid(True, linestyle='--', alpha=0.6)
plt.legend()
plt.show()