import jax
import os
import pickle
import jax.numpy as jnp
import equinox as eqx
from warnings import warn
import json
from typing import Union
import numpy as np
# =====================================================================
# =====================================================================
# Lieb-Oxford Bound Enforcer
# =====================================================================
# =====================================================================
[docs]
class LOB(eqx.Module):
limit: float
[docs]
def __init__(self, limit: float):
'''
Utility function to squash output to [-1, limit-1] inteval.
:param limit: The Lieb-Oxford bound value to impose, defaults to 1.804
:type limit: float, optional
'''
super().__init__()
self.limit = limit
[docs]
def __call__(self, x):
'''
Method calling the actual mapping of the input to the desired bounded region.
:param x: Energy value to map back into the bounded region.
:type x: float
:return: Energy value mapped into bounded region.
:rtype: float
'''
return self.limit * jax.nn.sigmoid(x-jnp.log(self.limit - 1))-1
# =====================================================================
# =====================================================================
# GGA LEVEL NETWORKS
# =====================================================================
# =====================================================================
# Base Fx/Fc networks:
# Define the neural network module for Fx
[docs]
class GGA_FxNet_s(eqx.Module):
"""S: Exchange enhancementt factor for GGA.
The input to the network is the reduced density gradient, s, and the output is the enhancement factor, Fx.
"""
name: str
depth: int
nodes: int
seed: int
lob_lim: float
net: eqx.nn.MLP
lobf: eqx.Module
[docs]
def __init__(self, depth: int, nodes: int, seed: int, lob_lim=1.804):
'''
Constructor for the exchange enhancement factor object, for the GGA case.
In a GGA XC functional, the relevant quantities are (rho, grad_rho). Here, the network's input size is
hard-coded to 1 -- just the gradient information is passed to the network, to guarantee that the energy
yielded from this multiplicative factor behaves correctly under uniform scaling of the electron density
and obeys the spin-scaling relation.
:param depth: Depth of the neural network
:type depth: int
:param nodes: Number of nodes in each layer
:type nodes: int
:param seed: The random seed to initiate baseline weight values for the network
:type seed: int
:param lob_lim: The Lieb-Oxford bound to respect, defaults to 1.804
:type lob_lim: float, optional
'''
self.name = 'GGA_FxNet_s'
self.depth = depth
self.nodes = nodes
self.seed = seed
self.lob_lim = lob_lim
# to constrain this, we require only gradient inputs
self.net = eqx.nn.MLP(in_size=1, # Input is ONLY gradient_descriptor
out_size=1, # Output is Fx
depth=self.depth,
width_size=self.nodes,
activation=jax.nn.gelu,
key=jax.random.PRNGKey(self.seed))
self.lobf = LOB(limit=lob_lim)
[docs]
def __call__(self, inputs):
'''
The network's forward pass, resulting in the enhancement factor associated to the input gradient descriptor.
*NOTE*: This forward pass is explicitly NOT vectorized -- it expects one grid point worth of data, the (rho, gradient_descriptor) values at that point. This structure expects the :jax.vmap: call to be coded OUTSIDE of the network class.
*NOTE*: Here, the gradient_descriptor is assumed to be the reduced density gradient, :s:, and the call is structured in such a way to respect the UEG limits for when gradients vanish. Namely, when s = 0, Fx = 1, so the resulting e = Fx*e_heg = e_heg.
:param inputs: _description_
:type inputs: tuple, list, array of size 2 in order (rho, gradient_descriptor)
:return: The enhancement factor value
:rtype: float
'''
# takes forever if inputs[1] tanh input has extended shape , i.e. (1,1) as opposed to scalar shape (1,)
return 1+self.lobf((jnp.tanh(inputs[1])**2)*self.net(inputs[1, jnp.newaxis]).squeeze())
# Define the neural network module for Fc
[docs]
class GGA_FcNet_s(eqx.Module):
"""S: Correlation enhancement factor for GGA.
The input to the network is the reduced density gradient, s, and the output is the enhancement factor, Fc.
"""
name: str
depth: int
nodes: int
seed: int
lob_lim: float
net: eqx.nn.MLP
lobf: eqx.Module
[docs]
def __init__(self, depth: int, nodes: int, seed: int, lob_lim=2.0):
'''
Constructor for the correlation enhancement factor object, for the GGA case.
In a GGA XC functional, the relevant quantities are (rho, grad_rho). Here, the network's input size is hard-coded to 2 -- both the density and gradient information is passed to the network.
The default Lieb-Oxford bound the outputs are wrapped here is set to 2.0, to enforce the non-negativity of the correlation energy.
:param depth: Depth of the neural network
:type depth: int
:param nodes: Number of nodes in each layer
:type nodes: int
:param seed: The random seed to initiate baseline weight values for the network
:type seed: int
:param lob_lim: The Lieb-Oxford bound to respect, defaults to 2
:type lob_lim: float, optional
'''
self.name = 'GGA_FcNet_s'
self.depth = depth
self.nodes = nodes
self.seed = seed
self.lob_lim = lob_lim
self.net = eqx.nn.MLP(in_size=2, # Input is rho, gradient_descriptor
out_size=1, # Output is Fc
depth=self.depth,
width_size=self.nodes,
activation=jax.nn.gelu,
key=jax.random.PRNGKey(self.seed))
self.lobf = LOB(limit=lob_lim)
[docs]
def __call__(self, inputs):
'''
The network's forward pass, resulting in the enhancement factor associated to the input gradient descriptor.
*NOTE*: This forward pass is explicitly NOT vectorized -- it expects one grid point worth of data, the (rho, gradient_descriptor) values at that point. This structure expects the :jax.vmap: call to be coded OUTSIDE of the network class.
*NOTE*: Here, the gradient_descriptor is assumed to be the reduced density gradient, :s:, and the call is structured in such a way to respect the UEG limits for when gradients vanish. Namely, when s = 0, Fx = 1, so the resulting e = Fx*e_heg = e_heg.
:param inputs: _description_
:type inputs: tuple, list, array of size 2 in order (rho, gradient_descriptor)
:return: The enhancement factor value
:rtype: float
'''
# takes forever if inputs[1] tanh input has extended shape , i.e. (1,1) as opposed to scalar shape (1,)
return 1+self.lobf((jnp.tanh(inputs[1])**2)*self.net(inputs).squeeze())
[docs]
class GGA_FxNet_G(eqx.Module):
"""S: Exchange enhancementt factor for GGA.
It takes rho and grad_rho as input and outputs the exchange enhancement factor, Fx.
It transforms the input to the reduced density gradient, s, and then passes it through the network.
"""
name: str
depth: int
nodes: int
seed: int
lob_lim: float
net: eqx.nn.MLP
lobf: eqx.Module
def __init__(self, depth: int, nodes: int, seed: int, lob_lim=1.804):
self.name = 'GGA_FxNet_G'
self.depth = depth
self.nodes = nodes
self.seed = seed
self.lob_lim = lob_lim
# to constrain this, we require only gradient inputs
self.net = eqx.nn.MLP(in_size=1, # Input is ONLY s
out_size=1, # Output is Fx
depth=self.depth,
width_size=self.nodes,
activation=jax.nn.gelu,
key=jax.random.PRNGKey(self.seed))
self.lobf = LOB(limit=lob_lim)
def __call__(self, inputs):
# rho = jnp.maximum(1e-12, inputs[0]) # Prevents division by 0
# rho = rho.flatten()
# print('WITHOUT RHO MAXIMUM')
rho = inputs[0].flatten()
k_F = (3 * jnp.pi**2 * rho)**(1/3)
s = inputs[1].flatten() / (2 * k_F * rho)
s = s.flatten()
tanhterm = jnp.tanh(s)**2
netterm = self.net(s)
lobterm = self.lobf(tanhterm*netterm)
return 1+lobterm.squeeze()
[docs]
class GGA_FcNet_G(eqx.Module):
name: str
depth: int
nodes: int
seed: int
lob_lim: float
net: eqx.nn.MLP
lobf: eqx.Module
def __init__(self, depth: int, nodes: int, seed: int, lob_lim=2.0):
"""
Constructor for the correlation enhancement factor object, for the GGA case.
In a GGA XC functional, the relevant quantities are (rho, grad_rho). Here, the network's input size is hard-coded to 2 -- both the density and gradient information is passed to the network.
The default Lieb-Oxford bound the outputs are wrapped here is set to 2.0, to enforce the non-negativity of the correlation energy.
:param depth: Depth of the neural network
:type depth: int
:param nodes: Number of nodes in each layer
:type nodes: int
:param seed: The random seed to initiate baseline weight values for the network
:type seed: int
:param lob_lim: The Lieb-Oxford bound to respect, defaults to 2
:type lob_lim: float, optional
"""
self.name = 'GGA_FcNet_G'
self.depth = depth
self.nodes = nodes
self.seed = seed
self.lob_lim = lob_lim
self.net = eqx.nn.MLP(in_size=2, # Input is rho, s
out_size=1, # Output is Fx
depth=self.depth,
width_size=self.nodes,
activation=jax.nn.gelu,
key=jax.random.PRNGKey(self.seed))
self.lobf = LOB(limit=lob_lim)
def __call__(self, inputs):
rho = inputs[0].flatten()
k_F = (3 * jnp.pi**2 * rho)**(1/3)
s = inputs[1].flatten() / (2 * k_F * rho)
s = s.flatten()
netinp = jnp.stack([rho, s], axis=0).flatten()
tanhterm = jnp.tanh(s)**2
netterm = self.net(netinp)
lobterm = self.lobf(tanhterm*netterm)
return 1+lobterm.squeeze()
# Define the neural network module for Fx
[docs]
class GGA_FxNet_sigma(eqx.Module):
"""S: Exchange enhancementt factor for GGA.
It takes rho and sigma (grad_rho²) as input and outputs the exchange enhancement factor, Fx.
It transforms the input to the reduced density gradient, s, and then passes it through the network.
"""
name: str
depth: int
nodes: int
seed: int
lob_lim: float
lower_rho_cutoff: float
net: eqx.nn.MLP
lobf: eqx.Module
[docs]
def __init__(self, depth: int, nodes: int, seed: int, lob_lim=1.804, lower_rho_cutoff=1e-12):
'''
Constructor for the exchange enhancement factor object, for the GGA case.
In a GGA XC functional, the relevant quantities are (rho, grad_rho). Here, the network's input size is hard-coded to 1 -- just the gradient information is passed to the network, to guarantee that the energy yielded from this multiplicative factor behaves correctly under uniform scaling of the electron density and obeys the spin-scaling relation.
:param depth: Depth of the neural network
:type depth: int
:param nodes: Number of nodes in each layer
:type nodes: int
:param seed: The random seed to initiate baseline weight values for the network
:type seed: int
:param lob_lim: The Lieb-Oxford bound to respect, defaults to 1.804
:type lob_lim: float, optional
:param lower_rho_cutoff: a cut-off to bypass potential division by zero in the division by rho, defaults to 1e-12
:type lower_rho_cutoff: float, optional
'''
self.name = 'GGA_FxNet_sigma'
self.depth = depth
self.nodes = nodes
self.seed = seed
self.lob_lim = lob_lim
self.lower_rho_cutoff = lower_rho_cutoff
# to constrain this, we require only gradient inputs
self.net = eqx.nn.MLP(in_size=1, # Input is ONLY gradient_descriptor
out_size=1, # Output is Fx
depth=self.depth,
width_size=self.nodes,
activation=jax.nn.gelu,
key=jax.random.PRNGKey(self.seed))
self.lobf = LOB(limit=lob_lim)
[docs]
def __call__(self, inputs):
'''
The network's forward pass, resulting in the enhancement factor associated to the input gradient descriptor.
*NOTE*: This forward pass is explicitly NOT vectorized -- it expects one grid point worth of data, the (rho, gradient_descriptor) values at that point. This structure expects the :jax.vmap: call to be coded OUTSIDE of the network class.
*NOTE*: Here, the gradient_descriptor is assumed to be Libxc's/PySCF's internal variable for the density gradient -- sigma (gradient squared in non-spin-polarized, gradient contracted with itself in spin-polarized). This is so that we have easy access to automatic derivatives with respect to sigma, thus can generate v_sigma and use in convergence testing. However, within the call sigma is translated to the reduced density gradient, :s:, which the network is still assumed to be parameterized by, and the call is structured in such a way to respect the UEG limits for when gradients vanish. Namely, when s = 0, Fx = 1, so the resulting e = Fx*e_heg = e_heg.
:param inputs: _description_
:type inputs: tuple, list, array of size 2 in order (rho, gradient_descriptor)
:return: The enhancement factor value
:rtype: float
'''
# here, assume the inputs is [rho, sigma] and select the appropriate input
# takes forever if inputs[1] tanh input has extended shape , i.e. (1,1) as opposed to scalar shape (1,)
# rho = jnp.maximum(self.lower_rho_cutoff, inputs[0]) #Prevents division by 0
# rho = rho.flatten()
# sigma = jnp.maximum(self.lower_rho_cutoff, inputs[1]) #Prevents division by 0
# sigma = sigma.flatten()
rho = inputs[0]
sigma = inputs[1]
k_F = (3 * jnp.pi**2 * rho)**(1/3)
s = jnp.sqrt(sigma) / (2 * k_F * rho)
s = s.flatten()
tanhterm = jnp.tanh(s)**2
netterm = self.net(s)
lobterm = self.lobf(tanhterm*netterm)
return 1+lobterm.squeeze()
# Define the neural network module for Fc
[docs]
class GGA_FcNet_sigma(eqx.Module):
"""S: Correlation enhancement factor for GGA.
It takes rho and sigma (grad_rho²) as input and outputs the correlation enhancement factor, Fc.
It transforms the input to the reduced density gradient, s, and then passes it through the network.
"""
name: str
depth: int
nodes: int
seed: int
lob_lim: float
lower_rho_cutoff: float
net: eqx.nn.MLP
lobf: eqx.Module
[docs]
def __init__(self, depth: int, nodes: int, seed: int, lob_lim=2.0, lower_rho_cutoff=1e-12):
'''
Constructor for the correlation enhancement factor object, for the GGA case.
In a GGA XC functional, the relevant quantities are (rho, grad_rho). Here, the network's input size is hard-coded to 2 -- both the density and gradient information is passed to the network.
The default Lieb-Oxford bound the outputs are wrapped here is set to 2.0, to enforce the non-negativity of the correlation energy.
:param depth: Depth of the neural network
:type depth: int
:param nodes: Number of nodes in each layer
:type nodes: int
:param seed: The random seed to initiate baseline weight values for the network
:type seed: int
:param lob_lim: The Lieb-Oxford bound to respect, defaults to 2
:type lob_lim: float, optional
:param lower_rho_cutoff: a cut-off to bypass potential division by zero in the division by rho, defaults to 1e-12
:type lower_rho_cutoff: float, optional
'''
self.name = 'GGA_FcNet_sigma'
self.depth = depth
self.nodes = nodes
self.seed = seed
self.lob_lim = lob_lim
self.lower_rho_cutoff = lower_rho_cutoff
self.net = eqx.nn.MLP(in_size=2, # Input is rho, gradient_descriptor
out_size=1, # Output is Fc
depth=self.depth,
width_size=self.nodes,
activation=jax.nn.gelu,
key=jax.random.PRNGKey(self.seed))
self.lobf = LOB(limit=lob_lim)
[docs]
def __call__(self, inputs):
'''
The network's forward pass, resulting in the enhancement factor associated to the input gradient descriptor.
*NOTE*: This forward pass is explicitly NOT vectorized -- it expects one grid point worth of data, the (rho, gradient_descriptor) values at that point. This structure expects the :jax.vmap: call to be coded OUTSIDE of the network class.
*NOTE*: Here, the gradient_descriptor is assumed to be Libxc's/PySCF's internal variable for the density gradient -- sigma (gradient squared in non-spin-polarized, gradient contracted with itself in spin-polarized). This is so that we have easy access to automatic derivatives with respect to sigma, thus can generate v_sigma and use in convergence testing. However, within the call sigma is translated to the reduced density gradient, :s:, which the network is still assumed to be parameterized by, and the call is structured in such a way to respect the UEG limits for when gradients vanish. Namely, when s = 0, Fx = 1, so the resulting e = Fx*e_heg = e_heg.
:param inputs: _description_
:type inputs: tuple, list, array of size 2 in order (rho, gradient_descriptor)
:return: The enhancement factor value
:rtype: float
'''
# here, assume the inputs is [rho, sigma] and select the appropriate input
# takes forever if inputs[1] tanh input has extended shape , i.e. (1,1) as opposed to scalar shape (1,)
rho = jnp.maximum(self.lower_rho_cutoff, inputs[0]) # Prevents division by 0
rho = rho.flatten()
sigma = jnp.maximum(self.lower_rho_cutoff, inputs[1]) # Prevents division by 0
sigma = sigma.flatten()
k_F = (3 * jnp.pi**2 * rho)**(1/3)
s = jnp.sqrt(sigma) / (2 * k_F * rho)
s = s.flatten()
netinp = jnp.stack([rho, s], axis=0).flatten()
tanhterm = jnp.tanh(s)**2
netterm = self.net(netinp)
lobterm = self.lobf(tanhterm*netterm)
return 1+lobterm.squeeze()
# Saving models
[docs]
def save_xcquinox_model(model, path: str = '', fixing: Union[str, None] = None,
tail_info: Union[str, None] = None, loss: Union[list[float], None] = None):
"""Save our NN model to a file.
:param model: The model to save.
:type model: eqx.Module.
:param path: The path to save the model to, defaults to .
:type path: str, optional.
:param fixing: A string to append to the model name, defaults to None. Useful to determine the type of fixing used in the model.
:type fixing: Union[str, None], optional.
:param tail_info: A string to append to the model name, defaults to None. Useful to determine any additional information about the model.
:type tail_info: Union[str, None], optional.
:param loss: A list of loss values to save, defaults to None. Useful to determine the loss values during training. Will be saved in a separate file.
:type loss: Union[list[float], None], optional.
"""
if fixing is None:
fixing = ''
else:
fixing = f'_{fixing}'
if tail_info is None:
tail_info = ''
else:
tail_info = f'_{tail_info}'
save_name = f'{model.name}_d{model.depth}_n{model.nodes}_s{model.seed}\
{fixing}{tail_info}'
needen_info = {'depth': model.depth, 'nodes': model.nodes,
'seed': model.seed, 'name': model.name}
eqx.tree_serialise_leaves(f'{path}/{save_name}.eqx', model)
with open(f"{path}/{save_name}.json", "w") as f:
json.dump(needen_info, f)
print(f'Saved {path}/{save_name}.eqx')
if loss is not None:
with open(f"{path}/{save_name}_loss.txt", "w") as f:
np.savetxt(f, loss)
print(f'Saved the loss values in {path}/{save_name}_loss.txt')
[docs]
def load_xcquinox_model(path: str):
"""Load a model from a file.
Note that we must give the path where the model is stored, without the extension.
I.e, in path, we should have the files path.eqx and path.json.
"""
jax.config.update("jax_enable_x64", True) # Ensure 64-bit is enabled first
with open(f"{path}.json", "r") as f:
metadata = json.load(f)
# Model selection
name = metadata['name']
Model_Object = {
'GGA_FxNet_s': GGA_FxNet_s,
'GGA_FcNet_s': GGA_FcNet_s,
'GGA_FxNet_G': GGA_FxNet_G,
'GGA_FcNet_G': GGA_FcNet_G,
'GGA_FxNet_sigma': GGA_FxNet_sigma,
'GGA_FcNet_sigma': GGA_FcNet_sigma,
'MGGA_FxNet_sigma': MGGA_FxNet_sigma,
'MGGA_FcNet_sigma': MGGA_FcNet_sigma,
'MGGA_FxNet_sigma_transform': MGGA_FxNet_sigma_transform,
'MGGA_FcNet_sigma_transform': MGGA_FcNet_sigma_transform
}.get(name)
if Model_Object is None:
raise ValueError(f"Model {name} not recognized. Please check the model name.")
dummy_model = Model_Object(depth=metadata["depth"],
nodes=metadata["nodes"],
seed=metadata["seed"])
# Load the saved model into the dummy structure
model = eqx.tree_deserialise_leaves(f"{path}.eqx", like=dummy_model)
print(f'Loaded {path}.eqx')
return model
# unconstrained networks, for testing purposes
[docs]
class GGA_FxNet_sigma_UNC(eqx.Module):
depth: int
nodes: int
seed: int
lob_lim: float
lower_rho_cutoff: float
net: eqx.nn.MLP
[docs]
def __init__(self, depth: int, nodes: int, seed: int, lob_lim=1.804, lower_rho_cutoff=1e-12):
'''
Constructor for the exchange enhancement factor object, for the GGA case.
In a GGA XC functional, the relevant quantities are (rho, grad_rho). Here, the network's input size is hard-coded to 1 -- just the gradient information is passed to the network, to guarantee that the energy yielded from this multiplicative factor behaves correctly under uniform scaling of the electron density and obeys the spin-scaling relation.
:param depth: Depth of the neural network
:type depth: int
:param nodes: Number of nodes in each layer
:type nodes: int
:param seed: The random seed to initiate baseline weight values for the network
:type seed: int
:param lob_lim: The Lieb-Oxford bound to respect, defaults to 1.804
:type lob_lim: float, optional
:param lower_rho_cutoff: a cut-off to bypass potential division by zero in the division by rho, defaults to 1e-12
:type lower_rho_cutoff: float, optional
'''
self.depth = depth
self.nodes = nodes
self.seed = seed
self.lob_lim = lob_lim
self.lower_rho_cutoff = lower_rho_cutoff
self.net = eqx.nn.MLP(in_size=2, # Input is rho, gradient_descriptor
out_size=1, # Output is Fx
depth=self.depth,
width_size=self.nodes,
activation=jax.nn.gelu,
key=jax.random.PRNGKey(self.seed))
[docs]
def __call__(self, inputs):
'''
The network's forward pass, resulting in the enhancement factor associated to the input gradient descriptor.
*NOTE*: This forward pass is explicitly NOT vectorized -- it expects one grid point worth of data, the (rho, gradient_descriptor) values at that point. This structure expects the :jax.vmap: call to be coded OUTSIDE of the network class.
*NOTE*: Here, the gradient_descriptor is assumed to be Libxc's/PySCF's internal variable for the density gradient -- sigma (gradient squared in non-spin-polarized, gradient contracted with itself in spin-polarized). This is so that we have easy access to automatic derivatives with respect to sigma, thus can generate v_sigma and use in convergence testing. However, within the call sigma is translated to the reduced density gradient, :s:, which the network is still assumed to be parameterized by, and the call is structured in such a way to respect the UEG limits for when gradients vanish. Namely, when s = 0, Fx = 1, so the resulting e = Fx*e_heg = e_heg.
:param inputs: _description_
:type inputs: tuple, list, array of size 2 in order (rho, gradient_descriptor)
:return: The enhancement factor value
:rtype: float
'''
# here, assume the inputs is [rho, sigma] and select the appropriate input
# takes forever if inputs[1] tanh input has extended shape , i.e. (1,1) as opposed to scalar shape (1,)
rho = jnp.maximum(self.lower_rho_cutoff, inputs[0]) # Prevents division by 0
rho = rho.flatten()
sigma = jnp.maximum(self.lower_rho_cutoff, inputs[1]) # Prevents division by 0
sigma = sigma.flatten()
k_F = (3 * jnp.pi**2 * rho)**(1/3)
s = jnp.sqrt(sigma) / (2 * k_F * rho)
s = s.flatten()
netinp = jnp.stack([rho, s], axis=0).flatten()
netterm = self.net(netinp)
return netterm.squeeze()
# Define the neural network module for Fc
[docs]
class GGA_FcNet_sigma_UNC(eqx.Module):
depth: int
nodes: int
seed: int
lob_lim: float
lower_rho_cutoff: float
net: eqx.nn.MLP
[docs]
def __init__(self, depth: int, nodes: int, seed: int, lob_lim=2.0, lower_rho_cutoff=1e-12):
'''
Constructor for the correlation enhancement factor object, for the GGA case.
In a GGA XC functional, the relevant quantities are (rho, grad_rho). Here, the network's input size is hard-coded to 2 -- both the density and gradient information is passed to the network.
The default Lieb-Oxford bound the outputs are wrapped here is set to 2.0, to enforce the non-negativity of the correlation energy.
:param depth: Depth of the neural network
:type depth: int
:param nodes: Number of nodes in each layer
:type nodes: int
:param seed: The random seed to initiate baseline weight values for the network
:type seed: int
:param lob_lim: The Lieb-Oxford bound to respect, defaults to 2
:type lob_lim: float, optional
:param lower_rho_cutoff: a cut-off to bypass potential division by zero in the division by rho, defaults to 1e-12
:type lower_rho_cutoff: float, optional
'''
self.depth = depth
self.nodes = nodes
self.seed = seed
self.lob_lim = lob_lim
self.lower_rho_cutoff = lower_rho_cutoff
self.net = eqx.nn.MLP(in_size=2, # Input is rho, gradient_descriptor
out_size=1, # Output is Fc
depth=self.depth,
width_size=self.nodes,
activation=jax.nn.gelu,
key=jax.random.PRNGKey(self.seed))
[docs]
def __call__(self, inputs):
'''
The network's forward pass, resulting in the enhancement factor associated to the input gradient descriptor.
*NOTE*: This forward pass is explicitly NOT vectorized -- it expects one grid point worth of data, the (rho, gradient_descriptor) values at that point. This structure expects the :jax.vmap: call to be coded OUTSIDE of the network class.
*NOTE*: Here, the gradient_descriptor is assumed to be Libxc's/PySCF's internal variable for the density gradient -- sigma (gradient squared in non-spin-polarized, gradient contracted with itself in spin-polarized). This is so that we have easy access to automatic derivatives with respect to sigma, thus can generate v_sigma and use in convergence testing. However, within the call sigma is translated to the reduced density gradient, :s:, which the network is still assumed to be parameterized by, and the call is structured in such a way to respect the UEG limits for when gradients vanish. Namely, when s = 0, Fx = 1, so the resulting e = Fx*e_heg = e_heg.
:param inputs: _description_
:type inputs: tuple, list, array of size 2 in order (rho, gradient_descriptor)
:return: The enhancement factor value
:rtype: float
'''
# here, assume the inputs is [rho, sigma] and select the appropriate input
# takes forever if inputs[1] tanh input has extended shape , i.e. (1,1) as opposed to scalar shape (1,)
rho = jnp.maximum(self.lower_rho_cutoff, inputs[0]) # Prevents division by 0
rho = rho.flatten()
sigma = jnp.maximum(self.lower_rho_cutoff, inputs[1]) # Prevents division by 0
sigma = sigma.flatten()
k_F = (3 * jnp.pi**2 * rho)**(1/3)
s = jnp.sqrt(sigma) / (2 * k_F * rho)
s = s.flatten()
netinp = jnp.stack([rho, s], axis=0).flatten()
netterm = self.net(netinp)
return netterm.squeeze()
# =====================================================================
# =====================================================================
# Meta-GGA LEVEL NETWORKS
# =====================================================================
# =====================================================================
[docs]
class MGGA_FxNet_sigma(eqx.Module):
depth: int
nodes: int
seed: int
lob_lim: float
lower_rho_cutoff: float
net: eqx.nn.MLP
lobf: eqx.Module
name: str
[docs]
def __init__(self, depth: int, nodes: int, seed: int, lob_lim=1.174, lower_rho_cutoff=1e-12):
'''
Constructor for the exchange enhancement factor object, for the MGGA case.
In a MGGA XC functional, the relevant quantities are (rho, grad_rho, laplacian_rho, tau=kinetic energy density). Here,
the network's input size is hard-coded to 2 -- just the gradient and alpha (related to tau) information
is passed to the network, to guarantee that the energy yielded from this multiplicative
factor behaves correctly under uniform scaling of the electron density and obeys the
spin-scaling relation.
:param depth: Depth of the neural network
:type depth: int
:param nodes: Number of nodes in each layer
:type nodes: int
:param seed: The random seed to initiate baseline weight values for the network
:type seed: int
:param lob_lim: The Lieb-Oxford bound to respect, defaults to 1.804
:type lob_lim: float, optional
:param lower_rho_cutoff: a cut-off to bypass potential division by zero in the division by rho, defaults to 1e-12
:type lower_rho_cutoff: float, optional
'''
self.depth = depth
self.nodes = nodes
self.seed = seed
self.lob_lim = lob_lim
self.lower_rho_cutoff = lower_rho_cutoff
# to constrain this, we require only gradient inputs
self.net = eqx.nn.MLP(in_size=2, # Input is gradient_descriptor, tau_descriptor
out_size=1, # Output is Fx
depth=self.depth,
width_size=self.nodes,
activation=jax.nn.gelu,
key=jax.random.PRNGKey(self.seed))
self.lobf = LOB(limit=lob_lim)
self.name = 'MGGA_FxNet_sigma'
[docs]
def __call__(self, inputs):
'''
The network's forward pass, resulting in the enhancement factor associated to the input gradient descriptor.
*NOTE*: This forward pass is explicitly NOT vectorized -- it expects one grid point worth of data, the (rho, gradient_descriptor) values at that point. This structure expects the :jax.vmap: call to be coded OUTSIDE of the network class.
*NOTE*: Here, the gradient_descriptor is assumed to be Libxc's/PySCF's internal variable for the density gradient -- sigma (gradient squared in non-spin-polarized, gradient contracted with itself in spin-polarized). This is so that we have easy access to automatic derivatives with respect to sigma, thus can generate v_sigma and use in convergence testing. However, within the call sigma is translated to the reduced density gradient, :s:, which the network is still assumed to be parameterized by, and the call is structured in such a way to respect the UEG limits for when gradients vanish. Namely, when s = 0, Fx = 1, so the resulting e = Fx*e_heg = e_heg.
:param inputs: A one-dimensional list/array of inputs [rho, sigma, laplacian_rho, alpha]
:type inputs: tuple, list, one-dimensional array of size 4 in order [rho, sigma, laplacian_rho, alpha]
:return: The enhancement factor value
:rtype: float
'''
# here, assume the inputs is [rho, sigma, laplacian, tau] and select the appropriate input
# takes forever if inputs[1] tanh input has extended shape , i.e. (1,1) as opposed to scalar shape (1,)
# rho = jnp.maximum(self.lower_rho_cutoff, inputs[0]) #Prevents division by 0
# rho = rho.flatten()
# sigma = jnp.maximum(self.lower_rho_cutoff, inputs[1]) #Prevents division by 0
# sigma = sigma.flatten()
rho = inputs[0]
sigma = inputs[1]
tau = inputs[3]
tau_w = sigma/(8*rho)
tau_unif = (3/10)*(3*jnp.pi**2)**(2/3)*rho**(5/3)
alpha = ((tau - tau_w)/tau_unif).flatten()
k_F = (3 * jnp.pi**2 * rho)**(1/3)
s = jnp.sqrt(sigma) / (2 * k_F * rho)
s = s.flatten()
tanhterm = jnp.tanh(s)**2 + jnp.tanh(alpha-1)**2
netterm = self.net(jnp.array([s, alpha]).flatten())
lobterm = self.lobf(tanhterm*netterm)
return 1+lobterm.squeeze()
[docs]
class MGGA_FcNet_sigma(eqx.Module):
depth: int
nodes: int
seed: int
lob_lim: float
lower_rho_cutoff: float
net: eqx.nn.MLP
lobf: eqx.Module
name: str
def __init__(self, depth: int, nodes: int, seed: int, lob_lim=1.174, lower_rho_cutoff=1e-12):
'''
Constructor for the correlation enhancement factor object, for the MGGA case.
In a MGGA XC functional, the relevant quantities are (rho, grad_rho, laplacian_rho, tau=kinetic energy density). Here,
the network's input size is hard-coded to 3 -- just the density, gradient, and alpha (related to tau) information
is passed to the network.
:param depth: Depth of the neural network
:type depth: int
:param nodes: Number of nodes in each layer
:type nodes: int
:param seed: The random seed to initiate baseline weight values for the network
:type seed: int
:param lob_lim: The Lieb-Oxford bound to respect, defaults to 1.804
:type lob_lim: float, optional
:param lower_rho_cutoff: a cut-off to bypass potential division by zero in the division by rho, defaults to 1e-12
:type lower_rho_cutoff: float, optional
'''
self.depth = depth
self.nodes = nodes
self.seed = seed
self.lob_lim = lob_lim
self.lower_rho_cutoff = lower_rho_cutoff
# to constrain this, we require only gradient inputs
self.net = eqx.nn.MLP(in_size=3, # Input is all rho, gradient, tau descriptors
out_size=1, # Output is Fx
depth=self.depth,
width_size=self.nodes,
activation=jax.nn.gelu,
key=jax.random.PRNGKey(self.seed))
self.lobf = LOB(limit=lob_lim)
self.name = 'MGGA_FcNet_sigma'
def __call__(self, inputs):
'''
The network's forward pass, resulting in the enhancement factor associated to the input gradient descriptor.
*NOTE*: This forward pass is explicitly NOT vectorized -- it expects one grid point worth of data, the (rho, gradient_descriptor) values at that point. This structure expects the :jax.vmap: call to be coded OUTSIDE of the network class.
*NOTE*: Here, the gradient_descriptor is assumed to be Libxc's/PySCF's internal variable for the density gradient -- sigma (gradient squared in non-spin-polarized, gradient contracted with itself in spin-polarized). This is so that we have easy access to automatic derivatives with respect to sigma, thus can generate v_sigma and use in convergence testing. However, within the call sigma is translated to the reduced density gradient, :s:, which the network is still assumed to be parameterized by, and the call is structured in such a way to respect the UEG limits for when gradients vanish. Namely, when s = 0, Fx = 1, so the resulting e = Fx*e_heg = e_heg.
:param inputs: A one-dimensional list/array of inputs [rho, sigma, laplacian_rho, alpha]
:type inputs: tuple, list, one-dimensional array of size 4 in order [rho, sigma, laplacian_rho, alpha]
:return: The enhancement factor value
:rtype: float
'''
# here, assume the inputs is [rho, sigma, laplacian, tau] and select the appropriate input
# takes forever if inputs[1] tanh input has extended shape , i.e. (1,1) as opposed to scalar shape (1,)
# rho = jnp.maximum(self.lower_rho_cutoff, inputs[0]) #Prevents division by 0
# rho = rho.flatten()
# sigma = jnp.maximum(self.lower_rho_cutoff, inputs[1]) #Prevents division by 0
# sigma = sigma.flatten()
rho = inputs[0].flatten()
sigma = inputs[1]
tau = inputs[3]
tau_w = sigma/(8*rho)
tau_unif = (3/10)*(3*jnp.pi**2)**(2/3)*rho**(5/3)
alpha = ((tau - tau_w)/tau_unif).flatten()
k_F = (3 * jnp.pi**2 * rho)**(1/3)
s = jnp.sqrt(sigma) / (2 * k_F * rho)
s = s.flatten()
tanhterm = jnp.tanh(s)**2 + jnp.tanh(alpha-1)**2
netterm = self.net(jnp.array([rho, s, alpha]).flatten())
lobterm = self.lobf(tanhterm*netterm)
return 1+lobterm.squeeze()
# =====================================================================
# =====================================================================
# DEPRECATED CLASSES -- TO BE REMOVED
# =====================================================================
# =====================================================================
[docs]
class eX(eqx.Module):
n_input: int
n_hidden: int
ueg_limit: jax.Array
spin_scaling: bool
lob: jax.Array
use: list
net: eqx.Module
tanh: jax.named_call
lobf: jax.named_call
sig: jax.named_call
shift: jax.Array
lobf: eqx.Module
seed: int
depth: int
[docs]
def __init__(self, n_input, n_hidden=16, depth=3, use=[], ueg_limit=False, lob=1.804, seed=92017, spin_scaling=True):
"""
__init__ Local exchange model based on MLP.
Receives density descriptors in this order : [rho, s, alpha, nl], where the input may be truncated depending on XC-level of approximation.
The MLP generated is hard-coded to have one output value -- the predicted exchange energy given a specific input from the grid.
:param n_input: Input dimensions (LDA: 1, GGA: 2, meta-GGA: 3, ...)
:type n_input: int
:param n_hidden: Number of hidden nodes (three hidden layers used by default), defaults to 16
:type n_hidden: int, optional
:param depth: Depth of the MLP, defaults to 3
:type depth: int, optional
:param use: Only these indices are used as input to the model (can be used to omit density as input to enforce uniform density scaling). These indices are also used to enforce UEG where the assumed order is [s, alpha, ...], defaults to []
:type use: list, optional
:param ueg_limit: Flag to determine whether or not to enforce uniform homoegeneous electron gas limit, defaults to False
:type ueg_limit: bool, optional
:param lob: Enforce this value as local Lieb-Oxford bound (don't enforce if set to 0), defaults to 1.804
:type lob: float, optional
:param seed: Random seed used to generate initial weights and biases for the MLP, defaults to 92017
:type seed: int, optional
"""
super().__init__()
self.ueg_limit = ueg_limit
self.spin_scaling = spin_scaling
self.lob = lob
self.n_input = n_input
self.n_hidden = n_hidden
self.seed = seed
self.depth = depth
if not use:
self.use = jnp.arange(n_input)
else:
self.use = use
self.net = eqx.nn.MLP(in_size=self.n_input,
out_size=1,
width_size=self.n_hidden,
depth=self.depth,
activation=jax.nn.gelu,
key=jax.random.PRNGKey(self.seed))
self.tanh = jnp.tanh
self.lobf = LOB(limit=self.lob)
self.sig = jax.nn.sigmoid
self.shift = 1/(1+jnp.exp(-1e-3))
warn("WARNING - DEPRECATED. This class is not the currently working class and will be removed in the future.")
[docs]
def __call__(self, rho, **kwargs):
"""
__call__ Forward pass for the exchange network.
Uses :jax.vmap: to vectorize evaluation of the MLP on the descriptors, assuming a shape [batch, *, n_input]
.. todo: Make sure the :vmap: call can work with specific :use: values beyond the defaults assumed in the previous implementation.
:param rho: The descriptors to the MLP -- transformed densities and gradients appropriate to the XC-level. This network will only use the dimensions specified in self.use.
:type rho: jax.Array
:return: The exchange energy on the grid
:rtype: jax.Array
"""
print(f"eX.__call__, rho shape: {rho.shape}")
print(f"eX.__call__, rho nans: {jnp.isnan(rho).sum()}")
if self.spin_scaling:
squeezed = jnp.squeeze(jax.vmap(jax.vmap(self.net), in_axes=1)(rho[..., self.use])).T
else:
squeezed = jnp.squeeze(jax.vmap(self.net)(rho[..., self.use]))
if self.ueg_limit:
ueg_lim = rho[..., self.use[0]]
if len(self.use) > 1:
ueg_lim_a = jnp.power(self.tanh(rho[..., self.use[1]]), 2)
else:
ueg_lim_a = 0
if len(self.use) > 2:
ueg_lim_nl = jnp.sum(rho[..., self.use[2:]], axis=-1)
else:
ueg_lim_nl = 0
else:
ueg_lim = 1
ueg_lim_a = 0
ueg_lim_nl = 0
if self.lob:
result = self.lobf(squeezed*(ueg_lim + ueg_lim_a + ueg_lim_nl))
else:
result = squeezed*(ueg_lim + ueg_lim_a + ueg_lim_nl)
return result
# DEPRECATED
[docs]
class eC(eqx.Module):
n_input: int
n_hidden: int
ueg_limit: jax.Array
spin_scaling: bool
lob: jax.Array
use: list
net: eqx.Module
tanh: jax.named_call
lobf: jax.named_call
sig: jax.named_call
lobf: eqx.Module
seed: int
depth: int
[docs]
def __init__(self, n_input=2, n_hidden=16, depth=3, use=[], ueg_limit=False, lob=2.0, seed=92017, spin_scaling=False):
"""
__init__ Local correlation model based on MLP.
Receives density descriptors in this order : [rho, spinscale, s, alpha, nl], where the input may be truncated depending on XC-level of approximation
.. todo: Make sure the :vmap: call can work with specific :use: values beyond the defaults assumed in the previous implementation.
:param n_input: Input dimensions (LDA: 2, GGA: 3 , meta-GGA: 4), defaults to 2.
:type n_input: int
:param n_hidden: Number of hidden nodes (three hidden layers used by default), defaults to 16
:type n_hidden: int, optional
:param depth: Depth of the MLP, defaults to 3
:type depth: int, optional
:param use: Only these indices are used as input to the model. These indices are also used to enforce UEG where the assumed order is [s, alpha, ...], defaults to []
:type use: list, optional
:param ueg_limit: Flag to determine whether or not to enforce uniform homoegeneous electron gas limit, defaults to False
:type ueg_limit: bool, optional
:param lob: Enforce this value as local Lieb-Oxford bound (don't enforce if set to 0), defaults to 2.0
:type lob: float, optional
:param seed: Random seed used to generate initial weights and biases for the MLP, defaults to 92017
:type seed: int, optional
"""
super().__init__()
self.spin_scaling = spin_scaling
self.lob = False
self.ueg_limit = ueg_limit
self.n_input = n_input
self.n_hidden = n_hidden
self.seed = seed
self.depth = depth
if not use:
self.use = jnp.arange(n_input)
else:
self.use = use
self.net = eqx.nn.MLP(in_size=self.n_input,
out_size=1,
width_size=self.n_hidden,
depth=self.depth,
activation=jax.nn.gelu,
final_activation=jax.nn.softplus,
key=jax.random.PRNGKey(self.seed))
self.sig = jax.nn.sigmoid
self.tanh = jnp.tanh
self.lob = lob
if self.lob:
self.lobf = LOB(self.lob)
else:
self.lob = 1000.0
self.lobf = LOB(self.lob)
warn("WARNING - DEPRECATED. This class is not the currently working class and will be removed in the future.")
[docs]
def __call__(self, rho, **kwargs):
"""
__call__ Forward pass for the correlation network.
Uses :jax.vmap: to vectorize evaluation of the MLP on the descriptors, assuming a shape [*, n_input]
:param rho: The descriptors to the MLP -- transformed densities and gradients appropriate to the XC-level. This network will only use the dimensions specified in self.use in determining the UEG limits.
:type rho: jax.Array
:return: The exchange energy on the grid
:rtype: jax.Array
"""
print(f"eC.__call__, rho shape: {rho.shape}")
print(f"eC.__call__, rho nans: {jnp.isnan(rho).sum()}")
if self.spin_scaling:
squeezed = -jnp.squeeze(jax.vmap(jax.vmap(self.net), in_axes=1)(rho[..., self.use])).T
else:
squeezed = -jnp.squeeze(jax.vmap(self.net)(rho[..., self.use]))
if self.ueg_limit:
ueg_lim = self.tanh(rho[..., self.use[0]])
if len(self.use) > 1:
ueg_lim_a = jnp.pow(self.tanh(rho[..., self.use[1]]), 2)
else:
ueg_lim_a = 0
if len(self.use) > 2:
ueg_lim_nl = jnp.sum(self.tanh(rho[..., self.use[2:]])**2, axis=-1)
else:
ueg_lim_nl = 0
ueg_factor = ueg_lim + ueg_lim_a + ueg_lim_nl
else:
ueg_factor = 1
if self.lob:
return self.lobf(squeezed*ueg_factor)
else:
return squeezed*ueg_factor
[docs]
def make_net(xorc, level, depth, nhidden, ninput=None, use=None, spin_scaling=None, lob=None, ueg_limit=None,
random_seed=None, savepath=None, configfile='network.config'):
'''
make_net is a utility function designed to easily create new, individual exchange or correlation networks with ease. If no extra arguments are specified, the network will be generated with a default structure that respects the various constraints implemented within xcquinox
:param xorc: 'X' or 'C' -- the type of network to generate, exchange or correlation
:type xorc: str
:param level: one of ['GGA', 'MGGA', 'NONLOCAL', 'NL'], indicating the desired rung of Jacob's Ladder. NONLOCAL = NL
:type level: str
:param depth: The number of hidden layers in the generated network.
:type depth: int
:param nhidden: The number of nodes in a hidden layer
:type nhidden: int
:param ninput: The number of inputs the network will expect, defaults to None for automatic selection based on level
:type ninput: int, optional
:param use: The indices of the descriptors to evaluate the network on, defaults to None
:type use: list of ints, optional
:param spin_scaling: Whether or not to enforce the spin-scaling contraint in the generated network, defaults to None
:type spin_scaling: bool, optional
:param lob: Lieb-Oxford bound: If non-zero (i.e., truthy), the output values of e_x or e_c will be squashed between [-1, lob-1], defaults to None
:type lob: float, optional
:param ueg_limit: Whether or not to enforce the UEG scaling constraint, defaults to None
:type ueg_limit: bool, optional
:param random_seed: The random seed to use in generating initial network weights, defaults to None
:type random_seed: int, optional
:param savepath: Location to save the generated network and associated config file, defaults to None
:type savepath: str, optional
:param configfile: Name for the configuration file, needed when reading in the network to re-generate the same structure, defaults to 'network.config'
:type configfile: str, optional
:return: The resulting exchange or correlation network.
:rtype: :xcquinox.net.eX: or :xcquinox.net.eC:
'''
defaults_dct = {'GGA': {'X': {'ninput': 1, 'depth': 3, 'nhidden': 16, 'use': [1], 'spin_scaling': True, 'lob': 1.804, 'ueg_limit': True},
'C': {'ninput': 1, 'depth': 3, 'nhidden': 16, 'use': [2], 'spin_scaling': False, 'lob': 2.0, 'ueg_limit': True}
},
'MGGA': {'X': {'ninput': 2, 'depth': 3, 'nhidden': 16, 'use': [1, 2], 'spin_scaling': True, 'lob': 1.174, 'ueg_limit': True},
'C': {'ninput': 2, 'depth': 3, 'nhidden': 16, 'use': [2, 3], 'spin_scaling': False, 'lob': 2.0, 'ueg_limit': True}
},
'NONLOCAL': {'X': {'ninput': 15, 'depth': 3, 'nhidden': 16, 'use': None, 'spin_scaling': True, 'lob': 1.174, 'ueg_limit': True},
'C': {'ninput': 16, 'depth': 3, 'nhidden': 16, 'use': None, 'spin_scaling': False, 'lob': 2.0, 'ueg_limit': True}
},
'NL': {'X': {'ninput': 15, 'depth': 3, 'nhidden': 16, 'use': None, 'spin_scaling': True, 'lob': 1.174, 'ueg_limit': True},
'C': {'ninput': 16, 'depth': 3, 'nhidden': 16, 'use': None, 'spin_scaling': False, 'lob': 2.0, 'ueg_limit': True}
}
}
assert level.upper() in ['GGA', 'MGGA', 'NONLOCAL', 'NL']
ninput = ninput if ninput is not None else defaults_dct[level.upper()][xorc.upper()]['ninput']
depth = depth if depth is not None else defaults_dct[level.upper()][xorc.upper()]['depth']
nhidden = nhidden if nhidden is not None else defaults_dct[level.upper()][xorc.upper()]['nhidden']
use = use if use is not None else defaults_dct[level.upper()][xorc.upper()]['use']
spin_scaling = spin_scaling if spin_scaling is not None else defaults_dct[level.upper(
)][xorc.upper()]['spin_scaling']
ueg_limit = ueg_limit if ueg_limit is not None else defaults_dct[level.upper()][xorc.upper()]['ueg_limit']
lob = lob if lob is not None else defaults_dct[level.upper()][xorc.upper()]['lob']
random_seed = random_seed if random_seed is not None else 92017
config = {'ninput': ninput,
'depth': depth,
'nhidden': nhidden,
'use': use,
'spin_scaling': spin_scaling,
'ueg_limit': ueg_limit,
'lob': lob,
'random_seed': random_seed}
if xorc.upper() == 'X':
net = eX(n_input=ninput, use=use, depth=depth, n_hidden=nhidden,
spin_scaling=spin_scaling, lob=lob, seed=random_seed)
elif xorc.upper() == 'C':
net = eC(n_input=ninput, use=use, depth=depth, n_hidden=nhidden,
spin_scaling=spin_scaling, lob=lob, seed=random_seed)
if savepath:
try:
os.makedirs(savepath)
except Exception as e:
print(e)
print(f'Exception raised in creating {savepath}.')
with open(os.path.join(savepath, configfile), 'w') as f:
for k, v in config.items():
f.write(f'{k}\t{v}\n')
with open(os.path.join(savepath, configfile+'.pkl'), 'wb') as f:
pickle.dump(config, f)
eqx.tree_serialise_leaves(os.path.join(savepath, 'xc.eqx'), net)
return net, config
[docs]
def get_net(xorc, level, net_path, configfile='network.config', netfile='xc.eqx'):
'''
A utility function to easily load in a previously generated network. Functionally creates a random network of the same architecture, then overwrites the weights with those of the saved network.
:param xorc: 'X' or 'C' -- the type of network to generate, exchange or correlation
:type xorc: str
:param level: one of ['GGA', 'MGGA', 'NONLOCAL', 'NL'], indicating the desired rung of Jacob's Ladder. NONLOCAL = NL
:type level: str
:param net_path: Location of the saved network. Must have a {configfile}.pkl parameter file within.
:type net_path: str
:param configfile: Name for the configuration file, needed when reading in the network to re-generate the same structure, defaults to 'network.config'
:type configfile: str, optional
:param netfile: Name for the network file, needed when reading in the network overwrite generated random weights, defaults to 'xc.eqx'. If Falsy, just generates random network based on config file.
:type netfile: str, optional
:return: The requested exchange or correlation network.
:rtype: :xcquinox.net.eX: or :xcquinox.net.eC:
'''
with open(os.path.join(net_path, configfile+'.pkl'), 'rb') as f:
params = pickle.load(f)
# network parameters
depth = params['depth']
nodes = params['nhidden']
use = params['use']
inp = params['ninput']
ss = params['spin_scaling']
lob = params['lob']
ueg = params['ueg_limit']
seed = params['random_seed']
net, _ = make_net(xorc=xorc, level=level, depth=depth, nhidden=nodes, ninput=inp, use=use,
spin_scaling=ss, lob=lob, ueg_limit=ueg, random_seed=seed, configfile=configfile)
if netfile:
# make sure the netfile is actually there
netfs = [i for i in os.listdir(net_path) if netfile in i]
# if multiple returned, there was training to take place -- sort and select last checkpoint
if len(netfs) == 1:
print('SINGLE NETFILE MATCH FOUND. DESERIALIZING...')
net = eqx.tree_deserialise_leaves(os.path.join(net_path, netfs[0]), net)
elif len(netfs) > 1:
print('NETFILE MATCHES FOUND -- MULTIPLE. SELECTING LAST ONE.')
netf = sorted(netfs, key=lambda x: int(x.split('.')[-1]))[-1]
print('ATTEMPTING TO DESERIALIZE {}'.format(netf))
net = eqx.tree_deserialise_leaves(os.path.join(net_path, netf), net)
else:
print('NETFILE SPECIFIED BUT NO MATCHING FILE FOUND.')
return net, params