|
|
import pickle |
|
|
import random |
|
|
import time |
|
|
from collections import deque |
|
|
|
|
|
import gym_super_mario_bros |
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import torch.optim as optim |
|
|
from gym_super_mario_bros.actions import COMPLEX_MOVEMENT |
|
|
from nes_py.wrappers import JoypadSpace |
|
|
|
|
|
from PyQt5.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout, |
|
|
QHBoxLayout, QPushButton, QLabel, QComboBox, |
|
|
QTextEdit, QProgressBar, QTabWidget, QFrame, QGroupBox) |
|
|
from PyQt5.QtCore import QTimer, Qt, pyqtSignal, QThread |
|
|
from PyQt5.QtGui import QImage, QPixmap, QFont |
|
|
import sys |
|
|
import cv2 |
|
|
|
|
|
|
|
|
try: |
|
|
from wrappers import * |
|
|
except ImportError: |
|
|
|
|
|
class SimpleWrapper: |
|
|
def __init__(self, env): |
|
|
self.env = env |
|
|
self.action_space = env.action_space |
|
|
self.observation_space = env.observation_space |
|
|
|
|
|
def reset(self): |
|
|
return self.env.reset() |
|
|
|
|
|
def step(self, action): |
|
|
return self.env.step(action) |
|
|
|
|
|
def render(self, mode='rgb_array'): |
|
|
return self.env.render(mode) |
|
|
|
|
|
def close(self): |
|
|
if hasattr(self.env, 'close'): |
|
|
self.env.close() |
|
|
|
|
|
def wrap_mario(env): |
|
|
return SimpleWrapper(env) |
|
|
|
|
|
|
|
|
class FrameStacker: |
|
|
"""Handles frame stacking and preprocessing""" |
|
|
def __init__(self, frame_size=(84, 84), stack_size=4): |
|
|
self.frame_size = frame_size |
|
|
self.stack_size = stack_size |
|
|
self.frames = deque(maxlen=stack_size) |
|
|
|
|
|
def preprocess_frame(self, frame): |
|
|
"""Convert frame to grayscale and resize""" |
|
|
|
|
|
gray = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) |
|
|
|
|
|
resized = cv2.resize(gray, self.frame_size, interpolation=cv2.INTER_AREA) |
|
|
|
|
|
normalized = resized.astype(np.float32) / 255.0 |
|
|
return normalized |
|
|
|
|
|
def reset(self, frame): |
|
|
"""Reset frame stack with initial frame""" |
|
|
self.frames.clear() |
|
|
processed_frame = self.preprocess_frame(frame) |
|
|
for _ in range(self.stack_size): |
|
|
self.frames.append(processed_frame) |
|
|
return self.get_stacked_frames() |
|
|
|
|
|
def append(self, frame): |
|
|
"""Add new frame to stack""" |
|
|
processed_frame = self.preprocess_frame(frame) |
|
|
self.frames.append(processed_frame) |
|
|
return self.get_stacked_frames() |
|
|
|
|
|
def get_stacked_frames(self): |
|
|
"""Get stacked frames as numpy array""" |
|
|
stacked = np.array(self.frames) |
|
|
return np.ascontiguousarray(stacked) |
|
|
|
|
|
|
|
|
class replay_memory(object): |
|
|
def __init__(self, N): |
|
|
self.memory = deque(maxlen=N) |
|
|
|
|
|
def push(self, transition): |
|
|
self.memory.append(transition) |
|
|
|
|
|
def sample(self, n): |
|
|
return random.sample(self.memory, n) |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.memory) |
|
|
|
|
|
|
|
|
class DuelingDQNModel(nn.Module): |
|
|
def __init__(self, n_frame, n_action, device): |
|
|
super(DuelingDQNModel, self).__init__() |
|
|
|
|
|
|
|
|
self.conv_layers = nn.Sequential( |
|
|
nn.Conv2d(n_frame, 32, kernel_size=8, stride=4), |
|
|
nn.ReLU(), |
|
|
nn.Conv2d(32, 64, kernel_size=4, stride=2), |
|
|
nn.ReLU(), |
|
|
nn.Conv2d(64, 64, kernel_size=3, stride=1), |
|
|
nn.ReLU() |
|
|
) |
|
|
|
|
|
|
|
|
self.conv_out_size = self._get_conv_out((n_frame, 84, 84)) |
|
|
|
|
|
|
|
|
self.advantage_stream = nn.Sequential( |
|
|
nn.Linear(self.conv_out_size, 512), |
|
|
nn.ReLU(), |
|
|
nn.Linear(512, n_action) |
|
|
) |
|
|
|
|
|
|
|
|
self.value_stream = nn.Sequential( |
|
|
nn.Linear(self.conv_out_size, 512), |
|
|
nn.ReLU(), |
|
|
nn.Linear(512, 1) |
|
|
) |
|
|
|
|
|
self.device = device |
|
|
self.apply(self.init_weights) |
|
|
|
|
|
def _get_conv_out(self, shape): |
|
|
with torch.no_grad(): |
|
|
x = torch.zeros(1, *shape) |
|
|
x = self.conv_layers(x) |
|
|
return int(np.prod(x.size())) |
|
|
|
|
|
def init_weights(self, m): |
|
|
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): |
|
|
torch.nn.init.xavier_uniform_(m.weight) |
|
|
if m.bias is not None: |
|
|
m.bias.data.fill_(0.01) |
|
|
|
|
|
def forward(self, x): |
|
|
if not isinstance(x, torch.Tensor): |
|
|
x = torch.FloatTensor(x).to(self.device) |
|
|
|
|
|
|
|
|
x = self.conv_layers(x) |
|
|
x = x.view(x.size(0), -1) |
|
|
|
|
|
|
|
|
advantage = self.advantage_stream(x) |
|
|
value = self.value_stream(x) |
|
|
|
|
|
|
|
|
q_values = value + (advantage - advantage.mean(dim=1, keepdim=True)) |
|
|
|
|
|
return q_values |
|
|
|
|
|
|
|
|
def train(q, q_target, memory, batch_size, gamma, optimizer, device): |
|
|
if len(memory) < batch_size: |
|
|
return 0.0 |
|
|
|
|
|
transitions = memory.sample(batch_size) |
|
|
s, r, a, s_prime, done = list(map(list, zip(*transitions))) |
|
|
|
|
|
|
|
|
s = np.array([np.ascontiguousarray(arr) for arr in s]) |
|
|
s_prime = np.array([np.ascontiguousarray(arr) for arr in s_prime]) |
|
|
|
|
|
|
|
|
s_tensor = torch.FloatTensor(s).to(device) |
|
|
s_prime_tensor = torch.FloatTensor(s_prime).to(device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
next_q_values = q_target(s_prime_tensor) |
|
|
next_actions = next_q_values.max(1)[1].unsqueeze(1) |
|
|
next_q_value = next_q_values.gather(1, next_actions) |
|
|
|
|
|
|
|
|
r = torch.FloatTensor(r).unsqueeze(1).to(device) |
|
|
done = torch.FloatTensor(done).unsqueeze(1).to(device) |
|
|
target_q_values = r + gamma * next_q_value * (1 - done) |
|
|
|
|
|
|
|
|
a_tensor = torch.LongTensor(a).unsqueeze(1).to(device) |
|
|
current_q_values = q(s_tensor).gather(1, a_tensor) |
|
|
|
|
|
|
|
|
loss = F.smooth_l1_loss(current_q_values, target_q_values) |
|
|
|
|
|
|
|
|
optimizer.zero_grad() |
|
|
loss.backward() |
|
|
|
|
|
|
|
|
torch.nn.utils.clip_grad_norm_(q.parameters(), max_norm=10.0) |
|
|
|
|
|
optimizer.step() |
|
|
return loss.item() |
|
|
|
|
|
|
|
|
def copy_weights(q, q_target): |
|
|
q_dict = q.state_dict() |
|
|
q_target.load_state_dict(q_dict) |
|
|
|
|
|
|
|
|
class MarioTrainingThread(QThread): |
|
|
update_signal = pyqtSignal(dict) |
|
|
frame_signal = pyqtSignal(np.ndarray) |
|
|
|
|
|
def __init__(self, device="cpu"): |
|
|
super().__init__() |
|
|
self.device = device |
|
|
self.running = False |
|
|
self.env = None |
|
|
self.q = None |
|
|
self.q_target = None |
|
|
self.optimizer = None |
|
|
self.frame_stacker = None |
|
|
|
|
|
|
|
|
self.gamma = 0.99 |
|
|
self.batch_size = 32 |
|
|
self.memory_size = 10000 |
|
|
self.eps = 1.0 |
|
|
self.eps_min = 0.01 |
|
|
self.eps_decay = 0.995 |
|
|
self.update_interval = 1000 |
|
|
self.save_interval = 100 |
|
|
self.print_interval = 10 |
|
|
|
|
|
self.memory = None |
|
|
self.t = 0 |
|
|
self.k = 0 |
|
|
self.total_score = 0.0 |
|
|
self.loss_accumulator = 0.0 |
|
|
self.best_score = -float('inf') |
|
|
self.last_x_pos = 0 |
|
|
|
|
|
def setup_training(self): |
|
|
n_frame = 4 |
|
|
try: |
|
|
self.env = gym_super_mario_bros.make("SuperMarioBros-v3") |
|
|
self.env = JoypadSpace(self.env, COMPLEX_MOVEMENT) |
|
|
self.env = wrap_mario(self.env) |
|
|
|
|
|
|
|
|
self.frame_stacker = FrameStacker(frame_size=(84, 84), stack_size=n_frame) |
|
|
|
|
|
self.q = DuelingDQNModel(n_frame, self.env.action_space.n, self.device).to(self.device) |
|
|
self.q_target = DuelingDQNModel(n_frame, self.env.action_space.n, self.device).to(self.device) |
|
|
|
|
|
copy_weights(self.q, self.q_target) |
|
|
|
|
|
|
|
|
self.q_target.eval() |
|
|
|
|
|
|
|
|
self.optimizer = optim.Adam(self.q.parameters(), lr=0.0001, weight_decay=1e-5) |
|
|
|
|
|
self.memory = replay_memory(self.memory_size) |
|
|
|
|
|
self.log_message(f"โ
Training setup complete - Actions: {self.env.action_space.n}, Device: {self.device}") |
|
|
|
|
|
except Exception as e: |
|
|
self.log_message(f"โ Error setting up training: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
self.running = False |
|
|
|
|
|
def run(self): |
|
|
self.running = True |
|
|
self.setup_training() |
|
|
|
|
|
if not self.running: |
|
|
return |
|
|
|
|
|
start_time = time.perf_counter() |
|
|
score_lst = [] |
|
|
|
|
|
try: |
|
|
for k in range(1000000): |
|
|
if not self.running: |
|
|
break |
|
|
|
|
|
|
|
|
frame = self.env.reset() |
|
|
s = self.frame_stacker.reset(frame) |
|
|
done = False |
|
|
episode_loss = 0.0 |
|
|
episode_steps = 0 |
|
|
episode_score = 0.0 |
|
|
self.last_x_pos = 0 |
|
|
|
|
|
while not done and self.running: |
|
|
|
|
|
s_processed = np.ascontiguousarray(s) |
|
|
|
|
|
|
|
|
if np.random.random() <= self.eps: |
|
|
a = self.env.action_space.sample() |
|
|
else: |
|
|
with torch.no_grad(): |
|
|
|
|
|
state_tensor = torch.FloatTensor(s_processed).unsqueeze(0).to(self.device) |
|
|
q_values = self.q(state_tensor) |
|
|
|
|
|
if self.device == "cuda" or self.device == "mps": |
|
|
a = np.argmax(q_values.cpu().numpy()) |
|
|
else: |
|
|
a = np.argmax(q_values.detach().numpy()) |
|
|
|
|
|
|
|
|
frame, r, done, info = self.env.step(a) |
|
|
|
|
|
|
|
|
s_prime = self.frame_stacker.append(frame) |
|
|
|
|
|
episode_score += r |
|
|
|
|
|
|
|
|
reward = r |
|
|
|
|
|
|
|
|
if 'x_pos' in info: |
|
|
x_pos = info['x_pos'] |
|
|
x_progress = x_pos - self.last_x_pos |
|
|
if x_progress > 0: |
|
|
reward += 0.1 * x_progress |
|
|
self.last_x_pos = x_pos |
|
|
|
|
|
|
|
|
if done and info.get('flag_get', False): |
|
|
reward += 100.0 |
|
|
self.log_message(f"๐ LEVEL COMPLETED at episode {k}! ๐") |
|
|
|
|
|
|
|
|
s_contiguous = np.ascontiguousarray(s) |
|
|
s_prime_contiguous = np.ascontiguousarray(s_prime) |
|
|
self.memory.push((s_contiguous, float(reward), int(a), s_prime_contiguous, int(1 - done))) |
|
|
|
|
|
s = s_prime |
|
|
stage = info.get('stage', 1) |
|
|
world = info.get('world', 1) |
|
|
|
|
|
|
|
|
try: |
|
|
display_frame = self.env.render() |
|
|
if display_frame is not None: |
|
|
|
|
|
frame_contiguous = np.ascontiguousarray(display_frame) |
|
|
self.frame_signal.emit(frame_contiguous) |
|
|
except Exception as e: |
|
|
|
|
|
frame = np.zeros((240, 256, 3), dtype=np.uint8) |
|
|
self.frame_signal.emit(frame) |
|
|
|
|
|
|
|
|
if len(self.memory) > self.batch_size: |
|
|
loss_val = train(self.q, self.q_target, self.memory, self.batch_size, |
|
|
self.gamma, self.optimizer, self.device) |
|
|
if loss_val > 0: |
|
|
self.loss_accumulator += loss_val |
|
|
episode_loss += loss_val |
|
|
self.t += 1 |
|
|
|
|
|
|
|
|
if self.t % self.update_interval == 0: |
|
|
copy_weights(self.q, self.q_target) |
|
|
|
|
|
episode_steps += 1 |
|
|
|
|
|
|
|
|
if episode_steps % 10 == 0: |
|
|
progress_data = { |
|
|
'episode': k, |
|
|
'total_reward': episode_score, |
|
|
'steps': episode_steps, |
|
|
'epsilon': self.eps, |
|
|
'world': world, |
|
|
'stage': stage, |
|
|
'loss': episode_loss / (episode_steps + 1e-8), |
|
|
'memory_size': len(self.memory), |
|
|
'x_pos': info.get('x_pos', 0), |
|
|
'score': info.get('score', 0), |
|
|
'coins': info.get('coins', 0), |
|
|
'time': info.get('time', 400), |
|
|
'flag_get': info.get('flag_get', False) |
|
|
} |
|
|
self.update_signal.emit(progress_data) |
|
|
|
|
|
|
|
|
if self.eps > self.eps_min: |
|
|
self.eps *= self.eps_decay |
|
|
|
|
|
|
|
|
self.total_score += episode_score |
|
|
|
|
|
|
|
|
if episode_score > self.best_score and k > 0: |
|
|
self.best_score = episode_score |
|
|
torch.save(self.q.state_dict(), "enhanced_mario_q_best.pth") |
|
|
torch.save(self.q_target.state_dict(), "enhanced_mario_q_target_best.pth") |
|
|
self.log_message(f"๐พ New best model saved! Score: {self.best_score:.2f}") |
|
|
|
|
|
|
|
|
if k % self.save_interval == 0 and k > 0: |
|
|
torch.save(self.q.state_dict(), "enhanced_mario_q.pth") |
|
|
torch.save(self.q_target.state_dict(), "enhanced_mario_q_target.pth") |
|
|
self.log_message(f"๐พ Models saved at episode {k}") |
|
|
|
|
|
|
|
|
if k % self.print_interval == 0 and k > 0: |
|
|
time_spent = time.perf_counter() - start_time |
|
|
start_time = time.perf_counter() |
|
|
|
|
|
avg_loss = self.loss_accumulator / (self.print_interval * max(episode_steps, 1)) |
|
|
avg_score = self.total_score / self.print_interval |
|
|
|
|
|
log_msg = ( |
|
|
f"{self.device} | Ep: {k} | Score: {avg_score:.2f} | Loss: {avg_loss:.4f} | " |
|
|
f"Stage: {world}-{stage} | Eps: {self.eps:.3f} | Time: {time_spent:.2f}s | " |
|
|
f"Mem: {len(self.memory)} | Steps: {episode_steps}" |
|
|
) |
|
|
self.log_message(log_msg) |
|
|
|
|
|
score_lst.append(avg_score) |
|
|
self.total_score = 0.0 |
|
|
self.loss_accumulator = 0.0 |
|
|
|
|
|
try: |
|
|
pickle.dump(score_lst, open("score.p", "wb")) |
|
|
except Exception as e: |
|
|
self.log_message(f"โ ๏ธ Could not save scores: {e}") |
|
|
|
|
|
self.k = k |
|
|
|
|
|
except Exception as e: |
|
|
self.log_message(f"โ Training error: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
|
|
|
def log_message(self, message): |
|
|
progress_data = { |
|
|
'log_message': message |
|
|
} |
|
|
self.update_signal.emit(progress_data) |
|
|
|
|
|
def stop(self): |
|
|
self.running = False |
|
|
if self.env: |
|
|
try: |
|
|
self.env.close() |
|
|
except: |
|
|
pass |
|
|
|
|
|
|
|
|
class MarioRLApp(QMainWindow): |
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
self.training_thread = None |
|
|
self.init_ui() |
|
|
|
|
|
def init_ui(self): |
|
|
self.setWindowTitle('๐ฎ Super Mario Bros - Dueling DQN Training') |
|
|
self.setGeometry(100, 100, 1200, 800) |
|
|
|
|
|
central_widget = QWidget() |
|
|
self.setCentralWidget(central_widget) |
|
|
layout = QVBoxLayout(central_widget) |
|
|
|
|
|
|
|
|
title = QLabel('๐ฎ Super Mario Bros - Enhanced Dueling DQN') |
|
|
title.setFont(QFont('Arial', 16, QFont.Bold)) |
|
|
title.setAlignment(Qt.AlignCenter) |
|
|
layout.addWidget(title) |
|
|
|
|
|
|
|
|
control_layout = QHBoxLayout() |
|
|
|
|
|
self.device_combo = QComboBox() |
|
|
self.device_combo.addItems(['cpu', 'cuda', 'mps']) |
|
|
|
|
|
self.start_btn = QPushButton('Start Training') |
|
|
self.start_btn.clicked.connect(self.start_training) |
|
|
|
|
|
self.stop_btn = QPushButton('Stop Training') |
|
|
self.stop_btn.clicked.connect(self.stop_training) |
|
|
self.stop_btn.setEnabled(False) |
|
|
|
|
|
self.load_btn = QPushButton('Load Model') |
|
|
self.load_btn.clicked.connect(self.load_model) |
|
|
|
|
|
control_layout.addWidget(QLabel('Device:')) |
|
|
control_layout.addWidget(self.device_combo) |
|
|
control_layout.addWidget(self.start_btn) |
|
|
control_layout.addWidget(self.stop_btn) |
|
|
control_layout.addWidget(self.load_btn) |
|
|
control_layout.addStretch() |
|
|
|
|
|
layout.addLayout(control_layout) |
|
|
|
|
|
|
|
|
content_layout = QHBoxLayout() |
|
|
|
|
|
|
|
|
left_frame = QFrame() |
|
|
left_frame.setFrameStyle(QFrame.Box) |
|
|
left_layout = QVBoxLayout(left_frame) |
|
|
|
|
|
self.game_display = QLabel() |
|
|
self.game_display.setMinimumSize(400, 300) |
|
|
self.game_display.setAlignment(Qt.AlignCenter) |
|
|
self.game_display.setText('Game display will appear here\nPress "Start Training" to begin') |
|
|
self.game_display.setStyleSheet('border: 1px solid gray; background-color: black; color: white;') |
|
|
|
|
|
left_layout.addWidget(QLabel('Mario Game Display:')) |
|
|
left_layout.addWidget(self.game_display) |
|
|
|
|
|
|
|
|
right_frame = QFrame() |
|
|
right_frame.setFrameStyle(QFrame.Box) |
|
|
right_layout = QVBoxLayout(right_frame) |
|
|
|
|
|
|
|
|
stats_group = QGroupBox("Training Statistics") |
|
|
stats_layout = QVBoxLayout(stats_group) |
|
|
|
|
|
self.episode_label = QLabel('Episode: 0') |
|
|
self.world_label = QLabel('World: 1-1') |
|
|
self.score_label = QLabel('Score: 0') |
|
|
self.reward_label = QLabel('Episode Reward: 0') |
|
|
self.steps_label = QLabel('Steps: 0') |
|
|
self.epsilon_label = QLabel('Epsilon: 1.000') |
|
|
self.loss_label = QLabel('Loss: 0.0000') |
|
|
self.memory_label = QLabel('Memory: 0') |
|
|
self.xpos_label = QLabel('X Position: 0') |
|
|
self.coins_label = QLabel('Coins: 0') |
|
|
self.time_label = QLabel('Time: 400') |
|
|
self.flag_label = QLabel('Flag: No') |
|
|
|
|
|
stats_layout.addWidget(self.episode_label) |
|
|
stats_layout.addWidget(self.world_label) |
|
|
stats_layout.addWidget(self.score_label) |
|
|
stats_layout.addWidget(self.reward_label) |
|
|
stats_layout.addWidget(self.steps_label) |
|
|
stats_layout.addWidget(self.epsilon_label) |
|
|
stats_layout.addWidget(self.loss_label) |
|
|
stats_layout.addWidget(self.memory_label) |
|
|
stats_layout.addWidget(self.xpos_label) |
|
|
stats_layout.addWidget(self.coins_label) |
|
|
stats_layout.addWidget(self.time_label) |
|
|
stats_layout.addWidget(self.flag_label) |
|
|
|
|
|
right_layout.addWidget(stats_group) |
|
|
|
|
|
|
|
|
right_layout.addWidget(QLabel('Training Log:')) |
|
|
self.log_text = QTextEdit() |
|
|
self.log_text.setMaximumHeight(300) |
|
|
right_layout.addWidget(self.log_text) |
|
|
|
|
|
content_layout.addWidget(left_frame) |
|
|
content_layout.addWidget(right_frame) |
|
|
layout.addLayout(content_layout) |
|
|
|
|
|
def start_training(self): |
|
|
device = self.device_combo.currentText() |
|
|
|
|
|
|
|
|
if device == "cuda" and not torch.cuda.is_available(): |
|
|
self.log_text.append("โ CUDA not available, using CPU instead") |
|
|
device = "cpu" |
|
|
elif device == "mps" and not torch.backends.mps.is_available(): |
|
|
self.log_text.append("โ MPS not available, using CPU instead") |
|
|
device = "cpu" |
|
|
|
|
|
self.training_thread = MarioTrainingThread(device) |
|
|
self.training_thread.update_signal.connect(self.update_training_info) |
|
|
self.training_thread.frame_signal.connect(self.update_game_display) |
|
|
self.training_thread.start() |
|
|
|
|
|
self.start_btn.setEnabled(False) |
|
|
self.stop_btn.setEnabled(True) |
|
|
|
|
|
self.log_text.append(f'๐ Started Dueling DQN training on {device}...') |
|
|
|
|
|
def stop_training(self): |
|
|
if self.training_thread: |
|
|
self.training_thread.stop() |
|
|
self.training_thread.wait() |
|
|
|
|
|
self.start_btn.setEnabled(True) |
|
|
self.stop_btn.setEnabled(False) |
|
|
self.log_text.append('โน๏ธ Training stopped.') |
|
|
|
|
|
def load_model(self): |
|
|
|
|
|
self.log_text.append('๐ Load model functionality not implemented yet') |
|
|
|
|
|
def update_training_info(self, data): |
|
|
if 'episode' in data: |
|
|
self.episode_label.setText(f'Episode: {data["episode"]}') |
|
|
if 'world' in data and 'stage' in data: |
|
|
self.world_label.setText(f'World: {data["world"]}-{data["stage"]}') |
|
|
if 'score' in data: |
|
|
self.score_label.setText(f'Score: {data["score"]}') |
|
|
if 'total_reward' in data: |
|
|
self.reward_label.setText(f'Episode Reward: {data["total_reward"]:.2f}') |
|
|
if 'steps' in data: |
|
|
self.steps_label.setText(f'Steps: {data["steps"]}') |
|
|
if 'epsilon' in data: |
|
|
self.epsilon_label.setText(f'Epsilon: {data["epsilon"]:.3f}') |
|
|
if 'loss' in data: |
|
|
self.loss_label.setText(f'Loss: {data["loss"]:.4f}') |
|
|
if 'memory_size' in data: |
|
|
self.memory_label.setText(f'Memory: {data["memory_size"]}') |
|
|
if 'x_pos' in data: |
|
|
self.xpos_label.setText(f'X Position: {data["x_pos"]}') |
|
|
if 'coins' in data: |
|
|
self.coins_label.setText(f'Coins: {data["coins"]}') |
|
|
if 'time' in data: |
|
|
self.time_label.setText(f'Time: {data["time"]}') |
|
|
if 'flag_get' in data: |
|
|
flag_text = "Yes" if data["flag_get"] else "No" |
|
|
self.flag_label.setText(f'Flag: {flag_text}') |
|
|
if 'log_message' in data: |
|
|
self.log_text.append(data['log_message']) |
|
|
|
|
|
self.log_text.verticalScrollBar().setValue( |
|
|
self.log_text.verticalScrollBar().maximum() |
|
|
) |
|
|
|
|
|
def update_game_display(self, frame): |
|
|
if frame is not None: |
|
|
try: |
|
|
h, w, ch = frame.shape |
|
|
bytes_per_line = ch * w |
|
|
|
|
|
frame_contiguous = np.ascontiguousarray(frame) |
|
|
q_img = QImage(frame_contiguous.data, w, h, bytes_per_line, QImage.Format_RGB888) |
|
|
pixmap = QPixmap.fromImage(q_img) |
|
|
self.game_display.setPixmap(pixmap.scaled(400, 300, Qt.KeepAspectRatio)) |
|
|
except Exception as e: |
|
|
print(f"Error updating display: {e}") |
|
|
|
|
|
def closeEvent(self, event): |
|
|
self.stop_training() |
|
|
event.accept() |
|
|
|
|
|
|
|
|
def main(): |
|
|
|
|
|
torch.manual_seed(42) |
|
|
np.random.seed(42) |
|
|
random.seed(42) |
|
|
|
|
|
app = QApplication(sys.argv) |
|
|
window = MarioRLApp() |
|
|
window.show() |
|
|
sys.exit(app.exec_()) |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |