# Quellcode für qiskit.circuit.library.arithmetic.linear_amplitude_function

# This code is part of Qiskit.
#
# (C) Copyright IBM 2017, 2021.
#
# obtain a copy of this license in the LICENSE.txt file in the root directory
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

"""A class implementing a (piecewise-) linear function on qubit amplitudes."""

from __future__ import annotations
import numpy as np
from qiskit.circuit import QuantumCircuit

from .piecewise_linear_pauli_rotations import PiecewiseLinearPauliRotations

[Doku]class LinearAmplitudeFunction(QuantumCircuit):
r"""A circuit implementing a (piecewise) linear function on qubit amplitudes.

An amplitude function :math:F of a function :math:f is a mapping

.. math::

F|x\rangle|0\rangle = \sqrt{1 - \hat{f}(x)} |x\rangle|0\rangle + \sqrt{\hat{f}(x)}
|x\rangle|1\rangle.

for a function :math:\hat{f}: \{ 0, ..., 2^n - 1 \} \rightarrow [0, 1], where
:math:|x\rangle is a :math:n qubit state.

This circuit implements :math:F for piecewise linear functions :math:\hat{f}.
In this case, the mapping :math:F can be approximately implemented using a Taylor expansion
and linearly controlled Pauli-Y rotations, see [1, 2] for more detail. This approximation
uses a rescaling_factor to determine the accuracy of the Taylor expansion.

In general, the function of interest :math:f is defined from some interval :math:[a,b],
the domain to :math:[c,d], the image, instead of :math:\{ 1, ..., N \} to
:math:[0, 1]. Using an affine transformation we can rescale :math:f to :math:\hat{f}:

.. math::

\hat{f}(x) = \frac{f(\phi(x)) - c}{d - c}

with

.. math::

\phi(x) = a + \frac{b - a}{2^n - 1} x.

If :math:f is a piecewise linear function on :math:m intervals
:math:[p_{i-1}, p_i], i \in \{1, ..., m\} with slopes :math:\alpha_i and
offsets :math:\beta_i it can be written as

.. math::

f(x) = \sum_{i=1}^m 1_{[p_{i-1}, p_i]}(x) (\alpha_i x + \beta_i)

where :math:1_{[a, b]} is an indication function that is 1 if the argument is in the interval
:math:[a, b] and otherwise 0. The breakpoints :math:p_i can be specified by the
breakpoints argument.

References:

: Woerner, S., & Egger, D. J. (2018).
Quantum Risk Analysis.
arXiv:1806.06893 <http://arxiv.org/abs/1806.06893>_

: Gacon, J., Zoufal, C., & Woerner, S. (2020).
Quantum-Enhanced Simulation-Based Optimization.
arXiv:2005.10780 <http://arxiv.org/abs/2005.10780>_
"""

def __init__(
self,
num_state_qubits: int,
slope: float | list[float],
offset: float | list[float],
domain: tuple[float, float],
image: tuple[float, float],
rescaling_factor: float = 1,
breakpoints: list[float] | None = None,
name: str = "F",
) -> None:
r"""
Args:
num_state_qubits: The number of qubits used to encode the variable :math:x.
slope: The slope of the linear function. Can be a list of slopes if it is a piecewise
linear function.
offset: The offset of the linear function. Can be a list of offsets if it is a piecewise
linear function.
domain: The domain of the function as tuple :math:(x_\min{}, x_\max{}).
image: The image of the function as tuple :math:(f_\min{}, f_\max{}).
rescaling_factor: The rescaling factor to adjust the accuracy in the Taylor
approximation.
breakpoints: The breakpoints if the function is piecewise linear. If None, the function
is not piecewise.
name: Name of the circuit.
"""
if not hasattr(slope, "__len__"):
slope = [slope]
if not hasattr(offset, "__len__"):
offset = [offset]

# ensure that the breakpoints include the first point of the domain
if breakpoints is None:
breakpoints = [domain]
else:
if not np.isclose(breakpoints, domain):
breakpoints = [domain] + breakpoints

_check_sizes_match(slope, offset, breakpoints)
_check_sorted_and_in_range(breakpoints, domain)

self._domain = domain
self._image = image
self._rescaling_factor = rescaling_factor

# do rescalings
a, b = domain
c, d = image

mapped_breakpoints = []
mapped_slope = []
mapped_offset = []
for i, point in enumerate(breakpoints):
mapped_breakpoint = (point - a) / (b - a) * (2**num_state_qubits - 1)
mapped_breakpoints += [mapped_breakpoint]

# factor (upper - lower) / (2^n - 1) is for the scaling of x to [l,u]
# note that the +l for mapping to [l,u] is already included in
# the offsets given as parameters
mapped_slope += [slope[i] * (b - a) / (2**num_state_qubits - 1)]
mapped_offset += [offset[i]]

# approximate linear behavior by scaling and contracting around pi/4
slope_angles = np.zeros(len(breakpoints))
offset_angles = np.pi / 4 * (1 - rescaling_factor) * np.ones(len(breakpoints))
for i in range(len(breakpoints)):
slope_angles[i] = np.pi * rescaling_factor * mapped_slope[i] / 2 / (d - c)
offset_angles[i] += np.pi * rescaling_factor * (mapped_offset[i] - c) / 2 / (d - c)

# use PWLPauliRotations to implement the function
pwl_pauli_rotation = PiecewiseLinearPauliRotations(
num_state_qubits, mapped_breakpoints, 2 * slope_angles, 2 * offset_angles, name=name
)

super().__init__(*pwl_pauli_rotation.qregs, name=name)
self.append(pwl_pauli_rotation.to_gate(), self.qubits)

[Doku]    def post_processing(self, scaled_value: float) -> float:
r"""Map the function value of the approximated :math:\hat{f} to :math:f.

Args:
scaled_value: A function value from the Taylor expansion of :math:\hat{f}(x).

Returns:
The scaled_value mapped back to the domain of :math:f, by first inverting
the transformation used for the Taylor approximation and then mapping back from
:math:[0, 1] to the original domain.
"""
# revert the mapping applied in the Taylor approximation
value = scaled_value - 1 / 2 + np.pi / 4 * self._rescaling_factor
value *= 2 / np.pi / self._rescaling_factor

# map the value from [0, 1] back to the original domain
value *= self._image - self._image
value += self._image

return value

def _check_sorted_and_in_range(breakpoints, domain):
if breakpoints is None:
return

# check if sorted
if not np.all(np.diff(breakpoints) > 0):
raise ValueError("Breakpoints must be unique and sorted.")

if breakpoints < domain or breakpoints[-1] > domain:
raise ValueError("Breakpoints must be included in domain.")

def _check_sizes_match(slope, offset, breakpoints):
size = len(slope)
if len(offset) != size:
raise ValueError(f"Size mismatch of slope ({size}) and offset ({len(offset)}).")
if breakpoints is not None:
if len(breakpoints) != size:
raise ValueError(
f"Size mismatch of slope ({size}) and breakpoints ({len(breakpoints)})."
)