diff --git a/main.py b/main.py index 4e1cc87..62adc2b 100644 --- a/main.py +++ b/main.py @@ -10,7 +10,8 @@ from scipy.stats import norm def train(data): D = pd.DataFrame(data) - D["qoi"] = D["obs"].apply(lambda o: np.sum(o, axis=0) / np.sqrt(len(o))) + sd = np.array([1.0, 0.25, 0.5, 0.1]) + D["qoi"] = D["obs"].apply(lambda o: np.sum(o, axis=0) / sd / np.sqrt(len(o))) D["i"] = D["lam"].apply(lambda l: norm.pdf(l).prod()) D["o"] = D["qoi"].apply(lambda q: norm.pdf(q).prod()) Q = np.array(D["qoi"].to_list()).reshape(-1, 4) @@ -19,6 +20,7 @@ def train(data): D["u"] = D["i"] * D["o"] / D["p"] mud_point_idx = D["u"].argmax() mud_point = D["lam"].iloc[mud_point_idx] + print(f"MUD Point ({mud_point_idx}: {mud_point}") return mud_point @@ -44,5 +46,4 @@ def test(decision=np.array([-0.09, -0.71, -0.43, -0.74]), seed=1992): if __name__ == "__main__": data = pickle.load(open("data.pkl", "rb")) mud_point = train(data) - print(f"MUD Point: {mud_point}") test(mud_point)