Source code for campa.tl._estimator

from __future__ import annotations

from typing import Any, TYPE_CHECKING

if TYPE_CHECKING:
    from campa.tl import Experiment
    import numpy as np

import os
import logging

import pandas as pd
import tensorflow as tf

from campa.tl import LossEnum, ModelEnum
from campa.data import NNDataset
from campa.tl._layers import UpdateSparsityLevel


# --- Callbacks ---
class LossWarmup(tf.keras.callbacks.Callback):
    """Callback to warmup loss weights."""

    def __init__(self, weight_vars, to_weights, to_epochs):
        super().__init__()
        self.to_weights = to_weights
        self.to_epochs = to_epochs
        self.weight_vars = weight_vars

    def on_epoch_begin(self, epoch, logs=None):
        """Update loss weights."""
        for key in self.to_epochs.keys():
            to_epoch = self.to_epochs[key]
            to_weight = self.to_weights[key]
            if to_epoch == 0 or to_epoch <= epoch:
                tf.keras.backend.set_value(self.weight_vars[key], to_weight)
            else:
                tf.keras.backend.set_value(self.weight_vars[key], to_weight / to_epoch * epoch)
            print(f"set {key} loss weight to {tf.keras.backend.get_value(self.weight_vars[key])}")

        if "latent" in self.weight_vars.keys():
            print(f"set latent loss weight to {tf.keras.backend.get_value(self.weight_vars['latent'])}")


class AnnealTemperature(tf.keras.callbacks.Callback):
    """Callback to anneal learning rate."""

    def __init__(self, temperature, initial_temperature, final_temperature, to_epoch):
        super().__init__()
        self.temperature = temperature
        self.initial_temperature = initial_temperature
        self.final_temperature = final_temperature
        self.to_epoch = to_epoch

    def on_epoch_begin(self, epoch, logs=None):
        """Update temperature."""
        if self.to_epoch == 0 or self.to_epoch <= epoch:
            tf.keras.backend.set_value(self.temperature, self.final_temperature)
        else:
            tf.keras.backend.set_value(
                self.temperature,
                self.initial_temperature + (self.final_temperature - self.initial_temperature) / self.to_epoch * epoch,
            )
        print(f"set temperature to {tf.keras.backend.get_value(self.temperature)}")


