mm
3 years ago
1 changed files with 69 additions and 0 deletions
@ -0,0 +1,69 @@ |
|||
import sys |
|||
import pickle |
|||
import gym |
|||
import numpy as np |
|||
from matplotlib import pyplot as plt |
|||
from matplotlib.animation import FuncAnimation |
|||
from time import sleep |
|||
|
|||
# numpy precision for printing |
|||
np.set_printoptions(precision=3, suppress=True) |
|||
|
|||
plt.ion() # interactive plotting |
|||
fig, ax = plt.subplots() |
|||
colors = ["xkcd:orange", "xkcd:forest green", "xkcd:gray", "xkcd:light blue"] |
|||
plots = [None] * 4 |
|||
|
|||
env = gym.make("CartPole-v1") |
|||
observation, info = env.reset(seed=42, return_info=True) |
|||
|
|||
max_steps = 100 |
|||
num_samples = 500 |
|||
samples = np.random.randn(num_samples, 4) |
|||
|
|||
data = [] |
|||
for lam in samples: |
|||
breakpoints = [] |
|||
score = 0 |
|||
O = [] |
|||
for n in range(max_steps): |
|||
ax.cla() |
|||
# action = env.action_space.sample() |
|||
action = 1 if lam.T @ observation < 0 else 0 |
|||
# action = 1 if observation[0] - observation[3] < 0 else 0 |
|||
observation, reward, done, info = env.step(action) |
|||
score += reward |
|||
O.append(observation.tolist()) |
|||
o = np.array(O) |
|||
var = np.var(o[-int(score) :, :], axis=0) |
|||
for q in range(4): |
|||
lines = np.hstack([o[:, q], np.zeros(max_steps - n)]) |
|||
ax.plot(range(max_steps + 1), lines, c=colors[q]) |
|||
|
|||
ax.set_title(f"Reward: {int(score)}, Variance: {var}") |
|||
ax.set_ylim([-3, 3]) |
|||
|
|||
if done or n == max_steps: |
|||
breakpoints.append(n) |
|||
observation, info = env.reset(return_info=True) |
|||
# print(score, observation) |
|||
score = 0 # reset score |
|||
|
|||
# draw break-point lines when game is lost |
|||
for b in breakpoints: |
|||
ax.vlines( |
|||
b, np.min(o, axis=0).min(), np.max(o, axis=0).max(), color="black", lw=2 |
|||
) |
|||
|
|||
fig.canvas.draw() |
|||
fig.show() |
|||
fig.canvas.flush_events() |
|||
env.render() |
|||
# sleep(0.01) |
|||
|
|||
data.append({"lam": lam, "obs": O, "break": breakpoints}) |
|||
pickle.dump(data, open("data.pkl", "wb")) # dump data frequently |
|||
|
|||
stop = input("Press any key to close.") |
|||
plt.close() |
|||
env.close() |
Loading…
Reference in new issue