# Source code for qiskit_machine_learning.algorithms.distribution_learners.qgan.discriminative_network

```
# This code is part of Qiskit.
#
# (C) Copyright IBM 2019, 2022.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.
""" Discriminative Quantum or Classical Neural Networks."""
from typing import List, Iterable, Optional, Dict
from abc import ABC, abstractmethod
import numpy as np
from qiskit.utils import QuantumInstance
from ....deprecation import deprecate_function
[docs]class DiscriminativeNetwork(ABC):
"""
Base class for discriminative Quantum or Classical Neural Networks.
This method should initialize the module but
raise an exception if a required component of the module is not available.
"""
@abstractmethod
@deprecate_function(
"0.5.0",
additional_msg="with no direct replacement for it. "
"Instead, please refer to the new QGAN tutorial",
stack_level=3,
)
def __init__(self) -> None:
super().__init__()
self._num_parameters = 0
self._num_qubits = 0
self._bounds = [] # type: List[object]
[docs] @abstractmethod
def set_seed(self, seed):
"""
Set seed.
Args:
seed (int): seed
Raises:
NotImplementedError: not implemented
"""
raise NotImplementedError()
[docs] @abstractmethod
def get_label(self, x: Iterable):
"""
Apply quantum/classical neural network to the given input sample and compute
the respective data label
Args:
x (Discriminator): input, i.e. data sample.
Raises:
NotImplementedError: not implemented
"""
raise NotImplementedError()
[docs] @abstractmethod
def save_model(self, snapshot_dir: str):
"""
Save discriminator model
Args:
snapshot_dir: Directory to save the model
Raises:
NotImplementedError: not implemented
"""
raise NotImplementedError()
[docs] @abstractmethod
def loss(self, x: Iterable, y: Iterable, weights: Optional[np.ndarray] = None):
"""
Loss function used for optimization
Args:
x: output.
y: the data point
weights: Data weights.
Returns:
Loss w.r.t to the generated data points.
Raises:
NotImplementedError: not implemented
"""
raise NotImplementedError()
[docs] @abstractmethod
def train(
self,
data: Iterable,
weights: Iterable,
penalty: bool = False,
quantum_instance: Optional[QuantumInstance] = None,
shots: Optional[int] = None,
) -> Dict:
"""
Perform one training step w.r.t to the discriminator's parameters
Args:
data: Data batch.
weights: Data sample weights.
penalty: Indicate whether or not penalty function
is applied to the loss function. Ignored if no penalty function defined.
quantum_instance (QuantumInstance): used to run Quantum network.
Ignored for a classical network.
shots: Number of shots for hardware or qasm execution.
Ignored for classical network
Returns:
dict: with discriminator loss and updated parameters.
Raises:
NotImplementedError: not implemented
"""
raise NotImplementedError()
```