get_tf_dataset

NNDataset.get_tf_dataset(split='train', output_channels=None, is_conditional=False, repeat_y=False, add_c_to_y=False, shuffled=False)[source]

tf.data.Dataset of the desired split.

Parameters
  • split (str) – One of train, val, test.

  • output_channels (Optional[Iterable[str]]) – Channels that should be predicted by the neural network. Defaults to all input channels.

  • is_conditional (bool) – Whether to add condition information to x

  • repeat_y (Union[bool, int]) – Match output length to number of losses (otherwise keras will not work, even if its losses that do not need y).

  • add_c_to_y (bool) – Append condition to y. Needed for adversarial loss.

  • shuffled (bool) – Shuffle indices before generating data. Will produce same order every time.

Returns

The dataset.

Return type

tf.data.Dataset