import equinox as eqx
import jax
import jax.numpy as jnp
from xcquinox.utils import get_dm_moe, pad_array, pad_array_list
from xcquinox.pyscf import generate_network_eval_xc
[docs]
@eqx.filter_value_and_grad
def compute_loss_mae(model, inputs, ref):
'''
Computes the mean-absolute-error loss of the model's prediction using the given inputs against the provided reference.
:param model: The model which is given to `jax.vmap` to generate predictions using the inputs given.
:type model: eqx.Module
:param inputs: The input points that are given to the network. Shape will be dependent on your network architecture.
:type inputs: array
:param ref: The reference values to be used in generating prediction error.
:type ref: array
:return: The MAE that will be used in backpropagation.
:rtype: float
'''
pred = jax.vmap(model)(inputs)
loss = jnp.mean(jnp.abs(pred - ref))
return loss
# =====================================================================
# =====================================================================
# DEPRECATED CLASSES -- TO BE REMOVED
# =====================================================================
# =====================================================================
[docs]
class E_loss(eqx.Module):
[docs]
def __init__(self):
'''
The standard energy loss module, RMSE loss of predicted vs. reference energies.
'''
super().__init__()
[docs]
def __call__(self, model, inp_dm, ref_en, ao_eval, grid_weights):
'''
Computes the energy loss for a given model and associated input density matrix, atomic orbitals on the grid, and grid weights
Loss is the RMSE energy, so predicted energy can potentially be a jax.Array of SCF guesses.
:param model: The XC object whose forward pass predicts the XC energy based on the inputs here.
:type model: xcquinox.xc.eXC
:param inp_dm: The density matrix to pass into the network for density creation on the grid.
:type inp_dm: jax.Array
:param ref_en: The reference energy to take the loss with respect to.
:type ref_en: jax.Array
:param ao_eval: Atomic orbitals evaluated on the grid
:type ao_eval: jax.Array
:param grid_weights: pyscfad's grid weights for the reference calculation
:type grid_weights: jax.Array
:return: The RMSE error.
:rtype: jax.Array
'''
e_pred = model(inp_dm, ao_eval, grid_weights)
eL = jnp.sqrt(jnp.mean((e_pred-ref_en)**2))
return eL
[docs]
class NL_E_loss(eqx.Module):
[docs]
def __init__(self):
'''
The standard energy loss module for a non-local descriptor training, RMSE loss of predicted vs. reference energies.
'''
super().__init__()
[docs]
def __call__(self, model, inp_dm, ref_en, ao_eval, grid_weights, mf):
'''
Computes the energy loss for a given model and associated input density matrix, atomic orbitals on the grid, and grid weights
Loss is the RMSE energy, so predicted energy can potentially be a jax.Array of SCF guesses.
:param model: The XC object whose forward pass predicts the XC energy based on the inputs here.
:type model: xcquinox.xc.eXC
:param inp_dm: The density matrix to pass into the network for density creation on the grid.
:type inp_dm: jax.Array
:param ref_en: The reference energy to take the loss with respect to.
:type ref_en: jax.Array
:param ao_eval: Atomic orbitals evaluated on the grid
:type ao_eval: jax.Array
:param grid_weights: pyscfad's grid weights for the reference calculation
:type grid_weights: jax.Array
:param mf: A pyscf(ad) converged calculation kernel if self.level > 3, used for building the CIDER nonlocal descriptors, defaults to None
:type mf: pyscfad.dft.RKS kernel
:return: The RMSE error.
:rtype: jax.Array
'''
e_pred = model(inp_dm, ao_eval, grid_weights, mf)
eL = jnp.sqrt(jnp.mean((e_pred-ref_en)**2))
return eL
[docs]
class DM_HoLu_loss(eqx.Module):
[docs]
def __init__(self):
"""
Creates DM_HoLu_loss object for use in training.
Options to compute the RMSE loss with respect to the density matrix, the RMSE homo-lumo gap loss, and the root-integrated-squared loss for the density on the grid.
"""
super().__init__()
[docs]
def __call__(self, model, ao_eval, gw, dm, eri, mo_occ, hc, s, ogd, holu=None, alpha0=0.7,
dmL=1.0, holuL=1.0, dm_to_rho=0.0):
"""
Forward pass to compute the total loss based on the given inputs.
If more than one loss flag evaluates to True, total loss returned is the rooted sum-of-squares for the individual losses.
i.e. total_loss = jnp.sqrt( (dmL*dmLv)**2 + (holuL*holuLv)**2 + (dm_to_rho*rhoLv)**2)
:param model: The model for use in generating the Vxc during the DM generation
:type model: xcquinox.xc.eXC
:param ao_eval: The atomic orbitals evaluated on the grid for the given molecule
:type ao_eval: jax.Array
:param gw: The grid weights associated to the current molecule's grids
:type gw: jax.Array
:param dm: Input reference density matrix for use during the one-shot forward pass to generate the new DM
:type dm: jax.Array
:param eri: Electron repulsion integrals associated with this molecule
:type eri: jax.Array
:param mo_occ: The molecule's molecular orbital occupation numbers
:type mo_occ: jax.Array
:param hc: The molecule's core Hamiltonian
:type hc: jax.Array
:param s: The molecule's overlap matrix
:type s: jax.Array
:param ogd: The original dimensions of this molecule's density matrix, used if padded to constrict the eigendecomposition to a relevant shape
:type ogd: jax.Array
:param holu: The reference HOMO-LUMO bandgap, if doing the corresponding loss, defaults to None
:type holu: jax.Array, optional
:param alpha0: The mixing parameter for the one-shot density matrix generation, defaults to 0.7
:type alpha0: float, optional
:param dmL: Float to evaluate whether or not to include RMSE DM loss, used as the loss weight, defaults to 1.0
:type dmL: float, optional
:param holuL: Float to evaluate whether or not to include RMSE HOMO-LUMO gap loss, used as the loss weight, defaults to 1.0
:type holuL: float, optional
:param dm_to_rho: Float to evaluate whether or not to include integrated rho-on-grid loss, used as the loss weight, defaults to 0.0
:type dm_to_rho: float, optional
:return: The root-sum of squares loss
:rtype: jax.Array
"""
# create the function for calculating E to take derivative of for vxc
def vgf(x): return model(x, ao_eval, gw)
# predict the network-based DM, mo_e, and mo_coeff
dmp, moep, mocp = get_dm_moe(dm, eri, vgf, mo_occ, hc, s, ogd, alpha0)
# density matrix loss, RMSE
dmLv = jnp.sqrt(jnp.mean((dmp - dm)**2)) if dmL else 0
# holu loss if needed
if holuL:
homo_i = jnp.max(jnp.nonzero(mo_occ, size=dm.shape[0])[0])
holup = moep[homo_i+1] - moep[homo_i]
holuLv = jnp.sqrt(jnp.mean((holu - holup)**2))
else:
holuLv = 0
# rho on grid loss if needed
if dm_to_rho:
rho = jnp.einsum('xij,yik,...jk->xy...i', ao_eval, ao_eval, dm)+1e-10
rhop = jnp.einsum('xij,yik,...jk->xy...i', ao_eval, ao_eval, dmp)+1e-10
# integrate squared density difference on grid, sqrt then weight by number of electrons
rhoLv = jnp.sqrt(jnp.sum((rho-rhop)**2 * gw))/jnp.sum(mo_occ)
else:
rhoLv = 0
return jnp.sqrt((dmL*dmLv)**2 + (holuL*holuLv)**2 + (dm_to_rho*rhoLv)**2)
[docs]
class Band_gap_1shot_loss(eqx.Module):
[docs]
def __init__(self):
"""
Initializer for the loss module, which attempts to find loss band gaps w.r.t. reference
.. todo: Make more robust for non-local descriptors
"""
super().__init__()
[docs]
def __call__(self, model, ao_eval, gw, dm, eri, mo_occ, hc, s, ogd, refgap, mf, alpha0=0.7):
"""
Forward pass for loss object
NOTE: This differs from HoLu loss in that it selects the deepest minimum w.r.t. the LUMO (Fermi energy)
:param model: The model that will be used in generating the molecular orbital energies ('band' energies)
:type model: xcquinox.xc.eXC
:param ao_eval: The atomic orbitals evaluated on the grid for the given molecule
:type ao_eval: jax.Array
:param gw: The grid weights associated to the current molecule's grids
:type gw: jax.Array
:param dm: Input reference density matrix for use during the one-shot forward pass to generate the new DM
:type dm: jax.Array
:param eri: Electron repulsion integrals associated with this molecule
:type eri: jax.Array
:param mo_occ: The molecule's molecular orbital occupation numbers
:type mo_occ: jax.Array
:param hc: The molecule's core Hamiltonian
:type hc: jax.Array
:param s: The molecule's overlap matrix
:type s: jax.Array
:param ogd: The original dimensions of this molecule's density matrix, used if padded to constrict the eigendecomposition to a relevant shape
:type ogd: jax.Array
:param refgap: The reference gap to optimzie against
:type refgap: jax.Array
:param mf: A pyscf(ad) converged calculation kernel if self.level > 3, used for building the CIDER nonlocal descriptors, defaults to None
:type mf: pyscfad.dft.RKS kernel
:param alpha0: The mixing parameter for the one-shot density matrix generation, defaults to 0.7
:type alpha0: float, optional
:return: Root-squared error between predicted gap (minimum of molecular energies) and the reference
:rtype: jax.Array
"""
def vgf(x): return model(x, ao_eval, gw, mf)
dmp, moep, mocp = get_dm_moe(dm, eri, vgf, mo_occ, hc, s, ogd, alpha0)
efermi = moep[mf.mol.nelectron//2-1]
moep -= efermi
# print(moep)
moep_gap = jnp.min(moep)
# print(moep_gap)
loss = jnp.sqrt((moep_gap - refgap)**2)
# print(loss)
return jnp.sqrt((moep_gap - refgap)**2)
[docs]
class DM_Gap_loss(eqx.Module):
[docs]
def __init__(self):
'''
Initializer for the DM_Gap_loss, a semi-local loss calculator to optimize the one-shot density matrices and band gaps for
Gamma-Gamma direct transitions for PBC structures.
Band gap optimized is the HOMO-LUMO gap here, hence why specifically Gamma-Gamma
'''
super().__init__()
[docs]
def __call__(self, model, ao, hc, eri, s, gw, inp_dm, mo_occ, ogd, refDM, refGap):
'''
Forward pass to calculate the DM (sum of squared error) and gap (squared error) loss for a given molecule.
Individual molecule loss is jnp.sqrt( dmL + gapL ),
where dmL = jnp.sum( (dm_pred-dm_ref)**2 )
and gapL = (gap_pred - gap_ref)**2
:param model: The model to use in the predictions, here to generate DM and molecular energies
:type model: xcquinox.xc.eXC
:param ao: Atomic orbitals evaluated on a grid
:type ao: jax.Array
:param hc: Core Hamiltonian
:type hc: jax.Array
:param eri: Electron repulsion integrals
:type eri: jax.Array
:param s: Overlap matrices
:type s: jax.Array
:param gw: Weights for the grid being used
:type gw: jax.Array
:param inp_dm: Initial density matrix guesses, from mf.get_init_guess(), to be used in the one-shot DM generation to produce a mixed DM to optimize against reference
:type inp_dm: jax.Array
:param mo_occ: Molecular orbital occupations
:type mo_occ: jax.Array
:param ogd: The original dimensions of the density matrix
:type ogd: tuple
:param refDM: Reference density matricex from high-accuracy method (e.g., CCSD(T)).
:type refDM: jax.Array
:param refGap: Reference band gap (e.g. from the Borlido 2019 dataset).
:type refGap: jax.Array
:return: The molecule's loss
:rtype: jax.Array/scalar
'''
homo_i = jnp.max(jnp.nonzero(mo_occ, size=inp_dm.shape[0])[0])
# vxc function for gradient
def vgf(x): return model(x, ao, gw)
dmp, moep, mocoep = get_dm_moe(inp_dm, eri, vgf, mo_occ, hc, s, ogd)
dmp = pad_array(dmp, inp_dm)
moep = pad_array(moep, moep, shape=(dmp.shape[0],))
gap_pred = moep[homo_i+1]-moep[homo_i]
dm_L = jnp.sum((dmp-refDM)**2)
gap_L = (gap_pred-refGap)**2
return jnp.sqrt(dm_L+gap_L)
[docs]
class DM_Gap_Loop_loss(eqx.Module):
[docs]
def __init__(self):
'''
Initializer for the DM_Gap_Loop_loss, a semi-local loss calculator to optimize the one-shot density matrices and band gaps for
Gamma-Gamma direct transitions for PBC structures.
Band gap optimized is the HOMO-LUMO gap here, hence why specifically Gamma-Gamma
This is a BATCH-CUMULATIVE LOSS class -- it will loop over the inputs and accumulate each DM_Gap loss before the loss is returned and used for optimization
'''
super().__init__()
[docs]
def __call__(self, model, aos, hcs, eris, ss, gws, inp_dms, mo_occs, ogds, refDMs, refGaps):
'''
Forward pass to calculate the DM (sum of squared error) and gap (squared error) loss across a given dataset.
Individual molecule loss is jnp.sqrt( dmL + gapL ),
where dmL = jnp.sum( (dm_pred-dm_ref)**2 )
and gapL = (gap_pred - gap_ref)**2
:param model: The model to use in the predictions, here to generate DM and molecular energies
:type model: xcquinox.xc.eXC
:param aos: Atomic orbitals evaluated on a grid
:type aos: list of jax.Arrays
:param hcs: Core Hamiltonians
:type hcs: list of jax.Arrays
:param eris: Electron repulsion integrals
:type eris: list of jax.Arrays
:param ss: Overlap matrices
:type ss: list of jax.Arrays
:param gws: Weights for the grids being used
:type gws: list of jax.Arrays
:param inp_dms: Initial density matrix guesses, from mf.get_init_guess(), to be used in the one-shot DM generation to produce a mixed DM to optimize against reference
:type inp_dms: list of jax.Arrays
:param mo_occs: Molecular orbital occupations
:type mo_occs: list of jax.Arrays
:param ogds: The original dimensions of the density matrices
:type ogds: list of tuples
:param refDMs: List of reference density matrices from high-accuracy method (e.g., CCSD(T)).
:type refDMs: list of jax.Arrays
:param refGaps: List of reference band gaps (e.g. from the Borlido 2019 dataset).
:type refGaps: list of jax.Arrays
:return: The cumulative loss across the dataset
:rtype: jax.Array/scalar
'''
total_loss = 0
for idx in range(len(aos)):
# subselect the individual loss data
mo_occ = mo_occs[idx]
inp_dm = inp_dms[idx]
ao = aos[idx]
gw = gws[idx]
eri = eris[idx]
hc = hcs[idx]
s = ss[idx]
ogd = ogds[idx]
refGap = refGaps[idx]
refDM = refDMs[idx]
homo_i = jnp.max(jnp.nonzero(mo_occ, size=inp_dm.shape[0])[0])
# vxc function for gradient
def vgf(x): return model(x, ao, gw)
dmp, moep, mocoep = get_dm_moe(inp_dm, eri, vgf, mo_occ, hc, s, ogd)
dmp = pad_array(dmp, inp_dm)
moep = pad_array(moep, moep, shape=(dmp.shape[0],))
gap_pred = moep[homo_i+1]-moep[homo_i]
dm_L = jnp.sum((dmp-refDM)**2)
gap_L = (gap_pred-refGap)**2
total_loss += jnp.sqrt(dm_L+gap_L)
return total_loss
class E_PySCFAD_loss(eqx.Module):
def __init__(self):
'''
The standard energy loss module, RMSE loss of predicted vs. reference energies.
'''
super().__init__()
def __call__(self, model, mf, inp_dm, ref_en):
'''
Computes the energy loss for a given model and associated input density matrix, atomic orbitals on the grid, and grid weights
Loss is the RMSE energy, so predicted energy can potentially be a jax.Array of SCF guesses.
:param model: The XC object whose forward pass predicts the XC energy based on the inputs here.
:type model: xcquinox.xc.eXC
:param mf: A pyscf(ad) converged calculation kernel, whose eval_xc is overwritten to use the model calculation
:type mf: pyscfad.dft.RKS kernel
:param inp_dm: The density matrix to pass into the network for density creation on the grid.
:type inp_dm: jax.Array
:param ref_en: The reference energy to take the loss with respect to.
:type ref_en: jax.Array
:return: The RMSE error.
:rtype: jax.Array
'''
print('generating eval_xc function to overwrite')
evxc = generate_network_eval_xc(mf=mf, dm=inp_dm, network=model)
mf.define_xc_(evxc, xctype='MGGA')
print('predicting energy...')
e_pred = mf.kernel()
print('energy predicted')
eL = jnp.sqrt(jnp.mean((e_pred-ref_en)**2))
return eL