French
Langues
English
Bengali
French
Hindi
Japanese
Korean
Russian
Spanish
Tamil
Turkish



PyTorchDiscriminator

class PyTorchDiscriminator(n_features=1, n_out=1)[source]

Bases : DiscriminativeNetwork

Discriminator based on PyTorch

Paramètres:
  • n_features (int) – Dimension of input data vector.

  • n_out (int) – Dimension of the discriminator’s output vector.

Attributes

discriminator_net

Get discriminator

Methods

get_label(x[, detach])

Get data sample labels, i.e. true or fake.

gradient_penalty(x[, lambda_, k, c])

Compute gradient penalty for discriminator optimization

load_model(load_dir)

Load discriminator model

loss(x, y[, weights])

Loss function

save_model(snapshot_dir)

Save discriminator model

set_seed(seed)

Set seed.

train(data, weights[, penalty, ...])

Perform one training step w.r.t to the discriminator's parameters