Compare commits
3 Commits
dc353276d7
...
fe8a5ee1d7
Author | SHA1 | Date | |
---|---|---|---|
|
fe8a5ee1d7 | ||
|
53d2e1fdcb | ||
|
6020752d75 |
31
main.py
31
main.py
@ -5,26 +5,27 @@ import gym
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from scipy.stats import gaussian_kde as gkde
|
||||
from scipy.stats import norm
|
||||
|
||||
|
||||
def train(data):
|
||||
D = pd.DataFrame(data)
|
||||
D['qoi'] = D['obs'].apply(lambda o: np.sum(o,axis=0)/np.sqrt(len(o)))
|
||||
D['i'] = D['lam'].apply(lambda l: norm.pdf(l).prod())
|
||||
D['o'] = D['qoi'].apply(lambda q: norm.pdf(q).prod())
|
||||
Q = np.array(D['qoi'].to_list()).reshape(-1,4)
|
||||
K = [gkde(Q[:,i]) for i in range(4)]
|
||||
D['p'] = D['qoi'].apply(lambda q: np.prod([ K[i].pdf(q[i]) for i in range(4)]))
|
||||
D['u'] = D['i']*D['o']/D['p']
|
||||
mud_point_idx = D['u'].argmax()
|
||||
mud_point = D['lam'].iloc[mud_point_idx]
|
||||
D["qoi"] = D["obs"].apply(lambda o: np.sum(o, axis=0) / np.sqrt(len(o)))
|
||||
D["i"] = D["lam"].apply(lambda l: norm.pdf(l).prod())
|
||||
D["o"] = D["qoi"].apply(lambda q: norm.pdf(q).prod())
|
||||
Q = np.array(D["qoi"].to_list()).reshape(-1, 4)
|
||||
K = [gkde(Q[:, i]) for i in range(4)]
|
||||
D["p"] = D["qoi"].apply(lambda q: np.prod([K[i].pdf(q[i]) for i in range(4)]))
|
||||
D["u"] = D["i"] * D["o"] / D["p"]
|
||||
mud_point_idx = D["u"].argmax()
|
||||
mud_point = D["lam"].iloc[mud_point_idx]
|
||||
return mud_point
|
||||
|
||||
|
||||
def test(decision=np.array([-0.09, -0.71, -0.43 , -0.74]), seed=1992):
|
||||
def test(decision=np.array([-0.09, -0.71, -0.43, -0.74]), seed=1992):
|
||||
env = gym.make("CartPole-v1")
|
||||
observation, info = env.reset(seed=seed, return_info=True)
|
||||
score = 1
|
||||
score = 0
|
||||
for i in range(10000):
|
||||
action = 1 if decision.T @ observation < 0 else 0
|
||||
observation, reward, done, info = env.step(action)
|
||||
@ -34,14 +35,14 @@ def test(decision=np.array([-0.09, -0.71, -0.43 , -0.74]), seed=1992):
|
||||
if score == 500:
|
||||
print("WIN")
|
||||
else:
|
||||
print("LOSE: {int(score)}")
|
||||
score = 1 # reset score
|
||||
print(f"LOSE: {int(score)}")
|
||||
score = 0 # reset score
|
||||
observation, info = env.reset(return_info=True)
|
||||
env.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
data = pickle.load(open('data.pkl','rb'))
|
||||
data = pickle.load(open("data.pkl", "rb"))
|
||||
mud_point = train(data)
|
||||
print("MUD Point: {mud_point}")
|
||||
print(f"MUD Point: {mud_point}")
|
||||
test(mud_point)
|
||||
|
5
requirements.txt
Normal file
5
requirements.txt
Normal file
@ -0,0 +1,5 @@
|
||||
scipy
|
||||
numpy
|
||||
gym[classic_control]
|
||||
matplotlib
|
||||
pandas
|
Loading…
Reference in New Issue
Block a user