main file.

This commit is contained in:
mm 2022-03-21 01:22:22 +00:00
parent 60b29e26bf
commit cce9a052ab

47
main.py Normal file
View File

@ -0,0 +1,47 @@
#!/usr/bin/env python
import pickle
import gym
import numpy as np
import pandas as pd
from scipy.stats import gaussian_kde as gkde
def train(data):
D = pd.DataFrame(data)
D['qoi'] = D['obs'].apply(lambda o: np.sum(o,axis=0)/np.sqrt(len(o)))
D['i'] = D['lam'].apply(lambda l: norm.pdf(l).prod())
D['o'] = D['qoi'].apply(lambda q: norm.pdf(q).prod())
Q = np.array(D['qoi'].to_list()).reshape(-1,4)
K = [gkde(Q[:,i]) for i in range(4)]
D['p'] = D['qoi'].apply(lambda q: np.prod([ K[i].pdf(q[i]) for i in range(4)]))
D['u'] = D['i']*D['o']/D['p']
mud_point_idx = D['u'].argmax()
mud_point = D['lam'].iloc[mud_point_idx]
return mud_point
def test(decision=np.array([-0.09, -0.71, -0.43 , -0.74]), seed=1992):
env = gym.make("CartPole-v1")
observation, info = env.reset(seed=seed, return_info=True)
score = 1
for i in range(10000):
action = 1 if decision.T @ observation < 0 else 0
observation, reward, done, info = env.step(action)
score += reward
env.render()
if done:
if score == 500:
print("WIN")
else:
print("LOSE: {int(score)}")
score = 1 # reset score
observation, info = env.reset(return_info=True)
env.close()
if __name__ == "__main__":
data = pickle.load(open('data.pkl','rb'))
mud_point = train(data)
print("MUD Point: {mud_point}")
test(mud_point)