70 lines
		
	
	
		
			2.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			70 lines
		
	
	
		
			2.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import sys
 | 
						|
import pickle
 | 
						|
import gym
 | 
						|
import numpy as np
 | 
						|
from matplotlib import pyplot as plt
 | 
						|
from matplotlib.animation import FuncAnimation
 | 
						|
from time import sleep
 | 
						|
 | 
						|
# numpy precision for printing
 | 
						|
np.set_printoptions(precision=3, suppress=True)
 | 
						|
 | 
						|
plt.ion()  # interactive plotting
 | 
						|
fig, ax = plt.subplots()
 | 
						|
colors = ["xkcd:orange", "xkcd:forest green", "xkcd:gray", "xkcd:light blue"]
 | 
						|
plots = [None] * 4
 | 
						|
 | 
						|
env = gym.make("CartPole-v1")
 | 
						|
observation, info = env.reset(seed=42, return_info=True)
 | 
						|
 | 
						|
max_steps = 100
 | 
						|
num_samples = 500
 | 
						|
samples = np.random.randn(num_samples, 4)
 | 
						|
 | 
						|
data = []
 | 
						|
for lam in samples:
 | 
						|
    breakpoints = []
 | 
						|
    score = 0
 | 
						|
    O = []
 | 
						|
    for n in range(max_steps):
 | 
						|
        ax.cla()
 | 
						|
        # action = env.action_space.sample()
 | 
						|
        action = 1 if lam.T @ observation < 0 else 0
 | 
						|
        # action = 1 if observation[0] - observation[3]  < 0 else 0
 | 
						|
        observation, reward, done, info = env.step(action)
 | 
						|
        score += reward
 | 
						|
        O.append(observation.tolist())
 | 
						|
        o = np.array(O)
 | 
						|
        var = np.var(o[-int(score) :, :], axis=0)
 | 
						|
        for q in range(4):
 | 
						|
            lines = np.hstack([o[:, q], np.zeros(max_steps - n)])
 | 
						|
            ax.plot(range(max_steps + 1), lines, c=colors[q])
 | 
						|
 | 
						|
        ax.set_title(f"Reward: {int(score)}, Variance: {var}")
 | 
						|
        ax.set_ylim([-3, 3])
 | 
						|
 | 
						|
        if done or n == max_steps:
 | 
						|
            breakpoints.append(n)
 | 
						|
            observation, info = env.reset(return_info=True)
 | 
						|
            # print(score, observation)
 | 
						|
            score = 0  # reset score
 | 
						|
 | 
						|
        # draw break-point lines when game is lost
 | 
						|
        for b in breakpoints:
 | 
						|
            ax.vlines(
 | 
						|
                b, np.min(o, axis=0).min(), np.max(o, axis=0).max(), color="black", lw=2
 | 
						|
            )
 | 
						|
 | 
						|
        fig.canvas.draw()
 | 
						|
        fig.show()
 | 
						|
        fig.canvas.flush_events()
 | 
						|
        env.render()
 | 
						|
        # sleep(0.01)
 | 
						|
 | 
						|
    data.append({"lam": lam, "obs": O, "break": breakpoints})
 | 
						|
    pickle.dump(data, open("data.pkl", "wb"))  # dump data frequently
 | 
						|
 | 
						|
stop = input("Press any key to close.")
 | 
						|
plt.close()
 | 
						|
env.close()
 |