add notebook, support interactive output to enable visualizing results #1
86
DemoGym.ipynb
Normal file
86
DemoGym.ipynb
Normal file
@ -0,0 +1,86 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "71383e8d-63f1-462c-bd77-688d8d34a60a",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Demonstration of `gym`: Visualize Interactive Results in Jupyter Notebook"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "eae51654-4ccf-44ed-aaac-f1d993d7e4a1",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"from pyvirtualdisplay import Display\n",
|
||||
"display = Display(visible=0, size=(1400, 900))\n",
|
||||
"display.start()\n",
|
||||
"\n",
|
||||
"is_ipython = 'inline' in plt.get_backend()\n",
|
||||
"if is_ipython:\n",
|
||||
" from IPython import display\n",
|
||||
"\n",
|
||||
"plt.ion()\n",
|
||||
"\n",
|
||||
"# Load the gym environment"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "be872e01-e4fd-4940-874e-d46e97fb3519",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import gym\n",
|
||||
"import random\n",
|
||||
"%matplotlib inline\n",
|
||||
"\n",
|
||||
"env = gym.make('LunarLander-v2')\n",
|
||||
"env.seed(23)\n",
|
||||
"\n",
|
||||
"# Let's watch how an untrained agent moves around\n",
|
||||
"\n",
|
||||
"state = env.reset()\n",
|
||||
"img = plt.imshow(env.render(mode='rgb_array'))\n",
|
||||
"for j in range(200):\n",
|
||||
"# action = agent.act(state)\n",
|
||||
" action = random.choice(range(4))\n",
|
||||
" img.set_data(env.render(mode='rgb_array')) \n",
|
||||
" plt.axis('off')\n",
|
||||
" display.display(plt.gcf())\n",
|
||||
" display.clear_output(wait=True)\n",
|
||||
" state, reward, done, _ = env.step(action)\n",
|
||||
" if done:\n",
|
||||
" break\n",
|
||||
"\n",
|
||||
"env.close()"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.7"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
74
DemoMUD.ipynb
Normal file
74
DemoMUD.ipynb
Normal file
@ -0,0 +1,74 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2e848ca9-c915-4aa2-a7cc-a5654ed06863",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Demonstration of Training and Testing"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a9506e99-a947-4f69-8355-a3ce696793fa",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from main import train, test\n",
|
||||
"import pickle"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "de500d9e-40d1-4b6b-900f-96c2ec69e464",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"data = pickle.load(open(\"data.pkl\", \"rb\"))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "39ca7791-c844-4231-9f3b-e8ae80fe8103",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"mud_point = train(data)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8d8c70ab-d055-418c-b67e-ba5109d989f3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"test(mud_point)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.7"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
@ -45,3 +45,7 @@ Using the following presumptions, we can establish better values for the "data v
|
||||
|
||||
> The angular momentum of the pole is the most important thing to stabilize.
|
||||
|
||||
# headless mode
|
||||
|
||||
Run `./headless.sh` (requires `sudo`) to install virtual displays so you can view results in a Jupyter notebook.
|
||||
|
||||
|
3
headless.sh
Executable file
3
headless.sh
Executable file
@ -0,0 +1,3 @@
|
||||
#!/bin/sh
|
||||
sudo apt update && sudo apt install build-essential xvfb swig
|
||||
pip install box2d-py pyvirtualdisplay
|
24
main.py
24
main.py
@ -7,6 +7,20 @@ import pandas as pd
|
||||
from scipy.stats import gaussian_kde as gkde
|
||||
from scipy.stats import norm
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
try:
|
||||
from pyvirtualdisplay import Display
|
||||
display = Display(visible=0, size=(1400, 900))
|
||||
display.start()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
is_ipython = 'inline' in plt.get_backend()
|
||||
if is_ipython:
|
||||
from IPython import display
|
||||
|
||||
plt.ion()
|
||||
|
||||
def train(data):
|
||||
D = pd.DataFrame(data)
|
||||
@ -28,11 +42,19 @@ def test(decision=np.array([-0.09, -0.71, -0.43, -0.74]), seed=1992):
|
||||
env = gym.make("CartPole-v1")
|
||||
observation, info = env.reset(seed=seed, return_info=True)
|
||||
score = 0
|
||||
if is_ipython:
|
||||
img = plt.imshow(env.render(mode='rgb_array'))
|
||||
for i in range(10000):
|
||||
action = 1 if decision.T @ observation < 0 else 0
|
||||
observation, reward, done, info = env.step(action)
|
||||
score += reward
|
||||
env.render()
|
||||
if not is_ipython:
|
||||
env.render()
|
||||
else:
|
||||
img.set_data(env.render(mode='rgb_array'))
|
||||
plt.axis('off')
|
||||
display.display(plt.gcf())
|
||||
display.clear_output(wait=True)
|
||||
if done:
|
||||
if score == 500:
|
||||
print("WIN")
|
||||
|
Loading…
Reference in New Issue
Block a user