7. Loss Computation Classes

7.1. The compute_loss_mae Function

xcquinox.loss.compute_loss_mae(model, inputs, ref)[source]

Computes the mean-absolute-error loss of the model’s prediction using the given inputs against the provided reference.

Parameters:
  • model (eqx.Module) – The model which is given to jax.vmap to generate predictions using the inputs given.

  • inputs (array) – The input points that are given to the network. Shape will be dependent on your network architecture.

  • ref (array) – The reference values to be used in generating prediction error.

Returns:

The MAE that will be used in backpropagation.

Return type:

float

7.2. The E_loss Class

class xcquinox.loss.E_loss[source]

Bases: Module

__init__()[source]

The standard energy loss module, RMSE loss of predicted vs. reference energies.

__call__(model, inp_dm, ref_en, ao_eval, grid_weights)[source]

Computes the energy loss for a given model and associated input density matrix, atomic orbitals on the grid, and grid weights

Loss is the RMSE energy, so predicted energy can potentially be a jax.Array of SCF guesses.

Parameters:
  • model (xcquinox.xc.eXC) – The XC object whose forward pass predicts the XC energy based on the inputs here.

  • inp_dm (jax.Array) – The density matrix to pass into the network for density creation on the grid.

  • ref_en (jax.Array) – The reference energy to take the loss with respect to.

  • ao_eval (jax.Array) – Atomic orbitals evaluated on the grid

  • grid_weights (jax.Array) – pyscfad’s grid weights for the reference calculation

Returns:

The RMSE error.

Return type:

jax.Array

7.3. The NL_E_loss Class

class xcquinox.loss.NL_E_loss[source]

Bases: Module

__init__()[source]

The standard energy loss module for a non-local descriptor training, RMSE loss of predicted vs. reference energies.

__call__(model, inp_dm, ref_en, ao_eval, grid_weights, mf)[source]

Computes the energy loss for a given model and associated input density matrix, atomic orbitals on the grid, and grid weights

Loss is the RMSE energy, so predicted energy can potentially be a jax.Array of SCF guesses.

Parameters:
  • model (xcquinox.xc.eXC) – The XC object whose forward pass predicts the XC energy based on the inputs here.

  • inp_dm (jax.Array) – The density matrix to pass into the network for density creation on the grid.

  • ref_en (jax.Array) – The reference energy to take the loss with respect to.

  • ao_eval (jax.Array) – Atomic orbitals evaluated on the grid

  • grid_weights (jax.Array) – pyscfad’s grid weights for the reference calculation

  • mf (pyscfad.dft.RKS kernel) – A pyscf(ad) converged calculation kernel if self.level > 3, used for building the CIDER nonlocal descriptors, defaults to None

Returns:

The RMSE error.

Return type:

jax.Array

7.4. The DM_HoLu_loss Class

class xcquinox.loss.DM_HoLu_loss[source]

Bases: Module

__init__()[source]

Creates DM_HoLu_loss object for use in training.

Options to compute the RMSE loss with respect to the density matrix, the RMSE homo-lumo gap loss, and the root-integrated-squared loss for the density on the grid.

__call__(model, ao_eval, gw, dm, eri, mo_occ, hc, s, ogd, holu=None, alpha0=0.7, dmL=1.0, holuL=1.0, dm_to_rho=0.0)[source]

Forward pass to compute the total loss based on the given inputs.

If more than one loss flag evaluates to True, total loss returned is the rooted sum-of-squares for the individual losses. i.e. total_loss = jnp.sqrt( (dmL*dmLv)**2 + (holuL*holuLv)**2 + (dm_to_rho*rhoLv)**2)

Parameters:
  • model (xcquinox.xc.eXC) – The model for use in generating the Vxc during the DM generation

  • ao_eval (jax.Array) – The atomic orbitals evaluated on the grid for the given molecule

  • gw (jax.Array) – The grid weights associated to the current molecule’s grids

  • dm (jax.Array) – Input reference density matrix for use during the one-shot forward pass to generate the new DM

  • eri (jax.Array) – Electron repulsion integrals associated with this molecule

  • mo_occ (jax.Array) – The molecule’s molecular orbital occupation numbers

  • hc (jax.Array) – The molecule’s core Hamiltonian

  • s (jax.Array) – The molecule’s overlap matrix

  • ogd (jax.Array) – The original dimensions of this molecule’s density matrix, used if padded to constrict the eigendecomposition to a relevant shape

  • holu (jax.Array, optional) – The reference HOMO-LUMO bandgap, if doing the corresponding loss, defaults to None

  • alpha0 (float, optional) – The mixing parameter for the one-shot density matrix generation, defaults to 0.7

  • dmL (float, optional) – Float to evaluate whether or not to include RMSE DM loss, used as the loss weight, defaults to 1.0

  • holuL (float, optional) – Float to evaluate whether or not to include RMSE HOMO-LUMO gap loss, used as the loss weight, defaults to 1.0

  • dm_to_rho (float, optional) – Float to evaluate whether or not to include integrated rho-on-grid loss, used as the loss weight, defaults to 0.0

Returns:

The root-sum of squares loss

Return type:

jax.Array

7.5. The Band_gap_1shot_loss Class

class xcquinox.loss.Band_gap_1shot_loss[source]

Bases: Module

__init__()[source]

Initializer for the loss module, which attempts to find loss band gaps w.r.t. reference

