import pickle import gym import numpy as np from matplotlib import pyplot as plt # numpy precision for printing np.set_printoptions(precision=3, suppress=True) plt.ion() # interactive plotting fig, ax = plt.subplots() colors = ["xkcd:orange", "xkcd:forest green", "xkcd:gray", "xkcd:light blue"] plots = [None] * 4 env = gym.make("CartPole-v1") observation, info = env.reset(seed=42, return_info=True) max_steps = 100 num_samples = 500 samples = np.random.randn(num_samples, 4) data = [] for lam in samples: breakpoints = [] score = 0 obs = [] for n in range(max_steps): ax.cla() # action = env.action_space.sample() action = 1 if lam.T @ observation < 0 else 0 # action = 1 if observation[0] - observation[3] < 0 else 0 observation, reward, done, info = env.step(action) score += reward obs.append(observation.tolist()) o = np.array(obs) var = np.var(o[-int(score) :, :], axis=0) for q in range(4): lines = np.hstack([o[:, q], np.zeros(max_steps - n)]) ax.plot(range(max_steps + 1), lines, c=colors[q]) ax.set_title(f"Reward: {int(score)}, Variance: {var}") ax.set_ylim([-3, 3]) if done or n == max_steps: breakpoints.append(n) observation, info = env.reset(return_info=True) # print(score, observation) score = 0 # reset score # draw break-point lines when game is lost for b in breakpoints: ax.vlines( b, np.min(o, axis=0).min(), np.max(o, axis=0).max(), color="black", lw=2 ) fig.canvas.draw() fig.show() fig.canvas.flush_events() env.render() data.append({"lam": lam, "obs": obs, "break": breakpoints}) pickle.dump(data, open("data.pkl", "wb")) # dump data frequently stop = input("Press any key to close.") plt.close() env.close()