8. PySCF(AD) Interface

PLEASE NOTE. Most classes and functions in this module are deprecated and will be removed or updated soon.

8.1. The generate_network_eval_xc Function

xcquinox.pyscf.generate_network_eval_xc(mf, dm, network)[source]

Generates a function to overwrite eval_xc with on the mf object, for use in training with pyscfad’s SCF cycle

Parameters:
  • mf (Pyscfad calculation kernel object) – Pyscfad calculation kernel object

  • dm (jax.Array) – Initial density matrix to use in the cycle

  • network (xcquinox.xc.eXC) – The network to use in evaluating the SCF cycle

Returns:

A function eval_xc that uses an xcquinox network as the pyscfad kernel calculation driver.

Return type:

function

The returned function:

eval_xc(xc_code, rho, ao, gw, coords, spin=0, relativity=0, deriv=1, omega=None, verbose=None) The function to use as driver for a pyscf(ad) calculation, using an xcquinox network.

This overwrites mf.eval_xc with a custom function, evaluating:

Exc_exc, vs = jax.value_and_grad(EXC_exc_vs, has_aux=True)(jnp.concatenate([jnp.expand_dims(rho0_a,-1),

jnp.expand_dims(rho0_b,-1), jnp.expand_dims(gamma_a,-1), jnp.expand_dims(gamma_ab,-1), jnp.expand_dims(gamma_b,-1), jnp.expand_dims(jnp.zeros_like(rho0_a),-1), #Dummy for laplacian jnp.expand_dims(jnp.zeros_like(rho0_a),-1), #Dummy for laplacian jnp.expand_dims(tau_a,-1), jnp.expand_dims(tau_b,-1), jnp.expand_dims(non_loc_a,-1), jnp.expand_dims(non_loc_b,-1)],axis=-1))

param xc_code:

The XC functional code string in libxc format, but it is ignored as the network is the calculation driver

type xc_code:

str

param rho:

The […, , N] arrays (… for spin polarized), N is the number of grid points. rho (,N) ordered as (rho, grad_x, grad_y, grad_z, laplacian, tau) rho (2,*,N) is [(rho_up, grad_x_up, grad_y_up, grad_z_up, laplacian_up, tau_up),

(rho_down, grad_x_down, grad_y_down, grad_z_down, laplacian_down, tau_down)]

PySCFAD doesn’t do spin-polarized grid calculations yet, so this will be unpolarized.

type rho:

jax.Array

param ao:

The atomic orbitals on the grid to use in the network calculation. Explcitly specified as the block loops break down the grid if memory is too low

type ao:

jax.Array

param ao:

The grid weights to use in the network calculation. Explcitly specified as the block loops break down the grid if memory is too low

type ao:

jax.Array

param ao:

The grid coordinates to use in the network calculation. Explcitly specified as the block loops break down the grid if memory is too low

type ao:

jax.Array

param spin:

The spin of the calculation, integer valued, polarized if non-zero, defaults to zero

type spin:

int

param relativity:

Integer, unused right now, defaults to zero

type relativity:

int

param deriv:

Unused here, defaults to 1

type deriv:

int

param omega:

Hybrid mixing term, unused here, defaults to None

type omega:

float

param verbose:

Unused here, defaults to None

type verbose:

int

return:

ex, vxc, fxc, kxc where: ex -> exc, XC energy density on the grid

vxc -> (vrho, vsigma, vlapl, vtau), gradients of the exc w.r.t. the quantities given. Only vrho and vtau are used, vsigma=vlapl=fxc=kxc=None. vrho = vs[:, 0]+vs[:, 1] vtau = vs[:, 7]+vs[:, 8]

rtype:

tuple