__call__(model, ao_eval, gw, dm, eri, mo_occ, hc, s, ogd, refgap, mf, alpha0=0.7)[source]

Forward pass for loss object

NOTE: This differs from HoLu loss in that it selects the deepest minimum w.r.t. the LUMO (Fermi energy)

Parameters:
  • model (xcquinox.xc.eXC) – The model that will be used in generating the molecular orbital energies (‘band’ energies)

  • ao_eval (jax.Array) – The atomic orbitals evaluated on the grid for the given molecule

  • gw (jax.Array) – The grid weights associated to the current molecule’s grids

  • dm (jax.Array) – Input reference density matrix for use during the one-shot forward pass to generate the new DM

  • eri (jax.Array) – Electron repulsion integrals associated with this molecule

  • mo_occ (jax.Array) – The molecule’s molecular orbital occupation numbers

  • hc (jax.Array) – The molecule’s core Hamiltonian

  • s (jax.Array) – The molecule’s overlap matrix

  • ogd (jax.Array) – The original dimensions of this molecule’s density matrix, used if padded to constrict the eigendecomposition to a relevant shape

  • refgap (jax.Array) – The reference gap to optimzie against

  • mf (pyscfad.dft.RKS kernel) – A pyscf(ad) converged calculation kernel if self.level > 3, used for building the CIDER nonlocal descriptors, defaults to None

  • alpha0 (float, optional) – The mixing parameter for the one-shot density matrix generation, defaults to 0.7

Returns:

Root-squared error between predicted gap (minimum of molecular energies) and the reference

Return type:

jax.Array

7.6. The DM_Gap_loss Class

class xcquinox.loss.DM_Gap_loss[source]

Bases: Module

__init__()[source]

Initializer for the DM_Gap_loss, a semi-local loss calculator to optimize the one-shot density matrices and band gaps for Gamma-Gamma direct transitions for PBC structures.

Band gap optimized is the HOMO-LUMO gap here, hence why specifically Gamma-Gamma

__call__(model, ao, hc, eri, s, gw, inp_dm, mo_occ, ogd, refDM, refGap)[source]

Forward pass to calculate the DM (sum of squared error) and gap (squared error) loss for a given molecule.

Individual molecule loss is jnp.sqrt( dmL + gapL ), where dmL = jnp.sum( (dm_pred-dm_ref)**2 ) and gapL = (gap_pred - gap_ref)**2

Parameters:
  • model (xcquinox.xc.eXC) – The model to use in the predictions, here to generate DM and molecular energies

  • ao (jax.Array) – Atomic orbitals evaluated on a grid

  • hc (jax.Array) – Core Hamiltonian

  • eri (jax.Array) – Electron repulsion integrals

  • s (jax.Array) – Overlap matrices

  • gw (jax.Array) – Weights for the grid being used

  • inp_dm (jax.Array) – Initial density matrix guesses, from mf.get_init_guess(), to be used in the one-shot DM generation to produce a mixed DM to optimize against reference

  • mo_occ (jax.Array) – Molecular orbital occupations

  • ogd (tuple) – The original dimensions of the density matrix

  • refDM (jax.Array) – Reference density matricex from high-accuracy method (e.g., CCSD(T)).

  • refGap (jax.Array) – Reference band gap (e.g. from the Borlido 2019 dataset).

Returns:

The molecule’s loss

Return type:

jax.Array/scalar

7.7. The DM_Gap_Loop_loss Class

class xcquinox.loss.DM_Gap_Loop_loss[source]

Bases: Module

__init__()[source]

Initializer for the DM_Gap_Loop_loss, a semi-local loss calculator to optimize the one-shot density matrices and band gaps for Gamma-Gamma direct transitions for PBC structures.

Band gap optimized is the HOMO-LUMO gap here, hence why specifically Gamma-Gamma

This is a BATCH-CUMULATIVE LOSS class – it will loop over the inputs and accumulate each DM_Gap loss before the loss is returned and used for optimization

__call__(model, aos, hcs, eris, ss, gws, inp_dms, mo_occs, ogds, refDMs, refGaps)[source]

Forward pass to calculate the DM (sum of squared error) and gap (squared error) loss across a given dataset.

Individual molecule loss is jnp.sqrt( dmL + gapL ), where dmL = jnp.sum( (dm_pred-dm_ref)**2 ) and gapL = (gap_pred - gap_ref)**2

Parameters:
  • model (xcquinox.xc.eXC) – The model to use in the predictions, here to generate DM and molecular energies

  • aos (list of jax.Arrays) – Atomic orbitals evaluated on a grid

  • hcs (list of jax.Arrays) – Core Hamiltonians

  • eris (list of jax.Arrays) – Electron repulsion integrals

  • ss (list of jax.Arrays) – Overlap matrices

  • gws (list of jax.Arrays) – Weights for the grids being used

  • inp_dms (list of jax.Arrays) – Initial density matrix guesses, from mf.get_init_guess(), to be used in the one-shot DM generation to produce a mixed DM to optimize against reference

  • mo_occs (list of jax.Arrays) – Molecular orbital occupations

  • ogds (list of tuples) – The original dimensions of the density matrices

  • refDMs (list of jax.Arrays) – List of reference density matrices from high-accuracy method (e.g., CCSD(T)).

  • refGaps (list of jax.Arrays) – List of reference band gaps (e.g. from the Borlido 2019 dataset).

Returns:

The cumulative loss across the dataset

Return type:

jax.Array/scalar