TroglodyteDerivations's picture
Upload 32 files
2db463d verified
import pickle
import random
import time
from collections import deque
import gym_super_mario_bros
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from gym_super_mario_bros.actions import COMPLEX_MOVEMENT
from nes_py.wrappers import JoypadSpace
from PyQt5.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout,
QHBoxLayout, QPushButton, QLabel, QComboBox,
QTextEdit, QProgressBar, QTabWidget, QFrame, QGroupBox)
from PyQt5.QtCore import QTimer, Qt, pyqtSignal, QThread
from PyQt5.QtGui import QImage, QPixmap, QFont
import sys
import cv2
# Import your wrappers (make sure this module exists)
try:
from wrappers import *
except ImportError:
# Create a proper wrapper if the module doesn't exist
class SimpleWrapper:
def __init__(self, env):
self.env = env
self.action_space = env.action_space
self.observation_space = env.observation_space
def reset(self):
return self.env.reset()
def step(self, action):
return self.env.step(action)
def render(self, mode='rgb_array'):
return self.env.render(mode)
def close(self):
if hasattr(self.env, 'close'):
self.env.close()
def wrap_mario(env):
return SimpleWrapper(env)
class FrameStacker:
"""Handles frame stacking and preprocessing"""
def __init__(self, frame_size=(84, 84), stack_size=4):
self.frame_size = frame_size
self.stack_size = stack_size
self.frames = deque(maxlen=stack_size)
def preprocess_frame(self, frame):
"""Convert frame to grayscale and resize"""
# Convert to grayscale
gray = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
# Resize to 84x84
resized = cv2.resize(gray, self.frame_size, interpolation=cv2.INTER_AREA)
# Normalize to [0, 1]
normalized = resized.astype(np.float32) / 255.0
return normalized
def reset(self, frame):
"""Reset frame stack with initial frame"""
self.frames.clear()
processed_frame = self.preprocess_frame(frame)
for _ in range(self.stack_size):
self.frames.append(processed_frame)
return self.get_stacked_frames()
def append(self, frame):
"""Add new frame to stack"""
processed_frame = self.preprocess_frame(frame)
self.frames.append(processed_frame)
return self.get_stacked_frames()
def get_stacked_frames(self):
"""Get stacked frames as numpy array"""
stacked = np.array(self.frames)
return np.ascontiguousarray(stacked)
class replay_memory(object):
def __init__(self, N):
self.memory = deque(maxlen=N)
def push(self, transition):
self.memory.append(transition)
def sample(self, n):
return random.sample(self.memory, n)
def __len__(self):
return len(self.memory)
class DuelingDQNModel(nn.Module):
def __init__(self, n_frame, n_action, device):
super(DuelingDQNModel, self).__init__()
# CNN layers for feature extraction
self.conv_layers = nn.Sequential(
nn.Conv2d(n_frame, 32, kernel_size=8, stride=4),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=4, stride=2),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, stride=1),
nn.ReLU()
)
# Calculate conv output size
self.conv_out_size = self._get_conv_out((n_frame, 84, 84))
# Advantage stream
self.advantage_stream = nn.Sequential(
nn.Linear(self.conv_out_size, 512),
nn.ReLU(),
nn.Linear(512, n_action)
)
# Value stream
self.value_stream = nn.Sequential(
nn.Linear(self.conv_out_size, 512),
nn.ReLU(),
nn.Linear(512, 1)
)
self.device = device
self.apply(self.init_weights)
def _get_conv_out(self, shape):
with torch.no_grad():
x = torch.zeros(1, *shape)
x = self.conv_layers(x)
return int(np.prod(x.size()))
def init_weights(self, m):
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
torch.nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
m.bias.data.fill_(0.01)
def forward(self, x):
if not isinstance(x, torch.Tensor):
x = torch.FloatTensor(x).to(self.device)
# Forward through conv layers
x = self.conv_layers(x)
x = x.view(x.size(0), -1)
# Forward through advantage and value streams
advantage = self.advantage_stream(x)
value = self.value_stream(x)
# Combine value and advantage
q_values = value + (advantage - advantage.mean(dim=1, keepdim=True))
return q_values
def train(q, q_target, memory, batch_size, gamma, optimizer, device):
if len(memory) < batch_size:
return 0.0
transitions = memory.sample(batch_size)
s, r, a, s_prime, done = list(map(list, zip(*transitions)))
# Ensure positive strides for all arrays
s = np.array([np.ascontiguousarray(arr) for arr in s])
s_prime = np.array([np.ascontiguousarray(arr) for arr in s_prime])
# Move computations to device
s_tensor = torch.FloatTensor(s).to(device)
s_prime_tensor = torch.FloatTensor(s_prime).to(device)
# Get next Q values from target network
with torch.no_grad():
next_q_values = q_target(s_prime_tensor)
next_actions = next_q_values.max(1)[1].unsqueeze(1)
next_q_value = next_q_values.gather(1, next_actions)
# Calculate target Q values
r = torch.FloatTensor(r).unsqueeze(1).to(device)
done = torch.FloatTensor(done).unsqueeze(1).to(device)
target_q_values = r + gamma * next_q_value * (1 - done)
# Get current Q values
a_tensor = torch.LongTensor(a).unsqueeze(1).to(device)
current_q_values = q(s_tensor).gather(1, a_tensor)
# Calculate loss
loss = F.smooth_l1_loss(current_q_values, target_q_values)
# Optimize
optimizer.zero_grad()
loss.backward()
# Gradient clipping
torch.nn.utils.clip_grad_norm_(q.parameters(), max_norm=10.0)
optimizer.step()
return loss.item()
def copy_weights(q, q_target):
q_dict = q.state_dict()
q_target.load_state_dict(q_dict)
class MarioTrainingThread(QThread):
update_signal = pyqtSignal(dict)
frame_signal = pyqtSignal(np.ndarray)
def __init__(self, device="cpu"):
super().__init__()
self.device = device
self.running = False
self.env = None
self.q = None
self.q_target = None
self.optimizer = None
self.frame_stacker = None
# Training parameters
self.gamma = 0.99
self.batch_size = 32
self.memory_size = 10000
self.eps = 1.0 # Start with full exploration
self.eps_min = 0.01
self.eps_decay = 0.995
self.update_interval = 1000
self.save_interval = 100
self.print_interval = 10
self.memory = None
self.t = 0
self.k = 0
self.total_score = 0.0
self.loss_accumulator = 0.0
self.best_score = -float('inf')
self.last_x_pos = 0
def setup_training(self):
n_frame = 4 # Number of stacked frames
try:
self.env = gym_super_mario_bros.make("SuperMarioBros-v3")
self.env = JoypadSpace(self.env, COMPLEX_MOVEMENT)
self.env = wrap_mario(self.env)
# Initialize frame stacker
self.frame_stacker = FrameStacker(frame_size=(84, 84), stack_size=n_frame)
self.q = DuelingDQNModel(n_frame, self.env.action_space.n, self.device).to(self.device)
self.q_target = DuelingDQNModel(n_frame, self.env.action_space.n, self.device).to(self.device)
copy_weights(self.q, self.q_target)
# Set target network to eval mode
self.q_target.eval()
# Optimizer
self.optimizer = optim.Adam(self.q.parameters(), lr=0.0001, weight_decay=1e-5)
self.memory = replay_memory(self.memory_size)
self.log_message(f"โœ… Training setup complete - Actions: {self.env.action_space.n}, Device: {self.device}")
except Exception as e:
self.log_message(f"โŒ Error setting up training: {e}")
import traceback
traceback.print_exc()
self.running = False
def run(self):
self.running = True
self.setup_training()
if not self.running:
return
start_time = time.perf_counter()
score_lst = []
try:
for k in range(1000000):
if not self.running:
break
# Reset environment and frame stacker
frame = self.env.reset()
s = self.frame_stacker.reset(frame)
done = False
episode_loss = 0.0
episode_steps = 0
episode_score = 0.0
self.last_x_pos = 0
while not done and self.running:
# Ensure state has positive strides before processing
s_processed = np.ascontiguousarray(s)
# Epsilon-greedy action selection
if np.random.random() <= self.eps:
a = self.env.action_space.sample()
else:
with torch.no_grad():
# Add batch dimension and create tensor
state_tensor = torch.FloatTensor(s_processed).unsqueeze(0).to(self.device)
q_values = self.q(state_tensor)
if self.device == "cuda" or self.device == "mps":
a = np.argmax(q_values.cpu().numpy())
else:
a = np.argmax(q_values.detach().numpy())
# Take action
frame, r, done, info = self.env.step(a)
# Update frame stack
s_prime = self.frame_stacker.append(frame)
episode_score += r
# Enhanced reward shaping
reward = r # Start with original reward
# Bonus for x_pos progress
if 'x_pos' in info:
x_pos = info['x_pos']
x_progress = x_pos - self.last_x_pos
if x_progress > 0:
reward += 0.1 * x_progress
self.last_x_pos = x_pos
# Large bonus for completing the level
if done and info.get('flag_get', False):
reward += 100.0
self.log_message(f"๐ŸŽ‰ LEVEL COMPLETED at episode {k}! ๐ŸŽ‰")
# Store transition with contiguous arrays
s_contiguous = np.ascontiguousarray(s)
s_prime_contiguous = np.ascontiguousarray(s_prime)
self.memory.push((s_contiguous, float(reward), int(a), s_prime_contiguous, int(1 - done)))
s = s_prime
stage = info.get('stage', 1)
world = info.get('world', 1)
# Emit frame for display
try:
display_frame = self.env.render()
if display_frame is not None:
# Ensure frame has positive strides
frame_contiguous = np.ascontiguousarray(display_frame)
self.frame_signal.emit(frame_contiguous)
except Exception as e:
# Create a placeholder frame if rendering fails
frame = np.zeros((240, 256, 3), dtype=np.uint8)
self.frame_signal.emit(frame)
# Train only if we have enough samples
if len(self.memory) > self.batch_size:
loss_val = train(self.q, self.q_target, self.memory, self.batch_size,
self.gamma, self.optimizer, self.device)
if loss_val > 0:
self.loss_accumulator += loss_val
episode_loss += loss_val
self.t += 1
# Update target network
if self.t % self.update_interval == 0:
copy_weights(self.q, self.q_target)
episode_steps += 1
# Emit training progress every 10 steps
if episode_steps % 10 == 0:
progress_data = {
'episode': k,
'total_reward': episode_score,
'steps': episode_steps,
'epsilon': self.eps,
'world': world,
'stage': stage,
'loss': episode_loss / (episode_steps + 1e-8),
'memory_size': len(self.memory),
'x_pos': info.get('x_pos', 0),
'score': info.get('score', 0),
'coins': info.get('coins', 0),
'time': info.get('time', 400),
'flag_get': info.get('flag_get', False)
}
self.update_signal.emit(progress_data)
# Epsilon decay after each episode
if self.eps > self.eps_min:
self.eps *= self.eps_decay
# Update total score
self.total_score += episode_score
# Save best model
if episode_score > self.best_score and k > 0:
self.best_score = episode_score
torch.save(self.q.state_dict(), "enhanced_mario_q_best.pth")
torch.save(self.q_target.state_dict(), "enhanced_mario_q_target_best.pth")
self.log_message(f"๐Ÿ’พ New best model saved! Score: {self.best_score:.2f}")
# Save models periodically
if k % self.save_interval == 0 and k > 0:
torch.save(self.q.state_dict(), "enhanced_mario_q.pth")
torch.save(self.q_target.state_dict(), "enhanced_mario_q_target.pth")
self.log_message(f"๐Ÿ’พ Models saved at episode {k}")
# Print progress
if k % self.print_interval == 0 and k > 0:
time_spent = time.perf_counter() - start_time
start_time = time.perf_counter()
avg_loss = self.loss_accumulator / (self.print_interval * max(episode_steps, 1))
avg_score = self.total_score / self.print_interval
log_msg = (
f"{self.device} | Ep: {k} | Score: {avg_score:.2f} | Loss: {avg_loss:.4f} | "
f"Stage: {world}-{stage} | Eps: {self.eps:.3f} | Time: {time_spent:.2f}s | "
f"Mem: {len(self.memory)} | Steps: {episode_steps}"
)
self.log_message(log_msg)
score_lst.append(avg_score)
self.total_score = 0.0
self.loss_accumulator = 0.0
try:
pickle.dump(score_lst, open("score.p", "wb"))
except Exception as e:
self.log_message(f"โš ๏ธ Could not save scores: {e}")
self.k = k
except Exception as e:
self.log_message(f"โŒ Training error: {e}")
import traceback
traceback.print_exc()
def log_message(self, message):
progress_data = {
'log_message': message
}
self.update_signal.emit(progress_data)
def stop(self):
self.running = False
if self.env:
try:
self.env.close()
except:
pass
class MarioRLApp(QMainWindow):
def __init__(self):
super().__init__()
self.training_thread = None
self.init_ui()
def init_ui(self):
self.setWindowTitle('๐ŸŽฎ Super Mario Bros - Dueling DQN Training')
self.setGeometry(100, 100, 1200, 800)
central_widget = QWidget()
self.setCentralWidget(central_widget)
layout = QVBoxLayout(central_widget)
# Title
title = QLabel('๐ŸŽฎ Super Mario Bros - Enhanced Dueling DQN')
title.setFont(QFont('Arial', 16, QFont.Bold))
title.setAlignment(Qt.AlignCenter)
layout.addWidget(title)
# Control Panel
control_layout = QHBoxLayout()
self.device_combo = QComboBox()
self.device_combo.addItems(['cpu', 'cuda', 'mps'])
self.start_btn = QPushButton('Start Training')
self.start_btn.clicked.connect(self.start_training)
self.stop_btn = QPushButton('Stop Training')
self.stop_btn.clicked.connect(self.stop_training)
self.stop_btn.setEnabled(False)
self.load_btn = QPushButton('Load Model')
self.load_btn.clicked.connect(self.load_model)
control_layout.addWidget(QLabel('Device:'))
control_layout.addWidget(self.device_combo)
control_layout.addWidget(self.start_btn)
control_layout.addWidget(self.stop_btn)
control_layout.addWidget(self.load_btn)
control_layout.addStretch()
layout.addLayout(control_layout)
# Content Area
content_layout = QHBoxLayout()
# Left side - Game Display
left_frame = QFrame()
left_frame.setFrameStyle(QFrame.Box)
left_layout = QVBoxLayout(left_frame)
self.game_display = QLabel()
self.game_display.setMinimumSize(400, 300)
self.game_display.setAlignment(Qt.AlignCenter)
self.game_display.setText('Game display will appear here\nPress "Start Training" to begin')
self.game_display.setStyleSheet('border: 1px solid gray; background-color: black; color: white;')
left_layout.addWidget(QLabel('Mario Game Display:'))
left_layout.addWidget(self.game_display)
# Right side - Training Info
right_frame = QFrame()
right_frame.setFrameStyle(QFrame.Box)
right_layout = QVBoxLayout(right_frame)
# Training stats
stats_group = QGroupBox("Training Statistics")
stats_layout = QVBoxLayout(stats_group)
self.episode_label = QLabel('Episode: 0')
self.world_label = QLabel('World: 1-1')
self.score_label = QLabel('Score: 0')
self.reward_label = QLabel('Episode Reward: 0')
self.steps_label = QLabel('Steps: 0')
self.epsilon_label = QLabel('Epsilon: 1.000')
self.loss_label = QLabel('Loss: 0.0000')
self.memory_label = QLabel('Memory: 0')
self.xpos_label = QLabel('X Position: 0')
self.coins_label = QLabel('Coins: 0')
self.time_label = QLabel('Time: 400')
self.flag_label = QLabel('Flag: No')
stats_layout.addWidget(self.episode_label)
stats_layout.addWidget(self.world_label)
stats_layout.addWidget(self.score_label)
stats_layout.addWidget(self.reward_label)
stats_layout.addWidget(self.steps_label)
stats_layout.addWidget(self.epsilon_label)
stats_layout.addWidget(self.loss_label)
stats_layout.addWidget(self.memory_label)
stats_layout.addWidget(self.xpos_label)
stats_layout.addWidget(self.coins_label)
stats_layout.addWidget(self.time_label)
stats_layout.addWidget(self.flag_label)
right_layout.addWidget(stats_group)
# Training log
right_layout.addWidget(QLabel('Training Log:'))
self.log_text = QTextEdit()
self.log_text.setMaximumHeight(300)
right_layout.addWidget(self.log_text)
content_layout.addWidget(left_frame)
content_layout.addWidget(right_frame)
layout.addLayout(content_layout)
def start_training(self):
device = self.device_combo.currentText()
# Check device availability
if device == "cuda" and not torch.cuda.is_available():
self.log_text.append("โŒ CUDA not available, using CPU instead")
device = "cpu"
elif device == "mps" and not torch.backends.mps.is_available():
self.log_text.append("โŒ MPS not available, using CPU instead")
device = "cpu"
self.training_thread = MarioTrainingThread(device)
self.training_thread.update_signal.connect(self.update_training_info)
self.training_thread.frame_signal.connect(self.update_game_display)
self.training_thread.start()
self.start_btn.setEnabled(False)
self.stop_btn.setEnabled(True)
self.log_text.append(f'๐Ÿš€ Started Dueling DQN training on {device}...')
def stop_training(self):
if self.training_thread:
self.training_thread.stop()
self.training_thread.wait()
self.start_btn.setEnabled(True)
self.stop_btn.setEnabled(False)
self.log_text.append('โน๏ธ Training stopped.')
def load_model(self):
# Placeholder for model loading functionality
self.log_text.append('๐Ÿ“ Load model functionality not implemented yet')
def update_training_info(self, data):
if 'episode' in data:
self.episode_label.setText(f'Episode: {data["episode"]}')
if 'world' in data and 'stage' in data:
self.world_label.setText(f'World: {data["world"]}-{data["stage"]}')
if 'score' in data:
self.score_label.setText(f'Score: {data["score"]}')
if 'total_reward' in data:
self.reward_label.setText(f'Episode Reward: {data["total_reward"]:.2f}')
if 'steps' in data:
self.steps_label.setText(f'Steps: {data["steps"]}')
if 'epsilon' in data:
self.epsilon_label.setText(f'Epsilon: {data["epsilon"]:.3f}')
if 'loss' in data:
self.loss_label.setText(f'Loss: {data["loss"]:.4f}')
if 'memory_size' in data:
self.memory_label.setText(f'Memory: {data["memory_size"]}')
if 'x_pos' in data:
self.xpos_label.setText(f'X Position: {data["x_pos"]}')
if 'coins' in data:
self.coins_label.setText(f'Coins: {data["coins"]}')
if 'time' in data:
self.time_label.setText(f'Time: {data["time"]}')
if 'flag_get' in data:
flag_text = "Yes" if data["flag_get"] else "No"
self.flag_label.setText(f'Flag: {flag_text}')
if 'log_message' in data:
self.log_text.append(data['log_message'])
# Auto-scroll to bottom
self.log_text.verticalScrollBar().setValue(
self.log_text.verticalScrollBar().maximum()
)
def update_game_display(self, frame):
if frame is not None:
try:
h, w, ch = frame.shape
bytes_per_line = ch * w
# Ensure contiguous array
frame_contiguous = np.ascontiguousarray(frame)
q_img = QImage(frame_contiguous.data, w, h, bytes_per_line, QImage.Format_RGB888)
pixmap = QPixmap.fromImage(q_img)
self.game_display.setPixmap(pixmap.scaled(400, 300, Qt.KeepAspectRatio))
except Exception as e:
print(f"Error updating display: {e}")
def closeEvent(self, event):
self.stop_training()
event.accept()
def main():
# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)
app = QApplication(sys.argv)
window = MarioRLApp()
window.show()
sys.exit(app.exec_())
if __name__ == '__main__':
main()