From 9b2b0f6a90bcb834abc756767a12eab83ae046a6 Mon Sep 17 00:00:00 2001 From: mm Date: Mon, 21 Mar 2022 01:28:21 +0000 Subject: [PATCH] generate sample data --- sample.py | 69 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) create mode 100644 sample.py diff --git a/sample.py b/sample.py new file mode 100644 index 0000000..7200d7a --- /dev/null +++ b/sample.py @@ -0,0 +1,69 @@ +import sys +import pickle +import gym +import numpy as np +from matplotlib import pyplot as plt +from matplotlib.animation import FuncAnimation +from time import sleep + +# numpy precision for printing +np.set_printoptions(precision=3, suppress=True) + +plt.ion() # interactive plotting +fig, ax = plt.subplots() +colors = ["xkcd:orange", "xkcd:forest green", "xkcd:gray", "xkcd:light blue"] +plots = [None] * 4 + +env = gym.make("CartPole-v1") +observation, info = env.reset(seed=42, return_info=True) + +max_steps = 100 +num_samples = 500 +samples = np.random.randn(num_samples, 4) + +data = [] +for lam in samples: + breakpoints = [] + score = 0 + O = [] + for n in range(max_steps): + ax.cla() + # action = env.action_space.sample() + action = 1 if lam.T @ observation < 0 else 0 + # action = 1 if observation[0] - observation[3] < 0 else 0 + observation, reward, done, info = env.step(action) + score += reward + O.append(observation.tolist()) + o = np.array(O) + var = np.var(o[-int(score) :, :], axis=0) + for q in range(4): + lines = np.hstack([o[:, q], np.zeros(max_steps - n)]) + ax.plot(range(max_steps + 1), lines, c=colors[q]) + + ax.set_title(f"Reward: {int(score)}, Variance: {var}") + ax.set_ylim([-3, 3]) + + if done or n == max_steps: + breakpoints.append(n) + observation, info = env.reset(return_info=True) + # print(score, observation) + score = 0 # reset score + + # draw break-point lines when game is lost + for b in breakpoints: + ax.vlines( + b, np.min(o, axis=0).min(), np.max(o, axis=0).max(), color="black", lw=2 + ) + + fig.canvas.draw() + fig.show() + fig.canvas.flush_events() + env.render() + # sleep(0.01) + + data.append({"lam": lam, "obs": O, "break": breakpoints}) + pickle.dump(data, open("data.pkl", "wb")) # dump data frequently + +stop = input("Press any key to close.") +plt.close() +env.close()