#!/usr/bin/env python import pickle import gym import numpy as np import pandas as pd from scipy.stats import gaussian_kde as gkde from scipy.stats import norm import matplotlib.pyplot as plt try: from pyvirtualdisplay import Display display = Display(visible=0, size=(1400, 900)) display.start() except ImportError: pass is_ipython = 'inline' in plt.get_backend() if is_ipython: from IPython import display plt.ion() def train(data): D = pd.DataFrame(data) sd = np.array([1.0, 0.5, 0.2, 0.5]) D["qoi"] = D["obs"].apply(lambda o: np.sum(o, axis=0) / sd / np.sqrt(len(o))) D["i"] = D["lam"].apply(lambda l: norm.pdf(l).prod()) D["o"] = D["qoi"].apply(lambda q: norm.pdf(q).prod()) Q = np.array(D["qoi"].to_list()).reshape(-1, 4) K = [gkde(Q[:, i]) for i in range(4)] D["p"] = D["qoi"].apply(lambda q: np.prod([K[i].pdf(q[i]) for i in range(4)])) D["u"] = D["i"] * D["o"] / D["p"] mud_point_idx = D["u"].argmax() mud_point = D["lam"].iloc[mud_point_idx] print(f"MUD Point {mud_point_idx}: {mud_point}") return mud_point def test(decision=np.array([-0.09, -0.71, -0.43, -0.74]), seed=1992): env = gym.make("CartPole-v1") observation, info = env.reset(seed=seed, return_info=True) score = 0 if is_ipython: img = plt.imshow(env.render(mode='rgb_array')) for i in range(10000): action = 1 if decision.T @ observation < 0 else 0 observation, reward, done, info = env.step(action) score += reward if not is_ipython: env.render() else: img.set_data(env.render(mode='rgb_array')) plt.axis('off') display.display(plt.gcf()) display.clear_output(wait=True) if done: if score == 500: print("WIN") else: print(f"LOSE: {int(score)}") score = 0 # reset score observation, info = env.reset(return_info=True) env.close() if __name__ == "__main__": data = pickle.load(open("data.pkl", "rb")) mud_point = train(data) test(mud_point)