{ "cells": [ { "cell_type": "markdown", "id": "3906ee5a-5a8a-4c77-a357-378a1d129895", "metadata": {}, "source": [ "## Exercise 8: Policy gradient\n", "\n", "In this exercise, we solve a continuous actions version of cart pole using Reinforce." ] }, { "cell_type": "markdown", "id": "b31144ea-ef4e-449f-932f-9f0f082ecff0", "metadata": {}, "source": [ "### 8.1 Reinforce\n", "\n", "In exercise07 we looked at DQN, where we used function approximation to learn on continuous state spaces. To additionaly extend the action space to a continuous action spaces, one typically uses policy gradient methods." ] }, { "cell_type": "markdown", "id": "3253e1a0-994b-451d-a0e7-4c3e3cbbe400", "metadata": {}, "source": [ "#### 8.1 Parameterized Policies\n", "\n", "For this we need a parameterized policy $\\pi_\\theta$ that gets as an input a state $x$ and returns a distribution of possible actions $\\pi_\\theta( \\cdot | x)$.\n", "\n", "Similar to the previous exercise we also use a neural network for this. The following implementation uses a neural network to predict the mean $\\mu$ of a gaussian, whereas $\\sigma$, the standard deviation, is fixed." ] }, { "cell_type": "code", "execution_count": 2, "id": "2151a12f-d37d-49a7-9fa2-af90ee3fbbc8", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from torch.distributions import Normal\n", "\n", "\n", "class Policy(nn.Module):\n", " def __init__(self,\n", " env,\n", " sigma=0.1):\n", " super(Policy, self).__init__()\n", " self.env = env\n", " self.set_sigma(sigma)\n", " \n", " self.features = nn.Linear(np.prod(env.observation_space.shape),\n", " np.prod(env.action_space.shape),\n", " bias=False)\n", "\n", " def set_sigma(self, sigma):\n", " self.sigma = torch.full(self.env.action_space.shape, sigma)\n", " \n", " def forward(self, x, action=None):\n", " x = torch.as_tensor(x).float().flatten()\n", " x = self.features(x).reshape(self.sigma.shape)\n", " \n", " distr = Normal(x, self.sigma) # The output of the features are used as the mean.\n", "\n", " if action is not None:\n", " return distr.log_prob(action) # The log of the probabiliy density function.\n", "\n", " action = distr.sample()\n", "\n", " # We clip the action to ensure that they are in bound.\n", " return action.clip(self.env.action_space.low.item(),\n", " self.env.action_space.high.item())" ] }, { "cell_type": "markdown", "id": "10dcacd8-7acd-457b-a80a-852a021d06a4", "metadata": {}, "source": [ "**Sampling an action**:\n", "\n", "Using the policy we can now sample an action:" ] }, { "cell_type": "code", "execution_count": 3, "id": "66007620-4d15-4325-b5e1-bcaf63229229", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/jasperoni/.environments/mpcNrl/lib/python3.8/site-packages/gym/logger.py:30: UserWarning: \u001b[33mWARN: Box bound precision lowered by casting to float32\u001b[0m\n", " warnings.warn(colorize('%s: %s'%('WARN', msg % args), 'yellow'))\n" ] }, { "data": { "text/plain": [ "Text(0.5, 1.0, 'Action Distribution')" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": 