From cce9a052ab1fa72fc4897397c099140a840847d6 Mon Sep 17 00:00:00 2001 From: mm Date: Mon, 21 Mar 2022 01:22:22 +0000 Subject: [PATCH] main file. --- main.py | 47 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) create mode 100644 main.py diff --git a/main.py b/main.py new file mode 100644 index 0000000..88a0a81 --- /dev/null +++ b/main.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python +import pickle + +import gym +import numpy as np +import pandas as pd +from scipy.stats import gaussian_kde as gkde + + +def train(data): + D = pd.DataFrame(data) + D['qoi'] = D['obs'].apply(lambda o: np.sum(o,axis=0)/np.sqrt(len(o))) + D['i'] = D['lam'].apply(lambda l: norm.pdf(l).prod()) + D['o'] = D['qoi'].apply(lambda q: norm.pdf(q).prod()) + Q = np.array(D['qoi'].to_list()).reshape(-1,4) + K = [gkde(Q[:,i]) for i in range(4)] + D['p'] = D['qoi'].apply(lambda q: np.prod([ K[i].pdf(q[i]) for i in range(4)])) + D['u'] = D['i']*D['o']/D['p'] + mud_point_idx = D['u'].argmax() + mud_point = D['lam'].iloc[mud_point_idx] + return mud_point + + +def test(decision=np.array([-0.09, -0.71, -0.43 , -0.74]), seed=1992): + env = gym.make("CartPole-v1") + observation, info = env.reset(seed=seed, return_info=True) + score = 1 + for i in range(10000): + action = 1 if decision.T @ observation < 0 else 0 + observation, reward, done, info = env.step(action) + score += reward + env.render() + if done: + if score == 500: + print("WIN") + else: + print("LOSE: {int(score)}") + score = 1 # reset score + observation, info = env.reset(return_info=True) + env.close() + + +if __name__ == "__main__": + data = pickle.load(open('data.pkl','rb')) + mud_point = train(data) + print("MUD Point: {mud_point}") + test(mud_point)