Intialize

[1]:
%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from classy_sz import Class
import os
import time





cosmo_params = {
'omega_b': 0.02242,
'omega_cdm':  0.11933,
'H0': 67.66, # use H0 because this is what is used by the emulators.
'tau_reio': 0.0561,
'ln10^{10}A_s': 3.047,
'n_s': 0.9665,
}



font = {'family':'STIXGeneral'}
axislabelfontsize='large'
matplotlib.rc('font', **font)

plt.rcParams.update({
    "text.usetex": True,
    "font.family": "sans-serif",
    "font.sans-serif": ["Helvetica"]})


Compute halo mass function

[13]:
%%time
M = Class()
M.set(cosmo_params)
params = {
        'output': 'dndlnM',
        # mass function
        'mass_function' : 'T08M200m',
        #integration precision settings
        'ndim_redshifts' :100,
        #redshift and mass bounds
        'z_min' : 0.,
        'z_max' : 3.,
        'M_min' : 1e10,
        'M_max' : 1e15,

}
M.set(params)
M.initialize_classy_szfast()
CPU times: user 129 ms, sys: 7.23 ms, total: 136 ms
Wall time: 97.3 ms
/Users/boris/pyvenvs/py312-cmbagent/lib/python3.12/site-packages/mcfit/mcfit.py:130: UserWarning: use backend='jax' if desired
  warnings.warn("use backend='jax' if desired")
[14]:
z = 0.
print(M.pk_lin(1.,z))
z = 1.
print(M.pk_lin(1.,z))
86.81763360030757
32.18256614421709
[15]:
m_arr = np.geomspace(params['M_min'],params['M_max'],500)
dndlnm = np.vectorize(M.get_dndlnM_at_z_and_M)
sigma = np.vectorize(M.get_sigma_at_z_and_m)
nu = np.vectorize(M.get_nu_at_z_and_m)
b1 = np.vectorize(M.get_first_order_bias_at_z_and_nu)
b2 = np.vectorize(M.get_second_order_bias_at_z_and_nu)
[16]:

