TroglodyteDerivations commited on
Commit
2db463d
ยท
verified ยท
1 Parent(s): b037d57

Upload 32 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ Super-Mario-RL/mario1.gif filter=lfs diff=lfs merge=lfs -text
37
+ Super-Mario-RL/mario1.mp4 filter=lfs diff=lfs merge=lfs -text
38
+ Super-Mario-RL/mario14.gif filter=lfs diff=lfs merge=lfs -text
39
+ Super-Mario-RL/mario14.mp4 filter=lfs diff=lfs merge=lfs -text
Super-Mario-RL-PyQt5/app.py ADDED
@@ -0,0 +1,673 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import random
3
+ import time
4
+ from collections import deque
5
+
6
+ import gym_super_mario_bros
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import torch.optim as optim
12
+ from gym_super_mario_bros.actions import COMPLEX_MOVEMENT
13
+ from nes_py.wrappers import JoypadSpace
14
+
15
+ from PyQt5.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout,
16
+ QHBoxLayout, QPushButton, QLabel, QComboBox,
17
+ QTextEdit, QProgressBar, QTabWidget, QFrame, QGroupBox)
18
+ from PyQt5.QtCore import QTimer, Qt, pyqtSignal, QThread
19
+ from PyQt5.QtGui import QImage, QPixmap, QFont
20
+ import sys
21
+ import cv2
22
+
23
+ # Import your wrappers (make sure this module exists)
24
+ try:
25
+ from wrappers import *
26
+ except ImportError:
27
+ # Create a proper wrapper if the module doesn't exist
28
+ class SimpleWrapper:
29
+ def __init__(self, env):
30
+ self.env = env
31
+ self.action_space = env.action_space
32
+ self.observation_space = env.observation_space
33
+
34
+ def reset(self):
35
+ return self.env.reset()
36
+
37
+ def step(self, action):
38
+ return self.env.step(action)
39
+
40
+ def render(self, mode='rgb_array'):
41
+ return self.env.render(mode)
42
+
43
+ def close(self):
44
+ if hasattr(self.env, 'close'):
45
+ self.env.close()
46
+
47
+ def wrap_mario(env):
48
+ return SimpleWrapper(env)
49
+
50
+
51
+ class FrameStacker:
52
+ """Handles frame stacking and preprocessing"""
53
+ def __init__(self, frame_size=(84, 84), stack_size=4):
54
+ self.frame_size = frame_size
55
+ self.stack_size = stack_size
56
+ self.frames = deque(maxlen=stack_size)
57
+
58
+ def preprocess_frame(self, frame):
59
+ """Convert frame to grayscale and resize"""
60
+ # Convert to grayscale
61
+ gray = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
62
+ # Resize to 84x84
63
+ resized = cv2.resize(gray, self.frame_size, interpolation=cv2.INTER_AREA)
64
+ # Normalize to [0, 1]
65
+ normalized = resized.astype(np.float32) / 255.0
66
+ return normalized
67
+
68
+ def reset(self, frame):
69
+ """Reset frame stack with initial frame"""
70
+ self.frames.clear()
71
+ processed_frame = self.preprocess_frame(frame)
72
+ for _ in range(self.stack_size):
73
+ self.frames.append(processed_frame)
74
+ return self.get_stacked_frames()
75
+
76
+ def append(self, frame):
77
+ """Add new frame to stack"""
78
+ processed_frame = self.preprocess_frame(frame)
79
+ self.frames.append(processed_frame)
80
+ return self.get_stacked_frames()
81
+
82
+ def get_stacked_frames(self):
83
+ """Get stacked frames as numpy array"""
84
+ stacked = np.array(self.frames)
85
+ return np.ascontiguousarray(stacked)
86
+
87
+
88
+ class replay_memory(object):
89
+ def __init__(self, N):
90
+ self.memory = deque(maxlen=N)
91
+
92
+ def push(self, transition):
93
+ self.memory.append(transition)
94
+
95
+ def sample(self, n):
96
+ return random.sample(self.memory, n)
97
+
98
+ def __len__(self):
99
+ return len(self.memory)
100
+
101
+
102
+ class DuelingDQNModel(nn.Module):
103
+ def __init__(self, n_frame, n_action, device):
104
+ super(DuelingDQNModel, self).__init__()
105
+
106
+ # CNN layers for feature extraction
107
+ self.conv_layers = nn.Sequential(
108
+ nn.Conv2d(n_frame, 32, kernel_size=8, stride=4),
109
+ nn.ReLU(),
110
+ nn.Conv2d(32, 64, kernel_size=4, stride=2),
111
+ nn.ReLU(),
112
+ nn.Conv2d(64, 64, kernel_size=3, stride=1),
113
+ nn.ReLU()
114
+ )
115
+
116
+ # Calculate conv output size
117
+ self.conv_out_size = self._get_conv_out((n_frame, 84, 84))
118
+
119
+ # Advantage stream
120
+ self.advantage_stream = nn.Sequential(
121
+ nn.Linear(self.conv_out_size, 512),
122
+ nn.ReLU(),
123
+ nn.Linear(512, n_action)
124
+ )
125
+
126
+ # Value stream
127
+ self.value_stream = nn.Sequential(
128
+ nn.Linear(self.conv_out_size, 512),
129
+ nn.ReLU(),
130
+ nn.Linear(512, 1)
131
+ )
132
+
133
+ self.device = device
134
+ self.apply(self.init_weights)
135
+
136
+ def _get_conv_out(self, shape):
137
+ with torch.no_grad():
138
+ x = torch.zeros(1, *shape)
139
+ x = self.conv_layers(x)
140
+ return int(np.prod(x.size()))
141
+
142
+ def init_weights(self, m):
143
+ if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
144
+ torch.nn.init.xavier_uniform_(m.weight)
145
+ if m.bias is not None:
146
+ m.bias.data.fill_(0.01)
147
+
148
+ def forward(self, x):
149
+ if not isinstance(x, torch.Tensor):
150
+ x = torch.FloatTensor(x).to(self.device)
151
+
152
+ # Forward through conv layers
153
+ x = self.conv_layers(x)
154
+ x = x.view(x.size(0), -1)
155
+
156
+ # Forward through advantage and value streams
157
+ advantage = self.advantage_stream(x)
158
+ value = self.value_stream(x)
159
+
160
+ # Combine value and advantage
161
+ q_values = value + (advantage - advantage.mean(dim=1, keepdim=True))
162
+
163
+ return q_values
164
+
165
+
166
+ def train(q, q_target, memory, batch_size, gamma, optimizer, device):
167
+ if len(memory) < batch_size:
168
+ return 0.0
169
+
170
+ transitions = memory.sample(batch_size)
171
+ s, r, a, s_prime, done = list(map(list, zip(*transitions)))
172
+
173
+ # Ensure positive strides for all arrays
174
+ s = np.array([np.ascontiguousarray(arr) for arr in s])
175
+ s_prime = np.array([np.ascontiguousarray(arr) for arr in s_prime])
176
+
177
+ # Move computations to device
178
+ s_tensor = torch.FloatTensor(s).to(device)
179
+ s_prime_tensor = torch.FloatTensor(s_prime).to(device)
180
+
181
+ # Get next Q values from target network
182
+ with torch.no_grad():
183
+ next_q_values = q_target(s_prime_tensor)
184
+ next_actions = next_q_values.max(1)[1].unsqueeze(1)
185
+ next_q_value = next_q_values.gather(1, next_actions)
186
+
187
+ # Calculate target Q values
188
+ r = torch.FloatTensor(r).unsqueeze(1).to(device)
189
+ done = torch.FloatTensor(done).unsqueeze(1).to(device)
190
+ target_q_values = r + gamma * next_q_value * (1 - done)
191
+
192
+ # Get current Q values
193
+ a_tensor = torch.LongTensor(a).unsqueeze(1).to(device)
194
+ current_q_values = q(s_tensor).gather(1, a_tensor)
195
+
196
+ # Calculate loss
197
+ loss = F.smooth_l1_loss(current_q_values, target_q_values)
198
+
199
+ # Optimize
200
+ optimizer.zero_grad()
201
+ loss.backward()
202
+
203
+ # Gradient clipping
204
+ torch.nn.utils.clip_grad_norm_(q.parameters(), max_norm=10.0)
205
+
206
+ optimizer.step()
207
+ return loss.item()
208
+
209
+
210
+ def copy_weights(q, q_target):
211
+ q_dict = q.state_dict()
212
+ q_target.load_state_dict(q_dict)
213
+
214
+
215
+ class MarioTrainingThread(QThread):
216
+ update_signal = pyqtSignal(dict)
217
+ frame_signal = pyqtSignal(np.ndarray)
218
+
219
+ def __init__(self, device="cpu"):
220
+ super().__init__()
221
+ self.device = device
222
+ self.running = False
223
+ self.env = None
224
+ self.q = None
225
+ self.q_target = None
226
+ self.optimizer = None
227
+ self.frame_stacker = None
228
+
229
+ # Training parameters
230
+ self.gamma = 0.99
231
+ self.batch_size = 32
232
+ self.memory_size = 10000
233
+ self.eps = 1.0 # Start with full exploration
234
+ self.eps_min = 0.01
235
+ self.eps_decay = 0.995
236
+ self.update_interval = 1000
237
+ self.save_interval = 100
238
+ self.print_interval = 10
239
+
240
+ self.memory = None
241
+ self.t = 0
242
+ self.k = 0
243
+ self.total_score = 0.0
244
+ self.loss_accumulator = 0.0
245
+ self.best_score = -float('inf')
246
+ self.last_x_pos = 0
247
+
248
+ def setup_training(self):
249
+ n_frame = 4 # Number of stacked frames
250
+ try:
251
+ self.env = gym_super_mario_bros.make("SuperMarioBros-v3")
252
+ self.env = JoypadSpace(self.env, COMPLEX_MOVEMENT)
253
+ self.env = wrap_mario(self.env)
254
+
255
+ # Initialize frame stacker
256
+ self.frame_stacker = FrameStacker(frame_size=(84, 84), stack_size=n_frame)
257
+
258
+ self.q = DuelingDQNModel(n_frame, self.env.action_space.n, self.device).to(self.device)
259
+ self.q_target = DuelingDQNModel(n_frame, self.env.action_space.n, self.device).to(self.device)
260
+
261
+ copy_weights(self.q, self.q_target)
262
+
263
+ # Set target network to eval mode
264
+ self.q_target.eval()
265
+
266
+ # Optimizer
267
+ self.optimizer = optim.Adam(self.q.parameters(), lr=0.0001, weight_decay=1e-5)
268
+
269
+ self.memory = replay_memory(self.memory_size)
270
+
271
+ self.log_message(f"โœ… Training setup complete - Actions: {self.env.action_space.n}, Device: {self.device}")
272
+
273
+ except Exception as e:
274
+ self.log_message(f"โŒ Error setting up training: {e}")
275
+ import traceback
276
+ traceback.print_exc()
277
+ self.running = False
278
+
279
+ def run(self):
280
+ self.running = True
281
+ self.setup_training()
282
+
283
+ if not self.running:
284
+ return
285
+
286
+ start_time = time.perf_counter()
287
+ score_lst = []
288
+
289
+ try:
290
+ for k in range(1000000):
291
+ if not self.running:
292
+ break
293
+
294
+ # Reset environment and frame stacker
295
+ frame = self.env.reset()
296
+ s = self.frame_stacker.reset(frame)
297
+ done = False
298
+ episode_loss = 0.0
299
+ episode_steps = 0
300
+ episode_score = 0.0
301
+ self.last_x_pos = 0
302
+
303
+ while not done and self.running:
304
+ # Ensure state has positive strides before processing
305
+ s_processed = np.ascontiguousarray(s)
306
+
307
+ # Epsilon-greedy action selection
308
+ if np.random.random() <= self.eps:
309
+ a = self.env.action_space.sample()
310
+ else:
311
+ with torch.no_grad():
312
+ # Add batch dimension and create tensor
313
+ state_tensor = torch.FloatTensor(s_processed).unsqueeze(0).to(self.device)
314
+ q_values = self.q(state_tensor)
315
+
316
+ if self.device == "cuda" or self.device == "mps":
317
+ a = np.argmax(q_values.cpu().numpy())
318
+ else:
319
+ a = np.argmax(q_values.detach().numpy())
320
+
321
+ # Take action
322
+ frame, r, done, info = self.env.step(a)
323
+
324
+ # Update frame stack
325
+ s_prime = self.frame_stacker.append(frame)
326
+
327
+ episode_score += r
328
+
329
+ # Enhanced reward shaping
330
+ reward = r # Start with original reward
331
+
332
+ # Bonus for x_pos progress
333
+ if 'x_pos' in info:
334
+ x_pos = info['x_pos']
335
+ x_progress = x_pos - self.last_x_pos
336
+ if x_progress > 0:
337
+ reward += 0.1 * x_progress
338
+ self.last_x_pos = x_pos
339
+
340
+ # Large bonus for completing the level
341
+ if done and info.get('flag_get', False):
342
+ reward += 100.0
343
+ self.log_message(f"๐ŸŽ‰ LEVEL COMPLETED at episode {k}! ๐ŸŽ‰")
344
+
345
+ # Store transition with contiguous arrays
346
+ s_contiguous = np.ascontiguousarray(s)
347
+ s_prime_contiguous = np.ascontiguousarray(s_prime)
348
+ self.memory.push((s_contiguous, float(reward), int(a), s_prime_contiguous, int(1 - done)))
349
+
350
+ s = s_prime
351
+ stage = info.get('stage', 1)
352
+ world = info.get('world', 1)
353
+
354
+ # Emit frame for display
355
+ try:
356
+ display_frame = self.env.render()
357
+ if display_frame is not None:
358
+ # Ensure frame has positive strides
359
+ frame_contiguous = np.ascontiguousarray(display_frame)
360
+ self.frame_signal.emit(frame_contiguous)
361
+ except Exception as e:
362
+ # Create a placeholder frame if rendering fails
363
+ frame = np.zeros((240, 256, 3), dtype=np.uint8)
364
+ self.frame_signal.emit(frame)
365
+
366
+ # Train only if we have enough samples
367
+ if len(self.memory) > self.batch_size:
368
+ loss_val = train(self.q, self.q_target, self.memory, self.batch_size,
369
+ self.gamma, self.optimizer, self.device)
370
+ if loss_val > 0:
371
+ self.loss_accumulator += loss_val
372
+ episode_loss += loss_val
373
+ self.t += 1
374
+
375
+ # Update target network
376
+ if self.t % self.update_interval == 0:
377
+ copy_weights(self.q, self.q_target)
378
+
379
+ episode_steps += 1
380
+
381
+ # Emit training progress every 10 steps
382
+ if episode_steps % 10 == 0:
383
+ progress_data = {
384
+ 'episode': k,
385
+ 'total_reward': episode_score,
386
+ 'steps': episode_steps,
387
+ 'epsilon': self.eps,
388
+ 'world': world,
389
+ 'stage': stage,
390
+ 'loss': episode_loss / (episode_steps + 1e-8),
391
+ 'memory_size': len(self.memory),
392
+ 'x_pos': info.get('x_pos', 0),
393
+ 'score': info.get('score', 0),
394
+ 'coins': info.get('coins', 0),
395
+ 'time': info.get('time', 400),
396
+ 'flag_get': info.get('flag_get', False)
397
+ }
398
+ self.update_signal.emit(progress_data)
399
+
400
+ # Epsilon decay after each episode
401
+ if self.eps > self.eps_min:
402
+ self.eps *= self.eps_decay
403
+
404
+ # Update total score
405
+ self.total_score += episode_score
406
+
407
+ # Save best model
408
+ if episode_score > self.best_score and k > 0:
409
+ self.best_score = episode_score
410
+ torch.save(self.q.state_dict(), "enhanced_mario_q_best.pth")
411
+ torch.save(self.q_target.state_dict(), "enhanced_mario_q_target_best.pth")
412
+ self.log_message(f"๐Ÿ’พ New best model saved! Score: {self.best_score:.2f}")
413
+
414
+ # Save models periodically
415
+ if k % self.save_interval == 0 and k > 0:
416
+ torch.save(self.q.state_dict(), "enhanced_mario_q.pth")
417
+ torch.save(self.q_target.state_dict(), "enhanced_mario_q_target.pth")
418
+ self.log_message(f"๐Ÿ’พ Models saved at episode {k}")
419
+
420
+ # Print progress
421
+ if k % self.print_interval == 0 and k > 0:
422
+ time_spent = time.perf_counter() - start_time
423
+ start_time = time.perf_counter()
424
+
425
+ avg_loss = self.loss_accumulator / (self.print_interval * max(episode_steps, 1))
426
+ avg_score = self.total_score / self.print_interval
427
+
428
+ log_msg = (
429
+ f"{self.device} | Ep: {k} | Score: {avg_score:.2f} | Loss: {avg_loss:.4f} | "
430
+ f"Stage: {world}-{stage} | Eps: {self.eps:.3f} | Time: {time_spent:.2f}s | "
431
+ f"Mem: {len(self.memory)} | Steps: {episode_steps}"
432
+ )
433
+ self.log_message(log_msg)
434
+
435
+ score_lst.append(avg_score)
436
+ self.total_score = 0.0
437
+ self.loss_accumulator = 0.0
438
+
439
+ try:
440
+ pickle.dump(score_lst, open("score.p", "wb"))
441
+ except Exception as e:
442
+ self.log_message(f"โš ๏ธ Could not save scores: {e}")
443
+
444
+ self.k = k
445
+
446
+ except Exception as e:
447
+ self.log_message(f"โŒ Training error: {e}")
448
+ import traceback
449
+ traceback.print_exc()
450
+
451
+ def log_message(self, message):
452
+ progress_data = {
453
+ 'log_message': message
454
+ }
455
+ self.update_signal.emit(progress_data)
456
+
457
+ def stop(self):
458
+ self.running = False
459
+ if self.env:
460
+ try:
461
+ self.env.close()
462
+ except:
463
+ pass
464
+
465
+
466
+ class MarioRLApp(QMainWindow):
467
+ def __init__(self):
468
+ super().__init__()
469
+ self.training_thread = None
470
+ self.init_ui()
471
+
472
+ def init_ui(self):
473
+ self.setWindowTitle('๐ŸŽฎ Super Mario Bros - Dueling DQN Training')
474
+ self.setGeometry(100, 100, 1200, 800)
475
+
476
+ central_widget = QWidget()
477
+ self.setCentralWidget(central_widget)
478
+ layout = QVBoxLayout(central_widget)
479
+
480
+ # Title
481
+ title = QLabel('๐ŸŽฎ Super Mario Bros - Enhanced Dueling DQN')
482
+ title.setFont(QFont('Arial', 16, QFont.Bold))
483
+ title.setAlignment(Qt.AlignCenter)
484
+ layout.addWidget(title)
485
+
486
+ # Control Panel
487
+ control_layout = QHBoxLayout()
488
+
489
+ self.device_combo = QComboBox()
490
+ self.device_combo.addItems(['cpu', 'cuda', 'mps'])
491
+
492
+ self.start_btn = QPushButton('Start Training')
493
+ self.start_btn.clicked.connect(self.start_training)
494
+
495
+ self.stop_btn = QPushButton('Stop Training')
496
+ self.stop_btn.clicked.connect(self.stop_training)
497
+ self.stop_btn.setEnabled(False)
498
+
499
+ self.load_btn = QPushButton('Load Model')
500
+ self.load_btn.clicked.connect(self.load_model)
501
+
502
+ control_layout.addWidget(QLabel('Device:'))
503
+ control_layout.addWidget(self.device_combo)
504
+ control_layout.addWidget(self.start_btn)
505
+ control_layout.addWidget(self.stop_btn)
506
+ control_layout.addWidget(self.load_btn)
507
+ control_layout.addStretch()
508
+
509
+ layout.addLayout(control_layout)
510
+
511
+ # Content Area
512
+ content_layout = QHBoxLayout()
513
+
514
+ # Left side - Game Display
515
+ left_frame = QFrame()
516
+ left_frame.setFrameStyle(QFrame.Box)
517
+ left_layout = QVBoxLayout(left_frame)
518
+
519
+ self.game_display = QLabel()
520
+ self.game_display.setMinimumSize(400, 300)
521
+ self.game_display.setAlignment(Qt.AlignCenter)
522
+ self.game_display.setText('Game display will appear here\nPress "Start Training" to begin')
523
+ self.game_display.setStyleSheet('border: 1px solid gray; background-color: black; color: white;')
524
+
525
+ left_layout.addWidget(QLabel('Mario Game Display:'))
526
+ left_layout.addWidget(self.game_display)
527
+
528
+ # Right side - Training Info
529
+ right_frame = QFrame()
530
+ right_frame.setFrameStyle(QFrame.Box)
531
+ right_layout = QVBoxLayout(right_frame)
532
+
533
+ # Training stats
534
+ stats_group = QGroupBox("Training Statistics")
535
+ stats_layout = QVBoxLayout(stats_group)
536
+
537
+ self.episode_label = QLabel('Episode: 0')
538
+ self.world_label = QLabel('World: 1-1')
539
+ self.score_label = QLabel('Score: 0')
540
+ self.reward_label = QLabel('Episode Reward: 0')
541
+ self.steps_label = QLabel('Steps: 0')
542
+ self.epsilon_label = QLabel('Epsilon: 1.000')
543
+ self.loss_label = QLabel('Loss: 0.0000')
544
+ self.memory_label = QLabel('Memory: 0')
545
+ self.xpos_label = QLabel('X Position: 0')
546
+ self.coins_label = QLabel('Coins: 0')
547
+ self.time_label = QLabel('Time: 400')
548
+ self.flag_label = QLabel('Flag: No')
549
+
550
+ stats_layout.addWidget(self.episode_label)
551
+ stats_layout.addWidget(self.world_label)
552
+ stats_layout.addWidget(self.score_label)
553
+ stats_layout.addWidget(self.reward_label)
554
+ stats_layout.addWidget(self.steps_label)
555
+ stats_layout.addWidget(self.epsilon_label)
556
+ stats_layout.addWidget(self.loss_label)
557
+ stats_layout.addWidget(self.memory_label)
558
+ stats_layout.addWidget(self.xpos_label)
559
+ stats_layout.addWidget(self.coins_label)
560
+ stats_layout.addWidget(self.time_label)
561
+ stats_layout.addWidget(self.flag_label)
562
+
563
+ right_layout.addWidget(stats_group)
564
+
565
+ # Training log
566
+ right_layout.addWidget(QLabel('Training Log:'))
567
+ self.log_text = QTextEdit()
568
+ self.log_text.setMaximumHeight(300)
569
+ right_layout.addWidget(self.log_text)
570
+
571
+ content_layout.addWidget(left_frame)
572
+ content_layout.addWidget(right_frame)
573
+ layout.addLayout(content_layout)
574
+
575
+ def start_training(self):
576
+ device = self.device_combo.currentText()
577
+
578
+ # Check device availability
579
+ if device == "cuda" and not torch.cuda.is_available():
580
+ self.log_text.append("โŒ CUDA not available, using CPU instead")
581
+ device = "cpu"
582
+ elif device == "mps" and not torch.backends.mps.is_available():
583
+ self.log_text.append("โŒ MPS not available, using CPU instead")
584
+ device = "cpu"
585
+
586
+ self.training_thread = MarioTrainingThread(device)
587
+ self.training_thread.update_signal.connect(self.update_training_info)
588
+ self.training_thread.frame_signal.connect(self.update_game_display)
589
+ self.training_thread.start()
590
+
591
+ self.start_btn.setEnabled(False)
592
+ self.stop_btn.setEnabled(True)
593
+
594
+ self.log_text.append(f'๐Ÿš€ Started Dueling DQN training on {device}...')
595
+
596
+ def stop_training(self):
597
+ if self.training_thread:
598
+ self.training_thread.stop()
599
+ self.training_thread.wait()
600
+
601
+ self.start_btn.setEnabled(True)
602
+ self.stop_btn.setEnabled(False)
603
+ self.log_text.append('โน๏ธ Training stopped.')
604
+
605
+ def load_model(self):
606
+ # Placeholder for model loading functionality
607
+ self.log_text.append('๐Ÿ“ Load model functionality not implemented yet')
608
+
609
+ def update_training_info(self, data):
610
+ if 'episode' in data:
611
+ self.episode_label.setText(f'Episode: {data["episode"]}')
612
+ if 'world' in data and 'stage' in data:
613
+ self.world_label.setText(f'World: {data["world"]}-{data["stage"]}')
614
+ if 'score' in data:
615
+ self.score_label.setText(f'Score: {data["score"]}')
616
+ if 'total_reward' in data:
617
+ self.reward_label.setText(f'Episode Reward: {data["total_reward"]:.2f}')
618
+ if 'steps' in data:
619
+ self.steps_label.setText(f'Steps: {data["steps"]}')
620
+ if 'epsilon' in data:
621
+ self.epsilon_label.setText(f'Epsilon: {data["epsilon"]:.3f}')
622
+ if 'loss' in data:
623
+ self.loss_label.setText(f'Loss: {data["loss"]:.4f}')
624
+ if 'memory_size' in data:
625
+ self.memory_label.setText(f'Memory: {data["memory_size"]}')
626
+ if 'x_pos' in data:
627
+ self.xpos_label.setText(f'X Position: {data["x_pos"]}')
628
+ if 'coins' in data:
629
+ self.coins_label.setText(f'Coins: {data["coins"]}')
630
+ if 'time' in data:
631
+ self.time_label.setText(f'Time: {data["time"]}')
632
+ if 'flag_get' in data:
633
+ flag_text = "Yes" if data["flag_get"] else "No"
634
+ self.flag_label.setText(f'Flag: {flag_text}')
635
+ if 'log_message' in data:
636
+ self.log_text.append(data['log_message'])
637
+ # Auto-scroll to bottom
638
+ self.log_text.verticalScrollBar().setValue(
639
+ self.log_text.verticalScrollBar().maximum()
640
+ )
641
+
642
+ def update_game_display(self, frame):
643
+ if frame is not None:
644
+ try:
645
+ h, w, ch = frame.shape
646
+ bytes_per_line = ch * w
647
+ # Ensure contiguous array
648
+ frame_contiguous = np.ascontiguousarray(frame)
649
+ q_img = QImage(frame_contiguous.data, w, h, bytes_per_line, QImage.Format_RGB888)
650
+ pixmap = QPixmap.fromImage(q_img)
651
+ self.game_display.setPixmap(pixmap.scaled(400, 300, Qt.KeepAspectRatio))
652
+ except Exception as e:
653
+ print(f"Error updating display: {e}")
654
+
655
+ def closeEvent(self, event):
656
+ self.stop_training()
657
+ event.accept()
658
+
659
+
660
+ def main():
661
+ # Set random seeds for reproducibility
662
+ torch.manual_seed(42)
663
+ np.random.seed(42)
664
+ random.seed(42)
665
+
666
+ app = QApplication(sys.argv)
667
+ window = MarioRLApp()
668
+ window.show()
669
+ sys.exit(app.exec_())
670
+
671
+
672
+ if __name__ == '__main__':
673
+ main()
Super-Mario-RL-PyQt5/app_2.py ADDED
@@ -0,0 +1,676 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import random
3
+ import time
4
+ from collections import deque
5
+
6
+ import gym_super_mario_bros
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import torch.optim as optim
12
+ from gym_super_mario_bros.actions import COMPLEX_MOVEMENT
13
+ from nes_py.wrappers import JoypadSpace
14
+
15
+ from PyQt5.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout,
16
+ QHBoxLayout, QPushButton, QLabel, QComboBox,
17
+ QTextEdit, QProgressBar, QTabWidget, QFrame, QGroupBox)
18
+ from PyQt5.QtCore import QTimer, Qt, pyqtSignal, QThread
19
+ from PyQt5.QtGui import QImage, QPixmap, QFont
20
+ import sys
21
+ import cv2
22
+
23
+ # Import your wrappers (make sure this module exists)
24
+ try:
25
+ from wrappers import *
26
+ except ImportError:
27
+ # Create a proper wrapper if the module doesn't exist
28
+ class SimpleWrapper:
29
+ def __init__(self, env):
30
+ self.env = env
31
+ self.action_space = env.action_space
32
+ self.observation_space = env.observation_space
33
+
34
+ def reset(self):
35
+ return self.env.reset()
36
+
37
+ def step(self, action):
38
+ return self.env.step(action)
39
+
40
+ def render(self, mode='rgb_array'):
41
+ return self.env.render(mode)
42
+
43
+ def close(self):
44
+ if hasattr(self.env, 'close'):
45
+ self.env.close()
46
+
47
+ def wrap_mario(env):
48
+ return SimpleWrapper(env)
49
+
50
+
51
+ class FrameStacker:
52
+ """Handles frame stacking and preprocessing for the neural network"""
53
+ def __init__(self, frame_size=(84, 84), stack_size=4):
54
+ self.frame_size = frame_size
55
+ self.stack_size = stack_size
56
+ self.frames = deque(maxlen=stack_size)
57
+
58
+ def preprocess_frame(self, frame):
59
+ """Convert frame to grayscale and resize for the neural network"""
60
+ # Convert to grayscale
61
+ gray = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
62
+ # Resize to 84x84
63
+ resized = cv2.resize(gray, self.frame_size, interpolation=cv2.INTER_AREA)
64
+ # Normalize to [0, 1]
65
+ normalized = resized.astype(np.float32) / 255.0
66
+ return normalized
67
+
68
+ def reset(self, frame):
69
+ """Reset frame stack with initial frame"""
70
+ self.frames.clear()
71
+ processed_frame = self.preprocess_frame(frame)
72
+ for _ in range(self.stack_size):
73
+ self.frames.append(processed_frame)
74
+ return self.get_stacked_frames()
75
+
76
+ def append(self, frame):
77
+ """Add new frame to stack"""
78
+ processed_frame = self.preprocess_frame(frame)
79
+ self.frames.append(processed_frame)
80
+ return self.get_stacked_frames()
81
+
82
+ def get_stacked_frames(self):
83
+ """Get stacked frames as numpy array for the neural network"""
84
+ stacked = np.array(self.frames)
85
+ return np.ascontiguousarray(stacked)
86
+
87
+
88
+ class replay_memory(object):
89
+ def __init__(self, N):
90
+ self.memory = deque(maxlen=N)
91
+
92
+ def push(self, transition):
93
+ self.memory.append(transition)
94
+
95
+ def sample(self, n):
96
+ return random.sample(self.memory, n)
97
+
98
+ def __len__(self):
99
+ return len(self.memory)
100
+
101
+
102
+ class DuelingDQNModel(nn.Module):
103
+ def __init__(self, n_frame, n_action, device):
104
+ super(DuelingDQNModel, self).__init__()
105
+
106
+ # CNN layers for feature extraction
107
+ self.conv_layers = nn.Sequential(
108
+ nn.Conv2d(n_frame, 32, kernel_size=8, stride=4),
109
+ nn.ReLU(),
110
+ nn.Conv2d(32, 64, kernel_size=4, stride=2),
111
+ nn.ReLU(),
112
+ nn.Conv2d(64, 64, kernel_size=3, stride=1),
113
+ nn.ReLU()
114
+ )
115
+
116
+ # Calculate conv output size
117
+ self.conv_out_size = self._get_conv_out((n_frame, 84, 84))
118
+
119
+ # Advantage stream
120
+ self.advantage_stream = nn.Sequential(
121
+ nn.Linear(self.conv_out_size, 512),
122
+ nn.ReLU(),
123
+ nn.Linear(512, n_action)
124
+ )
125
+
126
+ # Value stream
127
+ self.value_stream = nn.Sequential(
128
+ nn.Linear(self.conv_out_size, 512),
129
+ nn.ReLU(),
130
+ nn.Linear(512, 1)
131
+ )
132
+
133
+ self.device = device
134
+ self.apply(self.init_weights)
135
+
136
+ def _get_conv_out(self, shape):
137
+ with torch.no_grad():
138
+ x = torch.zeros(1, *shape)
139
+ x = self.conv_layers(x)
140
+ return int(np.prod(x.size()))
141
+
142
+ def init_weights(self, m):
143
+ if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
144
+ torch.nn.init.xavier_uniform_(m.weight)
145
+ if m.bias is not None:
146
+ m.bias.data.fill_(0.01)
147
+
148
+ def forward(self, x):
149
+ if not isinstance(x, torch.Tensor):
150
+ x = torch.FloatTensor(x).to(self.device)
151
+
152
+ # Forward through conv layers
153
+ x = self.conv_layers(x)
154
+ x = x.view(x.size(0), -1)
155
+
156
+ # Forward through advantage and value streams
157
+ advantage = self.advantage_stream(x)
158
+ value = self.value_stream(x)
159
+
160
+ # Combine value and advantage
161
+ q_values = value + (advantage - advantage.mean(dim=1, keepdim=True))
162
+
163
+ return q_values
164
+
165
+
166
+ def train(q, q_target, memory, batch_size, gamma, optimizer, device):
167
+ if len(memory) < batch_size:
168
+ return 0.0
169
+
170
+ transitions = memory.sample(batch_size)
171
+ s, r, a, s_prime, done = list(map(list, zip(*transitions)))
172
+
173
+ # Ensure positive strides for all arrays
174
+ s = np.array([np.ascontiguousarray(arr) for arr in s])
175
+ s_prime = np.array([np.ascontiguousarray(arr) for arr in s_prime])
176
+
177
+ # Move computations to device
178
+ s_tensor = torch.FloatTensor(s).to(device)
179
+ s_prime_tensor = torch.FloatTensor(s_prime).to(device)
180
+
181
+ # Get next Q values from target network
182
+ with torch.no_grad():
183
+ next_q_values = q_target(s_prime_tensor)
184
+ next_actions = next_q_values.max(1)[1].unsqueeze(1)
185
+ next_q_value = next_q_values.gather(1, next_actions)
186
+
187
+ # Calculate target Q values
188
+ r = torch.FloatTensor(r).unsqueeze(1).to(device)
189
+ done = torch.FloatTensor(done).unsqueeze(1).to(device)
190
+ target_q_values = r + gamma * next_q_value * (1 - done)
191
+
192
+ # Get current Q values
193
+ a_tensor = torch.LongTensor(a).unsqueeze(1).to(device)
194
+ current_q_values = q(s_tensor).gather(1, a_tensor)
195
+
196
+ # Calculate loss
197
+ loss = F.smooth_l1_loss(current_q_values, target_q_values)
198
+
199
+ # Optimize
200
+ optimizer.zero_grad()
201
+ loss.backward()
202
+
203
+ # Gradient clipping
204
+ torch.nn.utils.clip_grad_norm_(q.parameters(), max_norm=10.0)
205
+
206
+ optimizer.step()
207
+ return loss.item()
208
+
209
+
210
+ def copy_weights(q, q_target):
211
+ q_dict = q.state_dict()
212
+ q_target.load_state_dict(q_dict)
213
+
214
+
215
+ class MarioTrainingThread(QThread):
216
+ update_signal = pyqtSignal(dict)
217
+ frame_signal = pyqtSignal(np.ndarray)
218
+
219
+ def __init__(self, device="cpu"):
220
+ super().__init__()
221
+ self.device = device
222
+ self.running = False
223
+ self.env = None
224
+ self.q = None
225
+ self.q_target = None
226
+ self.optimizer = None
227
+ self.frame_stacker = None
228
+
229
+ # Training parameters
230
+ self.gamma = 0.99
231
+ self.batch_size = 32
232
+ self.memory_size = 10000
233
+ self.eps = 1.0 # Start with full exploration
234
+ self.eps_min = 0.01
235
+ self.eps_decay = 0.995
236
+ self.update_interval = 1000
237
+ self.save_interval = 100
238
+ self.print_interval = 10
239
+
240
+ self.memory = None
241
+ self.t = 0
242
+ self.k = 0
243
+ self.total_score = 0.0
244
+ self.loss_accumulator = 0.0
245
+ self.best_score = -float('inf')
246
+ self.last_x_pos = 0
247
+
248
+ def setup_training(self):
249
+ n_frame = 4 # Number of stacked frames
250
+ try:
251
+ self.env = gym_super_mario_bros.make("SuperMarioBros-v3")
252
+ self.env = JoypadSpace(self.env, COMPLEX_MOVEMENT)
253
+ self.env = wrap_mario(self.env)
254
+
255
+ # Initialize frame stacker
256
+ self.frame_stacker = FrameStacker(frame_size=(84, 84), stack_size=n_frame)
257
+
258
+ self.q = DuelingDQNModel(n_frame, self.env.action_space.n, self.device).to(self.device)
259
+ self.q_target = DuelingDQNModel(n_frame, self.env.action_space.n, self.device).to(self.device)
260
+
261
+ copy_weights(self.q, self.q_target)
262
+
263
+ # Set target network to eval mode
264
+ self.q_target.eval()
265
+
266
+ # Optimizer
267
+ self.optimizer = optim.Adam(self.q.parameters(), lr=0.0001, weight_decay=1e-5)
268
+
269
+ self.memory = replay_memory(self.memory_size)
270
+
271
+ self.log_message(f"โœ… Training setup complete - Actions: {self.env.action_space.n}, Device: {self.device}")
272
+
273
+ except Exception as e:
274
+ self.log_message(f"โŒ Error setting up training: {e}")
275
+ import traceback
276
+ traceback.print_exc()
277
+ self.running = False
278
+
279
+ def run(self):
280
+ self.running = True
281
+ self.setup_training()
282
+
283
+ if not self.running:
284
+ return
285
+
286
+ start_time = time.perf_counter()
287
+ score_lst = []
288
+
289
+ try:
290
+ for k in range(1000000):
291
+ if not self.running:
292
+ break
293
+
294
+ # Reset environment and frame stacker
295
+ frame = self.env.reset()
296
+ s = self.frame_stacker.reset(frame)
297
+ done = False
298
+ episode_loss = 0.0
299
+ episode_steps = 0
300
+ episode_score = 0.0
301
+ self.last_x_pos = 0
302
+
303
+ while not done and self.running:
304
+ # Ensure state has positive strides before processing
305
+ s_processed = np.ascontiguousarray(s)
306
+
307
+ # Epsilon-greedy action selection
308
+ if np.random.random() <= self.eps:
309
+ a = self.env.action_space.sample()
310
+ else:
311
+ with torch.no_grad():
312
+ # Add batch dimension and create tensor
313
+ state_tensor = torch.FloatTensor(s_processed).unsqueeze(0).to(self.device)
314
+ q_values = self.q(state_tensor)
315
+
316
+ if self.device == "cuda" or self.device == "mps":
317
+ a = np.argmax(q_values.cpu().numpy())
318
+ else:
319
+ a = np.argmax(q_values.detach().numpy())
320
+
321
+ # Take action
322
+ next_frame, r, done, info = self.env.step(a)
323
+
324
+ # Update frame stack for neural network
325
+ s_prime = self.frame_stacker.append(next_frame)
326
+
327
+ episode_score += r
328
+
329
+ # Enhanced reward shaping
330
+ reward = r # Start with original reward
331
+
332
+ # Bonus for x_pos progress
333
+ if 'x_pos' in info:
334
+ x_pos = info['x_pos']
335
+ x_progress = x_pos - self.last_x_pos
336
+ if x_progress > 0:
337
+ reward += 0.1 * x_progress
338
+ self.last_x_pos = x_pos
339
+
340
+ # Large bonus for completing the level
341
+ if done and info.get('flag_get', False):
342
+ reward += 100.0
343
+ self.log_message(f"๐ŸŽ‰ LEVEL COMPLETED at episode {k}! ๐ŸŽ‰")
344
+
345
+ # Store transition with contiguous arrays
346
+ s_contiguous = np.ascontiguousarray(s)
347
+ s_prime_contiguous = np.ascontiguousarray(s_prime)
348
+ self.memory.push((s_contiguous, float(reward), int(a), s_prime_contiguous, int(1 - done)))
349
+
350
+ s = s_prime
351
+ stage = info.get('stage', 1)
352
+ world = info.get('world', 1)
353
+
354
+ # Emit ORIGINAL COLOR FRAME for display (not preprocessed)
355
+ try:
356
+ # Get the original color frame for display
357
+ display_frame = self.env.render()
358
+ if display_frame is not None:
359
+ # Ensure frame has positive strides and emit original color frame
360
+ frame_contiguous = np.ascontiguousarray(display_frame)
361
+ self.frame_signal.emit(frame_contiguous)
362
+ except Exception as e:
363
+ # Create a placeholder frame if rendering fails
364
+ frame = np.zeros((240, 256, 3), dtype=np.uint8)
365
+ self.frame_signal.emit(frame)
366
+
367
+ # Train only if we have enough samples
368
+ if len(self.memory) > self.batch_size:
369
+ loss_val = train(self.q, self.q_target, self.memory, self.batch_size,
370
+ self.gamma, self.optimizer, self.device)
371
+ if loss_val > 0:
372
+ self.loss_accumulator += loss_val
373
+ episode_loss += loss_val
374
+ self.t += 1
375
+
376
+ # Update target network
377
+ if self.t % self.update_interval == 0:
378
+ copy_weights(self.q, self.q_target)
379
+ self.log_message(f"๐Ÿ”„ Target network updated at step {self.t}")
380
+
381
+ episode_steps += 1
382
+
383
+ # Emit training progress every 5 steps for more frequent updates
384
+ if episode_steps % 5 == 0:
385
+ progress_data = {
386
+ 'episode': k,
387
+ 'total_reward': episode_score,
388
+ 'steps': episode_steps,
389
+ 'epsilon': self.eps,
390
+ 'world': world,
391
+ 'stage': stage,
392
+ 'loss': episode_loss / (episode_steps + 1e-8),
393
+ 'memory_size': len(self.memory),
394
+ 'x_pos': info.get('x_pos', 0),
395
+ 'score': info.get('score', 0),
396
+ 'coins': info.get('coins', 0),
397
+ 'time': info.get('time', 400),
398
+ 'flag_get': info.get('flag_get', False)
399
+ }
400
+ self.update_signal.emit(progress_data)
401
+
402
+ # Epsilon decay after each episode
403
+ if self.eps > self.eps_min:
404
+ self.eps *= self.eps_decay
405
+
406
+ # Update total score
407
+ self.total_score += episode_score
408
+
409
+ # Save best model
410
+ if episode_score > self.best_score and k > 0:
411
+ self.best_score = episode_score
412
+ torch.save(self.q.state_dict(), "enhanced_mario_q_best.pth")
413
+ torch.save(self.q_target.state_dict(), "enhanced_mario_q_target_best.pth")
414
+ self.log_message(f"๐Ÿ’พ New best model saved! Score: {self.best_score:.2f}")
415
+
416
+ # Save models periodically
417
+ if k % self.save_interval == 0 and k > 0:
418
+ torch.save(self.q.state_dict(), "enhanced_mario_q.pth")
419
+ torch.save(self.q_target.state_dict(), "enhanced_mario_q_target.pth")
420
+ self.log_message(f"๐Ÿ’พ Models saved at episode {k}")
421
+
422
+ # Print progress
423
+ if k % self.print_interval == 0 and k > 0:
424
+ time_spent = time.perf_counter() - start_time
425
+ start_time = time.perf_counter()
426
+
427
+ avg_loss = self.loss_accumulator / (self.print_interval * max(episode_steps, 1))
428
+ avg_score = self.total_score / self.print_interval
429
+
430
+ log_msg = (
431
+ f"{self.device} | Ep: {k} | Score: {avg_score:.2f} | Loss: {avg_loss:.4f} | "
432
+ f"Stage: {world}-{stage} | Eps: {self.eps:.3f} | Time: {time_spent:.2f}s | "
433
+ f"Mem: {len(self.memory)} | Steps: {episode_steps}"
434
+ )
435
+ self.log_message(log_msg)
436
+
437
+ score_lst.append(avg_score)
438
+ self.total_score = 0.0
439
+ self.loss_accumulator = 0.0
440
+
441
+ try:
442
+ pickle.dump(score_lst, open("score.p", "wb"))
443
+ except Exception as e:
444
+ self.log_message(f"โš ๏ธ Could not save scores: {e}")
445
+
446
+ self.k = k
447
+
448
+ except Exception as e:
449
+ self.log_message(f"โŒ Training error: {e}")
450
+ import traceback
451
+ traceback.print_exc()
452
+
453
+ def log_message(self, message):
454
+ progress_data = {
455
+ 'log_message': message
456
+ }
457
+ self.update_signal.emit(progress_data)
458
+
459
+ def stop(self):
460
+ self.running = False
461
+ if self.env:
462
+ try:
463
+ self.env.close()
464
+ except:
465
+ pass
466
+
467
+
468
+ class MarioRLApp(QMainWindow):
469
+ def __init__(self):
470
+ super().__init__()
471
+ self.training_thread = None
472
+ self.init_ui()
473
+
474
+ def init_ui(self):
475
+ self.setWindowTitle('๐ŸŽฎ Super Mario Bros - Dueling DQN Training')
476
+ self.setGeometry(100, 100, 1200, 800)
477
+
478
+ central_widget = QWidget()
479
+ self.setCentralWidget(central_widget)
480
+ layout = QVBoxLayout(central_widget)
481
+
482
+ # Title
483
+ title = QLabel('๐ŸŽฎ Super Mario Bros - Enhanced Dueling DQN')
484
+ title.setFont(QFont('Arial', 16, QFont.Bold))
485
+ title.setAlignment(Qt.AlignCenter)
486
+ layout.addWidget(title)
487
+
488
+ # Control Panel
489
+ control_layout = QHBoxLayout()
490
+
491
+ self.device_combo = QComboBox()
492
+ self.device_combo.addItems(['cpu', 'cuda', 'mps'])
493
+
494
+ self.start_btn = QPushButton('Start Training')
495
+ self.start_btn.clicked.connect(self.start_training)
496
+
497
+ self.stop_btn = QPushButton('Stop Training')
498
+ self.stop_btn.clicked.connect(self.stop_training)
499
+ self.stop_btn.setEnabled(False)
500
+
501
+ self.load_btn = QPushButton('Load Model')
502
+ self.load_btn.clicked.connect(self.load_model)
503
+
504
+ control_layout.addWidget(QLabel('Device:'))
505
+ control_layout.addWidget(self.device_combo)
506
+ control_layout.addWidget(self.start_btn)
507
+ control_layout.addWidget(self.stop_btn)
508
+ control_layout.addWidget(self.load_btn)
509
+ control_layout.addStretch()
510
+
511
+ layout.addLayout(control_layout)
512
+
513
+ # Content Area
514
+ content_layout = QHBoxLayout()
515
+
516
+ # Left side - Game Display
517
+ left_frame = QFrame()
518
+ left_frame.setFrameStyle(QFrame.Box)
519
+ left_layout = QVBoxLayout(left_frame)
520
+
521
+ self.game_display = QLabel()
522
+ self.game_display.setMinimumSize(400, 300)
523
+ self.game_display.setAlignment(Qt.AlignCenter)
524
+ self.game_display.setText('Game display will appear here\nPress "Start Training" to begin')
525
+ self.game_display.setStyleSheet('border: 1px solid gray; background-color: black; color: white;')
526
+
527
+ left_layout.addWidget(QLabel('Mario Game Display:'))
528
+ left_layout.addWidget(self.game_display)
529
+
530
+ # Right side - Training Info
531
+ right_frame = QFrame()
532
+ right_frame.setFrameStyle(QFrame.Box)
533
+ right_layout = QVBoxLayout(right_frame)
534
+
535
+ # Training stats
536
+ stats_group = QGroupBox("Training Statistics")
537
+ stats_layout = QVBoxLayout(stats_group)
538
+
539
+ self.episode_label = QLabel('Episode: 0')
540
+ self.world_label = QLabel('World: 1-1')
541
+ self.score_label = QLabel('Score: 0')
542
+ self.reward_label = QLabel('Episode Reward: 0')
543
+ self.steps_label = QLabel('Steps: 0')
544
+ self.epsilon_label = QLabel('Epsilon: 1.000')
545
+ self.loss_label = QLabel('Loss: 0.0000')
546
+ self.memory_label = QLabel('Memory: 0')
547
+ self.xpos_label = QLabel('X Position: 0')
548
+ self.coins_label = QLabel('Coins: 0')
549
+ self.time_label = QLabel('Time: 400')
550
+ self.flag_label = QLabel('Flag: No')
551
+
552
+ stats_layout.addWidget(self.episode_label)
553
+ stats_layout.addWidget(self.world_label)
554
+ stats_layout.addWidget(self.score_label)
555
+ stats_layout.addWidget(self.reward_label)
556
+ stats_layout.addWidget(self.steps_label)
557
+ stats_layout.addWidget(self.epsilon_label)
558
+ stats_layout.addWidget(self.loss_label)
559
+ stats_layout.addWidget(self.memory_label)
560
+ stats_layout.addWidget(self.xpos_label)
561
+ stats_layout.addWidget(self.coins_label)
562
+ stats_layout.addWidget(self.time_label)
563
+ stats_layout.addWidget(self.flag_label)
564
+
565
+ right_layout.addWidget(stats_group)
566
+
567
+ # Training log
568
+ right_layout.addWidget(QLabel('Training Log:'))
569
+ self.log_text = QTextEdit()
570
+ self.log_text.setMaximumHeight(300)
571
+ right_layout.addWidget(self.log_text)
572
+
573
+ content_layout.addWidget(left_frame)
574
+ content_layout.addWidget(right_frame)
575
+ layout.addLayout(content_layout)
576
+
577
+ def start_training(self):
578
+ device = self.device_combo.currentText()
579
+
580
+ # Check device availability
581
+ if device == "cuda" and not torch.cuda.is_available():
582
+ self.log_text.append("โŒ CUDA not available, using CPU instead")
583
+ device = "cpu"
584
+ elif device == "mps" and not torch.backends.mps.is_available():
585
+ self.log_text.append("โŒ MPS not available, using CPU instead")
586
+ device = "cpu"
587
+
588
+ self.training_thread = MarioTrainingThread(device)
589
+ self.training_thread.update_signal.connect(self.update_training_info)
590
+ self.training_thread.frame_signal.connect(self.update_game_display)
591
+ self.training_thread.start()
592
+
593
+ self.start_btn.setEnabled(False)
594
+ self.stop_btn.setEnabled(True)
595
+
596
+ self.log_text.append(f'๐Ÿš€ Started Dueling DQN training on {device}...')
597
+
598
+ def stop_training(self):
599
+ if self.training_thread:
600
+ self.training_thread.stop()
601
+ self.training_thread.wait()
602
+
603
+ self.start_btn.setEnabled(True)
604
+ self.stop_btn.setEnabled(False)
605
+ self.log_text.append('โน๏ธ Training stopped.')
606
+
607
+ def load_model(self):
608
+ # Placeholder for model loading functionality
609
+ self.log_text.append('๐Ÿ“ Load model functionality not implemented yet')
610
+
611
+ def update_training_info(self, data):
612
+ if 'episode' in data:
613
+ self.episode_label.setText(f'Episode: {data["episode"]}')
614
+ if 'world' in data and 'stage' in data:
615
+ self.world_label.setText(f'World: {data["world"]}-{data["stage"]}')
616
+ if 'score' in data:
617
+ self.score_label.setText(f'Score: {data["score"]}')
618
+ if 'total_reward' in data:
619
+ self.reward_label.setText(f'Episode Reward: {data["total_reward"]:.2f}')
620
+ if 'steps' in data:
621
+ self.steps_label.setText(f'Steps: {data["steps"]}')
622
+ if 'epsilon' in data:
623
+ self.epsilon_label.setText(f'Epsilon: {data["epsilon"]:.3f}')
624
+ if 'loss' in data:
625
+ self.loss_label.setText(f'Loss: {data["loss"]:.4f}')
626
+ if 'memory_size' in data:
627
+ self.memory_label.setText(f'Memory: {data["memory_size"]}')
628
+ if 'x_pos' in data:
629
+ self.xpos_label.setText(f'X Position: {data["x_pos"]}')
630
+ if 'coins' in data:
631
+ self.coins_label.setText(f'Coins: {data["coins"]}')
632
+ if 'time' in data:
633
+ self.time_label.setText(f'Time: {data["time"]}')
634
+ if 'flag_get' in data:
635
+ flag_text = "Yes" if data["flag_get"] else "No"
636
+ self.flag_label.setText(f'Flag: {flag_text}')
637
+ if 'log_message' in data:
638
+ self.log_text.append(data['log_message'])
639
+ # Auto-scroll to bottom
640
+ self.log_text.verticalScrollBar().setValue(
641
+ self.log_text.verticalScrollBar().maximum()
642
+ )
643
+
644
+ def update_game_display(self, frame):
645
+ if frame is not None:
646
+ try:
647
+ h, w, ch = frame.shape
648
+ bytes_per_line = ch * w
649
+ # Ensure contiguous array and display original color frame
650
+ frame_contiguous = np.ascontiguousarray(frame)
651
+ q_img = QImage(frame_contiguous.data, w, h, bytes_per_line, QImage.Format_RGB888)
652
+ pixmap = QPixmap.fromImage(q_img)
653
+ # Scale to fit the display while maintaining aspect ratio
654
+ self.game_display.setPixmap(pixmap.scaled(400, 300, Qt.KeepAspectRatio, Qt.SmoothTransformation))
655
+ except Exception as e:
656
+ print(f"Error updating display: {e}")
657
+
658
+ def closeEvent(self, event):
659
+ self.stop_training()
660
+ event.accept()
661
+
662
+
663
+ def main():
664
+ # Set random seeds for reproducibility
665
+ torch.manual_seed(42)
666
+ np.random.seed(42)
667
+ random.seed(42)
668
+
669
+ app = QApplication(sys.argv)
670
+ window = MarioRLApp()
671
+ window.show()
672
+ sys.exit(app.exec_())
673
+
674
+
675
+ if __name__ == '__main__':
676
+ main()
Super-Mario-RL-PyQt5/enhanced_mario_q_best.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6c82b8bb39904cf745061a6ba1ca2a207977f22f13fb0e72f1141c3f85045eb0
3
+ size 13193617
Super-Mario-RL-PyQt5/enhanced_mario_q_target_best.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a9f6060fca857b4335781a219ffa7450baa20eb8efb9287e9561087978c1a1e0
3
+ size 13193949
Super-Mario-RL-PyQt5/requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ numpy==1.26.4
2
+ torch>=1.6.0
3
+ torchvision
4
+ gym==0.23
5
+ nes-py
6
+ gym-super-mario-bros==7.2.3
7
+ opencv-python
8
+ matplotlib
9
+ pyqt5
Super-Mario-RL-PyQt5/score.p ADDED
Binary file (34 Bytes). View file
 
