diff --git a/pytorch/actor_critic.py b/pytorch/actor_critic.py index 65c08de..bc5c72e 100644 --- a/pytorch/actor_critic.py +++ b/pytorch/actor_critic.py @@ -83,13 +83,23 @@ def update(self, state, action_prob, reward, next_state, done): for episode in range(2000): state = env.reset() + if isinstance(state, tuple): + state = state[0] done = False total_reward = 0 while not done: action, prob = agent.get_action(state) - next_state, reward, done, info = env.step(action) - + try: + # For newer gym versions (>=0.26.0) + next_state, reward, terminated, truncated, info = env.step(action) + done = terminated or truncated + except ValueError: + # For older gym versions + next_state, reward, done, info = env.step(action) + if isinstance(info, dict) and 'TimeLimit.truncated' in info: + done = done and not info['TimeLimit.truncated'] + agent.update(state, prob, reward, next_state, done) state = next_state diff --git a/pytorch/dqn.py b/pytorch/dqn.py index a9116c7..0e672c8 100644 --- a/pytorch/dqn.py +++ b/pytorch/dqn.py @@ -24,10 +24,10 @@ def __len__(self): def get_batch(self): data = random.sample(self.buffer, self.batch_size) - state = torch.tensor(np.stack([x[0] for x in data])) + state = torch.tensor(np.array([x[0] for x in data]), dtype=torch.float32) action = torch.tensor(np.array([x[1] for x in data]).astype(np.long)) reward = torch.tensor(np.array([x[2] for x in data]).astype(np.float32)) - next_state = torch.tensor(np.stack([x[3] for x in data])) + next_state = torch.tensor(np.array([x[3] for x in data]), dtype=torch.float32) done = torch.tensor(np.array([x[4] for x in data]).astype(np.int32)) return state, action, reward, next_state, done @@ -64,7 +64,7 @@ def get_action(self, state): if np.random.rand() < self.epsilon: return np.random.choice(self.action_size) else: - state = torch.tensor(state[np.newaxis, :]) + state = torch.tensor(state, dtype=torch.float32).unsqueeze(0) qs = self.qnet(state) return qs.argmax().item() @@ -102,12 +102,21 @@ def sync_qnet(self): for episode in range(episodes): state = env.reset() + if isinstance(state, tuple): # Handle new gym API + state, _ = state done = False total_reward = 0 while not done: action = agent.get_action(state) - next_state, reward, done, info = env.step(action) + step_result = env.step(action) + + # Handle different gym versions + if len(step_result) == 5: # Newer gym version (step returns 5 values) + next_state, reward, terminated, truncated, info = step_result + done = terminated or truncated + else: # Older gym version (step returns 4 values) + next_state, reward, done, info = step_result agent.update(state, action, reward, next_state, done) state = next_state diff --git a/pytorch/reinforce.py b/pytorch/reinforce.py index 50c26b6..4f7cf96 100644 --- a/pytorch/reinforce.py +++ b/pytorch/reinforce.py @@ -30,6 +30,8 @@ def __init__(self): self.optimizer = optim.Adam(self.pi.parameters(), lr=self.lr) def get_action(self, state): + if isinstance(state, tuple): + state = state[0] # Extract the actual state from the tuple state = torch.tensor(state[np.newaxis, :]) probs = self.pi(state) probs = probs[0] @@ -64,7 +66,8 @@ def update(self): while not done: action, prob = agent.get_action(state) - next_state, reward, done, info = env.step(action) + next_state, reward, terminated, truncated, info = env.step(action) + done = terminated or truncated agent.add(reward, prob) state = next_state diff --git a/pytorch/simple_pg.py b/pytorch/simple_pg.py index 1737846..b3873ac 100644 --- a/pytorch/simple_pg.py +++ b/pytorch/simple_pg.py @@ -30,7 +30,9 @@ def __init__(self): self.optimizer = optim.Adam(self.pi.parameters(), lr=self.lr) def get_action(self, state): - state = torch.tensor(state[np.newaxis, :]) + if isinstance(state, tuple): + state = state[0] # Extract the observation from tuple if needed + state = torch.FloatTensor(state).unsqueeze(0) probs = self.pi(state) probs = probs[0] m = Categorical(probs) @@ -60,13 +62,14 @@ def update(self): reward_history = [] for episode in range(3000): - state = env.reset() + state, _ = env.reset() done = False total_reward = 0 while not done: action, prob = agent.get_action(state) - next_state, reward, done, info = env.step(action) + next_state, reward, terminated, truncated, info = env.step(action) + done = terminated or truncated agent.add(reward, prob) state = next_state