RL_Models / Super-Mario-RL /duel_dqn.py
TroglodyteDerivations's picture
Upload 32 files
2db463d verified
import pickle
import random
import time
from collections import deque
import gym_super_mario_bros
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from gym_super_mario_bros.actions import COMPLEX_MOVEMENT
from nes_py.wrappers import JoypadSpace
from wrappers import *
def arrange(s):
if not type(s) == "numpy.ndarray":
s = np.array(s)
assert len(s.shape) == 3
ret = np.transpose(s, (2, 0, 1))
return np.expand_dims(ret, 0)
class replay_memory(object):
def __init__(self, N):
self.memory = deque(maxlen=N)
def push(self, transition):
self.memory.append(transition)
def sample(self, n):
return random.sample(self.memory, n)
def __len__(self):
return len(self.memory)
class model(nn.Module):
def __init__(self, n_frame, n_action, device):
super(model, self).__init__()
self.layer1 = nn.Conv2d(n_frame, 32, 8, 4)
self.layer2 = nn.Conv2d(32, 64, 3, 1)
self.fc = nn.Linear(20736, 512)
self.q = nn.Linear(512, n_action)
self.v = nn.Linear(512, 1)
self.device = device
self.seq = nn.Sequential(self.layer1, self.layer2, self.fc, self.q, self.v)
self.seq.apply(init_weights)
def forward(self, x):
if type(x) != torch.Tensor:
x = torch.FloatTensor(x).to(self.device)
x = torch.relu(self.layer1(x))
x = torch.relu(self.layer2(x))
x = x.view(-1, 20736)
x = torch.relu(self.fc(x))
adv = self.q(x)
v = self.v(x)
q = v + (adv - 1 / adv.shape[-1] * adv.sum(-1, keepdim=True))
return q
def init_weights(m):
if type(m) == nn.Conv2d:
torch.nn.init.xavier_uniform_(m.weight)
m.bias.data.fill_(0.01)
def train(q, q_target, memory, batch_size, gamma, optimizer, device):
s, r, a, s_prime, done = list(map(list, zip(*memory.sample(batch_size))))
s = np.array(s).squeeze()
s_prime = np.array(s_prime).squeeze()
a_max = q(s_prime).max(1)[1].unsqueeze(-1)
r = torch.FloatTensor(r).unsqueeze(-1).to(device)
done = torch.FloatTensor(done).unsqueeze(-1).to(device)
with torch.no_grad():
y = r + gamma * q_target(s_prime).gather(1, a_max) * done
a = torch.tensor(a).unsqueeze(-1).to(device)
q_value = torch.gather(q(s), dim=1, index=a.view(-1, 1).long())
loss = F.smooth_l1_loss(q_value, y).mean()
optimizer.zero_grad()
loss.backward()
optimizer.step()
return loss
def copy_weights(q, q_target):
q_dict = q.state_dict()
q_target.load_state_dict(q_dict)
def main(env, q, q_target, optimizer, device):
t = 0
gamma = 0.99
batch_size = 256
N = 50000
eps = 0.001
memory = replay_memory(N)
update_interval = 50
print_interval = 10
score_lst = []
total_score = 0.0
loss = 0.0
start_time = time.perf_counter()
for k in range(1000000):
s = arrange(env.reset())
done = False
while not done:
if eps > np.random.rand():
a = env.action_space.sample()
else:
if device == "cpu":
a = np.argmax(q(s).detach().numpy())
else:
a = np.argmax(q(s).cpu().detach().numpy())
s_prime, r, done, _ = env.step(a)
s_prime = arrange(s_prime)
total_score += r
r = np.sign(r) * (np.sqrt(abs(r) + 1) - 1) + 0.001 * r
memory.push((s, float(r), int(a), s_prime, int(1 - done)))
s = s_prime
stage = env.unwrapped._stage
if len(memory) > 2000:
loss += train(q, q_target, memory, batch_size, gamma, optimizer, device)
t += 1
if t % update_interval == 0:
copy_weights(q, q_target)
torch.save(q.state_dict(), "mario_q.pth")
torch.save(q_target.state_dict(), "mario_q_target.pth")
if k % print_interval == 0:
time_spent, start_time = (
time.perf_counter() - start_time,
time.perf_counter(),
)
print(
"%s |Epoch : %d | score : %f | loss : %.2f | stage : %d | time spent: %f"
% (
device,
k,
total_score / print_interval,
loss / print_interval,
stage,
time_spent,
)
)
score_lst.append(total_score / print_interval)
total_score = 0
loss = 0.0
pickle.dump(score_lst, open("score.p", "wb"))
if __name__ == "__main__":
n_frame = 4
env = gym_super_mario_bros.make("SuperMarioBros-v0")
env = JoypadSpace(env, COMPLEX_MOVEMENT)
env = wrap_mario(env)
device = "cpu"
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
q = model(n_frame, env.action_space.n, device).to(device)
q_target = model(n_frame, env.action_space.n, device).to(device)
optimizer = optim.Adam(q.parameters(), lr=0.0001)
print(device)
main(env, q, q_target, optimizer, device)