Исходный код qiskit_machine_learning.algorithms.distribution_learners.qgan.pytorch_discriminator

# This code is part of Qiskit.
# (C) Copyright IBM 2019, 2022.
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

PyTorch Discriminator Neural Network

from typing import Dict, Any, Iterable, Optional, Sequence, cast
import os
import numpy as np
from qiskit.utils import QuantumInstance
import qiskit_machine_learning.optionals as _optionals
from ....deprecation import deprecate_function
from .discriminative_network import DiscriminativeNetwork

[документация]@_optionals.HAS_TORCH.require_in_instance class PyTorchDiscriminator(DiscriminativeNetwork): """ Discriminator based on PyTorch """ @deprecate_function( "0.5.0", additional_msg="with no direct replacement for it. " "Instead, please refer to the new QGAN tutorial", stack_level=3, ) def __init__(self, n_features: int = 1, n_out: int = 1) -> None: """ Args: n_features: Dimension of input data vector. n_out: Dimension of the discriminator's output vector. """ super().__init__() # pylint: disable=import-outside-toplevel from torch import optim self._n_features = n_features self._n_out = n_out # discriminator_net: torch.nn.Module or None, Discriminator network. # pylint: disable=import-outside-toplevel from ._pytorch_discriminator_net import DiscriminatorNet self._discriminator = DiscriminatorNet(self._n_features, self._n_out) # optimizer: torch.optim.Optimizer or None, Optimizer initialized w.r.t # discriminator network parameters. self._optimizer = optim.Adam(self._discriminator.parameters(), lr=1e-5, amsgrad=True) self._ret = {} # type: Dict[str, Any]
[документация] def set_seed(self, seed: int): """ Set seed. Args: seed: seed """ import torch torch.manual_seed(seed)
[документация] def save_model(self, snapshot_dir: str): """ Save discriminator model Args: snapshot_dir: directory path for saving the model """ import torch, os.path.join(snapshot_dir, ""))
[документация] def load_model(self, load_dir: str): """ Load discriminator model Args: load_dir: file with stored PyTorch discriminator model to be loaded """ import torch self._discriminator = torch.load(load_dir)
@property def discriminator_net(self): """ Get discriminator Returns: object: discriminator object """ return self._discriminator @discriminator_net.setter def discriminator_net(self, net): self._discriminator = net
[документация] def get_label(self, x, detach=False): # pylint: disable=arguments-differ """ Get data sample labels, i.e. true or fake. Args: x (Union(numpy.ndarray, torch.Tensor)): Discriminator input, i.e. data sample. detach (bool): if None detach from torch tensor variable (optional) Returns: torch.Tensor: Discriminator output, i.e. data label """ # pylint: disable=not-callable, no-member import torch if isinstance(x, torch.Tensor): pass else: from torch.autograd.variable import Variable x = torch.tensor(x, dtype=torch.float32) x = Variable(x) if detach: return self._discriminator.forward(x).detach().numpy() else: return self._discriminator.forward(x)
[документация] def loss(self, x, y, weights=None): """ Loss function Args: x (torch.Tensor): Discriminator output. y (torch.Tensor): Label of the data point weights (torch.Tensor): Data weights. Returns: torch.Tensor: Loss w.r.t to the generated data points. """ from torch import nn if weights is not None: loss_funct = nn.BCELoss(weight=weights, reduction="sum") else: loss_funct = nn.BCELoss() return loss_funct(x, y)
[документация] def gradient_penalty(self, x, lambda_=5.0, k=0.01, c=1.0): """ Compute gradient penalty for discriminator optimization Args: x (numpy.ndarray): Generated data sample. lambda_ (float): Gradient penalty coefficient 1. k (float): Gradient penalty coefficient 2. c (float): Gradient penalty coefficient 3. Returns: torch.Tensor: Gradient penalty. """ # pylint: disable=not-callable, no-member import torch if isinstance(x, torch.Tensor): pass else: from torch.autograd.variable import Variable x = torch.tensor(x, dtype=torch.float32) x = Variable(x) # pylint: disable=no-member delta_ = torch.rand(x.size()) * c z = torch.autograd.Variable(x + delta_, requires_grad=True) o_l = self.get_label(z) # pylint: disable=no-member d_g = torch.autograd.grad(o_l, z, grad_outputs=torch.ones(o_l.size()), create_graph=True)[ 0 ].view(z.size(0), -1) return lambda_ * ((d_g.norm(p=2, dim=1) - k) ** 2).mean()
[документация] def train( self, data: Iterable, weights: Iterable, penalty: bool = False, quantum_instance: Optional[QuantumInstance] = None, shots: Optional[int] = None, ) -> Dict[str, Any]: """ Perform one training step w.r.t to the discriminator's parameters Args: data: Data batch. weights: Data sample weights. penalty: Indicate whether or not penalty function is applied to the loss function. Ignored if no penalty function defined. quantum_instance (QuantumInstance): used to run Quantum network. Ignored for a classical network. shots: Number of shots for hardware or qasm execution. Ignored for classical network Returns: dict: with discriminator loss and updated, weights, penalty=True, quantum_instance=None, shots=None) -> Dict[str, Any]: """ # pylint: disable=E1101 # pylint: disable=E1102 import torch from torch.autograd.variable import Variable # Reset gradients self._optimizer.zero_grad() real_batch = cast(Sequence, data)[0] real_prob = cast(Sequence, weights)[0] generated_batch = cast(Sequence, data)[1] generated_prob = cast(Sequence, weights)[1] real_batch = np.reshape(real_batch, (len(real_batch), self._n_features)) real_batch = torch.tensor(real_batch, dtype=torch.float32) real_batch = Variable(real_batch) real_prob = np.reshape(real_prob, (len(real_prob), 1)) real_prob = torch.tensor(real_prob, dtype=torch.float32) # Train on Real Data prediction_real = self.get_label(real_batch) # Calculate error and back propagate error_real = self.loss(prediction_real, torch.ones(len(prediction_real), 1), real_prob) error_real.backward() # Train on Generated Data generated_batch = np.reshape(generated_batch, (len(generated_batch), self._n_features)) generated_prob = np.reshape(generated_prob, (len(generated_prob), 1)) generated_prob = torch.tensor(generated_prob, dtype=torch.float32) prediction_fake = self.get_label(generated_batch) # Calculate error and back propagate error_fake = self.loss( prediction_fake, torch.zeros(len(prediction_fake), 1), generated_prob ) error_fake.backward() if penalty: self.gradient_penalty(real_batch).backward() # pylint: enable=E1101 # pylint: enable=E1102 # Update weights with gradients self._optimizer.step() # Return error and predictions for real and fake inputs loss_ret = 0.5 * (error_real + error_fake) self._ret["loss"] = loss_ret.detach().numpy() params = [] for param in self._discriminator.parameters(): params.append( self._ret["params"] = params return self._ret