8. PySCF(AD) Interface
PLEASE NOTE. Most classes and functions in this module are deprecated and will be removed or updated soon.
8.1. The generate_network_eval_xc Function
- xcquinox.pyscf.generate_network_eval_xc(mf, dm, network)[source]
Generates a function to overwrite eval_xc with on the mf object, for use in training with pyscfad’s SCF cycle
- Parameters:
mf (Pyscfad calculation kernel object) – Pyscfad calculation kernel object
dm (jax.Array) – Initial density matrix to use in the cycle
network (xcquinox.xc.eXC) – The network to use in evaluating the SCF cycle
- Returns:
A function eval_xc that uses an xcquinox network as the pyscfad kernel calculation driver.
- Return type:
function
The returned function:
eval_xc(xc_code, rho, ao, gw, coords, spin=0, relativity=0, deriv=1, omega=None, verbose=None) The function to use as driver for a pyscf(ad) calculation, using an xcquinox network.
This overwrites mf.eval_xc with a custom function, evaluating:
- Exc_exc, vs = jax.value_and_grad(EXC_exc_vs, has_aux=True)(jnp.concatenate([jnp.expand_dims(rho0_a,-1),
jnp.expand_dims(rho0_b,-1), jnp.expand_dims(gamma_a,-1), jnp.expand_dims(gamma_ab,-1), jnp.expand_dims(gamma_b,-1), jnp.expand_dims(jnp.zeros_like(rho0_a),-1), #Dummy for laplacian jnp.expand_dims(jnp.zeros_like(rho0_a),-1), #Dummy for laplacian jnp.expand_dims(tau_a,-1), jnp.expand_dims(tau_b,-1), jnp.expand_dims(non_loc_a,-1), jnp.expand_dims(non_loc_b,-1)],axis=-1))
- param xc_code:
The XC functional code string in libxc format, but it is ignored as the network is the calculation driver
- type xc_code:
str
- param rho:
The […, , N] arrays (… for spin polarized), N is the number of grid points. rho (,N) ordered as (rho, grad_x, grad_y, grad_z, laplacian, tau) rho (2,*,N) is [(rho_up, grad_x_up, grad_y_up, grad_z_up, laplacian_up, tau_up),
(rho_down, grad_x_down, grad_y_down, grad_z_down, laplacian_down, tau_down)]
PySCFAD doesn’t do spin-polarized grid calculations yet, so this will be unpolarized.
- type rho:
jax.Array
- param ao:
The atomic orbitals on the grid to use in the network calculation. Explcitly specified as the block loops break down the grid if memory is too low
- type ao:
jax.Array
- param ao:
The grid weights to use in the network calculation. Explcitly specified as the block loops break down the grid if memory is too low
- type ao:
jax.Array
- param ao:
The grid coordinates to use in the network calculation. Explcitly specified as the block loops break down the grid if memory is too low
- type ao:
jax.Array
- param spin:
The spin of the calculation, integer valued, polarized if non-zero, defaults to zero
- type spin:
int
- param relativity:
Integer, unused right now, defaults to zero
- type relativity:
int
- param deriv:
Unused here, defaults to 1
- type deriv:
int
- param omega:
Hybrid mixing term, unused here, defaults to None
- type omega:
float
- param verbose:
Unused here, defaults to None
- type verbose:
int
- return:
ex, vxc, fxc, kxc where: ex -> exc, XC energy density on the grid
vxc -> (vrho, vsigma, vlapl, vtau), gradients of the exc w.r.t. the quantities given. Only vrho and vtau are used, vsigma=vlapl=fxc=kxc=None. vrho = vs[:, 0]+vs[:, 1] vtau = vs[:, 7]+vs[:, 8]
- rtype:
tuple