Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions pytorch/actor_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,23 @@ def update(self, state, action_prob, reward, next_state, done):

for episode in range(2000):
state = env.reset()
if isinstance(state, tuple):
state = state[0]
done = False
total_reward = 0

while not done:
action, prob = agent.get_action(state)
next_state, reward, done, info = env.step(action)

try:
# For newer gym versions (>=0.26.0)
next_state, reward, terminated, truncated, info = env.step(action)
done = terminated or truncated
except ValueError:
# For older gym versions
next_state, reward, done, info = env.step(action)
if isinstance(info, dict) and 'TimeLimit.truncated' in info:
done = done and not info['TimeLimit.truncated']

agent.update(state, prob, reward, next_state, done)

state = next_state
Expand Down
17 changes: 13 additions & 4 deletions pytorch/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ def __len__(self):
def get_batch(self):
data = random.sample(self.buffer, self.batch_size)

state = torch.tensor(np.stack([x[0] for x in data]))
state = torch.tensor(np.array([x[0] for x in data]), dtype=torch.float32)
action = torch.tensor(np.array([x[1] for x in data]).astype(np.long))
reward = torch.tensor(np.array([x[2] for x in data]).astype(np.float32))
next_state = torch.tensor(np.stack([x[3] for x in data]))
next_state = torch.tensor(np.array([x[3] for x in data]), dtype=torch.float32)
done = torch.tensor(np.array([x[4] for x in data]).astype(np.int32))
return state, action, reward, next_state, done

Expand Down Expand Up @@ -64,7 +64,7 @@ def get_action(self, state):
if np.random.rand() < self.epsilon:
return np.random.choice(self.action_size)
else:
state = torch.tensor(state[np.newaxis, :])
state = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
qs = self.qnet(state)
return qs.argmax().item()

Expand Down Expand Up @@ -102,12 +102,21 @@ def sync_qnet(self):

for episode in range(episodes):
state = env.reset()
if isinstance(state, tuple): # Handle new gym API
state, _ = state
done = False
total_reward = 0

while not done:
action = agent.get_action(state)
next_state, reward, done, info = env.step(action)
step_result = env.step(action)

# Handle different gym versions
if len(step_result) == 5: # Newer gym version (step returns 5 values)
next_state, reward, terminated, truncated, info = step_result
done = terminated or truncated
else: # Older gym version (step returns 4 values)
next_state, reward, done, info = step_result

agent.update(state, action, reward, next_state, done)
state = next_state
Expand Down
5 changes: 4 additions & 1 deletion pytorch/reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def __init__(self):
self.optimizer = optim.Adam(self.pi.parameters(), lr=self.lr)

def get_action(self, state):
if isinstance(state, tuple):
state = state[0] # Extract the actual state from the tuple
state = torch.tensor(state[np.newaxis, :])
probs = self.pi(state)
probs = probs[0]
Expand Down Expand Up @@ -64,7 +66,8 @@ def update(self):

while not done:
action, prob = agent.get_action(state)
next_state, reward, done, info = env.step(action)
next_state, reward, terminated, truncated, info = env.step(action)
done = terminated or truncated

agent.add(reward, prob)
state = next_state
Expand Down
9 changes: 6 additions & 3 deletions pytorch/simple_pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ def __init__(self):
self.optimizer = optim.Adam(self.pi.parameters(), lr=self.lr)

def get_action(self, state):
state = torch.tensor(state[np.newaxis, :])
if isinstance(state, tuple):
state = state[0] # Extract the observation from tuple if needed
state = torch.FloatTensor(state).unsqueeze(0)
probs = self.pi(state)
probs = probs[0]
m = Categorical(probs)
Expand Down Expand Up @@ -60,13 +62,14 @@ def update(self):
reward_history = []

for episode in range(3000):
state = env.reset()
state, _ = env.reset()
done = False
total_reward = 0

while not done:
action, prob = agent.get_action(state)
next_state, reward, done, info = env.step(action)
next_state, reward, terminated, truncated, info = env.step(action)
done = terminated or truncated

agent.add(reward, prob)
state = next_state
Expand Down