label_size = 20 title_size = 25 legend_size = 13 handle_length = 1.5 fig, ((ax1,ax2),(ax3,ax4)) = plt.subplots(2,2,figsize=(20,10),sharex=True) ax = ax1 ax.tick_params(axis = 'x',which='both',length=5,direction='in', pad=10) ax.tick_params(axis = 'y',which='both',length=5,direction='in', pad=5) ax.xaxis.set_ticks_position('both') ax.yaxis.set_ticks_position('both') plt.setp(ax.get_yticklabels(), rotation='horizontal', fontsize=label_size) plt.setp(ax.get_xticklabels(), fontsize=label_size) ax.grid( visible=True, which="both", alpha=0.1, linestyle='--') # ax.set_xlabel(r'$m\quad[M_\mathrm{sun}/h]$',size=title_size) ax.set_ylabel(r'$\mathrm{d}n/\mathrm{dln} m\quad[\mathrm{Mpc/h}]^3$',size=title_size,labelpad=8) z = 0. ax.plot(m_arr,dndlnm(z,m_arr),label=r'$z=%.0f$'%z,alpha=1.,c='k') z = 1. ax.plot(m_arr,dndlnm(z,m_arr),label=r'$z=%.0f$'%z,alpha=1.,c='k',ls='--') z = 2. ax.plot(m_arr,dndlnm(z,m_arr),label=r'$z=%.0f$'%z,alpha=1.,c='k',ls='-.') ax.loglog() ax.set_ylim(1e-6,2e-1) ax.set_xlim(params['M_min'],params['M_max']) ax.legend(loc=1,frameon=False,framealpha=1,fontsize=20) # ax.set_title(r'$z=0$') ax = ax2 ax.tick_params(axis = 'x',which='both',length=5,direction='in', pad=10) ax.tick_params(axis = 'y',which='both',length=5,direction='in', pad=5) ax.xaxis.set_ticks_position('both') ax.yaxis.set_ticks_position('both') plt.setp(ax.get_yticklabels(), rotation='horizontal', fontsize=label_size) plt.setp(ax.get_xticklabels(), fontsize=label_size) ax.grid( visible=True, which="both", alpha=0.1, linestyle='--') # ax.set_xlabel(r'$m\quad[M_\mathrm{sun}/h]$',size=title_size) ax.set_ylabel(r'$\sigma(m)$',size=title_size,labelpad=8) z = 0. ax.plot(m_arr,sigma(z,m_arr),label=r'$z=%.0f$'%z,alpha=1.,c='k') z = 1. ax.plot(m_arr,sigma(z,m_arr),label=r'$z=%.0f$'%z,alpha=1.,c='k',ls='--') z = 2. ax.plot(m_arr,sigma(z,m_arr),label=r'$z=%.0f$'%z,alpha=1.,c='k',ls='-.') ax.set_xscale('log') # ax.set_ylim(1e1,1e5) # ax.set_xlim(2e-3,1e1) ax.set_xlim(params['M_min'],params['M_max']) ax.legend(loc=1,frameon=False,framealpha=1,fontsize=20) # ax.set_title(r'$z=1$') ax = ax3 ax.tick_params(axis = 'x',which='both',length=5,direction='in', pad=10) ax.tick_params(axis = 'y',which='both',length=5,direction='in', pad=5) ax.xaxis.set_ticks_position('both') ax.yaxis.set_ticks_position('both') plt.setp(ax.get_yticklabels(), rotation='horizontal', fontsize=label_size) plt.setp(ax.get_xticklabels(), fontsize=label_size) ax.grid( visible=True, which="both", alpha=0.1, linestyle='--') ax.set_xlabel(r'$m\quad[M_\mathrm{sun}/h]$',size=title_size) ax.set_ylabel(r'$\nu$',size=title_size,labelpad=8) z = 0. ax.plot(m_arr,nu(z,m_arr),label=r'$z=%.0f$'%z,alpha=1.,c='k') z = 1. ax.plot(m_arr,nu(z,m_arr),label=r'$z=%.0f$'%z,alpha=1.,c='k',ls='--') z = 2. ax.plot(m_arr,nu(z,m_arr),label=r'$z=%.0f$'%z,alpha=1.,c='k',ls='-.') ax.loglog() # ax.set_ylim(1e1,1e5) # ax.set_xlim(2e-3,1e1) ax.set_xlim(params['M_min'],params['M_max']) ax.legend(loc=2,frameon=False,framealpha=1,fontsize=20) # ax.set_title(r'$z=2$') ax = ax4 ax.tick_params(axis = 'x',which='both',length=5,direction='in', pad=10) ax.tick_params(axis = 'y',which='both',length=5,direction='in', pad=5) ax.xaxis.set_ticks_position('both') ax.yaxis.set_ticks_position('both') plt.setp(ax.get_yticklabels(), rotation='horizontal', fontsize=label_size) plt.setp(ax.get_xticklabels(), fontsize=label_size) ax.grid( visible=True, which="both", alpha=0.1, linestyle='--') ax.set_xlabel(r'$m\quad[M_\mathrm{sun}/h]$',size=title_size) ax.set_ylabel(r'$b^{(1)}, b^{(2)}$',size=title_size,labelpad=8) z = 0. ax.plot(m_arr,b1(z,nu(z,m_arr)),label=r'$z=%.0f$'%z,alpha=1.,c='k') ax.plot(m_arr,np.abs(b2(z,nu(z,m_arr))),alpha=1.,c='k',lw=0.5) z = 1. ax.plot(m_arr,b1(z,nu(z,m_arr)),label=r'$z=%.0f$'%z,alpha=1.,c='k',ls='--') ax.plot(m_arr,np.abs(b2(z,nu(z,m_arr))),alpha=1.,c='k',ls='--',lw=0.5) z = 2. ax.plot(m_arr,b1(z,nu(z,m_arr)),label=r'$z=%.0f$'%z,alpha=1.,c='k',ls='-.') ax.plot(m_arr,np.abs(b2(z,nu(z,m_arr))),alpha=1.,c='k',ls='-.',lw=0.5) ax.set_xscale('log') ax.set_yscale('log') # ax.set_ylim(1e1,1e5) # ax.set_xlim(2e-3,1e1) ax.set_xlim(params['M_min'],params['M_max']) ax.legend(loc=2,frameon=False,framealpha=1,fontsize=20) # ax.set_title(r'$z=2$') fig.tight_layout() # plt.savefig('figures/hmf.pdf')
../_images/notebooks_classy_szfast_hmf_and_sigma_6_0.png

