JAX Backend for High-Performance ComputingΒΆ

This tutorial demonstrates how to use the JAX backend for high-performance computing.

from innovate.backend import use_backend

# Switch to the JAX backend
use_backend("jax")

# Now, all models will use the JAX backend for their computations.
# For example, let's use the BassModel:

import numpy as np
import matplotlib.pyplot as plt
from innovate.diffuse.bass import BassModel

# Initialize the model
model = BassModel()

# Set the parameters
model.params_ = {
    "p": 0.03, "q": 0.38, "m": 1000
}

# Generate the time points
t = np.linspace(0, 20, 100)

# Predict the diffusion
y = model.predict(t)

# Plot the results
plt.figure(figsize=(10, 6))
plt.plot(t, y, label=f'p=0.03, q=0.38, m=1000')
plt.title("Bass Diffusion Model with JAX Backend")
plt.xlabel("Time")
plt.ylabel("Cumulative Adopters")
plt.grid(True, linestyle='--', alpha=0.6)
plt.legend()
plt.show()