Upload 32 files
Browse files- .gitattributes +4 -0
- Super-Mario-RL-PyQt5/app.py +673 -0
- Super-Mario-RL-PyQt5/app_2.py +676 -0
- Super-Mario-RL-PyQt5/enhanced_mario_q_best.pth +3 -0
- Super-Mario-RL-PyQt5/enhanced_mario_q_target_best.pth +3 -0
- Super-Mario-RL-PyQt5/requirements.txt +9 -0
- Super-Mario-RL-PyQt5/score.p +0 -0
- Super-Mario-RL/README.md +85 -0
- Super-Mario-RL/__pycache__/wrappers.cpython-313.pyc +0 -0
- Super-Mario-RL/duel_dqn.py +178 -0
- Super-Mario-RL/duel_dqn_2.py +237 -0
- Super-Mario-RL/enhanced_duel_dqn.py +257 -0
- Super-Mario-RL/enhanced_mario_q.pth +3 -0
- Super-Mario-RL/enhanced_mario_q_best.pth +3 -0
- Super-Mario-RL/enhanced_mario_q_target.pth +3 -0
- Super-Mario-RL/enhanced_mario_q_target_best.pth +3 -0
- Super-Mario-RL/eval.py +111 -0
- Super-Mario-RL/mario1.gif +3 -0
- Super-Mario-RL/mario1.mp4 +3 -0
- Super-Mario-RL/mario14.gif +3 -0
- Super-Mario-RL/mario14.mp4 +3 -0
- Super-Mario-RL/mario_q.pth +3 -0
- Super-Mario-RL/mario_q_best.pth +3 -0
- Super-Mario-RL/mario_q_target.pth +3 -0
- Super-Mario-RL/mario_q_target_best.pth +3 -0
- Super-Mario-RL/ppo.py +272 -0
- Super-Mario-RL/requirements.txt +8 -0
- Super-Mario-RL/score.p +0 -0
- Super-Mario-RL/terminal.txt +5 -0
- Super-Mario-RL/wrappers.py +361 -0
- ale_pyqt5/app.py +514 -0
- ale_pyqt5/app_2.py +559 -0
- ale_pyqt5/installed_packages_ale_py.txt +30 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
Super-Mario-RL/mario1.gif filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
Super-Mario-RL/mario1.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
Super-Mario-RL/mario14.gif filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
Super-Mario-RL/mario14.mp4 filter=lfs diff=lfs merge=lfs -text
|
Super-Mario-RL-PyQt5/app.py
ADDED
|
@@ -0,0 +1,673 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pickle
|
| 2 |
+
import random
|
| 3 |
+
import time
|
| 4 |
+
from collections import deque
|
| 5 |
+
|
| 6 |
+
import gym_super_mario_bros
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
import torch.optim as optim
|
| 12 |
+
from gym_super_mario_bros.actions import COMPLEX_MOVEMENT
|
| 13 |
+
from nes_py.wrappers import JoypadSpace
|
| 14 |
+
|
| 15 |
+
from PyQt5.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout,
|
| 16 |
+
QHBoxLayout, QPushButton, QLabel, QComboBox,
|
| 17 |
+
QTextEdit, QProgressBar, QTabWidget, QFrame, QGroupBox)
|
| 18 |
+
from PyQt5.QtCore import QTimer, Qt, pyqtSignal, QThread
|
| 19 |
+
from PyQt5.QtGui import QImage, QPixmap, QFont
|
| 20 |
+
import sys
|
| 21 |
+
import cv2
|
| 22 |
+
|
| 23 |
+
# Import your wrappers (make sure this module exists)
|
| 24 |
+
try:
|
| 25 |
+
from wrappers import *
|
| 26 |
+
except ImportError:
|
| 27 |
+
# Create a proper wrapper if the module doesn't exist
|
| 28 |
+
class SimpleWrapper:
|
| 29 |
+
def __init__(self, env):
|
| 30 |
+
self.env = env
|
| 31 |
+
self.action_space = env.action_space
|
| 32 |
+
self.observation_space = env.observation_space
|
| 33 |
+
|
| 34 |
+
def reset(self):
|
| 35 |
+
return self.env.reset()
|
| 36 |
+
|
| 37 |
+
def step(self, action):
|
| 38 |
+
return self.env.step(action)
|
| 39 |
+
|
| 40 |
+
def render(self, mode='rgb_array'):
|
| 41 |
+
return self.env.render(mode)
|
| 42 |
+
|
| 43 |
+
def close(self):
|
| 44 |
+
if hasattr(self.env, 'close'):
|
| 45 |
+
self.env.close()
|
| 46 |
+
|
| 47 |
+
def wrap_mario(env):
|
| 48 |
+
return SimpleWrapper(env)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class FrameStacker:
|
| 52 |
+
"""Handles frame stacking and preprocessing"""
|
| 53 |
+
def __init__(self, frame_size=(84, 84), stack_size=4):
|
| 54 |
+
self.frame_size = frame_size
|
| 55 |
+
self.stack_size = stack_size
|
| 56 |
+
self.frames = deque(maxlen=stack_size)
|
| 57 |
+
|
| 58 |
+
def preprocess_frame(self, frame):
|
| 59 |
+
"""Convert frame to grayscale and resize"""
|
| 60 |
+
# Convert to grayscale
|
| 61 |
+
gray = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
|
| 62 |
+
# Resize to 84x84
|
| 63 |
+
resized = cv2.resize(gray, self.frame_size, interpolation=cv2.INTER_AREA)
|
| 64 |
+
# Normalize to [0, 1]
|
| 65 |
+
normalized = resized.astype(np.float32) / 255.0
|
| 66 |
+
return normalized
|
| 67 |
+
|
| 68 |
+
def reset(self, frame):
|
| 69 |
+
"""Reset frame stack with initial frame"""
|
| 70 |
+
self.frames.clear()
|
| 71 |
+
processed_frame = self.preprocess_frame(frame)
|
| 72 |
+
for _ in range(self.stack_size):
|
| 73 |
+
self.frames.append(processed_frame)
|
| 74 |
+
return self.get_stacked_frames()
|
| 75 |
+
|
| 76 |
+
def append(self, frame):
|
| 77 |
+
"""Add new frame to stack"""
|
| 78 |
+
processed_frame = self.preprocess_frame(frame)
|
| 79 |
+
self.frames.append(processed_frame)
|
| 80 |
+
return self.get_stacked_frames()
|
| 81 |
+
|
| 82 |
+
def get_stacked_frames(self):
|
| 83 |
+
"""Get stacked frames as numpy array"""
|
| 84 |
+
stacked = np.array(self.frames)
|
| 85 |
+
return np.ascontiguousarray(stacked)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class replay_memory(object):
|
| 89 |
+
def __init__(self, N):
|
| 90 |
+
self.memory = deque(maxlen=N)
|
| 91 |
+
|
| 92 |
+
def push(self, transition):
|
| 93 |
+
self.memory.append(transition)
|
| 94 |
+
|
| 95 |
+
def sample(self, n):
|
| 96 |
+
return random.sample(self.memory, n)
|
| 97 |
+
|
| 98 |
+
def __len__(self):
|
| 99 |
+
return len(self.memory)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class DuelingDQNModel(nn.Module):
|
| 103 |
+
def __init__(self, n_frame, n_action, device):
|
| 104 |
+
super(DuelingDQNModel, self).__init__()
|
| 105 |
+
|
| 106 |
+
# CNN layers for feature extraction
|
| 107 |
+
self.conv_layers = nn.Sequential(
|
| 108 |
+
nn.Conv2d(n_frame, 32, kernel_size=8, stride=4),
|
| 109 |
+
nn.ReLU(),
|
| 110 |
+
nn.Conv2d(32, 64, kernel_size=4, stride=2),
|
| 111 |
+
nn.ReLU(),
|
| 112 |
+
nn.Conv2d(64, 64, kernel_size=3, stride=1),
|
| 113 |
+
nn.ReLU()
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
# Calculate conv output size
|
| 117 |
+
self.conv_out_size = self._get_conv_out((n_frame, 84, 84))
|
| 118 |
+
|
| 119 |
+
# Advantage stream
|
| 120 |
+
self.advantage_stream = nn.Sequential(
|
| 121 |
+
nn.Linear(self.conv_out_size, 512),
|
| 122 |
+
nn.ReLU(),
|
| 123 |
+
nn.Linear(512, n_action)
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
# Value stream
|
| 127 |
+
self.value_stream = nn.Sequential(
|
| 128 |
+
nn.Linear(self.conv_out_size, 512),
|
| 129 |
+
nn.ReLU(),
|
| 130 |
+
nn.Linear(512, 1)
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
self.device = device
|
| 134 |
+
self.apply(self.init_weights)
|
| 135 |
+
|
| 136 |
+
def _get_conv_out(self, shape):
|
| 137 |
+
with torch.no_grad():
|
| 138 |
+
x = torch.zeros(1, *shape)
|
| 139 |
+
x = self.conv_layers(x)
|
| 140 |
+
return int(np.prod(x.size()))
|
| 141 |
+
|
| 142 |
+
def init_weights(self, m):
|
| 143 |
+
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
|
| 144 |
+
torch.nn.init.xavier_uniform_(m.weight)
|
| 145 |
+
if m.bias is not None:
|
| 146 |
+
m.bias.data.fill_(0.01)
|
| 147 |
+
|
| 148 |
+
def forward(self, x):
|
| 149 |
+
if not isinstance(x, torch.Tensor):
|
| 150 |
+
x = torch.FloatTensor(x).to(self.device)
|
| 151 |
+
|
| 152 |
+
# Forward through conv layers
|
| 153 |
+
x = self.conv_layers(x)
|
| 154 |
+
x = x.view(x.size(0), -1)
|
| 155 |
+
|
| 156 |
+
# Forward through advantage and value streams
|
| 157 |
+
advantage = self.advantage_stream(x)
|
| 158 |
+
value = self.value_stream(x)
|
| 159 |
+
|
| 160 |
+
# Combine value and advantage
|
| 161 |
+
q_values = value + (advantage - advantage.mean(dim=1, keepdim=True))
|
| 162 |
+
|
| 163 |
+
return q_values
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def train(q, q_target, memory, batch_size, gamma, optimizer, device):
|
| 167 |
+
if len(memory) < batch_size:
|
| 168 |
+
return 0.0
|
| 169 |
+
|
| 170 |
+
transitions = memory.sample(batch_size)
|
| 171 |
+
s, r, a, s_prime, done = list(map(list, zip(*transitions)))
|
| 172 |
+
|
| 173 |
+
# Ensure positive strides for all arrays
|
| 174 |
+
s = np.array([np.ascontiguousarray(arr) for arr in s])
|
| 175 |
+
s_prime = np.array([np.ascontiguousarray(arr) for arr in s_prime])
|
| 176 |
+
|
| 177 |
+
# Move computations to device
|
| 178 |
+
s_tensor = torch.FloatTensor(s).to(device)
|
| 179 |
+
s_prime_tensor = torch.FloatTensor(s_prime).to(device)
|
| 180 |
+
|
| 181 |
+
# Get next Q values from target network
|
| 182 |
+
with torch.no_grad():
|
| 183 |
+
next_q_values = q_target(s_prime_tensor)
|
| 184 |
+
next_actions = next_q_values.max(1)[1].unsqueeze(1)
|
| 185 |
+
next_q_value = next_q_values.gather(1, next_actions)
|
| 186 |
+
|
| 187 |
+
# Calculate target Q values
|
| 188 |
+
r = torch.FloatTensor(r).unsqueeze(1).to(device)
|
| 189 |
+
done = torch.FloatTensor(done).unsqueeze(1).to(device)
|
| 190 |
+
target_q_values = r + gamma * next_q_value * (1 - done)
|
| 191 |
+
|
| 192 |
+
# Get current Q values
|
| 193 |
+
a_tensor = torch.LongTensor(a).unsqueeze(1).to(device)
|
| 194 |
+
current_q_values = q(s_tensor).gather(1, a_tensor)
|
| 195 |
+
|
| 196 |
+
# Calculate loss
|
| 197 |
+
loss = F.smooth_l1_loss(current_q_values, target_q_values)
|
| 198 |
+
|
| 199 |
+
# Optimize
|
| 200 |
+
optimizer.zero_grad()
|
| 201 |
+
loss.backward()
|
| 202 |
+
|
| 203 |
+
# Gradient clipping
|
| 204 |
+
torch.nn.utils.clip_grad_norm_(q.parameters(), max_norm=10.0)
|
| 205 |
+
|
| 206 |
+
optimizer.step()
|
| 207 |
+
return loss.item()
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def copy_weights(q, q_target):
|
| 211 |
+
q_dict = q.state_dict()
|
| 212 |
+
q_target.load_state_dict(q_dict)
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
class MarioTrainingThread(QThread):
|
| 216 |
+
update_signal = pyqtSignal(dict)
|
| 217 |
+
frame_signal = pyqtSignal(np.ndarray)
|
| 218 |
+
|
| 219 |
+
def __init__(self, device="cpu"):
|
| 220 |
+
super().__init__()
|
| 221 |
+
self.device = device
|
| 222 |
+
self.running = False
|
| 223 |
+
self.env = None
|
| 224 |
+
self.q = None
|
| 225 |
+
self.q_target = None
|
| 226 |
+
self.optimizer = None
|
| 227 |
+
self.frame_stacker = None
|
| 228 |
+
|
| 229 |
+
# Training parameters
|
| 230 |
+
self.gamma = 0.99
|
| 231 |
+
self.batch_size = 32
|
| 232 |
+
self.memory_size = 10000
|
| 233 |
+
self.eps = 1.0 # Start with full exploration
|
| 234 |
+
self.eps_min = 0.01
|
| 235 |
+
self.eps_decay = 0.995
|
| 236 |
+
self.update_interval = 1000
|
| 237 |
+
self.save_interval = 100
|
| 238 |
+
self.print_interval = 10
|
| 239 |
+
|
| 240 |
+
self.memory = None
|
| 241 |
+
self.t = 0
|
| 242 |
+
self.k = 0
|
| 243 |
+
self.total_score = 0.0
|
| 244 |
+
self.loss_accumulator = 0.0
|
| 245 |
+
self.best_score = -float('inf')
|
| 246 |
+
self.last_x_pos = 0
|
| 247 |
+
|
| 248 |
+
def setup_training(self):
|
| 249 |
+
n_frame = 4 # Number of stacked frames
|
| 250 |
+
try:
|
| 251 |
+
self.env = gym_super_mario_bros.make("SuperMarioBros-v3")
|
| 252 |
+
self.env = JoypadSpace(self.env, COMPLEX_MOVEMENT)
|
| 253 |
+
self.env = wrap_mario(self.env)
|
| 254 |
+
|
| 255 |
+
# Initialize frame stacker
|
| 256 |
+
self.frame_stacker = FrameStacker(frame_size=(84, 84), stack_size=n_frame)
|
| 257 |
+
|
| 258 |
+
self.q = DuelingDQNModel(n_frame, self.env.action_space.n, self.device).to(self.device)
|
| 259 |
+
self.q_target = DuelingDQNModel(n_frame, self.env.action_space.n, self.device).to(self.device)
|
| 260 |
+
|
| 261 |
+
copy_weights(self.q, self.q_target)
|
| 262 |
+
|
| 263 |
+
# Set target network to eval mode
|
| 264 |
+
self.q_target.eval()
|
| 265 |
+
|
| 266 |
+
# Optimizer
|
| 267 |
+
self.optimizer = optim.Adam(self.q.parameters(), lr=0.0001, weight_decay=1e-5)
|
| 268 |
+
|
| 269 |
+
self.memory = replay_memory(self.memory_size)
|
| 270 |
+
|
| 271 |
+
self.log_message(f"โ
Training setup complete - Actions: {self.env.action_space.n}, Device: {self.device}")
|
| 272 |
+
|
| 273 |
+
except Exception as e:
|
| 274 |
+
self.log_message(f"โ Error setting up training: {e}")
|
| 275 |
+
import traceback
|
| 276 |
+
traceback.print_exc()
|
| 277 |
+
self.running = False
|
| 278 |
+
|
| 279 |
+
def run(self):
|
| 280 |
+
self.running = True
|
| 281 |
+
self.setup_training()
|
| 282 |
+
|
| 283 |
+
if not self.running:
|
| 284 |
+
return
|
| 285 |
+
|
| 286 |
+
start_time = time.perf_counter()
|
| 287 |
+
score_lst = []
|
| 288 |
+
|
| 289 |
+
try:
|
| 290 |
+
for k in range(1000000):
|
| 291 |
+
if not self.running:
|
| 292 |
+
break
|
| 293 |
+
|
| 294 |
+
# Reset environment and frame stacker
|
| 295 |
+
frame = self.env.reset()
|
| 296 |
+
s = self.frame_stacker.reset(frame)
|
| 297 |
+
done = False
|
| 298 |
+
episode_loss = 0.0
|
| 299 |
+
episode_steps = 0
|
| 300 |
+
episode_score = 0.0
|
| 301 |
+
self.last_x_pos = 0
|
| 302 |
+
|
| 303 |
+
while not done and self.running:
|
| 304 |
+
# Ensure state has positive strides before processing
|
| 305 |
+
s_processed = np.ascontiguousarray(s)
|
| 306 |
+
|
| 307 |
+
# Epsilon-greedy action selection
|
| 308 |
+
if np.random.random() <= self.eps:
|
| 309 |
+
a = self.env.action_space.sample()
|
| 310 |
+
else:
|
| 311 |
+
with torch.no_grad():
|
| 312 |
+
# Add batch dimension and create tensor
|
| 313 |
+
state_tensor = torch.FloatTensor(s_processed).unsqueeze(0).to(self.device)
|
| 314 |
+
q_values = self.q(state_tensor)
|
| 315 |
+
|
| 316 |
+
if self.device == "cuda" or self.device == "mps":
|
| 317 |
+
a = np.argmax(q_values.cpu().numpy())
|
| 318 |
+
else:
|
| 319 |
+
a = np.argmax(q_values.detach().numpy())
|
| 320 |
+
|
| 321 |
+
# Take action
|
| 322 |
+
frame, r, done, info = self.env.step(a)
|
| 323 |
+
|
| 324 |
+
# Update frame stack
|
| 325 |
+
s_prime = self.frame_stacker.append(frame)
|
| 326 |
+
|
| 327 |
+
episode_score += r
|
| 328 |
+
|
| 329 |
+
# Enhanced reward shaping
|
| 330 |
+
reward = r # Start with original reward
|
| 331 |
+
|
| 332 |
+
# Bonus for x_pos progress
|
| 333 |
+
if 'x_pos' in info:
|
| 334 |
+
x_pos = info['x_pos']
|
| 335 |
+
x_progress = x_pos - self.last_x_pos
|
| 336 |
+
if x_progress > 0:
|
| 337 |
+
reward += 0.1 * x_progress
|
| 338 |
+
self.last_x_pos = x_pos
|
| 339 |
+
|
| 340 |
+
# Large bonus for completing the level
|
| 341 |
+
if done and info.get('flag_get', False):
|
| 342 |
+
reward += 100.0
|
| 343 |
+
self.log_message(f"๐ LEVEL COMPLETED at episode {k}! ๐")
|
| 344 |
+
|
| 345 |
+
# Store transition with contiguous arrays
|
| 346 |
+
s_contiguous = np.ascontiguousarray(s)
|
| 347 |
+
s_prime_contiguous = np.ascontiguousarray(s_prime)
|
| 348 |
+
self.memory.push((s_contiguous, float(reward), int(a), s_prime_contiguous, int(1 - done)))
|
| 349 |
+
|
| 350 |
+
s = s_prime
|
| 351 |
+
stage = info.get('stage', 1)
|
| 352 |
+
world = info.get('world', 1)
|
| 353 |
+
|
| 354 |
+
# Emit frame for display
|
| 355 |
+
try:
|
| 356 |
+
display_frame = self.env.render()
|
| 357 |
+
if display_frame is not None:
|
| 358 |
+
# Ensure frame has positive strides
|
| 359 |
+
frame_contiguous = np.ascontiguousarray(display_frame)
|
| 360 |
+
self.frame_signal.emit(frame_contiguous)
|
| 361 |
+
except Exception as e:
|
| 362 |
+
# Create a placeholder frame if rendering fails
|
| 363 |
+
frame = np.zeros((240, 256, 3), dtype=np.uint8)
|
| 364 |
+
self.frame_signal.emit(frame)
|
| 365 |
+
|
| 366 |
+
# Train only if we have enough samples
|
| 367 |
+
if len(self.memory) > self.batch_size:
|
| 368 |
+
loss_val = train(self.q, self.q_target, self.memory, self.batch_size,
|
| 369 |
+
self.gamma, self.optimizer, self.device)
|
| 370 |
+
if loss_val > 0:
|
| 371 |
+
self.loss_accumulator += loss_val
|
| 372 |
+
episode_loss += loss_val
|
| 373 |
+
self.t += 1
|
| 374 |
+
|
| 375 |
+
# Update target network
|
| 376 |
+
if self.t % self.update_interval == 0:
|
| 377 |
+
copy_weights(self.q, self.q_target)
|
| 378 |
+
|
| 379 |
+
episode_steps += 1
|
| 380 |
+
|
| 381 |
+
# Emit training progress every 10 steps
|
| 382 |
+
if episode_steps % 10 == 0:
|
| 383 |
+
progress_data = {
|
| 384 |
+
'episode': k,
|
| 385 |
+
'total_reward': episode_score,
|
| 386 |
+
'steps': episode_steps,
|
| 387 |
+
'epsilon': self.eps,
|
| 388 |
+
'world': world,
|
| 389 |
+
'stage': stage,
|
| 390 |
+
'loss': episode_loss / (episode_steps + 1e-8),
|
| 391 |
+
'memory_size': len(self.memory),
|
| 392 |
+
'x_pos': info.get('x_pos', 0),
|
| 393 |
+
'score': info.get('score', 0),
|
| 394 |
+
'coins': info.get('coins', 0),
|
| 395 |
+
'time': info.get('time', 400),
|
| 396 |
+
'flag_get': info.get('flag_get', False)
|
| 397 |
+
}
|
| 398 |
+
self.update_signal.emit(progress_data)
|
| 399 |
+
|
| 400 |
+
# Epsilon decay after each episode
|
| 401 |
+
if self.eps > self.eps_min:
|
| 402 |
+
self.eps *= self.eps_decay
|
| 403 |
+
|
| 404 |
+
# Update total score
|
| 405 |
+
self.total_score += episode_score
|
| 406 |
+
|
| 407 |
+
# Save best model
|
| 408 |
+
if episode_score > self.best_score and k > 0:
|
| 409 |
+
self.best_score = episode_score
|
| 410 |
+
torch.save(self.q.state_dict(), "enhanced_mario_q_best.pth")
|
| 411 |
+
torch.save(self.q_target.state_dict(), "enhanced_mario_q_target_best.pth")
|
| 412 |
+
self.log_message(f"๐พ New best model saved! Score: {self.best_score:.2f}")
|
| 413 |
+
|
| 414 |
+
# Save models periodically
|
| 415 |
+
if k % self.save_interval == 0 and k > 0:
|
| 416 |
+
torch.save(self.q.state_dict(), "enhanced_mario_q.pth")
|
| 417 |
+
torch.save(self.q_target.state_dict(), "enhanced_mario_q_target.pth")
|
| 418 |
+
self.log_message(f"๐พ Models saved at episode {k}")
|
| 419 |
+
|
| 420 |
+
# Print progress
|
| 421 |
+
if k % self.print_interval == 0 and k > 0:
|
| 422 |
+
time_spent = time.perf_counter() - start_time
|
| 423 |
+
start_time = time.perf_counter()
|
| 424 |
+
|
| 425 |
+
avg_loss = self.loss_accumulator / (self.print_interval * max(episode_steps, 1))
|
| 426 |
+
avg_score = self.total_score / self.print_interval
|
| 427 |
+
|
| 428 |
+
log_msg = (
|
| 429 |
+
f"{self.device} | Ep: {k} | Score: {avg_score:.2f} | Loss: {avg_loss:.4f} | "
|
| 430 |
+
f"Stage: {world}-{stage} | Eps: {self.eps:.3f} | Time: {time_spent:.2f}s | "
|
| 431 |
+
f"Mem: {len(self.memory)} | Steps: {episode_steps}"
|
| 432 |
+
)
|
| 433 |
+
self.log_message(log_msg)
|
| 434 |
+
|
| 435 |
+
score_lst.append(avg_score)
|
| 436 |
+
self.total_score = 0.0
|
| 437 |
+
self.loss_accumulator = 0.0
|
| 438 |
+
|
| 439 |
+
try:
|
| 440 |
+
pickle.dump(score_lst, open("score.p", "wb"))
|
| 441 |
+
except Exception as e:
|
| 442 |
+
self.log_message(f"โ ๏ธ Could not save scores: {e}")
|
| 443 |
+
|
| 444 |
+
self.k = k
|
| 445 |
+
|
| 446 |
+
except Exception as e:
|
| 447 |
+
self.log_message(f"โ Training error: {e}")
|
| 448 |
+
import traceback
|
| 449 |
+
traceback.print_exc()
|
| 450 |
+
|
| 451 |
+
def log_message(self, message):
|
| 452 |
+
progress_data = {
|
| 453 |
+
'log_message': message
|
| 454 |
+
}
|
| 455 |
+
self.update_signal.emit(progress_data)
|
| 456 |
+
|
| 457 |
+
def stop(self):
|
| 458 |
+
self.running = False
|
| 459 |
+
if self.env:
|
| 460 |
+
try:
|
| 461 |
+
self.env.close()
|
| 462 |
+
except:
|
| 463 |
+
pass
|
| 464 |
+
|
| 465 |
+
|
| 466 |
+
class MarioRLApp(QMainWindow):
|
| 467 |
+
def __init__(self):
|
| 468 |
+
super().__init__()
|
| 469 |
+
self.training_thread = None
|
| 470 |
+
self.init_ui()
|
| 471 |
+
|
| 472 |
+
def init_ui(self):
|
| 473 |
+
self.setWindowTitle('๐ฎ Super Mario Bros - Dueling DQN Training')
|
| 474 |
+
self.setGeometry(100, 100, 1200, 800)
|
| 475 |
+
|
| 476 |
+
central_widget = QWidget()
|
| 477 |
+
self.setCentralWidget(central_widget)
|
| 478 |
+
layout = QVBoxLayout(central_widget)
|
| 479 |
+
|
| 480 |
+
# Title
|
| 481 |
+
title = QLabel('๐ฎ Super Mario Bros - Enhanced Dueling DQN')
|
| 482 |
+
title.setFont(QFont('Arial', 16, QFont.Bold))
|
| 483 |
+
title.setAlignment(Qt.AlignCenter)
|
| 484 |
+
layout.addWidget(title)
|
| 485 |
+
|
| 486 |
+
# Control Panel
|
| 487 |
+
control_layout = QHBoxLayout()
|
| 488 |
+
|
| 489 |
+
self.device_combo = QComboBox()
|
| 490 |
+
self.device_combo.addItems(['cpu', 'cuda', 'mps'])
|
| 491 |
+
|
| 492 |
+
self.start_btn = QPushButton('Start Training')
|
| 493 |
+
self.start_btn.clicked.connect(self.start_training)
|
| 494 |
+
|
| 495 |
+
self.stop_btn = QPushButton('Stop Training')
|
| 496 |
+
self.stop_btn.clicked.connect(self.stop_training)
|
| 497 |
+
self.stop_btn.setEnabled(False)
|
| 498 |
+
|
| 499 |
+
self.load_btn = QPushButton('Load Model')
|
| 500 |
+
self.load_btn.clicked.connect(self.load_model)
|
| 501 |
+
|
| 502 |
+
control_layout.addWidget(QLabel('Device:'))
|
| 503 |
+
control_layout.addWidget(self.device_combo)
|
| 504 |
+
control_layout.addWidget(self.start_btn)
|
| 505 |
+
control_layout.addWidget(self.stop_btn)
|
| 506 |
+
control_layout.addWidget(self.load_btn)
|
| 507 |
+
control_layout.addStretch()
|
| 508 |
+
|
| 509 |
+
layout.addLayout(control_layout)
|
| 510 |
+
|
| 511 |
+
# Content Area
|
| 512 |
+
content_layout = QHBoxLayout()
|
| 513 |
+
|
| 514 |
+
# Left side - Game Display
|
| 515 |
+
left_frame = QFrame()
|
| 516 |
+
left_frame.setFrameStyle(QFrame.Box)
|
| 517 |
+
left_layout = QVBoxLayout(left_frame)
|
| 518 |
+
|
| 519 |
+
self.game_display = QLabel()
|
| 520 |
+
self.game_display.setMinimumSize(400, 300)
|
| 521 |
+
self.game_display.setAlignment(Qt.AlignCenter)
|
| 522 |
+
self.game_display.setText('Game display will appear here\nPress "Start Training" to begin')
|
| 523 |
+
self.game_display.setStyleSheet('border: 1px solid gray; background-color: black; color: white;')
|
| 524 |
+
|
| 525 |
+
left_layout.addWidget(QLabel('Mario Game Display:'))
|
| 526 |
+
left_layout.addWidget(self.game_display)
|
| 527 |
+
|
| 528 |
+
# Right side - Training Info
|
| 529 |
+
right_frame = QFrame()
|
| 530 |
+
right_frame.setFrameStyle(QFrame.Box)
|
| 531 |
+
right_layout = QVBoxLayout(right_frame)
|
| 532 |
+
|
| 533 |
+
# Training stats
|
| 534 |
+
stats_group = QGroupBox("Training Statistics")
|
| 535 |
+
stats_layout = QVBoxLayout(stats_group)
|
| 536 |
+
|
| 537 |
+
self.episode_label = QLabel('Episode: 0')
|
| 538 |
+
self.world_label = QLabel('World: 1-1')
|
| 539 |
+
self.score_label = QLabel('Score: 0')
|
| 540 |
+
self.reward_label = QLabel('Episode Reward: 0')
|
| 541 |
+
self.steps_label = QLabel('Steps: 0')
|
| 542 |
+
self.epsilon_label = QLabel('Epsilon: 1.000')
|
| 543 |
+
self.loss_label = QLabel('Loss: 0.0000')
|
| 544 |
+
self.memory_label = QLabel('Memory: 0')
|
| 545 |
+
self.xpos_label = QLabel('X Position: 0')
|
| 546 |
+
self.coins_label = QLabel('Coins: 0')
|
| 547 |
+
self.time_label = QLabel('Time: 400')
|
| 548 |
+
self.flag_label = QLabel('Flag: No')
|
| 549 |
+
|
| 550 |
+
stats_layout.addWidget(self.episode_label)
|
| 551 |
+
stats_layout.addWidget(self.world_label)
|
| 552 |
+
stats_layout.addWidget(self.score_label)
|
| 553 |
+
stats_layout.addWidget(self.reward_label)
|
| 554 |
+
stats_layout.addWidget(self.steps_label)
|
| 555 |
+
stats_layout.addWidget(self.epsilon_label)
|
| 556 |
+
stats_layout.addWidget(self.loss_label)
|
| 557 |
+
stats_layout.addWidget(self.memory_label)
|
| 558 |
+
stats_layout.addWidget(self.xpos_label)
|
| 559 |
+
stats_layout.addWidget(self.coins_label)
|
| 560 |
+
stats_layout.addWidget(self.time_label)
|
| 561 |
+
stats_layout.addWidget(self.flag_label)
|
| 562 |
+
|
| 563 |
+
right_layout.addWidget(stats_group)
|
| 564 |
+
|
| 565 |
+
# Training log
|
| 566 |
+
right_layout.addWidget(QLabel('Training Log:'))
|
| 567 |
+
self.log_text = QTextEdit()
|
| 568 |
+
self.log_text.setMaximumHeight(300)
|
| 569 |
+
right_layout.addWidget(self.log_text)
|
| 570 |
+
|
| 571 |
+
content_layout.addWidget(left_frame)
|
| 572 |
+
content_layout.addWidget(right_frame)
|
| 573 |
+
layout.addLayout(content_layout)
|
| 574 |
+
|
| 575 |
+
def start_training(self):
|
| 576 |
+
device = self.device_combo.currentText()
|
| 577 |
+
|
| 578 |
+
# Check device availability
|
| 579 |
+
if device == "cuda" and not torch.cuda.is_available():
|
| 580 |
+
self.log_text.append("โ CUDA not available, using CPU instead")
|
| 581 |
+
device = "cpu"
|
| 582 |
+
elif device == "mps" and not torch.backends.mps.is_available():
|
| 583 |
+
self.log_text.append("โ MPS not available, using CPU instead")
|
| 584 |
+
device = "cpu"
|
| 585 |
+
|
| 586 |
+
self.training_thread = MarioTrainingThread(device)
|
| 587 |
+
self.training_thread.update_signal.connect(self.update_training_info)
|
| 588 |
+
self.training_thread.frame_signal.connect(self.update_game_display)
|
| 589 |
+
self.training_thread.start()
|
| 590 |
+
|
| 591 |
+
self.start_btn.setEnabled(False)
|
| 592 |
+
self.stop_btn.setEnabled(True)
|
| 593 |
+
|
| 594 |
+
self.log_text.append(f'๐ Started Dueling DQN training on {device}...')
|
| 595 |
+
|
| 596 |
+
def stop_training(self):
|
| 597 |
+
if self.training_thread:
|
| 598 |
+
self.training_thread.stop()
|
| 599 |
+
self.training_thread.wait()
|
| 600 |
+
|
| 601 |
+
self.start_btn.setEnabled(True)
|
| 602 |
+
self.stop_btn.setEnabled(False)
|
| 603 |
+
self.log_text.append('โน๏ธ Training stopped.')
|
| 604 |
+
|
| 605 |
+
def load_model(self):
|
| 606 |
+
# Placeholder for model loading functionality
|
| 607 |
+
self.log_text.append('๐ Load model functionality not implemented yet')
|
| 608 |
+
|
| 609 |
+
def update_training_info(self, data):
|
| 610 |
+
if 'episode' in data:
|
| 611 |
+
self.episode_label.setText(f'Episode: {data["episode"]}')
|
| 612 |
+
if 'world' in data and 'stage' in data:
|
| 613 |
+
self.world_label.setText(f'World: {data["world"]}-{data["stage"]}')
|
| 614 |
+
if 'score' in data:
|
| 615 |
+
self.score_label.setText(f'Score: {data["score"]}')
|
| 616 |
+
if 'total_reward' in data:
|
| 617 |
+
self.reward_label.setText(f'Episode Reward: {data["total_reward"]:.2f}')
|
| 618 |
+
if 'steps' in data:
|
| 619 |
+
self.steps_label.setText(f'Steps: {data["steps"]}')
|
| 620 |
+
if 'epsilon' in data:
|
| 621 |
+
self.epsilon_label.setText(f'Epsilon: {data["epsilon"]:.3f}')
|
| 622 |
+
if 'loss' in data:
|
| 623 |
+
self.loss_label.setText(f'Loss: {data["loss"]:.4f}')
|
| 624 |
+
if 'memory_size' in data:
|
| 625 |
+
self.memory_label.setText(f'Memory: {data["memory_size"]}')
|
| 626 |
+
if 'x_pos' in data:
|
| 627 |
+
self.xpos_label.setText(f'X Position: {data["x_pos"]}')
|
| 628 |
+
if 'coins' in data:
|
| 629 |
+
self.coins_label.setText(f'Coins: {data["coins"]}')
|
| 630 |
+
if 'time' in data:
|
| 631 |
+
self.time_label.setText(f'Time: {data["time"]}')
|
| 632 |
+
if 'flag_get' in data:
|
| 633 |
+
flag_text = "Yes" if data["flag_get"] else "No"
|
| 634 |
+
self.flag_label.setText(f'Flag: {flag_text}')
|
| 635 |
+
if 'log_message' in data:
|
| 636 |
+
self.log_text.append(data['log_message'])
|
| 637 |
+
# Auto-scroll to bottom
|
| 638 |
+
self.log_text.verticalScrollBar().setValue(
|
| 639 |
+
self.log_text.verticalScrollBar().maximum()
|
| 640 |
+
)
|
| 641 |
+
|
| 642 |
+
def update_game_display(self, frame):
|
| 643 |
+
if frame is not None:
|
| 644 |
+
try:
|
| 645 |
+
h, w, ch = frame.shape
|
| 646 |
+
bytes_per_line = ch * w
|
| 647 |
+
# Ensure contiguous array
|
| 648 |
+
frame_contiguous = np.ascontiguousarray(frame)
|
| 649 |
+
q_img = QImage(frame_contiguous.data, w, h, bytes_per_line, QImage.Format_RGB888)
|
| 650 |
+
pixmap = QPixmap.fromImage(q_img)
|
| 651 |
+
self.game_display.setPixmap(pixmap.scaled(400, 300, Qt.KeepAspectRatio))
|
| 652 |
+
except Exception as e:
|
| 653 |
+
print(f"Error updating display: {e}")
|
| 654 |
+
|
| 655 |
+
def closeEvent(self, event):
|
| 656 |
+
self.stop_training()
|
| 657 |
+
event.accept()
|
| 658 |
+
|
| 659 |
+
|
| 660 |
+
def main():
|
| 661 |
+
# Set random seeds for reproducibility
|
| 662 |
+
torch.manual_seed(42)
|
| 663 |
+
np.random.seed(42)
|
| 664 |
+
random.seed(42)
|
| 665 |
+
|
| 666 |
+
app = QApplication(sys.argv)
|
| 667 |
+
window = MarioRLApp()
|
| 668 |
+
window.show()
|
| 669 |
+
sys.exit(app.exec_())
|
| 670 |
+
|
| 671 |
+
|
| 672 |
+
if __name__ == '__main__':
|
| 673 |
+
main()
|
Super-Mario-RL-PyQt5/app_2.py
ADDED
|
@@ -0,0 +1,676 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pickle
|
| 2 |
+
import random
|
| 3 |
+
import time
|
| 4 |
+
from collections import deque
|
| 5 |
+
|
| 6 |
+
import gym_super_mario_bros
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
import torch.optim as optim
|
| 12 |
+
from gym_super_mario_bros.actions import COMPLEX_MOVEMENT
|
| 13 |
+
from nes_py.wrappers import JoypadSpace
|
| 14 |
+
|
| 15 |
+
from PyQt5.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout,
|
| 16 |
+
QHBoxLayout, QPushButton, QLabel, QComboBox,
|
| 17 |
+
QTextEdit, QProgressBar, QTabWidget, QFrame, QGroupBox)
|
| 18 |
+
from PyQt5.QtCore import QTimer, Qt, pyqtSignal, QThread
|
| 19 |
+
from PyQt5.QtGui import QImage, QPixmap, QFont
|
| 20 |
+
import sys
|
| 21 |
+
import cv2
|
| 22 |
+
|
| 23 |
+
# Import your wrappers (make sure this module exists)
|
| 24 |
+
try:
|
| 25 |
+
from wrappers import *
|
| 26 |
+
except ImportError:
|
| 27 |
+
# Create a proper wrapper if the module doesn't exist
|
| 28 |
+
class SimpleWrapper:
|
| 29 |
+
def __init__(self, env):
|
| 30 |
+
self.env = env
|
| 31 |
+
self.action_space = env.action_space
|
| 32 |
+
self.observation_space = env.observation_space
|
| 33 |
+
|
| 34 |
+
def reset(self):
|
| 35 |
+
return self.env.reset()
|
| 36 |
+
|
| 37 |
+
def step(self, action):
|
| 38 |
+
return self.env.step(action)
|
| 39 |
+
|
| 40 |
+
def render(self, mode='rgb_array'):
|
| 41 |
+
return self.env.render(mode)
|
| 42 |
+
|
| 43 |
+
def close(self):
|
| 44 |
+
if hasattr(self.env, 'close'):
|
| 45 |
+
self.env.close()
|
| 46 |
+
|
| 47 |
+
def wrap_mario(env):
|
| 48 |
+
return SimpleWrapper(env)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class FrameStacker:
|
| 52 |
+
"""Handles frame stacking and preprocessing for the neural network"""
|
| 53 |
+
def __init__(self, frame_size=(84, 84), stack_size=4):
|
| 54 |
+
self.frame_size = frame_size
|
| 55 |
+
self.stack_size = stack_size
|
| 56 |
+
self.frames = deque(maxlen=stack_size)
|
| 57 |
+
|
| 58 |
+
def preprocess_frame(self, frame):
|
| 59 |
+
"""Convert frame to grayscale and resize for the neural network"""
|
| 60 |
+
# Convert to grayscale
|
| 61 |
+
gray = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
|
| 62 |
+
# Resize to 84x84
|
| 63 |
+
resized = cv2.resize(gray, self.frame_size, interpolation=cv2.INTER_AREA)
|
| 64 |
+
# Normalize to [0, 1]
|
| 65 |
+
normalized = resized.astype(np.float32) / 255.0
|
| 66 |
+
return normalized
|
| 67 |
+
|
| 68 |
+
def reset(self, frame):
|
| 69 |
+
"""Reset frame stack with initial frame"""
|
| 70 |
+
self.frames.clear()
|
| 71 |
+
processed_frame = self.preprocess_frame(frame)
|
| 72 |
+
for _ in range(self.stack_size):
|
| 73 |
+
self.frames.append(processed_frame)
|
| 74 |
+
return self.get_stacked_frames()
|
| 75 |
+
|
| 76 |
+
def append(self, frame):
|
| 77 |
+
"""Add new frame to stack"""
|
| 78 |
+
processed_frame = self.preprocess_frame(frame)
|
| 79 |
+
self.frames.append(processed_frame)
|
| 80 |
+
return self.get_stacked_frames()
|
| 81 |
+
|
| 82 |
+
def get_stacked_frames(self):
|
| 83 |
+
"""Get stacked frames as numpy array for the neural network"""
|
| 84 |
+
stacked = np.array(self.frames)
|
| 85 |
+
return np.ascontiguousarray(stacked)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class replay_memory(object):
|
| 89 |
+
def __init__(self, N):
|
| 90 |
+
self.memory = deque(maxlen=N)
|
| 91 |
+
|
| 92 |
+
def push(self, transition):
|
| 93 |
+
self.memory.append(transition)
|
| 94 |
+
|
| 95 |
+
def sample(self, n):
|
| 96 |
+
return random.sample(self.memory, n)
|
| 97 |
+
|
| 98 |
+
def __len__(self):
|
| 99 |
+
return len(self.memory)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class DuelingDQNModel(nn.Module):
|
| 103 |
+
def __init__(self, n_frame, n_action, device):
|
| 104 |
+
super(DuelingDQNModel, self).__init__()
|
| 105 |
+
|
| 106 |
+
# CNN layers for feature extraction
|
| 107 |
+
self.conv_layers = nn.Sequential(
|
| 108 |
+
nn.Conv2d(n_frame, 32, kernel_size=8, stride=4),
|
| 109 |
+
nn.ReLU(),
|
| 110 |
+
nn.Conv2d(32, 64, kernel_size=4, stride=2),
|
| 111 |
+
nn.ReLU(),
|
| 112 |
+
nn.Conv2d(64, 64, kernel_size=3, stride=1),
|
| 113 |
+
nn.ReLU()
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
# Calculate conv output size
|
| 117 |
+
self.conv_out_size = self._get_conv_out((n_frame, 84, 84))
|
| 118 |
+
|
| 119 |
+
# Advantage stream
|
| 120 |
+
self.advantage_stream = nn.Sequential(
|
| 121 |
+
nn.Linear(self.conv_out_size, 512),
|
| 122 |
+
nn.ReLU(),
|
| 123 |
+
nn.Linear(512, n_action)
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
# Value stream
|
| 127 |
+
self.value_stream = nn.Sequential(
|
| 128 |
+
nn.Linear(self.conv_out_size, 512),
|
| 129 |
+
nn.ReLU(),
|
| 130 |
+
nn.Linear(512, 1)
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
self.device = device
|
| 134 |
+
self.apply(self.init_weights)
|
| 135 |
+
|
| 136 |
+
def _get_conv_out(self, shape):
|
| 137 |
+
with torch.no_grad():
|
| 138 |
+
x = torch.zeros(1, *shape)
|
| 139 |
+
x = self.conv_layers(x)
|
| 140 |
+
return int(np.prod(x.size()))
|
| 141 |
+
|
| 142 |
+
def init_weights(self, m):
|
| 143 |
+
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
|
| 144 |
+
torch.nn.init.xavier_uniform_(m.weight)
|
| 145 |
+
if m.bias is not None:
|
| 146 |
+
m.bias.data.fill_(0.01)
|
| 147 |
+
|
| 148 |
+
def forward(self, x):
|
| 149 |
+
if not isinstance(x, torch.Tensor):
|
| 150 |
+
x = torch.FloatTensor(x).to(self.device)
|
| 151 |
+
|
| 152 |
+
# Forward through conv layers
|
| 153 |
+
x = self.conv_layers(x)
|
| 154 |
+
x = x.view(x.size(0), -1)
|
| 155 |
+
|
| 156 |
+
# Forward through advantage and value streams
|
| 157 |
+
advantage = self.advantage_stream(x)
|
| 158 |
+
value = self.value_stream(x)
|
| 159 |
+
|
| 160 |
+
# Combine value and advantage
|
| 161 |
+
q_values = value + (advantage - advantage.mean(dim=1, keepdim=True))
|
| 162 |
+
|
| 163 |
+
return q_values
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def train(q, q_target, memory, batch_size, gamma, optimizer, device):
|
| 167 |
+
if len(memory) < batch_size:
|
| 168 |
+
return 0.0
|
| 169 |
+
|
| 170 |
+
transitions = memory.sample(batch_size)
|
| 171 |
+
s, r, a, s_prime, done = list(map(list, zip(*transitions)))
|
| 172 |
+
|
| 173 |
+
# Ensure positive strides for all arrays
|
| 174 |
+
s = np.array([np.ascontiguousarray(arr) for arr in s])
|
| 175 |
+
s_prime = np.array([np.ascontiguousarray(arr) for arr in s_prime])
|
| 176 |
+
|
| 177 |
+
# Move computations to device
|
| 178 |
+
s_tensor = torch.FloatTensor(s).to(device)
|
| 179 |
+
s_prime_tensor = torch.FloatTensor(s_prime).to(device)
|
| 180 |
+
|
| 181 |
+
# Get next Q values from target network
|
| 182 |
+
with torch.no_grad():
|
| 183 |
+
next_q_values = q_target(s_prime_tensor)
|
| 184 |
+
next_actions = next_q_values.max(1)[1].unsqueeze(1)
|
| 185 |
+
next_q_value = next_q_values.gather(1, next_actions)
|
| 186 |
+
|
| 187 |
+
# Calculate target Q values
|
| 188 |
+
r = torch.FloatTensor(r).unsqueeze(1).to(device)
|
| 189 |
+
done = torch.FloatTensor(done).unsqueeze(1).to(device)
|
| 190 |
+
target_q_values = r + gamma * next_q_value * (1 - done)
|
| 191 |
+
|
| 192 |
+
# Get current Q values
|
| 193 |
+
a_tensor = torch.LongTensor(a).unsqueeze(1).to(device)
|
| 194 |
+
current_q_values = q(s_tensor).gather(1, a_tensor)
|
| 195 |
+
|
| 196 |
+
# Calculate loss
|
| 197 |
+
loss = F.smooth_l1_loss(current_q_values, target_q_values)
|
| 198 |
+
|
| 199 |
+
# Optimize
|
| 200 |
+
optimizer.zero_grad()
|
| 201 |
+
loss.backward()
|
| 202 |
+
|
| 203 |
+
# Gradient clipping
|
| 204 |
+
torch.nn.utils.clip_grad_norm_(q.parameters(), max_norm=10.0)
|
| 205 |
+
|
| 206 |
+
optimizer.step()
|
| 207 |
+
return loss.item()
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def copy_weights(q, q_target):
|
| 211 |
+
q_dict = q.state_dict()
|
| 212 |
+
q_target.load_state_dict(q_dict)
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
class MarioTrainingThread(QThread):
|
| 216 |
+
update_signal = pyqtSignal(dict)
|
| 217 |
+
frame_signal = pyqtSignal(np.ndarray)
|
| 218 |
+
|
| 219 |
+
def __init__(self, device="cpu"):
|
| 220 |
+
super().__init__()
|
| 221 |
+
self.device = device
|
| 222 |
+
self.running = False
|
| 223 |
+
self.env = None
|
| 224 |
+
self.q = None
|
| 225 |
+
self.q_target = None
|
| 226 |
+
self.optimizer = None
|
| 227 |
+
self.frame_stacker = None
|
| 228 |
+
|
| 229 |
+
# Training parameters
|
| 230 |
+
self.gamma = 0.99
|
| 231 |
+
self.batch_size = 32
|
| 232 |
+
self.memory_size = 10000
|
| 233 |
+
self.eps = 1.0 # Start with full exploration
|
| 234 |
+
self.eps_min = 0.01
|
| 235 |
+
self.eps_decay = 0.995
|
| 236 |
+
self.update_interval = 1000
|
| 237 |
+
self.save_interval = 100
|
| 238 |
+
self.print_interval = 10
|
| 239 |
+
|
| 240 |
+
self.memory = None
|
| 241 |
+
self.t = 0
|
| 242 |
+
self.k = 0
|
| 243 |
+
self.total_score = 0.0
|
| 244 |
+
self.loss_accumulator = 0.0
|
| 245 |
+
self.best_score = -float('inf')
|
| 246 |
+
self.last_x_pos = 0
|
| 247 |
+
|
| 248 |
+
def setup_training(self):
|
| 249 |
+
n_frame = 4 # Number of stacked frames
|
| 250 |
+
try:
|
| 251 |
+
self.env = gym_super_mario_bros.make("SuperMarioBros-v3")
|
| 252 |
+
self.env = JoypadSpace(self.env, COMPLEX_MOVEMENT)
|
| 253 |
+
self.env = wrap_mario(self.env)
|
| 254 |
+
|
| 255 |
+
# Initialize frame stacker
|
| 256 |
+
self.frame_stacker = FrameStacker(frame_size=(84, 84), stack_size=n_frame)
|
| 257 |
+
|
| 258 |
+
self.q = DuelingDQNModel(n_frame, self.env.action_space.n, self.device).to(self.device)
|
| 259 |
+
self.q_target = DuelingDQNModel(n_frame, self.env.action_space.n, self.device).to(self.device)
|
| 260 |
+
|
| 261 |
+
copy_weights(self.q, self.q_target)
|
| 262 |
+
|
| 263 |
+
# Set target network to eval mode
|
| 264 |
+
self.q_target.eval()
|
| 265 |
+
|
| 266 |
+
# Optimizer
|
| 267 |
+
self.optimizer = optim.Adam(self.q.parameters(), lr=0.0001, weight_decay=1e-5)
|
| 268 |
+
|
| 269 |
+
self.memory = replay_memory(self.memory_size)
|
| 270 |
+
|
| 271 |
+
self.log_message(f"โ
Training setup complete - Actions: {self.env.action_space.n}, Device: {self.device}")
|
| 272 |
+
|
| 273 |
+
except Exception as e:
|
| 274 |
+
self.log_message(f"โ Error setting up training: {e}")
|
| 275 |
+
import traceback
|
| 276 |
+
traceback.print_exc()
|
| 277 |
+
self.running = False
|
| 278 |
+
|
| 279 |
+
def run(self):
|
| 280 |
+
self.running = True
|
| 281 |
+
self.setup_training()
|
| 282 |
+
|
| 283 |
+
if not self.running:
|
| 284 |
+
return
|
| 285 |
+
|
| 286 |
+
start_time = time.perf_counter()
|
| 287 |
+
score_lst = []
|
| 288 |
+
|
| 289 |
+
try:
|
| 290 |
+
for k in range(1000000):
|
| 291 |
+
if not self.running:
|
| 292 |
+
break
|
| 293 |
+
|
| 294 |
+
# Reset environment and frame stacker
|
| 295 |
+
frame = self.env.reset()
|
| 296 |
+
s = self.frame_stacker.reset(frame)
|
| 297 |
+
done = False
|
| 298 |
+
episode_loss = 0.0
|
| 299 |
+
episode_steps = 0
|
| 300 |
+
episode_score = 0.0
|
| 301 |
+
self.last_x_pos = 0
|
| 302 |
+
|
| 303 |
+
while not done and self.running:
|
| 304 |
+
# Ensure state has positive strides before processing
|
| 305 |
+
s_processed = np.ascontiguousarray(s)
|
| 306 |
+
|
| 307 |
+
# Epsilon-greedy action selection
|
| 308 |
+
if np.random.random() <= self.eps:
|
| 309 |
+
a = self.env.action_space.sample()
|
| 310 |
+
else:
|
| 311 |
+
with torch.no_grad():
|
| 312 |
+
# Add batch dimension and create tensor
|
| 313 |
+
state_tensor = torch.FloatTensor(s_processed).unsqueeze(0).to(self.device)
|
| 314 |
+
q_values = self.q(state_tensor)
|
| 315 |
+
|
| 316 |
+
if self.device == "cuda" or self.device == "mps":
|
| 317 |
+
a = np.argmax(q_values.cpu().numpy())
|
| 318 |
+
else:
|
| 319 |
+
a = np.argmax(q_values.detach().numpy())
|
| 320 |
+
|
| 321 |
+
# Take action
|
| 322 |
+
next_frame, r, done, info = self.env.step(a)
|
| 323 |
+
|
| 324 |
+
# Update frame stack for neural network
|
| 325 |
+
s_prime = self.frame_stacker.append(next_frame)
|
| 326 |
+
|
| 327 |
+
episode_score += r
|
| 328 |
+
|
| 329 |
+
# Enhanced reward shaping
|
| 330 |
+
reward = r # Start with original reward
|
| 331 |
+
|
| 332 |
+
# Bonus for x_pos progress
|
| 333 |
+
if 'x_pos' in info:
|
| 334 |
+
x_pos = info['x_pos']
|
| 335 |
+
x_progress = x_pos - self.last_x_pos
|
| 336 |
+
if x_progress > 0:
|
| 337 |
+
reward += 0.1 * x_progress
|
| 338 |
+
self.last_x_pos = x_pos
|
| 339 |
+
|
| 340 |
+
# Large bonus for completing the level
|
| 341 |
+
if done and info.get('flag_get', False):
|
| 342 |
+
reward += 100.0
|
| 343 |
+
self.log_message(f"๐ LEVEL COMPLETED at episode {k}! ๐")
|
| 344 |
+
|
| 345 |
+
# Store transition with contiguous arrays
|
| 346 |
+
s_contiguous = np.ascontiguousarray(s)
|
| 347 |
+
s_prime_contiguous = np.ascontiguousarray(s_prime)
|
| 348 |
+
self.memory.push((s_contiguous, float(reward), int(a), s_prime_contiguous, int(1 - done)))
|
| 349 |
+
|
| 350 |
+
s = s_prime
|
| 351 |
+
stage = info.get('stage', 1)
|
| 352 |
+
world = info.get('world', 1)
|
| 353 |
+
|
| 354 |
+
# Emit ORIGINAL COLOR FRAME for display (not preprocessed)
|
| 355 |
+
try:
|
| 356 |
+
# Get the original color frame for display
|
| 357 |
+
display_frame = self.env.render()
|
| 358 |
+
if display_frame is not None:
|
| 359 |
+
# Ensure frame has positive strides and emit original color frame
|
| 360 |
+
frame_contiguous = np.ascontiguousarray(display_frame)
|
| 361 |
+
self.frame_signal.emit(frame_contiguous)
|
| 362 |
+
except Exception as e:
|
| 363 |
+
# Create a placeholder frame if rendering fails
|
| 364 |
+
frame = np.zeros((240, 256, 3), dtype=np.uint8)
|
| 365 |
+
self.frame_signal.emit(frame)
|
| 366 |
+
|
| 367 |
+
# Train only if we have enough samples
|
| 368 |
+
if len(self.memory) > self.batch_size:
|
| 369 |
+
loss_val = train(self.q, self.q_target, self.memory, self.batch_size,
|
| 370 |
+
self.gamma, self.optimizer, self.device)
|
| 371 |
+
if loss_val > 0:
|
| 372 |
+
self.loss_accumulator += loss_val
|
| 373 |
+
episode_loss += loss_val
|
| 374 |
+
self.t += 1
|
| 375 |
+
|
| 376 |
+
# Update target network
|
| 377 |
+
if self.t % self.update_interval == 0:
|
| 378 |
+
copy_weights(self.q, self.q_target)
|
| 379 |
+
self.log_message(f"๐ Target network updated at step {self.t}")
|
| 380 |
+
|
| 381 |
+
episode_steps += 1
|
| 382 |
+
|
| 383 |
+
# Emit training progress every 5 steps for more frequent updates
|
| 384 |
+
if episode_steps % 5 == 0:
|
| 385 |
+
progress_data = {
|
| 386 |
+
'episode': k,
|
| 387 |
+
'total_reward': episode_score,
|
| 388 |
+
'steps': episode_steps,
|
| 389 |
+
'epsilon': self.eps,
|
| 390 |
+
'world': world,
|
| 391 |
+
'stage': stage,
|
| 392 |
+
'loss': episode_loss / (episode_steps + 1e-8),
|
| 393 |
+
'memory_size': len(self.memory),
|
| 394 |
+
'x_pos': info.get('x_pos', 0),
|
| 395 |
+
'score': info.get('score', 0),
|
| 396 |
+
'coins': info.get('coins', 0),
|
| 397 |
+
'time': info.get('time', 400),
|
| 398 |
+
'flag_get': info.get('flag_get', False)
|
| 399 |
+
}
|
| 400 |
+
self.update_signal.emit(progress_data)
|
| 401 |
+
|
| 402 |
+
# Epsilon decay after each episode
|
| 403 |
+
if self.eps > self.eps_min:
|
| 404 |
+
self.eps *= self.eps_decay
|
| 405 |
+
|
| 406 |
+
# Update total score
|
| 407 |
+
self.total_score += episode_score
|
| 408 |
+
|
| 409 |
+
# Save best model
|
| 410 |
+
if episode_score > self.best_score and k > 0:
|
| 411 |
+
self.best_score = episode_score
|
| 412 |
+
torch.save(self.q.state_dict(), "enhanced_mario_q_best.pth")
|
| 413 |
+
torch.save(self.q_target.state_dict(), "enhanced_mario_q_target_best.pth")
|
| 414 |
+
self.log_message(f"๐พ New best model saved! Score: {self.best_score:.2f}")
|
| 415 |
+
|
| 416 |
+
# Save models periodically
|
| 417 |
+
if k % self.save_interval == 0 and k > 0:
|
| 418 |
+
torch.save(self.q.state_dict(), "enhanced_mario_q.pth")
|
| 419 |
+
torch.save(self.q_target.state_dict(), "enhanced_mario_q_target.pth")
|
| 420 |
+
self.log_message(f"๐พ Models saved at episode {k}")
|
| 421 |
+
|
| 422 |
+
# Print progress
|
| 423 |
+
if k % self.print_interval == 0 and k > 0:
|
| 424 |
+
time_spent = time.perf_counter() - start_time
|
| 425 |
+
start_time = time.perf_counter()
|
| 426 |
+
|
| 427 |
+
avg_loss = self.loss_accumulator / (self.print_interval * max(episode_steps, 1))
|
| 428 |
+
avg_score = self.total_score / self.print_interval
|
| 429 |
+
|
| 430 |
+
log_msg = (
|
| 431 |
+
f"{self.device} | Ep: {k} | Score: {avg_score:.2f} | Loss: {avg_loss:.4f} | "
|
| 432 |
+
f"Stage: {world}-{stage} | Eps: {self.eps:.3f} | Time: {time_spent:.2f}s | "
|
| 433 |
+
f"Mem: {len(self.memory)} | Steps: {episode_steps}"
|
| 434 |
+
)
|
| 435 |
+
self.log_message(log_msg)
|
| 436 |
+
|
| 437 |
+
score_lst.append(avg_score)
|
| 438 |
+
self.total_score = 0.0
|
| 439 |
+
self.loss_accumulator = 0.0
|
| 440 |
+
|
| 441 |
+
try:
|
| 442 |
+
pickle.dump(score_lst, open("score.p", "wb"))
|
| 443 |
+
except Exception as e:
|
| 444 |
+
self.log_message(f"โ ๏ธ Could not save scores: {e}")
|
| 445 |
+
|
| 446 |
+
self.k = k
|
| 447 |
+
|
| 448 |
+
except Exception as e:
|
| 449 |
+
self.log_message(f"โ Training error: {e}")
|
| 450 |
+
import traceback
|
| 451 |
+
traceback.print_exc()
|
| 452 |
+
|
| 453 |
+
def log_message(self, message):
|
| 454 |
+
progress_data = {
|
| 455 |
+
'log_message': message
|
| 456 |
+
}
|
| 457 |
+
self.update_signal.emit(progress_data)
|
| 458 |
+
|
| 459 |
+
def stop(self):
|
| 460 |
+
self.running = False
|
| 461 |
+
if self.env:
|
| 462 |
+
try:
|
| 463 |
+
self.env.close()
|
| 464 |
+
except:
|
| 465 |
+
pass
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
class MarioRLApp(QMainWindow):
|
| 469 |
+
def __init__(self):
|
| 470 |
+
super().__init__()
|
| 471 |
+
self.training_thread = None
|
| 472 |
+
self.init_ui()
|
| 473 |
+
|
| 474 |
+
def init_ui(self):
|
| 475 |
+
self.setWindowTitle('๐ฎ Super Mario Bros - Dueling DQN Training')
|
| 476 |
+
self.setGeometry(100, 100, 1200, 800)
|
| 477 |
+
|
| 478 |
+
central_widget = QWidget()
|
| 479 |
+
self.setCentralWidget(central_widget)
|
| 480 |
+
layout = QVBoxLayout(central_widget)
|
| 481 |
+
|
| 482 |
+
# Title
|
| 483 |
+
title = QLabel('๐ฎ Super Mario Bros - Enhanced Dueling DQN')
|
| 484 |
+
title.setFont(QFont('Arial', 16, QFont.Bold))
|
| 485 |
+
title.setAlignment(Qt.AlignCenter)
|
| 486 |
+
layout.addWidget(title)
|
| 487 |
+
|
| 488 |
+
# Control Panel
|
| 489 |
+
control_layout = QHBoxLayout()
|
| 490 |
+
|
| 491 |
+
self.device_combo = QComboBox()
|
| 492 |
+
self.device_combo.addItems(['cpu', 'cuda', 'mps'])
|
| 493 |
+
|
| 494 |
+
self.start_btn = QPushButton('Start Training')
|
| 495 |
+
self.start_btn.clicked.connect(self.start_training)
|
| 496 |
+
|
| 497 |
+
self.stop_btn = QPushButton('Stop Training')
|
| 498 |
+
self.stop_btn.clicked.connect(self.stop_training)
|
| 499 |
+
self.stop_btn.setEnabled(False)
|
| 500 |
+
|
| 501 |
+
self.load_btn = QPushButton('Load Model')
|
| 502 |
+
self.load_btn.clicked.connect(self.load_model)
|
| 503 |
+
|
| 504 |
+
control_layout.addWidget(QLabel('Device:'))
|
| 505 |
+
control_layout.addWidget(self.device_combo)
|
| 506 |
+
control_layout.addWidget(self.start_btn)
|
| 507 |
+
control_layout.addWidget(self.stop_btn)
|
| 508 |
+
control_layout.addWidget(self.load_btn)
|
| 509 |
+
control_layout.addStretch()
|
| 510 |
+
|
| 511 |
+
layout.addLayout(control_layout)
|
| 512 |
+
|
| 513 |
+
# Content Area
|
| 514 |
+
content_layout = QHBoxLayout()
|
| 515 |
+
|
| 516 |
+
# Left side - Game Display
|
| 517 |
+
left_frame = QFrame()
|
| 518 |
+
left_frame.setFrameStyle(QFrame.Box)
|
| 519 |
+
left_layout = QVBoxLayout(left_frame)
|
| 520 |
+
|
| 521 |
+
self.game_display = QLabel()
|
| 522 |
+
self.game_display.setMinimumSize(400, 300)
|
| 523 |
+
self.game_display.setAlignment(Qt.AlignCenter)
|
| 524 |
+
self.game_display.setText('Game display will appear here\nPress "Start Training" to begin')
|
| 525 |
+
self.game_display.setStyleSheet('border: 1px solid gray; background-color: black; color: white;')
|
| 526 |
+
|
| 527 |
+
left_layout.addWidget(QLabel('Mario Game Display:'))
|
| 528 |
+
left_layout.addWidget(self.game_display)
|
| 529 |
+
|
| 530 |
+
# Right side - Training Info
|
| 531 |
+
right_frame = QFrame()
|
| 532 |
+
right_frame.setFrameStyle(QFrame.Box)
|
| 533 |
+
right_layout = QVBoxLayout(right_frame)
|
| 534 |
+
|
| 535 |
+
# Training stats
|
| 536 |
+
stats_group = QGroupBox("Training Statistics")
|
| 537 |
+
stats_layout = QVBoxLayout(stats_group)
|
| 538 |
+
|
| 539 |
+
self.episode_label = QLabel('Episode: 0')
|
| 540 |
+
self.world_label = QLabel('World: 1-1')
|
| 541 |
+
self.score_label = QLabel('Score: 0')
|
| 542 |
+
self.reward_label = QLabel('Episode Reward: 0')
|
| 543 |
+
self.steps_label = QLabel('Steps: 0')
|
| 544 |
+
self.epsilon_label = QLabel('Epsilon: 1.000')
|
| 545 |
+
self.loss_label = QLabel('Loss: 0.0000')
|
| 546 |
+
self.memory_label = QLabel('Memory: 0')
|
| 547 |
+
self.xpos_label = QLabel('X Position: 0')
|
| 548 |
+
self.coins_label = QLabel('Coins: 0')
|
| 549 |
+
self.time_label = QLabel('Time: 400')
|
| 550 |
+
self.flag_label = QLabel('Flag: No')
|
| 551 |
+
|
| 552 |
+
stats_layout.addWidget(self.episode_label)
|
| 553 |
+
stats_layout.addWidget(self.world_label)
|
| 554 |
+
stats_layout.addWidget(self.score_label)
|
| 555 |
+
stats_layout.addWidget(self.reward_label)
|
| 556 |
+
stats_layout.addWidget(self.steps_label)
|
| 557 |
+
stats_layout.addWidget(self.epsilon_label)
|
| 558 |
+
stats_layout.addWidget(self.loss_label)
|
| 559 |
+
stats_layout.addWidget(self.memory_label)
|
| 560 |
+
stats_layout.addWidget(self.xpos_label)
|
| 561 |
+
stats_layout.addWidget(self.coins_label)
|
| 562 |
+
stats_layout.addWidget(self.time_label)
|
| 563 |
+
stats_layout.addWidget(self.flag_label)
|
| 564 |
+
|
| 565 |
+
right_layout.addWidget(stats_group)
|
| 566 |
+
|
| 567 |
+
# Training log
|
| 568 |
+
right_layout.addWidget(QLabel('Training Log:'))
|
| 569 |
+
self.log_text = QTextEdit()
|
| 570 |
+
self.log_text.setMaximumHeight(300)
|
| 571 |
+
right_layout.addWidget(self.log_text)
|
| 572 |
+
|
| 573 |
+
content_layout.addWidget(left_frame)
|
| 574 |
+
content_layout.addWidget(right_frame)
|
| 575 |
+
layout.addLayout(content_layout)
|
| 576 |
+
|
| 577 |
+
def start_training(self):
|
| 578 |
+
device = self.device_combo.currentText()
|
| 579 |
+
|
| 580 |
+
# Check device availability
|
| 581 |
+
if device == "cuda" and not torch.cuda.is_available():
|
| 582 |
+
self.log_text.append("โ CUDA not available, using CPU instead")
|
| 583 |
+
device = "cpu"
|
| 584 |
+
elif device == "mps" and not torch.backends.mps.is_available():
|
| 585 |
+
self.log_text.append("โ MPS not available, using CPU instead")
|
| 586 |
+
device = "cpu"
|
| 587 |
+
|
| 588 |
+
self.training_thread = MarioTrainingThread(device)
|
| 589 |
+
self.training_thread.update_signal.connect(self.update_training_info)
|
| 590 |
+
self.training_thread.frame_signal.connect(self.update_game_display)
|
| 591 |
+
self.training_thread.start()
|
| 592 |
+
|
| 593 |
+
self.start_btn.setEnabled(False)
|
| 594 |
+
self.stop_btn.setEnabled(True)
|
| 595 |
+
|
| 596 |
+
self.log_text.append(f'๐ Started Dueling DQN training on {device}...')
|
| 597 |
+
|
| 598 |
+
def stop_training(self):
|
| 599 |
+
if self.training_thread:
|
| 600 |
+
self.training_thread.stop()
|
| 601 |
+
self.training_thread.wait()
|
| 602 |
+
|
| 603 |
+
self.start_btn.setEnabled(True)
|
| 604 |
+
self.stop_btn.setEnabled(False)
|
| 605 |
+
self.log_text.append('โน๏ธ Training stopped.')
|
| 606 |
+
|
| 607 |
+
def load_model(self):
|
| 608 |
+
# Placeholder for model loading functionality
|
| 609 |
+
self.log_text.append('๐ Load model functionality not implemented yet')
|
| 610 |
+
|
| 611 |
+
def update_training_info(self, data):
|
| 612 |
+
if 'episode' in data:
|
| 613 |
+
self.episode_label.setText(f'Episode: {data["episode"]}')
|
| 614 |
+
if 'world' in data and 'stage' in data:
|
| 615 |
+
self.world_label.setText(f'World: {data["world"]}-{data["stage"]}')
|
| 616 |
+
if 'score' in data:
|
| 617 |
+
self.score_label.setText(f'Score: {data["score"]}')
|
| 618 |
+
if 'total_reward' in data:
|
| 619 |
+
self.reward_label.setText(f'Episode Reward: {data["total_reward"]:.2f}')
|
| 620 |
+
if 'steps' in data:
|
| 621 |
+
self.steps_label.setText(f'Steps: {data["steps"]}')
|
| 622 |
+
if 'epsilon' in data:
|
| 623 |
+
self.epsilon_label.setText(f'Epsilon: {data["epsilon"]:.3f}')
|
| 624 |
+
if 'loss' in data:
|
| 625 |
+
self.loss_label.setText(f'Loss: {data["loss"]:.4f}')
|
| 626 |
+
if 'memory_size' in data:
|
| 627 |
+
self.memory_label.setText(f'Memory: {data["memory_size"]}')
|
| 628 |
+
if 'x_pos' in data:
|
| 629 |
+
self.xpos_label.setText(f'X Position: {data["x_pos"]}')
|
| 630 |
+
if 'coins' in data:
|
| 631 |
+
self.coins_label.setText(f'Coins: {data["coins"]}')
|
| 632 |
+
if 'time' in data:
|
| 633 |
+
self.time_label.setText(f'Time: {data["time"]}')
|
| 634 |
+
if 'flag_get' in data:
|
| 635 |
+
flag_text = "Yes" if data["flag_get"] else "No"
|
| 636 |
+
self.flag_label.setText(f'Flag: {flag_text}')
|
| 637 |
+
if 'log_message' in data:
|
| 638 |
+
self.log_text.append(data['log_message'])
|
| 639 |
+
# Auto-scroll to bottom
|
| 640 |
+
self.log_text.verticalScrollBar().setValue(
|
| 641 |
+
self.log_text.verticalScrollBar().maximum()
|
| 642 |
+
)
|
| 643 |
+
|
| 644 |
+
def update_game_display(self, frame):
|
| 645 |
+
if frame is not None:
|
| 646 |
+
try:
|
| 647 |
+
h, w, ch = frame.shape
|
| 648 |
+
bytes_per_line = ch * w
|
| 649 |
+
# Ensure contiguous array and display original color frame
|
| 650 |
+
frame_contiguous = np.ascontiguousarray(frame)
|
| 651 |
+
q_img = QImage(frame_contiguous.data, w, h, bytes_per_line, QImage.Format_RGB888)
|
| 652 |
+
pixmap = QPixmap.fromImage(q_img)
|
| 653 |
+
# Scale to fit the display while maintaining aspect ratio
|
| 654 |
+
self.game_display.setPixmap(pixmap.scaled(400, 300, Qt.KeepAspectRatio, Qt.SmoothTransformation))
|
| 655 |
+
except Exception as e:
|
| 656 |
+
print(f"Error updating display: {e}")
|
| 657 |
+
|
| 658 |
+
def closeEvent(self, event):
|
| 659 |
+
self.stop_training()
|
| 660 |
+
event.accept()
|
| 661 |
+
|
| 662 |
+
|
| 663 |
+
def main():
|
| 664 |
+
# Set random seeds for reproducibility
|
| 665 |
+
torch.manual_seed(42)
|
| 666 |
+
np.random.seed(42)
|
| 667 |
+
random.seed(42)
|
| 668 |
+
|
| 669 |
+
app = QApplication(sys.argv)
|
| 670 |
+
window = MarioRLApp()
|
| 671 |
+
window.show()
|
| 672 |
+
sys.exit(app.exec_())
|
| 673 |
+
|
| 674 |
+
|
| 675 |
+
if __name__ == '__main__':
|
| 676 |
+
main()
|
Super-Mario-RL-PyQt5/enhanced_mario_q_best.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6c82b8bb39904cf745061a6ba1ca2a207977f22f13fb0e72f1141c3f85045eb0
|
| 3 |
+
size 13193617
|
Super-Mario-RL-PyQt5/enhanced_mario_q_target_best.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a9f6060fca857b4335781a219ffa7450baa20eb8efb9287e9561087978c1a1e0
|
| 3 |
+
size 13193949
|
Super-Mario-RL-PyQt5/requirements.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
numpy==1.26.4
|
| 2 |
+
torch>=1.6.0
|
| 3 |
+
torchvision
|
| 4 |
+
gym==0.23
|
| 5 |
+
nes-py
|
| 6 |
+
gym-super-mario-bros==7.2.3
|
| 7 |
+
opencv-python
|
| 8 |
+
matplotlib
|
| 9 |
+
pyqt5
|
Super-Mario-RL-PyQt5/score.p
ADDED
|
Binary file (34 Bytes). View file
|
|
|
Super-Mario-RL/README.md
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# :mushroom: Super-Mario-RL
|
| 2 |
+
|
| 3 |
+
This is a private project to make Super Mario Agent.
|
| 4 |
+
|
| 5 |
+
It consists of training an agent to clear Super Mario Bros with deep reinforcement learning methods.
|
| 6 |
+
|
| 7 |
+
Here are my super mario agents with dueling network. ( trained 7,000 epoch )
|
| 8 |
+
|
| 9 |
+
**(25-05-20) SuperMario with PPO has been updated!**
|
| 10 |
+
|
| 11 |
+
<p float="center">
|
| 12 |
+
<img src="/mario1.gif" width="350" />
|
| 13 |
+
<img src="/mario14.gif" width="350" />
|
| 14 |
+
</p>
|
| 15 |
+
|
| 16 |
+
# Get started
|
| 17 |
+
|
| 18 |
+
## Cloning git
|
| 19 |
+
|
| 20 |
+
```
|
| 21 |
+
git clone https://github.com/jiseongHAN/Super-Mario-RL.git
|
| 22 |
+
cd Super-Mario-RL
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
## Install Requirements
|
| 26 |
+
```
|
| 27 |
+
pip install -r requirements.txt
|
| 28 |
+
```
|
| 29 |
+
|
| 30 |
+
## Or Install Manually
|
| 31 |
+
* Install [openAI gym](http://gym.openai.com/)
|
| 32 |
+
```
|
| 33 |
+
pip install 'gym'
|
| 34 |
+
```
|
| 35 |
+
* Install [Pytorch](https://pytorch.org/)
|
| 36 |
+
```
|
| 37 |
+
pip install torch torchvision
|
| 38 |
+
```
|
| 39 |
+
* Install [nes-py](https://pypi.org/project/nes-py/)
|
| 40 |
+
```
|
| 41 |
+
pip install nes-py
|
| 42 |
+
```
|
| 43 |
+
* Install [gym-super-mario-bros](https://pypi.org/project/gym-super-mario-bros/)
|
| 44 |
+
```
|
| 45 |
+
pip install gym-super-mario-bros
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
# Running
|
| 49 |
+
|
| 50 |
+
## Train
|
| 51 |
+
|
| 52 |
+
* Train with dueling dqn.
|
| 53 |
+
```
|
| 54 |
+
python duel_dqn.py
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
* Train with PPO.
|
| 58 |
+
|
| 59 |
+
```
|
| 60 |
+
python ppo.py
|
| 61 |
+
```
|
| 62 |
+
|
| 63 |
+
### Result
|
| 64 |
+
* score.p : save total score every 50 episode
|
| 65 |
+
* *.pth : save weight of q, q_target every 50 training
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
## Evaluate
|
| 69 |
+
* (Now, pre-trained agent has been corrupted๐ข)
|
| 70 |
+
* Test and render trained agent.
|
| 71 |
+
* To test our agent, we need 'q_target.pth' that generated at the training step.
|
| 72 |
+
* (eval.py with PPO is not supported now)
|
| 73 |
+
```
|
| 74 |
+
python eval.py
|
| 75 |
+
```
|
| 76 |
+
* Or you can use your own agent.
|
| 77 |
+
```
|
| 78 |
+
python eval.py your_own_agent.pth
|
| 79 |
+
```
|
| 80 |
+
|
| 81 |
+
## Reference
|
| 82 |
+
[Wang, Ziyu, et al. "Dueling network architectures for deep reinforcement learning." International conference on machine learning. PMLR, 2016.](https://arxiv.org/pdf/1511.06581.pdf)
|
| 83 |
+
|
| 84 |
+
[Schulman, J., Wolski, F., Dhariwal, P., Radford, A. & Klimov, O. Proximal policy optimization
|
| 85 |
+
algorithms. arXiv preprint arXiv:1707.06347 (2017).](https://arxiv.org/pdf/1707.06347)
|
Super-Mario-RL/__pycache__/wrappers.cpython-313.pyc
ADDED
|
Binary file (18.7 kB). View file
|
|
|
Super-Mario-RL/duel_dqn.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pickle
|
| 2 |
+
import random
|
| 3 |
+
import time
|
| 4 |
+
from collections import deque
|
| 5 |
+
|
| 6 |
+
import gym_super_mario_bros
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
import torch.optim as optim
|
| 12 |
+
from gym_super_mario_bros.actions import COMPLEX_MOVEMENT
|
| 13 |
+
from nes_py.wrappers import JoypadSpace
|
| 14 |
+
|
| 15 |
+
from wrappers import *
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def arrange(s):
|
| 19 |
+
if not type(s) == "numpy.ndarray":
|
| 20 |
+
s = np.array(s)
|
| 21 |
+
assert len(s.shape) == 3
|
| 22 |
+
ret = np.transpose(s, (2, 0, 1))
|
| 23 |
+
return np.expand_dims(ret, 0)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class replay_memory(object):
|
| 27 |
+
def __init__(self, N):
|
| 28 |
+
self.memory = deque(maxlen=N)
|
| 29 |
+
|
| 30 |
+
def push(self, transition):
|
| 31 |
+
self.memory.append(transition)
|
| 32 |
+
|
| 33 |
+
def sample(self, n):
|
| 34 |
+
return random.sample(self.memory, n)
|
| 35 |
+
|
| 36 |
+
def __len__(self):
|
| 37 |
+
return len(self.memory)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class model(nn.Module):
|
| 41 |
+
def __init__(self, n_frame, n_action, device):
|
| 42 |
+
super(model, self).__init__()
|
| 43 |
+
self.layer1 = nn.Conv2d(n_frame, 32, 8, 4)
|
| 44 |
+
self.layer2 = nn.Conv2d(32, 64, 3, 1)
|
| 45 |
+
self.fc = nn.Linear(20736, 512)
|
| 46 |
+
self.q = nn.Linear(512, n_action)
|
| 47 |
+
self.v = nn.Linear(512, 1)
|
| 48 |
+
|
| 49 |
+
self.device = device
|
| 50 |
+
self.seq = nn.Sequential(self.layer1, self.layer2, self.fc, self.q, self.v)
|
| 51 |
+
|
| 52 |
+
self.seq.apply(init_weights)
|
| 53 |
+
|
| 54 |
+
def forward(self, x):
|
| 55 |
+
if type(x) != torch.Tensor:
|
| 56 |
+
x = torch.FloatTensor(x).to(self.device)
|
| 57 |
+
x = torch.relu(self.layer1(x))
|
| 58 |
+
x = torch.relu(self.layer2(x))
|
| 59 |
+
x = x.view(-1, 20736)
|
| 60 |
+
x = torch.relu(self.fc(x))
|
| 61 |
+
adv = self.q(x)
|
| 62 |
+
v = self.v(x)
|
| 63 |
+
q = v + (adv - 1 / adv.shape[-1] * adv.sum(-1, keepdim=True))
|
| 64 |
+
|
| 65 |
+
return q
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def init_weights(m):
|
| 69 |
+
if type(m) == nn.Conv2d:
|
| 70 |
+
torch.nn.init.xavier_uniform_(m.weight)
|
| 71 |
+
m.bias.data.fill_(0.01)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def train(q, q_target, memory, batch_size, gamma, optimizer, device):
|
| 75 |
+
s, r, a, s_prime, done = list(map(list, zip(*memory.sample(batch_size))))
|
| 76 |
+
s = np.array(s).squeeze()
|
| 77 |
+
s_prime = np.array(s_prime).squeeze()
|
| 78 |
+
a_max = q(s_prime).max(1)[1].unsqueeze(-1)
|
| 79 |
+
r = torch.FloatTensor(r).unsqueeze(-1).to(device)
|
| 80 |
+
done = torch.FloatTensor(done).unsqueeze(-1).to(device)
|
| 81 |
+
with torch.no_grad():
|
| 82 |
+
y = r + gamma * q_target(s_prime).gather(1, a_max) * done
|
| 83 |
+
a = torch.tensor(a).unsqueeze(-1).to(device)
|
| 84 |
+
q_value = torch.gather(q(s), dim=1, index=a.view(-1, 1).long())
|
| 85 |
+
|
| 86 |
+
loss = F.smooth_l1_loss(q_value, y).mean()
|
| 87 |
+
optimizer.zero_grad()
|
| 88 |
+
loss.backward()
|
| 89 |
+
optimizer.step()
|
| 90 |
+
return loss
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def copy_weights(q, q_target):
|
| 94 |
+
q_dict = q.state_dict()
|
| 95 |
+
q_target.load_state_dict(q_dict)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def main(env, q, q_target, optimizer, device):
|
| 99 |
+
t = 0
|
| 100 |
+
gamma = 0.99
|
| 101 |
+
batch_size = 256
|
| 102 |
+
|
| 103 |
+
N = 50000
|
| 104 |
+
eps = 0.001
|
| 105 |
+
memory = replay_memory(N)
|
| 106 |
+
update_interval = 50
|
| 107 |
+
print_interval = 10
|
| 108 |
+
|
| 109 |
+
score_lst = []
|
| 110 |
+
total_score = 0.0
|
| 111 |
+
loss = 0.0
|
| 112 |
+
start_time = time.perf_counter()
|
| 113 |
+
|
| 114 |
+
for k in range(1000000):
|
| 115 |
+
s = arrange(env.reset())
|
| 116 |
+
done = False
|
| 117 |
+
|
| 118 |
+
while not done:
|
| 119 |
+
if eps > np.random.rand():
|
| 120 |
+
a = env.action_space.sample()
|
| 121 |
+
else:
|
| 122 |
+
if device == "cpu":
|
| 123 |
+
a = np.argmax(q(s).detach().numpy())
|
| 124 |
+
else:
|
| 125 |
+
a = np.argmax(q(s).cpu().detach().numpy())
|
| 126 |
+
s_prime, r, done, _ = env.step(a)
|
| 127 |
+
s_prime = arrange(s_prime)
|
| 128 |
+
total_score += r
|
| 129 |
+
r = np.sign(r) * (np.sqrt(abs(r) + 1) - 1) + 0.001 * r
|
| 130 |
+
memory.push((s, float(r), int(a), s_prime, int(1 - done)))
|
| 131 |
+
s = s_prime
|
| 132 |
+
stage = env.unwrapped._stage
|
| 133 |
+
if len(memory) > 2000:
|
| 134 |
+
loss += train(q, q_target, memory, batch_size, gamma, optimizer, device)
|
| 135 |
+
t += 1
|
| 136 |
+
if t % update_interval == 0:
|
| 137 |
+
copy_weights(q, q_target)
|
| 138 |
+
torch.save(q.state_dict(), "mario_q.pth")
|
| 139 |
+
torch.save(q_target.state_dict(), "mario_q_target.pth")
|
| 140 |
+
|
| 141 |
+
if k % print_interval == 0:
|
| 142 |
+
time_spent, start_time = (
|
| 143 |
+
time.perf_counter() - start_time,
|
| 144 |
+
time.perf_counter(),
|
| 145 |
+
)
|
| 146 |
+
print(
|
| 147 |
+
"%s |Epoch : %d | score : %f | loss : %.2f | stage : %d | time spent: %f"
|
| 148 |
+
% (
|
| 149 |
+
device,
|
| 150 |
+
k,
|
| 151 |
+
total_score / print_interval,
|
| 152 |
+
loss / print_interval,
|
| 153 |
+
stage,
|
| 154 |
+
time_spent,
|
| 155 |
+
)
|
| 156 |
+
)
|
| 157 |
+
score_lst.append(total_score / print_interval)
|
| 158 |
+
total_score = 0
|
| 159 |
+
loss = 0.0
|
| 160 |
+
pickle.dump(score_lst, open("score.p", "wb"))
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
if __name__ == "__main__":
|
| 164 |
+
n_frame = 4
|
| 165 |
+
env = gym_super_mario_bros.make("SuperMarioBros-v0")
|
| 166 |
+
env = JoypadSpace(env, COMPLEX_MOVEMENT)
|
| 167 |
+
env = wrap_mario(env)
|
| 168 |
+
device = "cpu"
|
| 169 |
+
if torch.cuda.is_available():
|
| 170 |
+
device = "cuda"
|
| 171 |
+
elif torch.backends.mps.is_available():
|
| 172 |
+
device = "mps"
|
| 173 |
+
q = model(n_frame, env.action_space.n, device).to(device)
|
| 174 |
+
q_target = model(n_frame, env.action_space.n, device).to(device)
|
| 175 |
+
optimizer = optim.Adam(q.parameters(), lr=0.0001)
|
| 176 |
+
print(device)
|
| 177 |
+
|
| 178 |
+
main(env, q, q_target, optimizer, device)
|
Super-Mario-RL/duel_dqn_2.py
ADDED
|
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pickle
|
| 2 |
+
import random
|
| 3 |
+
import time
|
| 4 |
+
from collections import deque
|
| 5 |
+
|
| 6 |
+
import gym_super_mario_bros
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
import torch.optim as optim
|
| 12 |
+
from gym_super_mario_bros.actions import COMPLEX_MOVEMENT
|
| 13 |
+
from nes_py.wrappers import JoypadSpace
|
| 14 |
+
|
| 15 |
+
from wrappers import *
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def arrange(s):
|
| 19 |
+
if not type(s) == "numpy.ndarray":
|
| 20 |
+
s = np.array(s)
|
| 21 |
+
assert len(s.shape) == 3
|
| 22 |
+
ret = np.transpose(s, (2, 0, 1))
|
| 23 |
+
return np.expand_dims(ret, 0)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class replay_memory(object):
|
| 27 |
+
def __init__(self, N):
|
| 28 |
+
self.memory = deque(maxlen=N)
|
| 29 |
+
|
| 30 |
+
def push(self, transition):
|
| 31 |
+
self.memory.append(transition)
|
| 32 |
+
|
| 33 |
+
def sample(self, n):
|
| 34 |
+
return random.sample(self.memory, n)
|
| 35 |
+
|
| 36 |
+
def __len__(self):
|
| 37 |
+
return len(self.memory)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class model(nn.Module):
|
| 41 |
+
def __init__(self, n_frame, n_action, device):
|
| 42 |
+
super(model, self).__init__()
|
| 43 |
+
self.layer1 = nn.Conv2d(n_frame, 32, 8, 4)
|
| 44 |
+
self.layer2 = nn.Conv2d(32, 64, 3, 1)
|
| 45 |
+
self.fc = nn.Linear(20736, 512)
|
| 46 |
+
self.q = nn.Linear(512, n_action)
|
| 47 |
+
self.v = nn.Linear(512, 1)
|
| 48 |
+
|
| 49 |
+
self.device = device
|
| 50 |
+
self.seq = nn.Sequential(self.layer1, self.layer2, self.fc, self.q, self.v)
|
| 51 |
+
|
| 52 |
+
self.seq.apply(init_weights)
|
| 53 |
+
|
| 54 |
+
def forward(self, x):
|
| 55 |
+
if type(x) != torch.Tensor:
|
| 56 |
+
x = torch.FloatTensor(x).to(self.device)
|
| 57 |
+
x = torch.relu(self.layer1(x))
|
| 58 |
+
x = torch.relu(self.layer2(x))
|
| 59 |
+
x = x.view(-1, 20736)
|
| 60 |
+
x = torch.relu(self.fc(x))
|
| 61 |
+
adv = self.q(x)
|
| 62 |
+
v = self.v(x)
|
| 63 |
+
q = v + (adv - 1 / adv.shape[-1] * adv.sum(-1, keepdim=True))
|
| 64 |
+
|
| 65 |
+
return q
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def init_weights(m):
|
| 69 |
+
if type(m) == nn.Conv2d:
|
| 70 |
+
torch.nn.init.xavier_uniform_(m.weight)
|
| 71 |
+
m.bias.data.fill_(0.01)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def train(q, q_target, memory, batch_size, gamma, optimizer, device):
|
| 75 |
+
s, r, a, s_prime, done = list(map(list, zip(*memory.sample(batch_size))))
|
| 76 |
+
s = np.array(s).squeeze()
|
| 77 |
+
s_prime = np.array(s_prime).squeeze()
|
| 78 |
+
|
| 79 |
+
# Move computations to device
|
| 80 |
+
s_tensor = torch.FloatTensor(s).to(device)
|
| 81 |
+
s_prime_tensor = torch.FloatTensor(s_prime).to(device)
|
| 82 |
+
|
| 83 |
+
a_max = q(s_prime_tensor).max(1)[1].unsqueeze(-1)
|
| 84 |
+
r = torch.FloatTensor(r).unsqueeze(-1).to(device)
|
| 85 |
+
done = torch.FloatTensor(done).unsqueeze(-1).to(device)
|
| 86 |
+
|
| 87 |
+
with torch.no_grad():
|
| 88 |
+
y = r + gamma * q_target(s_prime_tensor).gather(1, a_max) * done
|
| 89 |
+
|
| 90 |
+
a = torch.tensor(a).unsqueeze(-1).to(device)
|
| 91 |
+
q_value = torch.gather(q(s_tensor), dim=1, index=a.view(-1, 1).long())
|
| 92 |
+
|
| 93 |
+
loss = F.smooth_l1_loss(q_value, y).mean()
|
| 94 |
+
optimizer.zero_grad()
|
| 95 |
+
loss.backward()
|
| 96 |
+
|
| 97 |
+
# Gradient clipping to prevent explosion
|
| 98 |
+
torch.nn.utils.clip_grad_norm_(q.parameters(), max_norm=1.0)
|
| 99 |
+
|
| 100 |
+
optimizer.step()
|
| 101 |
+
return loss.item() # Use .item() to get scalar value
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def copy_weights(q, q_target):
|
| 105 |
+
q_dict = q.state_dict()
|
| 106 |
+
q_target.load_state_dict(q_dict)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def main(env, q, q_target, optimizer, device):
|
| 110 |
+
t = 0
|
| 111 |
+
gamma = 0.99
|
| 112 |
+
batch_size = 256
|
| 113 |
+
|
| 114 |
+
N = 50000
|
| 115 |
+
eps = 0.1 # Increased exploration
|
| 116 |
+
eps_min = 0.01
|
| 117 |
+
eps_decay = 0.999
|
| 118 |
+
memory = replay_memory(N)
|
| 119 |
+
update_interval = 50 # How often to update target network
|
| 120 |
+
save_interval = 100 # How often to save models (in episodes)
|
| 121 |
+
print_interval = 10
|
| 122 |
+
|
| 123 |
+
score_lst = []
|
| 124 |
+
total_score = 0.0
|
| 125 |
+
loss_accumulator = 0.0
|
| 126 |
+
start_time = time.perf_counter()
|
| 127 |
+
|
| 128 |
+
for k in range(1000000):
|
| 129 |
+
s = arrange(env.reset())
|
| 130 |
+
done = False
|
| 131 |
+
|
| 132 |
+
while not done:
|
| 133 |
+
# Epsilon decay
|
| 134 |
+
if eps > eps_min:
|
| 135 |
+
eps *= eps_decay
|
| 136 |
+
|
| 137 |
+
if eps > np.random.rand():
|
| 138 |
+
a = env.action_space.sample()
|
| 139 |
+
else:
|
| 140 |
+
# Get action with proper device handling
|
| 141 |
+
with torch.no_grad():
|
| 142 |
+
q_values = q(torch.FloatTensor(s).to(device))
|
| 143 |
+
|
| 144 |
+
# Move to CPU for numpy conversion regardless of device
|
| 145 |
+
if device == "cuda" or device == "mps":
|
| 146 |
+
a = np.argmax(q_values.cpu().numpy())
|
| 147 |
+
else:
|
| 148 |
+
a = np.argmax(q_values.detach().numpy())
|
| 149 |
+
|
| 150 |
+
s_prime, r, done, info = env.step(a)
|
| 151 |
+
s_prime = arrange(s_prime)
|
| 152 |
+
total_score += r
|
| 153 |
+
|
| 154 |
+
# Enhanced reward shaping
|
| 155 |
+
reward = np.sign(r) * (np.sqrt(abs(r) + 1) - 1) + 0.001 * r
|
| 156 |
+
|
| 157 |
+
# Bonus for x_pos progress
|
| 158 |
+
if 'x_pos' in info:
|
| 159 |
+
x_pos = info['x_pos']
|
| 160 |
+
if hasattr(main, 'last_x_pos'):
|
| 161 |
+
x_progress = x_pos - main.last_x_pos
|
| 162 |
+
if x_progress > 0:
|
| 163 |
+
reward += 0.1 * x_progress # Small bonus for moving right
|
| 164 |
+
main.last_x_pos = x_pos
|
| 165 |
+
|
| 166 |
+
memory.push((s, float(reward), int(a), s_prime, int(1 - done)))
|
| 167 |
+
s = s_prime
|
| 168 |
+
stage = env.unwrapped._stage
|
| 169 |
+
|
| 170 |
+
if len(memory) > 2000:
|
| 171 |
+
loss_val = train(q, q_target, memory, batch_size, gamma, optimizer, device)
|
| 172 |
+
loss_accumulator += loss_val
|
| 173 |
+
t += 1
|
| 174 |
+
|
| 175 |
+
# Update target network (but don't save every time)
|
| 176 |
+
if t % update_interval == 0:
|
| 177 |
+
copy_weights(q, q_target)
|
| 178 |
+
|
| 179 |
+
# Save models less frequently (every save_interval episodes)
|
| 180 |
+
if k % save_interval == 0 and k > 0:
|
| 181 |
+
torch.save(q.state_dict(), "mario_q.pth")
|
| 182 |
+
torch.save(q_target.state_dict(), "mario_q_target.pth")
|
| 183 |
+
print(f"Models saved at episode {k}")
|
| 184 |
+
|
| 185 |
+
if k % print_interval == 0:
|
| 186 |
+
time_spent, start_time = (
|
| 187 |
+
time.perf_counter() - start_time,
|
| 188 |
+
time.perf_counter(),
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
# Fixed: Use loss_accumulator instead of loss and ensure proper formatting
|
| 192 |
+
avg_loss = loss_accumulator / print_interval if print_interval > 0 else 0.0
|
| 193 |
+
avg_score = total_score / print_interval if print_interval > 0 else 0.0
|
| 194 |
+
|
| 195 |
+
print(
|
| 196 |
+
"%s | Epoch : %d | score : %.2f | loss : %.2f | stage : %d | eps : %.3f | time: %.2fs | memory: %d"
|
| 197 |
+
% (
|
| 198 |
+
device,
|
| 199 |
+
k,
|
| 200 |
+
avg_score,
|
| 201 |
+
avg_loss,
|
| 202 |
+
stage,
|
| 203 |
+
eps,
|
| 204 |
+
time_spent,
|
| 205 |
+
len(memory)
|
| 206 |
+
)
|
| 207 |
+
)
|
| 208 |
+
score_lst.append(avg_score)
|
| 209 |
+
total_score = 0.0
|
| 210 |
+
loss_accumulator = 0.0
|
| 211 |
+
pickle.dump(score_lst, open("score.p", "wb"))
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
if __name__ == "__main__":
|
| 215 |
+
n_frame = 4
|
| 216 |
+
env = gym_super_mario_bros.make("SuperMarioBros-v3")
|
| 217 |
+
env = JoypadSpace(env, COMPLEX_MOVEMENT)
|
| 218 |
+
env = wrap_mario(env)
|
| 219 |
+
|
| 220 |
+
# Device detection with MPS support
|
| 221 |
+
device = "cpu"
|
| 222 |
+
if torch.cuda.is_available():
|
| 223 |
+
device = "cuda"
|
| 224 |
+
elif torch.backends.mps.is_available():
|
| 225 |
+
device = "mps"
|
| 226 |
+
|
| 227 |
+
print(f"Using device: {device}")
|
| 228 |
+
|
| 229 |
+
q = model(n_frame, env.action_space.n, device).to(device)
|
| 230 |
+
q_target = model(n_frame, env.action_space.n, device).to(device)
|
| 231 |
+
|
| 232 |
+
# Copy weights initially
|
| 233 |
+
copy_weights(q, q_target)
|
| 234 |
+
|
| 235 |
+
optimizer = optim.Adam(q.parameters(), lr=0.0001, weight_decay=1e-5) # Added weight decay
|
| 236 |
+
|
| 237 |
+
main(env, q, q_target, optimizer, device)
|
Super-Mario-RL/enhanced_duel_dqn.py
ADDED
|
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pickle
|
| 2 |
+
import random
|
| 3 |
+
import time
|
| 4 |
+
from collections import deque
|
| 5 |
+
|
| 6 |
+
import gym_super_mario_bros
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
import torch.optim as optim
|
| 12 |
+
from gym_super_mario_bros.actions import COMPLEX_MOVEMENT
|
| 13 |
+
from nes_py.wrappers import JoypadSpace
|
| 14 |
+
|
| 15 |
+
from wrappers import *
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def arrange(s):
|
| 19 |
+
if not type(s) == "numpy.ndarray":
|
| 20 |
+
s = np.array(s)
|
| 21 |
+
assert len(s.shape) == 3
|
| 22 |
+
ret = np.transpose(s, (2, 0, 1))
|
| 23 |
+
return np.expand_dims(ret, 0)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class replay_memory(object):
|
| 27 |
+
def __init__(self, N):
|
| 28 |
+
self.memory = deque(maxlen=N)
|
| 29 |
+
|
| 30 |
+
def push(self, transition):
|
| 31 |
+
self.memory.append(transition)
|
| 32 |
+
|
| 33 |
+
def sample(self, n):
|
| 34 |
+
return random.sample(self.memory, n)
|
| 35 |
+
|
| 36 |
+
def __len__(self):
|
| 37 |
+
return len(self.memory)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class model(nn.Module):
|
| 41 |
+
def __init__(self, n_frame, n_action, device):
|
| 42 |
+
super(model, self).__init__()
|
| 43 |
+
self.layer1 = nn.Conv2d(n_frame, 32, 8, 4)
|
| 44 |
+
self.layer2 = nn.Conv2d(32, 64, 3, 1)
|
| 45 |
+
self.fc = nn.Linear(20736, 512)
|
| 46 |
+
self.q = nn.Linear(512, n_action)
|
| 47 |
+
self.v = nn.Linear(512, 1)
|
| 48 |
+
|
| 49 |
+
self.device = device
|
| 50 |
+
self.seq = nn.Sequential(self.layer1, self.layer2, self.fc, self.q, self.v)
|
| 51 |
+
|
| 52 |
+
self.seq.apply(init_weights)
|
| 53 |
+
|
| 54 |
+
def forward(self, x):
|
| 55 |
+
if type(x) != torch.Tensor:
|
| 56 |
+
x = torch.FloatTensor(x).to(self.device)
|
| 57 |
+
x = torch.relu(self.layer1(x))
|
| 58 |
+
x = torch.relu(self.layer2(x))
|
| 59 |
+
x = x.view(-1, 20736)
|
| 60 |
+
x = torch.relu(self.fc(x))
|
| 61 |
+
adv = self.q(x)
|
| 62 |
+
v = self.v(x)
|
| 63 |
+
q = v + (adv - 1 / adv.shape[-1] * adv.sum(-1, keepdim=True))
|
| 64 |
+
|
| 65 |
+
return q
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def init_weights(m):
|
| 69 |
+
if type(m) == nn.Conv2d:
|
| 70 |
+
torch.nn.init.xavier_uniform_(m.weight)
|
| 71 |
+
m.bias.data.fill_(0.01)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def train(q, q_target, memory, batch_size, gamma, optimizer, device):
|
| 75 |
+
s, r, a, s_prime, done = list(map(list, zip(*memory.sample(batch_size))))
|
| 76 |
+
s = np.array(s).squeeze()
|
| 77 |
+
s_prime = np.array(s_prime).squeeze()
|
| 78 |
+
|
| 79 |
+
# Move computations to device
|
| 80 |
+
s_tensor = torch.FloatTensor(s).to(device)
|
| 81 |
+
s_prime_tensor = torch.FloatTensor(s_prime).to(device)
|
| 82 |
+
|
| 83 |
+
a_max = q(s_prime_tensor).max(1)[1].unsqueeze(-1)
|
| 84 |
+
r = torch.FloatTensor(r).unsqueeze(-1).to(device)
|
| 85 |
+
done = torch.FloatTensor(done).unsqueeze(-1).to(device)
|
| 86 |
+
|
| 87 |
+
with torch.no_grad():
|
| 88 |
+
y = r + gamma * q_target(s_prime_tensor).gather(1, a_max) * done
|
| 89 |
+
|
| 90 |
+
a = torch.tensor(a).unsqueeze(-1).to(device)
|
| 91 |
+
q_value = torch.gather(q(s_tensor), dim=1, index=a.view(-1, 1).long())
|
| 92 |
+
|
| 93 |
+
loss = F.smooth_l1_loss(q_value, y).mean()
|
| 94 |
+
optimizer.zero_grad()
|
| 95 |
+
loss.backward()
|
| 96 |
+
|
| 97 |
+
# Gradient clipping to prevent explosion
|
| 98 |
+
torch.nn.utils.clip_grad_norm_(q.parameters(), max_norm=10.0) # Increased clipping
|
| 99 |
+
|
| 100 |
+
optimizer.step()
|
| 101 |
+
return loss.item()
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def copy_weights(q, q_target):
|
| 105 |
+
q_dict = q.state_dict()
|
| 106 |
+
q_target.load_state_dict(q_dict)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def main(env, q, q_target, optimizer, device):
|
| 110 |
+
t = 0
|
| 111 |
+
gamma = 0.99
|
| 112 |
+
batch_size = 256
|
| 113 |
+
|
| 114 |
+
N = 50000
|
| 115 |
+
eps = 0.3 # Higher initial exploration
|
| 116 |
+
eps_min = 0.05 # Higher minimum exploration
|
| 117 |
+
eps_decay = 0.995 # Slower decay
|
| 118 |
+
memory = replay_memory(N)
|
| 119 |
+
update_interval = 100 # Less frequent target updates
|
| 120 |
+
save_interval = 100
|
| 121 |
+
print_interval = 10
|
| 122 |
+
|
| 123 |
+
score_lst = []
|
| 124 |
+
total_score = 0.0
|
| 125 |
+
loss_accumulator = 0.0
|
| 126 |
+
start_time = time.perf_counter()
|
| 127 |
+
|
| 128 |
+
# Track best score for saving
|
| 129 |
+
best_score = -float('inf')
|
| 130 |
+
|
| 131 |
+
for k in range(1000000):
|
| 132 |
+
s = arrange(env.reset())
|
| 133 |
+
done = False
|
| 134 |
+
episode_loss = 0.0
|
| 135 |
+
episode_steps = 0
|
| 136 |
+
|
| 137 |
+
while not done:
|
| 138 |
+
# Epsilon decay per step
|
| 139 |
+
if eps > eps_min:
|
| 140 |
+
eps *= eps_decay
|
| 141 |
+
|
| 142 |
+
if eps > np.random.rand():
|
| 143 |
+
a = env.action_space.sample()
|
| 144 |
+
else:
|
| 145 |
+
with torch.no_grad():
|
| 146 |
+
q_values = q(torch.FloatTensor(s).to(device))
|
| 147 |
+
|
| 148 |
+
if device == "cuda" or device == "mps":
|
| 149 |
+
a = np.argmax(q_values.cpu().numpy())
|
| 150 |
+
else:
|
| 151 |
+
a = np.argmax(q_values.detach().numpy())
|
| 152 |
+
|
| 153 |
+
s_prime, r, done, info = env.step(a)
|
| 154 |
+
s_prime = arrange(s_prime)
|
| 155 |
+
total_score += r
|
| 156 |
+
|
| 157 |
+
# Enhanced reward shaping
|
| 158 |
+
reward = np.sign(r) * (np.sqrt(abs(r) + 1) - 1) + 0.001 * r
|
| 159 |
+
|
| 160 |
+
# Bonus for x_pos progress and stage completion
|
| 161 |
+
if 'x_pos' in info:
|
| 162 |
+
x_pos = info['x_pos']
|
| 163 |
+
if hasattr(main, 'last_x_pos'):
|
| 164 |
+
x_progress = x_pos - main.last_x_pos
|
| 165 |
+
if x_progress > 0:
|
| 166 |
+
reward += 0.05 * x_progress # Reduced bonus to prevent over-optimization
|
| 167 |
+
main.last_x_pos = x_pos
|
| 168 |
+
|
| 169 |
+
# Large bonus for completing the level
|
| 170 |
+
if done and info.get('flag_get', False):
|
| 171 |
+
reward += 50.0
|
| 172 |
+
print(f"๐ LEVEL COMPLETED at episode {k}! ๐")
|
| 173 |
+
|
| 174 |
+
memory.push((s, float(reward), int(a), s_prime, int(1 - done)))
|
| 175 |
+
s = s_prime
|
| 176 |
+
stage = info.get('stage', 1)
|
| 177 |
+
world = info.get('world', 1)
|
| 178 |
+
|
| 179 |
+
# Train only if we have enough samples
|
| 180 |
+
if len(memory) > 5000: # Increased minimum buffer size
|
| 181 |
+
loss_val = train(q, q_target, memory, batch_size, gamma, optimizer, device)
|
| 182 |
+
loss_accumulator += loss_val
|
| 183 |
+
episode_loss += loss_val
|
| 184 |
+
episode_steps += 1
|
| 185 |
+
t += 1
|
| 186 |
+
|
| 187 |
+
# Update target network less frequently
|
| 188 |
+
if t % update_interval == 0:
|
| 189 |
+
copy_weights(q, q_target)
|
| 190 |
+
|
| 191 |
+
# Save best model
|
| 192 |
+
current_avg_score = total_score / print_interval if k % print_interval == 0 else total_score
|
| 193 |
+
if current_avg_score > best_score and k > 0:
|
| 194 |
+
best_score = current_avg_score
|
| 195 |
+
torch.save(q.state_dict(), "enhanced_mario_q_best.pth")
|
| 196 |
+
torch.save(q_target.state_dict(), "enhanced_mario_q_target_best.pth")
|
| 197 |
+
print(f"๐พ New best model saved! Score: {best_score:.2f}")
|
| 198 |
+
|
| 199 |
+
# Save models periodically
|
| 200 |
+
if k % save_interval == 0 and k > 0:
|
| 201 |
+
torch.save(q.state_dict(), "enhanced_mario_q.pth")
|
| 202 |
+
torch.save(q_target.state_dict(), "enhanced_mario_q_target.pth")
|
| 203 |
+
print(f"Models saved at episode {k}")
|
| 204 |
+
|
| 205 |
+
if k % print_interval == 0:
|
| 206 |
+
time_spent, start_time = (
|
| 207 |
+
time.perf_counter() - start_time,
|
| 208 |
+
time.perf_counter(),
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
avg_loss = loss_accumulator / (print_interval * episode_steps) if episode_steps > 0 else 0.0
|
| 212 |
+
avg_score = total_score / print_interval
|
| 213 |
+
|
| 214 |
+
print(
|
| 215 |
+
"%s | Ep: %d | Score: %.2f | Loss: %.2f | Stage: %d-%d | Eps: %.3f | Time: %.2fs | Mem: %d | Steps: %d"
|
| 216 |
+
% (
|
| 217 |
+
device,
|
| 218 |
+
k,
|
| 219 |
+
avg_score,
|
| 220 |
+
avg_loss,
|
| 221 |
+
world,
|
| 222 |
+
stage,
|
| 223 |
+
eps,
|
| 224 |
+
time_spent,
|
| 225 |
+
len(memory),
|
| 226 |
+
episode_steps
|
| 227 |
+
)
|
| 228 |
+
)
|
| 229 |
+
score_lst.append(avg_score)
|
| 230 |
+
total_score = 0.0
|
| 231 |
+
loss_accumulator = 0.0
|
| 232 |
+
pickle.dump(score_lst, open("score.p", "wb"))
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
if __name__ == "__main__":
|
| 236 |
+
n_frame = 4
|
| 237 |
+
env = gym_super_mario_bros.make("SuperMarioBros-v3")
|
| 238 |
+
env = JoypadSpace(env, COMPLEX_MOVEMENT)
|
| 239 |
+
env = wrap_mario(env)
|
| 240 |
+
|
| 241 |
+
device = "cpu"
|
| 242 |
+
if torch.cuda.is_available():
|
| 243 |
+
device = "cuda"
|
| 244 |
+
elif torch.backends.mps.is_available():
|
| 245 |
+
device = "mps"
|
| 246 |
+
|
| 247 |
+
print(f"Using device: {device}")
|
| 248 |
+
|
| 249 |
+
q = model(n_frame, env.action_space.n, device).to(device)
|
| 250 |
+
q_target = model(n_frame, env.action_space.n, device).to(device)
|
| 251 |
+
|
| 252 |
+
copy_weights(q, q_target)
|
| 253 |
+
|
| 254 |
+
# Lower learning rate for stability
|
| 255 |
+
optimizer = optim.Adam(q.parameters(), lr=0.00005, weight_decay=1e-5)
|
| 256 |
+
|
| 257 |
+
main(env, q, q_target, optimizer, device)
|
Super-Mario-RL/enhanced_mario_q.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:034fc1bde429fd3e9bdde72fb16697707a604dad8db737f49f8e61ad5b442026
|
| 3 |
+
size 42607893
|
Super-Mario-RL/enhanced_mario_q_best.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c6874204a382dc834d328f078a05110c793ead351eee5034b9f37f09ed6c11b9
|
| 3 |
+
size 42607973
|
Super-Mario-RL/enhanced_mario_q_target.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6fa5448953e60e5be93f7a6a0843d11af0a279a6e13adafd36cb31f025fdf914
|
| 3 |
+
size 42608069
|
Super-Mario-RL/enhanced_mario_q_target_best.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4a5f341b08473c536af3398f2b984006c55d9e96952b0b0f5263bf1cdd7f7917
|
| 3 |
+
size 42608213
|
Super-Mario-RL/eval.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import time
|
| 3 |
+
|
| 4 |
+
import gym_super_mario_bros
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from gym_super_mario_bros.actions import COMPLEX_MOVEMENT
|
| 8 |
+
from nes_py.wrappers import JoypadSpace
|
| 9 |
+
|
| 10 |
+
from wrappers import *
|
| 11 |
+
|
| 12 |
+
# Device detection
|
| 13 |
+
device = "cpu"
|
| 14 |
+
if torch.cuda.is_available():
|
| 15 |
+
device = "cuda"
|
| 16 |
+
elif torch.backends.mps.is_available():
|
| 17 |
+
device = "mps"
|
| 18 |
+
|
| 19 |
+
print(f"Using device: {device}")
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# Same as duel_dqn.mlp (you can make model.py to avoid duplication.)
|
| 23 |
+
class model(nn.Module):
|
| 24 |
+
def __init__(self, n_frame, n_action, device):
|
| 25 |
+
super(model, self).__init__()
|
| 26 |
+
self.layer1 = nn.Conv2d(n_frame, 32, 8, 4)
|
| 27 |
+
self.layer2 = nn.Conv2d(32, 64, 3, 1)
|
| 28 |
+
self.fc = nn.Linear(20736, 512)
|
| 29 |
+
self.q = nn.Linear(512, n_action)
|
| 30 |
+
self.v = nn.Linear(512, 1)
|
| 31 |
+
|
| 32 |
+
self.device = device
|
| 33 |
+
self.seq = nn.Sequential(self.layer1, self.layer2, self.fc, self.q, self.v)
|
| 34 |
+
|
| 35 |
+
self.seq.apply(init_weights)
|
| 36 |
+
|
| 37 |
+
def forward(self, x):
|
| 38 |
+
if type(x) != torch.Tensor:
|
| 39 |
+
x = torch.FloatTensor(x).to(self.device)
|
| 40 |
+
x = torch.relu(self.layer1(x))
|
| 41 |
+
x = torch.relu(self.layer2(x))
|
| 42 |
+
x = x.view(-1, 20736)
|
| 43 |
+
x = torch.relu(self.fc(x))
|
| 44 |
+
adv = self.q(x)
|
| 45 |
+
v = self.v(x)
|
| 46 |
+
q = v + (adv - 1 / adv.shape[-1] * adv.max(-1, True)[0])
|
| 47 |
+
return q
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def init_weights(m):
|
| 51 |
+
if type(m) == nn.Conv2d:
|
| 52 |
+
torch.nn.init.xavier_uniform_(m.weight)
|
| 53 |
+
m.bias.data.fill_(0.01)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def arange(s):
|
| 57 |
+
if not type(s) == "numpy.ndarray":
|
| 58 |
+
s = np.array(s)
|
| 59 |
+
assert len(s.shape) == 3
|
| 60 |
+
ret = np.transpose(s, (2, 0, 1))
|
| 61 |
+
return np.expand_dims(ret, 0)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
if __name__ == "__main__":
|
| 65 |
+
ckpt_path = sys.argv[1] if len(sys.argv) > 1 else "mario_q_target.pth"
|
| 66 |
+
print(f"Load ckpt from {ckpt_path}")
|
| 67 |
+
n_frame = 4
|
| 68 |
+
env = gym_super_mario_bros.make("SuperMarioBros-v0")
|
| 69 |
+
env = JoypadSpace(env, COMPLEX_MOVEMENT)
|
| 70 |
+
env = wrap_mario(env)
|
| 71 |
+
|
| 72 |
+
q = model(n_frame, env.action_space.n, device).to(device)
|
| 73 |
+
|
| 74 |
+
# Load model with proper device mapping
|
| 75 |
+
try:
|
| 76 |
+
q.load_state_dict(torch.load(ckpt_path, map_location=torch.device(device)))
|
| 77 |
+
print(f"Model loaded successfully on {device}")
|
| 78 |
+
except Exception as e:
|
| 79 |
+
print(f"Error loading model with {device}: {e}")
|
| 80 |
+
print("Trying to load with CPU mapping...")
|
| 81 |
+
q.load_state_dict(torch.load(ckpt_path, map_location="cpu"))
|
| 82 |
+
q = q.to(device)
|
| 83 |
+
print(f"Model loaded with CPU mapping and moved to {device}")
|
| 84 |
+
|
| 85 |
+
total_score = 0.0
|
| 86 |
+
done = False
|
| 87 |
+
s = arange(env.reset())
|
| 88 |
+
i = 0
|
| 89 |
+
|
| 90 |
+
# Evaluation loop
|
| 91 |
+
while not done:
|
| 92 |
+
env.render()
|
| 93 |
+
|
| 94 |
+
# Get Q-values and action
|
| 95 |
+
with torch.no_grad():
|
| 96 |
+
q_values = q(s)
|
| 97 |
+
|
| 98 |
+
# Move to CPU for numpy conversion regardless of device
|
| 99 |
+
if device == "cuda" or device == "mps":
|
| 100 |
+
a = np.argmax(q_values.cpu().numpy())
|
| 101 |
+
else:
|
| 102 |
+
a = np.argmax(q_values.detach().numpy())
|
| 103 |
+
|
| 104 |
+
s_prime, r, done, _ = env.step(a)
|
| 105 |
+
s_prime = arange(s_prime)
|
| 106 |
+
total_score += r
|
| 107 |
+
s = s_prime
|
| 108 |
+
time.sleep(0.001)
|
| 109 |
+
|
| 110 |
+
stage = env.unwrapped._stage
|
| 111 |
+
print("Total score : %f | stage : %d" % (total_score, stage))
|
Super-Mario-RL/mario1.gif
ADDED
|
Git LFS Details
|
Super-Mario-RL/mario1.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2c3e5f49d302a68348659adba9a7f9f1be4d57fd3204214a191a47234aea0cd0
|
| 3 |
+
size 643908
|
Super-Mario-RL/mario14.gif
ADDED
|
Git LFS Details
|
Super-Mario-RL/mario14.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:52efcbda754d016c56b6cf67d50da683988061b31ae7d43217f995a5db89474a
|
| 3 |
+
size 547827
|
Super-Mario-RL/mario_q.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2e1d46a2e2822ff25428906b6964f578bcf2f16526cf4edecef0ded6350499d6
|
| 3 |
+
size 42607237
|
Super-Mario-RL/mario_q_best.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d286a832aaeae13128121500fe34759a50f0e587df5cfd9ac83d780b181ed340
|
| 3 |
+
size 42607829
|
Super-Mario-RL/mario_q_target.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e7811192034bc02b5236043c3bfc8e982038c970bcd6e948b82cdac18706a964
|
| 3 |
+
size 39059456
|
Super-Mario-RL/mario_q_target_best.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c358203988162769284fe455cdc9fc047e69e6672d4ec05066f79a20dd54404c
|
| 3 |
+
size 42607941
|
Super-Mario-RL/ppo.py
ADDED
|
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import Counter
|
| 2 |
+
|
| 3 |
+
import gym_super_mario_bros
|
| 4 |
+
import gymnasium as gym
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.optim as optim
|
| 9 |
+
from gym_super_mario_bros.actions import COMPLEX_MOVEMENT
|
| 10 |
+
from nes_py.wrappers import JoypadSpace
|
| 11 |
+
|
| 12 |
+
from wrappers import *
|
| 13 |
+
|
| 14 |
+
device = "cpu"
|
| 15 |
+
if torch.cuda.is_available():
|
| 16 |
+
device = "cuda"
|
| 17 |
+
elif torch.backends.mps.is_available():
|
| 18 |
+
device = "mps"
|
| 19 |
+
|
| 20 |
+
print(f"Using device: {device}")
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def make_env():
|
| 24 |
+
env = gym_super_mario_bros.make("SuperMarioBros-v0")
|
| 25 |
+
env = JoypadSpace(env, COMPLEX_MOVEMENT)
|
| 26 |
+
env = wrap_mario(env)
|
| 27 |
+
return env
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def get_reward(r):
|
| 31 |
+
r = np.sign(r) * (np.sqrt(abs(r) + 1) - 1) + 0.001 * r
|
| 32 |
+
return r
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class ActorCritic(nn.Module):
|
| 36 |
+
def __init__(self, n_frame, act_dim):
|
| 37 |
+
super().__init__()
|
| 38 |
+
self.net = nn.Sequential(
|
| 39 |
+
nn.Conv2d(n_frame, 32, 8, 4),
|
| 40 |
+
nn.ReLU(),
|
| 41 |
+
nn.Conv2d(32, 64, 3, 1),
|
| 42 |
+
nn.ReLU(),
|
| 43 |
+
)
|
| 44 |
+
self.linear = nn.Linear(20736, 512)
|
| 45 |
+
self.policy_head = nn.Linear(512, act_dim)
|
| 46 |
+
self.value_head = nn.Linear(512, 1)
|
| 47 |
+
|
| 48 |
+
def forward(self, x):
|
| 49 |
+
if x.dim() == 4:
|
| 50 |
+
x = x.permute(0, 3, 1, 2)
|
| 51 |
+
elif x.dim() == 3:
|
| 52 |
+
x = x.permute(2, 0, 1)
|
| 53 |
+
x = self.net(x)
|
| 54 |
+
x = x.reshape(-1, 20736)
|
| 55 |
+
x = torch.relu(self.linear(x))
|
| 56 |
+
|
| 57 |
+
return self.policy_head(x), self.value_head(x).squeeze(-1)
|
| 58 |
+
|
| 59 |
+
def act(self, obs):
|
| 60 |
+
logits, value = self.forward(obs)
|
| 61 |
+
dist = torch.distributions.Categorical(logits=logits)
|
| 62 |
+
action = dist.sample()
|
| 63 |
+
logprob = dist.log_prob(action)
|
| 64 |
+
return action, logprob, value
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def compute_gae_batch(rewards, values, dones, gamma=0.99, lam=0.95):
|
| 68 |
+
T, N = rewards.shape
|
| 69 |
+
advantages = torch.zeros_like(rewards)
|
| 70 |
+
gae = torch.zeros(N, device=device)
|
| 71 |
+
|
| 72 |
+
for t in reversed(range(T)):
|
| 73 |
+
not_done = 1.0 - dones[t]
|
| 74 |
+
delta = rewards[t] + gamma * values[t + 1] * not_done - values[t]
|
| 75 |
+
gae = delta + gamma * lam * not_done * gae
|
| 76 |
+
advantages[t] = gae
|
| 77 |
+
|
| 78 |
+
returns = advantages + values[:-1]
|
| 79 |
+
return advantages, returns
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def rollout_with_bootstrap(envs, model, rollout_steps, init_obs):
|
| 83 |
+
obs = init_obs
|
| 84 |
+
obs = torch.tensor(obs, dtype=torch.float32).to(device)
|
| 85 |
+
obs_buf, act_buf, rew_buf, done_buf, val_buf, logp_buf = [], [], [], [], [], []
|
| 86 |
+
|
| 87 |
+
for _ in range(rollout_steps):
|
| 88 |
+
obs_buf.append(obs)
|
| 89 |
+
|
| 90 |
+
with torch.no_grad():
|
| 91 |
+
action, logp, value = model.act(obs)
|
| 92 |
+
|
| 93 |
+
val_buf.append(value)
|
| 94 |
+
logp_buf.append(logp)
|
| 95 |
+
act_buf.append(action)
|
| 96 |
+
|
| 97 |
+
actions = action.cpu().numpy()
|
| 98 |
+
next_obs, reward, done, infos = envs.step(actions)
|
| 99 |
+
|
| 100 |
+
reward = [get_reward(r) for r in reward]
|
| 101 |
+
# done = np.logical_or(terminated)
|
| 102 |
+
|
| 103 |
+
rew_buf.append(torch.tensor(reward, dtype=torch.float32).to(device))
|
| 104 |
+
done_buf.append(torch.tensor(done, dtype=torch.float32).to(device))
|
| 105 |
+
|
| 106 |
+
for i, d in enumerate(done):
|
| 107 |
+
if d:
|
| 108 |
+
print(f"Env {i} done. Resetting. (info: {infos[i]})")
|
| 109 |
+
next_obs[i] = envs.envs[i].reset()
|
| 110 |
+
|
| 111 |
+
obs = torch.tensor(next_obs, dtype=torch.float32).to(device)
|
| 112 |
+
max_stage = max([i["stage"] for i in infos])
|
| 113 |
+
|
| 114 |
+
with torch.no_grad():
|
| 115 |
+
_, last_value = model.forward(obs)
|
| 116 |
+
|
| 117 |
+
obs_buf = torch.stack(obs_buf)
|
| 118 |
+
act_buf = torch.stack(act_buf)
|
| 119 |
+
rew_buf = torch.stack(rew_buf)
|
| 120 |
+
done_buf = torch.stack(done_buf)
|
| 121 |
+
val_buf = torch.stack(val_buf)
|
| 122 |
+
val_buf = torch.cat([val_buf, last_value.unsqueeze(0)], dim=0)
|
| 123 |
+
logp_buf = torch.stack(logp_buf)
|
| 124 |
+
|
| 125 |
+
adv_buf, ret_buf = compute_gae_batch(rew_buf, val_buf, done_buf)
|
| 126 |
+
adv_buf = (adv_buf - adv_buf.mean()) / (adv_buf.std() + 1e-8)
|
| 127 |
+
|
| 128 |
+
return {
|
| 129 |
+
"obs": obs_buf, # [T, N, obs_dim]
|
| 130 |
+
"actions": act_buf,
|
| 131 |
+
"logprobs": logp_buf,
|
| 132 |
+
"advantages": adv_buf,
|
| 133 |
+
"returns": ret_buf,
|
| 134 |
+
"max_stage": max_stage,
|
| 135 |
+
"last_obs": obs,
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def evaluate_policy(env, model, episodes=5, render=False):
|
| 140 |
+
"""
|
| 141 |
+
Function to evaluate the learned policy
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
env: gym.Env single environment (not vector!)
|
| 145 |
+
|
| 146 |
+
model: ActorCritic model
|
| 147 |
+
|
| 148 |
+
episodes: number of episodes to evaluate
|
| 149 |
+
|
| 150 |
+
render: whether to visualize (if True, display on screen)
|
| 151 |
+
|
| 152 |
+
Returns:
|
| 153 |
+
avg_return: average total reward
|
| 154 |
+
"""
|
| 155 |
+
model.eval()
|
| 156 |
+
total_returns = []
|
| 157 |
+
actions = []
|
| 158 |
+
stages = []
|
| 159 |
+
for ep in range(episodes):
|
| 160 |
+
obs = env.reset()
|
| 161 |
+
done = False
|
| 162 |
+
total_reward = 0
|
| 163 |
+
if render:
|
| 164 |
+
env.render()
|
| 165 |
+
while not done:
|
| 166 |
+
obs_tensor = (
|
| 167 |
+
torch.tensor(np.array(obs), dtype=torch.float32).unsqueeze(0).to(device)
|
| 168 |
+
)
|
| 169 |
+
with torch.no_grad():
|
| 170 |
+
logits, _ = model(obs_tensor)
|
| 171 |
+
dist = torch.distributions.Categorical(logits=logits)
|
| 172 |
+
action = dist.probs.argmax(dim=-1).item() # greedy action
|
| 173 |
+
actions.append(action)
|
| 174 |
+
|
| 175 |
+
obs, reward, done, info = env.step(action)
|
| 176 |
+
stages.append(info["stage"])
|
| 177 |
+
total_reward += reward
|
| 178 |
+
|
| 179 |
+
total_returns.append(total_reward)
|
| 180 |
+
info["action_count"] = Counter(actions)
|
| 181 |
+
model.train()
|
| 182 |
+
return np.mean(total_returns), info, max(stages)
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def train_ppo():
|
| 186 |
+
num_env = 8
|
| 187 |
+
envs = gym.vector.SyncVectorEnv([lambda: make_env() for _ in range(num_env)])
|
| 188 |
+
obs_dim = envs.single_observation_space.shape[-1]
|
| 189 |
+
act_dim = envs.single_action_space.n
|
| 190 |
+
print(f"{obs_dim=} {act_dim=}")
|
| 191 |
+
model = ActorCritic(obs_dim, act_dim).to(device)
|
| 192 |
+
|
| 193 |
+
# Load model with proper device mapping
|
| 194 |
+
try:
|
| 195 |
+
# Try to load with current device first
|
| 196 |
+
model.load_state_dict(torch.load("mario_1_1.pt", map_location=device))
|
| 197 |
+
print("Model loaded successfully with current device mapping")
|
| 198 |
+
except:
|
| 199 |
+
try:
|
| 200 |
+
# If that fails, try loading with CPU and then moving to device
|
| 201 |
+
model.load_state_dict(torch.load("mario_1_1.pt", map_location="cpu"))
|
| 202 |
+
model = model.to(device)
|
| 203 |
+
print("Model loaded successfully with CPU mapping")
|
| 204 |
+
except Exception as e:
|
| 205 |
+
print(f"Failed to load model: {e}")
|
| 206 |
+
print("Starting with fresh model")
|
| 207 |
+
|
| 208 |
+
optimizer = optim.Adam(model.parameters(), lr=2.5e-4)
|
| 209 |
+
|
| 210 |
+
rollout_steps = 128
|
| 211 |
+
epochs = 4
|
| 212 |
+
minibatch_size = 64
|
| 213 |
+
clip_eps = 0.2
|
| 214 |
+
vf_coef = 0.5
|
| 215 |
+
ent_coef = 0.01
|
| 216 |
+
eval_env = make_env()
|
| 217 |
+
eval_env.reset()
|
| 218 |
+
|
| 219 |
+
init_obs = envs.reset()
|
| 220 |
+
update = 0
|
| 221 |
+
while True:
|
| 222 |
+
update += 1
|
| 223 |
+
batch = rollout_with_bootstrap(envs, model, rollout_steps, init_obs)
|
| 224 |
+
init_obs = batch["last_obs"]
|
| 225 |
+
|
| 226 |
+
T, N = rollout_steps, envs.num_envs
|
| 227 |
+
total_size = T * N
|
| 228 |
+
|
| 229 |
+
obs = batch["obs"].reshape(total_size, *envs.single_observation_space.shape)
|
| 230 |
+
act = batch["actions"].reshape(total_size)
|
| 231 |
+
logp_old = batch["logprobs"].reshape(total_size)
|
| 232 |
+
adv = batch["advantages"].reshape(total_size)
|
| 233 |
+
ret = batch["returns"].reshape(total_size)
|
| 234 |
+
|
| 235 |
+
for _ in range(epochs):
|
| 236 |
+
idx = torch.randperm(total_size)
|
| 237 |
+
for start in range(0, total_size, minibatch_size):
|
| 238 |
+
i = idx[start : start + minibatch_size]
|
| 239 |
+
logits, value = model(obs[i])
|
| 240 |
+
dist = torch.distributions.Categorical(logits=logits)
|
| 241 |
+
logp = dist.log_prob(act[i])
|
| 242 |
+
ratio = torch.exp(logp - logp_old[i])
|
| 243 |
+
clipped = torch.clamp(ratio, 1 - clip_eps, 1 + clip_eps) * adv[i]
|
| 244 |
+
policy_loss = -torch.min(ratio * adv[i], clipped).mean()
|
| 245 |
+
value_loss = (ret[i] - value).pow(2).mean()
|
| 246 |
+
entropy = dist.entropy().mean()
|
| 247 |
+
loss = policy_loss + vf_coef * value_loss - ent_coef * entropy
|
| 248 |
+
|
| 249 |
+
optimizer.zero_grad()
|
| 250 |
+
loss.backward()
|
| 251 |
+
optimizer.step()
|
| 252 |
+
|
| 253 |
+
# logging
|
| 254 |
+
avg_return = batch["returns"].mean().item()
|
| 255 |
+
max_stage = batch["max_stage"]
|
| 256 |
+
print(f"Update {update}: avg return = {avg_return:.2f} {max_stage=}")
|
| 257 |
+
|
| 258 |
+
# eval and save
|
| 259 |
+
if update % 10 == 0:
|
| 260 |
+
avg_score, info, eval_max_stage = evaluate_policy(
|
| 261 |
+
eval_env, model, episodes=1, render=False
|
| 262 |
+
)
|
| 263 |
+
print(f"[Eval] Update {update}: avg return = {avg_score:.2f} info: {info}")
|
| 264 |
+
if eval_max_stage > 1:
|
| 265 |
+
torch.save(model.state_dict(), "mario_1_1_clear.pt")
|
| 266 |
+
break
|
| 267 |
+
if update > 0 and update % 50 == 0:
|
| 268 |
+
torch.save(model.state_dict(), "mario_1_1_ppo.pt")
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
if __name__ == "__main__":
|
| 272 |
+
train_ppo()
|
Super-Mario-RL/requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
numpy==1.26.4
|
| 2 |
+
torch>=1.6.0
|
| 3 |
+
torchvision
|
| 4 |
+
gym==0.23
|
| 5 |
+
nes-py
|
| 6 |
+
gym-super-mario-bros==7.2.3
|
| 7 |
+
opencv-python
|
| 8 |
+
matplotlib
|
Super-Mario-RL/score.p
ADDED
|
Binary file (1.37 kB). View file
|
|
|
Super-Mario-RL/terminal.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
python enhanced_dual_dqn.py
|
| 2 |
+
python eval.py enhanced_mario_q_best.pth
|
| 3 |
+
python eval.py mario_q_target.pth
|
| 4 |
+
python eval.py mario_q_best.pth
|
| 5 |
+
python eval.py mario_q.pth
|
Super-Mario-RL/wrappers.py
ADDED
|
@@ -0,0 +1,361 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Code from OpenAI baseline
|
| 3 |
+
https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
os.environ.setdefault("PATH", "")
|
| 11 |
+
from collections import deque
|
| 12 |
+
|
| 13 |
+
import cv2
|
| 14 |
+
import gym
|
| 15 |
+
from gym import spaces
|
| 16 |
+
|
| 17 |
+
cv2.ocl.setUseOpenCL(False)
|
| 18 |
+
from gym.wrappers import TimeLimit
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class NoopResetEnv(gym.Wrapper):
|
| 22 |
+
def __init__(self, env, noop_max=30):
|
| 23 |
+
"""Sample initial states by taking random number of no-ops on reset.
|
| 24 |
+
No-op is assumed to be action 0.
|
| 25 |
+
"""
|
| 26 |
+
gym.Wrapper.__init__(self, env)
|
| 27 |
+
self.noop_max = noop_max
|
| 28 |
+
self.override_num_noops = None
|
| 29 |
+
self.noop_action = 0
|
| 30 |
+
assert env.unwrapped.get_action_meanings()[0] == "NOOP"
|
| 31 |
+
|
| 32 |
+
def reset(self, **kwargs):
|
| 33 |
+
"""Do no-op action for a number of steps in [1, noop_max]."""
|
| 34 |
+
self.env.reset(**kwargs)
|
| 35 |
+
if self.override_num_noops is not None:
|
| 36 |
+
noops = self.override_num_noops
|
| 37 |
+
else:
|
| 38 |
+
noops = self.unwrapped.np_random.randint(
|
| 39 |
+
1, self.noop_max + 1
|
| 40 |
+
) # pylint: disable=E1101
|
| 41 |
+
assert noops > 0
|
| 42 |
+
obs = None
|
| 43 |
+
for _ in range(noops):
|
| 44 |
+
obs, _, done, _ = self.env.step(self.noop_action)
|
| 45 |
+
if done:
|
| 46 |
+
obs = self.env.reset(**kwargs)
|
| 47 |
+
return obs
|
| 48 |
+
|
| 49 |
+
def step(self, ac):
|
| 50 |
+
return self.env.step(ac)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class FireResetEnv(gym.Wrapper):
|
| 54 |
+
def __init__(self, env):
|
| 55 |
+
"""Take action on reset for environments that are fixed until firing."""
|
| 56 |
+
gym.Wrapper.__init__(self, env)
|
| 57 |
+
assert env.unwrapped.get_action_meanings()[1] == "FIRE"
|
| 58 |
+
assert len(env.unwrapped.get_action_meanings()) >= 3
|
| 59 |
+
|
| 60 |
+
def reset(self, **kwargs):
|
| 61 |
+
self.env.reset(**kwargs)
|
| 62 |
+
obs, _, done, _ = self.env.step(1)
|
| 63 |
+
if done:
|
| 64 |
+
self.env.reset(**kwargs)
|
| 65 |
+
obs, _, done, _ = self.env.step(2)
|
| 66 |
+
if done:
|
| 67 |
+
self.env.reset(**kwargs)
|
| 68 |
+
return obs
|
| 69 |
+
|
| 70 |
+
def step(self, ac):
|
| 71 |
+
return self.env.step(ac)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class EpisodicLifeEnv(gym.Wrapper):
|
| 75 |
+
def __init__(self, env):
|
| 76 |
+
"""Make end-of-life == end-of-episode, but only reset on true game over.
|
| 77 |
+
Done by DeepMind for the DQN and co. since it helps value estimation.
|
| 78 |
+
"""
|
| 79 |
+
gym.Wrapper.__init__(self, env)
|
| 80 |
+
self.lives = 0
|
| 81 |
+
self.was_real_done = True
|
| 82 |
+
|
| 83 |
+
def step(self, action):
|
| 84 |
+
obs, reward, done, info = self.env.step(action)
|
| 85 |
+
self.was_real_done = done
|
| 86 |
+
# check current lives, make loss of life terminal,
|
| 87 |
+
# then update lives to handle bonus lives
|
| 88 |
+
lives = self.env.unwrapped.ale.lives()
|
| 89 |
+
if lives < self.lives and lives > 0:
|
| 90 |
+
# for Qbert sometimes we stay in lives == 0 condition for a few frames
|
| 91 |
+
# so it's important to keep lives > 0, so that we only reset once
|
| 92 |
+
# the environment advertises done.
|
| 93 |
+
done = True
|
| 94 |
+
self.lives = lives
|
| 95 |
+
return obs, reward, done, info
|
| 96 |
+
|
| 97 |
+
def reset(self, **kwargs):
|
| 98 |
+
"""Reset only when lives are exhausted.
|
| 99 |
+
This way all states are still reachable even though lives are episodic,
|
| 100 |
+
and the learner need not know about any of this behind-the-scenes.
|
| 101 |
+
"""
|
| 102 |
+
if self.was_real_done:
|
| 103 |
+
obs = self.env.reset(**kwargs)
|
| 104 |
+
else:
|
| 105 |
+
# no-op step to advance from terminal/lost life state
|
| 106 |
+
obs, _, _, _ = self.env.step(0)
|
| 107 |
+
self.lives = self.env.unwrapped.ale.lives()
|
| 108 |
+
return obs
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
class MaxAndSkipEnv(gym.Wrapper):
|
| 112 |
+
def __init__(self, env, skip=4):
|
| 113 |
+
"""Return only every `skip`-th frame"""
|
| 114 |
+
gym.Wrapper.__init__(self, env)
|
| 115 |
+
# most recent raw observations (for max pooling across time steps)
|
| 116 |
+
self._obs_buffer = np.zeros((2,) + env.observation_space.shape, dtype=np.uint8)
|
| 117 |
+
self._skip = skip
|
| 118 |
+
|
| 119 |
+
def step(self, action):
|
| 120 |
+
"""Repeat action, sum reward, and max over last observations."""
|
| 121 |
+
total_reward = 0.0
|
| 122 |
+
done = None
|
| 123 |
+
for i in range(self._skip):
|
| 124 |
+
obs, reward, done, info = self.env.step(action)
|
| 125 |
+
if i == self._skip - 2:
|
| 126 |
+
self._obs_buffer[0] = obs
|
| 127 |
+
if i == self._skip - 1:
|
| 128 |
+
self._obs_buffer[1] = obs
|
| 129 |
+
total_reward += reward
|
| 130 |
+
if done:
|
| 131 |
+
break
|
| 132 |
+
# Note that the observation on the done=True frame
|
| 133 |
+
# doesn't matter
|
| 134 |
+
max_frame = self._obs_buffer.max(axis=0)
|
| 135 |
+
|
| 136 |
+
return max_frame, total_reward, done, info
|
| 137 |
+
|
| 138 |
+
def reset(self, **kwargs):
|
| 139 |
+
return self.env.reset(**kwargs)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
class ClipRewardEnv(gym.RewardWrapper):
|
| 143 |
+
def __init__(self, env):
|
| 144 |
+
gym.RewardWrapper.__init__(self, env)
|
| 145 |
+
|
| 146 |
+
def reward(self, reward):
|
| 147 |
+
"""Bin reward to {+1, 0, -1} by its sign."""
|
| 148 |
+
return np.sign(reward)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
class WarpFrame(gym.ObservationWrapper):
|
| 152 |
+
def __init__(self, env, width=84, height=84, grayscale=True, dict_space_key=None):
|
| 153 |
+
"""
|
| 154 |
+
Warp frames to 84x84 as done in the Nature paper and later work.
|
| 155 |
+
If the environment uses dictionary observations, `dict_space_key` can be specified which indicates which
|
| 156 |
+
observation should be warped.
|
| 157 |
+
"""
|
| 158 |
+
super().__init__(env)
|
| 159 |
+
self._width = width
|
| 160 |
+
self._height = height
|
| 161 |
+
self._grayscale = grayscale
|
| 162 |
+
self._key = dict_space_key
|
| 163 |
+
if self._grayscale:
|
| 164 |
+
num_colors = 1
|
| 165 |
+
else:
|
| 166 |
+
num_colors = 3
|
| 167 |
+
|
| 168 |
+
new_space = gym.spaces.Box(
|
| 169 |
+
low=0,
|
| 170 |
+
high=255,
|
| 171 |
+
shape=(self._height, self._width, num_colors),
|
| 172 |
+
dtype=np.uint8,
|
| 173 |
+
)
|
| 174 |
+
if self._key is None:
|
| 175 |
+
original_space = self.observation_space
|
| 176 |
+
self.observation_space = new_space
|
| 177 |
+
else:
|
| 178 |
+
original_space = self.observation_space.spaces[self._key]
|
| 179 |
+
self.observation_space.spaces[self._key] = new_space
|
| 180 |
+
assert original_space.dtype == np.uint8 and len(original_space.shape) == 3
|
| 181 |
+
|
| 182 |
+
def observation(self, obs):
|
| 183 |
+
if self._key is None:
|
| 184 |
+
frame = obs
|
| 185 |
+
else:
|
| 186 |
+
frame = obs[self._key]
|
| 187 |
+
|
| 188 |
+
if self._grayscale:
|
| 189 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
|
| 190 |
+
frame = cv2.resize(
|
| 191 |
+
frame, (self._width, self._height), interpolation=cv2.INTER_AREA
|
| 192 |
+
)
|
| 193 |
+
if self._grayscale:
|
| 194 |
+
frame = np.expand_dims(frame, -1)
|
| 195 |
+
|
| 196 |
+
if self._key is None:
|
| 197 |
+
obs = frame
|
| 198 |
+
else:
|
| 199 |
+
obs = obs.copy()
|
| 200 |
+
obs[self._key] = frame
|
| 201 |
+
return obs
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
class FrameStack(gym.Wrapper):
|
| 205 |
+
def __init__(self, env, k):
|
| 206 |
+
"""Stack k last frames.
|
| 207 |
+
Returns lazy array, which is much more memory efficient.
|
| 208 |
+
See Also
|
| 209 |
+
--------
|
| 210 |
+
baselines.common.atari_wrappers.LazyFrames
|
| 211 |
+
"""
|
| 212 |
+
gym.Wrapper.__init__(self, env)
|
| 213 |
+
self.k = k
|
| 214 |
+
self.frames = deque([], maxlen=k)
|
| 215 |
+
shp = env.observation_space.shape
|
| 216 |
+
self.observation_space = spaces.Box(
|
| 217 |
+
low=0,
|
| 218 |
+
high=255,
|
| 219 |
+
shape=(shp[:-1] + (shp[-1] * k,)),
|
| 220 |
+
dtype=env.observation_space.dtype,
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
def reset(self):
|
| 224 |
+
ob = self.env.reset()
|
| 225 |
+
for _ in range(self.k):
|
| 226 |
+
self.frames.append(ob)
|
| 227 |
+
return self._get_ob()
|
| 228 |
+
|
| 229 |
+
def step(self, action):
|
| 230 |
+
ob, reward, done, info = self.env.step(action)
|
| 231 |
+
self.frames.append(ob)
|
| 232 |
+
return self._get_ob(), reward, done, info
|
| 233 |
+
|
| 234 |
+
def _get_ob(self):
|
| 235 |
+
assert len(self.frames) == self.k
|
| 236 |
+
return LazyFrames(list(self.frames))
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
class ScaledFloatFrame(gym.ObservationWrapper):
|
| 240 |
+
def __init__(self, env):
|
| 241 |
+
gym.ObservationWrapper.__init__(self, env)
|
| 242 |
+
self.observation_space = gym.spaces.Box(
|
| 243 |
+
low=0, high=1, shape=env.observation_space.shape, dtype=np.float32
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
def observation(self, observation):
|
| 247 |
+
# careful! This undoes the memory optimization, use
|
| 248 |
+
# with smaller replay buffers only.
|
| 249 |
+
return np.array(observation).astype(np.float32) / 255.0
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
class LazyFrames(object):
|
| 253 |
+
def __init__(self, frames):
|
| 254 |
+
"""This object ensures that common frames between the observations are only stored once.
|
| 255 |
+
It exists purely to optimize memory usage which can be huge for DQN's 1M frames replay
|
| 256 |
+
buffers.
|
| 257 |
+
This object should only be converted to numpy array before being passed to the model.
|
| 258 |
+
You'd not believe how complex the previous solution was."""
|
| 259 |
+
self._frames = frames
|
| 260 |
+
self._out = None
|
| 261 |
+
|
| 262 |
+
def _force(self):
|
| 263 |
+
if self._out is None:
|
| 264 |
+
self._out = np.concatenate(self._frames, axis=-1)
|
| 265 |
+
self._frames = None
|
| 266 |
+
return self._out
|
| 267 |
+
|
| 268 |
+
def __array__(self, dtype=None):
|
| 269 |
+
out = self._force()
|
| 270 |
+
if dtype is not None:
|
| 271 |
+
out = out.astype(dtype)
|
| 272 |
+
return out
|
| 273 |
+
|
| 274 |
+
def __len__(self):
|
| 275 |
+
return len(self._force())
|
| 276 |
+
|
| 277 |
+
def __getitem__(self, i):
|
| 278 |
+
return self._force()[i]
|
| 279 |
+
|
| 280 |
+
def count(self):
|
| 281 |
+
frames = self._force()
|
| 282 |
+
return frames.shape[frames.ndim - 1]
|
| 283 |
+
|
| 284 |
+
def frame(self, i):
|
| 285 |
+
return self._force()[..., i]
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
def make_atari(env_id, max_episode_steps=None):
|
| 289 |
+
env = gym.make(env_id)
|
| 290 |
+
assert "NoFrameskip" in env.spec.id
|
| 291 |
+
env = NoopResetEnv(env, noop_max=30)
|
| 292 |
+
env = MaxAndSkipEnv(env, skip=4)
|
| 293 |
+
if max_episode_steps is not None:
|
| 294 |
+
env = TimeLimit(env, max_episode_steps=max_episode_steps)
|
| 295 |
+
return env
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
def wrap_deepmind(
|
| 299 |
+
env, episode_life=True, clip_rewards=True, frame_stack=True, scale=True
|
| 300 |
+
):
|
| 301 |
+
"""Configure environment for DeepMind-style Atari."""
|
| 302 |
+
if episode_life:
|
| 303 |
+
env = EpisodicLifeEnv(env)
|
| 304 |
+
if "FIRE" in env.unwrapped.get_action_meanings():
|
| 305 |
+
env = FireResetEnv(env)
|
| 306 |
+
env = WarpFrame(env)
|
| 307 |
+
if scale:
|
| 308 |
+
env = ScaledFloatFrame(env)
|
| 309 |
+
if clip_rewards:
|
| 310 |
+
env = ClipRewardEnv(env)
|
| 311 |
+
if frame_stack:
|
| 312 |
+
env = FrameStack(env, 4)
|
| 313 |
+
return env
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
class EpisodicLifeMario(gym.Wrapper):
|
| 317 |
+
def __init__(self, env):
|
| 318 |
+
"""Make end-of-life == end-of-episode, but only reset on true game over.
|
| 319 |
+
Done by DeepMind for the DQN and co. since it helps value estimation.
|
| 320 |
+
"""
|
| 321 |
+
gym.Wrapper.__init__(self, env)
|
| 322 |
+
self.lives = 0
|
| 323 |
+
self.was_real_done = True
|
| 324 |
+
|
| 325 |
+
def step(self, action):
|
| 326 |
+
obs, reward, done, info = self.env.step(action)
|
| 327 |
+
self.was_real_done = done
|
| 328 |
+
# check current lives, make loss of life terminal,
|
| 329 |
+
# then update lives to handle bonus lives
|
| 330 |
+
lives = self.env.unwrapped._life
|
| 331 |
+
if lives < self.lives and lives > 0:
|
| 332 |
+
# for Qbert sometimes we stay in lives == 0 condition for a few frames
|
| 333 |
+
# so it's important to keep lives > 0, so that we only reset once
|
| 334 |
+
# the environment advertises done.
|
| 335 |
+
done = True
|
| 336 |
+
self.lives = lives
|
| 337 |
+
return obs, reward, done, info
|
| 338 |
+
|
| 339 |
+
def reset(self, **kwargs):
|
| 340 |
+
"""Reset only when lives are exhausted.
|
| 341 |
+
This way all states are still reachable even though lives are episodic,
|
| 342 |
+
and the learner need not know about any of this behind-the-scenes.
|
| 343 |
+
"""
|
| 344 |
+
if self.was_real_done:
|
| 345 |
+
obs = self.env.reset(**kwargs)
|
| 346 |
+
else:
|
| 347 |
+
# no-op step to advance from terminal/lost life state
|
| 348 |
+
obs, _, _, _ = self.env.step(0)
|
| 349 |
+
self.lives = self.env.unwrapped._life
|
| 350 |
+
return obs
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
def wrap_mario(env):
|
| 354 |
+
env = NoopResetEnv(env, noop_max=30)
|
| 355 |
+
env = MaxAndSkipEnv(env, skip=4)
|
| 356 |
+
env = EpisodicLifeMario(env)
|
| 357 |
+
env = WarpFrame(env)
|
| 358 |
+
env = ScaledFloatFrame(env)
|
| 359 |
+
# env = custom_reward(env)
|
| 360 |
+
env = FrameStack(env, 4)
|
| 361 |
+
return env
|
ale_pyqt5/app.py
ADDED
|
@@ -0,0 +1,514 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
import numpy as np
|
| 4 |
+
import random
|
| 5 |
+
from collections import deque
|
| 6 |
+
import gymnasium as gym
|
| 7 |
+
import ale_py
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.optim as optim
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
from torch.distributions import Categorical
|
| 14 |
+
|
| 15 |
+
from PyQt5.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout,
|
| 16 |
+
QHBoxLayout, QPushButton, QLabel, QComboBox,
|
| 17 |
+
QTextEdit, QProgressBar, QTabWidget, QFrame)
|
| 18 |
+
from PyQt5.QtCore import QTimer, Qt, pyqtSignal, QThread
|
| 19 |
+
from PyQt5.QtGui import QImage, QPixmap, QFont
|
| 20 |
+
|
| 21 |
+
# Register ALE environments
|
| 22 |
+
gym.register_envs(ale_py)
|
| 23 |
+
|
| 24 |
+
# Environment setup
|
| 25 |
+
def create_env(env_name='ALE/Breakout-v5'):
|
| 26 |
+
"""
|
| 27 |
+
Create ALE environment with Gymnasium API
|
| 28 |
+
Available environments:
|
| 29 |
+
- ALE/Breakout-v5, ALE/Pong-v5, ALE/SpaceInvaders-v5,
|
| 30 |
+
- ALE/Assault-v5, ALE/BeamRider-v5, ALE/Enduro-v5
|
| 31 |
+
"""
|
| 32 |
+
env = gym.make(env_name, render_mode='rgb_array')
|
| 33 |
+
return env
|
| 34 |
+
|
| 35 |
+
# Neural Network for Dueling DQN
|
| 36 |
+
class DuelingDQN(nn.Module):
|
| 37 |
+
def __init__(self, input_shape, n_actions):
|
| 38 |
+
super(DuelingDQN, self).__init__()
|
| 39 |
+
self.conv = nn.Sequential(
|
| 40 |
+
nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
|
| 41 |
+
nn.ReLU(),
|
| 42 |
+
nn.Conv2d(32, 64, kernel_size=4, stride=2),
|
| 43 |
+
nn.ReLU(),
|
| 44 |
+
nn.Conv2d(64, 64, kernel_size=3, stride=1),
|
| 45 |
+
nn.ReLU()
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
conv_out_size = self._get_conv_out(input_shape)
|
| 49 |
+
|
| 50 |
+
self.fc_advantage = nn.Sequential(
|
| 51 |
+
nn.Linear(conv_out_size, 512),
|
| 52 |
+
nn.ReLU(),
|
| 53 |
+
nn.Linear(512, n_actions)
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
self.fc_value = nn.Sequential(
|
| 57 |
+
nn.Linear(conv_out_size, 512),
|
| 58 |
+
nn.ReLU(),
|
| 59 |
+
nn.Linear(512, 1)
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
def _get_conv_out(self, shape):
|
| 63 |
+
o = self.conv(torch.zeros(1, *shape))
|
| 64 |
+
return int(np.prod(o.size()))
|
| 65 |
+
|
| 66 |
+
def forward(self, x):
|
| 67 |
+
conv_out = self.conv(x).view(x.size()[0], -1)
|
| 68 |
+
advantage = self.fc_advantage(conv_out)
|
| 69 |
+
value = self.fc_value(conv_out)
|
| 70 |
+
return value + advantage - advantage.mean()
|
| 71 |
+
|
| 72 |
+
# Neural Network for PPO
|
| 73 |
+
class PPONetwork(nn.Module):
|
| 74 |
+
def __init__(self, input_shape, n_actions):
|
| 75 |
+
super(PPONetwork, self).__init__()
|
| 76 |
+
self.conv = nn.Sequential(
|
| 77 |
+
nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
|
| 78 |
+
nn.ReLU(),
|
| 79 |
+
nn.Conv2d(32, 64, kernel_size=4, stride=2),
|
| 80 |
+
nn.ReLU(),
|
| 81 |
+
nn.Conv2d(64, 64, kernel_size=3, stride=1),
|
| 82 |
+
nn.ReLU()
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
conv_out_size = self._get_conv_out(input_shape)
|
| 86 |
+
|
| 87 |
+
self.actor = nn.Sequential(
|
| 88 |
+
nn.Linear(conv_out_size, 512),
|
| 89 |
+
nn.ReLU(),
|
| 90 |
+
nn.Linear(512, n_actions),
|
| 91 |
+
nn.Softmax(dim=-1)
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
self.critic = nn.Sequential(
|
| 95 |
+
nn.Linear(conv_out_size, 512),
|
| 96 |
+
nn.ReLU(),
|
| 97 |
+
nn.Linear(512, 1)
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
def _get_conv_out(self, shape):
|
| 101 |
+
o = self.conv(torch.zeros(1, *shape))
|
| 102 |
+
return int(np.prod(o.size()))
|
| 103 |
+
|
| 104 |
+
def forward(self, x):
|
| 105 |
+
conv_out = self.conv(x).view(x.size()[0], -1)
|
| 106 |
+
return self.actor(conv_out), self.critic(conv_out)
|
| 107 |
+
|
| 108 |
+
# Dueling DQN Agent
|
| 109 |
+
class DuelingDQNAgent:
|
| 110 |
+
def __init__(self, state_dim, action_dim, lr=1e-4, gamma=0.99, epsilon=1.0,
|
| 111 |
+
epsilon_min=0.01, epsilon_decay=0.995, memory_size=10000, batch_size=32):
|
| 112 |
+
self.state_dim = state_dim
|
| 113 |
+
self.action_dim = action_dim
|
| 114 |
+
self.lr = lr
|
| 115 |
+
self.gamma = gamma
|
| 116 |
+
self.epsilon = epsilon
|
| 117 |
+
self.epsilon_min = epsilon_min
|
| 118 |
+
self.epsilon_decay = epsilon_decay
|
| 119 |
+
self.batch_size = batch_size
|
| 120 |
+
|
| 121 |
+
self.memory = deque(maxlen=memory_size)
|
| 122 |
+
self.model = DuelingDQN(state_dim, action_dim)
|
| 123 |
+
self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
|
| 124 |
+
self.criterion = nn.MSELoss()
|
| 125 |
+
|
| 126 |
+
def remember(self, state, action, reward, next_state, done):
|
| 127 |
+
self.memory.append((state, action, reward, next_state, done))
|
| 128 |
+
|
| 129 |
+
def act(self, state):
|
| 130 |
+
if np.random.random() <= self.epsilon:
|
| 131 |
+
return random.randrange(self.action_dim)
|
| 132 |
+
|
| 133 |
+
state = torch.FloatTensor(state).unsqueeze(0)
|
| 134 |
+
with torch.no_grad():
|
| 135 |
+
q_values = self.model(state)
|
| 136 |
+
return np.argmax(q_values.detach().numpy())
|
| 137 |
+
|
| 138 |
+
def replay(self):
|
| 139 |
+
if len(self.memory) < self.batch_size:
|
| 140 |
+
return
|
| 141 |
+
|
| 142 |
+
batch = random.sample(self.memory, self.batch_size)
|
| 143 |
+
states = torch.FloatTensor(np.array([e[0] for e in batch]))
|
| 144 |
+
actions = torch.LongTensor([e[1] for e in batch])
|
| 145 |
+
rewards = torch.FloatTensor([e[2] for e in batch])
|
| 146 |
+
next_states = torch.FloatTensor(np.array([e[3] for e in batch]))
|
| 147 |
+
dones = torch.BoolTensor([e[4] for e in batch])
|
| 148 |
+
|
| 149 |
+
current_q_values = self.model(states).gather(1, actions.unsqueeze(1))
|
| 150 |
+
with torch.no_grad():
|
| 151 |
+
next_q_values = self.model(next_states).max(1)[0]
|
| 152 |
+
target_q_values = rewards + (self.gamma * next_q_values * ~dones)
|
| 153 |
+
|
| 154 |
+
loss = self.criterion(current_q_values.squeeze(), target_q_values)
|
| 155 |
+
|
| 156 |
+
self.optimizer.zero_grad()
|
| 157 |
+
loss.backward()
|
| 158 |
+
self.optimizer.step()
|
| 159 |
+
|
| 160 |
+
if self.epsilon > self.epsilon_min:
|
| 161 |
+
self.epsilon *= self.epsilon_decay
|
| 162 |
+
|
| 163 |
+
# PPO Agent
|
| 164 |
+
class PPOAgent:
|
| 165 |
+
def __init__(self, state_dim, action_dim, lr=3e-4, gamma=0.99, epsilon=0.2,
|
| 166 |
+
entropy_coef=0.01, value_coef=0.5):
|
| 167 |
+
self.state_dim = state_dim
|
| 168 |
+
self.action_dim = action_dim
|
| 169 |
+
self.gamma = gamma
|
| 170 |
+
self.epsilon = epsilon
|
| 171 |
+
self.entropy_coef = entropy_coef
|
| 172 |
+
self.value_coef = value_coef
|
| 173 |
+
|
| 174 |
+
self.model = PPONetwork(state_dim, action_dim)
|
| 175 |
+
self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
|
| 176 |
+
|
| 177 |
+
self.memory = []
|
| 178 |
+
|
| 179 |
+
def remember(self, state, action, reward, value, log_prob):
|
| 180 |
+
self.memory.append((state, action, reward, value, log_prob))
|
| 181 |
+
|
| 182 |
+
def act(self, state):
|
| 183 |
+
state = torch.FloatTensor(state).unsqueeze(0)
|
| 184 |
+
with torch.no_grad():
|
| 185 |
+
probs, value = self.model(state)
|
| 186 |
+
dist = Categorical(probs)
|
| 187 |
+
action = dist.sample()
|
| 188 |
+
return action.item(), dist.log_prob(action), value.squeeze()
|
| 189 |
+
|
| 190 |
+
def train(self):
|
| 191 |
+
if not self.memory:
|
| 192 |
+
return
|
| 193 |
+
|
| 194 |
+
states, actions, rewards, values, log_probs = zip(*self.memory)
|
| 195 |
+
|
| 196 |
+
# Calculate returns and advantages
|
| 197 |
+
returns = []
|
| 198 |
+
R = 0
|
| 199 |
+
|
| 200 |
+
for r in reversed(rewards):
|
| 201 |
+
R = r + self.gamma * R
|
| 202 |
+
returns.insert(0, R)
|
| 203 |
+
|
| 204 |
+
returns = torch.FloatTensor(returns)
|
| 205 |
+
values = torch.FloatTensor(values)
|
| 206 |
+
advantages = returns - values
|
| 207 |
+
|
| 208 |
+
# Normalize advantages
|
| 209 |
+
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
|
| 210 |
+
|
| 211 |
+
# Convert to tensors
|
| 212 |
+
states = torch.FloatTensor(np.array(states))
|
| 213 |
+
actions = torch.LongTensor(actions)
|
| 214 |
+
old_log_probs = torch.FloatTensor(log_probs)
|
| 215 |
+
|
| 216 |
+
# Get new probabilities
|
| 217 |
+
new_probs, new_values = self.model(states)
|
| 218 |
+
dist = Categorical(new_probs)
|
| 219 |
+
new_log_probs = dist.log_prob(actions)
|
| 220 |
+
entropy = dist.entropy().mean()
|
| 221 |
+
|
| 222 |
+
# PPO loss
|
| 223 |
+
ratio = (new_log_probs - old_log_probs).exp()
|
| 224 |
+
surr1 = ratio * advantages
|
| 225 |
+
surr2 = torch.clamp(ratio, 1 - self.epsilon, 1 + self.epsilon) * advantages
|
| 226 |
+
actor_loss = -torch.min(surr1, surr2).mean()
|
| 227 |
+
|
| 228 |
+
critic_loss = F.mse_loss(new_values.squeeze(), returns)
|
| 229 |
+
|
| 230 |
+
total_loss = actor_loss + self.value_coef * critic_loss - self.entropy_coef * entropy
|
| 231 |
+
|
| 232 |
+
self.optimizer.zero_grad()
|
| 233 |
+
total_loss.backward()
|
| 234 |
+
self.optimizer.step()
|
| 235 |
+
|
| 236 |
+
self.memory = []
|
| 237 |
+
|
| 238 |
+
# Training Thread
|
| 239 |
+
class TrainingThread(QThread):
|
| 240 |
+
update_signal = pyqtSignal(dict)
|
| 241 |
+
frame_signal = pyqtSignal(np.ndarray)
|
| 242 |
+
|
| 243 |
+
def __init__(self, algorithm='dqn', env_name='ALE/Breakout-v5'):
|
| 244 |
+
super().__init__()
|
| 245 |
+
self.algorithm = algorithm
|
| 246 |
+
self.env_name = env_name
|
| 247 |
+
self.running = False
|
| 248 |
+
self.env = None
|
| 249 |
+
self.agent = None
|
| 250 |
+
|
| 251 |
+
def preprocess_state(self, state):
|
| 252 |
+
# Convert to CHW format and normalize
|
| 253 |
+
state = state.transpose((2, 0, 1))
|
| 254 |
+
state = state / 255.0
|
| 255 |
+
return state
|
| 256 |
+
|
| 257 |
+
def run(self):
|
| 258 |
+
self.running = True
|
| 259 |
+
try:
|
| 260 |
+
self.env = create_env(self.env_name)
|
| 261 |
+
state, info = self.env.reset()
|
| 262 |
+
state = self.preprocess_state(state)
|
| 263 |
+
|
| 264 |
+
n_actions = self.env.action_space.n
|
| 265 |
+
state_dim = state.shape
|
| 266 |
+
|
| 267 |
+
print(f"Environment: {self.env_name}")
|
| 268 |
+
print(f"State shape: {state_dim}, Actions: {n_actions}")
|
| 269 |
+
|
| 270 |
+
if self.algorithm == 'dqn':
|
| 271 |
+
self.agent = DuelingDQNAgent(state_dim, n_actions)
|
| 272 |
+
else:
|
| 273 |
+
self.agent = PPOAgent(state_dim, n_actions)
|
| 274 |
+
|
| 275 |
+
episode = 0
|
| 276 |
+
total_reward = 0
|
| 277 |
+
steps = 0
|
| 278 |
+
episode_rewards = []
|
| 279 |
+
|
| 280 |
+
while self.running:
|
| 281 |
+
try:
|
| 282 |
+
if self.algorithm == 'dqn':
|
| 283 |
+
action = self.agent.act(state)
|
| 284 |
+
next_state, reward, terminated, truncated, info = self.env.step(action)
|
| 285 |
+
done = terminated or truncated
|
| 286 |
+
next_state = self.preprocess_state(next_state)
|
| 287 |
+
self.agent.remember(state, action, reward, next_state, done)
|
| 288 |
+
self.agent.replay()
|
| 289 |
+
else:
|
| 290 |
+
action, log_prob, value = self.agent.act(state)
|
| 291 |
+
next_state, reward, terminated, truncated, info = self.env.step(action)
|
| 292 |
+
done = terminated or truncated
|
| 293 |
+
next_state = self.preprocess_state(next_state)
|
| 294 |
+
self.agent.remember(state, action, reward, value, log_prob)
|
| 295 |
+
if done:
|
| 296 |
+
self.agent.train()
|
| 297 |
+
|
| 298 |
+
state = next_state
|
| 299 |
+
total_reward += reward
|
| 300 |
+
steps += 1
|
| 301 |
+
|
| 302 |
+
# Emit frame for display
|
| 303 |
+
try:
|
| 304 |
+
frame = self.env.render()
|
| 305 |
+
if frame is not None:
|
| 306 |
+
self.frame_signal.emit(frame)
|
| 307 |
+
except Exception as e:
|
| 308 |
+
# Create a placeholder frame if rendering fails
|
| 309 |
+
frame = np.zeros((210, 160, 3), dtype=np.uint8)
|
| 310 |
+
self.frame_signal.emit(frame)
|
| 311 |
+
|
| 312 |
+
# Emit training progress
|
| 313 |
+
if steps % 10 == 0:
|
| 314 |
+
progress_data = {
|
| 315 |
+
'episode': episode,
|
| 316 |
+
'total_reward': total_reward,
|
| 317 |
+
'steps': steps,
|
| 318 |
+
'epsilon': self.agent.epsilon if self.algorithm == 'dqn' else 0.2,
|
| 319 |
+
'env_name': self.env_name,
|
| 320 |
+
'lives': info.get('lives', 0) if isinstance(info, dict) else 0
|
| 321 |
+
}
|
| 322 |
+
self.update_signal.emit(progress_data)
|
| 323 |
+
|
| 324 |
+
if terminated or truncated:
|
| 325 |
+
episode_rewards.append(total_reward)
|
| 326 |
+
avg_reward = np.mean(episode_rewards[-10:]) if episode_rewards else total_reward
|
| 327 |
+
|
| 328 |
+
print(f"Episode {episode}: Total Reward: {total_reward:.2f}, "
|
| 329 |
+
f"Steps: {steps}, Avg Reward (last 10): {avg_reward:.2f}")
|
| 330 |
+
|
| 331 |
+
episode += 1
|
| 332 |
+
state, info = self.env.reset()
|
| 333 |
+
state = self.preprocess_state(state)
|
| 334 |
+
total_reward = 0
|
| 335 |
+
steps = 0
|
| 336 |
+
|
| 337 |
+
except Exception as e:
|
| 338 |
+
print(f"Error in training loop: {e}")
|
| 339 |
+
import traceback
|
| 340 |
+
traceback.print_exc()
|
| 341 |
+
break
|
| 342 |
+
|
| 343 |
+
except Exception as e:
|
| 344 |
+
print(f"Error setting up environment: {e}")
|
| 345 |
+
import traceback
|
| 346 |
+
traceback.print_exc()
|
| 347 |
+
|
| 348 |
+
def stop(self):
|
| 349 |
+
self.running = False
|
| 350 |
+
if self.env:
|
| 351 |
+
self.env.close()
|
| 352 |
+
|
| 353 |
+
# Main Application Window
|
| 354 |
+
class ALE_RLApp(QMainWindow):
|
| 355 |
+
def __init__(self):
|
| 356 |
+
super().__init__()
|
| 357 |
+
self.training_thread = None
|
| 358 |
+
self.init_ui()
|
| 359 |
+
|
| 360 |
+
def init_ui(self):
|
| 361 |
+
self.setWindowTitle('๐ฎ ALE Arcade RL Training')
|
| 362 |
+
self.setGeometry(100, 100, 1200, 800)
|
| 363 |
+
|
| 364 |
+
central_widget = QWidget()
|
| 365 |
+
self.setCentralWidget(central_widget)
|
| 366 |
+
layout = QVBoxLayout(central_widget)
|
| 367 |
+
|
| 368 |
+
# Title
|
| 369 |
+
title = QLabel('๐ฎ Arcade Reinforcement Learning (ALE)')
|
| 370 |
+
title.setFont(QFont('Arial', 16, QFont.Bold))
|
| 371 |
+
title.setAlignment(Qt.AlignCenter)
|
| 372 |
+
layout.addWidget(title)
|
| 373 |
+
|
| 374 |
+
# Control Panel
|
| 375 |
+
control_layout = QHBoxLayout()
|
| 376 |
+
|
| 377 |
+
self.algorithm_combo = QComboBox()
|
| 378 |
+
self.algorithm_combo.addItems(['Dueling DQN', 'PPO'])
|
| 379 |
+
|
| 380 |
+
self.env_combo = QComboBox()
|
| 381 |
+
self.env_combo.addItems([
|
| 382 |
+
'ALE/Breakout-v5',
|
| 383 |
+
'ALE/Pong-v5',
|
| 384 |
+
'ALE/SpaceInvaders-v5',
|
| 385 |
+
'ALE/Assault-v5',
|
| 386 |
+
'ALE/BeamRider-v5',
|
| 387 |
+
'ALE/Enduro-v5',
|
| 388 |
+
'ALE/Seaquest-v5',
|
| 389 |
+
'ALE/Qbert-v5'
|
| 390 |
+
])
|
| 391 |
+
|
| 392 |
+
self.start_btn = QPushButton('Start Training')
|
| 393 |
+
self.start_btn.clicked.connect(self.start_training)
|
| 394 |
+
|
| 395 |
+
self.stop_btn = QPushButton('Stop Training')
|
| 396 |
+
self.stop_btn.clicked.connect(self.stop_training)
|
| 397 |
+
self.stop_btn.setEnabled(False)
|
| 398 |
+
|
| 399 |
+
control_layout.addWidget(QLabel('Algorithm:'))
|
| 400 |
+
control_layout.addWidget(self.algorithm_combo)
|
| 401 |
+
control_layout.addWidget(QLabel('Environment:'))
|
| 402 |
+
control_layout.addWidget(self.env_combo)
|
| 403 |
+
control_layout.addWidget(self.start_btn)
|
| 404 |
+
control_layout.addWidget(self.stop_btn)
|
| 405 |
+
control_layout.addStretch()
|
| 406 |
+
|
| 407 |
+
layout.addLayout(control_layout)
|
| 408 |
+
|
| 409 |
+
# Content Area
|
| 410 |
+
content_layout = QHBoxLayout()
|
| 411 |
+
|
| 412 |
+
# Left side - Game Display
|
| 413 |
+
left_frame = QFrame()
|
| 414 |
+
left_frame.setFrameStyle(QFrame.Box)
|
| 415 |
+
left_layout = QVBoxLayout(left_frame)
|
| 416 |
+
|
| 417 |
+
self.game_display = QLabel()
|
| 418 |
+
self.game_display.setMinimumSize(400, 300)
|
| 419 |
+
self.game_display.setAlignment(Qt.AlignCenter)
|
| 420 |
+
self.game_display.setText('Game display will appear here\nPress "Start Training" to begin')
|
| 421 |
+
self.game_display.setStyleSheet('border: 1px solid gray; background-color: black; color: white;')
|
| 422 |
+
|
| 423 |
+
left_layout.addWidget(QLabel('Game Display:'))
|
| 424 |
+
left_layout.addWidget(self.game_display)
|
| 425 |
+
|
| 426 |
+
# Right side - Training Info
|
| 427 |
+
right_frame = QFrame()
|
| 428 |
+
right_frame.setFrameStyle(QFrame.Box)
|
| 429 |
+
right_layout = QVBoxLayout(right_frame)
|
| 430 |
+
|
| 431 |
+
# Progress bars
|
| 432 |
+
self.env_label = QLabel('Environment: Not started')
|
| 433 |
+
self.episode_label = QLabel('Episode: 0')
|
| 434 |
+
self.reward_label = QLabel('Total Reward: 0')
|
| 435 |
+
self.steps_label = QLabel('Steps: 0')
|
| 436 |
+
self.epsilon_label = QLabel('Epsilon: 0')
|
| 437 |
+
self.lives_label = QLabel('Lives: 0')
|
| 438 |
+
|
| 439 |
+
right_layout.addWidget(self.env_label)
|
| 440 |
+
right_layout.addWidget(self.episode_label)
|
| 441 |
+
right_layout.addWidget(self.reward_label)
|
| 442 |
+
right_layout.addWidget(self.steps_label)
|
| 443 |
+
right_layout.addWidget(self.epsilon_label)
|
| 444 |
+
right_layout.addWidget(self.lives_label)
|
| 445 |
+
|
| 446 |
+
# Training log
|
| 447 |
+
right_layout.addWidget(QLabel('Training Log:'))
|
| 448 |
+
self.log_text = QTextEdit()
|
| 449 |
+
self.log_text.setMaximumHeight(200)
|
| 450 |
+
right_layout.addWidget(self.log_text)
|
| 451 |
+
|
| 452 |
+
content_layout.addWidget(left_frame)
|
| 453 |
+
content_layout.addWidget(right_frame)
|
| 454 |
+
layout.addLayout(content_layout)
|
| 455 |
+
|
| 456 |
+
def start_training(self):
|
| 457 |
+
algorithm = 'dqn' if self.algorithm_combo.currentText() == 'Dueling DQN' else 'ppo'
|
| 458 |
+
env_name = self.env_combo.currentText()
|
| 459 |
+
|
| 460 |
+
self.training_thread = TrainingThread(algorithm, env_name)
|
| 461 |
+
self.training_thread.update_signal.connect(self.update_training_info)
|
| 462 |
+
self.training_thread.frame_signal.connect(self.update_game_display)
|
| 463 |
+
self.training_thread.start()
|
| 464 |
+
|
| 465 |
+
self.start_btn.setEnabled(False)
|
| 466 |
+
self.stop_btn.setEnabled(True)
|
| 467 |
+
|
| 468 |
+
self.log_text.append(f'Started {self.algorithm_combo.currentText()} training on {env_name}...')
|
| 469 |
+
|
| 470 |
+
def stop_training(self):
|
| 471 |
+
if self.training_thread:
|
| 472 |
+
self.training_thread.stop()
|
| 473 |
+
self.training_thread.wait()
|
| 474 |
+
|
| 475 |
+
self.start_btn.setEnabled(True)
|
| 476 |
+
self.stop_btn.setEnabled(False)
|
| 477 |
+
self.log_text.append('Training stopped.')
|
| 478 |
+
|
| 479 |
+
def update_training_info(self, data):
|
| 480 |
+
self.env_label.setText(f'Environment: {data.get("env_name", "Unknown")}')
|
| 481 |
+
self.episode_label.setText(f'Episode: {data["episode"]}')
|
| 482 |
+
self.reward_label.setText(f'Total Reward: {data["total_reward"]:.2f}')
|
| 483 |
+
self.steps_label.setText(f'Steps: {data["steps"]}')
|
| 484 |
+
self.epsilon_label.setText(f'Epsilon: {data["epsilon"]:.3f}')
|
| 485 |
+
self.lives_label.setText(f'Lives: {data.get("lives", 0)}')
|
| 486 |
+
|
| 487 |
+
def update_game_display(self, frame):
|
| 488 |
+
if frame is not None:
|
| 489 |
+
try:
|
| 490 |
+
h, w, ch = frame.shape
|
| 491 |
+
bytes_per_line = ch * w
|
| 492 |
+
q_img = QImage(frame.data, w, h, bytes_per_line, QImage.Format_RGB888)
|
| 493 |
+
pixmap = QPixmap.fromImage(q_img)
|
| 494 |
+
self.game_display.setPixmap(pixmap.scaled(400, 300, Qt.KeepAspectRatio))
|
| 495 |
+
except Exception as e:
|
| 496 |
+
print(f"Error updating display: {e}")
|
| 497 |
+
|
| 498 |
+
def closeEvent(self, event):
|
| 499 |
+
self.stop_training()
|
| 500 |
+
event.accept()
|
| 501 |
+
|
| 502 |
+
def main():
|
| 503 |
+
# Set random seeds for reproducibility
|
| 504 |
+
torch.manual_seed(42)
|
| 505 |
+
np.random.seed(42)
|
| 506 |
+
random.seed(42)
|
| 507 |
+
|
| 508 |
+
app = QApplication(sys.argv)
|
| 509 |
+
window = ALE_RLApp()
|
| 510 |
+
window.show()
|
| 511 |
+
sys.exit(app.exec_())
|
| 512 |
+
|
| 513 |
+
if __name__ == '__main__':
|
| 514 |
+
main()
|
ale_pyqt5/app_2.py
ADDED
|
@@ -0,0 +1,559 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
import numpy as np
|
| 4 |
+
import random
|
| 5 |
+
from collections import deque
|
| 6 |
+
import gymnasium as gym
|
| 7 |
+
import ale_py
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.optim as optim
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
from torch.distributions import Categorical
|
| 14 |
+
|
| 15 |
+
from PyQt5.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout,
|
| 16 |
+
QHBoxLayout, QPushButton, QLabel, QComboBox,
|
| 17 |
+
QTextEdit, QProgressBar, QTabWidget, QFrame)
|
| 18 |
+
from PyQt5.QtCore import QTimer, Qt, pyqtSignal, QThread
|
| 19 |
+
from PyQt5.QtGui import QImage, QPixmap, QFont
|
| 20 |
+
|
| 21 |
+
# Register ALE environments
|
| 22 |
+
gym.register_envs(ale_py)
|
| 23 |
+
|
| 24 |
+
# Environment setup
|
| 25 |
+
def create_env(env_name='ALE/SpaceInvaders-v5'):
|
| 26 |
+
"""
|
| 27 |
+
Create ALE environment with Gymnasium API
|
| 28 |
+
"""
|
| 29 |
+
env = gym.make(env_name, render_mode='rgb_array')
|
| 30 |
+
return env
|
| 31 |
+
|
| 32 |
+
# Enhanced Neural Network for Dueling DQN
|
| 33 |
+
class DuelingDQN(nn.Module):
|
| 34 |
+
def __init__(self, input_shape, n_actions):
|
| 35 |
+
super(DuelingDQN, self).__init__()
|
| 36 |
+
self.conv = nn.Sequential(
|
| 37 |
+
nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
|
| 38 |
+
nn.ReLU(),
|
| 39 |
+
nn.Conv2d(32, 64, kernel_size=4, stride=2),
|
| 40 |
+
nn.ReLU(),
|
| 41 |
+
nn.Conv2d(64, 64, kernel_size=3, stride=1),
|
| 42 |
+
nn.ReLU()
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
conv_out_size = self._get_conv_out(input_shape)
|
| 46 |
+
|
| 47 |
+
self.fc_advantage = nn.Sequential(
|
| 48 |
+
nn.Linear(conv_out_size, 256),
|
| 49 |
+
nn.ReLU(),
|
| 50 |
+
nn.Linear(256, n_actions)
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
self.fc_value = nn.Sequential(
|
| 54 |
+
nn.Linear(conv_out_size, 256),
|
| 55 |
+
nn.ReLU(),
|
| 56 |
+
nn.Linear(256, 1)
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
def _get_conv_out(self, shape):
|
| 60 |
+
o = self.conv(torch.zeros(1, *shape))
|
| 61 |
+
return int(np.prod(o.size()))
|
| 62 |
+
|
| 63 |
+
def forward(self, x):
|
| 64 |
+
conv_out = self.conv(x).view(x.size()[0], -1)
|
| 65 |
+
advantage = self.fc_advantage(conv_out)
|
| 66 |
+
value = self.fc_value(conv_out)
|
| 67 |
+
return value + advantage - advantage.mean()
|
| 68 |
+
|
| 69 |
+
# Enhanced Neural Network for PPO
|
| 70 |
+
class PPONetwork(nn.Module):
|
| 71 |
+
def __init__(self, input_shape, n_actions):
|
| 72 |
+
super(PPONetwork, self).__init__()
|
| 73 |
+
self.conv = nn.Sequential(
|
| 74 |
+
nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
|
| 75 |
+
nn.ReLU(),
|
| 76 |
+
nn.Conv2d(32, 64, kernel_size=4, stride=2),
|
| 77 |
+
nn.ReLU(),
|
| 78 |
+
nn.Conv2d(64, 64, kernel_size=3, stride=1),
|
| 79 |
+
nn.ReLU()
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
conv_out_size = self._get_conv_out(input_shape)
|
| 83 |
+
|
| 84 |
+
self.actor = nn.Sequential(
|
| 85 |
+
nn.Linear(conv_out_size, 256),
|
| 86 |
+
nn.ReLU(),
|
| 87 |
+
nn.Linear(256, n_actions),
|
| 88 |
+
nn.Softmax(dim=-1)
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
self.critic = nn.Sequential(
|
| 92 |
+
nn.Linear(conv_out_size, 256),
|
| 93 |
+
nn.ReLU(),
|
| 94 |
+
nn.Linear(256, 1)
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
def _get_conv_out(self, shape):
|
| 98 |
+
o = self.conv(torch.zeros(1, *shape))
|
| 99 |
+
return int(np.prod(o.size()))
|
| 100 |
+
|
| 101 |
+
def forward(self, x):
|
| 102 |
+
conv_out = self.conv(x).view(x.size()[0], -1)
|
| 103 |
+
return self.actor(conv_out), self.critic(conv_out)
|
| 104 |
+
|
| 105 |
+
# Enhanced Dueling DQN Agent with better training
|
| 106 |
+
class DuelingDQNAgent:
|
| 107 |
+
def __init__(self, state_dim, action_dim, lr=1e-4, gamma=0.99, epsilon=1.0,
|
| 108 |
+
epsilon_min=0.01, epsilon_decay=0.999, memory_size=50000, batch_size=32):
|
| 109 |
+
self.state_dim = state_dim
|
| 110 |
+
self.action_dim = action_dim
|
| 111 |
+
self.lr = lr
|
| 112 |
+
self.gamma = gamma
|
| 113 |
+
self.epsilon = epsilon
|
| 114 |
+
self.epsilon_min = epsilon_min
|
| 115 |
+
self.epsilon_decay = epsilon_decay
|
| 116 |
+
self.batch_size = batch_size
|
| 117 |
+
|
| 118 |
+
self.memory = deque(maxlen=memory_size)
|
| 119 |
+
self.model = DuelingDQN(state_dim, action_dim)
|
| 120 |
+
self.optimizer = optim.Adam(self.model.parameters(), lr=lr, weight_decay=1e-5)
|
| 121 |
+
self.criterion = nn.SmoothL1Loss() # Huber loss for better stability
|
| 122 |
+
|
| 123 |
+
# Target network for stable training
|
| 124 |
+
self.target_model = DuelingDQN(state_dim, action_dim)
|
| 125 |
+
self.update_target_network()
|
| 126 |
+
self.target_update_frequency = 1000
|
| 127 |
+
self.train_step = 0
|
| 128 |
+
|
| 129 |
+
def update_target_network(self):
|
| 130 |
+
self.target_model.load_state_dict(self.model.state_dict())
|
| 131 |
+
|
| 132 |
+
def remember(self, state, action, reward, next_state, done):
|
| 133 |
+
self.memory.append((state, action, reward, next_state, done))
|
| 134 |
+
|
| 135 |
+
def act(self, state):
|
| 136 |
+
if np.random.random() <= self.epsilon:
|
| 137 |
+
return random.randrange(self.action_dim)
|
| 138 |
+
|
| 139 |
+
state = torch.FloatTensor(state).unsqueeze(0)
|
| 140 |
+
with torch.no_grad():
|
| 141 |
+
q_values = self.model(state)
|
| 142 |
+
return np.argmax(q_values.detach().numpy())
|
| 143 |
+
|
| 144 |
+
def replay(self):
|
| 145 |
+
if len(self.memory) < self.batch_size:
|
| 146 |
+
return
|
| 147 |
+
|
| 148 |
+
batch = random.sample(self.memory, self.batch_size)
|
| 149 |
+
states = torch.FloatTensor(np.array([e[0] for e in batch]))
|
| 150 |
+
actions = torch.LongTensor([e[1] for e in batch])
|
| 151 |
+
rewards = torch.FloatTensor([e[2] for e in batch])
|
| 152 |
+
next_states = torch.FloatTensor(np.array([e[3] for e in batch]))
|
| 153 |
+
dones = torch.BoolTensor([e[4] for e in batch])
|
| 154 |
+
|
| 155 |
+
current_q_values = self.model(states).gather(1, actions.unsqueeze(1))
|
| 156 |
+
|
| 157 |
+
with torch.no_grad():
|
| 158 |
+
next_actions = self.model(next_states).max(1)[1]
|
| 159 |
+
next_q_values = self.target_model(next_states).gather(1, next_actions.unsqueeze(1)).squeeze()
|
| 160 |
+
|
| 161 |
+
target_q_values = rewards + (self.gamma * next_q_values * ~dones)
|
| 162 |
+
|
| 163 |
+
loss = self.criterion(current_q_values.squeeze(), target_q_values)
|
| 164 |
+
|
| 165 |
+
self.optimizer.zero_grad()
|
| 166 |
+
loss.backward()
|
| 167 |
+
# Gradient clipping
|
| 168 |
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
|
| 169 |
+
self.optimizer.step()
|
| 170 |
+
|
| 171 |
+
# Update target network periodically
|
| 172 |
+
self.train_step += 1
|
| 173 |
+
if self.train_step % self.target_update_frequency == 0:
|
| 174 |
+
self.update_target_network()
|
| 175 |
+
|
| 176 |
+
if self.epsilon > self.epsilon_min:
|
| 177 |
+
self.epsilon *= self.epsilon_decay
|
| 178 |
+
|
| 179 |
+
# Enhanced PPO Agent
|
| 180 |
+
class PPOAgent:
|
| 181 |
+
def __init__(self, state_dim, action_dim, lr=3e-4, gamma=0.99, epsilon=0.2,
|
| 182 |
+
entropy_coef=0.01, value_coef=0.5, ppo_epochs=4, batch_size=64):
|
| 183 |
+
self.state_dim = state_dim
|
| 184 |
+
self.action_dim = action_dim
|
| 185 |
+
self.gamma = gamma
|
| 186 |
+
self.epsilon = epsilon
|
| 187 |
+
self.entropy_coef = entropy_coef
|
| 188 |
+
self.value_coef = value_coef
|
| 189 |
+
self.ppo_epochs = ppo_epochs
|
| 190 |
+
self.batch_size = batch_size
|
| 191 |
+
|
| 192 |
+
self.model = PPONetwork(state_dim, action_dim)
|
| 193 |
+
self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
|
| 194 |
+
|
| 195 |
+
self.memory = []
|
| 196 |
+
|
| 197 |
+
def remember(self, state, action, reward, value, log_prob):
|
| 198 |
+
self.memory.append((state, action, reward, value, log_prob))
|
| 199 |
+
|
| 200 |
+
def act(self, state):
|
| 201 |
+
state = torch.FloatTensor(state).unsqueeze(0)
|
| 202 |
+
with torch.no_grad():
|
| 203 |
+
probs, value = self.model(state)
|
| 204 |
+
dist = Categorical(probs)
|
| 205 |
+
action = dist.sample()
|
| 206 |
+
return action.item(), dist.log_prob(action), value.squeeze()
|
| 207 |
+
|
| 208 |
+
def train(self):
|
| 209 |
+
if len(self.memory) < self.batch_size:
|
| 210 |
+
return
|
| 211 |
+
|
| 212 |
+
states, actions, rewards, values, log_probs = zip(*self.memory)
|
| 213 |
+
|
| 214 |
+
# Calculate returns and advantages
|
| 215 |
+
returns = []
|
| 216 |
+
R = 0
|
| 217 |
+
for r in reversed(rewards):
|
| 218 |
+
R = r + self.gamma * R
|
| 219 |
+
returns.insert(0, R)
|
| 220 |
+
|
| 221 |
+
returns = torch.FloatTensor(returns)
|
| 222 |
+
old_values = torch.FloatTensor(values)
|
| 223 |
+
advantages = returns - old_values
|
| 224 |
+
|
| 225 |
+
# Normalize advantages
|
| 226 |
+
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
|
| 227 |
+
|
| 228 |
+
# Convert to tensors
|
| 229 |
+
states_tensor = torch.FloatTensor(np.array(states))
|
| 230 |
+
actions_tensor = torch.LongTensor(actions)
|
| 231 |
+
old_log_probs = torch.FloatTensor(log_probs)
|
| 232 |
+
|
| 233 |
+
# PPO epochs
|
| 234 |
+
for _ in range(self.ppo_epochs):
|
| 235 |
+
# Get new probabilities
|
| 236 |
+
new_probs, new_values = self.model(states_tensor)
|
| 237 |
+
dist = Categorical(new_probs)
|
| 238 |
+
new_log_probs = dist.log_prob(actions_tensor)
|
| 239 |
+
entropy = dist.entropy().mean()
|
| 240 |
+
|
| 241 |
+
# PPO loss
|
| 242 |
+
ratio = (new_log_probs - old_log_probs).exp()
|
| 243 |
+
surr1 = ratio * advantages
|
| 244 |
+
surr2 = torch.clamp(ratio, 1 - self.epsilon, 1 + self.epsilon) * advantages
|
| 245 |
+
actor_loss = -torch.min(surr1, surr2).mean()
|
| 246 |
+
|
| 247 |
+
critic_loss = F.mse_loss(new_values.squeeze(), returns)
|
| 248 |
+
|
| 249 |
+
total_loss = actor_loss + self.value_coef * critic_loss - self.entropy_coef * entropy
|
| 250 |
+
|
| 251 |
+
self.optimizer.zero_grad()
|
| 252 |
+
total_loss.backward()
|
| 253 |
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.5)
|
| 254 |
+
self.optimizer.step()
|
| 255 |
+
|
| 256 |
+
self.memory = []
|
| 257 |
+
|
| 258 |
+
# Enhanced Training Thread with better state processing
|
| 259 |
+
class TrainingThread(QThread):
|
| 260 |
+
update_signal = pyqtSignal(dict)
|
| 261 |
+
frame_signal = pyqtSignal(np.ndarray)
|
| 262 |
+
|
| 263 |
+
def __init__(self, algorithm='dqn', env_name='ALE/Breakout-v5'):
|
| 264 |
+
super().__init__()
|
| 265 |
+
self.algorithm = algorithm
|
| 266 |
+
self.env_name = env_name
|
| 267 |
+
self.running = False
|
| 268 |
+
self.env = None
|
| 269 |
+
self.agent = None
|
| 270 |
+
|
| 271 |
+
def preprocess_state(self, state):
|
| 272 |
+
# Convert to CHW format, normalize, and convert to grayscale
|
| 273 |
+
if len(state.shape) == 3:
|
| 274 |
+
# Convert to grayscale and resize for faster processing
|
| 275 |
+
state = state.mean(axis=2, keepdims=True) # Convert to grayscale
|
| 276 |
+
state = state.transpose((2, 0, 1))
|
| 277 |
+
state = state / 255.0
|
| 278 |
+
return state
|
| 279 |
+
|
| 280 |
+
def run(self):
|
| 281 |
+
self.running = True
|
| 282 |
+
try:
|
| 283 |
+
self.env = create_env(self.env_name)
|
| 284 |
+
state, info = self.env.reset()
|
| 285 |
+
state = self.preprocess_state(state)
|
| 286 |
+
|
| 287 |
+
n_actions = self.env.action_space.n
|
| 288 |
+
state_dim = state.shape
|
| 289 |
+
|
| 290 |
+
print(f"๐ฎ Training on: {self.env_name}")
|
| 291 |
+
print(f"๐ State shape: {state_dim}, Actions: {n_actions}")
|
| 292 |
+
print(f"๐ค Algorithm: {self.algorithm}")
|
| 293 |
+
|
| 294 |
+
if self.algorithm == 'dqn':
|
| 295 |
+
self.agent = DuelingDQNAgent(state_dim, n_actions)
|
| 296 |
+
else:
|
| 297 |
+
self.agent = PPOAgent(state_dim, n_actions)
|
| 298 |
+
|
| 299 |
+
episode = 0
|
| 300 |
+
total_reward = 0
|
| 301 |
+
steps = 0
|
| 302 |
+
episode_rewards = []
|
| 303 |
+
best_reward = -float('inf')
|
| 304 |
+
|
| 305 |
+
while self.running:
|
| 306 |
+
try:
|
| 307 |
+
if self.algorithm == 'dqn':
|
| 308 |
+
action = self.agent.act(state)
|
| 309 |
+
next_state, reward, terminated, truncated, info = self.env.step(action)
|
| 310 |
+
done = terminated or truncated
|
| 311 |
+
next_state = self.preprocess_state(next_state)
|
| 312 |
+
self.agent.remember(state, action, reward, next_state, done)
|
| 313 |
+
self.agent.replay()
|
| 314 |
+
else:
|
| 315 |
+
action, log_prob, value = self.agent.act(state)
|
| 316 |
+
next_state, reward, terminated, truncated, info = self.env.step(action)
|
| 317 |
+
done = terminated or truncated
|
| 318 |
+
next_state = self.preprocess_state(next_state)
|
| 319 |
+
self.agent.remember(state, action, reward, value, log_prob)
|
| 320 |
+
if done:
|
| 321 |
+
self.agent.train()
|
| 322 |
+
|
| 323 |
+
state = next_state
|
| 324 |
+
total_reward += reward
|
| 325 |
+
steps += 1
|
| 326 |
+
|
| 327 |
+
# Emit frame for display
|
| 328 |
+
try:
|
| 329 |
+
frame = self.env.render()
|
| 330 |
+
if frame is not None:
|
| 331 |
+
self.frame_signal.emit(frame)
|
| 332 |
+
except Exception as e:
|
| 333 |
+
# Create a placeholder frame if rendering fails
|
| 334 |
+
frame = np.zeros((210, 160, 3), dtype=np.uint8)
|
| 335 |
+
self.frame_signal.emit(frame)
|
| 336 |
+
|
| 337 |
+
# Emit training progress more frequently for better feedback
|
| 338 |
+
if steps % 5 == 0:
|
| 339 |
+
avg_reward = np.mean(episode_rewards[-10:]) if episode_rewards else total_reward
|
| 340 |
+
progress_data = {
|
| 341 |
+
'episode': episode,
|
| 342 |
+
'total_reward': total_reward,
|
| 343 |
+
'steps': steps,
|
| 344 |
+
'epsilon': self.agent.epsilon if self.algorithm == 'dqn' else 0.2,
|
| 345 |
+
'env_name': self.env_name,
|
| 346 |
+
'lives': info.get('lives', 0) if isinstance(info, dict) else 0,
|
| 347 |
+
'avg_reward': avg_reward,
|
| 348 |
+
'best_reward': best_reward
|
| 349 |
+
}
|
| 350 |
+
self.update_signal.emit(progress_data)
|
| 351 |
+
|
| 352 |
+
if terminated or truncated:
|
| 353 |
+
episode_rewards.append(total_reward)
|
| 354 |
+
if total_reward > best_reward:
|
| 355 |
+
best_reward = total_reward
|
| 356 |
+
|
| 357 |
+
avg_reward = np.mean(episode_rewards[-10:]) if episode_rewards else total_reward
|
| 358 |
+
|
| 359 |
+
print(f"๐ฏ Episode {episode}: Reward: {total_reward:.1f}, "
|
| 360 |
+
f"Steps: {steps}, Avg (last 10): {avg_reward:.1f}, "
|
| 361 |
+
f"Best: {best_reward:.1f}, Epsilon: {self.agent.epsilon:.3f}")
|
| 362 |
+
|
| 363 |
+
episode += 1
|
| 364 |
+
state, info = self.env.reset()
|
| 365 |
+
state = self.preprocess_state(state)
|
| 366 |
+
total_reward = 0
|
| 367 |
+
steps = 0
|
| 368 |
+
|
| 369 |
+
except Exception as e:
|
| 370 |
+
print(f"โ Error in training loop: {e}")
|
| 371 |
+
import traceback
|
| 372 |
+
traceback.print_exc()
|
| 373 |
+
break
|
| 374 |
+
|
| 375 |
+
except Exception as e:
|
| 376 |
+
print(f"โ Error setting up environment: {e}")
|
| 377 |
+
import traceback
|
| 378 |
+
traceback.print_exc()
|
| 379 |
+
|
| 380 |
+
def stop(self):
|
| 381 |
+
self.running = False
|
| 382 |
+
if self.env:
|
| 383 |
+
self.env.close()
|
| 384 |
+
|
| 385 |
+
# Enhanced Main Application Window
|
| 386 |
+
class ALE_RLApp(QMainWindow):
|
| 387 |
+
def __init__(self):
|
| 388 |
+
super().__init__()
|
| 389 |
+
self.training_thread = None
|
| 390 |
+
self.init_ui()
|
| 391 |
+
|
| 392 |
+
def init_ui(self):
|
| 393 |
+
self.setWindowTitle('๐ฎ ALE Arcade RL Training - Enhanced')
|
| 394 |
+
self.setGeometry(100, 100, 1200, 800)
|
| 395 |
+
|
| 396 |
+
central_widget = QWidget()
|
| 397 |
+
self.setCentralWidget(central_widget)
|
| 398 |
+
layout = QVBoxLayout(central_widget)
|
| 399 |
+
|
| 400 |
+
# Title
|
| 401 |
+
title = QLabel('๐ฎ Arcade Reinforcement Learning (ALE) - Enhanced Training')
|
| 402 |
+
title.setFont(QFont('Arial', 16, QFont.Bold))
|
| 403 |
+
title.setAlignment(Qt.AlignCenter)
|
| 404 |
+
layout.addWidget(title)
|
| 405 |
+
|
| 406 |
+
# Control Panel
|
| 407 |
+
control_layout = QHBoxLayout()
|
| 408 |
+
|
| 409 |
+
self.algorithm_combo = QComboBox()
|
| 410 |
+
self.algorithm_combo.addItems(['Dueling DQN', 'PPO'])
|
| 411 |
+
|
| 412 |
+
self.env_combo = QComboBox()
|
| 413 |
+
self.env_combo.addItems([
|
| 414 |
+
'ALE/Breakout-v5',
|
| 415 |
+
'ALE/Pong-v5',
|
| 416 |
+
'ALE/SpaceInvaders-v5',
|
| 417 |
+
'ALE/Assault-v5',
|
| 418 |
+
'ALE/BeamRider-v5',
|
| 419 |
+
'ALE/Enduro-v5',
|
| 420 |
+
'ALE/Seaquest-v5',
|
| 421 |
+
'ALE/Qbert-v5'
|
| 422 |
+
])
|
| 423 |
+
|
| 424 |
+
self.start_btn = QPushButton('๐ Start Training')
|
| 425 |
+
self.start_btn.clicked.connect(self.start_training)
|
| 426 |
+
|
| 427 |
+
self.stop_btn = QPushButton('โน๏ธ Stop Training')
|
| 428 |
+
self.stop_btn.clicked.connect(self.stop_training)
|
| 429 |
+
self.stop_btn.setEnabled(False)
|
| 430 |
+
|
| 431 |
+
control_layout.addWidget(QLabel('๐ค Algorithm:'))
|
| 432 |
+
control_layout.addWidget(self.algorithm_combo)
|
| 433 |
+
control_layout.addWidget(QLabel('๐ฎ Environment:'))
|
| 434 |
+
control_layout.addWidget(self.env_combo)
|
| 435 |
+
control_layout.addWidget(self.start_btn)
|
| 436 |
+
control_layout.addWidget(self.stop_btn)
|
| 437 |
+
control_layout.addStretch()
|
| 438 |
+
|
| 439 |
+
layout.addLayout(control_layout)
|
| 440 |
+
|
| 441 |
+
# Content Area
|
| 442 |
+
content_layout = QHBoxLayout()
|
| 443 |
+
|
| 444 |
+
# Left side - Game Display
|
| 445 |
+
left_frame = QFrame()
|
| 446 |
+
left_frame.setFrameStyle(QFrame.Box)
|
| 447 |
+
left_layout = QVBoxLayout(left_frame)
|
| 448 |
+
|
| 449 |
+
self.game_display = QLabel()
|
| 450 |
+
self.game_display.setMinimumSize(400, 300)
|
| 451 |
+
self.game_display.setAlignment(Qt.AlignCenter)
|
| 452 |
+
self.game_display.setText('Game display will appear here\nPress "๐ Start Training" to begin')
|
| 453 |
+
self.game_display.setStyleSheet('border: 1px solid gray; background-color: black; color: white; font-size: 14px;')
|
| 454 |
+
|
| 455 |
+
left_layout.addWidget(QLabel('๐ฎ Game Display:'))
|
| 456 |
+
left_layout.addWidget(self.game_display)
|
| 457 |
+
|
| 458 |
+
# Right side - Training Info
|
| 459 |
+
right_frame = QFrame()
|
| 460 |
+
right_frame.setFrameStyle(QFrame.Box)
|
| 461 |
+
right_layout = QVBoxLayout(right_frame)
|
| 462 |
+
|
| 463 |
+
# Progress bars with better styling
|
| 464 |
+
self.env_label = QLabel('๐ฏ Environment: Not started')
|
| 465 |
+
self.episode_label = QLabel('๐ Episode: 0')
|
| 466 |
+
self.reward_label = QLabel('๐ Total Reward: 0')
|
| 467 |
+
self.avg_reward_label = QLabel('๐ Avg Reward (last 10): 0')
|
| 468 |
+
self.best_reward_label = QLabel('โญ Best Reward: 0')
|
| 469 |
+
self.steps_label = QLabel('โฑ๏ธ Steps: 0')
|
| 470 |
+
self.epsilon_label = QLabel('๐ฒ Epsilon: 0')
|
| 471 |
+
self.lives_label = QLabel('โค๏ธ Lives: 0')
|
| 472 |
+
|
| 473 |
+
# Style the labels
|
| 474 |
+
for label in [self.env_label, self.episode_label, self.reward_label,
|
| 475 |
+
self.avg_reward_label, self.best_reward_label, self.steps_label,
|
| 476 |
+
self.epsilon_label, self.lives_label]:
|
| 477 |
+
label.setStyleSheet('font-weight: bold; font-size: 12px;')
|
| 478 |
+
|
| 479 |
+
right_layout.addWidget(self.env_label)
|
| 480 |
+
right_layout.addWidget(self.episode_label)
|
| 481 |
+
right_layout.addWidget(self.reward_label)
|
| 482 |
+
right_layout.addWidget(self.avg_reward_label)
|
| 483 |
+
right_layout.addWidget(self.best_reward_label)
|
| 484 |
+
right_layout.addWidget(self.steps_label)
|
| 485 |
+
right_layout.addWidget(self.epsilon_label)
|
| 486 |
+
right_layout.addWidget(self.lives_label)
|
| 487 |
+
|
| 488 |
+
# Training log
|
| 489 |
+
right_layout.addWidget(QLabel('๐ Training Log:'))
|
| 490 |
+
self.log_text = QTextEdit()
|
| 491 |
+
self.log_text.setMaximumHeight(200)
|
| 492 |
+
self.log_text.setStyleSheet('font-family: monospace; font-size: 10px;')
|
| 493 |
+
right_layout.addWidget(self.log_text)
|
| 494 |
+
|
| 495 |
+
content_layout.addWidget(left_frame)
|
| 496 |
+
content_layout.addWidget(right_frame)
|
| 497 |
+
layout.addLayout(content_layout)
|
| 498 |
+
|
| 499 |
+
def start_training(self):
|
| 500 |
+
algorithm = 'dqn' if self.algorithm_combo.currentText() == 'Dueling DQN' else 'ppo'
|
| 501 |
+
env_name = self.env_combo.currentText()
|
| 502 |
+
|
| 503 |
+
self.training_thread = TrainingThread(algorithm, env_name)
|
| 504 |
+
self.training_thread.update_signal.connect(self.update_training_info)
|
| 505 |
+
self.training_thread.frame_signal.connect(self.update_game_display)
|
| 506 |
+
self.training_thread.start()
|
| 507 |
+
|
| 508 |
+
self.start_btn.setEnabled(False)
|
| 509 |
+
self.stop_btn.setEnabled(True)
|
| 510 |
+
|
| 511 |
+
self.log_text.append(f'๐ Started {self.algorithm_combo.currentText()} training on {env_name}...')
|
| 512 |
+
|
| 513 |
+
def stop_training(self):
|
| 514 |
+
if self.training_thread:
|
| 515 |
+
self.training_thread.stop()
|
| 516 |
+
self.training_thread.wait()
|
| 517 |
+
|
| 518 |
+
self.start_btn.setEnabled(True)
|
| 519 |
+
self.stop_btn.setEnabled(False)
|
| 520 |
+
self.log_text.append('โน๏ธ Training stopped.')
|
| 521 |
+
|
| 522 |
+
def update_training_info(self, data):
|
| 523 |
+
self.env_label.setText(f'๐ฏ Environment: {data.get("env_name", "Unknown")}')
|
| 524 |
+
self.episode_label.setText(f'๐ Episode: {data["episode"]}')
|
| 525 |
+
self.reward_label.setText(f'๐ Total Reward: {data["total_reward"]:.1f}')
|
| 526 |
+
self.avg_reward_label.setText(f'๐ Avg Reward (last 10): {data.get("avg_reward", 0):.1f}')
|
| 527 |
+
self.best_reward_label.setText(f'โญ Best Reward: {data.get("best_reward", 0):.1f}')
|
| 528 |
+
self.steps_label.setText(f'โฑ๏ธ Steps: {data["steps"]}')
|
| 529 |
+
self.epsilon_label.setText(f'๐ฒ Epsilon: {data["epsilon"]:.3f}')
|
| 530 |
+
self.lives_label.setText(f'โค๏ธ Lives: {data.get("lives", 0)}')
|
| 531 |
+
|
| 532 |
+
def update_game_display(self, frame):
|
| 533 |
+
if frame is not None:
|
| 534 |
+
try:
|
| 535 |
+
h, w, ch = frame.shape
|
| 536 |
+
bytes_per_line = ch * w
|
| 537 |
+
q_img = QImage(frame.data, w, h, bytes_per_line, QImage.Format_RGB888)
|
| 538 |
+
pixmap = QPixmap.fromImage(q_img)
|
| 539 |
+
self.game_display.setPixmap(pixmap.scaled(400, 300, Qt.KeepAspectRatio))
|
| 540 |
+
except Exception as e:
|
| 541 |
+
print(f"Error updating display: {e}")
|
| 542 |
+
|
| 543 |
+
def closeEvent(self, event):
|
| 544 |
+
self.stop_training()
|
| 545 |
+
event.accept()
|
| 546 |
+
|
| 547 |
+
def main():
|
| 548 |
+
# Set random seeds for reproducibility
|
| 549 |
+
torch.manual_seed(42)
|
| 550 |
+
np.random.seed(42)
|
| 551 |
+
random.seed(42)
|
| 552 |
+
|
| 553 |
+
app = QApplication(sys.argv)
|
| 554 |
+
window = ALE_RLApp()
|
| 555 |
+
window.show()
|
| 556 |
+
sys.exit(app.exec_())
|
| 557 |
+
|
| 558 |
+
if __name__ == '__main__':
|
| 559 |
+
main()
|
ale_pyqt5/installed_packages_ale_py.txt
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
ale-py==0.11.2
|
| 2 |
+
cloudpickle==3.1.2
|
| 3 |
+
contourpy==1.3.3
|
| 4 |
+
cycler==0.12.1
|
| 5 |
+
Farama-Notifications==0.0.4
|
| 6 |
+
filelock==3.20.0
|
| 7 |
+
fonttools==4.60.1
|
| 8 |
+
fsspec==2025.10.0
|
| 9 |
+
gym==0.26.2
|
| 10 |
+
gym-notices==0.1.0
|
| 11 |
+
gymnasium==1.2.2
|
| 12 |
+
Jinja2==3.1.6
|
| 13 |
+
kiwisolver==1.4.9
|
| 14 |
+
MarkupSafe==3.0.3
|
| 15 |
+
matplotlib==3.10.7
|
| 16 |
+
mpmath==1.3.0
|
| 17 |
+
networkx==3.5
|
| 18 |
+
numpy==2.2.6
|
| 19 |
+
opencv-python==4.12.0.88
|
| 20 |
+
packaging==25.0
|
| 21 |
+
pillow==12.0.0
|
| 22 |
+
pyglet==1.5.11
|
| 23 |
+
pyparsing==3.2.5
|
| 24 |
+
python-dateutil==2.9.0.post0
|
| 25 |
+
setuptools==80.9.0
|
| 26 |
+
six==1.17.0
|
| 27 |
+
sympy==1.14.0
|
| 28 |
+
torch==2.9.0
|
| 29 |
+
tqdm==4.67.1
|
| 30 |
+
typing_extensions==4.15.0
|