add notebook, support interactive output to enable visualizing results (#1)
needed to install some virtual display software. `gcc` may be required to add to `./headless.sh` Co-authored-by: Michael Pilosov <consistentbayes@gmail.com> Reviewed-on: #1 Co-authored-by: mm <mm@clfx.cc> Co-committed-by: mm <mm@clfx.cc>
This commit is contained in:
parent
f85551a16e
commit
cdee785166
86
DemoGym.ipynb
Normal file
86
DemoGym.ipynb
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "71383e8d-63f1-462c-bd77-688d8d34a60a",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Demonstration of `gym`: Visualize Interactive Results in Jupyter Notebook"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "eae51654-4ccf-44ed-aaac-f1d993d7e4a1",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import matplotlib.pyplot as plt\n",
|
||||||
|
"from pyvirtualdisplay import Display\n",
|
||||||
|
"display = Display(visible=0, size=(1400, 900))\n",
|
||||||
|
"display.start()\n",
|
||||||
|
"\n",
|
||||||
|
"is_ipython = 'inline' in plt.get_backend()\n",
|
||||||
|
"if is_ipython:\n",
|
||||||
|
" from IPython import display\n",
|
||||||
|
"\n",
|
||||||
|
"plt.ion()\n",
|
||||||
|
"\n",
|
||||||
|
"# Load the gym environment"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "be872e01-e4fd-4940-874e-d46e97fb3519",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import gym\n",
|
||||||
|
"import random\n",
|
||||||
|
"%matplotlib inline\n",
|
||||||
|
"\n",
|
||||||
|
"env = gym.make('LunarLander-v2')\n",
|
||||||
|
"env.seed(23)\n",
|
||||||
|
"\n",
|
||||||
|
"# Let's watch how an untrained agent moves around\n",
|
||||||
|
"\n",
|
||||||
|
"state = env.reset()\n",
|
||||||
|
"img = plt.imshow(env.render(mode='rgb_array'))\n",
|
||||||
|
"for j in range(200):\n",
|
||||||
|
"# action = agent.act(state)\n",
|
||||||
|
" action = random.choice(range(4))\n",
|
||||||
|
" img.set_data(env.render(mode='rgb_array')) \n",
|
||||||
|
" plt.axis('off')\n",
|
||||||
|
" display.display(plt.gcf())\n",
|
||||||
|
" display.clear_output(wait=True)\n",
|
||||||
|
" state, reward, done, _ = env.step(action)\n",
|
||||||
|
" if done:\n",
|
||||||
|
" break\n",
|
||||||
|
"\n",
|
||||||
|
"env.close()"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.9.7"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
74
DemoMUD.ipynb
Normal file
74
DemoMUD.ipynb
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "2e848ca9-c915-4aa2-a7cc-a5654ed06863",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Demonstration of Training and Testing"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "a9506e99-a947-4f69-8355-a3ce696793fa",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from main import train, test\n",
|
||||||
|
"import pickle"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "de500d9e-40d1-4b6b-900f-96c2ec69e464",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"data = pickle.load(open(\"data.pkl\", \"rb\"))"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "39ca7791-c844-4231-9f3b-e8ae80fe8103",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"mud_point = train(data)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "8d8c70ab-d055-418c-b67e-ba5109d989f3",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"test(mud_point)"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.9.7"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
@ -45,3 +45,7 @@ Using the following presumptions, we can establish better values for the "data v
|
|||||||
|
|
||||||
> The angular momentum of the pole is the most important thing to stabilize.
|
> The angular momentum of the pole is the most important thing to stabilize.
|
||||||
|
|
||||||
|
# headless mode
|
||||||
|
|
||||||
|
Run `./headless.sh` (requires `sudo`) to install virtual displays so you can view results in a Jupyter notebook.
|
||||||
|
|
||||||
|
3
headless.sh
Executable file
3
headless.sh
Executable file
@ -0,0 +1,3 @@
|
|||||||
|
#!/bin/sh
|
||||||
|
sudo apt update && sudo apt install build-essential xvfb swig
|
||||||
|
pip install box2d-py pyvirtualdisplay
|
24
main.py
24
main.py
@ -7,6 +7,20 @@ import pandas as pd
|
|||||||
from scipy.stats import gaussian_kde as gkde
|
from scipy.stats import gaussian_kde as gkde
|
||||||
from scipy.stats import norm
|
from scipy.stats import norm
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
try:
|
||||||
|
from pyvirtualdisplay import Display
|
||||||
|
display = Display(visible=0, size=(1400, 900))
|
||||||
|
display.start()
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
is_ipython = 'inline' in plt.get_backend()
|
||||||
|
if is_ipython:
|
||||||
|
from IPython import display
|
||||||
|
|
||||||
|
plt.ion()
|
||||||
|
|
||||||
def train(data):
|
def train(data):
|
||||||
D = pd.DataFrame(data)
|
D = pd.DataFrame(data)
|
||||||
@ -28,11 +42,19 @@ def test(decision=np.array([-0.09, -0.71, -0.43, -0.74]), seed=1992):
|
|||||||
env = gym.make("CartPole-v1")
|
env = gym.make("CartPole-v1")
|
||||||
observation, info = env.reset(seed=seed, return_info=True)
|
observation, info = env.reset(seed=seed, return_info=True)
|
||||||
score = 0
|
score = 0
|
||||||
|
if is_ipython:
|
||||||
|
img = plt.imshow(env.render(mode='rgb_array'))
|
||||||
for i in range(10000):
|
for i in range(10000):
|
||||||
action = 1 if decision.T @ observation < 0 else 0
|
action = 1 if decision.T @ observation < 0 else 0
|
||||||
observation, reward, done, info = env.step(action)
|
observation, reward, done, info = env.step(action)
|
||||||
score += reward
|
score += reward
|
||||||
env.render()
|
if not is_ipython:
|
||||||
|
env.render()
|
||||||
|
else:
|
||||||
|
img.set_data(env.render(mode='rgb_array'))
|
||||||
|
plt.axis('off')
|
||||||
|
display.display(plt.gcf())
|
||||||
|
display.clear_output(wait=True)
|
||||||
if done:
|
if done:
|
||||||
if score == 500:
|
if score == 500:
|
||||||
print("WIN")
|
print("WIN")
|
||||||
|
Loading…
Reference in New Issue
Block a user