Compare Tinker08 and Tinker10

[9]:
%%time
MT10 = Class()
MT10.set(cosmo_params)
params = {
        'output': 'dndlnM',
        # mass function
        'mass_function' : 'T10M200m',
        #integration precision settings
        'ndim_redshifts' :80,
        #redshift and mass bounds
        'z_min' : 0.,
        'z_max' : 3.,
        'M_min' : 1e10,
        'M_max' : 1e15,

        'skip_cmb':1,
        'skip_pknl':1
}
MT10.set(params)
MT10.initialize_classy_szfast()
CPU times: user 118 ms, sys: 5.57 ms, total: 124 ms
Wall time: 95 ms
/Users/boris/pyvenvs/py312-cmbagent/lib/python3.12/site-packages/mcfit/mcfit.py:130: UserWarning: use backend='jax' if desired
  warnings.warn("use backend='jax' if desired")
[10]:
z = 0.
print(MT10.pk_lin(1.,z))
z = 1.
print(MT10.pk_lin(1.,z))
86.81763360030749
32.18256626141982
[11]:
m_arr = np.geomspace(params['M_min'],params['M_max'],500)
dndlnmT10 = np.vectorize(MT10.get_dndlnM_at_z_and_M)
sigmaT10 = np.vectorize(MT10.get_sigma_at_z_and_m)
nuT10 = np.vectorize(MT10.get_nu_at_z_and_m)
b1T10 = np.vectorize(MT10.get_first_order_bias_at_z_and_nu)
b2T10 = np.vectorize(MT10.get_second_order_bias_at_z_and_nu)
[12]:

