|
|
import sys |
|
|
import os |
|
|
import numpy as np |
|
|
import random |
|
|
from collections import deque |
|
|
import gymnasium as gym |
|
|
import ale_py |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.optim as optim |
|
|
import torch.nn.functional as F |
|
|
from torch.distributions import Categorical |
|
|
|
|
|
from PyQt5.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout, |
|
|
QHBoxLayout, QPushButton, QLabel, QComboBox, |
|
|
QTextEdit, QProgressBar, QTabWidget, QFrame) |
|
|
from PyQt5.QtCore import QTimer, Qt, pyqtSignal, QThread |
|
|
from PyQt5.QtGui import QImage, QPixmap, QFont |
|
|
|
|
|
|
|
|
gym.register_envs(ale_py) |
|
|
|
|
|
|
|
|
def create_env(env_name='ALE/Breakout-v5'): |
|
|
""" |
|
|
Create ALE environment with Gymnasium API |
|
|
Available environments: |
|
|
- ALE/Breakout-v5, ALE/Pong-v5, ALE/SpaceInvaders-v5, |
|
|
- ALE/Assault-v5, ALE/BeamRider-v5, ALE/Enduro-v5 |
|
|
""" |
|
|
env = gym.make(env_name, render_mode='rgb_array') |
|
|
return env |
|
|
|
|
|
|
|
|
class DuelingDQN(nn.Module): |
|
|
def __init__(self, input_shape, n_actions): |
|
|
super(DuelingDQN, self).__init__() |
|
|
self.conv = nn.Sequential( |
|
|
nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4), |
|
|
nn.ReLU(), |
|
|
nn.Conv2d(32, 64, kernel_size=4, stride=2), |
|
|
nn.ReLU(), |
|
|
nn.Conv2d(64, 64, kernel_size=3, stride=1), |
|
|
nn.ReLU() |
|
|
) |
|
|
|
|
|
conv_out_size = self._get_conv_out(input_shape) |
|
|
|
|
|
self.fc_advantage = nn.Sequential( |
|
|
nn.Linear(conv_out_size, 512), |
|
|
nn.ReLU(), |
|
|
nn.Linear(512, n_actions) |
|
|
) |
|
|
|
|
|
self.fc_value = nn.Sequential( |
|
|
nn.Linear(conv_out_size, 512), |
|
|
nn.ReLU(), |
|
|
nn.Linear(512, 1) |
|
|
) |
|
|
|
|
|
def _get_conv_out(self, shape): |
|
|
o = self.conv(torch.zeros(1, *shape)) |
|
|
return int(np.prod(o.size())) |
|
|
|
|
|
def forward(self, x): |
|
|
conv_out = self.conv(x).view(x.size()[0], -1) |
|
|
advantage = self.fc_advantage(conv_out) |
|
|
value = self.fc_value(conv_out) |
|
|
return value + advantage - advantage.mean() |
|
|
|
|
|
|
|
|
class PPONetwork(nn.Module): |
|
|
def __init__(self, input_shape, n_actions): |
|
|
super(PPONetwork, self).__init__() |
|
|
self.conv = nn.Sequential( |
|
|
nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4), |
|
|
nn.ReLU(), |
|
|
nn.Conv2d(32, 64, kernel_size=4, stride=2), |
|
|
nn.ReLU(), |
|
|
nn.Conv2d(64, 64, kernel_size=3, stride=1), |
|
|
nn.ReLU() |
|
|
) |
|
|
|
|
|
conv_out_size = self._get_conv_out(input_shape) |
|
|
|
|
|
self.actor = nn.Sequential( |
|
|
nn.Linear(conv_out_size, 512), |
|
|
nn.ReLU(), |
|
|
nn.Linear(512, n_actions), |
|
|
nn.Softmax(dim=-1) |
|
|
) |
|
|
|
|
|
self.critic = nn.Sequential( |
|
|
nn.Linear(conv_out_size, 512), |
|
|
nn.ReLU(), |
|
|
nn.Linear(512, 1) |
|
|
) |
|
|
|
|
|
def _get_conv_out(self, shape): |
|
|
o = self.conv(torch.zeros(1, *shape)) |
|
|
return int(np.prod(o.size())) |
|
|
|
|
|
def forward(self, x): |
|
|
conv_out = self.conv(x).view(x.size()[0], -1) |
|
|
return self.actor(conv_out), self.critic(conv_out) |
|
|
|
|
|
|
|
|
class DuelingDQNAgent: |
|
|
def __init__(self, state_dim, action_dim, lr=1e-4, gamma=0.99, epsilon=1.0, |
|
|
epsilon_min=0.01, epsilon_decay=0.995, memory_size=10000, batch_size=32): |
|
|
self.state_dim = state_dim |
|
|
self.action_dim = action_dim |
|
|
self.lr = lr |
|
|
self.gamma = gamma |
|
|
self.epsilon = epsilon |
|
|
self.epsilon_min = epsilon_min |
|
|
self.epsilon_decay = epsilon_decay |
|
|
self.batch_size = batch_size |
|
|
|
|
|
self.memory = deque(maxlen=memory_size) |
|
|
self.model = DuelingDQN(state_dim, action_dim) |
|
|
self.optimizer = optim.Adam(self.model.parameters(), lr=lr) |
|
|
self.criterion = nn.MSELoss() |
|
|
|
|
|
def remember(self, state, action, reward, next_state, done): |
|
|
self.memory.append((state, action, reward, next_state, done)) |
|
|
|
|
|
def act(self, state): |
|
|
if np.random.random() <= self.epsilon: |
|
|
return random.randrange(self.action_dim) |
|
|
|
|
|
state = torch.FloatTensor(state).unsqueeze(0) |
|
|
with torch.no_grad(): |
|
|
q_values = self.model(state) |
|
|
return np.argmax(q_values.detach().numpy()) |
|
|
|
|
|
def replay(self): |
|
|
if len(self.memory) < self.batch_size: |
|
|
return |
|
|
|
|
|
batch = random.sample(self.memory, self.batch_size) |
|
|
states = torch.FloatTensor(np.array([e[0] for e in batch])) |
|
|
actions = torch.LongTensor([e[1] for e in batch]) |
|
|
rewards = torch.FloatTensor([e[2] for e in batch]) |
|
|
next_states = torch.FloatTensor(np.array([e[3] for e in batch])) |
|
|
dones = torch.BoolTensor([e[4] for e in batch]) |
|
|
|
|
|
current_q_values = self.model(states).gather(1, actions.unsqueeze(1)) |
|
|
with torch.no_grad(): |
|
|
next_q_values = self.model(next_states).max(1)[0] |
|
|
target_q_values = rewards + (self.gamma * next_q_values * ~dones) |
|
|
|
|
|
loss = self.criterion(current_q_values.squeeze(), target_q_values) |
|
|
|
|
|
self.optimizer.zero_grad() |
|
|
loss.backward() |
|
|
self.optimizer.step() |
|
|
|
|
|
if self.epsilon > self.epsilon_min: |
|
|
self.epsilon *= self.epsilon_decay |
|
|
|
|
|
|
|
|
class PPOAgent: |
|
|
def __init__(self, state_dim, action_dim, lr=3e-4, gamma=0.99, epsilon=0.2, |
|
|
entropy_coef=0.01, value_coef=0.5): |
|
|
self.state_dim = state_dim |
|
|
self.action_dim = action_dim |
|
|
self.gamma = gamma |
|
|
self.epsilon = epsilon |
|
|
self.entropy_coef = entropy_coef |
|
|
self.value_coef = value_coef |
|
|
|
|
|
self.model = PPONetwork(state_dim, action_dim) |
|
|
self.optimizer = optim.Adam(self.model.parameters(), lr=lr) |
|
|
|
|
|
self.memory = [] |
|
|
|
|
|
def remember(self, state, action, reward, value, log_prob): |
|
|
self.memory.append((state, action, reward, value, log_prob)) |
|
|
|
|
|
def act(self, state): |
|
|
state = torch.FloatTensor(state).unsqueeze(0) |
|
|
with torch.no_grad(): |
|
|
probs, value = self.model(state) |
|
|
dist = Categorical(probs) |
|
|
action = dist.sample() |
|
|
return action.item(), dist.log_prob(action), value.squeeze() |
|
|
|
|
|
def train(self): |
|
|
if not self.memory: |
|
|
return |
|
|
|
|
|
states, actions, rewards, values, log_probs = zip(*self.memory) |
|
|
|
|
|
|
|
|
returns = [] |
|
|
R = 0 |
|
|
|
|
|
for r in reversed(rewards): |
|
|
R = r + self.gamma * R |
|
|
returns.insert(0, R) |
|
|
|
|
|
returns = torch.FloatTensor(returns) |
|
|
values = torch.FloatTensor(values) |
|
|
advantages = returns - values |
|
|
|
|
|
|
|
|
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) |
|
|
|
|
|
|
|
|
states = torch.FloatTensor(np.array(states)) |
|
|
actions = torch.LongTensor(actions) |
|
|
old_log_probs = torch.FloatTensor(log_probs) |
|
|
|
|
|
|
|
|
new_probs, new_values = self.model(states) |
|
|
dist = Categorical(new_probs) |
|
|
new_log_probs = dist.log_prob(actions) |
|
|
entropy = dist.entropy().mean() |
|
|
|
|
|
|
|
|
ratio = (new_log_probs - old_log_probs).exp() |
|
|
surr1 = ratio * advantages |
|
|
surr2 = torch.clamp(ratio, 1 - self.epsilon, 1 + self.epsilon) * advantages |
|
|
actor_loss = -torch.min(surr1, surr2).mean() |
|
|
|
|
|
critic_loss = F.mse_loss(new_values.squeeze(), returns) |
|
|
|
|
|
total_loss = actor_loss + self.value_coef * critic_loss - self.entropy_coef * entropy |
|
|
|
|
|
self.optimizer.zero_grad() |
|
|
total_loss.backward() |
|
|
self.optimizer.step() |
|
|
|
|
|
self.memory = [] |
|
|
|
|
|
|
|
|
class TrainingThread(QThread): |
|
|
update_signal = pyqtSignal(dict) |
|
|
frame_signal = pyqtSignal(np.ndarray) |
|
|
|
|
|
def __init__(self, algorithm='dqn', env_name='ALE/Breakout-v5'): |
|
|
super().__init__() |
|
|
self.algorithm = algorithm |
|
|
self.env_name = env_name |
|
|
self.running = False |
|
|
self.env = None |
|
|
self.agent = None |
|
|
|
|
|
def preprocess_state(self, state): |
|
|
|
|
|
state = state.transpose((2, 0, 1)) |
|
|
state = state / 255.0 |
|
|
return state |
|
|
|
|
|
def run(self): |
|
|
self.running = True |
|
|
try: |
|
|
self.env = create_env(self.env_name) |
|
|
state, info = self.env.reset() |
|
|
state = self.preprocess_state(state) |
|
|
|
|
|
n_actions = self.env.action_space.n |
|
|
state_dim = state.shape |
|
|
|
|
|
print(f"Environment: {self.env_name}") |
|
|
print(f"State shape: {state_dim}, Actions: {n_actions}") |
|
|
|
|
|
if self.algorithm == 'dqn': |
|
|
self.agent = DuelingDQNAgent(state_dim, n_actions) |
|
|
else: |
|
|
self.agent = PPOAgent(state_dim, n_actions) |
|
|
|
|
|
episode = 0 |
|
|
total_reward = 0 |
|
|
steps = 0 |
|
|
episode_rewards = [] |
|
|
|
|
|
while self.running: |
|
|
try: |
|
|
if self.algorithm == 'dqn': |
|
|
action = self.agent.act(state) |
|
|
next_state, reward, terminated, truncated, info = self.env.step(action) |
|
|
done = terminated or truncated |
|
|
next_state = self.preprocess_state(next_state) |
|
|
self.agent.remember(state, action, reward, next_state, done) |
|
|
self.agent.replay() |
|
|
else: |
|
|
action, log_prob, value = self.agent.act(state) |
|
|
next_state, reward, terminated, truncated, info = self.env.step(action) |
|
|
done = terminated or truncated |
|
|
next_state = self.preprocess_state(next_state) |
|
|
self.agent.remember(state, action, reward, value, log_prob) |
|
|
if done: |
|
|
self.agent.train() |
|
|
|
|
|
state = next_state |
|
|
total_reward += reward |
|
|
steps += 1 |
|
|
|
|
|
|
|
|
try: |
|
|
frame = self.env.render() |
|
|
if frame is not None: |
|
|
self.frame_signal.emit(frame) |
|
|
except Exception as e: |
|
|
|
|
|
frame = np.zeros((210, 160, 3), dtype=np.uint8) |
|
|
self.frame_signal.emit(frame) |
|
|
|
|
|
|
|
|
if steps % 10 == 0: |
|
|
progress_data = { |
|
|
'episode': episode, |
|
|
'total_reward': total_reward, |
|
|
'steps': steps, |
|
|
'epsilon': self.agent.epsilon if self.algorithm == 'dqn' else 0.2, |
|
|
'env_name': self.env_name, |
|
|
'lives': info.get('lives', 0) if isinstance(info, dict) else 0 |
|
|
} |
|
|
self.update_signal.emit(progress_data) |
|
|
|
|
|
if terminated or truncated: |
|
|
episode_rewards.append(total_reward) |
|
|
avg_reward = np.mean(episode_rewards[-10:]) if episode_rewards else total_reward |
|
|
|
|
|
print(f"Episode {episode}: Total Reward: {total_reward:.2f}, " |
|
|
f"Steps: {steps}, Avg Reward (last 10): {avg_reward:.2f}") |
|
|
|
|
|
episode += 1 |
|
|
state, info = self.env.reset() |
|
|
state = self.preprocess_state(state) |
|
|
total_reward = 0 |
|
|
steps = 0 |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error in training loop: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
break |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error setting up environment: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
|
|
|
def stop(self): |
|
|
self.running = False |
|
|
if self.env: |
|
|
self.env.close() |
|
|
|
|
|
|
|
|
class ALE_RLApp(QMainWindow): |
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
self.training_thread = None |
|
|
self.init_ui() |
|
|
|
|
|
def init_ui(self): |
|
|
self.setWindowTitle('🎮 ALE Arcade RL Training') |
|
|
self.setGeometry(100, 100, 1200, 800) |
|
|
|
|
|
central_widget = QWidget() |
|
|
self.setCentralWidget(central_widget) |
|
|
layout = QVBoxLayout(central_widget) |
|
|
|
|
|
|
|
|
title = QLabel('🎮 Arcade Reinforcement Learning (ALE)') |
|
|
title.setFont(QFont('Arial', 16, QFont.Bold)) |
|
|
title.setAlignment(Qt.AlignCenter) |
|
|
layout.addWidget(title) |
|
|
|
|
|
|
|
|
control_layout = QHBoxLayout() |
|
|
|
|
|
self.algorithm_combo = QComboBox() |
|
|
self.algorithm_combo.addItems(['Dueling DQN', 'PPO']) |
|
|
|
|
|
self.env_combo = QComboBox() |
|
|
self.env_combo.addItems([ |
|
|
'ALE/Breakout-v5', |
|
|
'ALE/Pong-v5', |
|
|
'ALE/SpaceInvaders-v5', |
|
|
'ALE/Assault-v5', |
|
|
'ALE/BeamRider-v5', |
|
|
'ALE/Enduro-v5', |
|
|
'ALE/Seaquest-v5', |
|
|
'ALE/Qbert-v5' |
|
|
]) |
|
|
|
|
|
self.start_btn = QPushButton('Start Training') |
|
|
self.start_btn.clicked.connect(self.start_training) |
|
|
|
|
|
self.stop_btn = QPushButton('Stop Training') |
|
|
self.stop_btn.clicked.connect(self.stop_training) |
|
|
self.stop_btn.setEnabled(False) |
|
|
|
|
|
control_layout.addWidget(QLabel('Algorithm:')) |
|
|
control_layout.addWidget(self.algorithm_combo) |
|
|
control_layout.addWidget(QLabel('Environment:')) |
|
|
control_layout.addWidget(self.env_combo) |
|
|
control_layout.addWidget(self.start_btn) |
|
|
control_layout.addWidget(self.stop_btn) |
|
|
control_layout.addStretch() |
|
|
|
|
|
layout.addLayout(control_layout) |
|
|
|
|
|
|
|
|
content_layout = QHBoxLayout() |
|
|
|
|
|
|
|
|
left_frame = QFrame() |
|
|
left_frame.setFrameStyle(QFrame.Box) |
|
|
left_layout = QVBoxLayout(left_frame) |
|
|
|
|
|
self.game_display = QLabel() |
|
|
self.game_display.setMinimumSize(400, 300) |
|
|
self.game_display.setAlignment(Qt.AlignCenter) |
|
|
self.game_display.setText('Game display will appear here\nPress "Start Training" to begin') |
|
|
self.game_display.setStyleSheet('border: 1px solid gray; background-color: black; color: white;') |
|
|
|
|
|
left_layout.addWidget(QLabel('Game Display:')) |
|
|
left_layout.addWidget(self.game_display) |
|
|
|
|
|
|
|
|
right_frame = QFrame() |
|
|
right_frame.setFrameStyle(QFrame.Box) |
|
|
right_layout = QVBoxLayout(right_frame) |
|
|
|
|
|
|
|
|
self.env_label = QLabel('Environment: Not started') |
|
|
self.episode_label = QLabel('Episode: 0') |
|
|
self.reward_label = QLabel('Total Reward: 0') |
|
|
self.steps_label = QLabel('Steps: 0') |
|
|
self.epsilon_label = QLabel('Epsilon: 0') |
|
|
self.lives_label = QLabel('Lives: 0') |
|
|
|
|
|
right_layout.addWidget(self.env_label) |
|
|
right_layout.addWidget(self.episode_label) |
|
|
right_layout.addWidget(self.reward_label) |
|
|
right_layout.addWidget(self.steps_label) |
|
|
right_layout.addWidget(self.epsilon_label) |
|
|
right_layout.addWidget(self.lives_label) |
|
|
|
|
|
|
|
|
right_layout.addWidget(QLabel('Training Log:')) |
|
|
self.log_text = QTextEdit() |
|
|
self.log_text.setMaximumHeight(200) |
|
|
right_layout.addWidget(self.log_text) |
|
|
|
|
|
content_layout.addWidget(left_frame) |
|
|
content_layout.addWidget(right_frame) |
|
|
layout.addLayout(content_layout) |
|
|
|
|
|
def start_training(self): |
|
|
algorithm = 'dqn' if self.algorithm_combo.currentText() == 'Dueling DQN' else 'ppo' |
|
|
env_name = self.env_combo.currentText() |
|
|
|
|
|
self.training_thread = TrainingThread(algorithm, env_name) |
|
|
self.training_thread.update_signal.connect(self.update_training_info) |
|
|
self.training_thread.frame_signal.connect(self.update_game_display) |
|
|
self.training_thread.start() |
|
|
|
|
|
self.start_btn.setEnabled(False) |
|
|
self.stop_btn.setEnabled(True) |
|
|
|
|
|
self.log_text.append(f'Started {self.algorithm_combo.currentText()} training on {env_name}...') |
|
|
|
|
|
def stop_training(self): |
|
|
if self.training_thread: |
|
|
self.training_thread.stop() |
|
|
self.training_thread.wait() |
|
|
|
|
|
self.start_btn.setEnabled(True) |
|
|
self.stop_btn.setEnabled(False) |
|
|
self.log_text.append('Training stopped.') |
|
|
|
|
|
def update_training_info(self, data): |
|
|
self.env_label.setText(f'Environment: {data.get("env_name", "Unknown")}') |
|
|
self.episode_label.setText(f'Episode: {data["episode"]}') |
|
|
self.reward_label.setText(f'Total Reward: {data["total_reward"]:.2f}') |
|
|
self.steps_label.setText(f'Steps: {data["steps"]}') |
|
|
self.epsilon_label.setText(f'Epsilon: {data["epsilon"]:.3f}') |
|
|
self.lives_label.setText(f'Lives: {data.get("lives", 0)}') |
|
|
|
|
|
def update_game_display(self, frame): |
|
|
if frame is not None: |
|
|
try: |
|
|
h, w, ch = frame.shape |
|
|
bytes_per_line = ch * w |
|
|
q_img = QImage(frame.data, w, h, bytes_per_line, QImage.Format_RGB888) |
|
|
pixmap = QPixmap.fromImage(q_img) |
|
|
self.game_display.setPixmap(pixmap.scaled(400, 300, Qt.KeepAspectRatio)) |
|
|
except Exception as e: |
|
|
print(f"Error updating display: {e}") |
|
|
|
|
|
def closeEvent(self, event): |
|
|
self.stop_training() |
|
|
event.accept() |
|
|
|
|
|
def main(): |
|
|
|
|
|
torch.manual_seed(42) |
|
|
np.random.seed(42) |
|
|
random.seed(42) |
|
|
|
|
|
app = QApplication(sys.argv) |
|
|
window = ALE_RLApp() |
|
|
window.show() |
|
|
sys.exit(app.exec_()) |
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |