This commit is contained in:
Michael Pilosov 2022-03-20 20:04:49 -06:00
parent 45ce364166
commit 177c1ff006

View File

@ -1,10 +1,7 @@
import sys
import pickle import pickle
import gym import gym
import numpy as np import numpy as np
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from matplotlib.animation import FuncAnimation
from time import sleep
# numpy precision for printing # numpy precision for printing
np.set_printoptions(precision=3, suppress=True) np.set_printoptions(precision=3, suppress=True)
@ -25,7 +22,7 @@ data = []
for lam in samples: for lam in samples:
breakpoints = [] breakpoints = []
score = 0 score = 0
O = [] obs = []
for n in range(max_steps): for n in range(max_steps):
ax.cla() ax.cla()
# action = env.action_space.sample() # action = env.action_space.sample()
@ -33,8 +30,8 @@ for lam in samples:
# action = 1 if observation[0] - observation[3] < 0 else 0 # action = 1 if observation[0] - observation[3] < 0 else 0
observation, reward, done, info = env.step(action) observation, reward, done, info = env.step(action)
score += reward score += reward
O.append(observation.tolist()) obs.append(observation.tolist())
o = np.array(O) o = np.array(obs)
var = np.var(o[-int(score) :, :], axis=0) var = np.var(o[-int(score) :, :], axis=0)
for q in range(4): for q in range(4):
lines = np.hstack([o[:, q], np.zeros(max_steps - n)]) lines = np.hstack([o[:, q], np.zeros(max_steps - n)])
@ -59,9 +56,8 @@ for lam in samples:
fig.show() fig.show()
fig.canvas.flush_events() fig.canvas.flush_events()
env.render() env.render()
# sleep(0.01)
data.append({"lam": lam, "obs": O, "break": breakpoints}) data.append({"lam": lam, "obs": obs, "break": breakpoints})
pickle.dump(data, open("data.pkl", "wb")) # dump data frequently pickle.dump(data, open("data.pkl", "wb")) # dump data frequently
stop = input("Press any key to close.") stop = input("Press any key to close.")