6. Training Classes
6.1. The Pretrainer Class
- class xcquinox.train.Pretrainer(model, optim, inputs, ref, loss, steps=1000, print_every=100)[source]
Bases:
Module- __init__(model, optim, inputs, ref, loss, steps=1000, print_every=100)[source]
The Pretrainer object aids in the initial pre-training of enhancement factor networks to have a more physical starting point for further network optimization. This class is meant to pre-train a randomly initialized network to fit the values of a specific XC functional’s enhancement factor (either X or C, in principle it could also be a combined XC enhancement facator)
- Parameters:
model – The enhancement factor network to be pre-trained
optim (optax.GradientTransformation) – The optimizer than will control the weight updates given a loss and gradient
inputs (jnp.array) – The inputs the network itself is expecting in its forward pass function
ref (jnp.array) – The reference values the network is expected to reproduce
loss (Callable) – A function from :xcquinox.loss: that is decorated with @eqx.filter_value_and_grad
steps (int, optional) – Number of epochs to train over, defaults to 1000
print_every (int, optional) – How often to print loss statistic, defaults to 100
- __call__()[source]
The training loop itself. Here, a loop over the specifed epochs takes place to train the network to fit reference values.
- Returns:
The trained model and an array of the losses during training.
- Return type:
(:xcquinox.net: class, array)
- make_step(model, inputs, ref, opt_state)[source]
The function that does each epoch’s network update. It generates a loss and gradient using the specific :xcquinox.loss: function (that must be decorated with @eqx.filter_value_and_grad and only explicitly returns the loss value inside the function proper) given the specified inputs and reference values and initial optimization state.
- Parameters:
model – The enhancement factor network to be pre-trained
inputs (jnp.array) – The inputs the network itself is expecting in its forward pass function
ref (jnp.array) – The reference values the network is expected to reproduce
opt_state (The type of the above) – The INITIAL optimization state to work against, typically generated via :self.optim.init(eqx.filter(self.model, eqx.is_array)):
- Returns:
The loss value for this step, the updated model after that loss is calculated, and the new optimization state for this step to use next time
- Return type:
tuple
6.2. The Optimizer Class
- class xcquinox.train.Optimizer(model, optim, mols, refs, loss, steps=1000, print_every=100)[source]
Bases:
Module- __init__(model, optim, mols, refs, loss, steps=1000, print_every=100)[source]
The Pretrainer object aids in the initial pre-training of enhancement factor networks to have a more physical starting point for further network optimization. This class is meant to pre-train a randomly initialized network to fit the values of a specific XC functional’s enhancement factor (either X or C, in principle it could also be a combined XC enhancement facator)
- Parameters:
model – The enhancement factor network to be pre-trained
optim (optax.GradientTransformation) – The optimizer than will control the weight updates given a loss and gradient
inputs (jnp.array) – The inputs the network itself is expecting in its forward pass function
ref (jnp.array) – The reference values the network is expected to reproduce
loss (Callable) – A function from :xcquinox.loss: that is decorated with @eqx.filter_value_and_grad
steps (int, optional) – Number of epochs to train over, defaults to 1000
print_every (int, optional) – How often to print loss statistic, defaults to 100
- __call__()[source]
The training loop itself. Here, a loop over the specifed epochs takes place to train the network to fit reference values.
- Returns:
The trained model and an array of the losses during training.
- Return type:
(:xcquinox.net: class, array)
- make_step(model, inputs, ref, opt_state)[source]
The function that does each epoch’s network update. It generates a loss and gradient using the specific :xcquinox.loss: function (that must be decorated with @eqx.filter_value_and_grad and only explicitly returns the loss value inside the function proper) given the specified inputs and reference values and initial optimization state.
- Parameters:
model – The enhancement factor network to be pre-trained
inputs (jnp.array) – The inputs the network itself is expecting in its forward pass function
ref (jnp.array) – The reference values the network is expected to reproduce
opt_state (The type of the above) – The INITIAL optimization state to work against, typically generated via :self.optim.init(eqx.filter(self.model, eqx.is_array)):
- Returns:
The loss value for this step, the updated model after that loss is calculated, and the new optimization state for this step to use next time
- Return type:
tuple
6.3. The xcTrainer Class
- class xcquinox.train.xcTrainer(model, optim, loss, steps=50, print_every=1, clear_every=1, memory_profile=False, verbose=False, do_jit=True, serialize_every=1, logfile='')[source]
Bases:
Module- __init__(model, optim, loss, steps=50, print_every=1, clear_every=1, memory_profile=False, verbose=False, do_jit=True, serialize_every=1, logfile='')[source]
The base xcTrainer class, whose forward pass computes the training loop.
- Parameters:
model (xcquinox.xc.eXC) – The network which will be trained
optim (optax.GradientTransformation) – optax optimizer object, e.g. optax.adamw(1e-4)
loss (eqx.Module) – The loss function object that computes the loss to be trained against.
steps (int, optional) – Length of the training cycle, i.e. the number of epochs, defaults to 50
print_every (int, optional) – The number of epochs between loss information printing, defaults to 1
clear_every (int, optional) – The number of epochs between calls to clear the cache, defaults to 1
memory_profile (bool, optional) – If True, will write memory profiles at every epoch, to be used with pprof, defaults to False
verbose (bool, optional) – If true, will print various extra information during the training cycle, defaults to False
do_jit (bool, optional) – Controls whether the update function is jitted or not, useful for debugging if False, defaults to True
serialize_every (int, optional) – Controls how often the checkpoint network is written to disk, defaults to 1
- __call__(epoch_batch_len, model, *loss_input_lists)[source]
Forward pass of the xcTrainer object, which goes through the training cycle.
*loss_input_lists are positional arguments, each a list of length [epoch_batch_len elements], corresponding to the proper input order and values for the self.loss function I.e., for the E_loss object, these would be [density matrix list], [reference energy list], [ao_eval list], [grid_weight list]
- Parameters:
epoch_batch_len (int) – The number of batches in a given epoch (i.e., the number of molecules one is training on that are looped over)
model (xcquinox.xc.eXC) – The baseline model to update in the training process
- Returns:
The updated model after the training cycle completes
- Return type:
- make_step(model, opt_state, *args)[source]
The update step for the training cycle.
*args input are the inputs to the self.loss function.
- Parameters:
model (xcquinox.xc.eXC) – The model whose weights and biases are to be updated given the loss in self.loss
opt_state (result of optim.update) – The state of the optimizer to drive the update
- Returns:
The updated model
- Return type: