generate sample data
This commit is contained in:
parent
cce9a052ab
commit
9b2b0f6a90
69
sample.py
Normal file
69
sample.py
Normal file
@ -0,0 +1,69 @@
|
||||
import sys
|
||||
import pickle
|
||||
import gym
|
||||
import numpy as np
|
||||
from matplotlib import pyplot as plt
|
||||
from matplotlib.animation import FuncAnimation
|
||||
from time import sleep
|
||||
|
||||
# numpy precision for printing
|
||||
np.set_printoptions(precision=3, suppress=True)
|
||||
|
||||
plt.ion() # interactive plotting
|
||||
fig, ax = plt.subplots()
|
||||
colors = ["xkcd:orange", "xkcd:forest green", "xkcd:gray", "xkcd:light blue"]
|
||||
plots = [None] * 4
|
||||
|
||||
env = gym.make("CartPole-v1")
|
||||
observation, info = env.reset(seed=42, return_info=True)
|
||||
|
||||
max_steps = 100
|
||||
num_samples = 500
|
||||
samples = np.random.randn(num_samples, 4)
|
||||
|
||||
data = []
|
||||
for lam in samples:
|
||||
breakpoints = []
|
||||
score = 0
|
||||
O = []
|
||||
for n in range(max_steps):
|
||||
ax.cla()
|
||||
# action = env.action_space.sample()
|
||||
action = 1 if lam.T @ observation < 0 else 0
|
||||
# action = 1 if observation[0] - observation[3] < 0 else 0
|
||||
observation, reward, done, info = env.step(action)
|
||||
score += reward
|
||||
O.append(observation.tolist())
|
||||
o = np.array(O)
|
||||
var = np.var(o[-int(score) :, :], axis=0)
|
||||
for q in range(4):
|
||||
lines = np.hstack([o[:, q], np.zeros(max_steps - n)])
|
||||
ax.plot(range(max_steps + 1), lines, c=colors[q])
|
||||
|
||||
ax.set_title(f"Reward: {int(score)}, Variance: {var}")
|
||||
ax.set_ylim([-3, 3])
|
||||
|
||||
if done or n == max_steps:
|
||||
breakpoints.append(n)
|
||||
observation, info = env.reset(return_info=True)
|
||||
# print(score, observation)
|
||||
score = 0 # reset score
|
||||
|
||||
# draw break-point lines when game is lost
|
||||
for b in breakpoints:
|
||||
ax.vlines(
|
||||
b, np.min(o, axis=0).min(), np.max(o, axis=0).max(), color="black", lw=2
|
||||
)
|
||||
|
||||
fig.canvas.draw()
|
||||
fig.show()
|
||||
fig.canvas.flush_events()
|
||||
env.render()
|
||||
# sleep(0.01)
|
||||
|
||||
data.append({"lam": lam, "obs": O, "break": breakpoints})
|
||||
pickle.dump(data, open("data.pkl", "wb")) # dump data frequently
|
||||
|
||||
stop = input("Press any key to close.")
|
||||
plt.close()
|
||||
env.close()
|
Loading…
Reference in New Issue
Block a user