Source code for xcquinox.train

import equinox as eqx
import optax
import sys
import gc
import jax
import os
from jax.interpreters import xla
import jax.numpy as jnp
from typing import Callable
import inspect
import warnings
from functools import partial


[docs] class xcTrainer(eqx.Module): model: eqx.Module optim: optax.GradientTransformation loss: eqx.Module steps: int print_every: int clear_every: int memory_profile: bool verbose: bool do_jit: bool opt_state: tuple serialize_every: int logfile: str loss_v: float
[docs] def __init__(self, model, optim, loss, steps=50, print_every=1, clear_every=1, memory_profile=False, verbose=False, do_jit=True, serialize_every=1, logfile=''): ''' The base xcTrainer class, whose forward pass computes the training loop. :param model: The network which will be trained :type model: xcquinox.xc.eXC :param optim: optax optimizer object, e.g. optax.adamw(1e-4) :type optim: optax.GradientTransformation :param loss: The loss function object that computes the loss to be trained against. :type loss: eqx.Module :param steps: Length of the training cycle, i.e. the number of epochs, defaults to 50 :type steps: int, optional :param print_every: The number of epochs between loss information printing, defaults to 1 :type print_every: int, optional :param clear_every: The number of epochs between calls to clear the cache, defaults to 1 :type clear_every: int, optional :param memory_profile: If True, will write memory profiles at every epoch, to be used with `pprof`, defaults to False :type memory_profile: bool, optional :param verbose: If true, will print various extra information during the training cycle, defaults to False :type verbose: bool, optional :param do_jit: Controls whether the update function is jitted or not, useful for debugging if False, defaults to True :type do_jit: bool, optional :param serialize_every: Controls how often the checkpoint network is written to disk, defaults to 1 :type serialize_every: int, optional ''' super().__init__() self.model = model self.optim = optim self.loss = loss self.steps = steps self.print_every = print_every self.clear_every = clear_every self.memory_profile = memory_profile self.verbose = verbose self.do_jit = do_jit self.serialize_every = serialize_every self.opt_state = self.optim.init(eqx.filter(self.model, eqx.is_array)) self.logfile = logfile self.loss_v = 0
[docs] def __call__(self, epoch_batch_len, model, *loss_input_lists): ''' Forward pass of the xcTrainer object, which goes through the training cycle. *loss_input_lists are positional arguments, each a list of length [epoch_batch_len elements], corresponding to the proper input order and values for the self.loss function I.e., for the E_loss object, these would be [density matrix list], [reference energy list], [ao_eval list], [grid_weight list] .. todo: Remove model from inputs, not used since retrieved from self at first iteration :param epoch_batch_len: The number of batches in a given epoch (i.e., the number of molecules one is training on that are looped over) :type epoch_batch_len: int :param model: The baseline model to update in the training process :type model: xcquinox.xc.eXC :return: The updated model after the training cycle completes :rtype: xcquinox.xc.eXC ''' BEST_LOSS = 1e10 for step in range(self.steps): jax.debug.print('Epoch {}'.format(step)) epoch_loss = 0 if step == 0 and self.logfile: with open(self.logfile+'.dat', 'w') as f: f.write('#Epoch\tLoss\tBest\n') with open(self.logfile+'batch.dat', 'w') as f: f.write('#Epoch\tBatch\tLoss\tBest\n') if step == 0 and self.do_jit: fmake_step = eqx.filter_jit(self.make_step) elif ((step % self.clear_every) == 0) and (step > 0) and self.do_jit: fmake_step = eqx.filter_jit(self.make_step) else: fmake_step = self.make_step if step == 0: print('Step = 0: initializing inp_model and inp_opt_state.') inp_model = self.model start_model = self.model inp_opt_state = self.opt_state for idx in range(epoch_batch_len): jax.debug.print('Epoch {} :: Batch {}/{}'.format(step, idx+1, epoch_batch_len)) # loops over every iterable in loss_input_lists, selecting one batch's input data # assumes separate lists, each having inputs for multiple cases in the training set loss_inputs = [inp[idx] for inp in loss_input_lists] # this_loss = self.loss(inp_model, *loss_inputs).item() inp_model, inp_opt_state, this_loss = fmake_step(inp_model, inp_opt_state, *loss_inputs) if self.memory_profile: this_loss.block_until_ready() jax.profiler.save_device_memory_profile(f"memory{step}_{idx}.prof") jax.debug.print('Batch Loss = {}'.format(this_loss)) if self.logfile: with open(self.logfile+'batch.dat', 'a') as f: f.write(f'{step}\t{idx}\t{this_loss}\t{BEST_LOSS}\n') epoch_loss += this_loss if ((step % self.clear_every) == 0) and ((step > 0) == 0): eqx.clear_caches() jax.clear_backends() jax.clear_caches() # update self.loss_v to epoch's loss object.__setattr__(self, 'loss_v', epoch_loss.item()) if ((step % self.serialize_every) == 0) and (epoch_loss.item() < BEST_LOSS): eqx.tree_serialise_leaves('xc.eqx.{}'.format(step), start_model) BEST_LOSS = epoch_loss.item() # this will persist until next pass # the inp_model output just now is fed into get the loss next time, so if better we want to save this # not the updated one. start_model = inp_model if ((step % self.print_every) == 0) or (step == self.steps - 1): jax.debug.print( f"{step}, epoch_train_loss={epoch_loss}" ) if ((step % self.clear_every) == 0) and (step > 0): eqx.clear_caches() jax.clear_backends() jax.clear_caches() if self.logfile: with open(self.logfile+'.dat', 'a') as f: f.write(f'{step}\t{epoch_loss}\t{BEST_LOSS}\n') return inp_model
[docs] def make_step(self, model, opt_state, *args): ''' The update step for the training cycle. *args input are the inputs to the self.loss function. :param model: The model whose weights and biases are to be updated given the loss in self.loss :type model: xcquinox.xc.eXC :param opt_state: The state of the optimizer to drive the update :type opt_state: result of optim.update :return: The updated model :rtype: xcquinox.xc.eXC ''' self.vprint('loss_value, grads') loss_value, grads = eqx.filter_value_and_grad(self.loss)(model, *args) self.vprint('updates, opt_state') updates, opt_state = self.optim.update(grads, opt_state, model) self.vprint('model update') model = eqx.apply_updates(model, updates) return model, opt_state, loss_value
# def __post_init__(self, attr, value): # object.__setattr__(self, attr, value)
[docs] def clear_caches(self): ''' A function that attempts to clear memory associated to jax caching ''' for module_name, module in sys.modules.items(): if module_name.startswith("jax"): if module_name not in ["jax.interpreters.partial_eval"]: for obj_name in dir(module): obj = getattr(module, obj_name) if hasattr(obj, "cache_clear"): try: obj.cache_clear() except: pass gc.collect()
[docs] def vprint(self, output): ''' Custom print function. If self.verbose, will print the called output. :param output: The string or value to be printed. :type output: printable object ''' if self.verbose: jax.debug.print(output)
# Pre-trainer
[docs] class Pretrainer(eqx.Module): model: eqx.Module optim: optax.GradientTransformation steps: int print_every: int opt_state: tuple inputs: jnp.array ref: jnp.array loss: Callable
[docs] def __init__(self, model, optim, inputs, ref, loss, steps=1000, print_every=100): ''' The Pretrainer object aids in the initial pre-training of enhancement factor networks to have a more physical starting point for further network optimization. This class is meant to pre-train a randomly initialized network to fit the values of a specific XC functional's enhancement factor (either X or C, in principle it could also be a combined XC enhancement facator) :param model: The enhancement factor network to be pre-trained :type model: :xcquinox.net: class :param optim: The optimizer than will control the weight updates given a loss and gradient :type optim: optax.GradientTransformation :param inputs: The inputs the network itself is expecting in its forward pass function :type inputs: jnp.array :param ref: The reference values the network is expected to reproduce :type ref: jnp.array :param loss: A function from :xcquinox.loss: that is decorated with @eqx.filter_value_and_grad :type loss: Callable :param steps: Number of epochs to train over, defaults to 1000 :type steps: int, optional :param print_every: How often to print loss statistic, defaults to 100 :type print_every: int, optional ''' super().__init__() self.model = model self.optim = optim self.inputs = inputs self.ref = ref self.steps = steps self.print_every = print_every self.opt_state = self.optim.init(eqx.filter(self.model, eqx.is_array)) self.loss = loss
[docs] def __call__(self): ''' The training loop itself. Here, a loop over the specifed epochs takes place to train the network to fit reference values. :return: The trained model and an array of the losses during training. :rtype: (:xcquinox.net: class, array) ''' losses = [] for epoch in range(self.steps): if epoch == 0: this_model = self.model this_opt_state = self.opt_state loss, this_model, this_opt_state = self.make_step(this_model, self.inputs, self.ref, this_opt_state) lossi = loss.item() losses.append(lossi) if epoch % self.print_every == 0: print(f'Epoch {epoch}: Loss = {lossi}') return this_model, losses
[docs] @eqx.filter_jit def make_step(self, model, inputs, ref, opt_state): ''' The function that does each epoch's network update. It generates a loss and gradient using the specific :xcquinox.loss: function (that must be decorated with @eqx.filter_value_and_grad and only explicitly returns the loss value inside the function proper) given the specified inputs and reference values and initial optimization state. :param model: The enhancement factor network to be pre-trained :type model: :xcquinox.net: class :param inputs: The inputs the network itself is expecting in its forward pass function :type inputs: jnp.array :param ref: The reference values the network is expected to reproduce :type ref: jnp.array :param opt_state: The INITIAL optimization state to work against, typically generated via :self.optim.init(eqx.filter(self.model, eqx.is_array)): :type opt_state: The type of the above :return: The loss value for this step, the updated model after that loss is calculated, and the new optimization state for this step to use next time :rtype: tuple ''' loss, grad = self.loss(model, inputs, ref) updates, opt_state = self.optim.update(grad, opt_state) model = eqx.apply_updates(model, updates) return loss, model, opt_state
# Pre-trainer class Pretrainer_deriv(eqx.Module): model: eqx.Module optim: optax.GradientTransformation steps: int print_every: int opt_state: tuple inputs: jnp.array ref_v: jnp.array ref_g: tuple loss: Callable def __init__(self, model, optim, inputs, ref, loss, steps=1000, print_every=100): ''' The Pretrainer object aids in the initial pre-training of enhancement factor networks to have a more physical starting point for further network optimization. This class is meant to pre-train a randomly initialized network to fit the values of a specific XC functional's enhancement factor (either X or C, in principle it could also be a combined XC enhancement facator) :param model: The enhancement factor network to be pre-trained :type model: :xcquinox.net: class :param optim: The optimizer than will control the weight updates given a loss and gradient :type optim: optax.GradientTransformation :param inputs: The inputs the network itself is expecting in its forward pass function :type inputs: jnp.array :param ref: The reference values the network is expected to reproduce. If the shape is (N), it is assumed that the network is being trained to reproduce values only. If the shape is (N, 3), it is assumed that the network is being trained to reproduce values and gradients. The first column of the array is the values, the second and third are the gradients. :type ref: jnp.array :param loss: A function from :xcquinox.loss: that is decorated with @eqx.filter_value_and_grad :type loss: Callable :param steps: Number of epochs to train over, defaults to 1000 :type steps: int, optional :param print_every: How often to print loss statistic, defaults to 100 :type print_every: int, optional ''' super().__init__() self.model = model self.optim = optim self.inputs = inputs if len(ref.shape) == 1: self.ref_v = ref self.ref_g = None elif len(ref.shape) == 2 and ref.shape[1] == 3: self.ref_v = ref[:, 0] self.ref_g = (ref[:, 1], ref[:, 2]) else: raise NotImplementedError(f'Invalid reference shape: {ref.shape} - only implemented for training with values and gradients for GGA') self.steps = steps self.print_every = print_every self.opt_state = self.optim.init(eqx.filter(self.model, eqx.is_array)) # We see how many imputs does loss need, and fix the references. params_loss = len(inspect.signature(loss).parameters) if params_loss == 3: print('Pretrainer prepared to train using values') if self.ref_g is not None: warnings.warn('Reference values provided with gradients, but loss function only takes values. The gradients will be ignored.') fixed_loss = partial(loss, ref=self.ref_v) elif params_loss == 4: print('Pretrainer prepared to train using values and gradients') if self.ref_g is None: raise ValueError('Reference values provided without gradients, but loss function expects values and gradients.') else: fixed_loss = partial(loss, ref_v=self.ref_v, ref_g=self.ref_g) self.loss = fixed_loss def __call__(self): ''' The training loop itself. Here, a loop over the specifed epochs takes place to train the network to fit reference values. :return: The trained model and an array of the losses during training. :rtype: (:xcquinox.net: class, array) ''' losses = [] for epoch in range(self.steps): if epoch == 0: this_model = self.model this_opt_state = self.opt_state loss, this_model, this_opt_state = self.make_step(this_model, self.inputs, this_opt_state) lossi = loss.item() losses.append(lossi) if epoch % self.print_every == 0: print(f'Epoch {epoch}: Loss = {lossi}') return this_model, losses @eqx.filter_jit def make_step(self, model, inputs, opt_state): ''' The function that does each epoch's network update. It generates a loss and gradient using the specific :xcquinox.loss: function (that must be decorated with @eqx.filter_value_and_grad and only explicitly returns the loss value inside the function proper) given the specified inputs and reference values and initial optimization state. :param model: The enhancement factor network to be pre-trained :type model: :xcquinox.net: class :param inputs: The inputs the network itself is expecting in its forward pass function :type inputs: jnp.array :param ref: The reference values the network is expected to reproduce :type ref: jnp.array :param opt_state: The INITIAL optimization state to work against, typically generated via :self.optim.init(eqx.filter(self.model, eqx.is_array)): :type opt_state: The type of the above :return: The loss value for this step, the updated model after that loss is calculated, and the new optimization state for this step to use next time :rtype: tuple ''' loss, grad = self.loss(model, inputs) updates, opt_state = self.optim.update(grad, opt_state) model = eqx.apply_updates(model, updates) return loss, model, opt_state # Optimizer
[docs] class Optimizer(eqx.Module): model: eqx.Module optim: optax.GradientTransformation steps: int print_every: int opt_state: tuple mols: list refs: jnp.array loss: Callable
[docs] def __init__(self, model, optim, mols, refs, loss, steps=1000, print_every=100): ''' The Pretrainer object aids in the initial pre-training of enhancement factor networks to have a more physical starting point for further network optimization. This class is meant to pre-train a randomly initialized network to fit the values of a specific XC functional's enhancement factor (either X or C, in principle it could also be a combined XC enhancement facator) :param model: The enhancement factor network to be pre-trained :type model: :xcquinox.net: class :param optim: The optimizer than will control the weight updates given a loss and gradient :type optim: optax.GradientTransformation :param inputs: The inputs the network itself is expecting in its forward pass function :type inputs: jnp.array :param ref: The reference values the network is expected to reproduce :type ref: jnp.array :param loss: A function from :xcquinox.loss: that is decorated with @eqx.filter_value_and_grad :type loss: Callable :param steps: Number of epochs to train over, defaults to 1000 :type steps: int, optional :param print_every: How often to print loss statistic, defaults to 100 :type print_every: int, optional ''' super().__init__() self.model = model self.optim = optim self.mols = mols self.refs = refs self.steps = steps self.print_every = print_every self.opt_state = self.optim.init(eqx.filter(self.model, eqx.is_array)) self.loss = loss
[docs] def __call__(self): ''' The training loop itself. Here, a loop over the specifed epochs takes place to train the network to fit reference values. :return: The trained model and an array of the losses during training. :rtype: (:xcquinox.net: class, array) ''' losses = [] for epoch in range(self.steps): if epoch == 0: this_model = self.model this_opt_state = self.opt_state loss, this_model, this_opt_state = self.make_step(this_model, self.mols, self.refs, this_opt_state) lossi = loss.item() losses.append(lossi) if epoch % self.print_every == 0: print(f'Epoch {epoch}: Loss = {lossi}') return this_model, losses
# @eqx.filter_jit
[docs] def make_step(self, model, inputs, ref, opt_state): ''' The function that does each epoch's network update. It generates a loss and gradient using the specific :xcquinox.loss: function (that must be decorated with @eqx.filter_value_and_grad and only explicitly returns the loss value inside the function proper) given the specified inputs and reference values and initial optimization state. :param model: The enhancement factor network to be pre-trained :type model: :xcquinox.net: class :param inputs: The inputs the network itself is expecting in its forward pass function :type inputs: jnp.array :param ref: The reference values the network is expected to reproduce :type ref: jnp.array :param opt_state: The INITIAL optimization state to work against, typically generated via :self.optim.init(eqx.filter(self.model, eqx.is_array)): :type opt_state: The type of the above :return: The loss value for this step, the updated model after that loss is calculated, and the new optimization state for this step to use next time :rtype: tuple ''' loss, grad = self.loss(model, self.mols, self.refs) updates, opt_state = self.optim.update(grad, opt_state) model = eqx.apply_updates(model, updates) return loss, model, opt_state