cleanup
This commit is contained in:
		
							parent
							
								
									45ce364166
								
							
						
					
					
						commit
						177c1ff006
					
				
							
								
								
									
										12
									
								
								sample.py
									
									
									
									
									
								
							
							
						
						
									
										12
									
								
								sample.py
									
									
									
									
									
								
							| @ -1,10 +1,7 @@ | |||||||
| import sys |  | ||||||
| import pickle | import pickle | ||||||
| import gym | import gym | ||||||
| import numpy as np | import numpy as np | ||||||
| from matplotlib import pyplot as plt | from matplotlib import pyplot as plt | ||||||
| from matplotlib.animation import FuncAnimation |  | ||||||
| from time import sleep |  | ||||||
| 
 | 
 | ||||||
| # numpy precision for printing | # numpy precision for printing | ||||||
| np.set_printoptions(precision=3, suppress=True) | np.set_printoptions(precision=3, suppress=True) | ||||||
| @ -25,7 +22,7 @@ data = [] | |||||||
| for lam in samples: | for lam in samples: | ||||||
|     breakpoints = [] |     breakpoints = [] | ||||||
|     score = 0 |     score = 0 | ||||||
|     O = [] |     obs = [] | ||||||
|     for n in range(max_steps): |     for n in range(max_steps): | ||||||
|         ax.cla() |         ax.cla() | ||||||
|         # action = env.action_space.sample() |         # action = env.action_space.sample() | ||||||
| @ -33,8 +30,8 @@ for lam in samples: | |||||||
|         # action = 1 if observation[0] - observation[3]  < 0 else 0 |         # action = 1 if observation[0] - observation[3]  < 0 else 0 | ||||||
|         observation, reward, done, info = env.step(action) |         observation, reward, done, info = env.step(action) | ||||||
|         score += reward |         score += reward | ||||||
|         O.append(observation.tolist()) |         obs.append(observation.tolist()) | ||||||
|         o = np.array(O) |         o = np.array(obs) | ||||||
|         var = np.var(o[-int(score) :, :], axis=0) |         var = np.var(o[-int(score) :, :], axis=0) | ||||||
|         for q in range(4): |         for q in range(4): | ||||||
|             lines = np.hstack([o[:, q], np.zeros(max_steps - n)]) |             lines = np.hstack([o[:, q], np.zeros(max_steps - n)]) | ||||||
| @ -59,9 +56,8 @@ for lam in samples: | |||||||
|         fig.show() |         fig.show() | ||||||
|         fig.canvas.flush_events() |         fig.canvas.flush_events() | ||||||
|         env.render() |         env.render() | ||||||
|         # sleep(0.01) |  | ||||||
| 
 | 
 | ||||||
|     data.append({"lam": lam, "obs": O, "break": breakpoints}) |     data.append({"lam": lam, "obs": obs, "break": breakpoints}) | ||||||
|     pickle.dump(data, open("data.pkl", "wb"))  # dump data frequently |     pickle.dump(data, open("data.pkl", "wb"))  # dump data frequently | ||||||
| 
 | 
 | ||||||
| stop = input("Press any key to close.") | stop = input("Press any key to close.") | ||||||
|  | |||||||
		Loading…
	
		Reference in New Issue
	
	Block a user