import jax
import equinox as eqx
import numpy as np
import jax.numpy as jnp
from xcquinox.utils import lda_x, pw92c_unpolarized
from functools import partial
[docs]
def generate_network_eval_xc(mf, dm, network):
'''
Generates a function to overwrite eval_xc with on the mf object, for use in training with pyscfad's SCF cycle
:param mf: Pyscfad calculation kernel object
:type mf: Pyscfad calculation kernel object
:param dm: Initial density matrix to use in the cycle
:type dm: jax.Array
:param network: The network to use in evaluating the SCF cycle
:type network: xcquinox.xc.eXC
:return: A function `eval_xc` that uses an xcquinox network as the pyscfad kernel calculation driver.
:rtype: function
The returned function:
eval_xc(xc_code, rho, ao, gw, coords, spin=0, relativity=0, deriv=1, omega=None, verbose=None)
The function to use as driver for a pyscf(ad) calculation, using an xcquinox network.
This overwrites mf.eval_xc with a custom function, evaluating:
Exc_exc, vs = jax.value_and_grad(EXC_exc_vs, has_aux=True)(jnp.concatenate([jnp.expand_dims(rho0_a,-1),
jnp.expand_dims(rho0_b,-1),
jnp.expand_dims(gamma_a,-1),
jnp.expand_dims(gamma_ab,-1),
jnp.expand_dims(gamma_b,-1),
jnp.expand_dims(jnp.zeros_like(rho0_a),-1), #Dummy for laplacian
jnp.expand_dims(jnp.zeros_like(rho0_a),-1), #Dummy for laplacian
jnp.expand_dims(tau_a,-1),
jnp.expand_dims(tau_b,-1),
jnp.expand_dims(non_loc_a,-1),
jnp.expand_dims(non_loc_b,-1)],axis=-1))
:param xc_code: The XC functional code string in libxc format, but it is ignored as the network is the calculation driver
:type xc_code: str
:param rho: The [..., *, N] arrays (... for spin polarized), N is the number of grid points.
rho (*,N) ordered as (rho, grad_x, grad_y, grad_z, laplacian, tau)
rho (2,*,N) is [(rho_up, grad_x_up, grad_y_up, grad_z_up, laplacian_up, tau_up),
(rho_down, grad_x_down, grad_y_down, grad_z_down, laplacian_down, tau_down)]
PySCFAD doesn't do spin-polarized grid calculations yet, so this will be unpolarized.
:type rho: jax.Array
:param ao: The atomic orbitals on the grid to use in the network calculation. Explcitly specified as the block loops break down the grid if memory is too low
:type ao: jax.Array
:param ao: The grid weights to use in the network calculation. Explcitly specified as the block loops break down the grid if memory is too low
:type ao: jax.Array
:param ao: The grid coordinates to use in the network calculation. Explcitly specified as the block loops break down the grid if memory is too low
:type ao: jax.Array
:param spin: The spin of the calculation, integer valued, polarized if non-zero, defaults to zero
:type spin: int
:param relativity: Integer, unused right now, defaults to zero
:type relativity: int
:param deriv: Unused here, defaults to 1
:type deriv: int
:param omega: Hybrid mixing term, unused here, defaults to None
:type omega: float
:param verbose: Unused here, defaults to None
:type verbose: int
:return: ex, vxc, fxc, kxc
where: ex -> exc, XC energy density on the grid
vxc -> (vrho, vsigma, vlapl, vtau), gradients of the exc w.r.t. the quantities given.
Only vrho and vtau are used, vsigma=vlapl=fxc=kxc=None.
vrho = vs[:, 0]+vs[:, 1]
vtau = vs[:, 7]+vs[:, 8]
:rtype: tuple
'''
def eval_xc(xc_code, rho, ao, gw, coords, spin=0, relativity=0, deriv=1, omega=None, verbose=None):
'''
The function to use as driver for a pyscf(ad) calculation, using an xcquinox network.
This overwrites mf.eval_xc with a custom function, evaluating:
Exc_exc, vs = jax.value_and_grad(EXC_exc_vs, has_aux=True)(jnp.concatenate([jnp.expand_dims(rho0_a,-1),
jnp.expand_dims(rho0_b,-1),
jnp.expand_dims(gamma_a,-1),
jnp.expand_dims(gamma_ab,-1),
jnp.expand_dims(gamma_b,-1),
jnp.expand_dims(jnp.zeros_like(rho0_a),-1), #Dummy for laplacian
jnp.expand_dims(jnp.zeros_like(rho0_a),-1), #Dummy for laplacian
jnp.expand_dims(tau_a,-1),
jnp.expand_dims(tau_b,-1),
jnp.expand_dims(non_loc_a,-1),
jnp.expand_dims(non_loc_b,-1)],axis=-1))
:param xc_code: The XC functional code string in libxc format, but it is ignored as the network is the calculation driver
:type xc_code: str
:param rho: The [..., *, N] arrays (... for spin polarized), N is the number of grid points.
rho (*,N) ordered as (rho, grad_x, grad_y, grad_z, laplacian, tau)
rho (2,*,N) is [(rho_up, grad_x_up, grad_y_up, grad_z_up, laplacian_up, tau_up),
(rho_down, grad_x_down, grad_y_down, grad_z_down, laplacian_down, tau_down)]
PySCFAD doesn't do spin-polarized grid calculations yet, so this will be unpolarized.
:type rho: jax.Array
:param ao: The atomic orbitals on the grid to use in the network calculation. Explcitly specified as the block loops break down the grid if memory is too low
:type ao: jax.Array
:param ao: The grid weights to use in the network calculation. Explcitly specified as the block loops break down the grid if memory is too low
:type ao: jax.Array
:param ao: The grid coordinates to use in the network calculation. Explcitly specified as the block loops break down the grid if memory is too low
:type ao: jax.Array
:param spin: The spin of the calculation, integer valued, polarized if non-zero, defaults to zero
:type spin: int
:param relativity: Integer, unused right now, defaults to zero
:type relativity: int
:param deriv: Unused here, defaults to 1
:type deriv: int
:param omega: Hybrid mixing term, unused here, defaults to None
:type omega: float
:param verbose: Unused here, defaults to None
:type verbose: int
:return: ex, vxc, fxc, kxc
where: ex -> exc, XC energy density on the grid
vxc -> (vrho, vsigma, vlapl, vtau), gradients of the exc w.r.t. the quantities given.
Only vrho and vtau are used, vsigma=vlapl=fxc=kxc=None.
vrho = vs[:, 0]+vs[:, 1]
vtau = vs[:, 7]+vs[:, 8]
:rtype: tuple
'''
# print('custom eval_xc; input rho shape: ', rho.shape)
if len(rho.shape) == 2:
# not spin-polarized
rho0 = rho[0] # density
drho = rho[1:4] # grad_x, grad_y, grad_z
# laplacian next
# tau = 0.5*(rho[1] + rho[2] + rho[3])
tau = rho[-1] # tau
non_loc = jnp.zeros_like(tau)
# decompose into spin channels
rho0_a = rho0_b = rho0*0.5
gamma_a = gamma_b = gamma_ab = jnp.einsum('ij,ij->j', drho[:], drho[:])*0.25
tau_a = tau_b = tau*0.5
non_loc_a = non_loc_b = non_loc*0.5
if network.verbose:
print(
f'decomposed shapes:\nrho0={rho0.shape}\ndrho={drho.shape}\ntau={tau.shape}\nnon_loc={non_loc.shape}')
print(
f'decomposed shapes:\ngamma_a={gamma_a.shape}\ngamma_b={gamma_b.shape}\ngamma_ab={gamma_ab.shape}')
else:
# spin-polarized density
rho0_a = rho[0, 0]
rho0_b = rho[1, 0]
drho_a = rho[0, 1:4]
drho_b = rho[1, 1:4]
# jnp.einsumed density gradient
gamma_a, gamma_b = jnp.einsum('ij,ij->j', drho_a, drho_a), jnp.einsum('ij,ij->j', drho_b, drho_b)
gamma_ab = jnp.einsum('ij,ij->j', drho_a, drho_b)
# Kinetic energy density
tau_a = rho[0, -1]
tau_b = rho[1, -1]
non_loc_a, non_loc_b = jnp.zeros_like(tau_a), jnp.zeros_like(tau_b)
if network.verbose:
print(
f'decomposed shapes:\nrho0(a,b)={rho0_a.shape},{rho0_b.shape}\ndrho(a,b)={drho_a.shape},{drho_b.shape}\ntau(a,b)={tau_a.shape},{tau_b.shape}\nnon_loc(a,b)={non_loc_a.shape},{non_loc_b.shape}')
print(
f'decomposed shapes:\ngamma_a={gamma_a.shape}\ngamma_b={gamma_b.shape}\ngamma_ab={gamma_ab.shape}')
# xc-energy per unit particle
# print(f'EVALUATING GRID MODELS; OPTIONAL PARAMETERS:')
# try:
# print(f'gw.shape={gw.shape}, coor.shape={coor.shape}')
# except:
# print('no externally supplied gw or coor')
# print('eval_xc eval_grid_models call')
def EXC_exc_vs(x):
exc = network.eval_grid_models(x, mf=mf, dm=dm, ao=ao, gw=gw, coor=coords)
Exc = jnp.sum(((rho0_a + rho0_b)*exc[:, 0])*gw)
return Exc, exc
if network.verbose:
print(f'eval_xc -> Exc_exc and potentials on grid via autodiff')
v_and_g_inp = jnp.concatenate([jnp.expand_dims(rho0_a, -1),
jnp.expand_dims(rho0_b, -1),
jnp.expand_dims(gamma_a, -1),
jnp.expand_dims(gamma_ab, -1),
jnp.expand_dims(gamma_b, -1),
jnp.expand_dims(jnp.zeros_like(rho0_a), -1), # Dummy for laplacian
jnp.expand_dims(jnp.zeros_like(rho0_a), -1), # Dummy for laplacian
jnp.expand_dims(tau_a, -1),
jnp.expand_dims(tau_b, -1),
jnp.expand_dims(non_loc_a, -1),
jnp.expand_dims(non_loc_b, -1)], axis=-1)
print(f'v_and_g_inp.shape={v_and_g_inp.shape}')
Exc_exc, vs = jax.value_and_grad(EXC_exc_vs, has_aux=True)(v_and_g_inp)
print(f'Exc_exc and vs returned: Exc = {Exc_exc[0]}, exc.shape={Exc_exc[1].shape}, vs.shape={vs.shape}')
Exc, exc = Exc_exc
print(f'eval_xc Exc = {Exc}')
if jnp.sum(jnp.isnan(exc[:, 0])):
print('NaNs detected in exc. Number of NaNs: {}'.format(jnp.sum(jnp.isnan(exc[:, 0]))))
raise
else:
exc = exc[:, 0]
# print('ao shape: ', ao.shape)
# print('exc from network evaluation on grid models shape: ', exc.shape)
# print('vs from network evaluation on grid models shape: ', vs.shape)
# print('Exc from network evaluation on grid models shape: ', Exc)
def vgf(x): return network(x, ao, gw, mf=mf, coor=coords)
mf.converged = True
mf.network = network
mf.network_eval = vgf
# vrho; d Exc/d rho, separate spin channels
vrho = vs[:, 0]+vs[:, 1]
# vtau; d Exc/d tau, separate spin channels
vtau = vs[:, 7]+vs[:, 8]
vgamma = jnp.zeros_like(vrho)
vlapl = None
fxc = None # second order functional derivative
kxc = None # third order functional derivative
if network.verbose:
print(f'shapes: vrho={vrho.shape}, vgamma={vgamma.shape}')
return exc, (vrho, vgamma, vlapl, vtau), fxc, kxc
return eval_xc
# updated versions of this
# GGA
def custom_pbe_Fx(rho, sigma, XNET=None):
# this will be a call to the Fx neural network we want
# print('DEBUG custom_pbe_Fx, rho/sigma shapes: ', rho.shape, sigma.shape)
# print('DEBUG custom_pbe_Fx: rho: ', rho)
# print('DEBUG custom_pbe_Fx: sigma: ', sigma)
Fx = XNET([rho, sigma])
return Fx
def custom_pbe_Fc(rho, sigma, CNET=None): # Assumes zeta = 0
# this will be a call to the Fc neural network we want
Fc = CNET([rho, sigma])
return Fc
def custom_pbe_e(rho, sigma, XNET=None, CNET=None):
Fx = custom_pbe_Fx(rho, sigma, XNET=XNET)
Fc = custom_pbe_Fc(rho, sigma, CNET=CNET)
exc = lda_x(rho)*Fx + pw92c_unpolarized(rho)*Fc
return exc
def custom_pbe_epsilon(rho, sigma, XNET=None, CNET=None):
return rho*custom_pbe_e(rho, sigma, XNET=XNET, CNET=CNET)
def derivable_custom_pbe_e(rhosigma, XNET=None, CNET=None):
rho, sigma = rhosigma
# print('DEBUG derivable_custom_pbe_e: rhosigma len/shapes: ', len(rhosigma), rhosigma)
# print('DEBUG derivable_custom_pbe_e: rho/sigma shapes: ', rho.shape, sigma.shape)
# print('DEBUG derivable_custom_pbe_e: rho: ', rho)
# print('DEBUG derivable_custom_pbe_e: sigma: ', sigma)
return custom_pbe_e(rho, sigma, XNET=XNET, CNET=CNET)
def derivable_custom_pbe_epsilon(rhosigma, XNET=None, CNET=None):
rho = rhosigma[0]
sigma = rhosigma[1]
result = custom_pbe_epsilon(rho, sigma, XNET=XNET, CNET=CNET)
return result[0]
def eval_xc_gga_j(xc_code, rho, spin=0, relativity=0, deriv=1, omega=None, verbose=None,
XNET=None, CNET=None):
# we only expect there to be a rho0 array, but I unpack it as (rho, deriv) here to be in line with the
# pyscf example -- the size of the 'rho' array depends on the xc type (LDA, GGA, etc.)
# so since LDA calculation, check for size first.
rho0, dx, dy, dz = rho[:4]
rho0 = jnp.array(rho0)
sigma = jnp.array(dx**2+dy**2+dz**2)
# print('DEBUG eval_xc_gga_j: rho0/sigma shapes: ', rho0.shape, sigma.shape)
rhosig = (rho0, sigma)
# calculate the "custom" energy with rho -- THIS IS e
# cast back to np.array since that's what pyscf works with
# pass as tuple -- (rho, sigma)
derivable_net_e = partial(derivable_custom_pbe_e, XNET=XNET, CNET=CNET)
derivable_net_epsilon = partial(derivable_custom_pbe_epsilon, XNET=XNET, CNET=CNET)
exc = np.array(jax.vmap(derivable_net_e)(rhosig))
# first order derivatives w.r.t. rho and sigma
vrho_f = eqx.filter_grad(derivable_net_epsilon)
vrhosigma = np.array(jax.vmap(vrho_f)(rhosig))
# print('vrhosigma shape:', vrhosigma.shape)
vxc = (vrhosigma[0], vrhosigma[1], None, None)
# v2_f = eqx.filter_hessian(derivable_custom_pbe_epsilon)
v2_f = jax.hessian(derivable_net_epsilon)
# v2_f = jax.hessian(custom_pbe_epsilon, argnums=[0, 1])
v2 = np.array(jax.vmap(v2_f)(rhosig))
# print('v2 shape', v2.shape)
v2rho2 = v2[0][0]
v2rhosigma = v2[0][1]
v2sigma2 = v2[1][1]
v2lapl2 = None
vtau2 = None
v2rholapl = None
v2rhotau = None
v2lapltau = None
v2sigmalapl = None
v2sigmatau = None
# 2nd order functional derivative
fxc = (v2rho2, v2rhosigma, v2sigma2, v2lapl2, vtau2, v2rholapl, v2rhotau, v2lapltau, v2sigmalapl, v2sigmatau)
# 3rd order
kxc = None
return exc, vxc, fxc, kxc
def eval_xc_gga_j2(xc_code, rho, spin=0, relativity=0, deriv=1, omega=None, verbose=None,
xcmodel=None):
# we only expect there to be a rho0 array, but I unpack it as (rho, deriv) here to be in line with the
# pyscf example -- the size of the 'rho' array depends on the xc type (LDA, GGA, etc.)
# so since LDA calculation, check for size first.
try:
rho0, dx, dy, dz = rho[:4]
sigma = jnp.array(dx**2+dy**2+dz**2)
except:
rho0, drho = rho[:4]
sigma = jnp.array(drho**2)
rho0 = jnp.array(rho0)
# sigma = jnp.array(dx**2+dy**2+dz**2)
# print('DEBUG eval_xc_gga_j: rho0/sigma shapes: ', rho0.shape, sigma.shape)
# rhosig = (rho0, sigma)
rhosig = jnp.stack([rho0, sigma], axis=1)
print(rhosig.shape)
# calculate the "custom" energy with rho -- THIS IS e
# cast back to np.array since that's what pyscf works with
# pass as tuple -- (rho, sigma)
exc = jax.vmap(xcmodel)(rhosig)
exc = jnp.array(exc)/rho0
# exc = jnp.array(jax.vmap(xcmodel)( rhosig ) )/rho0
# print('exc shape = {}'.format(exc.shape))
# first order derivatives w.r.t. rho and sigma
vrho_f = eqx.filter_grad(xcmodel)
vrhosigma = jnp.array(jax.vmap(vrho_f)(rhosig))
# print('vrhosigma shape:', vrhosigma.shape)
vxc = (vrhosigma[:, 0], vrhosigma[:, 1], None, None)
# v2_f = eqx.filter_hessian(derivable_custom_pbe_epsilon)
v2_f = jax.hessian(xcmodel)
# v2_f = jax.hessian(custom_pbe_epsilon, argnums=[0, 1])
v2 = jnp.array(jax.vmap(v2_f)(rhosig))
print('v2 shape', v2.shape)
v2rho2 = v2[:, 0, 0]
v2rhosigma = v2[:, 0, 1]
v2sigma2 = v2[:, 1, 1]
v2lapl2 = None
vtau2 = None
v2rholapl = None
v2rhotau = None
v2lapltau = None
v2sigmalapl = None
v2sigmatau = None
# 2nd order functional derivative
fxc = (v2rho2, v2rhosigma, v2sigma2, v2lapl2, vtau2, v2rholapl, v2rhotau, v2lapltau, v2sigmalapl, v2sigmatau)
# 3rd order
kxc = None
return exc, vxc, fxc, kxc
def eval_xc_gga_pol(xc_code, rho, spin=0, relativity=0, deriv=1, omega=None, verbose=None,
xcmodel=None):
# we only expect there to be a rho0 array, but I unpack it as (rho, deriv) here to be in line with the
# pyscf example -- the size of the 'rho' array depends on the xc type (LDA, GGA, etc.)
# so since LDA calculation, check for size first.
try:
rhoshape = len(rho.shape)
pol = 3
except:
rhoshape = len(rho)
pol = 2
# if len of shape == 3, spin polarized so compress to unpolarized for calculation
if rhoshape != pol:
# SPIN-UNPOLARIZED, ALL ARRAYS PASSED AS IS TO LIBXC
try:
# print("unpacking rho[:4] into rho0, dx, dy, dz")
rho0, dx, dy, dz = rho[:4]
sigma = jnp.array(dx**2+dy**2+dz**2)
except:
print("Unpacking failed...")
rho0, drho = rho[:4]
sigma = jnp.array(drho**2)
rho0 = jnp.array(rho0)
rhosig = jnp.stack([rho0, sigma], axis=1)
# print('rho/sig/rhosig shapes: ', rho0.shape, sigma.shape, rhosig.shape)
# calculate the "custom" energy with rho -- THIS IS e
# cast back to np.array since that's what pyscf works with
# pass as tuple -- (rho, sigma)
exc = jax.vmap(xcmodel)(rhosig)
exc = jnp.array(exc)/rho0
vrho_f = eqx.filter_grad(xcmodel)
vrhosigma = jnp.array(jax.vmap(vrho_f)(rhosig))
# vxc = vrho and vsigma, unpolarized, followed by nothing higher order in GGA
vxc = (vrhosigma[:, 0], vrhosigma[:, 1], None, None)
v2_f = jax.hessian(xcmodel)
v2 = jnp.array(jax.vmap(v2_f)(rhosig))
# print('v2 shape', v2.shape)
v2rho2 = v2[:, 0, 0]
v2rhosigma = v2[:, 0, 1]
v2sigma2 = v2[:, 1, 1]
v2lapl2 = None
vtau2 = None
v2rholapl = None
v2rhotau = None
v2lapltau = None
v2sigmalapl = None
v2sigmatau = None
# 2nd order functional derivative
fxc = (v2rho2, v2rhosigma, v2sigma2, v2lapl2, vtau2, v2rholapl, v2rhotau, v2lapltau, v2sigmalapl, v2sigmatau)
# 3rd order
kxc = None
else:
# SPIN POLARIZED; RESULT ARRAYS MUST BE RETURNED SPIN POLARIZED
# THIS IS HACKY -- THE NETWORK IS NOT ARCHITECTED TO ACCEPT ALL THE POLARIZED PARAMETERS, SO THE GRADIENTS ARE JUST DUPLICATED IN THE RETURN;
# GENERATE A FUNCTION THAT COMBINES THEN CALLS
def make_epsilon_function(model):
# importantly, do not place the vmap here
def get_epsilon(arr):
rhou, rhod, sigma1, sigma2, sigma3 = arr
rho0 = jnp.array(rhou+rhod)
# sum the sigma contributions
sumsigma = sigma1+sigma2+sigma3
rhosig = jnp.stack([rho0, sumsigma])
# calculate the "custom" energy with rho -- THIS IS e
# cast back to np.array since that's what pyscf works with
# pass as tuple -- (rho, sigma)
exc = model(rhosig)
return exc
return get_epsilon
# model_epsilon = partial(get_epsilon, model=xcmodel)
model_epsilon = make_epsilon_function(model=xcmodel)
rho_u, rho_d = rho
# print('rho_u, rho_d shapes:', rho_u.shape, rho_d.shape)
rho0u, dxu, dyu, dzu = rho_u[:4]
rho0d, dxd, dyd, dzd = rho_d[:4]
# up-up
dxu2 = dxu*dxu
dyu2 = dyu*dyu
dzu2 = dzu*dzu
# up-down
dxud = dxu*dxd
dyud = dyu*dyd
dzud = dzu*dzd
# down-down
dxd2 = dxd*dxd
dyd2 = dyd*dyd
dzd2 = dzd*dzd
sigma1 = dxu2+dyu2+dzu2
sigma2 = dxud+dyud+dzud
sigma3 = dxd2+dyd2+dzd2
rho0 = jnp.array(rho0u+rho0d)
# print('rho0 shape', rho0.shape)
# print('sigma1/2/3 shapes', sigma1.shape, sigma2.shape, sigma3.shape)
sumsigma = sigma1+sigma2+sigma3
# print('sumsigma shape', sumsigma.shape)
# sum the sigma contributions
rhosig = jnp.stack([rho0, sigma1+sigma2+sigma3], axis=1)
# calculate the "custom" energy with rho -- THIS IS e
# cast back to np.array since that's what pyscf works with
# pass as tuple -- (rho, sigma)
# epsilon here
input_arr = jnp.stack([rho0u, rho0d, sigma1, sigma2, sigma3], axis=1)
exc = jax.vmap(model_epsilon)(input_arr)
# print('epsilon shape', exc.shape)
# e here
exc = jnp.array(exc)/rho0
# exc = exc[jnp.newaxis, :]
# print('exc shape', exc.shape)
v1_f = jax.grad(model_epsilon)
v1 = jax.vmap(v1_f)(input_arr)
# vrho = vrho_up, vrho_down
vrho = jnp.vstack((v1[:, 0], v1[:, 1]))
# vsigma = vsigma1, vsigma2, vsigma3
vsigma = jnp.vstack((v1[:, 2], v1[:, 3], v1[:, 4]))
vxc = (vrho, vsigma)
# print('vrho shape', vrho.shape)
# print('vsigma shape', vsigma.shape)
v2_f = jax.hessian(model_epsilon)
v2 = jax.vmap(v2_f)(input_arr)
# print('v2 shape', v2.shape)
# v2rho2 = (v2rhou2, v2rhoud, v2rhod2)
v2rho2 = jnp.vstack((v2[:, 0, 0], v2[:, 0, 1], v2[:, 1, 1]))
# v2rhosigma is six-part = (u,1),(u,2),(u,3),(d,1),(d,2),(d,3)
v2rhosigma = jnp.vstack((v2[:, 0, 2], v2[:, 0, 3], v2[:, 0, 4], v2[:, 1, 2], v2[:, 1, 3], v2[:, 1, 4]))
# v2sigma2 is also six-part
v2sigma2 = jnp.vstack((v2[:, 2, 2], v2[:, 2, 3], v2[:, 2, 4], v2[:, 3, 3], v2[:, 3, 4], v2[:, 4, 4]))
# print('v2rho2 shape', v2rho2.shape)
# print('v2rhosigma shape', v2rhosigma.shape)
# print('v2sigma2 shape', v2sigma2.shape)
v2lapl2 = None
vtau2 = None
v2rholapl = None
v2rhotau = None
v2lapltau = None
v2sigmalapl = None
v2sigmatau = None
# 2nd order functional derivative
fxc = (v2rho2, v2rhosigma, v2sigma2, v2lapl2, vtau2, v2rholapl, v2rhotau, v2lapltau, v2sigmalapl, v2sigmatau)
# 3rd order
kxc = None
TRANSPOSE = True
if TRANSPOSE:
vxc = [i.T for i in vxc]
fxc = [i.T for i in fxc if type(i) == type(jnp.array([1]))]
return exc, vxc, fxc, kxc