Lessons in the Research-to-Production Pipeline: From Data Science to Software Engineering
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

66 lines
1.9 KiB

3 years ago
import pickle
import gym
import numpy as np
from matplotlib import pyplot as plt
# numpy precision for printing
np.set_printoptions(precision=3, suppress=True)
plt.ion() # interactive plotting
fig, ax = plt.subplots()
colors = ["xkcd:orange", "xkcd:forest green", "xkcd:gray", "xkcd:light blue"]
plots = [None] * 4
env = gym.make("CartPole-v1")
observation, info = env.reset(seed=42, return_info=True)
max_steps = 100
num_samples = 500
samples = np.random.randn(num_samples, 4)
data = []
for lam in samples:
breakpoints = []
score = 0
obs = []
for n in range(max_steps):
ax.cla()
# action = env.action_space.sample()
action = 1 if lam.T @ observation < 0 else 0
# action = 1 if observation[0] - observation[3] < 0 else 0
observation, reward, done, info = env.step(action)
score += reward
obs.append(observation.tolist())
o = np.array(obs)
var = np.var(o[-int(score) :, :], axis=0)
for q in range(4):
lines = np.hstack([o[:, q], np.zeros(max_steps - n)])
ax.plot(range(max_steps + 1), lines, c=colors[q])
ax.set_title(f"Reward: {int(score)}, Variance: {var}")
ax.set_ylim([-3, 3])
if done or n == max_steps:
breakpoints.append(n)
observation, info = env.reset(return_info=True)
# print(score, observation)
score = 0 # reset score
# draw break-point lines when game is lost
for b in breakpoints:
ax.vlines(
b, np.min(o, axis=0).min(), np.max(o, axis=0).max(), color="black", lw=2
)
fig.canvas.draw()
fig.show()
fig.canvas.flush_events()
env.render()
data.append({"lam": lam, "obs": obs, "break": breakpoints})
pickle.dump(data, open("data.pkl", "wb")) # dump data frequently
stop = input("Press any key to close.")
plt.close()
env.close()