label_size = 20 title_size = 25 legend_size = 13 handle_length = 1.5 fig, ((ax1,ax2),(ax3,ax4)) = plt.subplots(2,2,figsize=(20,10),sharex=True) ax = ax1 ax.tick_params(axis = 'x',which='both',length=5,direction='in', pad=10) ax.tick_params(axis = 'y',which='both',length=5,direction='in', pad=5) ax.xaxis.set_ticks_position('both') ax.yaxis.set_ticks_position('both') plt.setp(ax.get_yticklabels(), rotation='horizontal', fontsize=label_size) plt.setp(ax.get_xticklabels(), fontsize=label_size) ax.grid( visible=True, which="both", alpha=0.1, linestyle='--') # ax.set_xlabel(r'$m\quad[M_\mathrm{sun}/h]$',size=title_size) ax.set_ylabel(r'$\mathrm{d}n/\mathrm{dln} m\quad[\mathrm{Mpc/h}]^3$',size=title_size,labelpad=8) z = 0. ax.plot(m_arr,dndlnm(z,m_arr),label=r'$z=%.0f$'%z,alpha=1.,c='k') ax.plot(m_arr,dndlnmT10(z,m_arr),label=r'$\mathrm{T10}$ $z=%.0f$'%z,alpha=1.,c='r',ls='--') z = 1. ax.plot(m_arr,dndlnm(z,m_arr),label=r'$z=%.0f$'%z,alpha=1.,c='k',ls='--') ax.plot(m_arr,dndlnmT10(z,m_arr),label=r'$\mathrm{T10}$ $z=%.0f$'%z,alpha=1.,c='r',ls='--') z = 2. ax.plot(m_arr,dndlnm(z,m_arr),label=r'$z=%.0f$'%z,alpha=1.,c='k',ls='-.') ax.plot(m_arr,dndlnmT10(z,m_arr),label=r'$\mathrm{T10}$ $z=%.0f$'%z,alpha=1.,c='r',ls='--') ax.loglog() ax.set_ylim(1e-6,2e-1) ax.set_xlim(params['M_min'],params['M_max']) ax.legend(loc=1,frameon=False,framealpha=1,fontsize=20) # ax.set_title(r'$z=0$') ax = ax2 ax.tick_params(axis = 'x',which='both',length=5,direction='in', pad=10) ax.tick_params(axis = 'y',which='both',length=5,direction='in', pad=5) ax.xaxis.set_ticks_position('both') ax.yaxis.set_ticks_position('both') plt.setp(ax.get_yticklabels(), rotation='horizontal', fontsize=label_size) plt.setp(ax.get_xticklabels(), fontsize=label_size) ax.grid( visible=True, which="both", alpha=0.1, linestyle='--') # ax.set_xlabel(r'$m\quad[M_\mathrm{sun}/h]$',size=title_size) ax.set_ylabel(r'$\sigma(m)$',size=title_size,labelpad=8) z = 0. ax.plot(m_arr,sigma(z,m_arr),label=r'$z=%.0f$'%z,alpha=1.,c='k') z = 1. ax.plot(m_arr,sigma(z,m_arr),label=r'$z=%.0f$'%z,alpha=1.,c='k',ls='--') z = 2. ax.plot(m_arr,sigma(z,m_arr),label=r'$z=%.0f$'%z,alpha=1.,c='k',ls='-.') ax.set_xscale('log') # ax.set_ylim(1e1,1e5) # ax.set_xlim(2e-3,1e1) ax.set_xlim(params['M_min'],params['M_max']) ax.legend(loc=1,frameon=False,framealpha=1,fontsize=20) # ax.set_title(r'$z=1$') ax = ax3 ax.tick_params(axis = 'x',which='both',length=5,direction='in', pad=10) ax.tick_params(axis = 'y',which='both',length=5,direction='in', pad=5) ax.xaxis.set_ticks_position('both') ax.yaxis.set_ticks_position('both') plt.setp(ax.get_yticklabels(), rotation='horizontal', fontsize=label_size) plt.setp(ax.get_xticklabels(), fontsize=label_size) ax.grid( visible=True, which="both", alpha=0.1, linestyle='--') ax.set_xlabel(r'$m\quad[M_\mathrm{sun}/h]$',size=title_size) ax.set_ylabel(r'$\nu$',size=title_size,labelpad=8) z = 0. ax.plot(m_arr,nu(z,m_arr),label=r'$z=%.0f$'%z,alpha=1.,c='k') z = 1. ax.plot(m_arr,nu(z,m_arr),label=r'$z=%.0f$'%z,alpha=1.,c='k',ls='--') z = 2. ax.plot(m_arr,nu(z,m_arr),label=r'$z=%.0f$'%z,alpha=1.,c='k',ls='-.') ax.loglog() # ax.set_ylim(1e1,1e5) # ax.set_xlim(2e-3,1e1) ax.set_xlim(params['M_min'],params['M_max']) ax.legend(loc=2,frameon=False,framealpha=1,fontsize=20) # ax.set_title(r'$z=2$') ax = ax4 ax.tick_params(axis = 'x',which='both',length=5,direction='in', pad=10) ax.tick_params(axis = 'y',which='both',length=5,direction='in', pad=5) ax.xaxis.set_ticks_position('both') ax.yaxis.set_ticks_position('both') plt.setp(ax.get_yticklabels(), rotation='horizontal', fontsize=label_size) plt.setp(ax.get_xticklabels(), fontsize=label_size) ax.grid( visible=True, which="both", alpha=0.1, linestyle='--') ax.set_xlabel(r'$m\quad[M_\mathrm{sun}/h]$',size=title_size) ax.set_ylabel(r'$b^{(1)}, b^{(2)}$',size=title_size,labelpad=8) z = 0. ax.plot(m_arr,b1(z,nu(z,m_arr)),label=r'$z=%.0f$'%z,alpha=1.,c='k') ax.plot(m_arr,np.abs(b2(z,nu(z,m_arr))),alpha=1.,c='k',lw=0.5) z = 1. ax.plot(m_arr,b1(z,nu(z,m_arr)),label=r'$z=%.0f$'%z,alpha=1.,c='k',ls='--') ax.plot(m_arr,np.abs(b2(z,nu(z,m_arr))),alpha=1.,c='k',ls='--',lw=0.5) z = 2. ax.plot(m_arr,b1(z,nu(z,m_arr)),label=r'$z=%.0f$'%z,alpha=1.,c='k',ls='-.') ax.plot(m_arr,np.abs(b2(z,nu(z,m_arr))),alpha=1.,c='k',ls='-.',lw=0.5) ax.set_xscale('log') ax.set_yscale('log') # ax.set_ylim(1e1,1e5) # ax.set_xlim(2e-3,1e1) ax.set_xlim(params['M_min'],params['M_max']) ax.legend(loc=2,frameon=False,framealpha=1,fontsize=20) # ax.set_title(r'$z=2$') fig.tight_layout() # plt.savefig('figures/hmf.pdf')
../_images/notebooks_classy_szfast_hmf_and_sigma_11_0.png

