Intialize
[1]:
%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from classy_sz import Class as Class_sz
[3]:
cosmo_params = {
'omega_b': 0.02242,
'omega_cdm': 0.11933,
'H0': 67.66, # use H0 because this is what is used by the emulators and to avoid any ambiguity when comparing with camb.
'tau_reio': 0.0561,
'ln10^{10}A_s': 3.047,
'n_s': 0.9665
}
Compute H(z)
[4]:
%%time
classy_sz = Class_sz()
classy_sz.set(cosmo_params)
classy_sz.set({
'output':' ',
'skip_hubble':0,
})
CPU times: user 304 μs, sys: 1.11 ms, total: 1.41 ms
Wall time: 1.64 ms
[4]:
True
[5]:
%%time
classy_sz.compute_class_szfast()
CPU times: user 10.3 ms, sys: 16.5 ms, total: 26.9 ms
Wall time: 32.8 ms
[6]:
classy_sz.get_H(0.)
[6]:
np.float64(0.00033359796437495265)
[8]:
classy_sz.Hubble(1.)
[8]:
np.float64(0.0004023684284735257)
[9]:
classy_sz.get_hubble_at_z(1,params_values_dict = cosmo_params)
[9]:
array(0.00040237)
Convert to usual units
[15]:
conv_fac = 299792.458 # speed of light
print(classy_sz.Hubble(0.)*conv_fac)
print(classy_sz.get_hubble_at_z(0,params_values_dict = cosmo_params)*conv_fac)
67.66687000949837
67.66687000949837
Plot H(z)
[8]:
label_size = 15
title_size = 20
legend_size = 13
handle_length = 1.5
fig, (ax1) = plt.subplots(1,1,figsize=(18,5))
ax = ax1
ax.tick_params(axis = 'x',which='both',length=5,direction='in', pad=10)
ax.tick_params(axis = 'y',which='both',length=5,direction='in', pad=5)
ax.xaxis.set_ticks_position('both')
ax.yaxis.set_ticks_position('both')
plt.setp(ax.get_yticklabels(), rotation='horizontal', fontsize=label_size)
plt.setp(ax.get_xticklabels(), fontsize=label_size)
ax.grid( visible=True, which="both", alpha=0.2, linestyle='--')
z = np.linspace(0.,20,100)
ax.plot(z,classy_sz.get_H(z),ls='-',label='class_szfast')
ax.set_ylabel(r"$H(z)$",size=title_size)
ax.set_xlabel(r"$z$",size=title_size)
ax.set_xscale('linear')
ax.set_yscale('log')
ax.set_xlim(0,20)
ax.legend(fontsize=legend_size)
[8]:
<matplotlib.legend.Legend at 0x146b8e000>
Time computations of H(z)
[10]:
%timeit -n 40 classy_sz.compute_class_szfast()
31.4 ms ± 4.76 ms per loop (mean ± std. dev. of 7 runs, 40 loops each)
[11]:
z = np.linspace(0.,20,1000)
%timeit -n 40 classy_sz.get_H(z)
8.61 ms ± 267 µs per loop (mean ± std. dev. of 7 runs, 40 loops each)
[12]:
classy_sz.get_H(z)[0]*classy_sz.h()
[12]:
0.00022571238269609295
[13]:
z = np.linspace(0.,20,1000)
%timeit -n 40 classy_sz.Hubble(z)
6.92 ms ± 114 µs per loop (mean ± std. dev. of 7 runs, 40 loops each)
[14]:
classy_sz.Hubble(z)[0]
[14]:
0.00022571238269609295
[15]:
# let's time it
z = np.linspace(0.,20,1000)
%timeit -n 40 classy_sz.get_hubble_at_z(z,params_values_dict = cosmo_params)
18.8 ms ± 6.96 ms per loop (mean ± std. dev. of 7 runs, 40 loops each)
[16]:
classy_sz.get_hubble_at_z(z,params_values_dict = cosmo_params)[0]
[16]:
0.00022571238269609295
Here we can now update cosmological parameters and recompute.
A strategy can be to initialize classy_sz and then recompute as here.
[17]:
cosmo_params.update({'H0':72.})
[18]:
classy_sz.get_hubble_at_z(z,params_values_dict = cosmo_params)[0]
[18]:
0.00024019108300278296
[19]:
%timeit -n 40 classy_sz.get_hubble_at_z(z,params_values_dict = cosmo_params)[0]
15.6 ms ± 3.89 ms per loop (mean ± std. dev. of 7 runs, 40 loops each)
Comparison without emulators
[20]:
from classy_sz import Class as Class_sz
import numpy as np
cosmo_params = {
'omega_b': 0.02242,
'omega_cdm': 0.11933,
'H0': 67.66, # use H0 because this is what is used by the emulators and to avoid any ambiguity when comparing with camb.
'tau_reio': 0.0561,
'ln10^{10}A_s': 3.047,
'n_s': 0.9665,
'N_ur' :0.00441, # this is the default value in class v3 to get Neff = 3.044
'N_ncdm': 1,
'm_ncdm': 0.02,
'deg_ncdm': 3
}
[21]:
%%time
classy_sz = Class_sz()
classy_sz.set(cosmo_params)
classy_sz.set({
'output':' ',
'skip_background_and_thermo':0,
})
classy_sz.compute()
CPU times: user 58 ms, sys: 5.67 ms, total: 63.6 ms
Wall time: 70.8 ms
[22]:
classy_sz.Hubble(1.)
[22]:
0.00040234621947652185
[23]:
z = np.linspace(0.,20,1000)
%timeit -n 40 classy_sz.Hubble(z)
461 µs ± 134 µs per loop (mean ± std. dev. of 7 runs, 40 loops each)
[24]:
classy_sz.Hubble(z)[0]
[24]:
0.00022568946681106967
Compute with Jax
[5]:
from classy_sz import Class as Class_sz
import jax.numpy as jnp
cosmo_params = {
'omega_b': 0.02242,
'omega_cdm': 0.11933,
'H0': 67.66, # use H0 because this is what is used by the emulators and to avoid any ambiguity when comparing with camb.
'tau_reio': 0.0561,
'ln10^{10}A_s': 3.047,
'n_s': 0.9665
}
[6]:
%%time
classy_sz = Class_sz()
classy_sz.set(cosmo_params)
classy_sz.set({
'output':' ',
'jax' : 1,
})
CPU times: user 1.45 ms, sys: 3.3 ms, total: 4.75 ms
Wall time: 4.99 ms
[6]:
True
[7]:
%%time
classy_sz.compute_class_szfast()
CPU times: user 444 ms, sys: 47.4 ms, total: 491 ms
Wall time: 495 ms
[8]:
z = 1.
classy_sz.get_hubble_at_z(z,params_values_dict = cosmo_params)
[8]:
Array(0.00040237, dtype=float64)
[13]:
cosmo_params.update({'H0':72.})
"%.25f"%classy_sz.get_hubble_at_z(z,params_values_dict = cosmo_params)
[13]:
'0.0004106620973760160566400'
[14]:
cosmo_params.update({'H0':70.})
"%.25f"%classy_sz.get_hubble_at_z(z,params_values_dict = cosmo_params)
[14]:
'0.0004067974877159367843964'
[15]:
"%.25f"%classy_sz.Hubble(1.)[0]
[15]:
'0.0004067974877159367843964'
[16]:
z = jnp.linspace(1.,20,1000)
"%.25f"%classy_sz.Hubble(z)[0]
[16]:
'0.0004067974877159367843964'
[17]:
cosmo_params.update({'H0':67.})
z = jnp.linspace(1.,20,1000)
"%.25f"%classy_sz.Hubble(z)[0]
[17]:
'0.0004067974877159367843964'
Compatibility
[18]:
import jax
z = jnp.linspace(1., 20, 1000)
hubble_values = classy_sz.Hubble(z)
# Check if it's a JAX array
is_jax_array = isinstance(hubble_values, jnp.ndarray)
# Additional check: apply a JAX function to see if it supports JAX transformations
try:
jitted_hubble = jax.jit(classy_sz.Hubble)(z)
supports_jit = True
except Exception as e:
supports_jit = False
print("Error with jax.jit:", e)
print("Is Hubble(z) a JAX array?", is_jax_array)
print("Does Hubble(z) support JAX jit?", supports_jit)
Is Hubble(z) a JAX array? True
Does Hubble(z) support JAX jit? True
[19]:
hubble_values = classy_sz.get_hubble_at_z(z,params_values_dict = cosmo_params)
is_jax_array = isinstance(hubble_values, jnp.ndarray)
print("Is get_hubble_at_z a JAX array?", is_jax_array)
Is get_hubble_at_z a JAX array? True
Gradients
1 dimension
At one redshift
[9]:
z = 1.
def Hubble(H0):
cosmo_params.update({'H0':H0})
hz = classy_sz.get_hubble_at_z(z,params_values_dict = cosmo_params)
return hz
[10]:
Hubble(72)
[10]:
Array(0.00041066, dtype=float64)
[13]:
from jax import grad
# Get the derivative of f with respect to p
dHubble = grad(Hubble)
[14]:
dHubble(76.)
[14]:
Array(2.02039159e-06, dtype=float64, weak_type=True)
[15]:
dHubble(76.)
[15]:
Array(2.02039159e-06, dtype=float64, weak_type=True)
[16]:
import jax
jax.__version__
[16]:
'0.4.38'
[17]:
import numpy as np
np.__version__
[17]:
'2.0.2'
Comparison with numerical derivative
[42]:
h = 1e-6
(Hubble(76.+h)-Hubble(76))/h
[42]:
Array(2.02039161e-06, dtype=float64)
[43]:
import numpy as np
h = np.geomspace(1e-10, 1e-1,50)
dH = [(Hubble(76.+hp)-Hubble(76))/hp for hp in h]
jaxdH = dHubble(76.)
[44]:
%matplotlib inline
import matplotlib.pyplot as plt
plt.figure(figsize=(8, 6))
plt.plot(h, dH, label='Numerical', linestyle='-', marker='o') # Line plot for 'numerical'
plt.axhline(y=jaxdH, label='JAX', color='r', linestyle='--') # Horizontal line for 'jax'
# Logarithmic scale for x-axis
plt.xscale('log')
# Axis labels and title
plt.xlabel('h (step size)')
plt.ylabel('dH (derivative)')
plt.title('Comparison of Numerical and JAX Derivatives')
# Thin grid
plt.grid(True, which='both', linestyle='--', linewidth=0.5)
# Legend
plt.legend()
# Display plot
plt.show()
On a redshift grid
[45]:
from jax import jacrev, jacfwd
z = jnp.linspace(1., 20, 10)
def Hubble(H0):
cosmo_params.update({'H0':H0})
hz = classy_sz.get_hubble_at_z(z,params_values_dict = cosmo_params)
return hz
[46]:
dHubble = jacrev(Hubble)
[47]:
%timeit -n 100 dHubble(72.)
24.8 ms ± 1.37 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
[48]:
dHubble = jacfwd(Hubble)
[49]:
%timeit -n 100 dHubble(72.)
17.2 ms ± 1.09 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
Forward mode is faster.
> 1 dimension
At one redshift
[50]:
z = 1.
def Hubble(H0,omega_cdm):
cosmo_params.update({'H0':H0,
'omega_cdm':omega_cdm})
hz = classy_sz.get_hubble_at_z(z,params_values_dict = cosmo_params)
return hz
[51]:
Hubble(72,0.11833)
[51]:
Array(0.00040971, dtype=float64)
[52]:
from jax import jacfwd
# Get the derivative of Hubble with respect to both parameters
dHubble = jacfwd(Hubble,argnums=(0,1))
[53]:
dHubble(76.,0.11933)
[53]:
(Array(2.02039159e-06, dtype=float64), Array(0.00092981, dtype=float64))
On a redshift grid
[54]:
z = jnp.linspace(1., 20, 10)
def Hubble(H0,omega_cdm):
cosmo_params.update({'H0':H0,
'omega_cdm':omega_cdm})
hz = classy_sz.get_hubble_at_z(z,params_values_dict = cosmo_params)
return hz
[55]:
Hubble(72,0.11833)
[55]:
Array([0.00040971, 0.00106581, 0.00195863, 0.00302666, 0.00424306,
0.00559103, 0.00705868, 0.00863741, 0.0103195 , 0.01209945], dtype=float64)
[56]:
dHubble = jacfwd(Hubble,argnums=(0,1))
[57]:
dHubble(76.,0.11933)
[57]:
(Array([2.02039159e-06, 7.88481298e-07, 4.29706169e-07, 2.77572060e-07,
1.97199735e-07, 1.47912808e-07, 1.17819987e-07, 9.30194044e-08,
7.54919311e-08, 5.92952867e-08], dtype=float64),
Array([0.00092981, 0.00355171, 0.00678657, 0.01058068, 0.01486751,
0.01960805, 0.0247732 , 0.03029529, 0.0361957 , 0.04240309], dtype=float64))