{ "cells": [ { "cell_type": "markdown", "id": "3906ee5a-5a8a-4c77-a357-378a1d129895", "metadata": {}, "source": [ "## Exercise 8: Policy gradient\n", "\n", "In this exercise, we solve a continuous actions version of cart pole using Reinforce." ] }, { "cell_type": "markdown", "id": "b31144ea-ef4e-449f-932f-9f0f082ecff0", "metadata": {}, "source": [ "### 8.1 Reinforce\n", "\n", "In exercise07 we looked at DQN, where we used function approximation to learn on continuous state spaces. To additionaly extend the action space to a continuous action spaces, one typically uses policy gradient methods." ] }, { "cell_type": "markdown", "id": "3253e1a0-994b-451d-a0e7-4c3e3cbbe400", "metadata": {}, "source": [ "#### 8.1 Parameterized Policies\n", "\n", "For this we need a parameterized policy $\\pi_\\theta$ that gets as an input a state $x$ and returns a distribution of possible actions $\\pi_\\theta( \\cdot | x)$.\n", "\n", "Similar to the previous exercise we also use a neural network for this. The following implementation uses a neural network to predict the mean $\\mu$ of a gaussian, whereas $\\sigma$, the standard deviation, is fixed." ] }, { "cell_type": "code", "execution_count": 2, "id": "2151a12f-d37d-49a7-9fa2-af90ee3fbbc8", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from torch.distributions import Normal\n", "\n", "\n", "class Policy(nn.Module):\n", " def __init__(self,\n", " env,\n", " sigma=0.1):\n", " super(Policy, self).__init__()\n", " self.env = env\n", " self.set_sigma(sigma)\n", " \n", " self.features = nn.Linear(np.prod(env.observation_space.shape),\n", " np.prod(env.action_space.shape),\n", " bias=False)\n", "\n", " def set_sigma(self, sigma):\n", " self.sigma = torch.full(self.env.action_space.shape, sigma)\n", " \n", " def forward(self, x, action=None):\n", " x = torch.as_tensor(x).float().flatten()\n", " x = self.features(x).reshape(self.sigma.shape)\n", " \n", " distr = Normal(x, self.sigma) # The output of the features are used as the mean.\n", "\n", " if action is not None:\n", " return distr.log_prob(action) # The log of the probabiliy density function.\n", "\n", " action = distr.sample()\n", "\n", " # We clip the action to ensure that they are in bound.\n", " return action.clip(self.env.action_space.low.item(),\n", " self.env.action_space.high.item())" ] }, { "cell_type": "markdown", "id": "10dcacd8-7acd-457b-a80a-852a021d06a4", "metadata": {}, "source": [ "**Sampling an action**:\n", "\n", "Using the policy we can now sample an action:" ] }, { "cell_type": "code", "execution_count": 3, "id": "66007620-4d15-4325-b5e1-bcaf63229229", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/jasperoni/.environments/mpcNrl/lib/python3.8/site-packages/gym/logger.py:30: UserWarning: \u001b[33mWARN: Box bound precision lowered by casting to float32\u001b[0m\n", " warnings.warn(colorize('%s: %s'%('WARN', msg % args), 'yellow'))\n" ] }, { "data": { "text/plain": [ "Text(0.5, 1.0, 'Action Distribution')" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAEICAYAAACktLTqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAAVb0lEQVR4nO3df7RlZX3f8fdHUGzEFpARYSAOmNFIVutoJ4SoiT9wCYIVXCU4GmW02NEU29iaNqMul/QHzdjVaOMy0ZJIwERBgrqcBowiwrK6BB1ciPwQGXQQxoEZQYiGSAW//eM8N24u9879de69Mw/v11pnnX2e/et7nnPv5+7znH32TVUhSerLY5a7AEnS+BnuktQhw12SOmS4S1KHDHdJ6pDhLkkdMty1x0nyoSTvWqZ9/0aSm8e4vc8kWd+mX5/kS2Pc9m8n+dy4tqe+xPPcNU5JrgSeBTylqh6YxfKvB95YVc9f5NJIchbwTuAnrWkH8Dng7KraMY9t/VJVvXYO67yeeT7XJKuA7wKPraoH57q+Hn08ctfYtAD6DaCAVyxvNdP6eFU9ETgIeCXwFOCaJIeOcycZ8fdLy8YfPo3T6cBVwHnA+uGMJEck+WSSXUnuTvKBJM8EPgT8epIfJ7m3LXtekv82WPdfJ9ma5J4km5McNphXSd6c5JYk9yb54ySZqdCq+mlV3QC8CtgFvK1t74VJ7hhs//eTbE/yoyQ3JzkuyQnAO4BXtbq/0Za9MsnZSb4M3A8c1dre+PCuyAeS3JfkW0mOG8zYluQlg8dnJfnL9vCL7f7ets9fnzzMk+S5Sb7Wtv21JM8dzLsyyX9N8uX2XD6X5OCZ+kl7L8Nd43Q68NF2Oz7JIQBJ9gH+GrgNWAWsBC6sqpuANwNfqar9q+qAyRtM8mLgD4DTgEPbNi6ctNjLgV8F/llb7vjZFlxVDwGfZvSOY/K+nwG8BfjVdrR/PLCtqv4G+O+M3gXsX1XPGqz2OmAD8MRW62S/BtwKHAy8G/hkkoNmUepvtvsD2j6/MqnWg4BLgPcDTwLeC1yS5EmDxV4DvAF4MvA44PdmsV/tpQx3jUWS5wNPBS6qqmsYBdhr2uxjgMOA/1hVf1dVP6mq2X6w+NvAuVX19TaG/3ZGR/qrBstsqqp7q+p7wBXAmjmW/31GwzSTPQTsBxyd5LFVta2qbp1hW+dV1Q1V9WBV/XSK+TuB/9XeOXwcuBk4aY71TuUk4Jaq+ou27wuAbwH/YrDMn1fVt6vq74GLmHs/aS9iuGtc1gOfq6oftMcf4+dDM0cAt83zg8DDGBwBV9WPgbsZHf1PuHMwfT+w/xz3sRK4Z3JjVW0F3gqcBexMcuFwSGgat88wf3s9/CyG2xg9x4V6WD8Ntj3OftJexHDXgiX5R4yGQ16Q5M4kdwL/HnhWkmcxCrxfTLLvFKvPdLrW9xm9I5jY1xMYDTtsH1Ptj2F0dPt/p5pfVR9rZ7c8tdX6nolZ02xypuezctJnAr/I6DkC/B3wC4N5T5nDdh/WT4Ntj6WftPcx3DUOpzAawjia0Vv9NcAzGQXm6cBXGZ12uCnJE5I8Psnz2rp3AYcnedw0274AeEOSNUn2YzTWfXVVbVtIwUn2bR/oXsAoRN87xTLPSPLitt+fAH8P/GxQ96p5nBHzZODfJXlskt9i1E+XtnnXAuvavLXAqYP1drV9HzXNdi8Fnp7kNe25vYrR6/HXc6xPnTDcNQ7rGY3nfq+q7py4AR9gNGYeRkfHvwR8D7iD0VkqAF8AbgDuTPKDyRuuqs8D7wI+wegPxNOAdQuo9VVJfgzcB2xmNMTzz6vq+1Msux+wCfgBoyGNJzMa8wf4q3Z/d5Kvz2H/VwOr2zbPBk6tqrvbvHcxen4/BP4zo6EtAKrq/rb8l9tZQccON9q28XJGZ/3cDfwn4OWDYTI9yvglJknqkEfuktQhw12SOmS4S1KHDHdJ6tBU5x0vuYMPPrhWrVq13GVI0l7lmmuu+UFVrZhq3h4R7qtWrWLLli3LXYYk7VWSTHX9IsBhGUnqkuEuSR0y3CWpQ4a7JHXIcJekDhnuktQhw12SOmS4S1KHDHdJ6tAe8Q1VaSarNl6ybPvetmkc/79aWloeuUtShwx3SeqQ4S5JHZox3JMckeSKJDcmuSHJ77b2s5JsT3Jtu504WOftSbYmuTnJ8Yv5BCRJjzSbD1QfBN5WVV9P8kTgmiSXtXnvq6r/OVw4ydGM/jv9rwCHAZ9P8vSqemichUuSpjfjkXtV7aiqr7fpHwE3ASt3s8rJwIVV9UBVfRfYChwzjmIlSbMzpzH3JKuAZwNXt6a3JLkuyblJDmxtK4HbB6vdwRR/DJJsSLIlyZZdu3bNvXJJ0rRmHe5J9gc+Aby1qv4W+CDwNGANsAP4w7nsuKrOqaq1VbV2xYop/0uUJGmeZhXuSR7LKNg/WlWfBKiqu6rqoar6GfCn/HzoZTtwxGD1w1ubJGmJzOZsmQAfBm6qqvcO2g8dLPZK4Po2vRlYl2S/JEcCq4Gvjq9kSdJMZnO2zPOA1wHfTHJta3sH8Ooka4ACtgFvAqiqG5JcBNzI6EybMz1TRpKW1ozhXlVfAjLFrEt3s87ZwNkLqEuStAB+Q1WSOmS4S1KHDHdJ6pDhLkkdMtwlqUOGuyR1yHCXpA4Z7pLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchwl6QOGe6S1CHDXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ4a7JHXIcJekDu273AVIe7pVGy9Zlv1u23TSsuxXffDIXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ4a7JHXIcJekDs0Y7kmOSHJFkhuT3JDkd1v7QUkuS3JLuz+wtSfJ+5NsTXJdkucs9pOQJD3cbI7cHwTeVlVHA8cCZyY5GtgIXF5Vq4HL22OAlwGr220D8MGxVy1J2q0Zw72qdlTV19v0j4CbgJXAycD5bbHzgVPa9MnAR2rkKuCAJIeOu3BJ0vTmNOaeZBXwbOBq4JCq2tFm3Qkc0qZXArcPVrujtU3e1oYkW5Js2bVr11zrliTtxqzDPcn+wCeAt1bV3w7nVVUBNZcdV9U5VbW2qtauWLFiLqtKkmYwq3BP8lhGwf7Rqvpka75rYril3e9s7duBIwarH97aJElLZDZnywT4MHBTVb13MGszsL5Nrwc+PWg/vZ01cyxw32D4RpK0BGZzyd/nAa8Dvpnk2tb2DmATcFGSM4DbgNPavEuBE4GtwP3AG8ZZsCRpZjOGe1V9Ccg0s4+bYvkCzlxgXZKkBfAbqpLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchwl6QOGe6S1CHDXZI6NJtry0j/YNXGS5a7BEmz4JG7JHXIcJekDhnuktQhw12SOmS4S1KHDHdJ6pDhLkkdMtwlqUOGuyR1yHCXpA4Z7pLUIa8tI+2hlus6Pts2nbQs+9V4eeQuSR0y3CWpQ4a7JHXIcJekDhnuktQhw12SOmS4S1KHDHdJ6tCM4Z7k3CQ7k1w/aDsryfYk17bbiYN5b0+yNcnNSY5frMIlSdObzZH7ecAJU7S/r6rWtNulAEmOBtYBv9LW+ZMk+4yrWEnS7MwY7lX1ReCeWW7vZODCqnqgqr4LbAWOWUB9kqR5WMiY+1uSXNeGbQ5sbSuB2wfL3NHaHiHJhiRbkmzZtWvXAsqQJE0233D/IPA0YA2wA/jDuW6gqs6pqrVVtXbFihXzLEOSNJV5hXtV3VVVD1XVz4A/5edDL9uBIwaLHt7aJElLaF7hnuTQwcNXAhNn0mwG1iXZL8mRwGrgqwsrUZI0VzNezz3JBcALgYOT3AG8G3hhkjVAAduANwFU1Q1JLgJuBB4EzqyqhxalcknStGYM96p69RTNH97N8mcDZy+kKEnSwvgNVUnqkOEuSR0y3CWpQ4a7JHXIcJekDhnuktQhw12SOmS4S1KHDHdJ6pDhLkkdMtwlqUOGuyR1yHCXpA4Z7pLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchwl6QOGe6S1CHDXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ4a7JHXIcJekDhnuktQhw12SOjRjuCc5N8nOJNcP2g5KclmSW9r9ga09Sd6fZGuS65I8ZzGLlyRNbTZH7ucBJ0xq2whcXlWrgcvbY4CXAavbbQPwwfGUKUmaixnDvaq+CNwzqflk4Pw2fT5wyqD9IzVyFXBAkkPHVKskaZbmO+Z+SFXtaNN3Aoe06ZXA7YPl7mhtj5BkQ5ItSbbs2rVrnmVIkqay4A9Uq6qAmsd651TV2qpau2LFioWWIUkamG+43zUx3NLud7b27cARg+UOb22SpCU033DfDKxv0+uBTw/aT29nzRwL3DcYvpEkLZF9Z1ogyQXAC4GDk9wBvBvYBFyU5AzgNuC0tvilwInAVuB+4A2LULMkaQYzhntVvXqaWcdNsWwBZy60KEnSwvgNVUnqkOEuSR0y3CWpQ4a7JHXIcJekDhnuktQhw12SOmS4S1KHDHdJ6pDhLkkdMtwlqUOGuyR1yHCXpA4Z7pLUIcNdkjpkuEtSh2b8Zx2SHl1Wbbxk2fa9bdNJy7bv3njkLkkdMtwlqUOGuyR1yHCXpA4Z7pLUIcNdkjrkqZB7oeU8VU3S3sEjd0nqkOEuSR0y3CWpQ4a7JHXIcJekDhnuktQhw12SOrSg89yTbAN+BDwEPFhVa5McBHwcWAVsA06rqh8urExJ0lyM48j9RVW1pqrWtscbgcurajVweXssSVpCizEsczJwfps+HzhlEfYhSdqNhYZ7AZ9Lck2SDa3tkKra0abvBA6ZasUkG5JsSbJl165dCyxDkjS00GvLPL+qtid5MnBZkm8NZ1ZVJampVqyqc4BzANauXTvlMpKk+VnQkXtVbW/3O4FPAccAdyU5FKDd71xokZKkuZl3uCd5QpInTkwDLwWuBzYD69ti64FPL7RISdLcLGRY5hDgU0kmtvOxqvqbJF8DLkpyBnAbcNrCy5QkzcW8w72qvgM8a4r2u4HjFlKUJGlh/IaqJHXIcJekDhnuktQhw12SOmS4S1KHDHdJ6pDhLkkdMtwlqUOGuyR1yHCXpA4Z7pLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchwl6QOLeR/qD7qrdp4yXKXIElT8shdkjrkkbukPcZyvRvetumkZdnvYvLIXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ4a7JHXIcJekDhnuktShvf5LTF4CQJIeySN3SerQoh25JzkB+CNgH+DPqmrTYu1LkhZiOUcAFuvSB4ty5J5kH+CPgZcBRwOvTnL0YuxLkvRIizUscwywtaq+U1X/D7gQOHmR9iVJmmSxhmVWArcPHt8B/NpwgSQbgA3t4Y+T3DzPfR0M/GCe6y6mPbUu2HNrs665sa652SPrynsWVNdTp5uxbGfLVNU5wDkL3U6SLVW1dgwljdWeWhfsubVZ19xY19w82uparGGZ7cARg8eHtzZJ0hJYrHD/GrA6yZFJHgesAzYv0r4kSZMsyrBMVT2Y5C3AZxmdCnluVd2wGPtiDEM7i2RPrQv23Nqsa26sa24eVXWlqhZju5KkZeQ3VCWpQ4a7JHVorwj3JL+V5IYkP0sy7SlDSU5IcnOSrUk2DtqPTHJ1a/94+5B3HHUdlOSyJLe0+wOnWOZFSa4d3H6S5JQ277wk3x3MW7NUdbXlHhrse/OgfTn7a02Sr7TX+7okrxrMG2t/TffzMpi/X3v+W1t/rBrMe3trvznJ8QupYx51/YckN7b+uTzJUwfzpnxNl6iu1yfZNdj/Gwfz1rfX/ZYk65e4rvcNavp2knsH8xazv85NsjPJ9dPMT5L3t7qvS/KcwbyF91dV7fE34JnAM4ArgbXTLLMPcCtwFPA44BvA0W3eRcC6Nv0h4HfGVNf/ADa26Y3Ae2ZY/iDgHuAX2uPzgFMXob9mVRfw42nal62/gKcDq9v0YcAO4IBx99fufl4Gy/wb4ENteh3w8TZ9dFt+P+DItp19lrCuFw1+hn5noq7dvaZLVNfrgQ9Mse5BwHfa/YFt+sClqmvS8v+W0Qkei9pfbdu/CTwHuH6a+ScCnwECHAtcPc7+2iuO3Kvqpqqa6RusU17yIEmAFwMXt+XOB04ZU2knt+3NdrunAp+pqvvHtP/pzLWuf7Dc/VVV366qW9r094GdwIox7X9oNpfIGNZ7MXBc65+TgQur6oGq+i6wtW1vSeqqqisGP0NXMfoeyWJbyCVFjgcuq6p7quqHwGXACctU16uBC8a0792qqi8yOpibzsnAR2rkKuCAJIcypv7aK8J9lqa65MFK4EnAvVX14KT2cTikqna06TuBQ2ZYfh2P/ME6u70le1+S/Za4rscn2ZLkqomhIvag/kpyDKOjsVsHzePqr+l+XqZcpvXHfYz6ZzbrLmZdQ2cwOvqbMNVrupR1/cv2+lycZOKLjHtEf7XhqyOBLwyaF6u/ZmO62sfSX3vMP+tI8nngKVPMemdVfXqp65mwu7qGD6qqkkx7Xmn7i/xPGZ37P+HtjELucYzOdf194L8sYV1PrartSY4CvpDkm4wCbN7G3F9/Aayvqp+15nn3V4+SvBZYC7xg0PyI17Sqbp16C2P3f4ALquqBJG9i9K7nxUu079lYB1xcVQ8N2pazvxbVHhPuVfWSBW5iukse3M3o7c6+7ehrTpdC2F1dSe5KcmhV7WhhtHM3mzoN+FRV/XSw7Ymj2AeS/Dnwe0tZV1Vtb/ffSXIl8GzgEyxzfyX5x8AljP6wXzXY9rz7awqzuUTGxDJ3JNkX+CeMfp4W8/Ias9p2kpcw+oP5gqp6YKJ9mtd0HGE1Y11Vdffg4Z8x+oxlYt0XTlr3yjHUNKu6BtYBZw4bFrG/ZmO62sfSXz0Ny0x5yYMafUJxBaPxboD1wLjeCWxu25vNdh8x1tcCbmKc+xRgyk/VF6OuJAdODGskORh4HnDjcvdXe+0+xWgs8uJJ88bZX7O5RMaw3lOBL7T+2Qysy+hsmiOB1cBXF1DLnOpK8mzgfwOvqKqdg/YpX9MlrOvQwcNXADe16c8CL231HQi8lIe/g13Uulptv8zow8mvDNoWs79mYzNwejtr5ljgvnYAM57+WqxPisd5A17JaNzpAeAu4LOt/TDg0sFyJwLfZvSX952D9qMY/fJtBf4K2G9MdT0JuBy4Bfg8cFBrX8vov09NLLeK0V/jx0xa/wvANxmF1F8C+y9VXcBz276/0e7P2BP6C3gt8FPg2sFtzWL011Q/L4yGeV7Rph/fnv/W1h9HDdZ9Z1vvZuBlY/55n6muz7ffg4n+2TzTa7pEdf0BcEPb/xXALw/W/VetH7cCb1jKutrjs4BNk9Zb7P66gNHZXj9llF9nAG8G3tzmh9E/Nbq17X/tYN0F95eXH5CkDvU0LCNJagx3SeqQ4S5JHTLcJalDhrskdchwl6QOGe6S1KH/DzBiojotxYhhAAAAAElFTkSuQmCC\n", "text/plain": [ "