mm
3 years ago
1 changed files with 47 additions and 0 deletions
@ -0,0 +1,47 @@ |
|||
#!/usr/bin/env python |
|||
import pickle |
|||
|
|||
import gym |
|||
import numpy as np |
|||
import pandas as pd |
|||
from scipy.stats import gaussian_kde as gkde |
|||
|
|||
|
|||
def train(data): |
|||
D = pd.DataFrame(data) |
|||
D['qoi'] = D['obs'].apply(lambda o: np.sum(o,axis=0)/np.sqrt(len(o))) |
|||
D['i'] = D['lam'].apply(lambda l: norm.pdf(l).prod()) |
|||
D['o'] = D['qoi'].apply(lambda q: norm.pdf(q).prod()) |
|||
Q = np.array(D['qoi'].to_list()).reshape(-1,4) |
|||
K = [gkde(Q[:,i]) for i in range(4)] |
|||
D['p'] = D['qoi'].apply(lambda q: np.prod([ K[i].pdf(q[i]) for i in range(4)])) |
|||
D['u'] = D['i']*D['o']/D['p'] |
|||
mud_point_idx = D['u'].argmax() |
|||
mud_point = D['lam'].iloc[mud_point_idx] |
|||
return mud_point |
|||
|
|||
|
|||
def test(decision=np.array([-0.09, -0.71, -0.43 , -0.74]), seed=1992): |
|||
env = gym.make("CartPole-v1") |
|||
observation, info = env.reset(seed=seed, return_info=True) |
|||
score = 1 |
|||
for i in range(10000): |
|||
action = 1 if decision.T @ observation < 0 else 0 |
|||
observation, reward, done, info = env.step(action) |
|||
score += reward |
|||
env.render() |
|||
if done: |
|||
if score == 500: |
|||
print("WIN") |
|||
else: |
|||
print("LOSE: {int(score)}") |
|||
score = 1 # reset score |
|||
observation, info = env.reset(return_info=True) |
|||
env.close() |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
data = pickle.load(open('data.pkl','rb')) |
|||
mud_point = train(data) |
|||
print("MUD Point: {mud_point}") |
|||
test(mud_point) |
Loading…
Reference in new issue