import sys import os import numpy as np import random from collections import deque import gymnasium as gym import ale_py import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F from torch.distributions import Categorical from PyQt5.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout, QHBoxLayout, QPushButton, QLabel, QComboBox, QTextEdit, QProgressBar, QTabWidget, QFrame) from PyQt5.QtCore import QTimer, Qt, pyqtSignal, QThread from PyQt5.QtGui import QImage, QPixmap, QFont # Register ALE environments gym.register_envs(ale_py) # Environment setup def create_env(env_name='ALE/Breakout-v5'): """ Create ALE environment with Gymnasium API Available environments: - ALE/Breakout-v5, ALE/Pong-v5, ALE/SpaceInvaders-v5, - ALE/Assault-v5, ALE/BeamRider-v5, ALE/Enduro-v5 """ env = gym.make(env_name, render_mode='rgb_array') return env # Neural Network for Dueling DQN class DuelingDQN(nn.Module): def __init__(self, input_shape, n_actions): super(DuelingDQN, self).__init__() self.conv = nn.Sequential( nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4), nn.ReLU(), nn.Conv2d(32, 64, kernel_size=4, stride=2), nn.ReLU(), nn.Conv2d(64, 64, kernel_size=3, stride=1), nn.ReLU() ) conv_out_size = self._get_conv_out(input_shape) self.fc_advantage = nn.Sequential( nn.Linear(conv_out_size, 512), nn.ReLU(), nn.Linear(512, n_actions) ) self.fc_value = nn.Sequential( nn.Linear(conv_out_size, 512), nn.ReLU(), nn.Linear(512, 1) ) def _get_conv_out(self, shape): o = self.conv(torch.zeros(1, *shape)) return int(np.prod(o.size())) def forward(self, x): conv_out = self.conv(x).view(x.size()[0], -1) advantage = self.fc_advantage(conv_out) value = self.fc_value(conv_out) return value + advantage - advantage.mean() # Neural Network for PPO class PPONetwork(nn.Module): def __init__(self, input_shape, n_actions): super(PPONetwork, self).__init__() self.conv = nn.Sequential( nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4), nn.ReLU(), nn.Conv2d(32, 64, kernel_size=4, stride=2), nn.ReLU(), nn.Conv2d(64, 64, kernel_size=3, stride=1), nn.ReLU() ) conv_out_size = self._get_conv_out(input_shape) self.actor = nn.Sequential( nn.Linear(conv_out_size, 512), nn.ReLU(), nn.Linear(512, n_actions), nn.Softmax(dim=-1) ) self.critic = nn.Sequential( nn.Linear(conv_out_size, 512), nn.ReLU(), nn.Linear(512, 1) ) def _get_conv_out(self, shape): o = self.conv(torch.zeros(1, *shape)) return int(np.prod(o.size())) def forward(self, x): conv_out = self.conv(x).view(x.size()[0], -1) return self.actor(conv_out), self.critic(conv_out) # Dueling DQN Agent class DuelingDQNAgent: def __init__(self, state_dim, action_dim, lr=1e-4, gamma=0.99, epsilon=1.0, epsilon_min=0.01, epsilon_decay=0.995, memory_size=10000, batch_size=32): self.state_dim = state_dim self.action_dim = action_dim self.lr = lr self.gamma = gamma self.epsilon = epsilon self.epsilon_min = epsilon_min self.epsilon_decay = epsilon_decay self.batch_size = batch_size self.memory = deque(maxlen=memory_size) self.model = DuelingDQN(state_dim, action_dim) self.optimizer = optim.Adam(self.model.parameters(), lr=lr) self.criterion = nn.MSELoss() def remember(self, state, action, reward, next_state, done): self.memory.append((state, action, reward, next_state, done)) def act(self, state): if np.random.random() <= self.epsilon: return random.randrange(self.action_dim) state = torch.FloatTensor(state).unsqueeze(0) with torch.no_grad(): q_values = self.model(state) return np.argmax(q_values.detach().numpy()) def replay(self): if len(self.memory) < self.batch_size: return batch = random.sample(self.memory, self.batch_size) states = torch.FloatTensor(np.array([e[0] for e in batch])) actions = torch.LongTensor([e[1] for e in batch]) rewards = torch.FloatTensor([e[2] for e in batch]) next_states = torch.FloatTensor(np.array([e[3] for e in batch])) dones = torch.BoolTensor([e[4] for e in batch]) current_q_values = self.model(states).gather(1, actions.unsqueeze(1)) with torch.no_grad(): next_q_values = self.model(next_states).max(1)[0] target_q_values = rewards + (self.gamma * next_q_values * ~dones) loss = self.criterion(current_q_values.squeeze(), target_q_values) self.optimizer.zero_grad() loss.backward() self.optimizer.step() if self.epsilon > self.epsilon_min: self.epsilon *= self.epsilon_decay # PPO Agent class PPOAgent: def __init__(self, state_dim, action_dim, lr=3e-4, gamma=0.99, epsilon=0.2, entropy_coef=0.01, value_coef=0.5): self.state_dim = state_dim self.action_dim = action_dim self.gamma = gamma self.epsilon = epsilon self.entropy_coef = entropy_coef self.value_coef = value_coef self.model = PPONetwork(state_dim, action_dim) self.optimizer = optim.Adam(self.model.parameters(), lr=lr) self.memory = [] def remember(self, state, action, reward, value, log_prob): self.memory.append((state, action, reward, value, log_prob)) def act(self, state): state = torch.FloatTensor(state).unsqueeze(0) with torch.no_grad(): probs, value = self.model(state) dist = Categorical(probs) action = dist.sample() return action.item(), dist.log_prob(action), value.squeeze() def train(self): if not self.memory: return states, actions, rewards, values, log_probs = zip(*self.memory) # Calculate returns and advantages returns = [] R = 0 for r in reversed(rewards): R = r + self.gamma * R returns.insert(0, R) returns = torch.FloatTensor(returns) values = torch.FloatTensor(values) advantages = returns - values # Normalize advantages advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) # Convert to tensors states = torch.FloatTensor(np.array(states)) actions = torch.LongTensor(actions) old_log_probs = torch.FloatTensor(log_probs) # Get new probabilities new_probs, new_values = self.model(states) dist = Categorical(new_probs) new_log_probs = dist.log_prob(actions) entropy = dist.entropy().mean() # PPO loss ratio = (new_log_probs - old_log_probs).exp() surr1 = ratio * advantages surr2 = torch.clamp(ratio, 1 - self.epsilon, 1 + self.epsilon) * advantages actor_loss = -torch.min(surr1, surr2).mean() critic_loss = F.mse_loss(new_values.squeeze(), returns) total_loss = actor_loss + self.value_coef * critic_loss - self.entropy_coef * entropy self.optimizer.zero_grad() total_loss.backward() self.optimizer.step() self.memory = [] # Training Thread class TrainingThread(QThread): update_signal = pyqtSignal(dict) frame_signal = pyqtSignal(np.ndarray) def __init__(self, algorithm='dqn', env_name='ALE/Breakout-v5'): super().__init__() self.algorithm = algorithm self.env_name = env_name self.running = False self.env = None self.agent = None def preprocess_state(self, state): # Convert to CHW format and normalize state = state.transpose((2, 0, 1)) state = state / 255.0 return state def run(self): self.running = True try: self.env = create_env(self.env_name) state, info = self.env.reset() state = self.preprocess_state(state) n_actions = self.env.action_space.n state_dim = state.shape print(f"Environment: {self.env_name}") print(f"State shape: {state_dim}, Actions: {n_actions}") if self.algorithm == 'dqn': self.agent = DuelingDQNAgent(state_dim, n_actions) else: self.agent = PPOAgent(state_dim, n_actions) episode = 0 total_reward = 0 steps = 0 episode_rewards = [] while self.running: try: if self.algorithm == 'dqn': action = self.agent.act(state) next_state, reward, terminated, truncated, info = self.env.step(action) done = terminated or truncated next_state = self.preprocess_state(next_state) self.agent.remember(state, action, reward, next_state, done) self.agent.replay() else: action, log_prob, value = self.agent.act(state) next_state, reward, terminated, truncated, info = self.env.step(action) done = terminated or truncated next_state = self.preprocess_state(next_state) self.agent.remember(state, action, reward, value, log_prob) if done: self.agent.train() state = next_state total_reward += reward steps += 1 # Emit frame for display try: frame = self.env.render() if frame is not None: self.frame_signal.emit(frame) except Exception as e: # Create a placeholder frame if rendering fails frame = np.zeros((210, 160, 3), dtype=np.uint8) self.frame_signal.emit(frame) # Emit training progress if steps % 10 == 0: progress_data = { 'episode': episode, 'total_reward': total_reward, 'steps': steps, 'epsilon': self.agent.epsilon if self.algorithm == 'dqn' else 0.2, 'env_name': self.env_name, 'lives': info.get('lives', 0) if isinstance(info, dict) else 0 } self.update_signal.emit(progress_data) if terminated or truncated: episode_rewards.append(total_reward) avg_reward = np.mean(episode_rewards[-10:]) if episode_rewards else total_reward print(f"Episode {episode}: Total Reward: {total_reward:.2f}, " f"Steps: {steps}, Avg Reward (last 10): {avg_reward:.2f}") episode += 1 state, info = self.env.reset() state = self.preprocess_state(state) total_reward = 0 steps = 0 except Exception as e: print(f"Error in training loop: {e}") import traceback traceback.print_exc() break except Exception as e: print(f"Error setting up environment: {e}") import traceback traceback.print_exc() def stop(self): self.running = False if self.env: self.env.close() # Main Application Window class ALE_RLApp(QMainWindow): def __init__(self): super().__init__() self.training_thread = None self.init_ui() def init_ui(self): self.setWindowTitle('🎮 ALE Arcade RL Training') self.setGeometry(100, 100, 1200, 800) central_widget = QWidget() self.setCentralWidget(central_widget) layout = QVBoxLayout(central_widget) # Title title = QLabel('🎮 Arcade Reinforcement Learning (ALE)') title.setFont(QFont('Arial', 16, QFont.Bold)) title.setAlignment(Qt.AlignCenter) layout.addWidget(title) # Control Panel control_layout = QHBoxLayout() self.algorithm_combo = QComboBox() self.algorithm_combo.addItems(['Dueling DQN', 'PPO']) self.env_combo = QComboBox() self.env_combo.addItems([ 'ALE/Breakout-v5', 'ALE/Pong-v5', 'ALE/SpaceInvaders-v5', 'ALE/Assault-v5', 'ALE/BeamRider-v5', 'ALE/Enduro-v5', 'ALE/Seaquest-v5', 'ALE/Qbert-v5' ]) self.start_btn = QPushButton('Start Training') self.start_btn.clicked.connect(self.start_training) self.stop_btn = QPushButton('Stop Training') self.stop_btn.clicked.connect(self.stop_training) self.stop_btn.setEnabled(False) control_layout.addWidget(QLabel('Algorithm:')) control_layout.addWidget(self.algorithm_combo) control_layout.addWidget(QLabel('Environment:')) control_layout.addWidget(self.env_combo) control_layout.addWidget(self.start_btn) control_layout.addWidget(self.stop_btn) control_layout.addStretch() layout.addLayout(control_layout) # Content Area content_layout = QHBoxLayout() # Left side - Game Display left_frame = QFrame() left_frame.setFrameStyle(QFrame.Box) left_layout = QVBoxLayout(left_frame) self.game_display = QLabel() self.game_display.setMinimumSize(400, 300) self.game_display.setAlignment(Qt.AlignCenter) self.game_display.setText('Game display will appear here\nPress "Start Training" to begin') self.game_display.setStyleSheet('border: 1px solid gray; background-color: black; color: white;') left_layout.addWidget(QLabel('Game Display:')) left_layout.addWidget(self.game_display) # Right side - Training Info right_frame = QFrame() right_frame.setFrameStyle(QFrame.Box) right_layout = QVBoxLayout(right_frame) # Progress bars self.env_label = QLabel('Environment: Not started') self.episode_label = QLabel('Episode: 0') self.reward_label = QLabel('Total Reward: 0') self.steps_label = QLabel('Steps: 0') self.epsilon_label = QLabel('Epsilon: 0') self.lives_label = QLabel('Lives: 0') right_layout.addWidget(self.env_label) right_layout.addWidget(self.episode_label) right_layout.addWidget(self.reward_label) right_layout.addWidget(self.steps_label) right_layout.addWidget(self.epsilon_label) right_layout.addWidget(self.lives_label) # Training log right_layout.addWidget(QLabel('Training Log:')) self.log_text = QTextEdit() self.log_text.setMaximumHeight(200) right_layout.addWidget(self.log_text) content_layout.addWidget(left_frame) content_layout.addWidget(right_frame) layout.addLayout(content_layout) def start_training(self): algorithm = 'dqn' if self.algorithm_combo.currentText() == 'Dueling DQN' else 'ppo' env_name = self.env_combo.currentText() self.training_thread = TrainingThread(algorithm, env_name) self.training_thread.update_signal.connect(self.update_training_info) self.training_thread.frame_signal.connect(self.update_game_display) self.training_thread.start() self.start_btn.setEnabled(False) self.stop_btn.setEnabled(True) self.log_text.append(f'Started {self.algorithm_combo.currentText()} training on {env_name}...') def stop_training(self): if self.training_thread: self.training_thread.stop() self.training_thread.wait() self.start_btn.setEnabled(True) self.stop_btn.setEnabled(False) self.log_text.append('Training stopped.') def update_training_info(self, data): self.env_label.setText(f'Environment: {data.get("env_name", "Unknown")}') self.episode_label.setText(f'Episode: {data["episode"]}') self.reward_label.setText(f'Total Reward: {data["total_reward"]:.2f}') self.steps_label.setText(f'Steps: {data["steps"]}') self.epsilon_label.setText(f'Epsilon: {data["epsilon"]:.3f}') self.lives_label.setText(f'Lives: {data.get("lives", 0)}') def update_game_display(self, frame): if frame is not None: try: h, w, ch = frame.shape bytes_per_line = ch * w q_img = QImage(frame.data, w, h, bytes_per_line, QImage.Format_RGB888) pixmap = QPixmap.fromImage(q_img) self.game_display.setPixmap(pixmap.scaled(400, 300, Qt.KeepAspectRatio)) except Exception as e: print(f"Error updating display: {e}") def closeEvent(self, event): self.stop_training() event.accept() def main(): # Set random seeds for reproducibility torch.manual_seed(42) np.random.seed(42) random.seed(42) app = QApplication(sys.argv) window = ALE_RLApp() window.show() sys.exit(app.exec_()) if __name__ == '__main__': main()