{ "cells": [ { "cell_type": "markdown", "id": "secondary-copying", "metadata": {}, "source": [ "# Torch Connector and Hybrid QNNs\n", "\n", "This tutorial introduces the `TorchConnector` class, and demonstrates how it allows for a natural integration of any `NeuralNetwork` from Qiskit Machine Learning into a PyTorch workflow. `TorchConnector` takes a `NeuralNetwork` and makes it available as a PyTorch `Module`. The resulting module can be seamlessly incorporated into PyTorch classical architectures and trained jointly without additional considerations, enabling the development and testing of novel **hybrid quantum-classical** machine learning architectures.\n", "\n", "## Content:\n", "\n", "[Part 1: Simple Classification & Regression](#Part-1:-Simple-Classification-&-Regression)\n", "\n", "The first part of this tutorial shows how quantum neural networks can be trained using PyTorch's automatic differentiation engine (`torch.autograd`, [link](https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html)) for simple classification and regression tasks. \n", "\n", "1. [Classification](#1.-Classification)\n", " 1. Classification with PyTorch and `EstimatorQNN`\n", " 2. Classification with PyTorch and `SamplerQNN`\n", "2. [Regression](#2.-Regression)\n", " 1. Regression with PyTorch and `EstimatorQNN`\n", "\n", "[Part 2: MNIST Classification, Hybrid QNNs](#Part-2:-MNIST-Classification,-Hybrid-QNNs)\n", "\n", "The second part of this tutorial illustrates how to embed a (Quantum) `NeuralNetwork` into a target PyTorch workflow (in this case, a typical CNN architecture) to classify MNIST data in a hybrid quantum-classical manner.\n", "\n", "***" ] }, { "cell_type": "code", "execution_count": 1, "id": "banned-helicopter", "metadata": {}, "outputs": [], "source": [ "# Necessary imports\n", "\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", "from torch import Tensor\n", "from torch.nn import Linear, CrossEntropyLoss, MSELoss\n", "from torch.optim import LBFGS\n", "\n", "from qiskit import QuantumCircuit\n", "from qiskit.circuit import Parameter\n", "from qiskit.circuit.library import RealAmplitudes, ZZFeatureMap\n", "from qiskit_algorithms.utils import algorithm_globals\n", "from qiskit_machine_learning.neural_networks import SamplerQNN, EstimatorQNN\n", "from qiskit_machine_learning.connectors import TorchConnector\n", "\n", "# Set seed for random generators\n", "algorithm_globals.random_seed = 42" ] }, { "cell_type": "markdown", "id": "unique-snapshot", "metadata": {}, "source": [ "## Part 1: Simple Classification & Regression" ] }, { "cell_type": "markdown", "id": "surgical-penetration", "metadata": {}, "source": [ "### 1. Classification\n", "\n", "First, we show how `TorchConnector` allows to train a Quantum `NeuralNetwork` to solve a classification tasks using PyTorch's automatic differentiation engine. In order to illustrate this, we will perform **binary classification** on a randomly generated dataset." ] }, { "cell_type": "code", "execution_count": 2, "id": "secure-tragedy", "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Generate random dataset\n", "\n", "# Select dataset dimension (num_inputs) and size (num_samples)\n", "num_inputs = 2\n", "num_samples = 20\n", "\n", "# Generate random input coordinates (X) and binary labels (y)\n", "X = 2 * algorithm_globals.random.random([num_samples, num_inputs]) - 1\n", "y01 = 1 * (np.sum(X, axis=1) >= 0) # in { 0, 1}, y01 will be used for SamplerQNN example\n", "y = 2 * y01 - 1 # in {-1, +1}, y will be used for EstimatorQNN example\n", "\n", "# Convert to torch Tensors\n", "X_ = Tensor(X)\n", "y01_ = Tensor(y01).reshape(len(y)).long()\n", "y_ = Tensor(y).reshape(len(y), 1)\n", "\n", "# Plot dataset\n", "for x, y_target in zip(X, y):\n", " if y_target == 1:\n", " plt.plot(x[0], x[1], \"bo\")\n", " else:\n", " plt.plot(x[0], x[1], \"go\")\n", "plt.plot([-1, 1], [1, -1], \"--\", color=\"black\")\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "hazardous-rehabilitation", "metadata": {}, "source": [ "#### A. Classification with PyTorch and `EstimatorQNN`\n", "\n", "Linking an `EstimatorQNN` to PyTorch is relatively straightforward. Here we illustrate this by using the `EstimatorQNN` constructed from a feature map and an ansatz." ] }, { "cell_type": "code", "execution_count": 3, "id": "fewer-desperate", "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjMAAACuCAYAAADDNYx2AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAArdklEQVR4nO3deXhM1+PH8XdWiRAkQhBLhCCI2Jfam6CKoqjWTpX2p/RLo3vRb/tVS7VVrdJqUS3aaq2tpbXvQa2xhiCSIBJbFll/f6SmRiaRySKd+Lyex/PIvWfOOXOfc+985txlrNLS0tIQERERsVDWBd0BERERkdxQmBERERGLpjAjIiIiFk1hRkRERCyawoyIiIhYNIUZERERsWgKMyIiImLRFGZERETEoinMiIiIiEVTmBERERGLpjAjIiIiFk1hRkRERCyawoyIiIhYNIUZERERsWgKMyIiImLRFGZERETEoinMiIiIiEVTmBERERGLpjAjIiIiFk1hRkRERCyawoyIiIhYNIUZERERsWgKMyIiImLRFGZERETEoinMiIiIiEVTmBERERGLpjAjIiIiFk1hRkRERCyawoyIiIhYNIUZERERsWgKMyIiImLRFGZERETEoinMiIiIiEVTmBERERGLpjAjIiIiFk1hRkRERCyawoyIiIhYNNuC7oCYLy0NUpMKuhfyqLG2Ayur3NWhsSsikDfHk3spzFig1CTYNLOgeyGPmnajwcY+d3Vo7IoI5M3x5F46zSQiIiIWTWFGRERELJrCjIiIiFg0hRkRERGxaAozIiIiYtEUZkRERMSiKcyIiIiIRVOYEREREYumMCMiIiIWTWFGRERELJrCjIiIiFg0hRkRERGxaAozIiIiYtEKfZiJiopi/PjxVKtWDQcHBypWrMiYMWOIjY1l2LBhWFlZMWvWrILupoiIiOSQbUF3ID8dPHiQJ554gsjISJycnPDx8SE8PJyZM2cSEhJCdHQ0AH5+fgXb0XyQmprKr9s/Zc3uOUTGhFLSyY3W9fowqON7ONo7FXT3RDKlsSsi5iq0MzNRUVF07dqVyMhIxo0bR0REBAcOHCAyMpIpU6awZs0agoKCsLKywtfXt6C7m+dmr/oPX64aS6WyPozq/hmtfXuzfPtM3v2mK6mpqQXdPZFMaeyKiLkK7czM6NGjCQsLY9SoUUyfPt1o3fjx4/nhhx84dOgQnp6eODs7F1Av80do5DFW7PiMlnV6MmHQMsNydxdPPl8xms2HltC+/nMF2EMR0zR2RSQnCuXMzPHjx1m6dCmlS5dm8uTJJss0bNgQgHr16hktP3fuHN26daN48eKUKlWKgQMHcu3atXzvc17adHAxaWlp9Gz1itHyzk2H42BXlD8OLCqYjok8gMauiOREoQwzixcvJjU1lX79+lGsWDGTZRwdHQHjMHPr1i3atWtHWFgYixcvZu7cuWzbto0uXbpY1PT2yYtBWFtZU6NSE6Pl9nYOVC3vx6mLQQXUM5GsaeyKSE4UytNMGzduBKBdu3aZlgkLCwOMw8zcuXO5dOkSW7dupVKlSgB4eHjQokULVq5cSffu3fOv03no2s1wnJ1KY29bJMO60iUqEHx+J0nJidjZ2hdA70Qyp7ErIjlRKMPM+fPnAahcubLJ9cnJyezYsQMwDjOrV6+mZcuWhiAD0Lx5c6pWrcqqVatyHGYaNWpEZGRkjl5rir2tI3NHnc50/Z3EOOxMfBikv9YhvUxSnD4QxCzVvauTmByfqzo0dkUETB9P3N3d2bdvX47qK5RhJjY2FoD4eNMH3qVLlxIVFUXx4sXx9PQ0LA8ODqZ3794ZyteuXZvg4OAc9ycyMpJLly7l+PX3c7ArmuX6IvZFib99xeS6xOSE9DIPqEPkfhHh4SQkxeWqDo1dEYG8OZ7cq1CGGXd3d2JiYjhw4ADNmzc3WhcREUFgYCAAvr6+WFlZGdbFxMRQsmTJDPW5uLhw8uTJXPUnL9nbOma53tW5PBcuB5OYfCfDdH3UjUuUcCqtb7ZitnLly+fJzExWNHZFHg2mjie5+awslGHG39+f48ePM2XKFAICAvD29gYgKCiIAQMGEBUVBTy8h+XldNosMymJsGlm5utrVGzM/lPrOXlhL3WrtjIsT0xK4Gz4QepWbZ2n/ZFHw+lTp7HJZY7Q2BURyJvjyb0K5d1M48ePx9XVlYsXL1K7dm3q1q1L9erVadKkCVWrVqV9+/ZAxtuyS5UqxfXr1zPUFx0djYuLy8Poep5oW+8ZrKys+GXbJ0bLf9vzFQlJcbSv369gOibyABq7IpIThXJmxsPDg23bthEYGMiWLVsIDQ3Fx8eHOXPmMHz4cLy8vICMYaZWrVomr40JDg6mdWvL+UboWa4u3Vr8Hyt2zGLigp40qdmZC1eOs3z7THyrttFDx+RfS2NXRHKiUIYZSA8mq1evzrD89u3bhIaGYm1tTZ06dYzWdenShTfffJOwsDA8PDwA2LNnDyEhIUybNu2h9DuvvNjtE8qWqsJve+ay9/ganJ1K0/2xlxnU8T2srQvlhJwUEhq7ImIuq7S0tLSC7sTDtGfPHpo1a0aNGjU4ceKE0bqbN29St25dSpcuzaRJk0hISGD8+PG4ubmxa9euf82B9EHXHYjkh3ajyfdrZkTk0ZAXx5N7/Ts+nR+iI0eOABlPMQE4OzuzceNGypUrR9++fXn++edp0aIFq1ev/tcEGRERETFWaE8zZSarMAPg5eVl8vSUiIiI/Ds9ctMNDwozIiIiYlkeuZmZu7/bJCIiIoXDIzczIyIiIoWLwoyIiIhYNIUZEZFH2KGQzQQEWrEuaH5Bd8Vs42a3pf//qhgtm7pkMAGBVqZf8JD7Ig/PI3fNjGQtuweB6SM3sXD9RA6f3fLAsgMCJjCww0QgfYfP7DWNvDsyefjabPfVXGcuHWTnseV0aDQYd5cq+daOOaYuGcyG/QsAmDU6iBoVG2Uos2zrx3y5aiwAr/b5lo6NBz/MLko+OhSymVe/bGe0zMHeCQ83b/wbDKD7Yy9jY1Owh+mU1BT6fVCJazfDGdThPfoHvFOg/cmJHUeXExJ+0HAcksJHYUaMvNb3u0zXRUSfZeH6CZRwKo2HWw2ee/wtnmjyvMmySSl3+HLVWOLv3MKncgujdXa2RRjb6+sMr3EtUT53nX+AkPCDfLdhEvW82v5rwsxd9rYOrAv61mSYWRf0Lfa2DiQmJxRAz+RhaOf3LE1qdiaNNGJuRbJh/0K+XDWWC1eO859ecwu0b0EnfufazXDKu3qxft98+vm/jZXVw5/5yK6xvb/ilae/NFq24+hyNuxfoDBTiCnMiBH/hv1NLk9IjGPMrOZYW9vwVv+luDqXw9W5XKb1fPTT88Ql3GRgh0k0qtHBaJ2NtW2m7ViyuIRbFHUonqPXPlanB5sPLmZktxnY2xYxLD95MYhzkUdoX/85Nv71Q151Vf5lqldoYLRPdG3xEsOm1uT3vV8zpNMHlCzmVmB9W7t3HuVdvRjRdQYT5j/FoZDN+FVr9+AXFhBbGzuwsSvobshDpjAj2fLRj0M5G3GYF7pMp3619lmWXbVzNmv3zqOZT1f6++dsSvrazQgWbXiPPSfWEHMrEmen0jSr1YXBnd6nVLEyhnJRN8L5eetH/HX6T65cP8+dpHjKuVQloNEgerd5FRtrGwAWrp/IdxsmARhN6wc0HMT4vvMN679741yGWZv+/6tC2VJV+OjFzf+8LtCKgIaD8G84gIXrJxASfhBvj0aGMicv7mPxnx9w5Nw24u/coqxLFfwbDqRv29dMnjbo2HgImw4uZufR5bT1e8awfF3Qt5R0cqNprS4ZwkxqaiqLN01m/8l1hF09xa34aEoVd6dpzScZ0ul9nJ1cDWUjo0MZMNmTAQET8HCrwZKNkwmLOkXJYmXo1Hgo/R5/u8BPZ8g/HO2dqFm5GdsO/0z4tRBDmMnr/eJBYm5dZvfx1fTzf4emNTtTslgZ1u6dZzLM3N1PXnrqU+asGsfxC7txsCvK4w0HMLzzFFJSk/l27dtsOriYm3HXqFmxCWOenkPlsrUMdawLms/0H4cwZfgGjoZuZ13Qt8TcisTDrQbPPv4m7fz6PrDPd0/dbpiW/ks9957avvc0+t1TtuNmt+VyTCiL3gw1qufefebeGZ1bcTF8tWY8O47+SmJSPN4VGzOi60eZ9ie7x4LQyGMs3DCR4NCd3IyNophjKSqVrUXvNq/StNaTD3zfjzodveSBftw8jc2HltK23jP0bjMuy7JHz+3gi5Vj8HDz5vW+32U6HX0jNirDsmKOpbCxtuFKzAVGz2pOckoinZoMo7yrF5eizrB612wOhmzii9H7cHIsAcC5iMPsOPILj9XpQTlXL1JSkwg6sZZ5v71O5LWzvNJrDgAt6/Qk+mYEa/bM5dn2b1KpTPoBtLyrV463y+mwfWw/uozOTYbTodEgw/I9x9cwaUFPypeuRq824yju6ELw+V0sXPcuIeEHeXfATxnqqlahPl7l/Vgb9I0hzCQmJbDp4GI6NhqS/m3zPskpify0eRqt6j5N89pP4WDvxKmLQawNmsfR0O18MWY/drbGP36yK3glEdfO0q3F/+FS3J1dwSv5bsMkLsecJ/CZb3O8LSTvRVwLAcC5qAtAvuwXD7Jh/0JS01IIaDgQGxtbHq/fj9W7vyQ2/oahrXtF3Qjj9bkBtPF7hla+vdh/aj3Lts7AxtqW85ePcScpnr7tXudGbBQ/b5nOxAXdmffq8Qw/F/P1b6+RkBhL1xYvAbA+6Fv+9/2zJCYlmH3N2HOPv0VaWipHzm0zOo1eu0qLLF5lWnJKEm983ZGTF4PwbzCAWpWbERJ+kNfm+uNc1DVD+eweC27GXiNwTvqXxC7NRlK2VGVuxEZxKmwfxy/sUZjJBoUZydKBU38w7/c38HSvy9g+87IsG3UjnP9+1ws72yJMHPSryYMdQEJiLL0mZpw2nxd4nEplajJr+cukpCQx+5W/cCvpYVjf2rc3o2c1Y9m2jw3flHy92rDwjbNGoalnq1f4cPEAft/7NQM6TMTVuRxVy/tSq3Jz1uyZS0PvAOp5tTV/Y9wn9PIxpgzfQANvf8OyxKQEPvpxGDUrNWXaiI2Gb15dmo/Aq3w9vlw1lkMhm02236nxUGavfIWr18NwK+nB9qO/cDv+Oh2bDOXilRMZytvZFmHpuxEUsXP8Z2HzkfhUacGMn55n57HltKnXx+g1Z8MPMWt0ENU9GgDw1GOjmLSgJ+v3zefJZiPwqdws19tFzJeQFMeN2CjS0tKvmVm160vOXPqLmhWb4OHmDZAv+8WDrA36hrqerQ2zlQGNBrFs28ds/OsHurZ4MUP58GshvN3/R9rU6w1A1+YjeemThvy0ZRrNanVl6gt/GPrk7OTKFyvGsP/0BhrX6GhUz43YKOaOPWw4hnRtNpIXZvgyZ9VY2vo9YzzmH6ChdwB/HvieI+e25fr09rqgbzl5MYj+/u8yqOMkw/LKZX2YvfI/lC1V2bDMnGPB0dAdXL99hbf7L82wz0r26NZsyVRkdCgffN+XokWKM3HQrzjaO2VaNik5kfcWPk30rUgC+8ynclmfTMva2zowZfiGDP/KlKxEbPwN9hxfTbPa3bC3c+BGbJThn7tLFSq4VmP/qfWGuorYORoOjknJidyMi+ZGbBSNvDuSmpbKqbB9ebdB7lO1XD2jIAOw//QGYm5fpkPjIdxOuG7U/yY1OwOw757+36t9g37Y2Nix/u+7m9IvCG6Mp3sdk+WtrKwMB/WU1BRux6e35/f3acDjF/ZkeE2D6gGGIHO3jj5txwOw4+iv5rx9yUML10+g10Q3ek8qwwszfFm16wta1unJpMErAApkvzgWupOLV04QcM+so1f5eoYZRFNKl6hgCDJ31fZsSVpaGt0fe9koXNX1bAXApajTGerp2vxFoy9DTo4l6NJ8JLfiYzgUsvmBfc8vO44tx9rahl73zVB3af4iRR2cjZaZcyxwckh/r3tP/E5sws2H8E4KH83MiEkJiXFMXNCD2/Ex/HfoasqXzvp0zKzlozh+YTd9271OK9+nsyxrbW2TIQTcdeLCXlLTUlm7dx5r95qeCSrnUtXw/5SUZJZs+pAN+xcSfu0MaWlpRmVvx8Vk2ZfcuPuN+V4XLh8H0q8xysz1W5dNLncu6kJzn26s3zcf/wb9OXhmI6O6z8qyD1sO/cjPWz7iTPhfJKckGa27HZ/xvVe65/qEu+4Gz4jos1m2JfnnyaYv0Nq3N8mpSZyLOMLSzVOIuhGGvZ0DABevnnzo+8XavfOwtbGjWvn6XIo6Y1jeqEZHlm6awtnww1Qt72v0GncXzwz1FHcsZXJdsb+X34q9luE1d08D36tymb/H6bWCG6eR187iWrwcTvcFF3vbIpRzqWq0z5lzLKjn1YaAhgNZv28+G//6Hm+PxjSo7k9bv2ey/GIo/1CYEZM+/nk4IeEHGdzxvzSp+USWZdfsnstve76ioXcHhnT6IFftppF+0H28QX86NBxksoz9PVPMX64ay/Idn9G23jM89/hblCxWBltrO05fOsDXv71Galpqttq1IvNbTVNSk00uL2JXNNP+v/DkNLzK+5l8XVa3oHdqPJQ35z3BjJ+GY2tjT7v6z2ZadtuRX3h/0TPUrNiEl7p9ilvJitjbOpCSlsKbX3ciNTV7710KXoXS1Q0Bv0nNJ6jj2ZL/fNGST5eN5K3+Sx76fhF/5zZbDv9IckoSL35S32SZtUHf8NJTnxgts7bK/MJi60wuOr773gpKZtf1ZbbfZ5e5x4LxfRfQu20gQSd+58i5bfy89SN+2PgBL3b7hO6PjcpVXx4FCjOSwc9bZrDxrx9oUfspnnv8rSzLBp/fzefLX8bdxZM3+y3OcCGfuSq4VsPKyorklMRMZ2/u9ceB76hbtTVv9V9itPzStTMZymb1bIzif19keSsu2uhupsSkBKJvRlDetVr2+l+6OpD+4LPs9P9+Db074FbCgwOnN9C+/nMUcyyZadk/93+Hva0D00ZuwsH+n2B1wcT1NYZ1f39bvNf5y8GA8Td7KVi1q7TAv8EANuxfSPeWo6noViPf9gtTthz6kfg7txn6xP8MY/pey7fP5M8Dixj+5NQMF5nnhQtXjtOCp4yWnb/y9zh1NX+cZrnvO7pwOm5/huWmZoDcXauy/9R6YhNuGs3OJCbfISL6rGEWCnJ2LPB0r4Onex36tA3kdvx1Xv6sKfN+e52nWvzfv/rZPv8GumZGjBw8s4mvfhtPRbcajO+7MMsdKPpmJO8tfBpraxsmDPzFcNdFbjg7udKkZme2H/mF4PO7M6xPS0vj+u2rhr+trWzgvin0+MRYftn2cYbXOtoXA+BmXHSGdRX+PmV04PQfRsuXbfs427M7kD4FX7JYGZZs+tBkO3eS4olLuJXp662trRnV43MGBEzgmbavZdmWtbUNVlZWpN3Tv7S0NH744/1MX3Pg9AZOhx0wKv/j5qkAPFa7e5btycPVz/8drK1tWLDu3XzdL0z5fe88ihd1oU+bQFr79srwr1OTYdyMu8bOYyty9yYzsWrXbGLjbxj+jo2/wepdX1LMsSS+VduYXZ9jkcz3fQ83b+Lu3OLEhb2GZampqSa3VQufp0hNTeHnLca3Yq/eNZu4+651MedYcDMuOsNMajHHkriX8uROUpwemJkNmpkRg2s3I3h/UR9SU1NoWfdpdh1bmWnZquV8mfnrS1y7Gc5jdXoQGnmU0MijJsuWKl6Wht4B2e7H6J6z+c/nLRk3uzX+DQdSrXx90tJSiYg+y85jKwhoONBw10Yr316s2T2H9xc9Q4Pq/sTcuszaoG9M3iZZo2JjrK2sWfznB9yOj8HB3gl3F09qVWpKg+r+VHSrwYL173Iz7hruLp4cO7ed4xd2U8KpdLb77mjvxPi+C5k4vztDp9agY+OhVChdjdvx17l45QTbj/7CxEG/Znk3VYva3WhRu9sD22rl24ttR5YROKc9/g0HkpKSxI5jy7mTGJfpa6qWr0fgnPbpt2Y7l2PXsRUcOP0H/g0G4FOlebbfp+S/CqWr0a5eX/7863uOnN2Wb/vF/S5cOUHw+Z10aDQ402cPNffphq2NHWv3zstwwW9eKOFUmpc/a0qHxkOA9Fuzr1y/wNjeXxvNQmZXrUrNWLFjFp/98hJNaj2JrY0dNSs1pZyLJ52bvcDPWz9i4oIe9Gg5Bjtbe7Ye/tnkaaaOjYfw2565LPrjPSKjz+FTuTlnwv9i6+GfKO/qZfQac44Ff+xfyLKtH/NYnR6UL10NW2s7Dp/dwr5T62hTr49Zd289qhRmxCDs6knD818Wb/xflmUHBEzgWOgOIP0umKzuhPGt2sasMFOmZEW+eGU/SzdNYeexFfx5YBH2tg64laxIM5+uRrcujuw6g6JFirPl0I/sPLYCt5IVebLpC3hXbMxrc42ndsuUqsS4Pt+wdNMUZv7yIskpSQQ0HEStSk2xsbbhvSEr+Xz5aFbs+AxbG3saenfgoxe38Mrnj2W77wCNa3Rk1pgglm78kD8PLOJG7FWKOZaivKsXT7cai2c53wdXkg3t/PoSf+cWy7Z+zNzVr1LcsRTNfLoyrPOHPD3B9IdWc59u/zw07+pJShYrQz//d3L8cEPJX88+/habDi5mwfp3mT5yU77sF/e7e4Fxy7o9My1TvGgp6nm148DpDVy5fpEyJSvmzRv+2/Odp3Dk3DZW7vyc67cuU8HNmzee+5729Z/LUX3t/J7lzKW/2HxoCVsP/0RqWiqv9vmWci6elHPxZOKg5Xzz+5ssWPcOxZ1c8W8wgE6NhzJ0Wk2jeuxs7fnwhQ18tTqQHceWs/3IMrwrNubD4RuYu/pVLseEGpXP7rHAt2pbzlz6iz3HVxN9MwJraxvcXTx5oct0ntL1MtlilXb/Ze7yr5eSCJtmFnQvxJJk9jRTc7QbDTa5vDxCY1eycvcJwNNHbsqTZ0HJv1deHE/upWtmRERExKIpzIiIiIhFU5gRERERi6YLgEUeAe4uVQy/Iizyb9Wx8WCzf0hSBDQzIyIiIhZOYUZEREQsmsKMiIiIWDSFGREREbFoCjMiIiJi0RRmRERExKIpzIiIiIhFU5gRERERi6YwIxZvXdB8nnqnBC992siwLOb2Fd74qhODplRn+PQ6HD671bBu8g/96DPJnS9WvJKrdvv/rwpDptbgtz1fA+k/5jhudlueeqcEI2b4GZU9cnYbI2b4ERBoxe3467lqV0wLCT/EqJlNGDqtFm981Ynrt68CcChkM0++4ciIGX7E3L4CQEJiHB98/yyDPqzG4CnebD38s6GeuasDee6DSkyY3z1b7S7b+jFDp9ZkyNQafP/nB4blU5cMpu9/K/DJspGGZe8t7MUz/y2fYRzcSYpnxAw/ur5VjB1Hlz+wzdvx13n326cYOrUmIz+uz4kLew3rAgKtGP5RXfYc/w2ATQeXMGKGH8On12H49Dr8tOUjQ1lzx6U52/ib399i+Ed1GTHDjxEz/Nh0cImhnvzcxnctWDeBgEArzlw6aFj26pft6PmuC79s++SBbSYlJzJt6RCGTq3J89Nrs+f4GsO6+/f9e9//3X93kuKB/N3GACt3fsHQabX+3tb1SExKAPJvG5+LOGL0Pvv/rwo933UxlDdnG+clPQFYCgU/r3ZMGrzc8Pe8316nVuVmTB6+lpMXg5i4oAffvXEOWxs73njuexaun5gnoeKtfkupVsEPgKIOzgzp9D6xCTf45ve3jMrVrdqKOWMPEhBoles2xbRpSwfzap9vqVbBj7V7v2Hu6lcZ33cBAB5uNZgz9qCh7E9bpmNnU4QFr58hIvoco2c2xc+rHc5OrrzQZRqVy9Zm57HlD2zz5MUgth1ZxpdjD2FtZc2bXz9B7cot8KvWDoA+bQPp2eoVQ/kuzUbycs8v6DOprFE9RewcmTP2IONmt83We/127ds08A7gvSErOH85mPcW9uKrcUextk7/fvrxS9so5lgSALcSFZn8/FpcnN2Jjb/BS582xNujIfW82po9Ls3Zxn3aBjL0ifQPxagblxg2rRYNqvtTwql0vm5jgBMX9nIyLIiypSobLZ8+chNTlwzO1ntdvn0mzkVd+Wb8Ca7djGDc7DbU8WyFk4MzYLzvm3r/d+XnNt55dAV/Hviez0btxsmxBNdvX8XGxg4g37axZ7m6Rn347NdRWFn9897M2cZ5STMzYhEuXjnJs+97EHHtLAA/bZ7OG191IjU11WT5LYd+pEuz9G8SNSo2xtW5PIdDtpjd7oyfhvPZr6MAuBkXzcDJXkazPPdyLupCHc+WONg7md2O5M6ZS3/hWKSY4cMloNEgdgWvJCk50WT5LYeW0qV5+vgo5+KJr1dbth/91ex2/ziwiI6Nh2BvWwRbGzs6NRnG+n0LMi3fwNufUsXKmN3O/TYfXELnpsMBqFzWB7eSFTlyzvS4rOP5GC7O7gA4OZagYpmaREaHmt2mudv4bpgCiL9zmzTSSE0zvb9mxdxtnJAYx6zlo3jl6Tlmt2Xc7nd0bf4iAK7O5fDzasf2I7/kqs4HMXcb/7hlGgMCJuDkWAKAksXcsLG2Mbtdc7fxXYlJCWz863s6NR5mdpt5TTMzYhEqlqnB8Cen8d9FfRjRZTord37OZ6P3Gr6J3utm7DVSUpIMB3CAsqWqcOX6BbPbHdX9M17+rBlbDv3EH/sX8kTT5/Gt2jpX70XyXkT0OcP09113EuOIunnJZPkr1y8YfWt3z+H4iIw+x+7gVSzf/hkACUmxuDqXN7sec9yMi+Z2wnVentnUsOzqjYtERJ+jnlfbLF97/nIwwed3Mabnl2a3a+42Bvh1+0xW7vycqOth/Kf31zkKcuZu46/WjKdL8xcpU7Ki2W3d3+7EBT2wsko/xly/fZlSxd0zLR8RHcKLnzTA2sqGjo2H0K3FS2a3ae42vnA5mFNh+/huwySSUu4Q0HAgPVqONrvdnI7j7Ud/oZxLVaMZqoKiMCMWo339ZzkUsok3vurI1BF/UrKYW763aW/nwDsDfuL/ZjbCp1Jz+rZ7Pd/blJypWakpHw5fZ/i718T8Hx8AQzp9QPv6zwKw5/galm6emu9t2ljZGE31v/dd7we+5ur1MN6d/xRjen6JW0mPHLVr7jbu0XI0PVqOJiT8EB8u7k8j7w44O7ma3W52t/H+Uxu4EnOel3vMMrsNUyYPX4erczkg/RqUzFSr0IDFb4Xh5FiCq9fDeGteZ0o4laZNvT5mt2nONk5JTSYy+hwzXtrK7fgYxs1uQzmXqjTz6WJ2uzkZx7/vnUenJgU/KwM6zSQWJCUlmdDIoxQv6kLUjcy/DTo7uWJjbUv0zUjDsssxoZQpWSlH7YZdPYmDvRPXY6+QlGJ6ulcKVjmXqkYzK7EJN0lIjKW0cwWT5cuUrMTlmPOGvyNzOD7ubzcyOpRyLlXNrscczkVdsLdzJObWZcOyyw9oN+pGOK/N9aff42/Tpt6Dg48p5m7je3mVr0dp5wocCtmc63az2sYHz2zk9KUD9P9fFfr/rwpXb4Tx1jed2RW8yux23e9vNyaUcq6m23VycDac6nEr6UG7+s9y5Nw2s9vMyThuV/9ZbKxtKOFUmiY1O3P8wu5ct5udcRwRfY4T53fTvv5zZreXHxRmxGJ8/dvreLjVYMZL25i7+lUuRZ3JtGwr396s3p0+lX7yYhBRNy7h69XGZNkTF/YSOOdxk+uuxFxg5q8vMfWFP6hVqRmzc3kHlOSPahX8sLW2Y/+pDQCs2vkFbeo9g52tvcnyrX17s3pX+viIiD7H4ZDNPFanu8myUTcuMXRqTZPr/BsOYMO+BcQl3OJOUjy/7/2aDo0G5/r9ACzfMYt5v72RabvLd6SfEjgWupPbCdep69nKZNlrNyMYP/dx+rR7jQ6NBj2w3SmLB7L9SMbrh8zdxucvBxv+Hx4Vwpnwv6hU1sdk2bzaxsM6T2bJO5dY9GYoi94Mxa2EBx8M/Y3mPl1Nlt9+5FemLB6Yabsr/j7tEnb1NMfP7+KxOj1Mlr12M8Jw/V5cwi12B6+mWvn6JstC3m3jdvWfY9+JtUD6HXGHQjZTtVw9k2Xzehyv2/sNj9XpYXRtVEHSaSaxCLuDV7Pv5Fo+G70XB/uijOg6g/cX9eHT/9tpsvzwJ6fw4eIBDJpSHTsbe15/dhG2f1/lf7/LMaEUsXPMsDwlJZkPvu/L4I7/pXJZH0Z2+5hXZrVg88GltPV7JkP5hMQ4hkz1Jin5DrEJN3j2fQ/8GwxgWOfJuXvzki1vPPc9034cwsxfXqS8azVef25RpmV7tw3kox+HMnCyF9bWNozqMYsSTqVNlo26cQkba9OHSm+PhjzZbAQjP/YjjTQ6Nx1OvUxCM8Bb857kbMQhAJ6fXpsKpavz0YubTZa9cDk402/HQzq+z5QlAxn0YTUc7J1487nFJq8fA1iw7l2uxlzg122f8uu2TwHo0WoMnRoPMVn+VNg+umdy3YU52/irNeOJjD6HjbUdNja2jOo+i8pla5ksm5fb2ByXok5T9O+7k+7Xo+VoPlk2goGTvbC1seM/vb4y3Ml0v21HlrF612xsrG1JSU2mtW9vOmayfSHvtnGv1mP5ZNkIhk3zwcrKipZ1n8505i0vt3Fqairr981nfN+FmZZ52BRmxCI08+lidB64Tb3eWU6XlypelikvrM9W3YdCtpi8FsbGxpZPR/0Tluxti/DFK/szrcfBviiL3w7LVpuS9zzL1eWLMfuyVdbR3om3+y/NVtnDZ7fwTBbXSvVsNYaercZkq64Phq15cKG/nY04zPOdp5hc5+RYgveGrMhWPWN7f8XY3l9lq+z121cpXaICNSo2MrnenG38/tDV2SoHebuN77XozdAs1wef38mL3T4xuc7Wxo5X+3yTrXa6PzaK7o+NylbZvNzG9nYOhtu2HyQvt7G1tTU/vH0xW2UfFp1mEotXxM6RkPCDRg/Ny8rkH/rx54FFhm9ko3t+Th3Plma3W8LJjSmL+xsenJWVuw/OKlWsrOHuCHk4bG3suRV3LcPDxjIzd3UgSzZNpphjKSD9GRv+Dfub3a6TYwlW7vzC5APd7nf3oXkR0Wext3UA4JP/205Rh+Jmt1uqWFnGzW5jeGheVu4flyWLuTHlhQ1mt2kJ2xjSH+h25OwWw+MTJg1ejrtLFbPbzc2+/6ht44fFKi0tLe2htii5lpIIm2YWdC/kUdNuNNiYPnWfbRq7IgJ5czy5l74iioiIiEVTmBERERGLpjAjIiIiFk1hRkRERCyawoyIiIhYtEcizERFRTF+/HiqVauGg4MDFStWZMyYMcTGxjJs2DCsrKyYNStvfstDREREHq5C/9C8gwcP8sQTTxAZGYmTkxM+Pj6Eh4czc+ZMQkJCiI6OBsDPz69gO5rHFm+czOlLBzgdtp/I6HOULVX5gQ+QEvk30NgVEXMV6jATFRVF165diYyMZNy4cUyYMIHixdMfQjV16lRee+01bG1tsbKywtfXt4B7m7e++f1Nihd1oXqFBsTGXy/o7ohkm8auiJirUIeZ0aNHExYWxqhRo5g+fbrRuvHjx/PDDz9w6NAhPD09cXY2/Zsblmrh6yGGX3gdPr0O8Ym3C7hHItmjsSsi5iq018wcP36cpUuXUrp0aSZPNv1Dfw0bNgSgXr1/fmX0bvhp0qQJRYoUwcrK6qH0N69l9lP1Iv92GrsiYq5CG2YWL15Mamoq/fr1o1ixYibLODqm/1LyvWHmzJkzLFu2DHd3dxo3bvxQ+ioiIiI5V2jDzMaNGwFo165dpmXCwtJ/4fjeMNO6dWsiIiJYuXIl/v7++dtJERERybVCe83M+fPnAahcubLJ9cnJyezYsQMwDjPW1nmf7xo1akRkZGSe1Wdv68jcUafzrD6R7KjuXZ3E5Phc1aGxKyJg+nji7u7Ovn37clRfoQ0zsbGxAMTHmz74Ll26lKioKIoXL46np2e+9iUyMpJLly7lWX0OdkXzrC6R7IoIDychKS5XdWjsigjkzfHkXoU2zLi7uxMTE8OBAwdo3ry50bqIiAgCAwMB8PX1zfeLfN3d3fO0PntbxzytTyQ7ypUvnyczMyIipo4nufmsLLRhxt/fn+PHjzNlyhQCAgLw9vYGICgoiAEDBhAVFQU8nIfl5XTaLDMpibBpZp5WKfJAp0+dxsY+d3Vo7IoI5M3x5F6F9gLg8ePH4+rqysWLF6lduzZ169alevXqNGnShKpVq9K+fXvA+HoZERERsTyFdmbGw8ODbdu2ERgYyJYtWwgNDcXHx4c5c+YwfPhwvLy8gMIbZjbs/44rMekXQV+PvUpySiLf//E+AGVKVSag4YCC7J5IpjR2RcRchTbMANSqVYvVq1dnWH779m1CQ0OxtramTp06BdCz/Ld27zwOn91itGz+uncA8K3aRh8I8q+lsSsi5irUYSYzx44dIy0tDW9vb4oWzXh3xc8//wxAcHCw0d9VqlShUaNGD6+jufDRi5sLugsiOaKxKyLmeiTDzJEjR4DMTzH17t3b5N+DBg1i/vz5+do3ERERMY/CjAlpaWkPszsiIiKSC4X2bqasPCjMiIiIiOV4JGdm7v5uk4iIiFi+R3JmRkRERAoPhRkRERGxaAozIiIiYtEUZkRERMSiKcyIiIiIRVOYEREREYumMCMiIiIWTWFGRERELJrCjIiIiFg0hRkRERGxaAozIiIiYtGs0vQT0RYnLQ1Skwq6F/KosbYDK6vc1aGxKyKQN8eTeynMiIiIiEXTaSYRERGxaAozIiIiYtEUZkRERMSiKcyIiIiIRVOYEREREYumMCMiIiIWTWFGRERELJrCjIiIiFg0hRkRERGxaAozIiIiYtEUZkRERMSiKcyIiIiIRVOYEREREYumMCMiIiIWTWFGRERELJrCjIiIiFg0hRkRERGxaAozIiIiYtEUZkRERMSiKcyIiIiIRVOYEREREYumMCMiIiIWTWFGRERELJrCjIiIiFg0hRkRERGxaP8PJm115LDvRT4AAAAASUVORK5CYII=", "text/plain": [ "
" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Set up a circuit\n", "feature_map = ZZFeatureMap(num_inputs)\n", "ansatz = RealAmplitudes(num_inputs)\n", "qc = QuantumCircuit(num_inputs)\n", "qc.compose(feature_map, inplace=True)\n", "qc.compose(ansatz, inplace=True)\n", "qc.draw(output=\"mpl\", style=\"clifford\")" ] }, { "cell_type": "code", "execution_count": 4, "id": "humanitarian-flavor", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Initial weights: [-0.01256962 0.06653564 0.04005302 -0.03752667 0.06645196 0.06095287\n", " -0.02250432 -0.04233438]\n" ] } ], "source": [ "# Setup QNN\n", "qnn1 = EstimatorQNN(\n", " circuit=qc, input_params=feature_map.parameters, weight_params=ansatz.parameters\n", ")\n", "\n", "# Set up PyTorch module\n", "# Note: If we don't explicitly declare the initial weights\n", "# they are chosen uniformly at random from [-1, 1].\n", "initial_weights = 0.1 * (2 * algorithm_globals.random.random(qnn1.num_weights) - 1)\n", "model1 = TorchConnector(qnn1, initial_weights=initial_weights)\n", "print(\"Initial weights: \", initial_weights)" ] }, { "cell_type": "code", "execution_count": 5, "id": "likely-grace", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([-0.3285], grad_fn=<_TorchNNFunctionBackward>)" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Test with a single input\n", "model1(X_[0, :])" ] }, { "cell_type": "markdown", "id": "gorgeous-segment", "metadata": {}, "source": [ "##### Optimizer\n", "The choice of optimizer for training any machine learning model can be crucial in determining the success of our training's outcome. When using `TorchConnector`, we get access to all of the optimizer algorithms defined in the [`torch.optim`] package ([link](https://pytorch.org/docs/stable/optim.html)). Some of the most famous algorithms used in popular machine learning architectures include *Adam*, *SGD*, or *Adagrad*. However, for this tutorial we will be using the L-BFGS algorithm (`torch.optim.LBFGS`), one of the most well know second-order optimization algorithms for numerical optimization. \n", "\n", "##### Loss Function\n", "As for the loss function, we can also take advantage of PyTorch's pre-defined modules from `torch.nn`, such as the [Cross-Entropy](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html) or [Mean Squared Error](https://pytorch.org/docs/stable/generated/torch.nn.MSELoss.html) losses.\n", "\n", "\n", "**💡 Clarification :** \n", "In classical machine learning, the general rule of thumb is to apply a Cross-Entropy loss to classification tasks, and MSE loss to regression tasks. However, this recommendation is given under the assumption that the output of the classification network is a class probability value in the $[0, 1]$ range (usually this is achieved through a Softmax layer). Because the following example for `EstimatorQNN` does not include such layer, and we don't apply any mapping to the output (the following section shows an example of application of parity mapping with `SamplerQNN`s), the QNN's output can take any value in the range $[-1, 1]$. In case you were wondering, this is the reason why this particular example uses MSELoss for classification despite it not being the norm (but we encourage you to experiment with different loss functions and see how they can impact training results). " ] }, { "cell_type": "code", "execution_count": 6, "id": "following-extension", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "25.535646438598633\n", "22.696760177612305\n", "20.039228439331055\n", "19.68790626525879\n", "19.267210006713867\n", "19.025371551513672\n", "18.154708862304688\n", "17.33785629272461\n", "19.082544326782227\n", "17.07332420349121\n", "16.21839141845703\n", "14.992581367492676\n", "14.929339408874512\n", "14.914534568786621\n", "14.907638549804688\n", "14.902363777160645\n", "14.902134895324707\n", "14.90211009979248\n", "14.902111053466797\n" ] }, { "data": { "text/plain": [ "tensor(25.5356, grad_fn=)" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Define optimizer and loss\n", "optimizer = LBFGS(model1.parameters())\n", "f_loss = MSELoss(reduction=\"sum\")\n", "\n", "# Start training\n", "model1.train() # set model to training mode\n", "\n", "\n", "# Note from (https://pytorch.org/docs/stable/optim.html):\n", "# Some optimization algorithms such as LBFGS need to\n", "# reevaluate the function multiple times, so you have to\n", "# pass in a closure that allows them to recompute your model.\n", "# The closure should clear the gradients, compute the loss,\n", "# and return it.\n", "def closure():\n", " optimizer.zero_grad() # Initialize/clear gradients\n", " loss = f_loss(model1(X_), y_) # Evaluate loss function\n", " loss.backward() # Backward pass\n", " print(loss.item()) # Print loss\n", " return loss\n", "\n", "\n", "# Run optimizer step4\n", "optimizer.step(closure)" ] }, { "cell_type": "code", "execution_count": 7, "id": "efficient-bangkok", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy: 0.8\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Evaluate model and compute accuracy\n", "model1.eval()\n", "y_predict = []\n", "for x, y_target in zip(X, y):\n", " output = model1(Tensor(x))\n", " y_predict += [np.sign(output.detach().numpy())[0]]\n", "\n", "print(\"Accuracy:\", sum(y_predict == y) / len(y))\n", "\n", "# Plot results\n", "# red == wrongly classified\n", "for x, y_target, y_p in zip(X, y, y_predict):\n", " if y_target == 1:\n", " plt.plot(x[0], x[1], \"bo\")\n", " else:\n", " plt.plot(x[0], x[1], \"go\")\n", " if y_target != y_p:\n", " plt.scatter(x[0], x[1], s=200, facecolors=\"none\", edgecolors=\"r\", linewidths=2)\n", "plt.plot([-1, 1], [1, -1], \"--\", color=\"black\")\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "abstract-parish", "metadata": {}, "source": [ "The red circles indicate wrongly classified data points." ] }, { "cell_type": "markdown", "id": "typical-cross", "metadata": {}, "source": [ "#### B. Classification with PyTorch and `SamplerQNN`\n", "\n", "Linking a `SamplerQNN` to PyTorch requires a bit more attention than `EstimatorQNN`. Without the correct setup, backpropagation is not possible. \n", "\n", "In particular, we must make sure that we are returning a dense array of probabilities in the network's forward pass (`sparse=False`). This parameter is set up to `False` by default, so we just have to make sure that it has not been changed.\n", "\n", "**⚠️ Attention:** \n", "If we define a custom interpret function ( in the example: `parity`), we must remember to explicitly provide the desired output shape ( in the example: `2`). For more info on the initial parameter setup for `SamplerQNN`, please check out the [official qiskit documentation](https://qiskit-community.github.io/qiskit-machine-learning/stubs/qiskit_machine_learning.neural_networks.SamplerQNN.html)." ] }, { "cell_type": "code", "execution_count": 8, "id": "present-operator", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Initial weights: [ 0.0364991 -0.0720495 -0.06001836 -0.09852755]\n" ] } ], "source": [ "# Define feature map and ansatz\n", "feature_map = ZZFeatureMap(num_inputs)\n", "ansatz = RealAmplitudes(num_inputs, entanglement=\"linear\", reps=1)\n", "\n", "# Define quantum circuit of num_qubits = input dim\n", "# Append feature map and ansatz\n", "qc = QuantumCircuit(num_inputs)\n", "qc.compose(feature_map, inplace=True)\n", "qc.compose(ansatz, inplace=True)\n", "\n", "# Define SamplerQNN and initial setup\n", "parity = lambda x: \"{:b}\".format(x).count(\"1\") % 2 # optional interpret function\n", "output_shape = 2 # parity = 0, 1\n", "qnn2 = SamplerQNN(\n", " circuit=qc,\n", " input_params=feature_map.parameters,\n", " weight_params=ansatz.parameters,\n", " interpret=parity,\n", " output_shape=output_shape,\n", ")\n", "\n", "# Set up PyTorch module\n", "# Reminder: If we don't explicitly declare the initial weights\n", "# they are chosen uniformly at random from [-1, 1].\n", "initial_weights = 0.1 * (2 * algorithm_globals.random.random(qnn2.num_weights) - 1)\n", "print(\"Initial weights: \", initial_weights)\n", "model2 = TorchConnector(qnn2, initial_weights)" ] }, { "cell_type": "markdown", "id": "liquid-reviewer", "metadata": {}, "source": [ "For a reminder on optimizer and loss function choices, you can go back to [this section](#Optimizer)." ] }, { "cell_type": "code", "execution_count": 9, "id": "marked-harvest", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.6925069093704224\n", "0.6881508231163025\n", "0.6516684293746948\n", "0.6485998034477234\n", "0.6394745111465454\n", "0.7055025100708008\n", "0.6669358611106873\n", "0.6768221259117126\n", "0.6784337759017944\n", "0.7485936284065247\n", "0.6641563773155212\n", "0.6561498045921326\n", "0.66301429271698\n", "0.6441987752914429\n", "0.6511136293411255\n", "0.6289191246032715\n", "0.6247060298919678\n", "0.6366127729415894\n", "0.6195870041847229\n", "0.6179186105728149\n" ] } ], "source": [ "# Define model, optimizer, and loss\n", "optimizer = LBFGS(model2.parameters())\n", "f_loss = CrossEntropyLoss() # Our output will be in the [0,1] range\n", "\n", "# Start training\n", "model2.train()\n", "\n", "# Define LBFGS closure method (explained in previous section)\n", "def closure():\n", " optimizer.zero_grad(set_to_none=True) # Initialize gradient\n", " loss = f_loss(model2(X_), y01_) # Calculate loss\n", " loss.backward() # Backward pass\n", "\n", " print(loss.item()) # Print loss\n", " return loss\n", "\n", "\n", "# Run optimizer (LBFGS requires closure)\n", "optimizer.step(closure);" ] }, { "cell_type": "code", "execution_count": 10, "id": "falling-electronics", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy: 0.8\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Evaluate model and compute accuracy\n", "model2.eval()\n", "y_predict = []\n", "for x in X:\n", " output = model2(Tensor(x))\n", " y_predict += [np.argmax(output.detach().numpy())]\n", "\n", "print(\"Accuracy:\", sum(y_predict == y01) / len(y01))\n", "\n", "# plot results\n", "# red == wrongly classified\n", "for x, y_target, y_ in zip(X, y01, y_predict):\n", " if y_target == 1:\n", " plt.plot(x[0], x[1], \"bo\")\n", " else:\n", " plt.plot(x[0], x[1], \"go\")\n", " if y_target != y_:\n", " plt.scatter(x[0], x[1], s=200, facecolors=\"none\", edgecolors=\"r\", linewidths=2)\n", "plt.plot([-1, 1], [1, -1], \"--\", color=\"black\")\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "aboriginal-white", "metadata": {}, "source": [ "The red circles indicate wrongly classified data points." ] }, { "cell_type": "markdown", "id": "scheduled-nicaragua", "metadata": {}, "source": [ "### 2. Regression \n", "\n", "We use a model based on the `EstimatorQNN` to also illustrate how to perform a regression task. The chosen dataset in this case is randomly generated following a sine wave. " ] }, { "cell_type": "code", "execution_count": 11, "id": "amateur-dubai", "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Generate random dataset\n", "\n", "num_samples = 20\n", "eps = 0.2\n", "lb, ub = -np.pi, np.pi\n", "f = lambda x: np.sin(x)\n", "\n", "X = (ub - lb) * algorithm_globals.random.random([num_samples, 1]) + lb\n", "y = f(X) + eps * (2 * algorithm_globals.random.random([num_samples, 1]) - 1)\n", "plt.plot(np.linspace(lb, ub), f(np.linspace(lb, ub)), \"r--\")\n", "plt.plot(X, y, \"bo\")\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "protected-genre", "metadata": {}, "source": [ "#### A. Regression with PyTorch and `EstimatorQNN`" ] }, { "cell_type": "markdown", "id": "lovely-semiconductor", "metadata": {}, "source": [ "The network definition and training loop will be analogous to those of the classification task using `EstimatorQNN`. In this case, we define our own feature map and ansatz, but let's do it a little different." ] }, { "cell_type": "code", "execution_count": 12, "id": "brazilian-adapter", "metadata": {}, "outputs": [], "source": [ "# Construct simple feature map\n", "param_x = Parameter(\"x\")\n", "feature_map = QuantumCircuit(1, name=\"fm\")\n", "feature_map.ry(param_x, 0)\n", "\n", "# Construct simple parameterized ansatz\n", "param_y = Parameter(\"y\")\n", "ansatz = QuantumCircuit(1, name=\"vf\")\n", "ansatz.ry(param_y, 0)\n", "\n", "qc = QuantumCircuit(1)\n", "qc.compose(feature_map, inplace=True)\n", "qc.compose(ansatz, inplace=True)\n", "\n", "# Construct QNN\n", "qnn3 = EstimatorQNN(circuit=qc, input_params=[param_x], weight_params=[param_y])\n", "\n", "# Set up PyTorch module\n", "# Reminder: If we don't explicitly declare the initial weights\n", "# they are chosen uniformly at random from [-1, 1].\n", "initial_weights = 0.1 * (2 * algorithm_globals.random.random(qnn3.num_weights) - 1)\n", "model3 = TorchConnector(qnn3, initial_weights)" ] }, { "cell_type": "markdown", "id": "waiting-competition", "metadata": {}, "source": [ "For a reminder on optimizer and loss function choices, you can go back to [this section](#Optimizer)." ] }, { "cell_type": "code", "execution_count": 13, "id": "bibliographic-consciousness", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "14.947757720947266\n", "2.948650360107422\n", "8.952412605285645\n", "0.37905153632164\n", "0.24995625019073486\n", "0.2483610212802887\n", "0.24835753440856934\n" ] }, { "data": { "text/plain": [ "tensor(14.9478, grad_fn=)" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Define optimizer and loss function\n", "optimizer = LBFGS(model3.parameters())\n", "f_loss = MSELoss(reduction=\"sum\")\n", "\n", "# Start training\n", "model3.train() # set model to training mode\n", "\n", "# Define objective function\n", "def closure():\n", " optimizer.zero_grad(set_to_none=True) # Initialize gradient\n", " loss = f_loss(model3(Tensor(X)), Tensor(y)) # Compute batch loss\n", " loss.backward() # Backward pass\n", " print(loss.item()) # Print loss\n", " return loss\n", "\n", "\n", "# Run optimizer\n", "optimizer.step(closure)" ] }, { "cell_type": "code", "execution_count": 14, "id": "timely-happiness", "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Plot target function\n", "plt.plot(np.linspace(lb, ub), f(np.linspace(lb, ub)), \"r--\")\n", "\n", "# Plot data\n", "plt.plot(X, y, \"bo\")\n", "\n", "# Plot fitted line\n", "model3.eval()\n", "y_ = []\n", "for x in np.linspace(lb, ub):\n", " output = model3(Tensor([x]))\n", " y_ += [output.detach().numpy()[0]]\n", "plt.plot(np.linspace(lb, ub), y_, \"g-\")\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "individual-georgia", "metadata": {}, "source": [ "***\n", "\n", "## Part 2: MNIST Classification, Hybrid QNNs\n", "\n", "In this second part, we show how to leverage a hybrid quantum-classical neural network using `TorchConnector`, to perform a more complex image classification task on the MNIST handwritten digits dataset. \n", "\n", "For a more detailed (pre-`TorchConnector`) explanation on hybrid quantum-classical neural networks, you can check out the corresponding section in the [Qiskit Textbook repository](https://github.com/Qiskit/platypus/blob/main/notebooks/v2/ch-machine-learning/machine-learning-qiskit-pytorch.ipynb)." ] }, { "cell_type": "code", "execution_count": 15, "id": "otherwise-military", "metadata": {}, "outputs": [], "source": [ "# Additional torch-related imports\n", "import torch\n", "from torch import cat, no_grad, manual_seed\n", "from torch.utils.data import DataLoader\n", "from torchvision import datasets, transforms\n", "import torch.optim as optim\n", "from torch.nn import (\n", " Module,\n", " Conv2d,\n", " Linear,\n", " Dropout2d,\n", " NLLLoss,\n", " MaxPool2d,\n", " Flatten,\n", " Sequential,\n", " ReLU,\n", ")\n", "import torch.nn.functional as F" ] }, { "cell_type": "markdown", "id": "bronze-encounter", "metadata": {}, "source": [ "### Step 1: Defining Data-loaders for train and test" ] }, { "cell_type": "markdown", "id": "parliamentary-middle", "metadata": {}, "source": [ "We take advantage of the `torchvision` [API](https://pytorch.org/vision/stable/datasets.html) to directly load a subset of the [MNIST dataset](https://en.wikipedia.org/wiki/MNIST_database) and define torch `DataLoader`s ([link](https://pytorch.org/docs/stable/data.html)) for train and test." ] }, { "cell_type": "code", "execution_count": 16, "id": "worthy-charlotte", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "71.7%" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n", "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100.0%\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100.0%\n", "100.0%" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\n", "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz\n", "Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw\n", "\n", "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n", "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz\n", "Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n", "100.0%\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\n", "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz\n", "Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw\n", "\n" ] } ], "source": [ "# Train Dataset\n", "# -------------\n", "\n", "# Set train shuffle seed (for reproducibility)\n", "manual_seed(42)\n", "\n", "batch_size = 1\n", "n_samples = 100 # We will concentrate on the first 100 samples\n", "\n", "# Use pre-defined torchvision function to load MNIST train data\n", "X_train = datasets.MNIST(\n", " root=\"./data\", train=True, download=True, transform=transforms.Compose([transforms.ToTensor()])\n", ")\n", "\n", "# Filter out labels (originally 0-9), leaving only labels 0 and 1\n", "idx = np.append(\n", " np.where(X_train.targets == 0)[0][:n_samples], np.where(X_train.targets == 1)[0][:n_samples]\n", ")\n", "X_train.data = X_train.data[idx]\n", "X_train.targets = X_train.targets[idx]\n", "\n", "# Define torch dataloader with filtered data\n", "train_loader = DataLoader(X_train, batch_size=batch_size, shuffle=True)" ] }, { "cell_type": "markdown", "id": "completed-spring", "metadata": {}, "source": [ "If we perform a quick visualization we can see that the train dataset consists of images of handwritten 0s and 1s." ] }, { "cell_type": "code", "execution_count": 17, "id": "medieval-bibliography", "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAxsAAACZCAYAAABHTieHAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAdNklEQVR4nO3deXBUVdrH8SeEbEDICLJvMYAgCKKACLLNwAgIOOwoKJuACiXIphQWq2yKoKMooBSMJTKAwDDqMIBiMoALDEqkGESWF5GwyRrCFgi57x8pOnlOQi9Jn3Q6fD9VVvWv+y7H9JPuHO4594Q4juMIAAAAAPhZkUA3AAAAAEDhRGcDAAAAgBV0NgAAAABYQWcDAAAAgBV0NgAAAABYQWcDAAAAgBV0NgAAAABYQWcDAAAAgBV0NgAAAABYETSdjV9//VVCQkLkzTff9NsxExISJCQkRBISEvx2zClTpkhISIjfjofboyZgoiaQFfUAEzUBEzVhn9XOxt/+9jcJCQmRnTt32jxNofPtt99K8+bNpVixYlK+fHkZMWKEXLp0KdDN8gtqIneoCZgKa01QD77btGmTPPvss3L//fdLaGioxMbGBrpJfkVN+I6aQE4C9b0RNFc27hSJiYnSpk0buXLlisybN08GDx4sH3zwgfTs2TPQTUOAUBMwURPIavny5bJ8+XKJiYmRihUrBro5KACoCZgC+b1R1PoZ4JMJEybIXXfdJQkJCVKyZEkREYmNjZUhQ4bIpk2b5LHHHgtwC5HfqAmYqAlkNXPmTPnwww8lLCxMOnXqJHv27Al0kxBg1ARMgfzeCPiVjevXr8ukSZOkYcOGEhMTI8WLF5cWLVpIfHz8bfd56623pFq1ahIVFSWtWrXK8Zdo37590qNHDylVqpRERkZKo0aN5LPPPvOqTdu3b5f27dtLTEyMFCtWTFq1aiXffPNNtu22bdsmjRs3lsjISKlevbosWrQox+OdOXNG9u3bJ1euXHF73osXL8qXX34pTz/9tKsQRET69esnJUqUkFWrVnnV/mBHTWSiJjJQE5moCerBVLFiRQkLC/OqnYUVNaFRE9REVgH/3nAsWrp0qSMizn//+9/bbnP69GmnQoUKzujRo50FCxY4b7zxhlOrVi0nLCzM2bVrl2u7w4cPOyLi1KtXz4mNjXVef/11Z+rUqU6pUqWcMmXKOCdPnnRtu2fPHicmJsapU6eO8/rrrzvz5893WrZs6YSEhDhr1651bRcfH++IiBMfH+96bvPmzU54eLjTtGlTZ+7cuc5bb73l1K9f3wkPD3e2b9/u2m737t1OVFSUU7VqVWfWrFnOa6+95pQrV86pX7++Y/5YJ0+enO08Odm2bZsjIs7KlSuzvda8eXPnoYcecrt/MKAmMlATmaiJDNREBuohg7f1YOrYsaNTrVo1n/Yp6KiJDNREJmoiQ7B8bwS8s5GWluakpqaq586fP++UK1fOGTRokOu5W8UQFRXlJCUluZ7fvn27IyLOqFGjXM+1adPGqVevnnPt2jXXc+np6U6zZs2cmjVrup4ziyE9Pd2pWbOm065dOyc9Pd213ZUrV5x77rnH+fOf/+x6rkuXLk5kZKRz5MgR13N79+51QkNDc10Mn376qSMizpYtW7K91rNnT6d8+fJu9w8G1EQGaiITNZGBmshAPWTgD8tM1EQGaiITNZEhWL43Aj6MKjQ0VMLDw0VEJD09Xc6dOydpaWnSqFEj+fHHH7Nt36VLF6lUqZIrP/zww9KkSRNZv369iIicO3dOvv76a+nVq5ekpKTImTNn5MyZM3L27Flp166dHDhwQI4dO5ZjWxITE+XAgQPSp08fOXv2rGvfy5cvS5s2bWTLli2Snp4uN2/elI0bN0qXLl2katWqrv3vu+8+adeuXbbjTpkyRRzHkdatW7v9WVy9elVERCIiIrK9FhkZ6Xq9sKMmMlETGaiJTNQE9YDsqAmYqIlMgf7eKBATxD/66COZO3eu7Nu3T27cuOF6/p577sm2bc2aNbM9d++997rGmx08eFAcx5GJEyfKxIkTczzf77//rgrqlgMHDoiISP/+/W/b1uTkZElNTZWrV6/m2JZatWq5CtNXUVFRIiKSmpqa7bVr1665Xr8TUBMZqIlM1EQGaiID9QATNQETNZEh0N8bAe9sLFu2TAYMGCBdunSRcePGSdmyZSU0NFRmzZolhw4d8vl46enpIiIyduzYHHuBIiI1atRwu++cOXOkQYMGOW5TokSJHN8sf6hQoYKIiJw4cSLbaydOnLhjbl9HTWSiJjJQE5moCeoB2VETMFETmQL9vRHwzsbq1aslLi5O1q5dq1ZGnDx5co7b3+odZrV//37XgjVxcXEiIhIWFiZt27b1qS3Vq1cXEZGSJUu63bdMmTISFRWVY1t++eUXn86Z1f333y9FixaVnTt3Sq9evVzPX79+XRITE9VzhRk1kYmayEBNZKImqAdkR03ARE1kCvT3RoGYsyEi4jiO67nt27fLd999l+P269atU2PiduzYIdu3b5cOHTqIiEjZsmWldevWsmjRohx7cKdPn75tWxo2bCjVq1eXN998M8cVFW/tGxoaKu3atZN169bJb7/95nr9559/lo0bN2bbz9tbk8XExEjbtm1l2bJlkpKS4nr+448/lkuXLt0xC3ZRE5moiQzURCZqgnpAdtQETNREpkB/b+TLlY0lS5bIhg0bsj0/cuRI6dSpk6xdu1a6du0qHTt2lMOHD8vChQulTp06Ob4hNWrUkObNm8sLL7wgqamp8vbbb0vp0qXl5Zdfdm3z3nvvSfPmzaVevXoyZMgQiYuLk1OnTsl3330nSUlJ8tNPP+XYziJFisjixYulQ4cOUrduXRk4cKBUqlRJjh07JvHx8VKyZEn5/PPPRURk6tSpsmHDBmnRooUMGzZM0tLS5N1335W6devK7t271XHnz58vU6dOlfj4eI+TeGbMmCHNmjWTVq1aydChQyUpKUnmzp0rjz32mLRv397TjzpoUBPUhImaoCayoh68r4fdu3e77vN/8OBBSU5OlunTp4uIyAMPPCCdO3d2u3+woCaoCRM1ESTfGzZvdXXr1mS3++/o0aNOenq6M3PmTKdatWpORESE8+CDDzpffPGF079/f3Wrtlu3JpszZ44zd+5cp0qVKk5ERITTokUL56effsp27kOHDjn9+vVzypcv74SFhTmVKlVyOnXq5Kxevdq1TU73QXYcx9m1a5fTrVs3p3Tp0k5ERIRTrVo1p1evXs7mzZvVdv/5z3+chg0bOuHh4U5cXJyzcOFC123IsvL1dnVbt251mjVr5kRGRjplypRxhg8f7ly8eNGrfQs6aiIDNZGJmshATWSgHjL4Ug/ufmb9+/f3uH9BR01koCYyURMZguV7I8RxslxfAgAAAAA/CficDQAAAACFE50NAAAAAFbQ2QAAAABgBZ0NAAAAAFbQ2QAAAABgBZ0NAAAAAFZ4tahfenq6HD9+XKKjo9WS7yj4HMeRlJQUqVixohQp4r++JTURvKgJZGWrHkSoiWDFZwRM1ARMvtSEV52N48ePS5UqVfzSOATG0aNHpXLlyn47HjUR/KgJZOXvehChJoIdnxEwURMweVMTXnVPo6Oj/dIgBI6/30NqIvhRE8jKxvtHTQQ3PiNgoiZg8uY99KqzwaWt4Ofv95CaCH7UBLKy8f5RE8GNzwiYqAmYvHkPmSAOAAAAwAo6GwAAAACsoLMBAAAAwAo6GwAAAACsoLMBAAAAwAo6GwAAAACsoLMBAAAAwAo6GwAAAACsoLMBAAAAwIqigW4AEIyefPJJlWvVqqXyzJkzVb5x44b1NgEAABQ0XNkAAAAAYAWdDQAAAABW0NkAAAAAYAVzNkSkaFH9Y3j00UdVnjt3rsoNGzZUOSkpSeU1a9ao/P7776u8f//+XLUTdpUuXdr12JyDYb5n69evV3nq1Kkqz58/X+WzZ8/6o4nIZ7Gxsa7Hq1atUq81btxY5fT0dLfH+uyzz1QeOnSoyqdPn85FCxFMFixYoPLzzz/vdvsKFSqofPLkSb+3CUDBERUVpfL06dNVDg8PV7lnz54q7927V+W+ffuqfOLEibw2MVe4sgEAAADACjobAAAAAKygswEAAADAijt2zkb16tVdj2fMmKFe69Wrl9t9HcdRuVKlSiqPGDFC5YEDB6o8b948lc3x/rAjIiJC5Tlz5qg8ZMgQ12NzXOS5c+dUbtCggcqvvvqqyi1btlT5H//4h09tRf7o0aOHys2bN1e5fv36rscPPviges2co+FpzkanTp1UXrRokcrdunVz31gEPbO+PNUM8l/btm1VNufJ9O7d2+tj/fbbbyofPnxY5fLly6ucmJio8p49e7w+F4JD8eLFVTbX7Fq4cKHKoaGhPh3/rrvuytP+tnBlAwAAAIAVdDYAAAAAWEFnAwAAAIAVd+ycjQ4dOrgee5qjkVfR0dEqv/jiiyqb63Bwv307XnnlFZWHDx+u8vLly12PX3rpJfVaWlqaysnJySqvXr1a5cWLF6u8YcMGla9eveq5wfDInFtTtWpVladNm6Zyo0aNVC5VqpTKMTExfmyde507d1bZrJnBgwfnW1tgR5MmTVQ2x+ij4Pnqq6/cvm5zHkXWtZ5Esq/tY/tvFfhfiRIlVB47dqzKkyZNsnp+c45xoHBlAwAAAIAVdDYAAAAAWEFnAwAAAIAVhXbORlhYmMoLFixQecCAAV4f6/LlyyqvXLlS5aSkJJXHjx+vsjmu3Byvb47TZM6GHeb9rI8cOaLy5MmTXY/Pnj2bp3M98cQTKi9ZskTlb7/9Nk/HRwZz/tPs2bPdbl+kiP73lYK0zkG5cuUC3QTkUWRkpMoffvihyuYcIdOmTZtUvnjxon8ahnxRpUoVlbt3767yqVOnVP773/+ucpkyZVQ255ghOGSd+7d27Vr12h//+EefjuVpf3NdDXOOxo0bN3w6ny1c2QAAAABgBZ0NAAAAAFbQ2QAAAABgRaGZs1G/fn2V16xZo3L16tVvu29qaqrKo0ePVvnf//63yr/++qvKTz/9tLfNFBGRu+++W+UaNWqovG/fPp+Oh5w9+uijKteuXVvll19+WeX/+7//89u5Q0JC/HYsZHrhhRdUtn2Pcnfn2rVrl8qe7mfeqVMnlZ9//nn/NAwFRteuXVWuW7eu2+3N+YDm2ipXrlzxT8OQL/r27avyzJkzVb5586bK9957r8rmfFAEB3MttRUrVrget2rVyu2+R48eVXnMmDEqb968WWVP67yYf3sUK1bM7fb5hSsbAAAAAKygswEAAADACjobAAAAAKwoNHM2OnfurLK7ORqmefPmqWyuyeFJQkKCyuY6GuY6G+vXr1f566+/9ul88E63bt1UNsfLehr7mBerV6922xbW2fBOhw4dVJ4/f36ejvfDDz+o3LBhQ7fbv/POO67Hs2bNytO5N27cqLK5RgiCT8uWLVV+++23fdrfvIf+sWPH8tokBJD5d4gpNDRU5YMHD9psDiypVq2aysuXL1e5adOmt933k08+UXnChAkqm3M4sq7ZISJStKhvf7azzgYAAACAQo3OBgAAAAAr6GwAAAAAsCJo52ysW7dO5ccff9yn/X/66SfX4ylTpuSpLea6HCVLlnS7fdWqVVU25xLAP7p3767yuXPnVN6wYYO1cycnJ6vcu3dvlceOHWvt3IXJ3r17VTbnR5lj5j3p2bOnykuWLHF7PE9rZ+DONm3aNJXNNZQ8OX36tD+bgwBo3bq16/FDDz3kdlvzO+jnn392u/3cuXNz3S7YY67R5W6Ohrnm26BBg1T2NKfiT3/6k8plypRxu735nWXOIQ4UrmwAAAAAsILOBgAAAAAr6GwAAAAAsCJo5mxUrlxZZXMcm6d7D+/fv1/lp556yvU4r/chPnz4sMohISFutzfH9UZFRamcmpqap/Yggzk3JpD3NDfbAu/Exsaq/MADD/i0/5AhQ1Q+cuSIyvHx8Sq7G3ub38yxuX/9619VHjlypMqM/88fbdu2dT2uU6eOT/ueP39e5azruCA4jR8/3vU4IiLC7bbm7+yPP/6osrkm1/Hjx/PYOuRGdHS0yrNnz1Z56NChbvfPOhdnzJgx6jVf/968dOmST9snJiaqzJwNAAAAAIUanQ0AAAAAVtDZAAAAAGBFgZ2zUbt2bZXXrl2rcokSJdzub46FfP7551Xet29frttmzg9p1aqVT/svWrRI5QsXLuS6LfCeeb9rFHzFihVTOSYmxqf9T5w44fb16dOnq/zMM8/4dHx/MuenmGuAtGjRQuWTJ0+qbK73A/8wx2/PmDHD9bh06dI+HevZZ59V+ejRo7lvGAoEX2rgzJkzbl+/fv2624z80aFDB5UHDhyocmhoqMrm33AjRoxwPf7tt9/y1Jasc8S8ERYWpjJzNgAAAAAUanQ2AAAAAFhBZwMAAACAFQV2zsbHH3+ssjmHw2TeP3/cuHEq79y50z8NE5EiRXQfrV69em63T09PVzkhIcFvbUGmcuXKqWyud/L777/nW1sCeW4Ep/vuu09lc44GAuPxxx9XuVGjRl7v+80336i8efNmv7QJgWN+z/g6jwwFT4MGDVQ21zSKjIxUOes6GiIir776qspbt27NdVtq1Kih8oABA3za31xno6DgygYAAAAAK+hsAAAAALCiwAyjModJmZe1PNm0aZPK8fHxeW3SbT355JMq16xZ0+325jCqX375xe9tQvbhDo7jqGxeCrXJPPcXX3yRb+cuzMwhjJ6Yw9k8MS9/B/J31dP/68iRI1U2PwM3bNjg9zbdCdq3b6/ye++95/W+5q1NJ0+erPKlS5dy3zAUCH369FE567AX8/PGvCXq2bNnrbULuffcc8+pbA6VM3311Vcqr1u3zt9NcvG0zIPJHOJVUHBlAwAAAIAVdDYAAAAAWEFnAwAAAIAVBWbOhnmrWnM5eNO2bdtUHjZsmN/bdEvJkiVVNm9z5om5XP358+fz3Cb4zhw/a1PVqlVVvnz5cr6duzAz5z95MnToUJU9zWMYPHiwz22yxdf/V3OeELwTHR2t8sSJE1W+6667vD7WJ598orLNuYMIjIEDB972NfN3cM+ePSr78xb8yL37779f5b/85S9ut9+/f7/Kb731lt/bdMupU6dUNm+bX61aNbf7e3o9ULiyAQAAAMAKOhsAAAAArKCzAQAAAMCKgM3Z+MMf/qBymzZt3G5/48YNlc37l9+8edMv7RLJPkfjjTfeUNnTuhqpqakqm21FYHi6d7Y/tW3bVuWCeu/rwq5z586BbgIKuHbt2qn8yCOPeL3vRx99pLI53wPBr2XLlirHxsbedtuUlBSVfVmjJSeVK1dWOSkpKU/HQ4YJEyaoXL58ebfbjxo1SuVff/3V301yiYuLU7l48eJut09LS1P5yy+/9Hub/IErGwAAAACsoLMBAAAAwAo6GwAAAACsCNicjaeeekplc10Ck7muhj/vXx4WFqbyvHnzVB40aJBPxzt48KDK5r3XERgRERHWjt28eXOVzXk///rXv6yduzA7cuSIyrt371a5fv36Ph3PvJ/6P//5z9w1DEHLHHP/4Ycf+rT/6dOnXY/N+Xysp1P4mOP5S5QocdttzXU0VqxY4dO5zDkanTp1UnnhwoU+HQ8ZzLVRevXq5Xb7Tz/9VOXNmzf7vU23FC2q/wzv06ePynfffbfb/ZOTk1XesmWLfxrmZ1zZAAAAAGAFnQ0AAAAAVtDZAAAAAGBFwOZslClTxqftv//+e7+eP+uY+jVr1qjXPK35Yco6hldE5Jlnnsl9w5BrS5cuVXnJkiUqP/fccyrPmTNH5XPnznl9rujoaJWXLVumcnh4uMrmnCN4Z+/evSr37dtX5U2bNqlcoUIFt8dbtGiRyunp6Sp//vnnvjYRQeall15S2ZxfZTpz5ozKPXr0cD3et2+f39qFgql06dJeb2v+LeArcz6IuW4HvGPOzxw8eLDKRYrof2c313FbvHixytevX/db20JDQ1Xu2LGjyuPGjXO7v7mOW9euXVW+cOFC7htnEVc2AAAAAFhBZwMAAACAFXQ2AAAAAFgRsDkb1atX92l7c+0KT8xxceZaGTNnznQ99mVMpkj28ffmvfvPnz/v0/Fgx44dO1Ru1KiRyiNGjFB52rRprsfmWH5zLZbXX39d5SpVqqi8cuVKlXft2uVFi+GJOUbe/F3s3bu32/3LlSun8rp161Q+fvy4ykOHDnU9Nu+hn9fx2SazPs1xxaZ33nlH5Y0bN/q1PYVF586dVc7ruknMv7qzjBw50u3rV69edT025wH6ypw7wLotuVO7dm2VH374YbfbHz58WGV//o6ba6eYa35MnTrV7f7mHI0PPvhA5WD5POLKBgAAAAAr6GwAAAAAsILOBgAAAAArAjZnY/369Sqb9883xyvXqlVLZXMM/ahRo1Tu3r27yo0bN/a6bWlpaSq/9957Kr/77rsqM0ejYOrQoYPKiYmJKk+cOFHlrOP3zffU3NYcd+lurL+IHtcL/xk2bJjK5lytbt26ud3fnJtTvnx5lT/77DPXY3MNji+//FLlBQsWuD1Xq1atVDbvj/7iiy+6bZvJcRy3ryPDE088oXLx4sV92v/tt9/2Y2tQ2GT92+OHH37I07EOHTrkNsM75u94SEiI2+3NdTSKFSumsjnXr23btm6PV7duXddjc40PT58/5ue6+ffm7Nmz3e5fUHFlAwAAAIAVdDYAAAAAWEFnAwAAAIAVIY4XA38vXrwoMTExfj2xOW7t+++/VznrmDfbli5dqvKUKVNUPnr0aL61xZbk5GQpWbKk345noyZsM8c+muP9r1y54npsjpePjo5W+dixYyq3b99e5f/973+5bmd+KYw1YZ7/7rvvVnn//v0qe5oX4c61a9dUNuftmMyftdk2c56a2bYjR46o/Mgjj6h85swZt+f3xN/1IBKYmjDn3Zj30A8PD3e7f//+/VVetWqVyub47sKsMH5GeNKkSROVN2zYoLLZ/qxrOJhr8RRGwVAT5ppF5t+XntYw8sScA+LL/DlzX3MdDXOtFnO+aEHkTU1wZQMAAACAFXQ2AAAAAFhBZwMAAACAFQFbZ+Py5csqr169WuUqVaqonNcxguYY+jFjxrgeJyQkqNfupDG5d5LRo0erfPr0aZVfeeUV12PzPtunTp1SuWPHjioHwxyNO0FycrLb/M4776hsjrXt16+fyu7GEps1EhcX53U7c2PNmjUq53WORmFljon2NEfD/Dnu2rVLZb4P7iwjR45U2fwM2LNnj8qFYU5nYZOUlKSy+TscGRmZb225cOGCyuYcsPHjx7vdvrDgygYAAAAAK+hsAAAAALCCzgYAAAAAKwI2Z8M0depUlc17yg8fPlzlhg0buj3e9OnTVV6xYoXKe/fu9bWJCHI3btxQedq0aW4zCp9Ro0a5fX3t2rUqZ10PyPyMaty4sf8aJtnv0T9p0iSVt27d6tfzFVaXLl1SOTExUeWqVauq/Nprr6nM/Ks7W9OmTVU217t5//33VTbn8yHwTp48qXKzZs1UNudJlC1bVuXWrVurfPbsWZU3b96s8vr161VOSUm57bYXL168TasLN65sAAAAALCCzgYAAAAAK+hsAAAAALAixDFvNJ+Dixcvur3fPAq+5OTkPK9VkhU1EfyoCWTl73oQoSaC3Z34GXH48GGVd+zYoXLv3r3zszkFzp1YE3DPm5rgygYAAAAAK+hsAAAAALCCzgYAAAAAKwrMOhsAAACBdM899wS6CUChw5UNAAAAAFbQ2QAAAABgBZ0NAAAAAFbQ2QAAAABgBZ0NAAAAAFbQ2QAAAABgBZ0NAAAAAFbQ2QAAAABgBZ0NAAAAAFbQ2QAAAABghVedDcdxbLcDlvn7PaQmgh81gaxsvH/URHDjMwImagImb95DrzobKSkpeW4MAsvf7yE1EfyoCWRl4/2jJoIbnxEwURMwefMehjhedEnS09Pl+PHjEh0dLSEhIX5pHPKH4ziSkpIiFStWlCJF/DdqjpoIXtQEsrJVDyLURLDiMwImagImX2rCq84GAAAAAPiKCeIAAAAArKCzAQAAAMAKOhsAAAAArKCzAQAAAMAKOhsAAAAArKCzAQAAAMAKOhsAAAAArPh/7ZEnYlSSkpMAAAAASUVORK5CYII=", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "n_samples_show = 6\n", "\n", "data_iter = iter(train_loader)\n", "fig, axes = plt.subplots(nrows=1, ncols=n_samples_show, figsize=(10, 3))\n", "\n", "while n_samples_show > 0:\n", " images, targets = data_iter.__next__()\n", "\n", " axes[n_samples_show - 1].imshow(images[0, 0].numpy().squeeze(), cmap=\"gray\")\n", " axes[n_samples_show - 1].set_xticks([])\n", " axes[n_samples_show - 1].set_yticks([])\n", " axes[n_samples_show - 1].set_title(\"Labeled: {}\".format(targets[0].item()))\n", "\n", " n_samples_show -= 1" ] }, { "cell_type": "code", "execution_count": 18, "id": "structural-chuck", "metadata": {}, "outputs": [], "source": [ "# Test Dataset\n", "# -------------\n", "\n", "# Set test shuffle seed (for reproducibility)\n", "# manual_seed(5)\n", "\n", "n_samples = 50\n", "\n", "# Use pre-defined torchvision function to load MNIST test data\n", "X_test = datasets.MNIST(\n", " root=\"./data\", train=False, download=True, transform=transforms.Compose([transforms.ToTensor()])\n", ")\n", "\n", "# Filter out labels (originally 0-9), leaving only labels 0 and 1\n", "idx = np.append(\n", " np.where(X_test.targets == 0)[0][:n_samples], np.where(X_test.targets == 1)[0][:n_samples]\n", ")\n", "X_test.data = X_test.data[idx]\n", "X_test.targets = X_test.targets[idx]\n", "\n", "# Define torch dataloader with filtered data\n", "test_loader = DataLoader(X_test, batch_size=batch_size, shuffle=True)" ] }, { "cell_type": "markdown", "id": "abroad-morris", "metadata": {}, "source": [ "### Step 2: Defining the QNN and Hybrid Model" ] }, { "cell_type": "markdown", "id": "super-tokyo", "metadata": {}, "source": [ "This second step shows the power of the `TorchConnector`. After defining our quantum neural network layer (in this case, a `EstimatorQNN`), we can embed it into a layer in our torch `Module` by initializing a torch connector as `TorchConnector(qnn)`.\n", "\n", "**⚠️ Attention:**\n", "In order to have an adequate gradient backpropagation in hybrid models, we MUST set the initial parameter `input_gradients` to TRUE during the qnn initialization." ] }, { "cell_type": "code", "execution_count": 19, "id": "urban-purse", "metadata": {}, "outputs": [], "source": [ "# Define and create QNN\n", "def create_qnn():\n", " feature_map = ZZFeatureMap(2)\n", " ansatz = RealAmplitudes(2, reps=1)\n", " qc = QuantumCircuit(2)\n", " qc.compose(feature_map, inplace=True)\n", " qc.compose(ansatz, inplace=True)\n", "\n", " # REMEMBER TO SET input_gradients=True FOR ENABLING HYBRID GRADIENT BACKPROP\n", " qnn = EstimatorQNN(\n", " circuit=qc,\n", " input_params=feature_map.parameters,\n", " weight_params=ansatz.parameters,\n", " input_gradients=True,\n", " )\n", " return qnn\n", "\n", "\n", "qnn4 = create_qnn()" ] }, { "cell_type": "code", "execution_count": 20, "id": "exclusive-productivity", "metadata": {}, "outputs": [], "source": [ "# Define torch NN module\n", "\n", "\n", "class Net(Module):\n", " def __init__(self, qnn):\n", " super().__init__()\n", " self.conv1 = Conv2d(1, 2, kernel_size=5)\n", " self.conv2 = Conv2d(2, 16, kernel_size=5)\n", " self.dropout = Dropout2d()\n", " self.fc1 = Linear(256, 64)\n", " self.fc2 = Linear(64, 2) # 2-dimensional input to QNN\n", " self.qnn = TorchConnector(qnn) # Apply torch connector, weights chosen\n", " # uniformly at random from interval [-1,1].\n", " self.fc3 = Linear(1, 1) # 1-dimensional output from QNN\n", "\n", " def forward(self, x):\n", " x = F.relu(self.conv1(x))\n", " x = F.max_pool2d(x, 2)\n", " x = F.relu(self.conv2(x))\n", " x = F.max_pool2d(x, 2)\n", " x = self.dropout(x)\n", " x = x.view(x.shape[0], -1)\n", " x = F.relu(self.fc1(x))\n", " x = self.fc2(x)\n", " x = self.qnn(x) # apply QNN\n", " x = self.fc3(x)\n", " return cat((x, 1 - x), -1)\n", "\n", "\n", "model4 = Net(qnn4)" ] }, { "cell_type": "markdown", "id": "academic-specific", "metadata": {}, "source": [ "### Step 3: Training" ] }, { "cell_type": "code", "execution_count": 21, "id": "precious-career", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training [10%]\tLoss: -1.1630\n", "Training [20%]\tLoss: -1.5294\n", "Training [30%]\tLoss: -1.7855\n", "Training [40%]\tLoss: -1.9863\n", "Training [50%]\tLoss: -2.2257\n", "Training [60%]\tLoss: -2.4513\n", "Training [70%]\tLoss: -2.6758\n", "Training [80%]\tLoss: -2.8832\n", "Training [90%]\tLoss: -3.1006\n", "Training [100%]\tLoss: -3.3061\n" ] } ], "source": [ "# Define model, optimizer, and loss function\n", "optimizer = optim.Adam(model4.parameters(), lr=0.001)\n", "loss_func = NLLLoss()\n", "\n", "# Start training\n", "epochs = 10 # Set number of epochs\n", "loss_list = [] # Store loss history\n", "model4.train() # Set model to training mode\n", "\n", "for epoch in range(epochs):\n", " total_loss = []\n", " for batch_idx, (data, target) in enumerate(train_loader):\n", " optimizer.zero_grad(set_to_none=True) # Initialize gradient\n", " output = model4(data) # Forward pass\n", " loss = loss_func(output, target) # Calculate loss\n", " loss.backward() # Backward pass\n", " optimizer.step() # Optimize weights\n", " total_loss.append(loss.item()) # Store loss\n", " loss_list.append(sum(total_loss) / len(total_loss))\n", " print(\"Training [{:.0f}%]\\tLoss: {:.4f}\".format(100.0 * (epoch + 1) / epochs, loss_list[-1]))" ] }, { "cell_type": "code", "execution_count": 22, "id": "spoken-stationery", "metadata": { "scrolled": true }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Plot loss convergence\n", "plt.plot(loss_list)\n", "plt.title(\"Hybrid NN Training Convergence\")\n", "plt.xlabel(\"Training Iterations\")\n", "plt.ylabel(\"Neg. Log Likelihood Loss\")\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "physical-closure", "metadata": {}, "source": [ "Now we'll save the trained model, just to show how a hybrid model can be saved and re-used later for inference. To save and load hybrid models, when using the TorchConnector, follow the PyTorch recommendations of saving and loading the models." ] }, { "cell_type": "code", "execution_count": 23, "id": "regulation-bread", "metadata": {}, "outputs": [], "source": [ "torch.save(model4.state_dict(), \"model4.pt\")" ] }, { "cell_type": "markdown", "id": "pacific-flour", "metadata": {}, "source": [ "### Step 4: Evaluation" ] }, { "cell_type": "markdown", "id": "fabulous-tribe", "metadata": {}, "source": [ "We start from recreating the model and loading the state from the previously saved file. You create a QNN layer using another simulator or a real hardware. So, you can train a model on real hardware available on the cloud and then for inference use a simulator or vice verse. For a sake of simplicity we create a new quantum neural network in the same way as above." ] }, { "cell_type": "code", "execution_count": 24, "id": "prospective-flooring", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "qnn5 = create_qnn()\n", "model5 = Net(qnn5)\n", "model5.load_state_dict(torch.load(\"model4.pt\"))" ] }, { "cell_type": "code", "execution_count": 25, "id": "spectacular-conservative", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Performance on test data:\n", "\tLoss: -3.3585\n", "\tAccuracy: 100.0%\n" ] } ], "source": [ "model5.eval() # set model to evaluation mode\n", "with no_grad():\n", "\n", " correct = 0\n", " for batch_idx, (data, target) in enumerate(test_loader):\n", " output = model5(data)\n", " if len(output.shape) == 1:\n", " output = output.reshape(1, *output.shape)\n", "\n", " pred = output.argmax(dim=1, keepdim=True)\n", " correct += pred.eq(target.view_as(pred)).sum().item()\n", "\n", " loss = loss_func(output, target)\n", " total_loss.append(loss.item())\n", "\n", " print(\n", " \"Performance on test data:\\n\\tLoss: {:.4f}\\n\\tAccuracy: {:.1f}%\".format(\n", " sum(total_loss) / len(total_loss), correct / len(test_loader) / batch_size * 100\n", " )\n", " )" ] }, { "cell_type": "code", "execution_count": 26, "id": "color-brave", "metadata": { "tags": [ "nbsphinx-thumbnail" ] }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAxsAAACZCAYAAABHTieHAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAVpUlEQVR4nO3deWxUVRvH8WdqgS4MoFigC1ZWqyAWKkoUBaWCFCS4YzQpKIKiCBqCu7hUCbiBUBsJBiOigigBlEWIVUQUF0piQUUNImQqRSyKZaml5/3Dl+ucU9pOyz13ZtrvJyG5v95ZHpmHiw/3nrk+pZQSAAAAAHBZTLgLAAAAANA4MWwAAAAAsIJhAwAAAIAVDBsAAAAArGDYAAAAAGAFwwYAAAAAKxg2AAAAAFjBsAEAAADACoYNAAAAAFY0umHjzDPPlNGjRzv5448/Fp/PJx9//HHYajKZNcIuegImegImegImegImeqJhXB02XnvtNfH5fM6vuLg46d69u9x9992yd+9eN9/KulWrVsnjjz8e1hoWL14st9xyi3Tr1k18Pp8MHDgwrPU0BD3hLnoiskRCT4iIrFixQvr06SNxcXFyxhlnyLRp06SysjLcZYWMnnAXx4nIQk+4g55wl5c9EWvjRZ988knp1KmTHDlyRDZu3CgFBQWyatUqKS4uloSEBBtvWaNLL71UDh8+LM2bN6/X81atWiX5+flhbYaCggL55ptvpG/fvrJ///6w1eEGesId9IQd0dwTq1evlpEjR8rAgQNlzpw58u2330peXp6UlpZKQUFB2OpqCHrCHRwn7KAnIgM94Q4ve8LKsDF06FA5//zzRURk7Nix0rZtW3nhhRdk+fLlctNNN53wOeXl5ZKYmOh6LTExMRIXF+f663ph4cKFkpqaKjExMdKzZ89wl3NS6Al30BP0hGnKlCnSq1cv+fDDDyU29t9DeqtWreSZZ56RSZMmSUZGRpgrDB094Q6OE/SEiZ6gJ0xe9oQnazYuv/xyERHZuXOniIiMHj1aWrZsKT///LPk5OSI3++Xm2++WUREqqqqZNasWdKjRw+Ji4uT9u3by/jx46WsrEx7TaWU5OXlSVpamiQkJMhll10m27Ztq/beNV1Pt3nzZsnJyZFTTz1VEhMTpVevXjJ79mynvvz8fBER7ZTdcW7XWJOOHTtKTEyjW1YjIvQEPVEdPVH/nti+fbts375dxo0b5wwaIiITJkwQpZQsXbo0pNeJVPQExwkTPUFPmOiJyO8JK2c2TD///LOIiLRt29b5WWVlpQwZMkT69+8vzz33nHPqa/z48fLaa6/JmDFj5J577pGdO3fK3LlzpaioSD777DNp1qyZiIg89thjkpeXJzk5OZKTkyNbtmyRwYMHS0VFRZ31rFu3ToYPHy7JyckyadIk6dChg3z33Xfy/vvvy6RJk2T8+PESCARk3bp1snDhwmrP96LGxo6eoCdM9ET9aywqKhIRcf6V77iUlBRJS0tz9kcreoLjhImeoCdM9EQU9IRy0YIFC5SIqPXr16t9+/ap3bt3q7ffflu1bdtWxcfHqz179iillMrNzVUioh544AHt+Z9++qkSEbVo0SLt52vWrNF+Xlpaqpo3b66GDRumqqqqnMc99NBDSkRUbm6u87PCwkIlIqqwsFAppVRlZaXq1KmTSk9PV2VlZdr7BL/WXXfdpU7022OjxlD06NFDDRgwoF7PiQT0BD1hoifc64lnn31WiYj69ddfq+3r27ev6tevX63PjxT0BMcJEz1BT5joiejtCSvnT7KzsyUpKUk6duwoo0aNkpYtW8qyZcskNTVVe9ydd96p5XfeeUdat24tV1xxhfz+++/Or6ysLGnZsqUUFhaKiMj69euloqJCJk6cqJ16mjx5cp21FRUVyc6dO2Xy5MnSpk0bbV/wa9XEixobI3qCnjDREyffE4cPHxYRkRYtWlTbFxcX5+yPFvQExwkTPUFPmOiJ6OsJK5dR5efnS/fu3SU2Nlbat28vZ511VrXrwmJjYyUtLU372Y8//ih//vmntGvX7oSvW1paKiIiu3btEhGRbt26afuTkpLk1FNPrbW246fbGroYxosaGyN6gp4w0RMn3xPx8fEiInL06NFq+44cOeLsjxb0BMcJEz1BT5joiejrCSvDxgUXXFDtGmJTixYtqjVHVVWVtGvXThYtWnTC5yQlJblWY0NFQ42RiJ6AiZ44ecnJySIiUlJSIh07dtT2lZSUyAUXXODK+3iFnoCJnoCJnog+niwQD1WXLl1k/fr1cvHFF9f6L3Lp6eki8u8E2LlzZ+fn+/btq7Za/0TvISJSXFws2dnZNT6uptNdXtSI/9ATMNET/8nMzBQRka+//lobLAKBgOzZs0fGjRtX52s0BvQETPQETPRE+ETU96DdcMMNcuzYMXnqqaeq7ausrJQDBw6IyL/X6zVr1kzmzJkjSinnMbNmzarzPfr06SOdOnWSWbNmOa93XPBrHf8+ZvMxXtSI/9ATMNET/+nRo4dkZGTIvHnz5NixY87PCwoKxOfzyXXXXRfS60Q7egImegImeiJ8IurMxoABA2T8+PEyffp02bp1qwwePFiaNWsmP/74o7zzzjsye/Zsue666yQpKUmmTJki06dPl+HDh0tOTo4UFRXJ6tWr5fTTT6/1PWJiYqSgoECuuuoqyczMlDFjxkhycrJ8//33sm3bNlm7dq2IiGRlZYmIyD333CNDhgyRU045RUaNGuVJjcdt2LBBNmzYICL/Tqvl5eWSl5cnIv/etfLSSy9t6G911KAndPQEPWF69tlnZcSIETJ48GAZNWqUFBcXy9y5c2Xs2LFy9tlnn9xvdpSgJ3QcJ+gJEz1BT5g87Qk3v9rq+NeSffXVV7U+Ljc3VyUmJta4f968eSorK0vFx8crv9+vzj33XDV16lQVCAScxxw7dkw98cQTKjk5WcXHx6uBAweq4uJilZ6eXuvXkh23ceNGdcUVVyi/368SExNVr1691Jw5c5z9lZWVauLEiSopKUn5fL5qX1HmZo01mTZtmhKRE/6aNm1anc+PBPQEPWGiJ9ztCaWUWrZsmcrMzFQtWrRQaWlp6pFHHlEVFRUhPTcS0BMcJ0z0BD1hoieityd8SgWdfwEAAAAAl0TUmg0AAAAAjQfDBgAAAAArGDYAAAAAWMGwAQAAAMAKhg0AAAAAVjBsAAAAALAipJv6VVVVSSAQEL/fX+Mt1hGZlFJy8OBBSUlJkZgY92ZLeiJ60RMIZqsfROiJaMUxAiZ6Aqb69ERIw0YgEJCOHTu6UhzCY/fu3ZKWluba69ET0Y+eQDC3+0GEnoh2HCNgoidgCqUnQhpP/X6/KwUhfNz+DOmJ6EdPIJiNz4+eiG4cI2CiJ2AK5TMMadjg1Fb0c/szpCeiHz2BYDY+P3oiunGMgImegCmUz5AF4gAAAACsYNgAAAAAYAXDBgAAAAArGDYAAAAAWMGwAQAAAMAKhg0AAAAAVjBsAAAAALCCYQMAAACAFQwbAAAAAKxg2AAAAABgBcMGAAAAACsYNgAAAABYwbABAAAAwIrYcBcARIqkpCRn+95779X27d27V8uzZ8/2pCZ4KyUlRcsjRoxwth9++GFt35IlS2p9rb///lvL+fn5Wi4tLW1IiYgi3bt31/L333+v5UmTJml5zpw51mvCyRk7dqyW582b52zHxPDvt9CtXbtWy4sWLdLy66+/7mU5YcOfDAAAAABWMGwAAAAAsIJhAwAAAIAVTXbNhlLK2a6qqqr1sRMmTNDyK6+8YqUmhNeQIUOc7QcffFDb9/LLL3tdDizw+/1avuaaa7Q8d+5cLSckJDjbwccMkerX25t8Pp+W77zzTi0XFhZq+amnntJycXFxra+PyNe7d28tm3/X7Nmzx8ty4ALzz715XACC1+58+eWX2r7Nmzd7XU5E4MwGAAAAACsYNgAAAABYwbABAAAAwIomu2Yj+NrZuq65NL8fnzUbjdMbb7zhbC9YsCCMlcAt5vXVt956q5Z79OhR6/OD+2D+/Pn1eu/JkydruV+/flq+9tprtXzRRRdpedCgQVresWNHvd4f4ZeZmanl8vJyLS9btszDauCGb7/9VsvnnHNOmCpBpApe/2nen+mtt97yupyIwJkNAAAAAFYwbAAAAACwgmEDAAAAgBVNds0GUB/mtdeITFlZWVp+8skntZyYmKhl894Wr776qpZP5j4Io0aN0vL999+v5aefflrLycnJWh42bJiWWbMR+Xr27Knlu+++W8sLFy70shy4YOjQoVq+/vrra3xsmzZttHzgwAELFSHS1XbvtpycHC1v377ddjkRgTMbAAAAAKxg2AAAAABgBcMGAAAAACtYswGEoG/fvuEuASEw71vQqlUrLW/atEnLTzzxhPWajpsxY4aWt23bpuX33ntPy1OmTNGyeX+fQ4cOuVgd3JCRkaFlc43Q4sWLvSwHLigpKdFyRUWFluPi4pztq6++WtvH/ZpgKisrC3cJYcGZDQAAAABWMGwAAAAAsIJhAwAAAIAVTXbNhs/nC3cJAFymlNKy+X3ngUDAy3Jq9f7772t5y5YtWjbXCeXn52t5zJgxdgpDg02dOlXLu3bt0vLXX3/tZTlwwdatW7VsrpWKj493tm+99VZtH2s2mqYLL7ywxn3t2rXzsJLIwZkNAAAAAFYwbAAAAACwgmEDAAAAgBVNds1G8LXd5nXeAKLDoEGDtJyUlKRl8/rq559/3npNDZWXl6fl5cuXa7lr165eloMQnHnmmVo+//zztbxjxw4tl5eX2y4Jli1dulTL48aNc7ZTU1O1faeddpqW//jjD3uFARGMMxsAAAAArGDYAAAAAGAFwwYAAAAAK5rsmg0A0c+8Jrp58+ZaXrlypZa/+OIL6zU1lHnfDUS+AQMG1Lp/3759HlUCr7z33ntaDl6z4ff7tX3B9+AAREQ++uijcJcQFpzZAAAAAGAFwwYAAAAAKxg2AAAAAFjBmg0AUevee+/Vss/n0/KGDRu8LMdVn376qZa7dOmi5fT0dC3v2rXLek3QnXvuubXunzlzpkeVwCu7d++ucd+RI0e0XFFRYbscRJn9+/eHu4Sw4MwGAAAAACsYNgAAAABYwbABAAAAwIoms2ajrmu7a/PJJ5+4XQ6ABujevbuWu3btqmWllJflWPXNN99ouX///lpOTk7WMms2vNGvXz9ne8yYMdq+oqIiLa9bt86TmhAZUlNTtWzeB4j7rqCp4swGAAAAACsYNgAAAABY0WQuozIFX25R16UX5uUMaJwyMjKcbfMyu23btnldDk7AvCzBzI1Jt27dtGx+ZeLvv//uZTn4v+zsbGfb7L81a9Zo2fwqVES/w4cPa/ngwYPOduvWrb0uB4gKnNkAAAAAYAXDBgAAAAArGDYAAAAAWNFk12zUR1ZWVrhLgAcqKytr3PfZZ595WAncMn/+/HCX0GDDhg3T8qZNm7T8008/eVkO/u+8885zts31fkuXLvW6HHjM/IrpN99809m+4447tH3XXHONlqdPn26vMESMNm3aONs//PCDtu/XX3/1uJrIwJkNAAAAAFYwbAAAAACwgmEDAAAAgBVNZs3Gli1bGvxc7rPRNARfA29ei928eXOvy0EIzPuhmIK/Az/SJCQkaHnx4sVaNv/bysvLrdeE6jp06KDlSy65xNk2r8detmyZJzUhcnz11VfOtrlmo2fPnl6XgwgwaNAgZ7ukpETbV1FR4XU5EYEzGwAAAACsYNgAAAAAYAXDBgAAAAArmsyajU8++UTLdV3r3dDHInqlpqY62+Znnpubq+Xbb7/dk5pQO3NtjSklJUXLgUDAZjm18vv9Wl6wYIGWhw4dquWjR49qecaMGXYKQ61Gjx6t5Xbt2jnbq1ev9rgaRJq//vqrxn3Z2dkeVoJw6dy5s5YzMjKc7by8PK/LiUic2QAAAABgBcMGAAAAACsYNgAAAABY0WTWbJiCr/Wu67rvuvajcThw4ICzzWfeONx4441afvHFFz17b/M+GuYajZEjR9b6/JdeeknLhYWFrtSF+klPT69xX1lZmYeVIBK9++674S4BYRYbq/+vdLNmzcJUSeTizAYAAAAAKxg2AAAAAFjBsAEAAADAiia7ZqM+fvrpp3CXAA+Ul5eHuwTU4bffftNySUmJlpOTk7Xcp08f6zXVZOrUqVqua43GhAkTtPzWW2+5XRIaYPjw4TXuW7lypYeVINKZ92dKSkrS8ogRI7S8YsUK6zXBPvPvod27d4epksjFmQ0AAAAAVjBsAAAAALCCYQMAAACAFazZCMEHH3wQ7hIAiMgvv/yi5fvuu0/Lb7/9tpZvvvlmLR86dEjLDz30kJb379+v5eDvTzfXg3Tt2lXLjz32mJYHDBig5aNHj2p54sSJWp4/f74g/Pr376/lDh06hKkSRBvz/kzcr6lp6N27t5bPOOMMZ3vv3r1elxOROLMBAAAAwAqGDQAAAABWMGwAAAAAsKLJrtkoKipyts3r7QBEh88//1zLxcXFWu7Zs6eWb7vtNi0PGTJEy5s3b9ZyQkKCsz106NB61WauL5kxY4aWWaMRma6++motn3LKKVoO/rtjw4YNntQEIHJdfPHFWt6zZ4+zvWTJEq/LiUic2QAAAABgBcMGAAAAACsYNgAAAABY0WTXbAQCAWc7MzMzfIUgIm3atEnLl1xyiZZvueUWLb/xxhvWa0J1wdfGiohceeWVWi4sLNSyeW+MtLS0WrPP53O26/rOfPN+PHfccYeWS0pKan0+wiN4XY6ISE5OTq2PX7p0qbN97NgxKzUhOh04cEDLrVu3Dk8hCKu//vrL2TZ7oqnizAYAAAAAKxg2AAAAAFjBsAEAAADAiia7ZmPVqlXOtnmN7kcffaTl0tJST2pC5MjNzdXyrFmztGx+9z4ig7kuol+/flo212w8+uijWh42bJiWX3zxxRpfe/HixbW+N9fzR4d//vlHy2VlZVpesWKFlmfPnm29JkQn8z4+5r10Nm7c6GU5CBNzLSE4swEAAADAEoYNAAAAAFYwbAAAAACwwqfq+vJ4+fc7gxvb90XHxv63XOWRRx7R9h06dEjLM2fO9KQmm/78809p1aqVa6/XGHuiqaEnEMztfhChJ6IdxwiY6AmYQukJzmwAAAAAsIJhAwAAAIAVDBsAAAAArGiy99morKx0th9//PHwFQIAAAA0UpzZAAAAAGAFwwYAAAAAKxg2AAAAAFjBsAEAAADACoYNAAAAAFYwbAAAAACwgmEDAAAAgBUMGwAAAACsYNgAAAAAYAXDBgAAAAArQho2lFK264Blbn+G9ET0oycQzMbnR09EN44RMNETMIXyGYY0bBw8ePCki0F4uf0Z0hPRj55AMBufHz0R3ThGwERPwBTKZ+hTIYwkVVVVEggExO/3i8/nc6U4eEMpJQcPHpSUlBSJiXHvqjl6InrREwhmqx9E6IloxTECJnoCpvr0REjDBgAAAADUFwvEAQAAAFjBsAEAAADACoYNAAAAAFYwbAAAAACwgmEDAAAAgBUMGwAAAACsYNgAAAAAYMX/AD6N543lY9uVAAAAAElFTkSuQmCC", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Plot predicted labels\n", "\n", "n_samples_show = 6\n", "count = 0\n", "fig, axes = plt.subplots(nrows=1, ncols=n_samples_show, figsize=(10, 3))\n", "\n", "model5.eval()\n", "with no_grad():\n", " for batch_idx, (data, target) in enumerate(test_loader):\n", " if count == n_samples_show:\n", " break\n", " output = model5(data[0:1])\n", " if len(output.shape) == 1:\n", " output = output.reshape(1, *output.shape)\n", "\n", " pred = output.argmax(dim=1, keepdim=True)\n", "\n", " axes[count].imshow(data[0].numpy().squeeze(), cmap=\"gray\")\n", "\n", " axes[count].set_xticks([])\n", " axes[count].set_yticks([])\n", " axes[count].set_title(\"Predicted {}\".format(pred.item()))\n", "\n", " count += 1" ] }, { "cell_type": "markdown", "id": "prompt-visibility", "metadata": {}, "source": [ "🎉🎉🎉🎉\n", "**You are now able to experiment with your own hybrid datasets and architectures using Qiskit Machine Learning.** \n", "**Good Luck!**" ] }, { "cell_type": "code", "execution_count": 27, "id": "related-wheat", "metadata": {}, "outputs": [ { "data": { "text/html": [ "

Version Information

SoftwareVersion
qiskit0.45.2
qiskit_algorithms0.2.2
qiskit_machine_learning0.7.1
System information
Python version3.10.13
Python compilerGCC 9.4.0
Python buildmain, Jan 10 2024 19:45:45
OSLinux
CPUs1
Memory (Gb)7.744113922119141
Wed Jan 24 06:55:41 2024 UTC
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "

This code is a part of Qiskit

© Copyright IBM 2017, 2024.

This code is licensed under the Apache License, Version 2.0. You may
obtain a copy of this license in the LICENSE.txt file in the root directory
of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.

Any modifications or derivative works of this code must retain this
copyright notice, and modified files need to carry a notice indicating
that they have been altered from the originals.

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import tutorial_magics\n", "\n", "%qiskit_version_table\n", "%qiskit_copyright" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.13" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 5 }