You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
72 lines
2.1 KiB
72 lines
2.1 KiB
3 years ago
|
#!/usr/bin/env python
|
||
|
import pickle
|
||
|
|
||
|
import gym
|
||
|
import numpy as np
|
||
|
import pandas as pd
|
||
|
from scipy.stats import gaussian_kde as gkde
|
||
|
from scipy.stats import norm
|
||
|
|
||
|
import matplotlib.pyplot as plt
|
||
|
|
||
|
try:
|
||
|
from pyvirtualdisplay import Display
|
||
|
display = Display(visible=0, size=(1400, 900))
|
||
|
display.start()
|
||
|
except ImportError:
|
||
|
pass
|
||
|
|
||
|
is_ipython = 'inline' in plt.get_backend()
|
||
|
if is_ipython:
|
||
|
from IPython import display
|
||
|
|
||
|
plt.ion()
|
||
|
|
||
|
def train(data):
|
||
|
D = pd.DataFrame(data)
|
||
|
sd = np.array([1.0, 0.5, 0.2, 0.5])
|
||
|
D["qoi"] = D["obs"].apply(lambda o: np.sum(o, axis=0) / sd / np.sqrt(len(o)))
|
||
|
D["i"] = D["lam"].apply(lambda l: norm.pdf(l).prod())
|
||
|
D["o"] = D["qoi"].apply(lambda q: norm.pdf(q).prod())
|
||
|
Q = np.array(D["qoi"].to_list()).reshape(-1, 4)
|
||
|
K = [gkde(Q[:, i]) for i in range(4)]
|
||
|
D["p"] = D["qoi"].apply(lambda q: np.prod([K[i].pdf(q[i]) for i in range(4)]))
|
||
|
D["u"] = D["i"] * D["o"] / D["p"]
|
||
|
mud_point_idx = D["u"].argmax()
|
||
|
mud_point = D["lam"].iloc[mud_point_idx]
|
||
|
print(f"MUD Point {mud_point_idx}: {mud_point}")
|
||
|
return mud_point
|
||
|
|
||
|
|
||
|
def test(decision=np.array([-0.09, -0.71, -0.43, -0.74]), seed=1992):
|
||
|
env = gym.make("CartPole-v1")
|
||
|
observation, info = env.reset(seed=seed, return_info=True)
|
||
|
score = 0
|
||
|
if is_ipython:
|
||
|
img = plt.imshow(env.render(mode='rgb_array'))
|
||
|
for i in range(10000):
|
||
|
action = 1 if decision.T @ observation < 0 else 0
|
||
|
observation, reward, done, info = env.step(action)
|
||
|
score += reward
|
||
|
if not is_ipython:
|
||
|
env.render()
|
||
|
else:
|
||
|
img.set_data(env.render(mode='rgb_array'))
|
||
|
plt.axis('off')
|
||
|
display.display(plt.gcf())
|
||
|
display.clear_output(wait=True)
|
||
|
if done:
|
||
|
if score == 500:
|
||
|
print("WIN")
|
||
|
else:
|
||
|
print(f"LOSE: {int(score)}")
|
||
|
score = 0 # reset score
|
||
|
observation, info = env.reset(return_info=True)
|
||
|
env.close()
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
data = pickle.load(open("data.pkl", "rb"))
|
||
|
mud_point = train(data)
|
||
|
test(mud_point)
|