# --- Estimator class ---
[docs]class Estimator: """ Neural network estimator. Handles training and evaluation of models. Parameters ---------- exp Experiment with model config. """ def __init__(self, exp: Experiment): self.log = logging.getLogger(self.__class__.__name__) self.exp = exp self.config = exp.estimator_config self.config["training"]["loss"] = { key: LossEnum(val).get_fn() for key, val in self.config["training"]["loss"].items() } self.config["training"]["metrics"] = { key: LossEnum(val).get_fn() for key, val in self.config["training"]["metrics"].items() } self.callbacks: list[object] = [] # create model self.optimizer = None self.epoch = 0 self.create_model() self.compiled_model = False # train and val datasets # config params impacting y self.output_channels = self.config["data"]["output_channels"] self.repeat_y = len(self.config["training"]["loss"].keys()) if self.repeat_y == 1: self.repeat_y = False self.add_c_to_y = False if "adv_head" in self.config["training"]["loss"].keys(): self.add_c_to_y = True self.repeat_y = self.repeat_y - 1 self.ds = NNDataset( self.config["data"]["dataset_name"], data_config=self.config["data"]["data_config"], ) self._train_dataset, self._val_dataset, self._test_dataset = None, None, None # set up model weights and history paths for saving/loading later self.weights_name = os.path.join(self.exp.full_path, "weights_epoch{:03d}") # noqa: P103 self.history_name = os.path.join(self.exp.full_path, "history.csv") @property def train_dataset(self) -> tf.data.Dataset: """ Shuffled :class:`tf.data.Dataset` of train split. """ if self._train_dataset is None: self._train_dataset = self._get_dataset("train", shuffled=True) return self._train_dataset @property def val_dataset(self) -> tf.data.Dataset: """ :class:`tf.data.Dataset` of val split. """ if self._val_dataset is None: self._val_dataset = self._get_dataset("val") return self._val_dataset @property def test_dataset(self) -> tf.data.Dataset: """ :class:`tf.data.Dataset` of test split. """ if self._test_dataset is None: self._test_dataset = self._get_dataset("test") return self._test_dataset def _get_dataset(self, split: str, shuffled: bool = False) -> tf.data.Dataset: return self.ds.get_tf_dataset( split=split, output_channels=self.output_channels, is_conditional=self.model.is_conditional, repeat_y=self.repeat_y, add_c_to_y=self.add_c_to_y, shuffled=shuffled, )
[docs] def create_model(self): """ Initialise neural network model. Adds ``self.model``. """ ModelClass = ModelEnum(self.config["model"]["model_cls"]).get_cls() self.model = ModelClass(**self.config["model"]["model_kwargs"]) weights_path = self.config["model"]["init_with_weights"] if weights_path is True: weights_path = tf.train.latest_checkpoint(self.exp.full_path) if weights_path is None: self.log.warning( f"WARNING: weights_path set to true but no trained model found in {self.exp.full_path}" ) if isinstance(weights_path, str): # first need to compile the model self._compile_model() self.log.info(f"Initializing model with weights from {weights_path}") w1 = self.model.model.layers[5].get_weights() self.model.model.load_weights(weights_path).assert_nontrivial_match().assert_existing_objects_matched() w2 = self.model.model.layers[5].get_weights() assert (w1[0] != w2[0]).any() assert (w1[1] != w2[1]).any() self.epoch = self.exp.epoch
# TODO when fine-tuning need to reset self.epoch! def _compile_model(self): config = self.config["training"] # set loss weights self.loss_weights = {key: tf.keras.backend.variable(val) for key, val in config["loss_weights"].items()} # callback to update weights before each epoch self.callbacks.append( LossWarmup( self.loss_weights, config["loss_weights"], config["loss_warmup_to_epoch"], ) ) self.callbacks.append(UpdateSparsityLevel()) if hasattr(self.model, "temperature"): self.callbacks.append( AnnealTemperature( self.model.temperature, self.model.config["initial_temperature"], self.model.config["temperature"], self.model.config["anneal_epochs"], ) ) # create optimizer if self.optimizer is None: self.optimizer = tf.keras.optimizers.Adam(learning_rate=config["learning_rate"]) self.model.model.compile( optimizer=self.optimizer, loss=config["loss"], loss_weights=self.loss_weights, metrics=config["metrics"], ) self.compiled_model = True
[docs] def train_model(self): """ Train neural network model. Needs an initialised model in ``self.model``. """ config = self.config["training"] if not self.compiled_model: self._compile_model() # reset epoch when overwriting history if config["overwrite_history"]: self.epoch = 0 self.log.info(f"Training model for {config['epochs']} epochs") history = self.model.model.fit( # TODO this is only shuffling the first 10000 samples, but as data is shuffled already should be ok x=self.train_dataset.shuffle(10000).batch(config["batch_size"]).prefetch(1), validation_data=self.val_dataset.batch(config["batch_size"]).prefetch(1), epochs=config["epochs"], verbose=1, callbacks=self.callbacks, ) self.epoch += config["epochs"] history = pd.DataFrame.from_dict(history.history) history["epoch"] = range(self.epoch - config["epochs"], self.epoch) history = history.set_index("epoch") if config["save_model_weights"]: weights_name = self.weights_name.format(self.epoch) self.log.info(f"Saving model to {weights_name}") self.model.model.save_weights(weights_name) if config["save_history"]: if os.path.exists(self.history_name) and not config["overwrite_history"]: # if there is a previous history, concatenate to this prev_history = pd.read_csv(self.history_name, index_col=0) history = pd.concat([prev_history, history]) history.to_csv(self.history_name) return history
[docs] def predict_model(self, data: tf.data.Dataset | np.ndarray, batch_size: int | None = None) -> Any: """ Predict all elements in ``data``. Parameters ---------- data Data to predict, with first dimension the number of elements. batch_size Batch size. If None, the training batch size is used. Returns ------- ``Iterable`` prediction """ if isinstance(data, tf.data.Dataset): data = data.batch(self.config["training"]["batch_size"]) batch_size = None elif batch_size is None: batch_size = self.config["training"]["batch_size"] pred = self.model.model.predict(data, batch_size=batch_size) if isinstance(pred, list): # multiple output model, but only care about first output pred = pred[0] return pred
[docs] def evaluate_model(self, dataset: tf.data.Dataset | None = None) -> Any: """ Evaluate model using :class:`tf.data.Dataset`. Parameters ---------- dataset Dataset to evaluate. If None, :meth:`Estimator.val_dataset` is used. Returns ------- ``Iterable[float]`` Scores. """ if not self.compiled_model: self._compile_model() if dataset is None: dataset = self.val_dataset self.model.model.reset_metrics() scores = self.model.model.evaluate(dataset.batch(self.config["training"]["batch_size"])) return scores