Super-Mario-RL/README.md ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # :mushroom: Super-Mario-RL
2
+
3
+ This is a private project to make Super Mario Agent.
4
+
5
+ It consists of training an agent to clear Super Mario Bros with deep reinforcement learning methods.
6
+
7
+ Here are my super mario agents with dueling network. ( trained 7,000 epoch )
8
+
9
+ **(25-05-20) SuperMario with PPO has been updated!**
10
+
11
+ <p float="center">
12
+ <img src="/mario1.gif" width="350" />
13
+ <img src="/mario14.gif" width="350" />
14
+ </p>
15
+
16
+ # Get started
17
+
18
+ ## Cloning git
19
+
20
+ ```
21
+ git clone https://github.com/jiseongHAN/Super-Mario-RL.git
22
+ cd Super-Mario-RL
23
+ ```
24
+
25
+ ## Install Requirements
26
+ ```
27
+ pip install -r requirements.txt
28
+ ```
29
+
30
+ ## Or Install Manually
31
+ * Install [openAI gym](http://gym.openai.com/)
32
+ ```
33
+ pip install 'gym'
34
+ ```
35
+ * Install [Pytorch](https://pytorch.org/)
36
+ ```
37
+ pip install torch torchvision
38
+ ```
39
+ * Install [nes-py](https://pypi.org/project/nes-py/)
40
+ ```
41
+ pip install nes-py
42
+ ```
43
+ * Install [gym-super-mario-bros](https://pypi.org/project/gym-super-mario-bros/)
44
+ ```
45
+ pip install gym-super-mario-bros
46
+ ```
47
+
48
+ # Running
49
+
50
+ ## Train
51
+
52
+ * Train with dueling dqn.
53
+ ```
54
+ python duel_dqn.py
55
+ ```
56
+
57
+ * Train with PPO.
58
+
59
+ ```
60
+ python ppo.py
61
+ ```
62
+
63
+ ### Result
64
+ * score.p : save total score every 50 episode
65
+ * *.pth : save weight of q, q_target every 50 training
66
+
67
+
68
+ ## Evaluate
69
+ * (Now, pre-trained agent has been corrupted๐Ÿ˜ข)
70
+ * Test and render trained agent.
71
+ * To test our agent, we need 'q_target.pth' that generated at the training step.
72
+ * (eval.py with PPO is not supported now)
73
+ ```
74
+ python eval.py
75
+ ```
76
+ * Or you can use your own agent.
77
+ ```
78
+ python eval.py your_own_agent.pth
79
+ ```
80
+
81
+ ## Reference
82
+ [Wang, Ziyu, et al. "Dueling network architectures for deep reinforcement learning." International conference on machine learning. PMLR, 2016.](https://arxiv.org/pdf/1511.06581.pdf)
83
+
84
+ [Schulman, J., Wolski, F., Dhariwal, P., Radford, A. & Klimov, O. Proximal policy optimization
85
+ algorithms. arXiv preprint arXiv:1707.06347 (2017).](https://arxiv.org/pdf/1707.06347)
Super-Mario-RL/__pycache__/wrappers.cpython-313.pyc ADDED
Binary file (18.7 kB). View file
 
Super-Mario-RL/duel_dqn.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import random
3
+ import time
4
+ from collections import deque
5
+
6
+ import gym_super_mario_bros
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import torch.optim as optim
12
+ from gym_super_mario_bros.actions import COMPLEX_MOVEMENT
13
+ from nes_py.wrappers import JoypadSpace
14
+
15
+ from wrappers import *
16
+
17
+
18
+ def arrange(s):
19
+ if not type(s) == "numpy.ndarray":
20
+ s = np.array(s)
21
+ assert len(s.shape) == 3
22
+ ret = np.transpose(s, (2, 0, 1))
23
+ return np.expand_dims(ret, 0)
24
+
25
+
26
+ class replay_memory(object):
27
+ def __init__(self, N):
28
+ self.memory = deque(maxlen=N)
29
+
30
+ def push(self, transition):
31
+ self.memory.append(transition)
32
+
33
+ def sample(self, n):
34
+ return random.sample(self.memory, n)
35
+
36
+ def __len__(self):
37
+ return len(self.memory)
38
+
39
+
40
+ class model(nn.Module):
41
+ def __init__(self, n_frame, n_action, device):
42
+ super(model, self).__init__()
43
+ self.layer1 = nn.Conv2d(n_frame, 32, 8, 4)
44
+ self.layer2 = nn.Conv2d(32, 64, 3, 1)
45
+ self.fc = nn.Linear(20736, 512)
46
+ self.q = nn.Linear(512, n_action)
47
+ self.v = nn.Linear(512, 1)
48
+
49
+ self.device = device
50
+ self.seq = nn.Sequential(self.layer1, self.layer2, self.fc, self.q, self.v)
51
+
52
+ self.seq.apply(init_weights)
53
+
54
+ def forward(self, x):
55
+ if type(x) != torch.Tensor:
56
+ x = torch.FloatTensor(x).to(self.device)
57
+ x = torch.relu(self.layer1(x))
58
+ x = torch.relu(self.layer2(x))
59
+ x = x.view(-1, 20736)
60
+ x = torch.relu(self.fc(x))
61
+ adv = self.q(x)
62
+ v = self.v(x)
63
+ q = v + (adv - 1 / adv.shape[-1] * adv.sum(-1, keepdim=True))
64
+
65
+ return q
66
+
67
+
68
+ def init_weights(m):
69
+ if type(m) == nn.Conv2d:
70
+ torch.nn.init.xavier_uniform_(m.weight)
71
+ m.bias.data.fill_(0.01)
72
+
73
+
74
+ def train(q, q_target, memory, batch_size, gamma, optimizer, device):
75
+ s, r, a, s_prime, done = list(map(list, zip(*memory.sample(batch_size))))
76
+ s = np.array(s).squeeze()
77
+ s_prime = np.array(s_prime).squeeze()
78
+ a_max = q(s_prime).max(1)[1].unsqueeze(-1)
79
+ r = torch.FloatTensor(r).unsqueeze(-1).to(device)
80
+ done = torch.FloatTensor(done).unsqueeze(-1).to(device)
81
+ with torch.no_grad():
82
+ y = r + gamma * q_target(s_prime).gather(1, a_max) * done
83
+ a = torch.tensor(a).unsqueeze(-1).to(device)
84
+ q_value = torch.gather(q(s), dim=1, index=a.view(-1, 1).long())
85
+
86
+ loss = F.smooth_l1_loss(q_value, y).mean()
87
+ optimizer.zero_grad()
88
+ loss.backward()
89
+ optimizer.step()
90
+ return loss
91
+
92
+
93
+ def copy_weights(q, q_target):
94
+ q_dict = q.state_dict()
95
+ q_target.load_state_dict(q_dict)
96
+
97
+
98
+ def main(env, q, q_target, optimizer, device):
99
+ t = 0
100
+ gamma = 0.99
101
+ batch_size = 256
102
+
103
+ N = 50000
104
+ eps = 0.001
105
+ memory = replay_memory(N)
106
+ update_interval = 50
107
+ print_interval = 10
108
+
109
+ score_lst = []
110
+ total_score = 0.0
111
+ loss = 0.0
112
+ start_time = time.perf_counter()
113
+
114
+ for k in range(1000000):
115
+ s = arrange(env.reset())
116
+ done = False
117
+
118
+ while not done:
119
+ if eps > np.random.rand():
120
+ a = env.action_space.sample()
121
+ else:
122
+ if device == "cpu":
123
+ a = np.argmax(q(s).detach().numpy())
124
+ else:
125
+ a = np.argmax(q(s).cpu().detach().numpy())
126
+ s_prime, r, done, _ = env.step(a)
127
+ s_prime = arrange(s_prime)
128
+ total_score += r
129
+ r = np.sign(r) * (np.sqrt(abs(r) + 1) - 1) + 0.001 * r
130
+ memory.push((s, float(r), int(a), s_prime, int(1 - done)))
131
+ s = s_prime
132
+ stage = env.unwrapped._stage
133
+ if len(memory) > 2000:
134
+ loss += train(q, q_target, memory, batch_size, gamma, optimizer, device)
135
+ t += 1
136
+ if t % update_interval == 0:
137
+ copy_weights(q, q_target)
138
+ torch.save(q.state_dict(), "mario_q.pth")
139
+ torch.save(q_target.state_dict(), "mario_q_target.pth")
140
+
141
+ if k % print_interval == 0:
142
+ time_spent, start_time = (
143
+ time.perf_counter() - start_time,
144
+ time.perf_counter(),
145
+ )
146
+ print(
147
+ "%s |Epoch : %d | score : %f | loss : %.2f | stage : %d | time spent: %f"
148
+ % (
149
+ device,
150
+ k,
151
+ total_score / print_interval,
152
+ loss / print_interval,
153
+ stage,
154
+ time_spent,
155
+ )
156
+ )
157
+ score_lst.append(total_score / print_interval)
158
+ total_score = 0
159
+ loss = 0.0
160
+ pickle.dump(score_lst, open("score.p", "wb"))
161
+
162
+
163
+ if __name__ == "__main__":
164
+ n_frame = 4
165
+ env = gym_super_mario_bros.make("SuperMarioBros-v0")
166
+ env = JoypadSpace(env, COMPLEX_MOVEMENT)
167
+ env = wrap_mario(env)
168
+ device = "cpu"
169
+ if torch.cuda.is_available():
170
+ device = "cuda"
171
+ elif torch.backends.mps.is_available():
172
+ device = "mps"
173
+ q = model(n_frame, env.action_space.n, device).to(device)
174
+ q_target = model(n_frame, env.action_space.n, device).to(device)
175
+ optimizer = optim.Adam(q.parameters(), lr=0.0001)
176
+ print(device)
177
+
178
+ main(env, q, q_target, optimizer, device)
Super-Mario-RL/duel_dqn_2.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import random
3
+ import time
4
+ from collections import deque
5
+
6
+ import gym_super_mario_bros
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import torch.optim as optim
12
+ from gym_super_mario_bros.actions import COMPLEX_MOVEMENT
13
+ from nes_py.wrappers import JoypadSpace
14
+
15
+ from wrappers import *
16
+
17
+
18
+ def arrange(s):
19
+ if not type(s) == "numpy.ndarray":
20
+ s = np.array(s)
21
+ assert len(s.shape) == 3
22
+ ret = np.transpose(s, (2, 0, 1))
23
+ return np.expand_dims(ret, 0)
24
+
25
+
26
+ class replay_memory(object):
27
+ def __init__(self, N):
28
+ self.memory = deque(maxlen=N)
29
+
30
+ def push(self, transition):
31
+ self.memory.append(transition)
32
+
33
+ def sample(self, n):
34
+ return random.sample(self.memory, n)
35
+
36
+ def __len__(self):
37
+ return len(self.memory)
38
+
39
+
40
+ class model(nn.Module):
41
+ def __init__(self, n_frame, n_action, device):
42
+ super(model, self).__init__()
43
+ self.layer1 = nn.Conv2d(n_frame, 32, 8, 4)
44
+ self.layer2 = nn.Conv2d(32, 64, 3, 1)
45
+ self.fc = nn.Linear(20736, 512)
46
+ self.q = nn.Linear(512, n_action)
47
+ self.v = nn.Linear(512, 1)
48
+
49
+ self.device = device
50
+ self.seq = nn.Sequential(self.layer1, self.layer2, self.fc, self.q, self.v)
51
+
52
+ self.seq.apply(init_weights)
53
+
54
+ def forward(self, x):
55
+ if type(x) != torch.Tensor:
56
+ x = torch.FloatTensor(x).to(self.device)
57
+ x = torch.relu(self.layer1(x))
58
+ x = torch.relu(self.layer2(x))
59
+ x = x.view(-1, 20736)
60
+ x = torch.relu(self.fc(x))
61
+ adv = self.q(x)
62
+ v = self.v(x)
63
+ q = v + (adv - 1 / adv.shape[-1] * adv.sum(-1, keepdim=True))
64
+
65
+ return q
66
+
67
+
68
+ def init_weights(m):
69
+ if type(m) == nn.Conv2d:
70
+ torch.nn.init.xavier_uniform_(m.weight)
71
+ m.bias.data.fill_(0.01)
72
+
73
+
74
+ def train(q, q_target, memory, batch_size, gamma, optimizer, device):
75
+ s, r, a, s_prime, done = list(map(list, zip(*memory.sample(batch_size))))
76
+ s = np.array(s).squeeze()
77
+ s_prime = np.array(s_prime).squeeze()
78
+
79
+ # Move computations to device
80
+ s_tensor = torch.FloatTensor(s).to(device)
81
+ s_prime_tensor = torch.FloatTensor(s_prime).to(device)
82
+
83
+ a_max = q(s_prime_tensor).max(1)[1].unsqueeze(-1)
84
+ r = torch.FloatTensor(r).unsqueeze(-1).to(device)
85
+ done = torch.FloatTensor(done).unsqueeze(-1).to(device)
86
+
87
+ with torch.no_grad():
88
+ y = r + gamma * q_target(s_prime_tensor).gather(1, a_max) * done
89
+
90
+ a = torch.tensor(a).unsqueeze(-1).to(device)
91
+ q_value = torch.gather(q(s_tensor), dim=1, index=a.view(-1, 1).long())
92
+
93
+ loss = F.smooth_l1_loss(q_value, y).mean()
94
+ optimizer.zero_grad()
95
+ loss.backward()
96
+
97
+ # Gradient clipping to prevent explosion
98
+ torch.nn.utils.clip_grad_norm_(q.parameters(), max_norm=1.0)
99
+
100
+ optimizer.step()
101
+ return loss.item() # Use .item() to get scalar value
102
+
103
+
104
+ def copy_weights(q, q_target):
105
+ q_dict = q.state_dict()
106
+ q_target.load_state_dict(q_dict)
107
+
108
+
109
+ def main(env, q, q_target, optimizer, device):
110
+ t = 0
111
+ gamma = 0.99
112
+ batch_size = 256
113
+
114
+ N = 50000
115
+ eps = 0.1 # Increased exploration
116
+ eps_min = 0.01
117
+ eps_decay = 0.999
118
+ memory = replay_memory(N)
119
+ update_interval = 50 # How often to update target network
120
+ save_interval = 100 # How often to save models (in episodes)
121
+ print_interval = 10
122
+
123
+ score_lst = []
124
+ total_score = 0.0
125
+ loss_accumulator = 0.0
126
+ start_time = time.perf_counter()
127
+
128
+ for k in range(1000000):
129
+ s = arrange(env.reset())
130
+ done = False
131
+
132
+ while not done:
133
+ # Epsilon decay
134
+ if eps > eps_min:
135
+ eps *= eps_decay
136
+
137
+ if eps > np.random.rand():
138
+ a = env.action_space.sample()
139
+ else:
140
+ # Get action with proper device handling
141
+ with torch.no_grad():
142
+ q_values = q(torch.FloatTensor(s).to(device))
143
+
144
+ # Move to CPU for numpy conversion regardless of device
145
+ if device == "cuda" or device == "mps":
146
+ a = np.argmax(q_values.cpu().numpy())
147
+ else:
148
+ a = np.argmax(q_values.detach().numpy())
149
+
150
+ s_prime, r, done, info = env.step(a)
151
+ s_prime = arrange(s_prime)
152
+ total_score += r
153
+
154
+ # Enhanced reward shaping
155
+ reward = np.sign(r) * (np.sqrt(abs(r) + 1) - 1) + 0.001 * r
156
+
157
+ # Bonus for x_pos progress
158
+ if 'x_pos' in info:
159
+ x_pos = info['x_pos']
160
+ if hasattr(main, 'last_x_pos'):
161
+ x_progress = x_pos - main.last_x_pos
162
+ if x_progress > 0:
163
+ reward += 0.1 * x_progress # Small bonus for moving right
164
+ main.last_x_pos = x_pos
165
+
166
+ memory.push((s, float(reward), int(a), s_prime, int(1 - done)))
167
+ s = s_prime
168
+ stage = env.unwrapped._stage
169
+
170
+ if len(memory) > 2000:
171
+ loss_val = train(q, q_target, memory, batch_size, gamma, optimizer, device)
172
+ loss_accumulator += loss_val
173
+ t += 1
174
+
175
+ # Update target network (but don't save every time)
176
+ if t % update_interval == 0:
177
+ copy_weights(q, q_target)
178
+
179
+ # Save models less frequently (every save_interval episodes)
180
+ if k % save_interval == 0 and k > 0:
181
+ torch.save(q.state_dict(), "mario_q.pth")
182
+ torch.save(q_target.state_dict(), "mario_q_target.pth")
183
+ print(f"Models saved at episode {k}")
184
+
185
+ if k % print_interval == 0:
186
+ time_spent, start_time = (
187
+ time.perf_counter() - start_time,
188
+ time.perf_counter(),
189
+ )
190
+
191
+ # Fixed: Use loss_accumulator instead of loss and ensure proper formatting
192
+ avg_loss = loss_accumulator / print_interval if print_interval > 0 else 0.0
193
+ avg_score = total_score / print_interval if print_interval > 0 else 0.0
194
+
195
+ print(
196
+ "%s | Epoch : %d | score : %.2f | loss : %.2f | stage : %d | eps : %.3f | time: %.2fs | memory: %d"
197
+ % (
198
+ device,
199
+ k,
200
+ avg_score,
201
+ avg_loss,
202
+ stage,
203
+ eps,
204
+ time_spent,
205
+ len(memory)
206
+ )
207
+ )
208
+ score_lst.append(avg_score)
209
+ total_score = 0.0
210
+ loss_accumulator = 0.0
211
+ pickle.dump(score_lst, open("score.p", "wb"))
212
+
213
+
214
+ if __name__ == "__main__":
215
+ n_frame = 4
216
+ env = gym_super_mario_bros.make("SuperMarioBros-v3")
217
+ env = JoypadSpace(env, COMPLEX_MOVEMENT)
218
+ env = wrap_mario(env)
219
+
220
+ # Device detection with MPS support
221
+ device = "cpu"
222
+ if torch.cuda.is_available():
223
+ device = "cuda"
224
+ elif torch.backends.mps.is_available():
225
+ device = "mps"
226
+
227
+ print(f"Using device: {device}")
228
+
229
+ q = model(n_frame, env.action_space.n, device).to(device)
230
+ q_target = model(n_frame, env.action_space.n, device).to(device)
231
+
232
+ # Copy weights initially
233
+ copy_weights(q, q_target)
234
+
235
+ optimizer = optim.Adam(q.parameters(), lr=0.0001, weight_decay=1e-5) # Added weight decay
236
+
237
+ main(env, q, q_target, optimizer, device)
Super-Mario-RL/enhanced_duel_dqn.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import random
3
+ import time
4
+ from collections import deque
5
+
6
+ import gym_super_mario_bros
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import torch.optim as optim
12
+ from gym_super_mario_bros.actions import COMPLEX_MOVEMENT
13
+ from nes_py.wrappers import JoypadSpace
14
+
15
+ from wrappers import *
16
+
17
+
18
+ def arrange(s):
19
+ if not type(s) == "numpy.ndarray":
20
+ s = np.array(s)
21
+ assert len(s.shape) == 3
22
+ ret = np.transpose(s, (2, 0, 1))
23
+ return np.expand_dims(ret, 0)
24
+
25
+
26
+ class replay_memory(object):
27
+ def __init__(self, N):
28
+ self.memory = deque(maxlen=N)
29
+
30
+ def push(self, transition):
31
+ self.memory.append(transition)
32
+
33
+ def sample(self, n):
34
+ return random.sample(self.memory, n)
35
+
36
+ def __len__(self):
37
+ return len(self.memory)
38
+
39
+
40
+ class model(nn.Module):
41
+ def __init__(self, n_frame, n_action, device):
42
+ super(model, self).__init__()
43
+ self.layer1 = nn.Conv2d(n_frame, 32, 8, 4)
44
+ self.layer2 = nn.Conv2d(32, 64, 3, 1)
45
+ self.fc = nn.Linear(20736, 512)
46
+ self.q = nn.Linear(512, n_action)
47
+ self.v = nn.Linear(512, 1)
48
+
49
+ self.device = device
50
+ self.seq = nn.Sequential(self.layer1, self.layer2, self.fc, self.q, self.v)
51
+
52
+ self.seq.apply(init_weights)
53
+
54
+ def forward(self, x):
55
+ if type(x) != torch.Tensor:
56
+ x = torch.FloatTensor(x).to(self.device)
57
+ x = torch.relu(self.layer1(x))
58
+ x = torch.relu(self.layer2(x))
59
+ x = x.view(-1, 20736)
60
+ x = torch.relu(self.fc(x))
61
+ adv = self.q(x)
62
+ v = self.v(x)
63
+ q = v + (adv - 1 / adv.shape[-1] * adv.sum(-1, keepdim=True))
64
+
65
+ return q
66
+
67
+
68
+ def init_weights(m):
69
+ if type(m) == nn.Conv2d:
70
+ torch.nn.init.xavier_uniform_(m.weight)
71
+ m.bias.data.fill_(0.01)
72
+
73
+
74
+ def train(q, q_target, memory, batch_size, gamma, optimizer, device):
75
+ s, r, a, s_prime, done = list(map(list, zip(*memory.sample(batch_size))))
76
+ s = np.array(s).squeeze()
77
+ s_prime = np.array(s_prime).squeeze()
78
+
79
+ # Move computations to device
80
+ s_tensor = torch.FloatTensor(s).to(device)
81
+ s_prime_tensor = torch.FloatTensor(s_prime).to(device)
82
+
83
+ a_max = q(s_prime_tensor).max(1)[1].unsqueeze(-1)
84
+ r = torch.FloatTensor(r).unsqueeze(-1).to(device)
85
+ done = torch.FloatTensor(done).unsqueeze(-1).to(device)
86
+
87
+ with torch.no_grad():
88
+ y = r + gamma * q_target(s_prime_tensor).gather(1, a_max) * done
89
+
90
+ a = torch.tensor(a).unsqueeze(-1).to(device)
91
+ q_value = torch.gather(q(s_tensor), dim=1, index=a.view(-1, 1).long())
92
+
93
+ loss = F.smooth_l1_loss(q_value, y).mean()
94
+ optimizer.zero_grad()
95
+ loss.backward()
96
+
97
+ # Gradient clipping to prevent explosion
98
+ torch.nn.utils.clip_grad_norm_(q.parameters(), max_norm=10.0) # Increased clipping
99
+
100
+ optimizer.step()
101
+ return loss.item()
102
+
103
+
104
+ def copy_weights(q, q_target):
105
+ q_dict = q.state_dict()
106
+ q_target.load_state_dict(q_dict)
107
+
108
+
109
+ def main(env, q, q_target, optimizer, device):
110
+ t = 0
111
+ gamma = 0.99
112
+ batch_size = 256
113
+
114
+ N = 50000
115
+ eps = 0.3 # Higher initial exploration
116
+ eps_min = 0.05 # Higher minimum exploration
117
+ eps_decay = 0.995 # Slower decay
118
+ memory = replay_memory(N)
119
+ update_interval = 100 # Less frequent target updates
120
+ save_interval = 100
121
+ print_interval = 10
122
+
123
+ score_lst = []
124
+ total_score = 0.0
125
+ loss_accumulator = 0.0
126
+ start_time = time.perf_counter()
127
+
128
+ # Track best score for saving
129
+ best_score = -float('inf')
130
+
131
+ for k in range(1000000):
132
+ s = arrange(env.reset())
133
+ done = False
134
+ episode_loss = 0.0
135
+ episode_steps = 0
136
+
137
+ while not done:
138
+ # Epsilon decay per step
139
+ if eps > eps_min:
140
+ eps *= eps_decay
141
+
142
+ if eps > np.random.rand():
143
+ a = env.action_space.sample()
144
+ else:
145
+ with torch.no_grad():
146
+ q_values = q(torch.FloatTensor(s).to(device))
147
+
148
+ if device == "cuda" or device == "mps":
149
+ a = np.argmax(q_values.cpu().numpy())
150
+ else:
151
+ a = np.argmax(q_values.detach().numpy())
152
+
153
+ s_prime, r, done, info = env.step(a)
154
+ s_prime = arrange(s_prime)
155
+ total_score += r
156
+
157
+ # Enhanced reward shaping
158
+ reward = np.sign(r) * (np.sqrt(abs(r) + 1) - 1) + 0.001 * r
159
+
160
+ # Bonus for x_pos progress and stage completion
161
+ if 'x_pos' in info:
162
+ x_pos = info['x_pos']
163
+ if hasattr(main, 'last_x_pos'):
164
+ x_progress = x_pos - main.last_x_pos
165
+ if x_progress > 0:
166
+ reward += 0.05 * x_progress # Reduced bonus to prevent over-optimization
167
+ main.last_x_pos = x_pos
168
+
169
+ # Large bonus for completing the level
170
+ if done and info.get('flag_get', False):
171
+ reward += 50.0
172
+ print(f"๐ŸŽ‰ LEVEL COMPLETED at episode {k}! ๐ŸŽ‰")
173
+
174
+ memory.push((s, float(reward), int(a), s_prime, int(1 - done)))
175
+ s = s_prime
176
+ stage = info.get('stage', 1)
177
+ world = info.get('world', 1)
178
+
179
+ # Train only if we have enough samples
180
+ if len(memory) > 5000: # Increased minimum buffer size
181
+ loss_val = train(q, q_target, memory, batch_size, gamma, optimizer, device)
182
+ loss_accumulator += loss_val
183
+ episode_loss += loss_val
184
+ episode_steps += 1
185
+ t += 1
186
+
187
+ # Update target network less frequently
188
+ if t % update_interval == 0:
189
+ copy_weights(q, q_target)
190
+
191
+ # Save best model
192
+ current_avg_score = total_score / print_interval if k % print_interval == 0 else total_score
193
+ if current_avg_score > best_score and k > 0:
194
+ best_score = current_avg_score
195
+ torch.save(q.state_dict(), "enhanced_mario_q_best.pth")
196
+ torch.save(q_target.state_dict(), "enhanced_mario_q_target_best.pth")
197
+ print(f"๐Ÿ’พ New best model saved! Score: {best_score:.2f}")
198
+
199
+ # Save models periodically
200
+ if k % save_interval == 0 and k > 0:
201
+ torch.save(q.state_dict(), "enhanced_mario_q.pth")
202
+ torch.save(q_target.state_dict(), "enhanced_mario_q_target.pth")
203
+ print(f"Models saved at episode {k}")
204
+
205
+ if k % print_interval == 0:
206
+ time_spent, start_time = (
207
+ time.perf_counter() - start_time,
208
+ time.perf_counter(),
209
+ )
210
+
211
+ avg_loss = loss_accumulator / (print_interval * episode_steps) if episode_steps > 0 else 0.0
212
+ avg_score = total_score / print_interval
213
+
214
+ print(
215
+ "%s | Ep: %d | Score: %.2f | Loss: %.2f | Stage: %d-%d | Eps: %.3f | Time: %.2fs | Mem: %d | Steps: %d"
216
+ % (
217
+ device,
218
+ k,
219
+ avg_score,
220
+ avg_loss,
221
+ world,
222
+ stage,
223
+ eps,
224
+ time_spent,
225
+ len(memory),
226
+ episode_steps
227
+ )
228
+ )
229
+ score_lst.append(avg_score)
230
+ total_score = 0.0
231
+ loss_accumulator = 0.0
232
+ pickle.dump(score_lst, open("score.p", "wb"))
233
+
234
+
235
+ if __name__ == "__main__":
236
+ n_frame = 4
237
+ env = gym_super_mario_bros.make("SuperMarioBros-v3")
238
+ env = JoypadSpace(env, COMPLEX_MOVEMENT)
239
+ env = wrap_mario(env)
240
+
241
+ device = "cpu"
242
+ if torch.cuda.is_available():
243
+ device = "cuda"
244
+ elif torch.backends.mps.is_available():
245
+ device = "mps"
246
+
247
+ print(f"Using device: {device}")
248
+
249
+ q = model(n_frame, env.action_space.n, device).to(device)
250
+ q_target = model(n_frame, env.action_space.n, device).to(device)
251
+
252
+ copy_weights(q, q_target)
253
+
254
+ # Lower learning rate for stability
255
+ optimizer = optim.Adam(q.parameters(), lr=0.00005, weight_decay=1e-5)
256
+
257
+ main(env, q, q_target, optimizer, device)
Super-Mario-RL/enhanced_mario_q.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:034fc1bde429fd3e9bdde72fb16697707a604dad8db737f49f8e61ad5b442026
3
+ size 42607893
Super-Mario-RL/enhanced_mario_q_best.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c6874204a382dc834d328f078a05110c793ead351eee5034b9f37f09ed6c11b9
3
+ size 42607973
Super-Mario-RL/enhanced_mario_q_target.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6fa5448953e60e5be93f7a6a0843d11af0a279a6e13adafd36cb31f025fdf914
3
+ size 42608069
Super-Mario-RL/enhanced_mario_q_target_best.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4a5f341b08473c536af3398f2b984006c55d9e96952b0b0f5263bf1cdd7f7917
3
+ size 42608213
Super-Mario-RL/eval.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import time
3
+
4
+ import gym_super_mario_bros
5
+ import torch
6
+ import torch.nn as nn
7
+ from gym_super_mario_bros.actions import COMPLEX_MOVEMENT
8
+ from nes_py.wrappers import JoypadSpace
9
+
10
+ from wrappers import *
11
+
12
+ # Device detection
13
+ device = "cpu"
14
+ if torch.cuda.is_available():
15
+ device = "cuda"
16
+ elif torch.backends.mps.is_available():
17
+ device = "mps"
18
+
19
+ print(f"Using device: {device}")
20
+
21
+
22
+ # Same as duel_dqn.mlp (you can make model.py to avoid duplication.)
23
+ class model(nn.Module):
24
+ def __init__(self, n_frame, n_action, device):
25
+ super(model, self).__init__()
26
+ self.layer1 = nn.Conv2d(n_frame, 32, 8, 4)
27
+ self.layer2 = nn.Conv2d(32, 64, 3, 1)
28
+ self.fc = nn.Linear(20736, 512)
29
+ self.q = nn.Linear(512, n_action)
30
+ self.v = nn.Linear(512, 1)
31
+
32
+ self.device = device
33
+ self.seq = nn.Sequential(self.layer1, self.layer2, self.fc, self.q, self.v)
34
+
35
+ self.seq.apply(init_weights)
36
+
37
+ def forward(self, x):
38
+ if type(x) != torch.Tensor:
39
+ x = torch.FloatTensor(x).to(self.device)
40
+ x = torch.relu(self.layer1(x))
41
+ x = torch.relu(self.layer2(x))
42
+ x = x.view(-1, 20736)
43
+ x = torch.relu(self.fc(x))
44
+ adv = self.q(x)
45
+ v = self.v(x)
46
+ q = v + (adv - 1 / adv.shape[-1] * adv.max(-1, True)[0])
47
+ return q
48
+
49
+
50
+ def init_weights(m):
51
+ if type(m) == nn.Conv2d:
52
+ torch.nn.init.xavier_uniform_(m.weight)
53
+ m.bias.data.fill_(0.01)
54
+
55
+
56
+ def arange(s):
57
+ if not type(s) == "numpy.ndarray":
58
+ s = np.array(s)
59
+ assert len(s.shape) == 3
60
+ ret = np.transpose(s, (2, 0, 1))
61
+ return np.expand_dims(ret, 0)
62
+
63
+
64
+ if __name__ == "__main__":
65
+ ckpt_path = sys.argv[1] if len(sys.argv) > 1 else "mario_q_target.pth"
66
+ print(f"Load ckpt from {ckpt_path}")
67
+ n_frame = 4
68
+ env = gym_super_mario_bros.make("SuperMarioBros-v0")
69
+ env = JoypadSpace(env, COMPLEX_MOVEMENT)
70
+ env = wrap_mario(env)
71
+
72
+ q = model(n_frame, env.action_space.n, device).to(device)
73
+
74
+ # Load model with proper device mapping
75
+ try:
76
+ q.load_state_dict(torch.load(ckpt_path, map_location=torch.device(device)))
77
+ print(f"Model loaded successfully on {device}")
78
+ except Exception as e:
79
+ print(f"Error loading model with {device}: {e}")
80
+ print("Trying to load with CPU mapping...")
81
+ q.load_state_dict(torch.load(ckpt_path, map_location="cpu"))
82
+ q = q.to(device)
83
+ print(f"Model loaded with CPU mapping and moved to {device}")
84
+
85
+ total_score = 0.0
86
+ done = False
87
+ s = arange(env.reset())
88
+ i = 0
89
+
90
+ # Evaluation loop
91
+ while not done:
92
+ env.render()
93
+
94
+ # Get Q-values and action
95
+ with torch.no_grad():
96
+ q_values = q(s)
97
+
98
+ # Move to CPU for numpy conversion regardless of device
99
+ if device == "cuda" or device == "mps":
100
+ a = np.argmax(q_values.cpu().numpy())
101
+ else:
102
+ a = np.argmax(q_values.detach().numpy())
103
+
104
+ s_prime, r, done, _ = env.step(a)
105
+ s_prime = arange(s_prime)
106
+ total_score += r
107
+ s = s_prime
108
+ time.sleep(0.001)
109
+
110
+ stage = env.unwrapped._stage
111
+ print("Total score : %f | stage : %d" % (total_score, stage))
Super-Mario-RL/mario1.gif ADDED

Git LFS Details

  • SHA256: e5c7637a136766dcfa9a71503488bd90e6bee3d2677941a8620053380ceb3d0c
  • Pointer size: 133 Bytes
  • Size of remote file: 10.1 MB
Super-Mario-RL/mario1.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2c3e5f49d302a68348659adba9a7f9f1be4d57fd3204214a191a47234aea0cd0
3
+ size 643908
Super-Mario-RL/mario14.gif ADDED

Git LFS Details

  • SHA256: 1ec48d5fb9641eaad80a3d69cc19cd227a9a1e31feb9a58d4c98ed098f7938dd
  • Pointer size: 132 Bytes
  • Size of remote file: 9.65 MB
Super-Mario-RL/mario14.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:52efcbda754d016c56b6cf67d50da683988061b31ae7d43217f995a5db89474a
3
+ size 547827
Super-Mario-RL/mario_q.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2e1d46a2e2822ff25428906b6964f578bcf2f16526cf4edecef0ded6350499d6
3
+ size 42607237
Super-Mario-RL/mario_q_best.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d286a832aaeae13128121500fe34759a50f0e587df5cfd9ac83d780b181ed340
3
+ size 42607829
Super-Mario-RL/mario_q_target.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e7811192034bc02b5236043c3bfc8e982038c970bcd6e948b82cdac18706a964
3
+ size 39059456
Super-Mario-RL/mario_q_target_best.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c358203988162769284fe455cdc9fc047e69e6672d4ec05066f79a20dd54404c
3
+ size 42607941
Super-Mario-RL/ppo.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import Counter
2
+
3
+ import gym_super_mario_bros
4
+ import gymnasium as gym
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.optim as optim
9
+ from gym_super_mario_bros.actions import COMPLEX_MOVEMENT
10
+ from nes_py.wrappers import JoypadSpace
11
+
12
+ from wrappers import *
13
+
14
+ device = "cpu"
15
+ if torch.cuda.is_available():
16
+ device = "cuda"
17
+ elif torch.backends.mps.is_available():
18
+ device = "mps"
19
+
20
+ print(f"Using device: {device}")
21
+
22
+
23
+ def make_env():
24
+ env = gym_super_mario_bros.make("SuperMarioBros-v0")
25
+ env = JoypadSpace(env, COMPLEX_MOVEMENT)
26
+ env = wrap_mario(env)
27
+ return env
28
+
29
+
30
+ def get_reward(r):
31
+ r = np.sign(r) * (np.sqrt(abs(r) + 1) - 1) + 0.001 * r
32
+ return r
33
+
34
+
35
+ class ActorCritic(nn.Module):
36
+ def __init__(self, n_frame, act_dim):
37
+ super().__init__()
38
+ self.net = nn.Sequential(
39
+ nn.Conv2d(n_frame, 32, 8, 4),
40
+ nn.ReLU(),
41
+ nn.Conv2d(32, 64, 3, 1),
42
+ nn.ReLU(),
43
+ )
44
+ self.linear = nn.Linear(20736, 512)
45
+ self.policy_head = nn.Linear(512, act_dim)
46
+ self.value_head = nn.Linear(512, 1)
47
+
48
+ def forward(self, x):
49
+ if x.dim() == 4:
50
+ x = x.permute(0, 3, 1, 2)
51
+ elif x.dim() == 3:
52
+ x = x.permute(2, 0, 1)
53
+ x = self.net(x)
54
+ x = x.reshape(-1, 20736)
55
+ x = torch.relu(self.linear(x))
56
+
57
+ return self.policy_head(x), self.value_head(x).squeeze(-1)
58
+
59
+ def act(self, obs):
60
+ logits, value = self.forward(obs)
61
+ dist = torch.distributions.Categorical(logits=logits)
62
+ action = dist.sample()
63
+ logprob = dist.log_prob(action)
64
+ return action, logprob, value
65
+
66
+
67
+ def compute_gae_batch(rewards, values, dones, gamma=0.99, lam=0.95):
68
+ T, N = rewards.shape
69
+ advantages = torch.zeros_like(rewards)
70
+ gae = torch.zeros(N, device=device)
71
+
72
+ for t in reversed(range(T)):
73
+ not_done = 1.0 - dones[t]
74
+ delta = rewards[t] + gamma * values[t + 1] * not_done - values[t]
75
+ gae = delta + gamma * lam * not_done * gae
76
+ advantages[t] = gae
77
+
78
+ returns = advantages + values[:-1]
79
+ return advantages, returns
80
+
81
+
82
+ def rollout_with_bootstrap(envs, model, rollout_steps, init_obs):
83
+ obs = init_obs
84
+ obs = torch.tensor(obs, dtype=torch.float32).to(device)
85
+ obs_buf, act_buf, rew_buf, done_buf, val_buf, logp_buf = [], [], [], [], [], []
86
+
87
+ for _ in range(rollout_steps):
88
+ obs_buf.append(obs)
89
+
90
+ with torch.no_grad():
91
+ action, logp, value = model.act(obs)
92
+
93
+ val_buf.append(value)
94
+ logp_buf.append(logp)
95
+ act_buf.append(action)
96
+
97
+ actions = action.cpu().numpy()
98
+ next_obs, reward, done, infos = envs.step(actions)
99
+
100
+ reward = [get_reward(r) for r in reward]
101
+ # done = np.logical_or(terminated)
102
+
103
+ rew_buf.append(torch.tensor(reward, dtype=torch.float32).to(device))
104
+ done_buf.append(torch.tensor(done, dtype=torch.float32).to(device))
105
+
106
+ for i, d in enumerate(done):
107
+ if d:
108
+ print(f"Env {i} done. Resetting. (info: {infos[i]})")
109
+ next_obs[i] = envs.envs[i].reset()
110
+
111
+ obs = torch.tensor(next_obs, dtype=torch.float32).to(device)
112
+ max_stage = max([i["stage"] for i in infos])
113
+
114
+ with torch.no_grad():
115
+ _, last_value = model.forward(obs)
116
+
117
+ obs_buf = torch.stack(obs_buf)
118
+ act_buf = torch.stack(act_buf)
119
+ rew_buf = torch.stack(rew_buf)
120
+ done_buf = torch.stack(done_buf)
121
+ val_buf = torch.stack(val_buf)
122
+ val_buf = torch.cat([val_buf, last_value.unsqueeze(0)], dim=0)
123
+ logp_buf = torch.stack(logp_buf)
124
+
125
+ adv_buf, ret_buf = compute_gae_batch(rew_buf, val_buf, done_buf)
126
+ adv_buf = (adv_buf - adv_buf.mean()) / (adv_buf.std() + 1e-8)
127
+
128
+ return {
129
+ "obs": obs_buf, # [T, N, obs_dim]
130
+ "actions": act_buf,
131
+ "logprobs": logp_buf,
132
+ "advantages": adv_buf,
133
+ "returns": ret_buf,
134
+ "max_stage": max_stage,
135
+ "last_obs": obs,
136
+ }
137
+
138
+
139
+ def evaluate_policy(env, model, episodes=5, render=False):
140
+ """
141
+ Function to evaluate the learned policy
142
+
143
+ Args:
144
+ env: gym.Env single environment (not vector!)
145
+
146
+ model: ActorCritic model
147
+
148
+ episodes: number of episodes to evaluate
149
+
150
+ render: whether to visualize (if True, display on screen)
151
+
152
+ Returns:
153
+ avg_return: average total reward
154
+ """
155
+ model.eval()
156
+ total_returns = []
157
+ actions = []
158
+ stages = []
159
+ for ep in range(episodes):
160
+ obs = env.reset()
161
+ done = False
162
+ total_reward = 0
163
+ if render:
164
+ env.render()
165
+ while not done:
166
+ obs_tensor = (
167
+ torch.tensor(np.array(obs), dtype=torch.float32).unsqueeze(0).to(device)
168
+ )
169
+ with torch.no_grad():
170
+ logits, _ = model(obs_tensor)
171
+ dist = torch.distributions.Categorical(logits=logits)
172
+ action = dist.probs.argmax(dim=-1).item() # greedy action
173
+ actions.append(action)
174
+
175
+ obs, reward, done, info = env.step(action)
176
+ stages.append(info["stage"])
177
+ total_reward += reward
178
+
179
+ total_returns.append(total_reward)
180
+ info["action_count"] = Counter(actions)
181
+ model.train()
182
+ return np.mean(total_returns), info, max(stages)
183
+
184
+
185
+ def train_ppo():
186
+ num_env = 8
187
+ envs = gym.vector.SyncVectorEnv([lambda: make_env() for _ in range(num_env)])
188
+ obs_dim = envs.single_observation_space.shape[-1]
189
+ act_dim = envs.single_action_space.n
190
+ print(f"{obs_dim=} {act_dim=}")
191
+ model = ActorCritic(obs_dim, act_dim).to(device)
192
+
193
+ # Load model with proper device mapping
194
+ try:
195
+ # Try to load with current device first
196
+ model.load_state_dict(torch.load("mario_1_1.pt", map_location=device))
197
+ print("Model loaded successfully with current device mapping")
198
+ except:
199
+ try:
200
+ # If that fails, try loading with CPU and then moving to device
201
+ model.load_state_dict(torch.load("mario_1_1.pt", map_location="cpu"))
202
+ model = model.to(device)
203
+ print("Model loaded successfully with CPU mapping")
204
+ except Exception as e:
205
+ print(f"Failed to load model: {e}")
206
+ print("Starting with fresh model")
207
+
208
+ optimizer = optim.Adam(model.parameters(), lr=2.5e-4)
209
+
210
+ rollout_steps = 128
211
+ epochs = 4
212
+ minibatch_size = 64
213
+ clip_eps = 0.2
214
+ vf_coef = 0.5
215
+ ent_coef = 0.01
216
+ eval_env = make_env()
217
+ eval_env.reset()
218
+
219
+ init_obs = envs.reset()
220
+ update = 0
221
+ while True:
222
+ update += 1
223
+ batch = rollout_with_bootstrap(envs, model, rollout_steps, init_obs)
224
+ init_obs = batch["last_obs"]
225
+
226
+ T, N = rollout_steps, envs.num_envs
227
+ total_size = T * N
228
+
229
+ obs = batch["obs"].reshape(total_size, *envs.single_observation_space.shape)
230
+ act = batch["actions"].reshape(total_size)
231
+ logp_old = batch["logprobs"].reshape(total_size)
232
+ adv = batch["advantages"].reshape(total_size)
233
+ ret = batch["returns"].reshape(total_size)
234
+
235
+ for _ in range(epochs):
236
+ idx = torch.randperm(total_size)
237
+ for start in range(0, total_size, minibatch_size):
238
+ i = idx[start : start + minibatch_size]
239
+ logits, value = model(obs[i])
240
+ dist = torch.distributions.Categorical(logits=logits)
241
+ logp = dist.log_prob(act[i])
242
+ ratio = torch.exp(logp - logp_old[i])
243
+ clipped = torch.clamp(ratio, 1 - clip_eps, 1 + clip_eps) * adv[i]
244
+ policy_loss = -torch.min(ratio * adv[i], clipped).mean()
245
+ value_loss = (ret[i] - value).pow(2).mean()
246
+ entropy = dist.entropy().mean()
247
+ loss = policy_loss + vf_coef * value_loss - ent_coef * entropy
248
+
249
+ optimizer.zero_grad()
250
+ loss.backward()
251
+ optimizer.step()
252
+
253
+ # logging
254
+ avg_return = batch["returns"].mean().item()
255
+ max_stage = batch["max_stage"]
256
+ print(f"Update {update}: avg return = {avg_return:.2f} {max_stage=}")
257
+
258
+ # eval and save
259
+ if update % 10 == 0:
260
+ avg_score, info, eval_max_stage = evaluate_policy(
261
+ eval_env, model, episodes=1, render=False
262
+ )
263
+ print(f"[Eval] Update {update}: avg return = {avg_score:.2f} info: {info}")
264
+ if eval_max_stage > 1:
265
+ torch.save(model.state_dict(), "mario_1_1_clear.pt")
266
+ break
267
+ if update > 0 and update % 50 == 0:
268
+ torch.save(model.state_dict(), "mario_1_1_ppo.pt")
269
+
270
+
271
+ if __name__ == "__main__":
272
+ train_ppo()
Super-Mario-RL/requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ numpy==1.26.4
2
+ torch>=1.6.0
3
+ torchvision
4
+ gym==0.23
5
+ nes-py
6
+ gym-super-mario-bros==7.2.3
7
+ opencv-python
8
+ matplotlib
Super-Mario-RL/score.p ADDED
Binary file (1.37 kB). View file
 
Super-Mario-RL/terminal.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ python enhanced_dual_dqn.py
2
+ python eval.py enhanced_mario_q_best.pth
3
+ python eval.py mario_q_target.pth
4
+ python eval.py mario_q_best.pth
5
+ python eval.py mario_q.pth
Super-Mario-RL/wrappers.py ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Code from OpenAI baseline
3
+ https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py
4
+ """
5
+
6
+ import os
7
+
8
+ import numpy as np
9
+
10
+ os.environ.setdefault("PATH", "")
11
+ from collections import deque
12
+
13
+ import cv2
14
+ import gym
15
+ from gym import spaces
16
+
17
+ cv2.ocl.setUseOpenCL(False)
18
+ from gym.wrappers import TimeLimit
19
+
20
+
21
+ class NoopResetEnv(gym.Wrapper):
22
+ def __init__(self, env, noop_max=30):
23
+ """Sample initial states by taking random number of no-ops on reset.
24
+ No-op is assumed to be action 0.
25
+ """
26
+ gym.Wrapper.__init__(self, env)
27
+ self.noop_max = noop_max
28
+ self.override_num_noops = None
29
+ self.noop_action = 0
30
+ assert env.unwrapped.get_action_meanings()[0] == "NOOP"
31
+
32
+ def reset(self, **kwargs):
33
+ """Do no-op action for a number of steps in [1, noop_max]."""
34
+ self.env.reset(**kwargs)
35
+ if self.override_num_noops is not None:
36
+ noops = self.override_num_noops
37
+ else:
38
+ noops = self.unwrapped.np_random.randint(
39
+ 1, self.noop_max + 1
40
+ ) # pylint: disable=E1101
41
+ assert noops > 0
42
+ obs = None
43
+ for _ in range(noops):
44
+ obs, _, done, _ = self.env.step(self.noop_action)
45
+ if done:
46
+ obs = self.env.reset(**kwargs)
47
+ return obs
48
+
49
+ def step(self, ac):
50
+ return self.env.step(ac)
51
+
52
+
53
+ class FireResetEnv(gym.Wrapper):
54
+ def __init__(self, env):
55
+ """Take action on reset for environments that are fixed until firing."""
56
+ gym.Wrapper.__init__(self, env)
57
+ assert env.unwrapped.get_action_meanings()[1] == "FIRE"
58
+ assert len(env.unwrapped.get_action_meanings()) >= 3
59
+
60
+ def reset(self, **kwargs):
61
+ self.env.reset(**kwargs)
62
+ obs, _, done, _ = self.env.step(1)
63
+ if done:
64
+ self.env.reset(**kwargs)
65
+ obs, _, done, _ = self.env.step(2)
66
+ if done:
67
+ self.env.reset(**kwargs)
68
+ return obs
69
+
70
+ def step(self, ac):
71
+ return self.env.step(ac)
72
+
73
+
74
+ class EpisodicLifeEnv(gym.Wrapper):
75
+ def __init__(self, env):
76
+ """Make end-of-life == end-of-episode, but only reset on true game over.
77
+ Done by DeepMind for the DQN and co. since it helps value estimation.
78
+ """
79
+ gym.Wrapper.__init__(self, env)
80
+ self.lives = 0
81
+ self.was_real_done = True
82
+
83
+ def step(self, action):
84
+ obs, reward, done, info = self.env.step(action)
85
+ self.was_real_done = done
86
+ # check current lives, make loss of life terminal,
87
+ # then update lives to handle bonus lives
88
+ lives = self.env.unwrapped.ale.lives()
89
+ if lives < self.lives and lives > 0:
90
+ # for Qbert sometimes we stay in lives == 0 condition for a few frames
91
+ # so it's important to keep lives > 0, so that we only reset once
92
+ # the environment advertises done.
93
+ done = True
94
+ self.lives = lives
95
+ return obs, reward, done, info
96
+
97
+ def reset(self, **kwargs):
98
+ """Reset only when lives are exhausted.
99
+ This way all states are still reachable even though lives are episodic,
100
+ and the learner need not know about any of this behind-the-scenes.
101
+ """
102
+ if self.was_real_done:
103
+ obs = self.env.reset(**kwargs)
104
+ else:
105
+ # no-op step to advance from terminal/lost life state
106
+ obs, _, _, _ = self.env.step(0)
107
+ self.lives = self.env.unwrapped.ale.lives()
108
+ return obs
109
+
110
+
111
+ class MaxAndSkipEnv(gym.Wrapper):
112
+ def __init__(self, env, skip=4):
113
+ """Return only every `skip`-th frame"""
114
+ gym.Wrapper.__init__(self, env)
115
+ # most recent raw observations (for max pooling across time steps)
116
+ self._obs_buffer = np.zeros((2,) + env.observation_space.shape, dtype=np.uint8)
117
+ self._skip = skip
118
+
119
+ def step(self, action):
120
+ """Repeat action, sum reward, and max over last observations."""
121
+ total_reward = 0.0
122
+ done = None
123
+ for i in range(self._skip):
124
+ obs, reward, done, info = self.env.step(action)
125
+ if i == self._skip - 2:
126
+ self._obs_buffer[0] = obs
127
+ if i == self._skip - 1:
128
+ self._obs_buffer[1] = obs
129
+ total_reward += reward
130
+ if done:
131
+ break
132
+ # Note that the observation on the done=True frame
133
+ # doesn't matter
134
+ max_frame = self._obs_buffer.max(axis=0)
135
+
136
+ return max_frame, total_reward, done, info
137
+
138
+ def reset(self, **kwargs):
139
+ return self.env.reset(**kwargs)
140
+
141
+
142
+ class ClipRewardEnv(gym.RewardWrapper):
143
+ def __init__(self, env):
144
+ gym.RewardWrapper.__init__(self, env)
145
+
146
+ def reward(self, reward):
147
+ """Bin reward to {+1, 0, -1} by its sign."""
148
+ return np.sign(reward)
149
+
150
+
151
+ class WarpFrame(gym.ObservationWrapper):
152
+ def __init__(self, env, width=84, height=84, grayscale=True, dict_space_key=None):
153
+ """
154
+ Warp frames to 84x84 as done in the Nature paper and later work.
155
+ If the environment uses dictionary observations, `dict_space_key` can be specified which indicates which
156
+ observation should be warped.
157
+ """
158
+ super().__init__(env)
159
+ self._width = width
160
+ self._height = height
161
+ self._grayscale = grayscale
162
+ self._key = dict_space_key
163
+ if self._grayscale:
164
+ num_colors = 1
165
+ else:
166
+ num_colors = 3
167
+
168
+ new_space = gym.spaces.Box(
169
+ low=0,
170
+ high=255,
171
+ shape=(self._height, self._width, num_colors),
172
+ dtype=np.uint8,
173
+ )
174
+ if self._key is None:
175
+ original_space = self.observation_space
176
+ self.observation_space = new_space
177
+ else:
178
+ original_space = self.observation_space.spaces[self._key]
179
+ self.observation_space.spaces[self._key] = new_space
180
+ assert original_space.dtype == np.uint8 and len(original_space.shape) == 3
181
+
182
+ def observation(self, obs):
183
+ if self._key is None:
184
+ frame = obs
185
+ else:
186
+ frame = obs[self._key]
187
+
188
+ if self._grayscale:
189
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
190
+ frame = cv2.resize(
191
+ frame, (self._width, self._height), interpolation=cv2.INTER_AREA
192
+ )
193
+ if self._grayscale:
194
+ frame = np.expand_dims(frame, -1)
195
+
196
+ if self._key is None:
197
+ obs = frame
198
+ else:
199
+ obs = obs.copy()
200
+ obs[self._key] = frame
201
+ return obs
202
+
203
+
204
+ class FrameStack(gym.Wrapper):
205
+ def __init__(self, env, k):
206
+ """Stack k last frames.
207
+ Returns lazy array, which is much more memory efficient.
208
+ See Also
209
+ --------
210
+ baselines.common.atari_wrappers.LazyFrames
211
+ """
212
+ gym.Wrapper.__init__(self, env)
213
+ self.k = k
214
+ self.frames = deque([], maxlen=k)
215
+ shp = env.observation_space.shape
216
+ self.observation_space = spaces.Box(
217
+ low=0,
218
+ high=255,
219
+ shape=(shp[:-1] + (shp[-1] * k,)),
220
+ dtype=env.observation_space.dtype,
221
+ )
222
+
223
+ def reset(self):
224
+ ob = self.env.reset()
225
+ for _ in range(self.k):
226
+ self.frames.append(ob)
227
+ return self._get_ob()
228
+
229
+ def step(self, action):
230
+ ob, reward, done, info = self.env.step(action)
231
+ self.frames.append(ob)
232
+ return self._get_ob(), reward, done, info
233
+
234
+ def _get_ob(self):
235
+ assert len(self.frames) == self.k
236
+ return LazyFrames(list(self.frames))
237
+
238
+
239
+ class ScaledFloatFrame(gym.ObservationWrapper):
240
+ def __init__(self, env):
241
+ gym.ObservationWrapper.__init__(self, env)
242
+ self.observation_space = gym.spaces.Box(
243
+ low=0, high=1, shape=env.observation_space.shape, dtype=np.float32
244
+ )
245
+
246
+ def observation(self, observation):
247
+ # careful! This undoes the memory optimization, use
248
+ # with smaller replay buffers only.
249
+ return np.array(observation).astype(np.float32) / 255.0
250
+
251
+
252
+ class LazyFrames(object):
253
+ def __init__(self, frames):
254
+ """This object ensures that common frames between the observations are only stored once.
255
+ It exists purely to optimize memory usage which can be huge for DQN's 1M frames replay
256
+ buffers.
257
+ This object should only be converted to numpy array before being passed to the model.
258
+ You'd not believe how complex the previous solution was."""
259
+ self._frames = frames
260
+ self._out = None
261
+
262
+ def _force(self):
263
+ if self._out is None:
264
+ self._out = np.concatenate(self._frames, axis=-1)
265
+ self._frames = None
266
+ return self._out
267
+
268
+ def __array__(self, dtype=None):
269
+ out = self._force()
270
+ if dtype is not None:
271
+ out = out.astype(dtype)
272
+ return out
273
+
274
+ def __len__(self):
275
+ return len(self._force())
276
+
277
+ def __getitem__(self, i):
278
+ return self._force()[i]
279
+
280
+ def count(self):
281
+ frames = self._force()
282
+ return frames.shape[frames.ndim - 1]
283
+
284
+ def frame(self, i):
285
+ return self._force()[..., i]
286
+
287
+
288
+ def make_atari(env_id, max_episode_steps=None):
289
+ env = gym.make(env_id)
290
+ assert "NoFrameskip" in env.spec.id
291
+ env = NoopResetEnv(env, noop_max=30)
292
+ env = MaxAndSkipEnv(env, skip=4)
293
+ if max_episode_steps is not None:
294
+ env = TimeLimit(env, max_episode_steps=max_episode_steps)
295
+ return env
296
+
297
+
298
+ def wrap_deepmind(
299
+ env, episode_life=True, clip_rewards=True, frame_stack=True, scale=True
300
+ ):
301
+ """Configure environment for DeepMind-style Atari."""
302
+ if episode_life:
303
+ env = EpisodicLifeEnv(env)
304
+ if "FIRE" in env.unwrapped.get_action_meanings():
305
+ env = FireResetEnv(env)
306
+ env = WarpFrame(env)
307
+ if scale:
308
+ env = ScaledFloatFrame(env)
309
+ if clip_rewards:
310
+ env = ClipRewardEnv(env)
311
+ if frame_stack:
312
+ env = FrameStack(env, 4)
313
+ return env
314
+
315
+
316
+ class EpisodicLifeMario(gym.Wrapper):
317
+ def __init__(self, env):
318
+ """Make end-of-life == end-of-episode, but only reset on true game over.
319
+ Done by DeepMind for the DQN and co. since it helps value estimation.
320
+ """
321
+ gym.Wrapper.__init__(self, env)
322
+ self.lives = 0
323
+ self.was_real_done = True
324
+
325
+ def step(self, action):
326
+ obs, reward, done, info = self.env.step(action)
327
+ self.was_real_done = done
328
+ # check current lives, make loss of life terminal,
329
+ # then update lives to handle bonus lives
330
+ lives = self.env.unwrapped._life
331
+ if lives < self.lives and lives > 0:
332
+ # for Qbert sometimes we stay in lives == 0 condition for a few frames
333
+ # so it's important to keep lives > 0, so that we only reset once
334
+ # the environment advertises done.
335
+ done = True
336
+ self.lives = lives
337
+ return obs, reward, done, info
338
+
339
+ def reset(self, **kwargs):
340
+ """Reset only when lives are exhausted.
341
+ This way all states are still reachable even though lives are episodic,
342
+ and the learner need not know about any of this behind-the-scenes.
343
+ """
344
+ if self.was_real_done:
345
+ obs = self.env.reset(**kwargs)
346
+ else:
347
+ # no-op step to advance from terminal/lost life state
348
+ obs, _, _, _ = self.env.step(0)
349
+ self.lives = self.env.unwrapped._life
350
+ return obs
351
+
352
+
353
+ def wrap_mario(env):
354
+ env = NoopResetEnv(env, noop_max=30)
355
+ env = MaxAndSkipEnv(env, skip=4)
356
+ env = EpisodicLifeMario(env)
357
+ env = WarpFrame(env)
358
+ env = ScaledFloatFrame(env)
359
+ # env = custom_reward(env)
360
+ env = FrameStack(env, 4)
361
+ return env
ale_pyqt5/app.py ADDED
@@ -0,0 +1,514 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import numpy as np
4
+ import random
5
+ from collections import deque
6
+ import gymnasium as gym
7
+ import ale_py
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.optim as optim
12
+ import torch.nn.functional as F
13
+ from torch.distributions import Categorical
14
+
15
+ from PyQt5.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout,
16
+ QHBoxLayout, QPushButton, QLabel, QComboBox,
17
+ QTextEdit, QProgressBar, QTabWidget, QFrame)
18
+ from PyQt5.QtCore import QTimer, Qt, pyqtSignal, QThread
19
+ from PyQt5.QtGui import QImage, QPixmap, QFont
20
+
21
+ # Register ALE environments
22
+ gym.register_envs(ale_py)
23
+
24
+ # Environment setup
25
+ def create_env(env_name='ALE/Breakout-v5'):
26
+ """
27
+ Create ALE environment with Gymnasium API
28
+ Available environments:
29
+ - ALE/Breakout-v5, ALE/Pong-v5, ALE/SpaceInvaders-v5,
30
+ - ALE/Assault-v5, ALE/BeamRider-v5, ALE/Enduro-v5
31
+ """
32
+ env = gym.make(env_name, render_mode='rgb_array')
33
+ return env
34
+
35
+ # Neural Network for Dueling DQN
36
+ class DuelingDQN(nn.Module):
37
+ def __init__(self, input_shape, n_actions):
38
+ super(DuelingDQN, self).__init__()
39
+ self.conv = nn.Sequential(
40
+ nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
41
+ nn.ReLU(),
42
+ nn.Conv2d(32, 64, kernel_size=4, stride=2),
43
+ nn.ReLU(),
44
+ nn.Conv2d(64, 64, kernel_size=3, stride=1),
45
+ nn.ReLU()
46
+ )
47
+
48
+ conv_out_size = self._get_conv_out(input_shape)
49
+
50
+ self.fc_advantage = nn.Sequential(
51
+ nn.Linear(conv_out_size, 512),
52
+ nn.ReLU(),
53
+ nn.Linear(512, n_actions)
54
+ )
55
+
56
+ self.fc_value = nn.Sequential(
57
+ nn.Linear(conv_out_size, 512),
58
+ nn.ReLU(),
59
+ nn.Linear(512, 1)
60
+ )
61
+
62
+ def _get_conv_out(self, shape):
63
+ o = self.conv(torch.zeros(1, *shape))
64
+ return int(np.prod(o.size()))
65
+
66
+ def forward(self, x):
67
+ conv_out = self.conv(x).view(x.size()[0], -1)
68
+ advantage = self.fc_advantage(conv_out)
69
+ value = self.fc_value(conv_out)
70
+ return value + advantage - advantage.mean()
71
+
72
+ # Neural Network for PPO
73
+ class PPONetwork(nn.Module):
74
+ def __init__(self, input_shape, n_actions):
75
+ super(PPONetwork, self).__init__()
76
+ self.conv = nn.Sequential(
77
+ nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
78
+ nn.ReLU(),
79
+ nn.Conv2d(32, 64, kernel_size=4, stride=2),
80
+ nn.ReLU(),
81
+ nn.Conv2d(64, 64, kernel_size=3, stride=1),
82
+ nn.ReLU()
83
+ )
84
+
85
+ conv_out_size = self._get_conv_out(input_shape)
86
+
87
+ self.actor = nn.Sequential(
88
+ nn.Linear(conv_out_size, 512),
89
+ nn.ReLU(),
90
+ nn.Linear(512, n_actions),
91
+ nn.Softmax(dim=-1)
92
+ )
93
+
94
+ self.critic = nn.Sequential(
95
+ nn.Linear(conv_out_size, 512),
96
+ nn.ReLU(),
97
+ nn.Linear(512, 1)
98
+ )
99
+
100
+ def _get_conv_out(self, shape):
101
+ o = self.conv(torch.zeros(1, *shape))
102
+ return int(np.prod(o.size()))
103
+
104
+ def forward(self, x):
105
+ conv_out = self.conv(x).view(x.size()[0], -1)
106
+ return self.actor(conv_out), self.critic(conv_out)
107
+
108
+ # Dueling DQN Agent
109
+ class DuelingDQNAgent:
110
+ def __init__(self, state_dim, action_dim, lr=1e-4, gamma=0.99, epsilon=1.0,
111
+ epsilon_min=0.01, epsilon_decay=0.995, memory_size=10000, batch_size=32):
112
+ self.state_dim = state_dim
113
+ self.action_dim = action_dim
114
+ self.lr = lr
115
+ self.gamma = gamma
116
+ self.epsilon = epsilon
117
+ self.epsilon_min = epsilon_min
118
+ self.epsilon_decay = epsilon_decay
119
+ self.batch_size = batch_size
120
+
121
+ self.memory = deque(maxlen=memory_size)
122
+ self.model = DuelingDQN(state_dim, action_dim)
123
+ self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
124
+ self.criterion = nn.MSELoss()
125
+
126
+ def remember(self, state, action, reward, next_state, done):
127
+ self.memory.append((state, action, reward, next_state, done))
128
+
129
+ def act(self, state):
130
+ if np.random.random() <= self.epsilon:
131
+ return random.randrange(self.action_dim)
132
+
133
+ state = torch.FloatTensor(state).unsqueeze(0)
134
+ with torch.no_grad():
135
+ q_values = self.model(state)
136
+ return np.argmax(q_values.detach().numpy())
137
+
138
+ def replay(self):
139
+ if len(self.memory) < self.batch_size:
140
+ return
141
+
142
+ batch = random.sample(self.memory, self.batch_size)
143
+ states = torch.FloatTensor(np.array([e[0] for e in batch]))
144
+ actions = torch.LongTensor([e[1] for e in batch])
145
+ rewards = torch.FloatTensor([e[2] for e in batch])
146
+ next_states = torch.FloatTensor(np.array([e[3] for e in batch]))
147
+ dones = torch.BoolTensor([e[4] for e in batch])
148
+
149
+ current_q_values = self.model(states).gather(1, actions.unsqueeze(1))
150
+ with torch.no_grad():
151
+ next_q_values = self.model(next_states).max(1)[0]
152
+ target_q_values = rewards + (self.gamma * next_q_values * ~dones)
153
+
154
+ loss = self.criterion(current_q_values.squeeze(), target_q_values)
155
+
156
+ self.optimizer.zero_grad()
157
+ loss.backward()
158
+ self.optimizer.step()
159
+
160
+ if self.epsilon > self.epsilon_min:
161
+ self.epsilon *= self.epsilon_decay
162
+
163
+ # PPO Agent
164
+ class PPOAgent:
165
+ def __init__(self, state_dim, action_dim, lr=3e-4, gamma=0.99, epsilon=0.2,
166
+ entropy_coef=0.01, value_coef=0.5):
167
+ self.state_dim = state_dim
168
+ self.action_dim = action_dim
169
+ self.gamma = gamma
170
+ self.epsilon = epsilon
171
+ self.entropy_coef = entropy_coef
172
+ self.value_coef = value_coef
173
+
174
+ self.model = PPONetwork(state_dim, action_dim)
175
+ self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
176
+
177
+ self.memory = []
178
+
179
+ def remember(self, state, action, reward, value, log_prob):
180
+ self.memory.append((state, action, reward, value, log_prob))
181
+
182
+ def act(self, state):
183
+ state = torch.FloatTensor(state).unsqueeze(0)
184
+ with torch.no_grad():
185
+ probs, value = self.model(state)
186
+ dist = Categorical(probs)
187
+ action = dist.sample()
188
+ return action.item(), dist.log_prob(action), value.squeeze()
189
+
190
+ def train(self):
191
+ if not self.memory:
192
+ return
193
+
194
+ states, actions, rewards, values, log_probs = zip(*self.memory)
195
+
196
+ # Calculate returns and advantages
197
+ returns = []
198
+ R = 0
199
+
200
+ for r in reversed(rewards):
201
+ R = r + self.gamma * R
202
+ returns.insert(0, R)
203
+
204
+ returns = torch.FloatTensor(returns)
205
+ values = torch.FloatTensor(values)
206
+ advantages = returns - values
207
+
208
+ # Normalize advantages
209
+ advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
210
+
211
+ # Convert to tensors
212
+ states = torch.FloatTensor(np.array(states))
213
+ actions = torch.LongTensor(actions)
214
+ old_log_probs = torch.FloatTensor(log_probs)
215
+
216
+ # Get new probabilities
217
+ new_probs, new_values = self.model(states)
218
+ dist = Categorical(new_probs)
219
+ new_log_probs = dist.log_prob(actions)
220
+ entropy = dist.entropy().mean()
221
+
222
+ # PPO loss
223
+ ratio = (new_log_probs - old_log_probs).exp()
224
+ surr1 = ratio * advantages
225
+ surr2 = torch.clamp(ratio, 1 - self.epsilon, 1 + self.epsilon) * advantages
226
+ actor_loss = -torch.min(surr1, surr2).mean()
227
+
228
+ critic_loss = F.mse_loss(new_values.squeeze(), returns)
229
+
230
+ total_loss = actor_loss + self.value_coef * critic_loss - self.entropy_coef * entropy
231
+
232
+ self.optimizer.zero_grad()
233
+ total_loss.backward()
234
+ self.optimizer.step()
235
+
236
+ self.memory = []
237
+
238
+ # Training Thread
239
+ class TrainingThread(QThread):
240
+ update_signal = pyqtSignal(dict)
241
+ frame_signal = pyqtSignal(np.ndarray)
242
+
243
+ def __init__(self, algorithm='dqn', env_name='ALE/Breakout-v5'):
244
+ super().__init__()
245
+ self.algorithm = algorithm
246
+ self.env_name = env_name
247
+ self.running = False
248
+ self.env = None
249
+ self.agent = None
250
+
251
+ def preprocess_state(self, state):
252
+ # Convert to CHW format and normalize
253
+ state = state.transpose((2, 0, 1))
254
+ state = state / 255.0
255
+ return state
256
+
257
+ def run(self):
258
+ self.running = True
259
+ try:
260
+ self.env = create_env(self.env_name)
261
+ state, info = self.env.reset()
262
+ state = self.preprocess_state(state)
263
+
264
+ n_actions = self.env.action_space.n
265
+ state_dim = state.shape
266
+
267
+ print(f"Environment: {self.env_name}")
268
+ print(f"State shape: {state_dim}, Actions: {n_actions}")
269
+
270
+ if self.algorithm == 'dqn':
271
+ self.agent = DuelingDQNAgent(state_dim, n_actions)
272
+ else:
273
+ self.agent = PPOAgent(state_dim, n_actions)
274
+
275
+ episode = 0
276
+ total_reward = 0
277
+ steps = 0
278
+ episode_rewards = []
279
+
280
+ while self.running:
281
+ try:
282
+ if self.algorithm == 'dqn':
283
+ action = self.agent.act(state)
284
+ next_state, reward, terminated, truncated, info = self.env.step(action)
285
+ done = terminated or truncated
286
+ next_state = self.preprocess_state(next_state)
287
+ self.agent.remember(state, action, reward, next_state, done)
288
+ self.agent.replay()
289
+ else:
290
+ action, log_prob, value = self.agent.act(state)
291
+ next_state, reward, terminated, truncated, info = self.env.step(action)
292
+ done = terminated or truncated
293
+ next_state = self.preprocess_state(next_state)
294
+ self.agent.remember(state, action, reward, value, log_prob)
295
+ if done:
296
+ self.agent.train()
297
+
298
+ state = next_state
299
+ total_reward += reward
300
+ steps += 1
301
+
302
+ # Emit frame for display
303
+ try:
304
+ frame = self.env.render()
305
+ if frame is not None:
306
+ self.frame_signal.emit(frame)
307
+ except Exception as e:
308
+ # Create a placeholder frame if rendering fails
309
+ frame = np.zeros((210, 160, 3), dtype=np.uint8)
310
+ self.frame_signal.emit(frame)
311
+
312
+ # Emit training progress
313
+ if steps % 10 == 0:
314
+ progress_data = {
315
+ 'episode': episode,
316
+ 'total_reward': total_reward,
317
+ 'steps': steps,
318
+ 'epsilon': self.agent.epsilon if self.algorithm == 'dqn' else 0.2,
319
+ 'env_name': self.env_name,
320
+ 'lives': info.get('lives', 0) if isinstance(info, dict) else 0
321
+ }
322
+ self.update_signal.emit(progress_data)
323
+
324
+ if terminated or truncated:
325
+ episode_rewards.append(total_reward)
326
+ avg_reward = np.mean(episode_rewards[-10:]) if episode_rewards else total_reward
327
+
328
+ print(f"Episode {episode}: Total Reward: {total_reward:.2f}, "
329
+ f"Steps: {steps}, Avg Reward (last 10): {avg_reward:.2f}")
330
+
331
+ episode += 1
332
+ state, info = self.env.reset()
333
+ state = self.preprocess_state(state)
334
+ total_reward = 0
335
+ steps = 0
336
+
337
+ except Exception as e:
338
+ print(f"Error in training loop: {e}")
339
+ import traceback
340
+ traceback.print_exc()
341
+ break
342
+
343
+ except Exception as e:
344
+ print(f"Error setting up environment: {e}")
345
+ import traceback
346
+ traceback.print_exc()
347
+
348
+ def stop(self):
349
+ self.running = False
350
+ if self.env:
351
+ self.env.close()
352
+
353
+ # Main Application Window
354
+ class ALE_RLApp(QMainWindow):
355
+ def __init__(self):
356
+ super().__init__()
357
+ self.training_thread = None
358
+ self.init_ui()
359
+
360
+ def init_ui(self):
361
+ self.setWindowTitle('๐ŸŽฎ ALE Arcade RL Training')
362
+ self.setGeometry(100, 100, 1200, 800)
363
+
364
+ central_widget = QWidget()
365
+ self.setCentralWidget(central_widget)
366
+ layout = QVBoxLayout(central_widget)
367
+
368
+ # Title
369
+ title = QLabel('๐ŸŽฎ Arcade Reinforcement Learning (ALE)')
370
+ title.setFont(QFont('Arial', 16, QFont.Bold))
371
+ title.setAlignment(Qt.AlignCenter)
372
+ layout.addWidget(title)
373
+
374
+ # Control Panel
375
+ control_layout = QHBoxLayout()
376
+
377
+ self.algorithm_combo = QComboBox()
378
+ self.algorithm_combo.addItems(['Dueling DQN', 'PPO'])
379
+
380
+ self.env_combo = QComboBox()
381
+ self.env_combo.addItems([
382
+ 'ALE/Breakout-v5',
383
+ 'ALE/Pong-v5',
384
+ 'ALE/SpaceInvaders-v5',
385
+ 'ALE/Assault-v5',
386
+ 'ALE/BeamRider-v5',
387
+ 'ALE/Enduro-v5',
388
+ 'ALE/Seaquest-v5',
389
+ 'ALE/Qbert-v5'
390
+ ])
391
+
392
+ self.start_btn = QPushButton('Start Training')
393
+ self.start_btn.clicked.connect(self.start_training)
394
+
395
+ self.stop_btn = QPushButton('Stop Training')
396
+ self.stop_btn.clicked.connect(self.stop_training)
397
+ self.stop_btn.setEnabled(False)
398
+
399
+ control_layout.addWidget(QLabel('Algorithm:'))
400
+ control_layout.addWidget(self.algorithm_combo)
401
+ control_layout.addWidget(QLabel('Environment:'))
402
+ control_layout.addWidget(self.env_combo)
403
+ control_layout.addWidget(self.start_btn)
404
+ control_layout.addWidget(self.stop_btn)
405
+ control_layout.addStretch()
406
+
407
+ layout.addLayout(control_layout)
408
+
409
+ # Content Area
410
+ content_layout = QHBoxLayout()
411
+
412
+ # Left side - Game Display
413
+ left_frame = QFrame()
414
+ left_frame.setFrameStyle(QFrame.Box)
415
+ left_layout = QVBoxLayout(left_frame)
416
+
417
+ self.game_display = QLabel()
418
+ self.game_display.setMinimumSize(400, 300)
419
+ self.game_display.setAlignment(Qt.AlignCenter)
420
+ self.game_display.setText('Game display will appear here\nPress "Start Training" to begin')
421
+ self.game_display.setStyleSheet('border: 1px solid gray; background-color: black; color: white;')
422
+
423
+ left_layout.addWidget(QLabel('Game Display:'))
424
+ left_layout.addWidget(self.game_display)
425
+
426
+ # Right side - Training Info
427
+ right_frame = QFrame()
428
+ right_frame.setFrameStyle(QFrame.Box)
429
+ right_layout = QVBoxLayout(right_frame)
430
+
431
+ # Progress bars
432
+ self.env_label = QLabel('Environment: Not started')
433
+ self.episode_label = QLabel('Episode: 0')
434
+ self.reward_label = QLabel('Total Reward: 0')
435
+ self.steps_label = QLabel('Steps: 0')
436
+ self.epsilon_label = QLabel('Epsilon: 0')
437
+ self.lives_label = QLabel('Lives: 0')
438
+
439
+ right_layout.addWidget(self.env_label)
440
+ right_layout.addWidget(self.episode_label)
441
+ right_layout.addWidget(self.reward_label)
442
+ right_layout.addWidget(self.steps_label)
443
+ right_layout.addWidget(self.epsilon_label)
444
+ right_layout.addWidget(self.lives_label)
445
+
446
+ # Training log
447
+ right_layout.addWidget(QLabel('Training Log:'))
448
+ self.log_text = QTextEdit()
449
+ self.log_text.setMaximumHeight(200)
450
+ right_layout.addWidget(self.log_text)
451
+
452
+ content_layout.addWidget(left_frame)
453
+ content_layout.addWidget(right_frame)
454
+ layout.addLayout(content_layout)
455
+
456
+ def start_training(self):
457
+ algorithm = 'dqn' if self.algorithm_combo.currentText() == 'Dueling DQN' else 'ppo'
458
+ env_name = self.env_combo.currentText()
459
+
460
+ self.training_thread = TrainingThread(algorithm, env_name)
461
+ self.training_thread.update_signal.connect(self.update_training_info)
462
+ self.training_thread.frame_signal.connect(self.update_game_display)
463
+ self.training_thread.start()
464
+
465
+ self.start_btn.setEnabled(False)
466
+ self.stop_btn.setEnabled(True)
467
+
468
+ self.log_text.append(f'Started {self.algorithm_combo.currentText()} training on {env_name}...')
469
+
470
+ def stop_training(self):
471
+ if self.training_thread:
472
+ self.training_thread.stop()
473
+ self.training_thread.wait()
474
+
475
+ self.start_btn.setEnabled(True)
476
+ self.stop_btn.setEnabled(False)
477
+ self.log_text.append('Training stopped.')
478
+
479
+ def update_training_info(self, data):
480
+ self.env_label.setText(f'Environment: {data.get("env_name", "Unknown")}')
481
+ self.episode_label.setText(f'Episode: {data["episode"]}')
482
+ self.reward_label.setText(f'Total Reward: {data["total_reward"]:.2f}')
483
+ self.steps_label.setText(f'Steps: {data["steps"]}')
484
+ self.epsilon_label.setText(f'Epsilon: {data["epsilon"]:.3f}')
485
+ self.lives_label.setText(f'Lives: {data.get("lives", 0)}')
486
+
487
+ def update_game_display(self, frame):
488
+ if frame is not None:
489
+ try:
490
+ h, w, ch = frame.shape
491
+ bytes_per_line = ch * w
492
+ q_img = QImage(frame.data, w, h, bytes_per_line, QImage.Format_RGB888)
493
+ pixmap = QPixmap.fromImage(q_img)
494
+ self.game_display.setPixmap(pixmap.scaled(400, 300, Qt.KeepAspectRatio))
495
+ except Exception as e:
496
+ print(f"Error updating display: {e}")
497
+
498
+ def closeEvent(self, event):
499
+ self.stop_training()
500
+ event.accept()
501
+
502
+ def main():
503
+ # Set random seeds for reproducibility
504
+ torch.manual_seed(42)
505
+ np.random.seed(42)
506
+ random.seed(42)
507
+
508
+ app = QApplication(sys.argv)
509
+ window = ALE_RLApp()
510
+ window.show()
511
+ sys.exit(app.exec_())
512
+
513
+ if __name__ == '__main__':
514
+ main()
ale_pyqt5/app_2.py ADDED
@@ -0,0 +1,559 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import numpy as np
4
+ import random
5
+ from collections import deque
6
+ import gymnasium as gym
7
+ import ale_py
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.optim as optim
12
+ import torch.nn.functional as F
13
+ from torch.distributions import Categorical
14
+
15
+ from PyQt5.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout,
16
+ QHBoxLayout, QPushButton, QLabel, QComboBox,
17
+ QTextEdit, QProgressBar, QTabWidget, QFrame)
18
+ from PyQt5.QtCore import QTimer, Qt, pyqtSignal, QThread
19
+ from PyQt5.QtGui import QImage, QPixmap, QFont
20
+
21
+ # Register ALE environments
22
+ gym.register_envs(ale_py)
23
+
24
+ # Environment setup
25
+ def create_env(env_name='ALE/SpaceInvaders-v5'):
26
+ """
27
+ Create ALE environment with Gymnasium API
28
+ """
29
+ env = gym.make(env_name, render_mode='rgb_array')
30
+ return env
31
+
32
+ # Enhanced Neural Network for Dueling DQN
33
+ class DuelingDQN(nn.Module):
34
+ def __init__(self, input_shape, n_actions):
35
+ super(DuelingDQN, self).__init__()
36
+ self.conv = nn.Sequential(
37
+ nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
38
+ nn.ReLU(),
39
+ nn.Conv2d(32, 64, kernel_size=4, stride=2),
40
+ nn.ReLU(),
41
+ nn.Conv2d(64, 64, kernel_size=3, stride=1),
42
+ nn.ReLU()
43
+ )
44
+
45
+ conv_out_size = self._get_conv_out(input_shape)
46
+
47
+ self.fc_advantage = nn.Sequential(
48
+ nn.Linear(conv_out_size, 256),
49
+ nn.ReLU(),
50
+ nn.Linear(256, n_actions)
51
+ )
52
+
53
+ self.fc_value = nn.Sequential(
54
+ nn.Linear(conv_out_size, 256),
55
+ nn.ReLU(),
56
+ nn.Linear(256, 1)
57
+ )
58
+
59
+ def _get_conv_out(self, shape):
60
+ o = self.conv(torch.zeros(1, *shape))
61
+ return int(np.prod(o.size()))
62
+
63
+ def forward(self, x):
64
+ conv_out = self.conv(x).view(x.size()[0], -1)
65
+ advantage = self.fc_advantage(conv_out)
66
+ value = self.fc_value(conv_out)
67
+ return value + advantage - advantage.mean()
68
+
69
+ # Enhanced Neural Network for PPO
70
+ class PPONetwork(nn.Module):
71
+ def __init__(self, input_shape, n_actions):
72
+ super(PPONetwork, self).__init__()
73
+ self.conv = nn.Sequential(
74
+ nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
75
+ nn.ReLU(),
76
+ nn.Conv2d(32, 64, kernel_size=4, stride=2),
77
+ nn.ReLU(),
78
+ nn.Conv2d(64, 64, kernel_size=3, stride=1),
79
+ nn.ReLU()
80
+ )
81
+
82
+ conv_out_size = self._get_conv_out(input_shape)
83
+
84
+ self.actor = nn.Sequential(
85
+ nn.Linear(conv_out_size, 256),
86
+ nn.ReLU(),
87
+ nn.Linear(256, n_actions),
88
+ nn.Softmax(dim=-1)
89
+ )
90
+
91
+ self.critic = nn.Sequential(
92
+ nn.Linear(conv_out_size, 256),
93
+ nn.ReLU(),
94
+ nn.Linear(256, 1)
95
+ )
96
+
97
+ def _get_conv_out(self, shape):
98
+ o = self.conv(torch.zeros(1, *shape))
99
+ return int(np.prod(o.size()))
100
+
101
+ def forward(self, x):
102
+ conv_out = self.conv(x).view(x.size()[0], -1)
103
+ return self.actor(conv_out), self.critic(conv_out)
104
+
105
+ # Enhanced Dueling DQN Agent with better training
106
+ class DuelingDQNAgent:
107
+ def __init__(self, state_dim, action_dim, lr=1e-4, gamma=0.99, epsilon=1.0,
108
+ epsilon_min=0.01, epsilon_decay=0.999, memory_size=50000, batch_size=32):
109
+ self.state_dim = state_dim
110
+ self.action_dim = action_dim
111
+ self.lr = lr
112
+ self.gamma = gamma
113
+ self.epsilon = epsilon
114
+ self.epsilon_min = epsilon_min
115
+ self.epsilon_decay = epsilon_decay
116
+ self.batch_size = batch_size
117
+
118
+ self.memory = deque(maxlen=memory_size)
119
+ self.model = DuelingDQN(state_dim, action_dim)
120
+ self.optimizer = optim.Adam(self.model.parameters(), lr=lr, weight_decay=1e-5)
121
+ self.criterion = nn.SmoothL1Loss() # Huber loss for better stability
122
+
123
+ # Target network for stable training
124
+ self.target_model = DuelingDQN(state_dim, action_dim)
125
+ self.update_target_network()
126
+ self.target_update_frequency = 1000
127
+ self.train_step = 0
128
+
129
+ def update_target_network(self):
130
+ self.target_model.load_state_dict(self.model.state_dict())
131
+
132
+ def remember(self, state, action, reward, next_state, done):
133
+ self.memory.append((state, action, reward, next_state, done))
134
+
135
+ def act(self, state):
136
+ if np.random.random() <= self.epsilon:
137
+ return random.randrange(self.action_dim)
138
+
139
+ state = torch.FloatTensor(state).unsqueeze(0)
140
+ with torch.no_grad():
141
+ q_values = self.model(state)
142
+ return np.argmax(q_values.detach().numpy())
143
+
144
+ def replay(self):
145
+ if len(self.memory) < self.batch_size:
146
+ return
147
+
148
+ batch = random.sample(self.memory, self.batch_size)
149
+ states = torch.FloatTensor(np.array([e[0] for e in batch]))
150
+ actions = torch.LongTensor([e[1] for e in batch])
151
+ rewards = torch.FloatTensor([e[2] for e in batch])
152
+ next_states = torch.FloatTensor(np.array([e[3] for e in batch]))
153
+ dones = torch.BoolTensor([e[4] for e in batch])
154
+
155
+ current_q_values = self.model(states).gather(1, actions.unsqueeze(1))
156
+
157
+ with torch.no_grad():
158
+ next_actions = self.model(next_states).max(1)[1]
159
+ next_q_values = self.target_model(next_states).gather(1, next_actions.unsqueeze(1)).squeeze()
160
+
161
+ target_q_values = rewards + (self.gamma * next_q_values * ~dones)
162
+
163
+ loss = self.criterion(current_q_values.squeeze(), target_q_values)
164
+
165
+ self.optimizer.zero_grad()
166
+ loss.backward()
167
+ # Gradient clipping
168
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
169
+ self.optimizer.step()
170
+
171
+ # Update target network periodically
172
+ self.train_step += 1
173
+ if self.train_step % self.target_update_frequency == 0:
174
+ self.update_target_network()
175
+
176
+ if self.epsilon > self.epsilon_min:
177
+ self.epsilon *= self.epsilon_decay
178
+
179
+ # Enhanced PPO Agent
180
+ class PPOAgent:
181
+ def __init__(self, state_dim, action_dim, lr=3e-4, gamma=0.99, epsilon=0.2,
182
+ entropy_coef=0.01, value_coef=0.5, ppo_epochs=4, batch_size=64):
183
+ self.state_dim = state_dim
184
+ self.action_dim = action_dim
185
+ self.gamma = gamma
186
+ self.epsilon = epsilon
187
+ self.entropy_coef = entropy_coef
188
+ self.value_coef = value_coef
189
+ self.ppo_epochs = ppo_epochs
190
+ self.batch_size = batch_size
191
+
192
+ self.model = PPONetwork(state_dim, action_dim)
193
+ self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
194
+
195
+ self.memory = []
196
+
197
+ def remember(self, state, action, reward, value, log_prob):
198
+ self.memory.append((state, action, reward, value, log_prob))
199
+
200
+ def act(self, state):
201
+ state = torch.FloatTensor(state).unsqueeze(0)
202
+ with torch.no_grad():
203
+ probs, value = self.model(state)
204
+ dist = Categorical(probs)
205
+ action = dist.sample()
206
+ return action.item(), dist.log_prob(action), value.squeeze()
207
+
208
+ def train(self):
209
+ if len(self.memory) < self.batch_size:
210
+ return
211
+
212
+ states, actions, rewards, values, log_probs = zip(*self.memory)
213
+
214
+ # Calculate returns and advantages
215
+ returns = []
216
+ R = 0
217
+ for r in reversed(rewards):
218
+ R = r + self.gamma * R
219
+ returns.insert(0, R)
220
+
221
+ returns = torch.FloatTensor(returns)
222
+ old_values = torch.FloatTensor(values)
223
+ advantages = returns - old_values
224
+
225
+ # Normalize advantages
226
+ advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
227
+
228
+ # Convert to tensors
229
+ states_tensor = torch.FloatTensor(np.array(states))
230
+ actions_tensor = torch.LongTensor(actions)
231
+ old_log_probs = torch.FloatTensor(log_probs)
232
+
233
+ # PPO epochs
234
+ for _ in range(self.ppo_epochs):
235
+ # Get new probabilities
236
+ new_probs, new_values = self.model(states_tensor)
237
+ dist = Categorical(new_probs)
238
+ new_log_probs = dist.log_prob(actions_tensor)
239
+ entropy = dist.entropy().mean()
240
+
241
+ # PPO loss
242
+ ratio = (new_log_probs - old_log_probs).exp()
243
+ surr1 = ratio * advantages
244
+ surr2 = torch.clamp(ratio, 1 - self.epsilon, 1 + self.epsilon) * advantages
245
+ actor_loss = -torch.min(surr1, surr2).mean()
246
+
247
+ critic_loss = F.mse_loss(new_values.squeeze(), returns)
248
+
249
+ total_loss = actor_loss + self.value_coef * critic_loss - self.entropy_coef * entropy
250
+
251
+ self.optimizer.zero_grad()
252
+ total_loss.backward()
253
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.5)
254
+ self.optimizer.step()
255
+
256
+ self.memory = []
257
+
258
+ # Enhanced Training Thread with better state processing
259
+ class TrainingThread(QThread):
260
+ update_signal = pyqtSignal(dict)
261
+ frame_signal = pyqtSignal(np.ndarray)
262
+
263
+ def __init__(self, algorithm='dqn', env_name='ALE/Breakout-v5'):
264
+ super().__init__()
265
+ self.algorithm = algorithm
266
+ self.env_name = env_name
267
+ self.running = False
268
+ self.env = None
269
+ self.agent = None
270
+
271
+ def preprocess_state(self, state):
272
+ # Convert to CHW format, normalize, and convert to grayscale
273
+ if len(state.shape) == 3:
274
+ # Convert to grayscale and resize for faster processing
275
+ state = state.mean(axis=2, keepdims=True) # Convert to grayscale
276
+ state = state.transpose((2, 0, 1))
277
+ state = state / 255.0
278
+ return state
279
+
280
+ def run(self):
281
+ self.running = True
282
+ try:
283
+ self.env = create_env(self.env_name)
284
+ state, info = self.env.reset()
285
+ state = self.preprocess_state(state)
286
+
287
+ n_actions = self.env.action_space.n
288
+ state_dim = state.shape
289
+
290
+ print(f"๐ŸŽฎ Training on: {self.env_name}")
291
+ print(f"๐Ÿ“Š State shape: {state_dim}, Actions: {n_actions}")
292
+ print(f"๐Ÿค– Algorithm: {self.algorithm}")
293
+
294
+ if self.algorithm == 'dqn':
295
+ self.agent = DuelingDQNAgent(state_dim, n_actions)
296
+ else:
297
+ self.agent = PPOAgent(state_dim, n_actions)
298
+
299
+ episode = 0
300
+ total_reward = 0
301
+ steps = 0
302
+ episode_rewards = []
303
+ best_reward = -float('inf')
304
+
305
+ while self.running:
306
+ try:
307
+ if self.algorithm == 'dqn':
308
+ action = self.agent.act(state)
309
+ next_state, reward, terminated, truncated, info = self.env.step(action)
310
+ done = terminated or truncated
311
+ next_state = self.preprocess_state(next_state)
312
+ self.agent.remember(state, action, reward, next_state, done)
313
+ self.agent.replay()
314
+ else:
315
+ action, log_prob, value = self.agent.act(state)
316
+ next_state, reward, terminated, truncated, info = self.env.step(action)
317
+ done = terminated or truncated
318
+ next_state = self.preprocess_state(next_state)
319
+ self.agent.remember(state, action, reward, value, log_prob)
320
+ if done:
321
+ self.agent.train()
322
+
323
+ state = next_state
324
+ total_reward += reward
325
+ steps += 1
326
+
327
+ # Emit frame for display
328
+ try:
329
+ frame = self.env.render()
330
+ if frame is not None:
331
+ self.frame_signal.emit(frame)
332
+ except Exception as e:
333
+ # Create a placeholder frame if rendering fails
334
+ frame = np.zeros((210, 160, 3), dtype=np.uint8)
335
+ self.frame_signal.emit(frame)
336
+
337
+ # Emit training progress more frequently for better feedback
338
+ if steps % 5 == 0:
339
+ avg_reward = np.mean(episode_rewards[-10:]) if episode_rewards else total_reward
340
+ progress_data = {
341
+ 'episode': episode,
342
+ 'total_reward': total_reward,
343
+ 'steps': steps,
344
+ 'epsilon': self.agent.epsilon if self.algorithm == 'dqn' else 0.2,
345
+ 'env_name': self.env_name,
346
+ 'lives': info.get('lives', 0) if isinstance(info, dict) else 0,
347
+ 'avg_reward': avg_reward,
348
+ 'best_reward': best_reward
349
+ }
350
+ self.update_signal.emit(progress_data)
351
+
352
+ if terminated or truncated:
353
+ episode_rewards.append(total_reward)
354
+ if total_reward > best_reward:
355
+ best_reward = total_reward
356
+
357
+ avg_reward = np.mean(episode_rewards[-10:]) if episode_rewards else total_reward
358
+
359
+ print(f"๐ŸŽฏ Episode {episode}: Reward: {total_reward:.1f}, "
360
+ f"Steps: {steps}, Avg (last 10): {avg_reward:.1f}, "
361
+ f"Best: {best_reward:.1f}, Epsilon: {self.agent.epsilon:.3f}")
362
+
363
+ episode += 1
364
+ state, info = self.env.reset()
365
+ state = self.preprocess_state(state)
366
+ total_reward = 0
367
+ steps = 0
368
+
369
+ except Exception as e:
370
+ print(f"โŒ Error in training loop: {e}")
371
+ import traceback
372
+ traceback.print_exc()
373
+ break
374
+
375
+ except Exception as e:
376
+ print(f"โŒ Error setting up environment: {e}")
377
+ import traceback
378
+ traceback.print_exc()
379
+
380
+ def stop(self):
381
+ self.running = False
382
+ if self.env:
383
+ self.env.close()
384
+
385
+ # Enhanced Main Application Window
386
+ class ALE_RLApp(QMainWindow):
387
+ def __init__(self):
388
+ super().__init__()
389
+ self.training_thread = None
390
+ self.init_ui()
391
+
392
+ def init_ui(self):
393
+ self.setWindowTitle('๐ŸŽฎ ALE Arcade RL Training - Enhanced')
394
+ self.setGeometry(100, 100, 1200, 800)
395
+
396
+ central_widget = QWidget()
397
+ self.setCentralWidget(central_widget)
398
+ layout = QVBoxLayout(central_widget)
399
+
400
+ # Title
401
+ title = QLabel('๐ŸŽฎ Arcade Reinforcement Learning (ALE) - Enhanced Training')
402
+ title.setFont(QFont('Arial', 16, QFont.Bold))
403
+ title.setAlignment(Qt.AlignCenter)
404
+ layout.addWidget(title)
405
+
406
+ # Control Panel
407
+ control_layout = QHBoxLayout()
408
+
409
+ self.algorithm_combo = QComboBox()
410
+ self.algorithm_combo.addItems(['Dueling DQN', 'PPO'])
411
+
412
+ self.env_combo = QComboBox()
413
+ self.env_combo.addItems([
414
+ 'ALE/Breakout-v5',
415
+ 'ALE/Pong-v5',
416
+ 'ALE/SpaceInvaders-v5',
417
+ 'ALE/Assault-v5',
418
+ 'ALE/BeamRider-v5',
419
+ 'ALE/Enduro-v5',
420
+ 'ALE/Seaquest-v5',
421
+ 'ALE/Qbert-v5'
422
+ ])
423
+
424
+ self.start_btn = QPushButton('๐Ÿš€ Start Training')
425
+ self.start_btn.clicked.connect(self.start_training)
426
+
427
+ self.stop_btn = QPushButton('โน๏ธ Stop Training')
428
+ self.stop_btn.clicked.connect(self.stop_training)
429
+ self.stop_btn.setEnabled(False)
430
+
431
+ control_layout.addWidget(QLabel('๐Ÿค– Algorithm:'))
432
+ control_layout.addWidget(self.algorithm_combo)
433
+ control_layout.addWidget(QLabel('๐ŸŽฎ Environment:'))
434
+ control_layout.addWidget(self.env_combo)
435
+ control_layout.addWidget(self.start_btn)
436
+ control_layout.addWidget(self.stop_btn)
437
+ control_layout.addStretch()
438
+
439
+ layout.addLayout(control_layout)
440
+
441
+ # Content Area
442
+ content_layout = QHBoxLayout()
443
+
444
+ # Left side - Game Display
445
+ left_frame = QFrame()
446
+ left_frame.setFrameStyle(QFrame.Box)
447
+ left_layout = QVBoxLayout(left_frame)
448
+
449
+ self.game_display = QLabel()
450
+ self.game_display.setMinimumSize(400, 300)
451
+ self.game_display.setAlignment(Qt.AlignCenter)
452
+ self.game_display.setText('Game display will appear here\nPress "๐Ÿš€ Start Training" to begin')
453
+ self.game_display.setStyleSheet('border: 1px solid gray; background-color: black; color: white; font-size: 14px;')
454
+
455
+ left_layout.addWidget(QLabel('๐ŸŽฎ Game Display:'))
456
+ left_layout.addWidget(self.game_display)
457
+
458
+ # Right side - Training Info
459
+ right_frame = QFrame()
460
+ right_frame.setFrameStyle(QFrame.Box)
461
+ right_layout = QVBoxLayout(right_frame)
462
+
463
+ # Progress bars with better styling
464
+ self.env_label = QLabel('๐ŸŽฏ Environment: Not started')
465
+ self.episode_label = QLabel('๐Ÿ“ˆ Episode: 0')
466
+ self.reward_label = QLabel('๐Ÿ† Total Reward: 0')
467
+ self.avg_reward_label = QLabel('๐Ÿ“Š Avg Reward (last 10): 0')
468
+ self.best_reward_label = QLabel('โญ Best Reward: 0')
469
+ self.steps_label = QLabel('โฑ๏ธ Steps: 0')
470
+ self.epsilon_label = QLabel('๐ŸŽฒ Epsilon: 0')
471
+ self.lives_label = QLabel('โค๏ธ Lives: 0')
472
+
473
+ # Style the labels
474
+ for label in [self.env_label, self.episode_label, self.reward_label,
475
+ self.avg_reward_label, self.best_reward_label, self.steps_label,
476
+ self.epsilon_label, self.lives_label]:
477
+ label.setStyleSheet('font-weight: bold; font-size: 12px;')
478
+
479
+ right_layout.addWidget(self.env_label)
480
+ right_layout.addWidget(self.episode_label)
481
+ right_layout.addWidget(self.reward_label)
482
+ right_layout.addWidget(self.avg_reward_label)
483
+ right_layout.addWidget(self.best_reward_label)
484
+ right_layout.addWidget(self.steps_label)
485
+ right_layout.addWidget(self.epsilon_label)
486
+ right_layout.addWidget(self.lives_label)
487
+
488
+ # Training log
489
+ right_layout.addWidget(QLabel('๐Ÿ“ Training Log:'))
490
+ self.log_text = QTextEdit()
491
+ self.log_text.setMaximumHeight(200)
492
+ self.log_text.setStyleSheet('font-family: monospace; font-size: 10px;')
493
+ right_layout.addWidget(self.log_text)
494
+
495
+ content_layout.addWidget(left_frame)
496
+ content_layout.addWidget(right_frame)
497
+ layout.addLayout(content_layout)
498
+
499
+ def start_training(self):
500
+ algorithm = 'dqn' if self.algorithm_combo.currentText() == 'Dueling DQN' else 'ppo'
501
+ env_name = self.env_combo.currentText()
502
+
503
+ self.training_thread = TrainingThread(algorithm, env_name)
504
+ self.training_thread.update_signal.connect(self.update_training_info)
505
+ self.training_thread.frame_signal.connect(self.update_game_display)
506
+ self.training_thread.start()
507
+
508
+ self.start_btn.setEnabled(False)
509
+ self.stop_btn.setEnabled(True)
510
+
511
+ self.log_text.append(f'๐Ÿš€ Started {self.algorithm_combo.currentText()} training on {env_name}...')
512
+
513
+ def stop_training(self):
514
+ if self.training_thread:
515
+ self.training_thread.stop()
516
+ self.training_thread.wait()
517
+
518
+ self.start_btn.setEnabled(True)
519
+ self.stop_btn.setEnabled(False)
520
+ self.log_text.append('โน๏ธ Training stopped.')
521
+
522
+ def update_training_info(self, data):
523
+ self.env_label.setText(f'๐ŸŽฏ Environment: {data.get("env_name", "Unknown")}')
524
+ self.episode_label.setText(f'๐Ÿ“ˆ Episode: {data["episode"]}')
525
+ self.reward_label.setText(f'๐Ÿ† Total Reward: {data["total_reward"]:.1f}')
526
+ self.avg_reward_label.setText(f'๐Ÿ“Š Avg Reward (last 10): {data.get("avg_reward", 0):.1f}')
527
+ self.best_reward_label.setText(f'โญ Best Reward: {data.get("best_reward", 0):.1f}')
528
+ self.steps_label.setText(f'โฑ๏ธ Steps: {data["steps"]}')
529
+ self.epsilon_label.setText(f'๐ŸŽฒ Epsilon: {data["epsilon"]:.3f}')
530
+ self.lives_label.setText(f'โค๏ธ Lives: {data.get("lives", 0)}')
531
+
532
+ def update_game_display(self, frame):
533
+ if frame is not None:
534
+ try:
535
+ h, w, ch = frame.shape
536
+ bytes_per_line = ch * w
537
+ q_img = QImage(frame.data, w, h, bytes_per_line, QImage.Format_RGB888)
538
+ pixmap = QPixmap.fromImage(q_img)
539
+ self.game_display.setPixmap(pixmap.scaled(400, 300, Qt.KeepAspectRatio))
540
+ except Exception as e:
541
+ print(f"Error updating display: {e}")
542
+
543
+ def closeEvent(self, event):
544
+ self.stop_training()
545
+ event.accept()
546
+
547
+ def main():
548
+ # Set random seeds for reproducibility
549
+ torch.manual_seed(42)
550
+ np.random.seed(42)
551
+ random.seed(42)
552
+
553
+ app = QApplication(sys.argv)
554
+ window = ALE_RLApp()
555
+ window.show()
556
+ sys.exit(app.exec_())
557
+
558
+ if __name__ == '__main__':
559
+ main()
ale_pyqt5/installed_packages_ale_py.txt ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ale-py==0.11.2
2
+ cloudpickle==3.1.2
3
+ contourpy==1.3.3
4
+ cycler==0.12.1
5
+ Farama-Notifications==0.0.4
6
+ filelock==3.20.0
7
+ fonttools==4.60.1
8
+ fsspec==2025.10.0
9
+ gym==0.26.2
10
+ gym-notices==0.1.0
11
+ gymnasium==1.2.2
12
+ Jinja2==3.1.6
13
+ kiwisolver==1.4.9
14
+ MarkupSafe==3.0.3
15
+ matplotlib==3.10.7
16
+ mpmath==1.3.0
17
+ networkx==3.5
18
+ numpy==2.2.6
19
+ opencv-python==4.12.0.88
20
+ packaging==25.0
21
+ pillow==12.0.0
22
+ pyglet==1.5.11
23
+ pyparsing==3.2.5
24
+ python-dateutil==2.9.0.post0
25
+ setuptools==80.9.0
26
+ six==1.17.0
27
+ sympy==1.14.0
28
+ torch==2.9.0
29
+ tqdm==4.67.1
30
+ typing_extensions==4.15.0