Source code for qiskit_experiments.data_processing.sklearn_discriminators

# This code is part of Qiskit.
#
# (C) Copyright IBM 2022.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

"""Discriminators that wrap SKLearn."""

from typing import Any, List, Dict, TYPE_CHECKING

from qiskit_experiments.data_processing.discriminator import BaseDiscriminator
from qiskit_experiments.warnings import HAS_SKLEARN

if TYPE_CHECKING:
    from sklearn.discriminant_analysis import (
        LinearDiscriminantAnalysis,
        QuadraticDiscriminantAnalysis,
    )


[docs] class SkLDA(BaseDiscriminator): """A wrapper for the scikit-learn linear discriminant analysis. .. note:: This class requires that scikit-learn is installed. """ def __init__(self, lda: "LinearDiscriminantAnalysis"): """ Args: lda: The sklearn linear discriminant analysis. This may be a trained or an untrained discriminator. Raises: DataProcessorError: If SKlearn could not be imported. """ self._lda = lda self.attributes = [ "coef_", "intercept_", "covariance_", "explained_variance_ratio_", "means_", "priors_", "scalings_", "xbar_", "classes_", "n_features_in_", "feature_names_in_", ] @property def discriminator(self) -> Any: """Return then SKLearn object.""" return self._lda
[docs] def is_trained(self) -> bool: """Return True if the discriminator has been trained on data.""" return not getattr(self._lda, "classes_", None) is None
[docs] def predict(self, data: List): """Wrap the predict method of the LDA.""" return self._lda.predict(data)
[docs] def fit(self, data: List, labels: List): """Fit the LDA. Args: data: The independent data. labels: The labels corresponding to data. """ self._lda.fit(data, labels)
[docs] def config(self) -> Dict[str, Any]: """Return the configuration of the LDA.""" attr_conf = {attr: getattr(self._lda, attr, None) for attr in self.attributes} return {"params": self._lda.get_params(), "attributes": attr_conf}
[docs] @classmethod @HAS_SKLEARN.require_in_call def from_config(cls, config: Dict[str, Any]) -> "SkLDA": """Deserialize from an object.""" from sklearn.discriminant_analysis import LinearDiscriminantAnalysis lda = LinearDiscriminantAnalysis() lda.set_params(**config["params"]) for name, value in config["attributes"].items(): if value is not None: setattr(lda, name, value) return SkLDA(lda)
[docs] class SkQDA(BaseDiscriminator): """A wrapper for the SKlearn quadratic discriminant analysis. .. note:: This class requires that scikit-learn is installed. """ def __init__(self, qda: "QuadraticDiscriminantAnalysis"): """ Args: qda: The sklearn quadratic discriminant analysis. This may be a trained or an untrained discriminator. Raises: DataProcessorError: If SKlearn could not be imported. """ self._qda = qda self.attributes = [ "coef_", "intercept_", "covariance_", "explained_variance_ratio_", "means_", "priors_", "scalings_", "xbar_", "classes_", "n_features_in_", "feature_names_in_", "rotations_", ] @property def discriminator(self) -> Any: """Return then SKLearn object.""" return self._qda
[docs] def is_trained(self) -> bool: """Return True if the discriminator has been trained on data.""" return not getattr(self._qda, "classes_", None) is None
[docs] def predict(self, data: List): """Wrap the predict method of the QDA.""" return self._qda.predict(data)
[docs] def fit(self, data: List, labels: List): """Fit the QDA. Args: data: The independent data. labels: The labels corresponding to data. """ self._qda.fit(data, labels)
[docs] def config(self) -> Dict[str, Any]: """Return the configuration of the QDA.""" attr_conf = {attr: getattr(self._qda, attr, None) for attr in self.attributes} return {"params": self._qda.get_params(), "attributes": attr_conf}
[docs] @classmethod @HAS_SKLEARN.require_in_call def from_config(cls, config: Dict[str, Any]) -> "SkQDA": """Deserialize from an object.""" from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis qda = QuadraticDiscriminantAnalysis() qda.set_params(**config["params"]) for name, value in config["attributes"].items(): if value is not None: setattr(qda, name, value) return SkQDA(qda)