|
35 | 35 | "source": [ |
36 | 36 | "import torch\n", |
37 | 37 | "from torch import nn\n", |
| 38 | + "from evotorch.decorators import pass_info\n", |
38 | 39 | "\n", |
| 40 | + "\n", |
| 41 | + "# The decorator `@pass_info` used below tells the problem class `GymNE`\n", |
| 42 | + "# to pass information regarding the gym environment via keyword arguments\n", |
| 43 | + "# such as `obs_length` and `act_length`.\n", |
| 44 | + "@pass_info\n", |
39 | 45 | "class LinearPolicy(nn.Module):\n", |
40 | | - " \n", |
41 | 46 | " def __init__(\n", |
42 | 47 | " self, \n", |
43 | 48 | " obs_length: int, # Number of observations from the environment\n", |
44 | 49 | " act_length: int, # Number of actions of the environment\n", |
45 | 50 | " bias: bool = True, # Whether the policy should use biases\n", |
46 | 51 | " **kwargs # Anything else that is passed\n", |
47 | | - " ):\n", |
| 52 | + " ):\n", |
48 | 53 | " super().__init__() # Always call super init for nn Modules\n", |
49 | 54 | " self.linear = nn.Linear(obs_length, act_length, bias = bias)\n", |
50 | 55 | " \n", |
|
71 | 76 | "from evotorch.neuroevolution import GymNE\n", |
72 | 77 | "\n", |
73 | 78 | "problem = GymNE(\n", |
74 | | - " env_name=\"LunarLanderContinuous-v2\", # Name of the environment\n", |
| 79 | + " env=\"LunarLanderContinuous-v2\", # Name of the environment\n", |
75 | 80 | " network=LinearPolicy, # Linear policy that we defined earlier\n", |
76 | 81 | " network_args = {'bias': False}, # Linear policy should not use biases\n", |
77 | 82 | " num_actors= 4, # Use 4 available CPUs. Note that you can modify this value, or use 'max' to exploit all available CPUs\n", |
|
189 | 194 | "outputs": [], |
190 | 195 | "source": [ |
191 | 196 | "problem = GymNE(\n", |
192 | | - " env_name=\"LunarLanderContinuous-v2\",\n", |
| 197 | + " env=\"LunarLanderContinuous-v2\",\n", |
193 | 198 | " network=LinearPolicy,\n", |
194 | 199 | " network_args = {'bias': False},\n", |
195 | 200 | " num_actors= 4, \n", |
|
250 | 255 | "id": "3dcb5243", |
251 | 256 | "metadata": {}, |
252 | 257 | "source": [ |
253 | | - "And once again we can visualize the learned policy. As `CoSyNE` is population based, it does not maintain a 'best estimate' of a good policy. Instead, we simply take the best performing solution from the current population. " |
| 258 | + "And once again we can visualize the learned policy. As `Cosyne` is population based, it does not maintain a 'best estimate' of a good policy. Instead, we simply take the best performing solution from the current population. " |
254 | 259 | ] |
255 | 260 | }, |
256 | 261 | { |
|
296 | 301 | "name": "python", |
297 | 302 | "nbconvert_exporter": "python", |
298 | 303 | "pygments_lexer": "ipython3", |
299 | | - "version": "3.8.13" |
| 304 | + "version": "3.7.13" |
300 | 305 | } |
301 | 306 | }, |
302 | 307 | "nbformat": 4, |
|
0 commit comments