Compute with Jax

[24]:
%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
import os
os.environ["JAX_PLATFORM_NAME"] = "cpu"
from classy_sz import Class as Class_sz

import jax.numpy as jnp
import jax
import mcfit
from mcfit import TophatVar
import jax.scipy as jscipy
print("Jax environment info:")
jax.print_environment_info()
Jax environment info:
jax:    0.5.0
jaxlib: 0.5.0
numpy:  1.26.4
python: 3.12.8 (main, Dec  3 2024, 18:42:41) [Clang 16.0.0 (clang-1600.0.26.4)]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Darwin', node='Boriss-MacBook-Pro-3.local', release='23.5.0', version='Darwin Kernel Version 23.5.0: Wed May  1 20:12:58 PDT 2024; root:xnu-10063.121.3~5/RELEASE_ARM64_T6000', machine='arm64')

[25]:
%%time
cosmo_params = {
'omega_b': 0.02242,
'omega_cdm':  0.11933,
'H0': 67.66, # use H0 because this is what is used by the emulators.
'tau_reio': 0.0561,
'ln10^{10}A_s': 3.047,
'n_s': 0.9665,
}

# initialize computation
classy_sz = Class_sz()
classy_sz.set(cosmo_params)
classy_sz.set({
'output':'mPk',
'jax': 1,
# 'ndim_redshifts':50,
# 'skip_background_and_thermo': 0,
## don't forget to set neutrinos if you want to compare with full background sol
# 'N_ncdm': 1,
# 'm_ncdm': 0.02,
# 'deg_ncdm': 3,

'M_min': 5e13,
'M_max': 5e15
})
classy_sz.compute_class_szfast()
# classy_sz.compute()
/Users/boris/venvdir/class_sz_312_brew/lib/python3.12/site-packages/IPython/core/magics/execution.py:1355: DeprecationWarning: compute_class_szfast is deprecated. Use initialize_classy_szfast instead.
  out = eval(code_2, glob, local_ns)
