import numpy as np
import equinox as eqx
import jax
import os
import sys
import pickle
import jax.numpy as jnp
from xcquinox.net import get_net
from .utils import lda_x, pw92c_unpolarized
# check if cider is in the environment
# from mldftdat.density import get_exchange_descriptors2
# from mldftdat.analyzers import RKSAnalyzer
# =====================================================================
# =====================================================================
# HEG Classes
# =====================================================================
# =====================================================================
[docs]
class LDA_X(eqx.Module):
[docs]
def __init__(self):
"""
__init__ Constructs an object whose forward pass computes the LDA exchange energy based on a given density input.
"""
super().__init__()
[docs]
def __call__(self, rho):
"""
__call__ Computes the LDA exchange energy for a given value of the density.
.. math:: E_x = -\\frac{3}{4} \\Big(\\frac{3\\rho}{\\pi} \\Big)^{1/3}
:param rho: The value of the density
:type rho: float, broadcastable
:return: The LDA exchange energy for the input value(s).
:rtype: float, broadcastable
"""
return -3/4*(3/np.pi)**(1/3)*rho**(1/3)
params_a_pp = [1, 1, 1]
params_a_alpha1 = [0.21370, 0.20548, 0.11125]
params_a_a = [0.031091, 0.015545, 0.016887]
params_a_beta1 = [7.5957, 14.1189, 10.357]
params_a_beta2 = [3.5876, 6.1977, 3.6231]
params_a_beta3 = [1.6382, 3.3662, 0.88026]
params_a_beta4 = [0.49294, 0.62517, 0.49671]
params_a_fz20 = 1.709921
[docs]
class PW_C(eqx.Module):
[docs]
def __init__(self):
"""
__init__ Constructs an object whose forward pass computes the UEG correlation energy, as parameterized by Perdew & Wang
DOI: `10.1103/PhysRevB.45.13244`_.
.. _10.1103/PhysRevB.45.13244: https://doi.org/10.1103/PhysRevB.45.13244
"""
super().__init__()
[docs]
def __call__(self, rs, zeta):
"""
__call__ Forward pass, computing the correlation energy per electron for the given input values, rs and zeta, where
.. math:: r_s = \\Big[\\frac{3}{4\\pi (\\rho_\\uparrow+\\rho_\\downarrow)} \\Big]^{1/3}
.. math:: \\zeta = \\frac{\\rho_\\uparrow-\\rho_\\downarrow}{\\rho_\\uparrow+\\rho_\\downarrow}
Atomic units are assumed.
:param rs: The Wigner-Seiz radius corresponding to the given density value
:type rs: float, broadcastable
:param zeta: The spin-polarization
:type zeta: float, broadcastable
"""
def g_aux(k, rs):
return params_a_beta1[k]*jnp.sqrt(rs) + params_a_beta2[k]*rs\
+ params_a_beta3[k]*rs**1.5 + params_a_beta4[k]*rs**(params_a_pp[k] + 1)
def g(k, rs):
return -2*params_a_a[k]*(1 + params_a_alpha1[k]*rs)\
* jnp.log(1 + 1/(2*params_a_a[k]*g_aux(k, rs)))
def f_zeta(zeta):
return ((1+zeta)**(4/3) + (1-zeta)**(4/3) - 2)/(2**(4/3)-2)
def f_pw(rs, zeta):
return g(0, rs) + zeta**4*f_zeta(zeta)*(g(1, rs) - g(0, rs) + g(2, rs)/params_a_fz20)\
- f_zeta(zeta)*g(2, rs)/params_a_fz20
return f_pw(rs, zeta)
# =====================================================================
# =====================================================================
# GGA LEVEL CLASSES
# =====================================================================
# =====================================================================
[docs]
class RXCModel_GGA(eqx.Module):
xnet: eqx.Module
cnet: eqx.Module
[docs]
def __init__(self, xnet, cnet):
'''
Initializes a combined X+C GGA model into one object that can be used in optimizing a functional.
:param xnet: The exchange enhancement factor network
:type xnet: eqx.Module
:param cnet: The correlation enhancement factor network
:type cnet: eqx.Module
'''
self.xnet = xnet
self.cnet = cnet
[docs]
def __call__(self, inputs):
'''
Calls the networks to generate an EPSILON value (rho* xc_energy_density), which libxc expects derivatives of.
Divide by rho to get the energy density itself.
This call calculates: RHO*( ( lda_x(RHO)*Fx(inputs) )+(pw92c_unpolarized(RHO)*Fc(inputs) )
:param inputs: The expected inputs (here, [rho, sigma]) that will be vmapped against while calling the x/c enhancement factor networks.
:type inputs: array
:return: The epsilon values generated by the networks.
:rtype: array
'''
# this generate epsilon, not exc -- divide end result by rho when needed
rho = inputs[0]
sigma = inputs[1]
# print('RXCModel call - inputs {}'.format(inputs))
return rho*(lda_x(rho)*self.xnet(inputs[..., jnp.newaxis]) + pw92c_unpolarized(rho)*self.cnet(inputs[..., jnp.newaxis])).flatten()[0]
# return rho*(lda_x(rho)*self.xnet(inputs) + pw92c_unpolarized(rho)*self.cnet(inputs)).flatten()[0]
# =====================================================================
# =====================================================================
# Meta-GGA LEVEL CLASSES
# =====================================================================
# =====================================================================
[docs]
class RXCModel_MGGA(eqx.Module):
xnet: eqx.Module
cnet: eqx.Module
[docs]
def __init__(self, xnet, cnet):
'''
Initializes a combined X+C MGGA model into one object that can be used in optimizing a functional.
:param xnet: The exchange enhancement factor network
:type xnet: eqx.Module
:param cnet: The correlation enhancement factor network
:type cnet: eqx.Module
'''
self.xnet = xnet
self.cnet = cnet
[docs]
def __call__(self, inputs):
'''
Calls the networks to generate an EPSILON value (rho* xc_energy_density), which libxc expects derivatives of.
Divide by rho to get the energy density itself.
This call calculates: RHO*( ( lda_x(RHO)*Fx(inputs) )+(pw92c_unpolarized(RHO)*Fc(inputs) )
:param inputs: The expected inputs (here, [rho, sigma, lapl, tau]) that will be vmapped against while calling the x/c enhancement factor networks.
:type inputs: array
:return: The epsilon values generated by the networks.
:rtype: array
'''
# inputs expected -- [rho, sigma, lapl, tau]
# this generates epsilon, not exc -- divide end result by rho when needed
rho = inputs[0]
return rho*(lda_x(rho)*self.xnet(inputs[..., jnp.newaxis]) + pw92c_unpolarized(rho)*self.cnet(inputs[..., jnp.newaxis])).flatten()[0]
# =====================================================================
# =====================================================================
# DEPRECATED CLASSES -- TO BE REMOVED
# =====================================================================
# =====================================================================
[docs]
class eXC(eqx.Module):
grid_models: list
heg_mult: bool
pw_mult: bool
level: int
exx_a: jax.Array
epsilon: jax.Array
loge: jax.Array
s_gam: jax.Array
heg_model: eqx.Module
pw_model: eqx.Module
model_mult: list
debug: bool
nlstart_i: int
nlend_i: int
verbose: bool
[docs]
def __init__(self, grid_models=[], heg_mult=True, pw_mult=True,
level=1, exx_a=None, epsilon=1e-6, debug=False,
nlstart_i=3, nlend_i=12, verbose=False):
"""
__init__ Defines the XC functional
Constructed with two MLPs -- one for the local exchange energy on the grid, the other for the local correlation energy.
:param grid_models: list of eX (local exchange) or eC (local correlation). Defines the xc-models/enhancement factors, defaults to []
:type grid_models: list, optional
:param heg_mult: Use homoegeneous electron gas exchange (multiplicative if grid_models is not empty), defaults to True
:type heg_mult: bool, optional
:param pw_mult: Use homoegeneous electron gas correlation (Perdew & Wang), defaults to True
:type pw_mult: bool, optional
:param level: Controls the number of density "descriptors" generated. 1: LDA, 2: GGA, 3:meta-GGA, 4: meta-GGA + electrostatic (nonlocal), defaults to 1
:type level: int, optional
:param exx_a: Exact exchange mixing parameter, defaults to None
:type exx_a: float, optional
:param epsilon: Offset to avoid div/0 in calculations, defaults to 1e-8
:type epsilon: float, optional
:param debug: Controls printing of various stats throughout, defaults to False
:type debug: bool, optional
:param nlstart_i: If level > 3, this controls the number of CIDER Nonlocal parameters are selected, defaults to 3 as the first nonlocal CIDER descriptor
:type nlstart_i: int, optional
:param nlend_i: If level > 3, this controls the number of CIDER Nonlocal parameters are selected, defaults to 12 as the last nonlocal CIDER descriptor
:type nlend_i: int, optional
:param verbose: Flag to determine printing of a lot more extra information to the terminal, useful for debuggin, defaults to False
:type verbose: bool, optional
"""
super().__init__()
self.heg_mult = heg_mult
self.pw_mult = pw_mult
self.level = level
self.grid_models = grid_models
self.epsilon = epsilon
if level > 3:
print('WARNING: External module "mldftdat" required for non-local descriptor use.')
self.loge = 1e-5
self.s_gam = 1
self.debug = debug
self.nlstart_i = nlstart_i
self.nlend_i = nlend_i
self.verbose = verbose
if heg_mult:
self.heg_model = LDA_X()
if pw_mult:
self.pw_model = PW_C()
self.model_mult = [1 for m in self.grid_models]
if not exx_a:
self.exx_a = 0
else:
self.exx_a = exx_a
[docs]
def __call__(self, dm, ao_eval, grid_weights, mf=None, coor=None):
"""
__call__ Forward call for the XC network to get the grid point e_xc
Generates the density-on-grid from the density matrix, atomic orbital evaluation, and the grid weights from a :pyscfad: calculation.
:param dm: Density matrix
:type dm: jax.Array
:param ao_eval: Atomic orbitals evaluated on the grid
:type ao_eval: jax.Array
:param grid_weights: Grid weights associated to the grid on which the atomic orbitals are evaluated
:type grid_weights: jax.Array
:param mf: A pyscf(ad) converged calculation kernel if self.level > 3, used for building the CIDER nonlocal descriptors, defaults to None
:type mf: pyscfad.dft.RKS kernel
:param coor: Grid coordinates associated to the grid on which the atomic orbitals are evaluated, passed into eval_grid_models, defaults to None
:type coor: jax.Array
:return: Exc, exchange-correlation energy from integrating the network calls across the grid
:rtype: float
"""
Exc = 0
if self.level > 3:
assert mf is not None
if self.grid_models or self.heg_mult:
if ao_eval.ndim == 2:
ao_eval = jnp.expand_dims(ao_eval, 0)
else:
ao_eval = ao_eval
# Create density (and gradients) from atomic orbitals evaluated on grid
# and density matrix
# rho[ijsp]: del_i phi del_j phi dm (s: spin, p: grid point index)
rho = jnp.einsum('xij,yik,...jk->xy...i', ao_eval, ao_eval, dm)+1e-10
rho0 = rho[0, 0]
drho = rho[0, 1:4] + rho[1:4, 0]
tau = 0.5*(rho[1, 1] + rho[2, 2] + rho[3, 3])
non_loc = jnp.zeros_like(tau)
if dm.ndim == 3: # If unrestricted (open-shell) calculation
# Density
rho0_a = rho0[0]
rho0_b = rho0[1]
# jnp.einsumed density gradient
gamma_a, gamma_b = jnp.einsum(
'ij,ij->j', drho[:, 0], drho[:, 0]), jnp.einsum('ij,ij->j', drho[:, 1], drho[:, 1])
gamma_ab = jnp.einsum('ij,ij->j', drho[:, 0], drho[:, 1])
# Kinetic energy density
tau_a, tau_b = tau
# E.-static
non_loc_a, non_loc_b = non_loc
else:
rho0_a = rho0_b = rho0*0.5
gamma_a = gamma_b = gamma_ab = jnp.einsum('ij,ij->j', drho[:], drho[:])*0.25
tau_a = tau_b = tau*0.5
non_loc_a = non_loc_b = non_loc*0.5
# xc-energy per unit particle
exc = self.eval_grid_models(jnp.concatenate([jnp.expand_dims(rho0_a, -1),
jnp.expand_dims(rho0_b, -1),
jnp.expand_dims(gamma_a, -1),
jnp.expand_dims(gamma_ab, -1),
jnp.expand_dims(gamma_b, -1),
jnp.expand_dims(jnp.zeros_like(rho0_a), -
1), # Dummy for laplacian
jnp.expand_dims(jnp.zeros_like(rho0_a), -
1), # Dummy for laplacian
jnp.expand_dims(tau_a, -1),
jnp.expand_dims(tau_b, -1),
jnp.expand_dims(non_loc_a, -1),
jnp.expand_dims(non_loc_b, -1)], axis=-1),
mf=mf, dm=dm, ao=ao_eval, gw=grid_weights, coor=coor)
excnans = jnp.sum(jnp.isnan(exc[:, 0]))
self.vprint(f'NaNs in exc from self.eval_grid_models = {excnans}')
Exc += jnp.sum(((rho0_a + rho0_b)*exc[:, 0])*grid_weights)
return Exc
[docs]
def vprint(self, str):
"""
Prints if verbose tag is set on the object
:param str: The string to be printed
:type str: str
"""
if self.verbose:
print(str)
# Density (rho)
[docs]
def l_1(self, rho):
"""
l_1 Level 1 (LDA-level) Descriptor -- Creates dimensionless quantity from rho.
Equation 3 from the `base paper`_.
.. _base paper: https://link.aps.org/doi/10.1103/PhysRevB.104.L161109
.. math:: x_0 = \\rho^{1/3}
:param rho: density
:type rho: jax.Array
:return: Scaled density
:rtype: jax.Array
"""
self.vprint('l_1, descr shape: {}'.format(rho.shape))
descrnans = jnp.sum(jnp.isnan(rho))
self.vprint(f'NaNs in descr from self.l_1 = {descrnans}')
return rho**(1/3)
# Reduced density gradient s
[docs]
def l_2(self, rho, gamma):
"""
l_2 Level 2 (GGA-level) Descriptor -- Reduced gradient density
Equation 5 from the `base paper`_.
.. _base paper: https://link.aps.org/doi/10.1103/PhysRevB.104.L161109
.. math:: x_2=s=\\frac{1}{2(3\\pi^2)^{1/3}} \\frac{|\\nabla \\rho|}{\\rho^{4/3}}
:param rho: density
:type rho: jax.Array
:param gamma: squared density gradient
:type gamma: jax.Array
:return: reduced density gradient s
:rtype: jax.Array
"""
l2_1 = jnp.sqrt(gamma)
l2_2 = 2*(3*np.pi**2)**(1/3)*rho**(4/3)+self.epsilon
l2_nans = (jnp.sum(jnp.isnan(l2_1)), jnp.sum(jnp.isnan(l2_2)))
self.vprint(f'l_2 fraction nans: num = {l2_nans[0]}, denom = {l2_nans[1]}')
l2 = l2_1 / l2_2
self.vprint('l_2, descr shape: {}'.format(l2.shape))
descrnans = jnp.sum(jnp.isnan(l2))
self.vprint(f'NaNs in descr from self.l_2 = {descrnans}')
return l2
# Reduced kinetic energy density alpha
[docs]
def l_3(self, rho, gamma, tau):
"""
l_3 Level 3 (MGGA-level) Descriptor -- Reduced kinetic energy density
Equation 6 from the `base paper`_.
.. _base paper: https://link.aps.org/doi/10.1103/PhysRevB.104.L161109
.. math:: \\tau^W = \\frac{|\\nabla \\rho|^2}{8\\rho}, \\tau^{unif} = \\frac{3}{10} (3\\pi^2)^{2/3}\\rho^{5/3}.
:param rho: density
:type rho: jax.Array
:param gamma: squared density gradient
:type gamma: jax.Array
:param tau: kinetic energy density
:type tau: jax.Array
:return: reduced kinetic energy density
:rtype: jax.Array
"""
uniform_factor = (3/10)*(3*np.pi**2)**(2/3)
tw = gamma/(8*rho+self.epsilon)
l3 = (tau - tw)/(uniform_factor*rho**(5/3)+self.epsilon)
self.vprint('l_3, descr shape: {}'.format(l3.shape))
descrnans = jnp.sum(jnp.isnan(l3))
self.vprint(f'NaNs in descr from self.l_3 = {descrnans}')
return jnp.clip(l3, 0, None)
# Unit-less electrostatic potential
[docs]
def l_4(self, rho, nl):
"""
l_4 Level 4 (Non-local level) Descriptor -- Unitless electrostatic potential
.. todo:: document/implement in a more descriptive manner
:param rho: density
:type rho: jax.Array
:param nl: non-local values arising from density contractions
:type nl: jax.Array
:return: the non-local descriptors
:rtype: jax.Array
"""
u = nl[:, :1]/((jnp.expand_dims(rho, -1)**(1/3))*self.nl_ueg[:, :1] + self.epsilon)
wu = nl[:, 1:]/((jnp.expand_dims(rho, -1))*self.nl_ueg[:, 1:] + self.epsilon)
return jax.nn.relu(jnp.concatenate([u, wu], axis=-1))
[docs]
def nl_4(self, mf, dm, ao=None, gw=None, coor=None):
"""
Level 4 (non-local level) descriptor generator -- generates the CIDER nonlocal descriptors, given a density matrix and a converged kernel with associated mol and grids
Optional parameters not provided will use the values already associated with `mf.grids`; provide custom values for things like grid subsets.
:param mf: The converged calculation kernel, passed to RKSAnalyzer for descriptor generation
:type mf: pyscf(ad).dft.RKS kernel
:param dm: Density matrix to use in nonlocal desciptor generation
:type dm: jax.Array
:param ao: Atomic orbital values to use in nonlocal desciptor generation, defaults to None
:type ao: jax.Array
:param gw: Grid weights to use in nonlocal desciptor generation, defaults to None
:type gw: jax.Array
:param coor: Grid coordinates to use in nonlocal desciptor generation, defaults to None
:type coor: jax.Array
:return: The non-local CIDER descriptors, from self.nlstart_i to self.nlend_i
:rtype: jax.Array
"""
self.vprint('Constructing non-local CIDER descriptor generator')
self.vprint(f'Sending mf={mf} to RKSAnalyzer')
self.vprint(f'mf.e_tot={mf.e_tot}')
an = RKSAnalyzer(mf, idm=True, dm=dm,
coor=coor, weight=gw,
require_converged=False)
if len(dm.shape) == 2:
restric = True
elif len(dm.shape) == 3:
restric = False
def descr5_func(x): return jax.lax.stop_gradient(jnp.asarray(get_exchange_descriptors2(an, restricted=restric, version='c', auxbasis=mf.mol.basis,
rdm1=True, dm=x, inmol=True, mol=mf.mol, ingrid=True, grid=mf.grids, weights=gw, coords=coor)
))
descr5 = descr5_func(dm)
self.vprint('nl_4, descr5 shape: {}'.format(descr5.shape))
descrnans = jnp.sum(jnp.isnan(descr5))
self.vprint(f'NaNs in descr from self.l_1 = {descrnans}')
return descr5
# @eqx.filter_jit
[docs]
def get_descriptors(self, rho0_a, rho0_b, gamma_a, gamma_b, gamma_ab, tau_a, tau_b, spin_scaling=False,
mf=None, dm=None, ao=None, gw=None, coor=None):
"""
get_descriptors Creates 'ML-compatible' descriptors from the electron density and its gradients, a & b correspond to spin channels
:param rho0_a: :math:`\\rho` in spin-channel a
:type rho0_a: jax.Array
:param rho0_b: :math:`\\rho` in spin-channel b
:type rho0_b: jax.Array
:param gamma_a: :math:`|\\nabla \\rho|^2` in spin-channel b
:type gamma_a: jax.Array
:param gamma_b: :math:`|\\nabla \\rho|^2` in spin-channel b
:type gamma_b: jax.Array
:param gamma_ab: :math:`|\\nabla \\rho|^2`, contracted from both spin channels
:type gamma_ab: jax.Array
:param tau_a: KE density in spin-channel a
:type tau_a: jax.Array
:param tau_b: KE density in spin-channel b
:type tau_b: jax.Array
:param spin_scaling: Flag for spin-scaling, defaults to False
:type spin_scaling: bool, optional
:param mf: The converged calculation kernel, passed to RKSAnalyzer for descriptor generation if self.level > 3, defaults to None
:type mf: pyscf(ad).dft.RKS kernel
:param dm: Density matrix to use in nonlocal desciptor generation if self.level > 3, defaults to None
:type dm: jax.Array
:param ao: Atomic orbital values to use in nonlocal desciptor generation if self.level > 3, defaults to None
:type ao: jax.Array
:param gw: Grid weights to use in nonlocal desciptor generation if self.level > 3, defaults to None
:type gw: jax.Array
:param coor: Grid coordinates to use in nonlocal desciptor generation if self.level > 3, defaults to None
:type coor: jax.Array
:return: Array of the machine-learning descriptors on the grid
:rtype: jax.Array
"""
self.vprint(f'get_descriptors called. Input stats: min/max')
self.vprint(f'rho0_a.min/max: {jnp.min(rho0_a)}, {jnp.max(rho0_a)}')
self.vprint(f'rho0_b.min/max: {jnp.min(rho0_b)}, {jnp.max(rho0_b)}')
self.vprint(f'gamma_a.min/max: {jnp.min(gamma_a)}, {jnp.max(gamma_a)}')
self.vprint(f'gamma_b.min/max: {jnp.min(gamma_b)}, {jnp.max(gamma_b)}')
self.vprint(f'gamma_ab.min/max: {jnp.min(gamma_ab)}, {jnp.max(gamma_ab)}')
self.vprint(f'tau_a.min/max: {jnp.min(tau_a)}, {jnp.max(tau_a)}')
self.vprint(f'tau_b.min/max: {jnp.min(tau_b)}, {jnp.max(tau_b)}')
self.vprint(f'get_descriptors called. Input stats: NaNs')
self.vprint(f'rho0_a.min/max: {jnp.isnan(rho0_a).sum()}')
self.vprint(f'rho0_b.min/max: {jnp.isnan(rho0_b).sum()}')
self.vprint(f'gamma_a.min/max: {jnp.isnan(gamma_a).sum()}')
self.vprint(f'gamma_b.min/max: {jnp.isnan(gamma_b).sum()}')
self.vprint(f'gamma_ab.min/max: {jnp.isnan(gamma_ab).sum()}')
self.vprint(f'tau_a.min/max: {jnp.isnan(tau_a).sum()}')
self.vprint(f'tau_b.min/max: {jnp.isnan(tau_b).sum()}')
self.vprint(f'spin_scaling: {spin_scaling}')
if not spin_scaling:
# If no spin-scaling, calculate polarization and use for X1
zeta = (rho0_a - rho0_b)/(rho0_a + rho0_b + self.epsilon)
spinscale = 0.5*((1+zeta)**(4/3) + (1-zeta)**(4/3)) # zeta
self.vprint(f'not spin_scaling: zeta nans: {jnp.isnan(zeta).sum()}')
self.vprint(f'not spin_scaling: spinscale nans: {jnp.isnan(spinscale).sum()}')
if self.level > 0: # LDA
if spin_scaling:
descr1 = jnp.log(self.l_1(2*rho0_a) + self.loge)
descr2 = jnp.log(self.l_1(2*rho0_b) + self.loge)
else:
descr1 = jnp.log(self.l_1(rho0_a + rho0_b) + self.loge) # rho
descr2 = jnp.log(spinscale) # zeta
self.vprint(f'self.level > 0; descr1 Nans = {jnp.sum(jnp.isnan(descr1))}')
self.vprint(f'self.level > 0; descr2 Nans = {jnp.sum(jnp.isnan(descr2))}')
descr = jnp.concatenate([jnp.expand_dims(descr1, -1), jnp.expand_dims(descr2, -1)], axis=-1)
self.vprint(
f'get_descriptors -> self.level > 0\ndescr1.shape={descr1.shape}, descr2.shape={descr2.shape}, descr.shape={descr.shape}')
if self.level > 1: # GGA
if spin_scaling:
descr3a = self.l_2(2*rho0_a, 4*gamma_a) # s
descr3b = self.l_2(2*rho0_b, 4*gamma_b) # s
descr3 = jnp.concatenate([jnp.expand_dims(descr3a, -1), jnp.expand_dims(descr3b, -1)], axis=-1)
descr3 = (1-jnp.exp(-descr3**2/self.s_gam))*jnp.log(descr3 + 1)
self.vprint(f'self.level > 1; descr3a Nans = {jnp.sum(jnp.isnan(descr3a))}')
self.vprint(f'self.level > 1; descr3b Nans = {jnp.sum(jnp.isnan(descr3b))}')
self.vprint(
f'get_descriptors -> self.level > 1 and spin_scaling\ndescr3a.shape={descr3a.shape}, descr3b.shape={descr3b.shape}')
else:
descr3 = self.l_2(rho0_a + rho0_b, gamma_a + gamma_b + 2*gamma_ab) # s
descr3 = descr3/((1+zeta)**(2/3) + (1-zeta)**2/3)
descr3 = jnp.expand_dims(descr3, -1)
descr3 = (1-jnp.exp(-descr3**2/self.s_gam))*jnp.log(descr3 + 1)
self.vprint(f'self.level > 1; descr3 Nans = {jnp.sum(jnp.isnan(descr3))}')
descr = jnp.concatenate([descr, descr3], axis=-1)
self.vprint(f'get_descriptors -> self.level > 1\ndescr3.shape={descr3.shape}, descr.shape={descr.shape}')
if self.level > 2: # meta-GGA
if spin_scaling:
descr4a = self.l_3(2*rho0_a, 4*gamma_a, 2*tau_a)
descr4b = self.l_3(2*rho0_b, 4*gamma_b, 2*tau_b)
descr4 = jnp.concatenate([jnp.expand_dims(descr4a, -1), jnp.expand_dims(descr4b, -1)], axis=-1)
descr4 = descr4**3/(descr4**2+self.epsilon)
self.vprint(f'self.level > 2; descr4a Nans = {jnp.sum(jnp.isnan(descr4a))}')
self.vprint(f'self.level > 2; descr4b Nans = {jnp.sum(jnp.isnan(descr4b))}')
self.vprint(f'descr4a.min/max: {jnp.min(descr4a)}, {jnp.max(descr4a)}')
self.vprint(f'descr4b.min/max: {jnp.min(descr4b)}, {jnp.max(descr4b)}')
self.vprint(
f'get_descriptors -> self.level > 2 and spin_scaling\ndescr4a.shape={descr4a.shape}, descr4b.shape={descr4b.shape}')
else:
descr4 = self.l_3(rho0_a + rho0_b, gamma_a + gamma_b + 2*gamma_ab, tau_a + tau_b)
descr4 = 2*descr4/((1+zeta)**(5/3) + (1-zeta)**(5/3))
descr4 = descr4**3/(descr4**2+self.epsilon)
descr4 = jnp.expand_dims(descr4, -1)
self.vprint(f'self.level > 2; pre-log descr4 Nans = {jnp.sum(jnp.isnan(descr4))}')
self.vprint(f'descr4.min/max: {jnp.min(descr4)}, {jnp.max(descr4)}')
descr4 = jnp.log((descr4 + 1)/2)
self.vprint(f'self.level > 2; descr4 Nans = {jnp.sum(jnp.isnan(descr4))}')
descr = jnp.concatenate([descr, descr4], axis=-1)
self.vprint(f'get_descriptors -> self.level > 2\ndescr4.shape={descr4.shape}, descr.shape={descr.shape}')
if self.level > 3: # meta-GGA + V_estat
# if spin_scaling:
# descr5a = self.l_4(2*rho0_a, 2*nl_a)
# descr5b = self.l_4(2*rho0_b, 2*nl_b)
# descr5 = jnp.log(jnp.stack([descr5a, descr5b],axis=-1) + self.loge)
# descr5 = descr5.view(descr5.size()[0],-1)
# else:
# descr5= jnp.log(self.l_4(rho0_a + rho0_b, nl_a + nl_b) + self.loge)
# descr = jnp.concatenate([descr, descr5],axis=-1)
def convert_dm(dm):
res_shape = jax.ShapeDtypeStruct(dm.shape, dm.dtype)
return jax.lax.stop_gradient(jax.pure_callback(np.asarray, res_shape, dm))
dmnp = np.asarray(jax.lax.stop_gradient(dm))
descr5 = self.nl_4(mf, dmnp, ao=ao, gw=gw, coor=coor).T
self.vprint(f'get_descriptors -> self.level > 3 -> returned descr5 shape={descr5.shape}')
if len(descr5.shape) == 3:
# descr5 returned (spin, 12 descriptors, Ngrid)
# transpose makes (Ngrid, 12 descriptors, spin)
# want (spin, Ngrid, 12 descriptors)
descr5 = jnp.transpose(descr5, (2, 0, 1))
self.vprint(f'reshaped descr5.shape={descr5.shape}')
if spin_scaling:
self.vprint(f'spin_scaling and self.level > 3, descr5.shape={descr5.shape}')
if len(descr5.shape) == 2:
self.vprint('decomposing descriptors into spin channels, half each')
descr5 = jnp.reshape(jnp.concatenate([0.5*descr5, 0.5*descr5]), (2, descr5.shape[0], -1))
self.vprint(f'new descr5.shape={descr5.shape}')
else:
if len(descr5.shape) == 3:
self.vprint(
f'not spin_scaling and self.level > 3 but spin polarized NL descriptors, averaging descr5 spin channels')
descr5 = 0.5*(descr5[0] + descr5[1])
if spin_scaling:
descr = jnp.transpose(jnp.reshape(descr, (jnp.shape(descr)[0], -1, 2)), (2, 0, 1))
if self.level > 3:
descr = jnp.concatenate([descr, descr5], axis=-1)
self.vprint(f'get_descriptors -> self.level > 3, descr5.shape={descr5.shape}')
self.vprint(f'spin_scaling, get_descriptors -> reshaping -> descr.shape={descr.shape}')
else:
if self.level > 3:
descr = jnp.concatenate([descr, descr5], axis=-1)
self.vprint(f'get_descriptors not_spin_scaling -> self.level > 3 -> descr5.shape={descr5.shape}')
self.vprint(f'get_descriptors, not spin_scaling -> descr.shape={descr.shape}')
return descr
[docs]
def eval_grid_models(self, rho, mf=None, dm=None, ao=None,
gw=None, coor=None):
"""
eval_grid_models Evaluates all models stored in self.grid_models along with HEG exchange and correlation
:param rho: List/array with [rho0_a,rho0_b,gamma_a,gamma_ab,gamma_b, dummy for laplacian, dummy for laplacian, tau_a, tau_b, non_loc_a, non_loc_b]
Shape assumes, for instance, that rho0_a = rho[:, 0], etc.
:type rho: jax.Array
:param mf: The converged calculation kernel, passed to RKSAnalyzer for nonlocal descriptor generation
:type mf: pyscf(ad).dft.RKS kernel
:param dm: Density matrix to use in nonlocal desciptor generation
:type dm: jax.Array
:param ao: Atomic orbitals to use in non-local descriptor generation, defaults to None
:type ao: jax.Array
:param gw: Grid weights to use in non-local descr generation, defaults to None
:type gw: jax.Array
:param coor: Grid coordinates to use in non-local descr generation, defaults to None
:type coor: jax.Array
:return: The exchange-correlation energy density (on the grid)
:rtype: jax.Array
"""
Exc = 0
rho0_a = rho[:, 0]
rho0_b = rho[:, 1]
gamma_a = rho[:, 2]
gamma_ab = rho[:, 3]
gamma_b = rho[:, 4]
tau_a = rho[:, 7]
tau_b = rho[:, 8]
nl = rho[:, 9:]
nl_size = jnp.size(nl, -1)//2
nl_a = nl[:, :nl_size]
nl_b = nl[:, nl_size:]
C_F = 3/10*(3*np.pi**2)**(2/3)
rho0_a_ueg = rho0_a
rho0_b_ueg = rho0_b
if gw is not None:
try:
self.vprint(
f'custom gw and coor present in eval_grid_models; shapes: gw={gw.shape}, coor={coor.shape}')
except:
self.vprint(
f'custom gw and coor present in eval_grid_models but shape print error; shapes: gw={gw}, coor={coor}')
zeta = (rho0_a_ueg - rho0_b_ueg)/(rho0_a_ueg + rho0_b_ueg + 1e-8)
rs = (4*np.pi/3*(rho0_a_ueg+rho0_b_ueg + 1e-8))**(-1/3)
rs_a = (4*np.pi/3*(rho0_a_ueg + 1e-8))**(-1/3)
rs_b = (4*np.pi/3*(rho0_b_ueg + 1e-8))**(-1/3)
# initialize zero values for the ex/ec/grid values
exc_a = jnp.zeros_like(rho0_a)
exc_b = jnp.zeros_like(rho0_a)
exc_ab = jnp.zeros_like(rho0_a)
self.vprint('eval_grid_models initial nan summary:')
self.vprint('zeta, rs, rs_a, rs_b, exc_a, exc_b, exc_ab')
self.vprint('{}, {}, {}, {}, {}, {}, {}'.format(
jnp.sum((jnp.isnan(zeta))),
jnp.sum((jnp.isnan(rs))),
jnp.sum((jnp.isnan(rs_a))),
jnp.sum((jnp.isnan(rs_b))),
jnp.sum((jnp.isnan(exc_a))),
jnp.sum((jnp.isnan(exc_b))),
jnp.sum((jnp.isnan(exc_ab))),
))
descr_dict = {}
rho_tot = rho0_a + rho0_b
# spin scaling false descriptors; no NL inputs, as not used in get_descriptors
descr_dict[False] = self.get_descriptors(rho0_a, rho0_b, gamma_a, gamma_b,
gamma_ab, tau_a, tau_b, spin_scaling=False,
mf=mf, dm=dm, ao=ao, gw=gw, coor=coor)
# spin scaling true descriptors; no NL inputs, as not used in get_descriptors
descr_dict[True] = self.get_descriptors(rho0_a, rho0_b, gamma_a, gamma_b,
gamma_ab, tau_a, tau_b, spin_scaling=True,
mf=mf, dm=dm, ao=ao, gw=gw, coor=coor)
self.vprint(f'NaNs in descr_dict[False] = {jnp.sum(jnp.isnan(descr_dict[False]))}')
self.vprint(f'NaNs in descr_dict[True] = {jnp.sum(jnp.isnan(descr_dict[True]))}')
def else_test_fun(exc, exc_b):
if self.heg_mult:
exc_b += (1 + exc[1])*self.heg_model(2*rho0_b_ueg)*(1-self.exx_a)
else:
exc_b += exc[1]*(1-self.exx_a)
return exc_b
def if_not_test_fun(exc, exc_b):
exc_b += exc[0]*0
return exc_b
def gm_eval_func(grid_model, exc_a, exc_b, exc_ab):
if not grid_model.spin_scaling:
descr = descr_dict[False]
# print(f"spin_scaling = False; input descr to exc shape: {descr.shape}")
exc = grid_model(descr)
excnans = jnp.sum(jnp.isnan(exc))
self.vprint(10*'=')
self.vprint(f'exc.shape, descr with spin_scaling=False -> {exc.shape}')
self.vprint(f'NaNs in exc from gm_eval_func, spin_scaling=False -> = {excnans}')
self.vprint(10*'=')
if jnp.ndim(exc) == 2: # If using spin decomposition
pw_alpha = self.pw_model(rs_a, jnp.ones_like(rs_a))
pw_beta = self.pw_model(rs_b, jnp.ones_like(rs_b))
pw = self.pw_model(rs, zeta)
ec_alpha = (1 + exc[:, 0])*pw_alpha*rho0_a/(rho_tot+1e-8)
ec_beta = (1 + exc[:, 1])*pw_beta*rho0_b/(rho_tot+1e-8)
ec_mixed = (1 + exc[:, 2])*(pw*rho_tot - pw_alpha*rho0_a - pw_beta*rho0_b)/(rho_tot+1e-8)
exc_ab += ec_alpha + ec_beta + ec_mixed
else:
if self.pw_mult:
exc_ab += (1 + exc)*self.pw_model(rs, zeta)
else:
exc_ab += exc
else:
descr = descr_dict[True]
# print(f"spin_scaling = True; input descr to exc shape: {descr.shape}")
exc = grid_model(descr)
excnans = jnp.sum(jnp.isnan(exc))
self.vprint(10*'=')
self.vprint(f'exc.shape, descr with spin_scaling=True -> {exc.shape}')
self.vprint(f'NaNs in exc from gm_eval_func, spin_scaling=True -> = {excnans}')
self.vprint(10*'=')
if self.heg_mult:
exc_a += (1 + exc[0])*self.heg_model(2*rho0_a_ueg)*(1-self.exx_a)
else:
exc_a += exc[0]*(1-self.exx_a)
test = jnp.sum(jnp.abs(rho0_b))
exc_b = jax.lax.cond(test, else_test_fun, if_not_test_fun, exc, exc_b)
return (exc_a, exc_b, exc_ab)
if self.grid_models:
self.vprint('Grid models present; looping over separate networks to construct exc')
gm_range = jnp.arange(0, len(self.grid_models))
for gmidx, gm in enumerate(self.grid_models):
self.vprint('Evaluating gm: {}'.format(gm))
exc_a, exc_b, exc_ab = gm_eval_func(gm, exc_a, exc_b, exc_ab)
self.vprint('eval_grid_models gm_eval_func [{}] nan summary:'.format(gmidx))
self.vprint('exc_a, exc_b, exc_ab')
self.vprint('{}, {}, {}'.format(
jnp.sum((jnp.isnan(exc_a))),
jnp.sum((jnp.isnan(exc_b))),
jnp.sum((jnp.isnan(exc_ab))),
))
# exc_a, exc_b, exc_ab = gm_eval_func(self.grid_models[0], exc_a, exc_b, exc_ab)
# self.vprint('eval_grid_models gm_eval_func [0] nan summary:')
# self.vprint('exc_a, exc_b, exc_ab')
# self.vprint('{}, {}, {}'.format(
# jnp.sum((jnp.isnan(exc_a))),
# jnp.sum((jnp.isnan(exc_b))),
# jnp.sum((jnp.isnan(exc_ab))),
# ))
# exc_a, exc_b, exc_ab = gm_eval_func(self.grid_models[1], exc_a, exc_b, exc_ab)
# self.vprint('eval_grid_models gm_eval_func [1] nan summary:')
# self.vprint('exc_a, exc_b, exc_ab')
# self.vprint('{}, {}, {}'.format(
# jnp.sum((jnp.isnan(exc_a))),
# jnp.sum((jnp.isnan(exc_b))),
# jnp.sum((jnp.isnan(exc_ab))),
# ))
else:
if self.heg_mult:
exc_a = self.heg_model(2*rho0_a_ueg)
exc_b = self.heg_model(2*rho0_b_ueg)
if self.pw_mult:
exc_ab = self.pw_model(rs, zeta)
exc = exc_a * (rho0_a_ueg / (rho_tot + self.epsilon)) + exc_b*(rho0_b_ueg / (rho_tot + self.epsilon)) + exc_ab
self.vprint('eval_grid_models final nan summary:')
self.vprint('zeta, rs, rs_a, rs_b, exc_a, exc_b, exc_ab')
self.vprint('{}, {}, {}, {}, {}, {}, {}'.format(
jnp.sum((jnp.isnan(zeta))),
jnp.sum((jnp.isnan(rs))),
jnp.sum((jnp.isnan(rs_a))),
jnp.sum((jnp.isnan(rs_b))),
jnp.sum((jnp.isnan(exc_a))),
jnp.sum((jnp.isnan(exc_b))),
jnp.sum((jnp.isnan(exc_ab))),
))
self.vprint(f'returning exc shape, pre-expanded dims: {exc.shape}')
return jnp.expand_dims(exc, -1)
[docs]
def make_xcfunc(level, x_net_path, c_net_path, configfile='network.config',
xdsfile='xc.eqx', cdsfile='xc.eqx', savepath=None):
'''
Constructs the combined :xcquinox.xc.eXC: object from previously-created and separate :xcquinox.net.eX: and :xquinox.net.eC: objects
:param level: one of ['GGA', 'MGGA', 'NONLOCAL', 'NL'], indicating the desired rung of Jacob's Ladder. NONLOCAL = NL
:type level: str
:param x_net_path: Location of the saved exchange network. Must have a {configfile}.pkl parameter file within.
:type x_net_path: str
:param c_net_path: Location of the saved correlation network. Must have a {configfile}.pkl parameter file within.
:type c_net_path: str
:param configfile: Name for the configuration file, needed when reading in the network to re-generate the same structure, defaults to 'network.config'
:type configfile: str, optional
:param xdsfile: Name for the exchange network file, needed when reading in the network overwrite generated random weights, defaults to 'xc.eqx'
:type xdsfile: str, optional
:param cdsfile: Name for the correlation network file, needed when reading in the network overwrite generated random weights, defaults to 'xc.eqx'
:type cdsfile: str, optional
:param savepath: Location to save the generated network and associated config file, defaults to None
:type savepath: str, optional
:return: The resulting exchange-correlation functional.
:rtype: :xcquinox.xc.eXC:
'''
level_dict = {'GGA': 2, 'MGGA': 3, 'NONLOCAL': 4, 'NL': 4}
try:
with open(os.path.join(x_net_path, configfile+'.pkl'), 'rb') as f:
xparams = pickle.load(f)
with open(os.path.join(c_net_path, configfile+'.pkl'), 'rb') as f:
cparams = pickle.load(f)
except:
print('BOTH exchange and correlation networks require a network.config.pkl file to generate the XC functional object.')
raise
# create the network to generate the descriptors for saving
xnet, xparams = get_net(xorc='X', level=level, net_path=x_net_path)
cnet, cparams = get_net(xorc='C', level=level, net_path=c_net_path)
print('XNET spin scaling: {}'.format(xnet.spin_scaling))
print('CNET spin scaling: {}'.format(cnet.spin_scaling))
if xdsfile:
xnet = eqx.tree_deserialise_leaves(os.path.join(x_net_path, xdsfile), xnet)
if cdsfile:
cnet = eqx.tree_deserialise_leaves(os.path.join(c_net_path, cdsfile), cnet)
xc = eXC(grid_models=[xnet, cnet], heg_mult=True, level=level_dict[level.upper()])
if savepath:
try:
os.makedirs(savepath)
except Exception as e:
print(e)
print(f'Exception raised in creating {savepath}.')
with open(os.path.join(savepath, 'x'+configfile+'.pkl'), 'wb') as f:
pickle.dump(xparams, f)
with open(os.path.join(savepath, 'c'+configfile+'.pkl'), 'wb') as f:
pickle.dump(cparams, f)
eqx.tree_serialise_leaves(os.path.join(savepath, 'xc.eqx'), xc)
return xc
[docs]
def get_xcfunc(level, xc_net_path, configfile='network.config', xcdsfile='xc.eqx'):
'''
Retrieves an XC functional object based on configuration files in given directory.
:param level: one of ['GGA', 'MGGA', 'NONLOCAL', 'NL'], indicating the desired rung of Jacob's Ladder. NONLOCAL = NL
:type level: str
:param xc_net_path: Location of the saved functional. Must have BOTH a 'x'+{configfile}+'.pkl' and 'c'+{configfile}+'.pkl' parameter files within.
:type xc_net_path: str
:param configfile: Base name for the configuration files to be read in, defaults to 'network.config'
:type configfile: str, optional
:param xcdsfile: If present, the network weights will be overwritten by what's present in this file, defaults to 'xc.eqx'
:type xcdsfile: str, optional
:return: The loaded functional
:rtype: :xcquinox.xc.eXC:
'''
level_dict = {'GGA': 2, 'MGGA': 3, 'NONLOCAL': 4, 'NL': 4}
try:
with open(os.path.join(xc_net_path, 'x'+configfile+'.pkl'), 'rb') as f:
xparams = pickle.load(f)
with open(os.path.join(xc_net_path, 'c'+configfile+'.pkl'), 'rb') as f:
cparams = pickle.load(f)
except:
print('Error in opening separate exchange/correlation configuration files. Both must be present to re-create the network architecture.')
raise
# create the network to generate the descriptors for saving
xnet, xparams = get_net(xorc='X', level=level, net_path=xc_net_path, configfile='x'+configfile, netfile=None)
cnet, cparams = get_net(xorc='C', level=level, net_path=xc_net_path, configfile='c'+configfile, netfile=None)
print('XNET spin scaling: {}'.format(xnet.spin_scaling))
print('CNET spin scaling: {}'.format(cnet.spin_scaling))
xc = eXC(grid_models=[xnet, cnet], heg_mult=True, level=level_dict[level.upper()])
if xcdsfile:
print('Deserializing XC Functional over created object')
xc = eqx.tree_deserialise_leaves(os.path.join(xc_net_path, xcdsfile), xc)
return xc