6. Training Classes

6.1. The Pretrainer Class

class xcquinox.train.Pretrainer(model, optim, inputs, ref, loss, steps=1000, print_every=100)[source]

Bases: Module

__init__(model, optim, inputs, ref, loss, steps=1000, print_every=100)[source]

The Pretrainer object aids in the initial pre-training of enhancement factor networks to have a more physical starting point for further network optimization. This class is meant to pre-train a randomly initialized network to fit the values of a specific XC functional’s enhancement factor (either X or C, in principle it could also be a combined XC enhancement facator)

Parameters:
  • model – The enhancement factor network to be pre-trained

  • optim (optax.GradientTransformation) – The optimizer than will control the weight updates given a loss and gradient

  • inputs (jnp.array) – The inputs the network itself is expecting in its forward pass function

  • ref (jnp.array) – The reference values the network is expected to reproduce

  • loss (Callable) – A function from :xcquinox.loss: that is decorated with @eqx.filter_value_and_grad

  • steps (int, optional) – Number of epochs to train over, defaults to 1000

  • print_every (int, optional) – How often to print loss statistic, defaults to 100

__call__()[source]

The training loop itself. Here, a loop over the specifed epochs takes place to train the network to fit reference values.

Returns:

The trained model and an array of the losses during training.

Return type:

(:xcquinox.net: class, array)

make_step(model, inputs, ref, opt_state)[source]

The function that does each epoch’s network update. It generates a loss and gradient using the specific :xcquinox.loss: function (that must be decorated with @eqx.filter_value_and_grad and only explicitly returns the loss value inside the function proper) given the specified inputs and reference values and initial optimization state.

Parameters:
  • model – The enhancement factor network to be pre-trained

  • inputs (jnp.array) – The inputs the network itself is expecting in its forward pass function

  • ref (jnp.array) – The reference values the network is expected to reproduce

  • opt_state (The type of the above) – The INITIAL optimization state to work against, typically generated via :self.optim.init(eqx.filter(self.model, eqx.is_array)):

Returns:

The loss value for this step, the updated model after that loss is calculated, and the new optimization state for this step to use next time

Return type:

tuple

6.2. The Optimizer Class

class xcquinox.train.Optimizer(model, optim, mols, refs, loss, steps=1000, print_every=100)[source]

Bases: Module

__init__(model, optim, mols, refs, loss, steps=1000, print_every=100)[source]

The Pretrainer object aids in the initial pre-training of enhancement factor networks to have a more physical starting point for further network optimization. This class is meant to pre-train a randomly initialized network to fit the values of a specific XC functional’s enhancement factor (either X or C, in principle it could also be a combined XC enhancement facator)

Parameters:
  • model – The enhancement factor network to be pre-trained

  • optim (optax.GradientTransformation) – The optimizer than will control the weight updates given a loss and gradient

  • inputs (jnp.array) – The inputs the network itself is expecting in its forward pass function

  • ref (jnp.array) – The reference values the network is expected to reproduce

  • loss (Callable) – A function from :xcquinox.loss: that is decorated with @eqx.filter_value_and_grad

  • steps (int, optional) – Number of epochs to train over, defaults to 1000

  • print_every (int, optional) – How often to print loss statistic, defaults to 100

__call__()[source]

The training loop itself. Here, a loop over the specifed epochs takes place to train the network to fit reference values.

Returns:

The trained model and an array of the losses during training.

Return type:

(:xcquinox.net: class, array)

make_step(model, inputs, ref, opt_state)[source]

The function that does each epoch’s network update. It generates a loss and gradient using the specific :xcquinox.loss: function (that must be decorated with @eqx.filter_value_and_grad and only explicitly returns the loss value inside the function proper) given the specified inputs and reference values and initial optimization state.

Parameters:
  • model – The enhancement factor network to be pre-trained

  • inputs (jnp.array) – The inputs the network itself is expecting in its forward pass function

  • ref (jnp.array) – The reference values the network is expected to reproduce

  • opt_state (The type of the above) – The INITIAL optimization state to work against, typically generated via :self.optim.init(eqx.filter(self.model, eqx.is_array)):

Returns:

The loss value for this step, the updated model after that loss is calculated, and the new optimization state for this step to use next time

Return type:

tuple

6.3. The xcTrainer Class

class xcquinox.train.xcTrainer(model, optim, loss, steps=50, print_every=1, clear_every=1, memory_profile=False, verbose=False, do_jit=True, serialize_every=1, logfile='')[source]

Bases: Module

__init__(model, optim, loss, steps=50, print_every=1, clear_every=1, memory_profile=False, verbose=False, do_jit=True, serialize_every=1, logfile='')[source]

The base xcTrainer class, whose forward pass computes the training loop.

Parameters:
  • model (xcquinox.xc.eXC) – The network which will be trained

  • optim (optax.GradientTransformation) – optax optimizer object, e.g. optax.adamw(1e-4)

  • loss (eqx.Module) – The loss function object that computes the loss to be trained against.

  • steps (int, optional) – Length of the training cycle, i.e. the number of epochs, defaults to 50

  • print_every (int, optional) – The number of epochs between loss information printing, defaults to 1

  • clear_every (int, optional) – The number of epochs between calls to clear the cache, defaults to 1

  • memory_profile (bool, optional) – If True, will write memory profiles at every epoch, to be used with pprof, defaults to False

  • verbose (bool, optional) – If true, will print various extra information during the training cycle, defaults to False

  • do_jit (bool, optional) – Controls whether the update function is jitted or not, useful for debugging if False, defaults to True

  • serialize_every (int, optional) – Controls how often the checkpoint network is written to disk, defaults to 1

__call__(epoch_batch_len, model, *loss_input_lists)[source]

Forward pass of the xcTrainer object, which goes through the training cycle.

*loss_input_lists are positional arguments, each a list of length [epoch_batch_len elements], corresponding to the proper input order and values for the self.loss function I.e., for the E_loss object, these would be [density matrix list], [reference energy list], [ao_eval list], [grid_weight list]

Parameters:
  • epoch_batch_len (int) – The number of batches in a given epoch (i.e., the number of molecules one is training on that are looped over)

  • model (xcquinox.xc.eXC) – The baseline model to update in the training process

Returns:

The updated model after the training cycle completes

Return type:

xcquinox.xc.eXC

make_step(model, opt_state, *args)[source]

The update step for the training cycle.

*args input are the inputs to the self.loss function.

Parameters:
  • model (xcquinox.xc.eXC) – The model whose weights and biases are to be updated given the loss in self.loss

  • opt_state (result of optim.update) – The state of the optimizer to drive the update

Returns:

The updated model

Return type:

xcquinox.xc.eXC

clear_caches()[source]

A function that attempts to clear memory associated to jax caching

vprint(output)[source]

Custom print function. If self.verbose, will print the called output.

Parameters:

output (printable object) – The string or value to be printed.