CPU times: user 1.53 s, sys: 87.7 ms, total: 1.62 s
Wall time: 300 ms
[26]:
def MF_T08(sigmas, z, delta_mean):
    # Convert delta_mean to log scale
    delta_mean = jnp.log10(delta_mean)

    # Define parameters as JAX arrays
    delta_mean_tab = jnp.log10(jnp.array([200, 300, 400, 600, 800, 1200, 1600, 2400, 3200]))
    A_tab = jnp.array([0.186, 0.200, 0.212, 0.218, 0.248, 0.255, 0.260, 0.260, 0.260])
    aa_tab = jnp.array([1.47, 1.52, 1.56, 1.61, 1.87, 2.13, 2.30, 2.53, 2.66])
    b_tab = jnp.array([2.57, 2.25, 2.05, 1.87, 1.59, 1.51, 1.46, 1.44, 1.41])
    c_tab = jnp.array([1.19, 1.27, 1.34, 1.45, 1.58, 1.80, 1.97, 2.24, 2.44])

    # Linear interpolation using jnp.interp
    Ap = jnp.interp(delta_mean, delta_mean_tab, A_tab) * (1 + z) ** -0.14
    a = jnp.interp(delta_mean, delta_mean_tab, aa_tab) * (1 + z) ** -0.06
    b = jnp.interp(delta_mean, delta_mean_tab, b_tab) * (1 + z) ** -jnp.power(10, -jnp.power(0.75 / jnp.log10(jnp.power(10, delta_mean) / 75), 1.2))
    c = jnp.interp(delta_mean, delta_mean_tab, c_tab)

    # print(a.shape,b.shape,c.shape,Ap.shape,sigmas.shape)

    # Calculate final result
    result = 0.5 * Ap[:,None] * (jnp.power(sigmas / b[:, None], -a[:, None]) + 1) * jnp.exp(-c[:, None] / sigmas**2)

    return result
[27]:
%%time
def get_hmf_grid(delta = 500, delta_def = 'critical', params_values_dict = None):

    rparams = classy_sz.get_all_relevant_params(params_values_dict = params_values_dict)
    h = rparams['h']
    ## initialize (get ks)
    z = 1.
    _,ks = classy_sz.get_pkl_at_z(z,params_values_dict = params_values_dict)

    # Define a single function for `get_pkl_at_z` calls
    def get_pks_for_z(zp):
        pks, ks = classy_sz.get_pkl_at_z(zp, params_values_dict= params_values_dict)
        return pks.flatten()

    # Vectorize this function over `z_grid`
    z_grid = classy_sz.z_grid()
    P = jax.vmap(get_pks_for_z)(z_grid).T

    # Vectorize the TophatVar function over `z_grid`
    def compute_tophat_var(pks, ks):
        _, var_z = TophatVar(ks, lowring=True, backend='jax')(pks, extrap=True)
        return var_z

    # Apply the function to each column of P
    var = jax.vmap(compute_tophat_var, in_axes=(1, None))(P, ks)


    # Vectorize the TophatVar function over `z_grid`
    def compute_tophat_dvar(pks, ks):
        _, var_z = TophatVar(ks, lowring=True, backend='jax',deriv=1)(pks*ks, extrap=True)
        # cosmocnc:  TophatVar(self.k,lowring=True,deriv=1)(self.pk*self.k,extrap=True)
        return var_z


    # Apply the function to each column of P
    dvar = jax.vmap(compute_tophat_dvar, in_axes=(1, None))(P, ks)


    # Step 4: Compute gradient of var with respect to R
    # Assuming R is uniform across z_grid, use the first R from TophatVar
    R, _ = TophatVar(ks, lowring=True, backend='jax')(P[:, 0], extrap=True)
    R = R.flatten()  # Ensure R has shape (1000,)
    lnr_grid = jnp.log(R)
    lnx_grid = jnp.log(1+z_grid)



    lnsigma_grid = 0.5*jnp.log(var)

    # dvar = R*jnp.gradient(var, jnp.log(R))
    dsigma2_grid = dvar


    Rh = R*rparams['h']
    lnm_grid = jnp.log(4*jnp.pi*rparams['Omega0_cb']*rparams['Rho_crit_0']*Rh**3/3.) # in h-units

    # Define the interpolator
    # lnsigma_interpolator = jscipy.interpolate.RegularGridInterpolator((lnx_grid, lnm_grid), lnsigma_grid)
    # dsigma2_interpolator = jscipy.interpolate.RegularGridInterpolator((lnx_grid, lnm_grid), dsigma2_grid)
    # print(jnp.exp(lnm_grid)[0],jnp.exp(lnm_grid)[-1])
    if delta_def == 'critical':
        delta_mean = classy_sz.get_delta_mean_from_delta_crit_at_z(delta,z_grid,params_values_dict = params_values_dict);
    elif delta_def == 'mean':
        delta_mean = jnp.full_like(z_grid, delta)
    else:
        print("Not implemened yet")
    # print(delta_mean[0],delta_mean[-1])


    delta_c =  (3./20.)*jnp.power(12.*jnp.pi,2./3.) # this is = 1.686470199841145
    # print(delta_c)
    # note here we dont use matter dependent delta_c
    # which would be multiplied by (1.+0.012299*log10(pvecback[pba->index_bg_Omega_m]));


    sigmas = jnp.exp(lnsigma_grid)
    nus = (delta_c/sigmas)**2 ## currently for book keeping
    # print("nus",nus.shape)
    # print("sigmas shape",sigmas.shape)
    # print("z_grid shape",z_grid.shape)
    # print("delta_mean shape",delta_mean.shape)

    hmf = MF_T08(sigmas, z_grid, delta_mean)
    # print("hmf shape",hmf.shape)

    lnSigma2 = 2.*lnsigma_grid
    dlnsigmadlnR = dsigma2_grid/2.
    dlnSigma2dlnR = 2.*dlnsigmadlnR*R/jnp.exp(lnSigma2)
    dlnnudlnRh = -dlnSigma2dlnR

    # Return dn/dlogM in units of h^3 Mpc^-3
    dndlnm_grid = 1./3.*3./(4.*jnp.pi*Rh**3)*dlnnudlnRh*hmf
    return lnx_grid,lnm_grid,dndlnm_grid
CPU times: user 11 μs, sys: 6 μs, total: 17 μs
Wall time: 24.8 μs
[28]:
%timeit -n 40 -r 10 get_hmf_grid(params_values_dict = cosmo_params)
55.6 ms ± 1.83 ms per loop (mean ± std. dev. of 10 runs, 40 loops each)
[29]:
def get_hmf_at_z_and_m(z,m,params_values_dict = None):
    lnx, lnm, dndlnm = get_hmf_grid(delta = 200, delta_def = 'mean', params_values_dict = params_values_dict)
    hmf_interp = jscipy.interpolate.RegularGridInterpolator((lnx, lnm), jnp.log(dndlnm))
    lnxp = jnp.log(1.+z)
    lnmp = jnp.log(m)
    return jnp.exp(hmf_interp((lnxp,lnmp)))
[30]:
m = jnp.geomspace(1e10,1e15,200)

[31]:
for z in [0,1,2]:
    dndlnm = get_hmf_at_z_and_m(z,m,params_values_dict = cosmo_params)
    plt.plot(m,dndlnm)
plt.loglog()
plt.grid(which='both',alpha=0.1)
plt.ylim(1e-6,2e-1)
plt.xlim(1e10,1e15)
[31]:
(10000000000.0, 1000000000000000.0)
../_images/notebooks_classy_szfast_hmf_and_sigma_20_1.png

Gradient

[18]:
z = 0.1
Omega_c = 0.26
m = jnp.geomspace(1e10,1e15,200)

def dndlnm_allms(Omega_c):
    omega_cdm = Omega_c*(cosmo_params['H0']/100.)**2
    cosmo_params.update({'omega_cdm':omega_cdm})
    dndlnm = get_hmf_at_z_and_m(z,m,cosmo_params)
    return dndlnm
[19]:
from jax import jacfwd
[20]:
%%time
ddndlnm_allms = jacfwd(dndlnm_allms,argnums=(0))
CPU times: user 43 µs, sys: 3 µs, total: 46 µs
Wall time: 52.2 µs
[21]:
%timeit -n 10 -r 5 ddndlnm_allms(Omega_c)
152 ms ± 79.7 ms per loop (mean ± std. dev. of 5 runs, 10 loops each)
[22]:
plt.loglog(m,jnp.abs(ddndlnm_allms(Omega_c)))
plt.grid(which = 'both',alpha=0.1)
plt.xlabel("$m$")
plt.ylabel(r"$|\partial \mathrm{dndlnm}(k)/\partial \Omega_c|$")
[22]:
Text(0, 0.5, '$|\\partial \\mathrm{dndlnm}(k)/\\partial \\Omega_c|$')
../_images/notebooks_classy_szfast_hmf_and_sigma_26_1.png