7. Loss Computation Classes
7.1. The compute_loss_mae Function
- xcquinox.loss.compute_loss_mae(model, inputs, ref)[source]
Computes the mean-absolute-error loss of the model’s prediction using the given inputs against the provided reference.
- Parameters:
model (eqx.Module) – The model which is given to jax.vmap to generate predictions using the inputs given.
inputs (array) – The input points that are given to the network. Shape will be dependent on your network architecture.
ref (array) – The reference values to be used in generating prediction error.
- Returns:
The MAE that will be used in backpropagation.
- Return type:
float
7.2. The E_loss Class
- class xcquinox.loss.E_loss[source]
Bases:
Module- __call__(model, inp_dm, ref_en, ao_eval, grid_weights)[source]
Computes the energy loss for a given model and associated input density matrix, atomic orbitals on the grid, and grid weights
Loss is the RMSE energy, so predicted energy can potentially be a jax.Array of SCF guesses.
- Parameters:
model (xcquinox.xc.eXC) – The XC object whose forward pass predicts the XC energy based on the inputs here.
inp_dm (jax.Array) – The density matrix to pass into the network for density creation on the grid.
ref_en (jax.Array) – The reference energy to take the loss with respect to.
ao_eval (jax.Array) – Atomic orbitals evaluated on the grid
grid_weights (jax.Array) – pyscfad’s grid weights for the reference calculation
- Returns:
The RMSE error.
- Return type:
jax.Array
7.3. The NL_E_loss Class
- class xcquinox.loss.NL_E_loss[source]
Bases:
Module- __init__()[source]
The standard energy loss module for a non-local descriptor training, RMSE loss of predicted vs. reference energies.
- __call__(model, inp_dm, ref_en, ao_eval, grid_weights, mf)[source]
Computes the energy loss for a given model and associated input density matrix, atomic orbitals on the grid, and grid weights
Loss is the RMSE energy, so predicted energy can potentially be a jax.Array of SCF guesses.
- Parameters:
model (xcquinox.xc.eXC) – The XC object whose forward pass predicts the XC energy based on the inputs here.
inp_dm (jax.Array) – The density matrix to pass into the network for density creation on the grid.
ref_en (jax.Array) – The reference energy to take the loss with respect to.
ao_eval (jax.Array) – Atomic orbitals evaluated on the grid
grid_weights (jax.Array) – pyscfad’s grid weights for the reference calculation
mf (pyscfad.dft.RKS kernel) – A pyscf(ad) converged calculation kernel if self.level > 3, used for building the CIDER nonlocal descriptors, defaults to None
- Returns:
The RMSE error.
- Return type:
jax.Array
7.4. The DM_HoLu_loss Class
- class xcquinox.loss.DM_HoLu_loss[source]
Bases:
Module- __init__()[source]
Creates DM_HoLu_loss object for use in training.
Options to compute the RMSE loss with respect to the density matrix, the RMSE homo-lumo gap loss, and the root-integrated-squared loss for the density on the grid.
- __call__(model, ao_eval, gw, dm, eri, mo_occ, hc, s, ogd, holu=None, alpha0=0.7, dmL=1.0, holuL=1.0, dm_to_rho=0.0)[source]
Forward pass to compute the total loss based on the given inputs.
If more than one loss flag evaluates to True, total loss returned is the rooted sum-of-squares for the individual losses. i.e. total_loss = jnp.sqrt( (dmL*dmLv)**2 + (holuL*holuLv)**2 + (dm_to_rho*rhoLv)**2)
- Parameters:
model (xcquinox.xc.eXC) – The model for use in generating the Vxc during the DM generation
ao_eval (jax.Array) – The atomic orbitals evaluated on the grid for the given molecule
gw (jax.Array) – The grid weights associated to the current molecule’s grids
dm (jax.Array) – Input reference density matrix for use during the one-shot forward pass to generate the new DM
eri (jax.Array) – Electron repulsion integrals associated with this molecule
mo_occ (jax.Array) – The molecule’s molecular orbital occupation numbers
hc (jax.Array) – The molecule’s core Hamiltonian
s (jax.Array) – The molecule’s overlap matrix
ogd (jax.Array) – The original dimensions of this molecule’s density matrix, used if padded to constrict the eigendecomposition to a relevant shape
holu (jax.Array, optional) – The reference HOMO-LUMO bandgap, if doing the corresponding loss, defaults to None
alpha0 (float, optional) – The mixing parameter for the one-shot density matrix generation, defaults to 0.7
dmL (float, optional) – Float to evaluate whether or not to include RMSE DM loss, used as the loss weight, defaults to 1.0
holuL (float, optional) – Float to evaluate whether or not to include RMSE HOMO-LUMO gap loss, used as the loss weight, defaults to 1.0
dm_to_rho (float, optional) – Float to evaluate whether or not to include integrated rho-on-grid loss, used as the loss weight, defaults to 0.0
- Returns:
The root-sum of squares loss
- Return type:
jax.Array
7.5. The Band_gap_1shot_loss Class
- class xcquinox.loss.Band_gap_1shot_loss[source]
Bases:
Module- __init__()[source]
Initializer for the loss module, which attempts to find loss band gaps w.r.t. reference
- __call__(model, ao_eval, gw, dm, eri, mo_occ, hc, s, ogd, refgap, mf, alpha0=0.7)[source]
Forward pass for loss object
NOTE: This differs from HoLu loss in that it selects the deepest minimum w.r.t. the LUMO (Fermi energy)
- Parameters:
model (xcquinox.xc.eXC) – The model that will be used in generating the molecular orbital energies (‘band’ energies)
ao_eval (jax.Array) – The atomic orbitals evaluated on the grid for the given molecule
gw (jax.Array) – The grid weights associated to the current molecule’s grids
dm (jax.Array) – Input reference density matrix for use during the one-shot forward pass to generate the new DM
eri (jax.Array) – Electron repulsion integrals associated with this molecule
mo_occ (jax.Array) – The molecule’s molecular orbital occupation numbers
hc (jax.Array) – The molecule’s core Hamiltonian
s (jax.Array) – The molecule’s overlap matrix
ogd (jax.Array) – The original dimensions of this molecule’s density matrix, used if padded to constrict the eigendecomposition to a relevant shape
refgap (jax.Array) – The reference gap to optimzie against
mf (pyscfad.dft.RKS kernel) – A pyscf(ad) converged calculation kernel if self.level > 3, used for building the CIDER nonlocal descriptors, defaults to None
alpha0 (float, optional) – The mixing parameter for the one-shot density matrix generation, defaults to 0.7
- Returns:
Root-squared error between predicted gap (minimum of molecular energies) and the reference
- Return type:
jax.Array
7.6. The DM_Gap_loss Class
- class xcquinox.loss.DM_Gap_loss[source]
Bases:
Module- __init__()[source]
Initializer for the DM_Gap_loss, a semi-local loss calculator to optimize the one-shot density matrices and band gaps for Gamma-Gamma direct transitions for PBC structures.
Band gap optimized is the HOMO-LUMO gap here, hence why specifically Gamma-Gamma
- __call__(model, ao, hc, eri, s, gw, inp_dm, mo_occ, ogd, refDM, refGap)[source]
Forward pass to calculate the DM (sum of squared error) and gap (squared error) loss for a given molecule.
Individual molecule loss is jnp.sqrt( dmL + gapL ), where dmL = jnp.sum( (dm_pred-dm_ref)**2 ) and gapL = (gap_pred - gap_ref)**2
- Parameters:
model (xcquinox.xc.eXC) – The model to use in the predictions, here to generate DM and molecular energies
ao (jax.Array) – Atomic orbitals evaluated on a grid
hc (jax.Array) – Core Hamiltonian
eri (jax.Array) – Electron repulsion integrals
s (jax.Array) – Overlap matrices
gw (jax.Array) – Weights for the grid being used
inp_dm (jax.Array) – Initial density matrix guesses, from mf.get_init_guess(), to be used in the one-shot DM generation to produce a mixed DM to optimize against reference
mo_occ (jax.Array) – Molecular orbital occupations
ogd (tuple) – The original dimensions of the density matrix
refDM (jax.Array) – Reference density matricex from high-accuracy method (e.g., CCSD(T)).
refGap (jax.Array) – Reference band gap (e.g. from the Borlido 2019 dataset).
- Returns:
The molecule’s loss
- Return type:
jax.Array/scalar
7.7. The DM_Gap_Loop_loss Class
- class xcquinox.loss.DM_Gap_Loop_loss[source]
Bases:
Module- __init__()[source]
Initializer for the DM_Gap_Loop_loss, a semi-local loss calculator to optimize the one-shot density matrices and band gaps for Gamma-Gamma direct transitions for PBC structures.
Band gap optimized is the HOMO-LUMO gap here, hence why specifically Gamma-Gamma
This is a BATCH-CUMULATIVE LOSS class – it will loop over the inputs and accumulate each DM_Gap loss before the loss is returned and used for optimization
- __call__(model, aos, hcs, eris, ss, gws, inp_dms, mo_occs, ogds, refDMs, refGaps)[source]
Forward pass to calculate the DM (sum of squared error) and gap (squared error) loss across a given dataset.
Individual molecule loss is jnp.sqrt( dmL + gapL ), where dmL = jnp.sum( (dm_pred-dm_ref)**2 ) and gapL = (gap_pred - gap_ref)**2
- Parameters:
model (xcquinox.xc.eXC) – The model to use in the predictions, here to generate DM and molecular energies
aos (list of jax.Arrays) – Atomic orbitals evaluated on a grid
hcs (list of jax.Arrays) – Core Hamiltonians
eris (list of jax.Arrays) – Electron repulsion integrals
ss (list of jax.Arrays) – Overlap matrices
gws (list of jax.Arrays) – Weights for the grids being used
inp_dms (list of jax.Arrays) – Initial density matrix guesses, from mf.get_init_guess(), to be used in the one-shot DM generation to produce a mixed DM to optimize against reference
mo_occs (list of jax.Arrays) – Molecular orbital occupations
ogds (list of tuples) – The original dimensions of the density matrices
refDMs (list of jax.Arrays) – List of reference density matrices from high-accuracy method (e.g., CCSD(T)).
refGaps (list of jax.Arrays) – List of reference band gaps (e.g. from the Borlido 2019 dataset).
- Returns:
The cumulative loss across the dataset
- Return type:
jax.Array/scalar