control systems with MUD points
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

70 lines
2.0 KiB

import sys
import pickle
import gym
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.animation import FuncAnimation
from time import sleep
# numpy precision for printing
np.set_printoptions(precision=3, suppress=True)
plt.ion() # interactive plotting
fig, ax = plt.subplots()
colors = ["xkcd:orange", "xkcd:forest green", "xkcd:gray", "xkcd:light blue"]
plots = [None] * 4
env = gym.make("CartPole-v1")
observation, info = env.reset(seed=42, return_info=True)
max_steps = 100
num_samples = 500
samples = np.random.randn(num_samples, 4)
data = []
for lam in samples:
breakpoints = []
score = 0
O = []
for n in range(max_steps):
ax.cla()
# action = env.action_space.sample()
action = 1 if lam.T @ observation < 0 else 0
# action = 1 if observation[0] - observation[3] < 0 else 0
observation, reward, done, info = env.step(action)
score += reward
O.append(observation.tolist())
o = np.array(O)
var = np.var(o[-int(score) :, :], axis=0)
for q in range(4):
lines = np.hstack([o[:, q], np.zeros(max_steps - n)])
ax.plot(range(max_steps + 1), lines, c=colors[q])
ax.set_title(f"Reward: {int(score)}, Variance: {var}")
ax.set_ylim([-3, 3])
if done or n == max_steps:
breakpoints.append(n)
observation, info = env.reset(return_info=True)
# print(score, observation)
score = 0 # reset score
# draw break-point lines when game is lost
for b in breakpoints:
ax.vlines(
b, np.min(o, axis=0).min(), np.max(o, axis=0).max(), color="black", lw=2
)
fig.canvas.draw()
fig.show()
fig.canvas.flush_events()
env.render()
# sleep(0.01)
data.append({"lam": lam, "obs": O, "break": breakpoints})
pickle.dump(data, open("data.pkl", "wb")) # dump data frequently
stop = input("Press any key to close.")
plt.close()
env.close()