Source code for campa.tl._models

# TODO go through functions and comment
from copy import deepcopy
from enum import Enum
from typing import Any, Dict, Tuple, Optional
from functools import partial
import json
import logging

import tensorflow as tf


[docs]class ModelEnum(str, Enum): """ Neural network models. Possible values are: - ``ModelEnum.BaseAEModel``: "BaseAEModel" (:class:`BaseAEModel`) - ``ModelEnum.VAEModel``: "VAEModel" (:class:`VAEModel`) """ BaseAEModel = "BaseAEModel" VAEModel = "VAEModel" CatVAEModel = "CatVAEModel" GMMVAEModel = "GMMVAEModel" CondCatVAEModel = "CondCatVAEModel" def get_cls(self): """ Get model class from enum string. """ cls = self.__class__ if self == cls.BaseAEModel: return BaseAEModel elif self == cls.VAEModel: return VAEModel elif self == cls.CatVAEModel: return CatVAEModel elif self == cls.GMMVAEModel: return GMMVAEModel elif self == cls.CondCatVAEModel: return CondCatVAEModel else: raise NotImplementedError
# --- tf functions needed for model definition --- def expand_and_broadcast(x, s=1): """Expand tensor x with shape (batches,n) to shape (batches,s,s,n).""" C = tf.expand_dims(tf.expand_dims(x, 1), 1) C = tf.broadcast_to(C, [tf.shape(C)[0], s, s, tf.shape(C)[-1]]) return C def reparameterize_gumbel_softmax(latent, temperature=0.1): """Draw a sample from the Gumbel-Softmax distribution.""" def sample_gumbel(shape, eps=1e-20): """Sample from Gumbel(0, 1).""" U = tf.random.uniform(shape, minval=0, maxval=1) return -tf.math.log(-tf.math.log(U + eps) + eps) y = latent + sample_gumbel(tf.shape(latent)) return tf.nn.softmax(y / temperature) def reparameterize_gaussian(latent): """Draw a sample from Gaussian distribution.""" z_mean, z_log_var = tf.split(latent, 2, axis=-1) eps = tf.random.normal(shape=tf.shape(z_mean)) return eps * tf.exp(z_log_var * 0.5) + z_mean @tf.custom_gradient def grad_reverse(x): """Reverse gradients for adversarial loss.""" y = tf.identity(x) def custom_grad(dy): return -dy return y, custom_grad class GradReverse(tf.keras.layers.Layer): """Reverse gradients class.""" def __init__(self): super().__init__() def call(self, x): """Apply gradient reversal.""" return grad_reverse(x) # --- Model classes --- BASE_MODEL_CONFIG: Dict[str, Any] = { "name": None, # input definition "num_neighbors": 1, "num_channels": 35, # can be split up in num_input_channels and num_output_channels "num_output_channels": None, "num_input_channels": None, # conditions are appended to the input and the latent representation. They are assumed to be 1d "num_conditions": 0, # if > 0, the model is conditional and an additional input w/ conditions is assumed. # if number or list, the condition is encoded using dense layers with this number of nodes "encode_condition": None, # which layers of encoder and decoder to apply condition to. # Give index of layer in encoder and decoder "condition_injection_layers": [0], # encoder architecture "input_noise": None, # 'gaussian', 'dropout', adds noise to encoder input "noise_scale": 0, "encoder_conv_layers": [32], "encoder_conv_kernel_size": [1], "encoder_fc_layers": [32, 16], # from last encoder layer, a linear fcl to latent_dim is applied "latent_dim": 16, # number of nodes in latent space (for some models == number of classes) # decoder architecture # from last decoder layer, a linear fcl to num_output_channels is applied "decoder_fc_layers": [], # decoder regularizer "decoder_regularizer": None, # 'l1' or 'l2' "decoder_regularizer_weight": 0, # for adversarial models, add adversarial layers "adversarial_layers": None, # only works with categorical conditions }
[docs]class BaseAEModel: """ Base class for AE and VAE models. This model can have neighbours, conditions (concatenated to input + decoder), and and adversarial head. The class defines initialisation functions for setting up AE with encoder and decoder. In addition, encoder and decoder input layers are defined (can be overwritten in subclassed functions). Subclassed models can define: - :attr:`BaseAEModel.default_config` (as class variable). - :meth:`BaseAEModel.create_encoder` function, returning encoder model and latent (for KL loss). - :meth:`BaseAEModel.create_decoder` function, returning decoder model. - :meth:`BaseAEModel.create_model` function, creating overall model (put encoder and decoder together). Default architecture of this model: - Encoder: ``(noise) - conv layers - fc layers - linear layer to latent_dim`` - Decoder: ``fc_layers - linear (regularized) layer to num_output_channels`` Conditional models additionally output `latent`, the latent space (for KL loss computation). Adversarial models additionally output `adv_head`, the output of the adversarial head (for adversarial loss computation). ``(adv_latent - reverse_gradients - adversarial_layers - linear layer to num_conditions)``. TODO adversarial models are not tested in this version of the code. Parameters ---------- name: Optional[str] Model name. num_neighbors: str Number of neighbours used in input data. num_channels: int Number of channels for input and output. Can be split up in ``num_input_channels`` and ``num_output_channels``. num_output_channels: Optional[int] Number of output channels num_input_channels: Optional[int] Number of input channels num_conditions: int Number of conditions. If > 0, the model is conditional and an additional input w/ conditions is assumed. Conditions are appended to the input and the latent representation. They are assumed to be 1d. encode_condition: Optional[Union[int, Iterable[int]] If number or list, the condition is encoded using dense layers with this number of nodes. condition_injection_layers: Iterable[int] Which layers of encoder and decoder to apply condition to. Give index of layer in encoder and decoder. input_noise: Optional[str] One of `gaussian`, `dropout`, adds noise to encoder input. noise_scale: int Scale of Gaussian noise. encoder_conv_layers: Iterable[int] Size of convolutional layers for encoder. encoder_conv_kernel_size: Iterable[int] Kernel size for each encoder convolutional layer. encoder_fc_layers: Iterable[int] Size of fully connected encoder layers. From last encoder layer, a linear fully connected layer to ``latent_dim`` is applied. latent_dim: int Number of nodes in latent space (for some models == number of classes). decoder_fc_layers: Iterable[int] Decoder architecture. From last decoder later, a linear fully connected layer to ``num_output_channels`` is applied. decoder_regularizer: Optional[str] Regularizer for decoder, `l1` or `l2`. decoder_regularizer_weight: float Weight of regularizer. adversarial_layers For adversarial models, add adversarial layers. Only works with categorical conditions. """ default_config: Dict[str, Any] = {"name": "BaseAEModel"} """ Default config used in every model. """ def __init__(self, **kwargs): # set up log and config self.log = logging.getLogger(self.__class__.__name__) self.config = deepcopy(BASE_MODEL_CONFIG) self.config.update(self.default_config) self.config.update(kwargs) if self.config["num_output_channels"] is None: self.config["num_output_channels"] = self.config["num_channels"] if self.config["num_input_channels"] is None: self.config["num_input_channels"] = self.config["num_channels"] if isinstance(self.config["encoder_conv_kernel_size"], int): self.config["encoder_conv_kernel_size"] = [ self.config["encoder_conv_kernel_size"] for _ in self.config["encoder_conv_layers"] ] self.log.info("Creating model") self.log.debug(f"Creating model with config: {json.dumps(self.config, indent=4)}") # set up model # input layers for encoder and decoder self.encoder_input = tf.keras.layers.Input( ( self.config["num_neighbors"], self.config["num_neighbors"], self.config["num_channels"], ) ) self.decoder_input = tf.keras.layers.Input((self.config["latent_dim"],)) if self.is_conditional: self.encoder_input = [ self.encoder_input, tf.keras.layers.Input((self.config["num_conditions"],)), ] self.decoder_input = [ self.decoder_input, tf.keras.layers.Input((self.config["num_conditions"],)), ] # set self.encoder, self.latent, self.decoder, self.model_output self.model = self.create_model() # expose layers and summary here self.layers = self.model.layers summary = [] self.model.summary(print_fn=lambda x: summary.append(x)) self.encoder.summary(print_fn=lambda x: summary.append(x)) self.decoder.summary(print_fn=lambda x: summary.append(x)) if self.is_adversarial: self.adv_head.summary(print_fn=lambda x: summary.append(x)) self.summary = "\n".join(summary)
[docs] def create_model(self) -> tf.keras.Model: """ Create :class:`tf.keras.Model`. Use :meth:`BaseAEModel.create_encoder` and :meth:`BaseAEModel.create_decoder` functions to set ``self.encoder``, ``self.latent``, ``self.decoder``, ``self.model_output`` attributes. Returns ------- Neural network model. """ # encoder and decoder self.encoder, self.latent = self.create_encoder() self.decoder = self.create_decoder() if self.is_adversarial: self.adv_head = self.create_adversarial_head() # create model if self.is_conditional: # self.model_output = self.decoder([self.encoder.output, self.encoder.input[1]]) self.model_output = self.decoder([self.encoder(self.encoder_input), self.encoder_input[1]]) else: self.model_output = self.decoder([self.encoder(self.encoder_input)]) if self.latent is not None: # model should return both output + latent (for KL loss) self.model_output = [self.model_output, self.latent] if self.is_adversarial: if isinstance(self.model_output, list): self.model_output = self.model_output + [self.adv_head(self.encoder(self.encoder_input))] else: self.model_output = [ self.model_output, self.adv_head(self.encoder(self.encoder_input)), ] model = tf.keras.Model(self.encoder_input, self.model_output, name=self.config["name"]) return model
@property def is_conditional(self): """Flag set based on ``num_conditions``.""" return self.config["num_conditions"] > 0 @property def is_adversarial(self): """Model is adversarial if is is conditional and adversarial layers are defined.""" return self.config["adversarial_layers"] is not None and self.is_conditional
[docs] def encode_condition(self, C): """ Apply condition encoder to C. Parameters ---------- C condition (maybe one-hot encoded) Returns ------- Encoded condition. """ if self.config["encode_condition"] is None: return C if not hasattr(self, "condition_encoder"): enc_l = self.config["encode_condition"] if isinstance(enc_l, int): enc_l = [enc_l] inpt = tf.keras.layers.Input((self.config["num_conditions"],)) x = inpt for layer in enc_l: x = tf.keras.layers.Dense(layer, activation=tf.nn.relu)(x) self.condition_encoder = tf.keras.Model(inpt, x, name="condition_encoder") return self.condition_encoder(C)
def _create_base_encoder(self): """ Create base encoder structure with convolutional layers and fully connected layers. Does not apply the last (linear) layer to latent_dim - useful for VAE which does this differently. """ if self.is_conditional: X, C = self.encoder_input C = self.encode_condition(C) # broadcast C to fit to X # fn = partial(expand_and_broadcast, s=self.config['num_neighbors']) # C = tf.keras.layers.Lambda(fn)(C) else: X = self.encoder_input if self.config["input_noise"] is not None: # add noise X = self.add_noise(X) # if self.is_conditional: # # concatenate input and conditions # X = tf.keras.layers.concatenate([X, C], axis=-1) # conv layers cond_layers = self.config["condition_injection_layers"] for i, l in enumerate(self.config["encoder_conv_layers"]): # check if need to concatenate current X with C if self.is_conditional and i in cond_layers: # need to broadcast C to fit to X fn = partial(expand_and_broadcast, s=self.config["num_neighbors"]) C_bcast = tf.keras.layers.Lambda(fn)(C) X = tf.keras.layers.concatenate([X, C_bcast], axis=-1) k = self.config["encoder_conv_kernel_size"][i] X = tf.keras.layers.Conv2D(l, kernel_size=(k, k), activation=tf.nn.relu)(X) X = tf.keras.layers.Flatten()(X) # fully connected layers for j, l in enumerate(self.config["encoder_fc_layers"]): # check if need to concatenate current X with C if self.is_conditional and i + 1 + j in cond_layers: X = tf.keras.layers.concatenate([X, C], axis=-1) X = tf.keras.layers.Dense(l, activation=tf.nn.relu)(X) return X
[docs] def create_encoder(self) -> Tuple[tf.keras.Model, Optional[tf.Tensor]]: """ Create encoder. Encoder outputs reparameterized latent. Latent is potentially returned by overall model for loss calculation, e.g. for VAE. Returns ------- Encoder and latent (None for BaseAEModel). """ X = self._create_base_encoder() # linear layer to latent X = tf.keras.layers.Dense(self.config["latent_dim"], activation=None, name="latent")(X) # define encoder model encoder = tf.keras.Model(self.encoder_input, X, name="encoder") return encoder, None
[docs] def create_decoder(self): """ Create decoder. Returns ------- tf.keras.Model Decoder. """ X = self.decoder_input if self.is_conditional: X, C = self.decoder_input C = self.encode_condition(C) # concatenate latent + conditions # X = tf.keras.layers.concatenate([X, C]) # fully-connected layers cond_layers = self.config["condition_injection_layers"] for i, l in enumerate(self.config["decoder_fc_layers"]): # check if need to concatenate current X with C if self.is_conditional and i in cond_layers: X = tf.keras.layers.concatenate([X, C], axis=-1) X = tf.keras.layers.Dense(l, activation=tf.nn.relu)(X) if i == 0 and self.is_conditional: self.entangled_latent = X # might need this later on # if no fully-connected layers are build, need to still concatenate current X with C if len(self.config["decoder_fc_layers"]) == 0 and self.is_conditional: X = tf.keras.layers.concatenate([X, C], axis=-1) # linear layer to num_output_channels (optionally regularized) if self.config["decoder_regularizer"] == "l1": tf.keras.regularizers.l1(self.config["decoder_regularizer_weight"]) else: tf.keras.regularizers.l2(self.config["decoder_regularizer_weight"]) decoder_output = tf.keras.layers.Dense(self.config["num_output_channels"], activation=None)(X) # define decoder model decoder = tf.keras.Model(self.decoder_input, decoder_output, name="decoder") return decoder
[docs] def create_adversarial_head(self) -> tf.keras.Model: """ Create adversarial head: ``reverse_gradient - adversarial_layers - num_conditions``. Returns ------- adversarial head. """ assert self.is_conditional assert self.is_adversarial inpt = tf.keras.layers.Input((self.config["latent_dim"],)) X = inpt X = GradReverse()(X) for layer in self.config["adversarial_layers"]: X = tf.keras.layers.Dense(layer, activation=tf.nn.relu)(X) # linear layer to num_conditions adv_head_output = tf.keras.layers.Dense(self.config["num_conditions"], activation=None)(X) # define adv_head model adv_head = tf.keras.Model(inpt, adv_head_output, name="adv_head") return adv_head
[docs] def add_noise(self, X): """ Add noise to X. Parameters ---------- X inputs. Returns ------- X with noise applied. """ if self.config["input_noise"] == "dropout": X = tf.keras.layers.Dropout(self.config["noise_scale"])(X) elif self.config["input_noise"] == "gaussian": X = tf.keras.layers.GaussianNoise(self.config["noise_scale"])(X) else: raise NotImplementedError return X
[docs]class VAEModel(BaseAEModel): """ VAE with simple Gaussian prior (trainable with KL loss). Inherits from :class:`BaseAEModel`. Model architecture: - Encoder: ``(noise) - conv layers - fc layers - linear layer to latent_dim * 2`` - Latent: split ``latent_dim`` in half, re-sample using Gaussian prior - Decoder: ``fc_layers - linear (regularized) layer to num_output_channels`` """ default_config = {"name": "VAEModel"}
[docs] def create_encoder(self) -> Tuple[tf.keras.Model, Optional[tf.Tensor]]: """ Create encoder. Encoder outputs reparameterized latent. Latent is potentially returned by overall model for loss calculation, e.g. for VAE. Returns ------- Encoder and latent (None for ``BaseAEModel``). """ X = self._create_base_encoder() # linear layer to latent latent = tf.keras.layers.Dense(self.config["latent_dim"] * 2, activation=None, name="latent")(X) # reparameterise reparam_latent = reparameterize_gaussian(latent) # define encoder encoder = tf.keras.Model(self.encoder_input, reparam_latent, name="encoder") return encoder, latent
class CatVAEModel(BaseAEModel): """ Categorical VAE Model. VAE with categorical prior (softmax gumbel) (trainable with categorical loss) Encoder: (noise) - conv layers - fc layers - linear layer to latent_dim * 2 Latent: split latent_dim in half, resample using Gaussian prior Decoder: fc_layers - linear (regularized) layer to num_output_channels """ default_config: Dict[str, Any] = { "name": "CatVAEModel", # temperature for scaling gumbel_softmax. values close to 0 are close to true categorical distribution "temperature": 0.1, "initial_temperature": 10, "anneal_epochs": 0, } def __init__(self, **kwargs): self.temperature = tf.Variable( initial_value=kwargs.get("initial_temperature", self.default_config["initial_temperature"]), trainable=False, dtype=tf.float32, ) super().__init__(**kwargs) def create_encoder(self): """ Create encoder. Encoder outputs reparameterised latent. Latent is potentially returned by overall model for loss calculation, e.g. for VAE. Returns ------- tf.keras.model, tf.Tensor Encoder and latent (None for BaseAEModel). """ X = self._create_base_encoder() # linear layer to latent latent = tf.keras.layers.Dense(self.config["latent_dim"], activation=None, name="latent")(X) # reparameterise reparam_latent = reparameterize_gumbel_softmax(latent, self.temperature) # define encoder encoder = tf.keras.Model(self.encoder_input, reparam_latent, name="encoder") return encoder, latent class CondCatVAEModel(CatVAEModel): """ Conditional Categorical VAE model. Conditional Categorical VAE using another concatenation scheme when adding the condition to the latent space. This model first calculates a fully connected layer to a vector with length #output_channels x #conditions IGNORES decoder_fc_layers - only supports linear decoder! """ def create_decoder(self): """ Create decoder. """ X, C = self.decoder_input # dense layer to num_output_channels x num_conditions X = tf.keras.layers.Dense( self.config["num_output_channels"] * self.config["encode_condition"], activation=None, )(X) X = tf.keras.layers.Reshape((self.config["num_output_channels"], self.config["encode_condition"]))(X) C = self.encode_condition(C) # multiply X by conditions decoder_output = tf.keras.layers.Dot(axes=[2, 1])([X, C]) # define decoder model decoder = tf.keras.Model(self.decoder_input, decoder_output, name="decoder") return decoder class GMMVAEModel(BaseAEModel): """ Gaussian Mixture Model VAE. VAE with gmm prior (trainable with categorical loss for y and weighted kl loss for z) Encoder y: (noise) - conv layers y - fc layers y - linear layer to latent_dim Encoder: (noise) + y - conv layers - fc layers - linear layer to latent_dim * 2 Latent: split latent_dim in half, resample using Gaussian prior Decoder: fc_layers - linear (regularized) layer to num_output_channels """ default_config: Dict[str, Any] = { "name": "GMMVAEModel", # y encoder architecture "y_conv_layers": None, "y_conv_kernel_size": None, "y_fc_layers": None, # pz (gmm prior for zmean and zvar from categorical y) "pz_fc_layers": None, # number of different gaussians "k": 10, # temperature for categorical loss on y <- might not need to anneal! # temperature for scaling gumbel_softmax. values close to 0 are close to true categorical distribution "temperature": 0.1, "initial_temperature": 10, "anneal_epochs": 0, } def __init__(self, **kwargs): # make sure y_... is defined config = deepcopy(BASE_MODEL_CONFIG) config.update(self.default_config) config.update(kwargs) if config["y_conv_layers"] is None: config["y_conv_layers"] = config["encoder_conv_layers"] if config["y_conv_kernel_size"] is None: config["y_conv_kernel_size"] = config["encoder_conv_kernel_size"] if config["y_fc_layers"] is None: config["y_fc_layers"] = config["encoder_fc_layers"] if config["pz_fc_layers"] is None: config["pz_fc_layers"] = config["encoder_fc_layers"] super().__init__(**config) def qy_graph(self, input_shape): """ Return Y calculated from X (convolutional layers + fully connected layers). """ X_input = tf.keras.layers.Input(input_shape) X = X_input # conv layers for i, l in enumerate(self.config["y_conv_layers"]): k = self.config["y_conv_kernel_size"][i] X = tf.keras.layers.Conv2D(l, kernel_size=(k, k), activation=tf.nn.relu)(X) X = tf.keras.layers.Flatten()(X) # fully connected layers for layer in self.config["y_fc_layers"]: X = tf.keras.layers.Dense(layer, activation=tf.nn.relu)(X) # linear layer to latent Y = tf.keras.layers.Dense(self.config["k"], activation="softmax")(X) model = tf.keras.Model(X_input, Y) return model def create_y_encoder(self, X): """ Return Y calculated from X (convolutional layers + fully connected layers). """ # conv layers for i, l in enumerate(self.config["y_conv_layers"]): k = self.config["y_conv_kernel_size"][i] X = tf.keras.layers.Conv2D(l, kernel_size=(k, k), activation=tf.nn.relu)(X) X = tf.keras.layers.Flatten()(X) # fully connected layers for layer in self.config["y_fc_layers"]: X = tf.keras.layers.Dense(layer, activation=tf.nn.relu)(X) # linear layer to latent Y = tf.keras.layers.Dense(self.config["k"], activation="softmax", name="latent_y")(X) return Y def qz_graph(self, X_input_shape): """ Return Z calculated from X and Y. """ X_input = tf.keras.layers.Input(X_input_shape) Y_input = tf.keras.layers.Input((self.config["k"],)) # concatenate Y with X fn = partial(expand_and_broadcast, s=self.config["num_neighbors"]) Y = tf.keras.layers.Lambda(fn)(Y_input) X = tf.keras.layers.concatenate([X_input, Y], axis=-1) # this is now the input for the normal encoder # conv layers for i, l in enumerate(self.config["encoder_conv_layers"]): k = self.config["encoder_conv_kernel_size"][i] X = tf.keras.layers.Conv2D(l, kernel_size=(k, k), activation=tf.nn.relu)(X) X = tf.keras.layers.Flatten()(X) # fully connected layers for layer in self.config["encoder_fc_layers"]: X = tf.keras.layers.Dense(layer, activation=tf.nn.relu)(X) # linear layer to latent Z = tf.keras.layers.Dense(self.config["latent_dim"] * 2, activation=None)(X) model = tf.keras.Model((X_input, Y_input), Z, name="qz_model") return model def pz_graph(self): """Prior distibution of Z for different categories (Y should be 1-hot encoded vector).""" Y_input = tf.keras.layers.Input((self.config["k"],)) X = Y_input # fully connected layers for layer in self.config["pz_fc_layers"]: X = tf.keras.layers.Dense(layer, activation=tf.nn.relu)(X) Z = tf.keras.layers.Dense(self.config["latent_dim"] * 2, activation=None)(X) model = tf.keras.Model(Y_input, Z, name="pz_model") return model def create_encoder(self): """ Create encoder. Encoder outputs reparameterised latent. Latent is potentially returned by overall model for loss calculation, e.g. for VAE. Returns ------- tf.keras.model, tf.Tensor Encoder and latent (None for BaseAEModel). """ if self.is_conditional: X, C = self.encoder_input # broadcast C to fit to X fn = partial(expand_and_broadcast, s=self.config["num_neighbors"]) C = tf.keras.layers.Lambda(fn)(C) else: X = self.encoder_input if self.config["input_noise"] is not None: X = self.add_noise(X) if self.is_conditional: # concatenate input and conditions X = tf.keras.layers.concatenate([X, C], axis=-1) # define qz and pz subgraphs shape_X = X.shape.as_list()[1:] qz_model = self.qz_graph(shape_X) pz_model = self.pz_graph() # get qy latent_y = self.create_y_encoder(X) # get qz for qy value (for reconstruction) Z = qz_model([X, latent_y]) # reparameterise reparam_Z = reparameterize_gaussian(Z) # define encoder encoder = tf.keras.Model(self.encoder_input, reparam_Z, name="encoder") # get pz, qz for different values of y pZ = [] qZ = [] # functions for expanding Y to have batch_size dim def expand_and_broadcastY(Y, s): Y = tf.expand_dims(Y, 0) Y = tf.broadcast_to(Y, [s, self.config["k"]]) return Y exp_fn = partial(expand_and_broadcastY, s=tf.shape(X)[0]) for i in range(0, self.config["k"]): Yi = tf.one_hot(i, depth=self.config["k"]) Yi = tf.keras.layers.Lambda(exp_fn)(Yi) pZ.append(pz_model(Yi)) qZ.append(qz_model([X, Yi])) # stack together (shape: None, k, latent_dim) pZ = tf.keras.layers.Lambda(lambda x: tf.stack(x, axis=1))(pZ) qZ = tf.keras.layers.Lambda(lambda x: tf.stack(x, axis=1))(qZ) # expand latent_y to have shape: None, k, latent_dim def expand_and_broadcast_qY(Y, s): Y = tf.expand_dims(Y, axis=-1) Y = tf.broadcast_to(Y, [tf.shape(Y)[0], tf.shape(Y)[1], s]) return Y fn = partial(expand_and_broadcast_qY, s=self.config["latent_dim"] * 2) qY = tf.keras.layers.Lambda(fn)(latent_y) # stack pZ, qZ, qY to one output latent_zy = tf.keras.layers.Lambda(lambda x: tf.stack(x, axis=1), name="latent_zy")([pZ, qZ, qY]) return encoder, latent_y, latent_zy def create_model(self): """ Create keras model using create_encoder and create_decoder functions. Set self.encoder, self.latent, self.decoder, self.model_output attributes. """ # encoder and decoder self.encoder, self.latent_y, self.latent_zy = self.create_encoder() self.decoder = self.create_decoder() # add encoder_y model self.encoder_y = tf.keras.Model(self.encoder_input, self.latent_y, name="encoder_y") # create model if self.is_conditional: self.model_output = self.decoder([self.encoder(self.encoder_input), self.encoder_input[1]]) else: self.model_output = self.decoder([self.encoder(self.encoder_input)]) # model should return both output + latent_y (for cat loss) + latent_zy (for KL loss) self.model_output = [self.model_output, self.latent_y, self.latent_zy] model = tf.keras.Model(self.encoder_input, self.model_output, name=self.config["name"]) return model