diff --git a/main.py b/main.py index a34176b..4e1cc87 100644 --- a/main.py +++ b/main.py @@ -10,19 +10,19 @@ from scipy.stats import norm def train(data): D = pd.DataFrame(data) - D['qoi'] = D['obs'].apply(lambda o: np.sum(o,axis=0)/np.sqrt(len(o))) - D['i'] = D['lam'].apply(lambda l: norm.pdf(l).prod()) - D['o'] = D['qoi'].apply(lambda q: norm.pdf(q).prod()) - Q = np.array(D['qoi'].to_list()).reshape(-1,4) - K = [gkde(Q[:,i]) for i in range(4)] - D['p'] = D['qoi'].apply(lambda q: np.prod([ K[i].pdf(q[i]) for i in range(4)])) - D['u'] = D['i']*D['o']/D['p'] - mud_point_idx = D['u'].argmax() - mud_point = D['lam'].iloc[mud_point_idx] + D["qoi"] = D["obs"].apply(lambda o: np.sum(o, axis=0) / np.sqrt(len(o))) + D["i"] = D["lam"].apply(lambda l: norm.pdf(l).prod()) + D["o"] = D["qoi"].apply(lambda q: norm.pdf(q).prod()) + Q = np.array(D["qoi"].to_list()).reshape(-1, 4) + K = [gkde(Q[:, i]) for i in range(4)] + D["p"] = D["qoi"].apply(lambda q: np.prod([K[i].pdf(q[i]) for i in range(4)])) + D["u"] = D["i"] * D["o"] / D["p"] + mud_point_idx = D["u"].argmax() + mud_point = D["lam"].iloc[mud_point_idx] return mud_point -def test(decision=np.array([-0.09, -0.71, -0.43 , -0.74]), seed=1992): +def test(decision=np.array([-0.09, -0.71, -0.43, -0.74]), seed=1992): env = gym.make("CartPole-v1") observation, info = env.reset(seed=seed, return_info=True) score = 0 @@ -42,7 +42,7 @@ def test(decision=np.array([-0.09, -0.71, -0.43 , -0.74]), seed=1992): if __name__ == "__main__": - data = pickle.load(open('data.pkl','rb')) + data = pickle.load(open("data.pkl", "rb")) mud_point = train(data) print(f"MUD Point: {mud_point}") test(mud_point)