TroglodyteDerivations commited on
Commit
55a1670
·
verified ·
1 Parent(s): f61094e

Upload 6 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ output.mp4 filter=lfs diff=lfs merge=lfs -text
37
+ Screenshot[[:space:]]2025-11-20[[:space:]]at[[:space:]]10.45.11 AM.png filter=lfs diff=lfs merge=lfs -text
38
+ Screenshot[[:space:]]2025-11-20[[:space:]]at[[:space:]]11.04.13 AM.png filter=lfs diff=lfs merge=lfs -text
Screenshot 2025-11-20 at 10.45.11 AM.png ADDED

Git LFS Details

  • SHA256: d7e13f247b0232eb78b629b121422d4f38043c592d84b8c0a9b973a7a71d932d
  • Pointer size: 131 Bytes
  • Size of remote file: 504 kB
Screenshot 2025-11-20 at 11.04.13 AM.png ADDED

Git LFS Details

  • SHA256: 2878ed3d32585691dffa5cc20bb99fca0444a4bb5e42979efa8f1273eb119b99
  • Pointer size: 131 Bytes
  • Size of remote file: 663 kB
mario_ai_app.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import random
3
+ import numpy as np
4
+ from PyQt5.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout,
5
+ QHBoxLayout, QLabel, QFrame, QGridLayout, QProgressBar)
6
+ from PyQt5.QtCore import QTimer, Qt, pyqtSignal
7
+ from PyQt5.QtGui import QFont, QPainter, QColor, QPen, QBrush
8
+ import math
9
+
10
+ class NeuralNetworkWidget(QWidget):
11
+ def __init__(self):
12
+ super().__init__()
13
+ self.setMinimumSize(400, 300)
14
+ self.layers = [80, 9, 6] # Input, Hidden, Output layers
15
+ self.activations = [random.random() for _ in range(sum(self.layers))]
16
+ self.connection_strengths = {}
17
+
18
+ def update_activations(self, new_activations=None):
19
+ if new_activations:
20
+ self.activations = new_activations
21
+ else:
22
+ # Simulate some neural activity
23
+ self.activations = [max(0, min(1, x + random.uniform(-0.2, 0.2)))
24
+ for x in self.activations]
25
+ self.update()
26
+
27
+ def paintEvent(self, event):
28
+ painter = QPainter(self)
29
+ painter.setRenderHint(QPainter.Antialiasing)
30
+
31
+ # Set up colors
32
+ bg_color = QColor(30, 30, 40)
33
+ neuron_color = QColor(100, 150, 255)
34
+ active_neuron_color = QColor(255, 100, 100)
35
+ connection_color = QColor(100, 100, 150, 100)
36
+
37
+ # Fill background
38
+ painter.fillRect(self.rect(), bg_color)
39
+
40
+ width = self.width()
41
+ height = self.height()
42
+
43
+ # Calculate positions for neurons
44
+ neuron_positions = []
45
+ activation_index = 0
46
+
47
+ for layer_idx, neuron_count in enumerate(self.layers):
48
+ layer_positions = []
49
+ layer_x = (layer_idx + 1) * width / (len(self.layers) + 1)
50
+
51
+ for neuron_idx in range(neuron_count):
52
+ neuron_y = (neuron_idx + 1) * height / (neuron_count + 1)
53
+ layer_positions.append((layer_x, neuron_y))
54
+
55
+ # Draw connections to next layer
56
+ if layer_idx < len(self.layers) - 1:
57
+ next_layer_count = self.layers[layer_idx + 1]
58
+ for next_neuron_idx in range(next_layer_count):
59
+ next_x = (layer_idx + 2) * width / (len(self.layers) + 1)
60
+ next_y = (next_neuron_idx + 1) * height / (next_layer_count + 1)
61
+
62
+ # Vary connection strength and color
63
+ strength = random.random()
64
+ alpha = int(50 + strength * 100)
65
+ pen_color = QColor(connection_color)
66
+ pen_color.setAlpha(alpha)
67
+
68
+ painter.setPen(QPen(pen_color, 1 + strength * 2))
69
+ painter.drawLine(int(layer_x), int(neuron_y),
70
+ int(next_x), int(next_y))
71
+
72
+ activation_index += 1
73
+ neuron_positions.append(layer_positions)
74
+
75
+ # Draw neurons
76
+ activation_index = 0
77
+ for layer_idx, positions in enumerate(neuron_positions):
78
+ for pos_idx, (x, y) in enumerate(positions):
79
+ activation = self.activations[activation_index]
80
+ activation_index += 1
81
+
82
+ # Determine neuron color based on activation
83
+ if activation > 0.7:
84
+ color = active_neuron_color
85
+ elif activation > 0.3:
86
+ color = QColor(255, 200, 100) # Orange for medium activation
87
+ else:
88
+ color = neuron_color
89
+
90
+ # Draw neuron
91
+ radius = 8 + activation * 8
92
+ painter.setBrush(QBrush(color))
93
+ painter.setPen(QPen(QColor(200, 200, 255), 2))
94
+ painter.drawEllipse(int(x - radius/2), int(y - radius/2),
95
+ int(radius), int(radius))
96
+
97
+ class ControlButtonWidget(QWidget):
98
+ def __init__(self):
99
+ super().__init__()
100
+ self.setup_ui()
101
+
102
+ def setup_ui(self):
103
+ layout = QGridLayout()
104
+ layout.setSpacing(10)
105
+ layout.setContentsMargins(10, 10, 10, 10)
106
+
107
+ # Button labels and their positions
108
+ buttons = [
109
+ ('U', 0, 1), # Up
110
+ ('L', 1, 0), # Left
111
+ ('D', 1, 1), # Down
112
+ ('R', 1, 2), # Right
113
+ ('A', 0, 3), # A button
114
+ ('B', 1, 3), # B button
115
+ ]
116
+
117
+ self.labels = {}
118
+
119
+ for text, row, col in buttons:
120
+ label = QLabel(text)
121
+ label.setAlignment(Qt.AlignCenter)
122
+ label.setStyleSheet("""
123
+ QLabel {
124
+ background-color: #2d2d2d;
125
+ color: #cccccc;
126
+ border: 2px solid #555555;
127
+ border-radius: 10px;
128
+ font-weight: bold;
129
+ font-size: 14px;
130
+ min-width: 40px;
131
+ min-height: 40px;
132
+ }
133
+ """)
134
+ label.setMinimumSize(50, 50)
135
+ layout.addWidget(label, row, col)
136
+ self.labels[text] = label
137
+
138
+ self.setLayout(layout)
139
+
140
+ def activate_button(self, button, active=True):
141
+ if button in self.labels:
142
+ if active:
143
+ self.labels[button].setStyleSheet("""
144
+ QLabel {
145
+ background-color: #ff4444;
146
+ color: white;
147
+ border: 2px solid #ff6666;
148
+ border-radius: 10px;
149
+ font-weight: bold;
150
+ font-size: 14px;
151
+ min-width: 40px;
152
+ min-height: 40px;
153
+ }
154
+ """)
155
+ else:
156
+ self.labels[button].setStyleSheet("""
157
+ QLabel {
158
+ background-color: #2d2d2d;
159
+ color: #cccccc;
160
+ border: 2px solid #555555;
161
+ border-radius: 10px;
162
+ font-weight: bold;
163
+ font-size: 14px;
164
+ min-width: 40px;
165
+ min-height: 40px;
166
+ }
167
+ """)
168
+
169
+ class MetricWidget(QWidget):
170
+ def __init__(self, title, value, unit=""):
171
+ super().__init__()
172
+ self.title = title
173
+ self.value = value
174
+ self.unit = unit
175
+ self.setup_ui()
176
+
177
+ def setup_ui(self):
178
+ layout = QVBoxLayout()
179
+ layout.setSpacing(2)
180
+ layout.setContentsMargins(5, 5, 5, 5)
181
+
182
+ self.title_label = QLabel(self.title)
183
+ self.title_label.setAlignment(Qt.AlignCenter)
184
+ self.title_label.setStyleSheet("color: #888888; font-size: 10px;")
185
+
186
+ self.value_label = QLabel(f"{self.value}{self.unit}")
187
+ self.value_label.setAlignment(Qt.AlignCenter)
188
+ self.value_label.setStyleSheet("color: #ffffff; font-size: 12px; font-weight: bold;")
189
+
190
+ layout.addWidget(self.title_label)
191
+ layout.addWidget(self.value_label)
192
+
193
+ self.setLayout(layout)
194
+
195
+ def update_value(self, new_value):
196
+ self.value_label.setText(f"{new_value}{self.unit}")
197
+
198
+ class MarioAIApp(QMainWindow):
199
+ def __init__(self):
200
+ super().__init__()
201
+ self.generation = 1214
202
+ self.individual = "Replay"
203
+ self.best_fitness = 0
204
+ self.max_distance = 3161
205
+ self.num_inputs = 80
206
+ self.trainable_params = 789
207
+ self.offspring = "10, 90"
208
+ self.lifespan = "Infinite"
209
+ self.mutation = "Static 5.0%"
210
+ self.crossover = "Roulette"
211
+ self.sbx_eta = 100.0
212
+ self.layers = "[80, 9, 6]"
213
+
214
+ self.setup_ui()
215
+ self.setup_timers()
216
+
217
+ def setup_ui(self):
218
+ self.setWindowTitle("MARIO 000500 - AI Learns to Play Super Mario Bros!")
219
+ self.setGeometry(100, 100, 900, 700)
220
+
221
+ # Central widget
222
+ central_widget = QWidget()
223
+ self.setCentralWidget(central_widget)
224
+ main_layout = QVBoxLayout(central_widget)
225
+
226
+ # Header
227
+ header_layout = QVBoxLayout()
228
+
229
+ title_label = QLabel("MARIO 000500")
230
+ title_label.setAlignment(Qt.AlignCenter)
231
+ title_label.setStyleSheet("""
232
+ QLabel {
233
+ color: #ff4444;
234
+ font-size: 24px;
235
+ font-weight: bold;
236
+ margin: 10px;
237
+ }
238
+ """)
239
+
240
+ world_label = QLabel("WORLD 1-1")
241
+ world_label.setAlignment(Qt.AlignCenter)
242
+ world_label.setStyleSheet("""
243
+ QLabel {
244
+ color: #ffffff;
245
+ font-size: 18px;
246
+ font-weight: bold;
247
+ margin: 5px;
248
+ }
249
+ """)
250
+
251
+ time_label = QLabel("TIME 344")
252
+ time_label.setAlignment(Qt.AlignCenter)
253
+ time_label.setStyleSheet("""
254
+ QLabel {
255
+ color: #ffff44;
256
+ font-size: 16px;
257
+ font-weight: bold;
258
+ margin: 5px;
259
+ }
260
+ """)
261
+
262
+ header_layout.addWidget(title_label)
263
+ header_layout.addWidget(world_label)
264
+ header_layout.addWidget(time_label)
265
+
266
+ # Separator
267
+ separator = QFrame()
268
+ separator.setFrameShape(QFrame.HLine)
269
+ separator.setFrameShadow(QFrame.Sunken)
270
+ separator.setStyleSheet("background-color: #555555;")
271
+
272
+ # Main content area
273
+ content_layout = QHBoxLayout()
274
+
275
+ # Left panel - Metrics
276
+ left_panel = QWidget()
277
+ left_layout = QVBoxLayout(left_panel)
278
+
279
+ # Metrics grid
280
+ metrics_grid = QGridLayout()
281
+ metrics_grid.setSpacing(10)
282
+
283
+ # First column
284
+ metrics_grid.addWidget(MetricWidget("Generation", self.generation), 0, 0)
285
+ metrics_grid.addWidget(MetricWidget("Individual", self.individual), 1, 0)
286
+ metrics_grid.addWidget(MetricWidget("Best Fitness", self.best_fitness), 2, 0)
287
+ metrics_grid.addWidget(MetricWidget("Max Distance", self.max_distance), 3, 0)
288
+ metrics_grid.addWidget(MetricWidget("Num Inputs", self.num_inputs), 4, 0)
289
+ metrics_grid.addWidget(MetricWidget("Trainable Params", self.trainable_params), 5, 0)
290
+
291
+ # Second column
292
+ metrics_grid.addWidget(MetricWidget("Offspring", self.offspring), 0, 1)
293
+ metrics_grid.addWidget(MetricWidget("Lifespan", self.lifespan), 1, 1)
294
+ metrics_grid.addWidget(MetricWidget("Mutation", self.mutation), 2, 1)
295
+ metrics_grid.addWidget(MetricWidget("Crossover", self.crossover), 3, 1)
296
+ metrics_grid.addWidget(MetricWidget("SBX Eta", self.sbx_eta), 4, 1)
297
+ metrics_grid.addWidget(MetricWidget("Layers", self.layers), 5, 1)
298
+
299
+ left_layout.addLayout(metrics_grid)
300
+
301
+ # Control buttons
302
+ left_layout.addWidget(QLabel("Controller:"))
303
+ self.control_widget = ControlButtonWidget()
304
+ left_layout.addWidget(self.control_widget)
305
+
306
+ # Right panel - Neural Network
307
+ right_panel = QWidget()
308
+ right_layout = QVBoxLayout(right_panel)
309
+
310
+ right_layout.addWidget(QLabel("Neural Network Visualization:"))
311
+ self.nn_widget = NeuralNetworkWidget()
312
+ right_layout.addWidget(self.nn_widget)
313
+
314
+ # Add panels to content layout
315
+ content_layout.addWidget(left_panel, 1)
316
+ content_layout.addWidget(right_panel, 2)
317
+
318
+ # Footer
319
+ footer_label = QLabel("AI Learns to Play Super Mario Bros!\nUsing a Genetic Algorithm and Neural Network, a population of AI were able to learn to play different levels of Super Mario Bros for the NES.")
320
+ footer_label.setAlignment(Qt.AlignCenter)
321
+ footer_label.setStyleSheet("""
322
+ QLabel {
323
+ color: #cccccc;
324
+ font-size: 12px;
325
+ margin: 10px;
326
+ padding: 10px;
327
+ background-color: #2a2a2a;
328
+ border-radius: 5px;
329
+ }
330
+ """)
331
+ footer_label.setWordWrap(True)
332
+
333
+ # Assemble main layout
334
+ main_layout.addLayout(header_layout)
335
+ main_layout.addWidget(separator)
336
+ main_layout.addLayout(content_layout)
337
+ main_layout.addWidget(footer_label)
338
+
339
+ # Set dark theme
340
+ self.setStyleSheet("""
341
+ QMainWindow {
342
+ background-color: #1a1a1a;
343
+ }
344
+ QWidget {
345
+ background-color: #1a1a1a;
346
+ color: #ffffff;
347
+ }
348
+ """)
349
+
350
+ def setup_timers(self):
351
+ # Timer for neural network updates
352
+ self.nn_timer = QTimer()
353
+ self.nn_timer.timeout.connect(self.update_neural_network)
354
+ self.nn_timer.start(100) # Update every 100ms
355
+
356
+ # Timer for button activations
357
+ self.button_timer = QTimer()
358
+ self.button_timer.timeout.connect(self.update_buttons)
359
+ self.button_timer.start(200) # Update every 200ms
360
+
361
+ # Timer for metrics updates
362
+ self.metrics_timer = QTimer()
363
+ self.metrics_timer.timeout.connect(self.update_metrics)
364
+ self.metrics_timer.start(1000) # Update every second
365
+
366
+ def update_neural_network(self):
367
+ self.nn_widget.update_activations()
368
+
369
+ def update_buttons(self):
370
+ # Randomly activate buttons to simulate gameplay
371
+ buttons = ['U', 'D', 'L', 'R', 'A', 'B']
372
+ for button in buttons:
373
+ if random.random() < 0.3: # 30% chance to activate each button
374
+ self.control_widget.activate_button(button, True)
375
+ else:
376
+ self.control_widget.activate_button(button, False)
377
+
378
+ def update_metrics(self):
379
+ # Simulate metric updates
380
+ self.max_distance += random.randint(1, 10)
381
+ self.best_fitness += random.randint(0, 5)
382
+
383
+ # Update UI (in a real app, you'd update the actual metric widgets)
384
+ # For now, we'll just store the updated values
385
+
386
+ def main():
387
+ app = QApplication(sys.argv)
388
+
389
+ # Set application-wide font
390
+ font = QFont("Courier New", 10)
391
+ app.setFont(font)
392
+
393
+ window = MarioAIApp()
394
+ window.show()
395
+
396
+ sys.exit(app.exec_())
397
+
398
+ if __name__ == "__main__":
399
+ main()
mario_ai_app_2.py ADDED
@@ -0,0 +1,664 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import random
3
+ import numpy as np
4
+ from PyQt5.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout,
5
+ QHBoxLayout, QLabel, QFrame, QGridLayout,
6
+ QPushButton, QProgressBar)
7
+ from PyQt5.QtCore import QTimer, Qt, QThread, pyqtSignal
8
+ from PyQt5.QtGui import QFont, QPainter, QColor, QPen, QBrush, QPixmap, QImage
9
+ import gym_super_mario_bros
10
+ from gym_super_mario_bros.actions import RIGHT_ONLY, SIMPLE_MOVEMENT, COMPLEX_MOVEMENT
11
+ from nes_py.wrappers import JoypadSpace
12
+ import cv2
13
+ from collections import deque
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.optim as optim
17
+ import numpy as np
18
+
19
+ class GeneticNetwork(nn.Module):
20
+ def __init__(self, input_size, hidden_size, output_size):
21
+ super(GeneticNetwork, self).__init__()
22
+ self.fc1 = nn.Linear(input_size, hidden_size)
23
+ self.fc2 = nn.Linear(hidden_size, output_size)
24
+ self.relu = nn.ReLU()
25
+
26
+ def forward(self, x):
27
+ x = self.relu(self.fc1(x))
28
+ x = self.fc2(x)
29
+ return x
30
+
31
+ class MarioAIWorker(QThread):
32
+ update_signal = pyqtSignal(dict)
33
+ frame_signal = pyqtSignal(np.ndarray)
34
+
35
+ def __init__(self):
36
+ super().__init__()
37
+ self.running = False
38
+ self.generation = 1
39
+ self.population_size = 50
40
+ self.current_individual = 0
41
+ self.best_fitness = 0
42
+ self.max_distance = 0
43
+ self.env = None
44
+ self.population = []
45
+ self.fitness_scores = []
46
+ self.setup_environment()
47
+ self.setup_population()
48
+
49
+ def setup_environment(self):
50
+ """Initialize the Mario environment"""
51
+ self.env = gym_super_mario_bros.make('SuperMarioBros-1-1-v0')
52
+ self.env = JoypadSpace(self.env, SIMPLE_MOVEMENT)
53
+
54
+ def setup_population(self):
55
+ """Initialize the population of neural networks"""
56
+ self.population = []
57
+ self.fitness_scores = [0] * self.population_size
58
+
59
+ for i in range(self.population_size):
60
+ network = GeneticNetwork(80, 9, 6) # Match the layers from screenshot
61
+ # Initialize with random weights
62
+ for param in network.parameters():
63
+ nn.init.normal_(param, mean=0.0, std=0.1)
64
+ self.population.append(network)
65
+
66
+ def preprocess_state(self, state):
67
+ """Preprocess the game state for the neural network"""
68
+ # Convert to grayscale and resize
69
+ gray = cv2.cvtColor(state, cv2.COLOR_RGB2GRAY)
70
+ resized = cv2.resize(gray, (10, 8)) # 80 pixels total
71
+ flattened = resized.flatten()
72
+ normalized = flattened / 255.0 # Normalize to [0, 1]
73
+ return normalized
74
+
75
+ def run(self):
76
+ """Main training loop"""
77
+ self.running = True
78
+ while self.running:
79
+ # Evaluate current individual
80
+ fitness, distance, frame = self.evaluate_individual(self.current_individual)
81
+ self.fitness_scores[self.current_individual] = fitness
82
+ self.max_distance = max(self.max_distance, distance)
83
+ self.best_fitness = max(self.best_fitness, fitness)
84
+
85
+ # Emit update signals
86
+ self.update_signal.emit({
87
+ 'generation': self.generation,
88
+ 'individual': f"Individual {self.current_individual + 1}",
89
+ 'best_fitness': int(self.best_fitness),
90
+ 'max_distance': int(self.max_distance),
91
+ 'current_fitness': int(fitness),
92
+ 'current_distance': int(distance)
93
+ })
94
+
95
+ if frame is not None:
96
+ self.frame_signal.emit(frame)
97
+
98
+ # Move to next individual
99
+ self.current_individual += 1
100
+ if self.current_individual >= self.population_size:
101
+ self.evolve_population()
102
+ self.current_individual = 0
103
+ self.generation += 1
104
+
105
+ # Small delay to prevent UI freezing
106
+ self.msleep(50)
107
+
108
+ def evaluate_individual(self, individual_idx):
109
+ """Evaluate one individual in the environment"""
110
+ network = self.population[individual_idx]
111
+ state = self.env.reset()
112
+ total_reward = 0
113
+ max_distance = 0
114
+ last_frame = None
115
+
116
+ for step in range(1000): # Limit steps per evaluation
117
+ # Preprocess state
118
+ processed_state = self.preprocess_state(state)
119
+ state_tensor = torch.FloatTensor(processed_state)
120
+
121
+ # Get action from neural network
122
+ with torch.no_grad():
123
+ output = network(state_tensor)
124
+ action = torch.argmax(output).item()
125
+
126
+ # Take action
127
+ state, reward, done, info = self.env.step(action)
128
+ total_reward += reward
129
+ max_distance = max(max_distance, info['x_pos'])
130
+ last_frame = state
131
+
132
+ if done:
133
+ break
134
+
135
+ # Calculate fitness (reward + distance bonus)
136
+ fitness = total_reward + (max_distance / 10)
137
+ return fitness, max_distance, last_frame
138
+
139
+ def evolve_population(self):
140
+ """Evolve the population using genetic algorithm"""
141
+ # Select parents based on fitness (tournament selection)
142
+ new_population = []
143
+
144
+ # Keep best individual
145
+ best_idx = np.argmax(self.fitness_scores)
146
+ new_population.append(self.population[best_idx])
147
+
148
+ # Create offspring through mutation and crossover
149
+ while len(new_population) < self.population_size:
150
+ # Tournament selection
151
+ parent1 = self.tournament_select()
152
+ parent2 = self.tournament_select()
153
+
154
+ # Crossover and mutation
155
+ child = self.crossover(parent1, parent2)
156
+ child = self.mutate(child)
157
+ new_population.append(child)
158
+
159
+ self.population = new_population
160
+ self.fitness_scores = [0] * self.population_size
161
+
162
+ def tournament_select(self, tournament_size=3):
163
+ """Tournament selection"""
164
+ candidates = random.sample(range(self.population_size), tournament_size)
165
+ best_candidate = max(candidates, key=lambda x: self.fitness_scores[x])
166
+ return self.population[best_candidate]
167
+
168
+ def crossover(self, parent1, parent2):
169
+ """Single-point crossover"""
170
+ child = GeneticNetwork(80, 9, 6)
171
+ child_state = child.state_dict()
172
+ parent1_state = parent1.state_dict()
173
+ parent2_state = parent2.state_dict()
174
+
175
+ for key in child_state.keys():
176
+ # Randomly choose weights from parents
177
+ mask = torch.rand_like(parent1_state[key]) > 0.5
178
+ child_state[key] = torch.where(mask, parent1_state[key], parent2_state[key])
179
+
180
+ child.load_state_dict(child_state)
181
+ return child
182
+
183
+ def mutate(self, network, mutation_rate=0.05):
184
+ """Apply random mutations"""
185
+ mutated_state = network.state_dict()
186
+
187
+ for key in mutated_state.keys():
188
+ mask = torch.rand_like(mutated_state[key]) < mutation_rate
189
+ mutation = torch.randn_like(mutated_state[key]) * 0.1
190
+ mutated_state[key] = torch.where(mask, mutated_state[key] + mutation, mutated_state[key])
191
+
192
+ network.load_state_dict(mutated_state)
193
+ return network
194
+
195
+ def stop(self):
196
+ """Stop the training thread"""
197
+ self.running = False
198
+ if self.env:
199
+ self.env.close()
200
+
201
+ class NeuralNetworkWidget(QWidget):
202
+ def __init__(self):
203
+ super().__init__()
204
+ self.setMinimumSize(400, 300)
205
+ self.layers = [80, 9, 6]
206
+ self.activations = [0.1] * sum(self.layers)
207
+ self.connection_strengths = {}
208
+
209
+ def update_activations(self, activations=None):
210
+ if activations is not None:
211
+ self.activations = activations
212
+ self.update()
213
+
214
+ def paintEvent(self, event):
215
+ painter = QPainter(self)
216
+ painter.setRenderHint(QPainter.Antialiasing)
217
+
218
+ bg_color = QColor(30, 30, 40)
219
+ neuron_color = QColor(100, 150, 255)
220
+ active_neuron_color = QColor(255, 100, 100)
221
+
222
+ painter.fillRect(self.rect(), bg_color)
223
+ width = self.width()
224
+ height = self.height()
225
+
226
+ neuron_positions = []
227
+ activation_index = 0
228
+
229
+ for layer_idx, neuron_count in enumerate(self.layers):
230
+ layer_positions = []
231
+ layer_x = (layer_idx + 1) * width / (len(self.layers) + 1)
232
+
233
+ for neuron_idx in range(neuron_count):
234
+ neuron_y = (neuron_idx + 1) * height / (neuron_count + 1)
235
+ layer_positions.append((layer_x, neuron_y))
236
+
237
+ if layer_idx < len(self.layers) - 1:
238
+ next_layer_count = self.layers[layer_idx + 1]
239
+ for next_neuron_idx in range(next_layer_count):
240
+ next_x = (layer_idx + 2) * width / (len(self.layers) + 1)
241
+ next_y = (next_neuron_idx + 1) * height / (next_layer_count + 1)
242
+
243
+ strength = random.random()
244
+ alpha = int(50 + strength * 100)
245
+ pen_color = QColor(100, 100, 150, alpha)
246
+
247
+ painter.setPen(QPen(pen_color, 1 + strength * 2))
248
+ painter.drawLine(int(layer_x), int(neuron_y),
249
+ int(next_x), int(next_y))
250
+
251
+ activation_index += 1
252
+ neuron_positions.append(layer_positions)
253
+
254
+ activation_index = 0
255
+ for layer_idx, positions in enumerate(neuron_positions):
256
+ for pos_idx, (x, y) in enumerate(positions):
257
+ activation = self.activations[activation_index]
258
+ activation_index += 1
259
+
260
+ if activation > 0.7:
261
+ color = active_neuron_color
262
+ elif activation > 0.3:
263
+ color = QColor(255, 200, 100)
264
+ else:
265
+ color = neuron_color
266
+
267
+ radius = 8 + activation * 8
268
+ painter.setBrush(QBrush(color))
269
+ painter.setPen(QPen(QColor(200, 200, 255), 2))
270
+ painter.drawEllipse(int(x - radius/2), int(y - radius/2),
271
+ int(radius), int(radius))
272
+
273
+ class ControlButtonWidget(QWidget):
274
+ def __init__(self):
275
+ super().__init__()
276
+ self.setup_ui()
277
+
278
+ def setup_ui(self):
279
+ layout = QGridLayout()
280
+ layout.setSpacing(10)
281
+ layout.setContentsMargins(10, 10, 10, 10)
282
+
283
+ buttons = [
284
+ ('U', 0, 1),
285
+ ('L', 1, 0),
286
+ ('D', 1, 1),
287
+ ('R', 1, 2),
288
+ ('A', 0, 3),
289
+ ('B', 1, 3),
290
+ ]
291
+
292
+ self.labels = {}
293
+
294
+ for text, row, col in buttons:
295
+ label = QLabel(text)
296
+ label.setAlignment(Qt.AlignCenter)
297
+ label.setStyleSheet("""
298
+ QLabel {
299
+ background-color: #2d2d2d;
300
+ color: #cccccc;
301
+ border: 2px solid #555555;
302
+ border-radius: 10px;
303
+ font-weight: bold;
304
+ font-size: 14px;
305
+ min-width: 40px;
306
+ min-height: 40px;
307
+ }
308
+ """)
309
+ label.setMinimumSize(50, 50)
310
+ layout.addWidget(label, row, col)
311
+ self.labels[text] = label
312
+
313
+ self.setLayout(layout)
314
+
315
+ def activate_button(self, button, active=True):
316
+ if button in self.labels:
317
+ if active:
318
+ self.labels[button].setStyleSheet("""
319
+ QLabel {
320
+ background-color: #ff4444;
321
+ color: white;
322
+ border: 2px solid #ff6666;
323
+ border-radius: 10px;
324
+ font-weight: bold;
325
+ font-size: 14px;
326
+ min-width: 40px;
327
+ min-height: 40px;
328
+ }
329
+ """)
330
+ else:
331
+ self.labels[button].setStyleSheet("""
332
+ QLabel {
333
+ background-color: #2d2d2d;
334
+ color: #cccccc;
335
+ border: 2px solid #555555;
336
+ border-radius: 10px;
337
+ font-weight: bold;
338
+ font-size: 14px;
339
+ min-width: 40px;
340
+ min-height: 40px;
341
+ }
342
+ """)
343
+
344
+ class MetricWidget(QWidget):
345
+ def __init__(self, title, value, unit=""):
346
+ super().__init__()
347
+ self.title = title
348
+ self.value = str(value)
349
+ self.unit = unit
350
+ self.setup_ui()
351
+
352
+ def setup_ui(self):
353
+ layout = QVBoxLayout()
354
+ layout.setSpacing(2)
355
+ layout.setContentsMargins(5, 5, 5, 5)
356
+
357
+ self.title_label = QLabel(self.title)
358
+ self.title_label.setAlignment(Qt.AlignCenter)
359
+ self.title_label.setStyleSheet("color: #888888; font-size: 10px;")
360
+
361
+ self.value_label = QLabel(f"{self.value}{self.unit}")
362
+ self.value_label.setAlignment(Qt.AlignCenter)
363
+ self.value_label.setStyleSheet("color: #ffffff; font-size: 12px; font-weight: bold;")
364
+
365
+ layout.addWidget(self.title_label)
366
+ layout.addWidget(self.value_label)
367
+
368
+ self.setLayout(layout)
369
+
370
+ def update_value(self, new_value):
371
+ self.value_label.setText(f"{new_value}{self.unit}")
372
+
373
+ class GameDisplayWidget(QWidget):
374
+ def __init__(self):
375
+ super().__init__()
376
+ self.setMinimumSize(320, 240)
377
+ self.current_frame = None
378
+
379
+ def update_frame(self, frame):
380
+ self.current_frame = frame
381
+ self.update()
382
+
383
+ def paintEvent(self, event):
384
+ if self.current_frame is not None:
385
+ painter = QPainter(self)
386
+
387
+ # Convert BGR to RGB
388
+ rgb_frame = cv2.cvtColor(self.current_frame, cv2.COLOR_BGR2RGB)
389
+
390
+ # Resize frame to fit widget
391
+ h, w = rgb_frame.shape[:2]
392
+ q_image = QImage(rgb_frame.data, w, h, QImage.Format_RGB888)
393
+ pixmap = QPixmap.fromImage(q_image)
394
+
395
+ # Scale pixmap to fit widget while maintaining aspect ratio
396
+ scaled_pixmap = pixmap.scaled(self.width(), self.height(),
397
+ Qt.KeepAspectRatio, Qt.FastTransformation)
398
+
399
+ # Center the pixmap
400
+ x = (self.width() - scaled_pixmap.width()) // 2
401
+ y = (self.height() - scaled_pixmap.height()) // 2
402
+ painter.drawPixmap(x, y, scaled_pixmap)
403
+ else:
404
+ # Show placeholder when no frame is available
405
+ painter = QPainter(self)
406
+ painter.fillRect(self.rect(), QColor(50, 50, 50))
407
+ painter.setPen(QColor(200, 200, 200))
408
+ painter.drawText(self.rect(), Qt.AlignCenter, "Game Display\n(Mario will appear here)")
409
+
410
+ class MarioAITrainer(QMainWindow):
411
+ def __init__(self):
412
+ super().__init__()
413
+ self.generation = 1
414
+ self.individual = "Individual 1"
415
+ self.best_fitness = 0
416
+ self.max_distance = 0
417
+ self.num_inputs = 80
418
+ self.trainable_params = 789
419
+ self.offspring = "10, 90"
420
+ self.lifespan = "Infinite"
421
+ self.mutation = "Static 5.0%"
422
+ self.crossover = "Roulette"
423
+ self.sbx_eta = 100.0
424
+ self.layers = "[80, 9, 6]"
425
+
426
+ self.ai_worker = MarioAIWorker()
427
+ self.setup_ui()
428
+ self.connect_signals()
429
+
430
+ def setup_ui(self):
431
+ self.setWindowTitle("MARIO 000500 - AI Learns to Play Super Mario Bros!")
432
+ self.setGeometry(100, 100, 1200, 800)
433
+
434
+ central_widget = QWidget()
435
+ self.setCentralWidget(central_widget)
436
+ main_layout = QVBoxLayout(central_widget)
437
+
438
+ # Header
439
+ header_layout = QVBoxLayout()
440
+
441
+ title_label = QLabel("MARIO 000500")
442
+ title_label.setAlignment(Qt.AlignCenter)
443
+ title_label.setStyleSheet("""
444
+ QLabel {
445
+ color: #ff4444;
446
+ font-size: 24px;
447
+ font-weight: bold;
448
+ margin: 10px;
449
+ }
450
+ """)
451
+
452
+ world_label = QLabel("WORLD 1-1")
453
+ world_label.setAlignment(Qt.AlignCenter)
454
+ world_label.setStyleSheet("""
455
+ QLabel {
456
+ color: #ffffff;
457
+ font-size: 18px;
458
+ font-weight: bold;
459
+ margin: 5px;
460
+ }
461
+ """)
462
+
463
+ time_label = QLabel("TIME 344")
464
+ time_label.setAlignment(Qt.AlignCenter)
465
+ time_label.setStyleSheet("""
466
+ QLabel {
467
+ color: #ffff44;
468
+ font-size: 16px;
469
+ font-weight: bold;
470
+ margin: 5px;
471
+ }
472
+ """)
473
+
474
+ header_layout.addWidget(title_label)
475
+ header_layout.addWidget(world_label)
476
+ header_layout.addWidget(time_label)
477
+
478
+ # Control buttons
479
+ control_buttons_layout = QHBoxLayout()
480
+ self.start_button = QPushButton("Start Training")
481
+ self.stop_button = QPushButton("Stop Training")
482
+ self.reset_button = QPushButton("Reset")
483
+
484
+ self.start_button.setStyleSheet("QPushButton { background-color: #4CAF50; color: white; font-weight: bold; }")
485
+ self.stop_button.setStyleSheet("QPushButton { background-color: #f44336; color: white; font-weight: bold; }")
486
+ self.reset_button.setStyleSheet("QPushButton { background-color: #ff9800; color: white; font-weight: bold; }")
487
+
488
+ control_buttons_layout.addWidget(self.start_button)
489
+ control_buttons_layout.addWidget(self.stop_button)
490
+ control_buttons_layout.addWidget(self.reset_button)
491
+ control_buttons_layout.addStretch()
492
+
493
+ # Main content
494
+ content_layout = QHBoxLayout()
495
+
496
+ # Left panel - Metrics and Controls
497
+ left_panel = QWidget()
498
+ left_layout = QVBoxLayout(left_panel)
499
+
500
+ # Metrics grid
501
+ metrics_grid = QGridLayout()
502
+ metrics_grid.setSpacing(10)
503
+
504
+ # First column
505
+ self.gen_widget = MetricWidget("Generation", self.generation)
506
+ self.ind_widget = MetricWidget("Individual", self.individual)
507
+ self.fit_widget = MetricWidget("Best Fitness", self.best_fitness)
508
+ self.dist_widget = MetricWidget("Max Distance", self.max_distance)
509
+ self.inputs_widget = MetricWidget("Num Inputs", self.num_inputs)
510
+ self.params_widget = MetricWidget("Trainable Params", self.trainable_params)
511
+
512
+ # Second column
513
+ self.offspring_widget = MetricWidget("Offspring", self.offspring)
514
+ self.lifespan_widget = MetricWidget("Lifespan", self.lifespan)
515
+ self.mutation_widget = MetricWidget("Mutation", self.mutation)
516
+ self.crossover_widget = MetricWidget("Crossover", self.crossover)
517
+ self.sbx_widget = MetricWidget("SBX Eta", self.sbx_eta)
518
+ self.layers_widget = MetricWidget("Layers", self.layers)
519
+
520
+ metrics_grid.addWidget(self.gen_widget, 0, 0)
521
+ metrics_grid.addWidget(self.ind_widget, 1, 0)
522
+ metrics_grid.addWidget(self.fit_widget, 2, 0)
523
+ metrics_grid.addWidget(self.dist_widget, 3, 0)
524
+ metrics_grid.addWidget(self.inputs_widget, 4, 0)
525
+ metrics_grid.addWidget(self.params_widget, 5, 0)
526
+
527
+ metrics_grid.addWidget(self.offspring_widget, 0, 1)
528
+ metrics_grid.addWidget(self.lifespan_widget, 1, 1)
529
+ metrics_grid.addWidget(self.mutation_widget, 2, 1)
530
+ metrics_grid.addWidget(self.crossover_widget, 3, 1)
531
+ metrics_grid.addWidget(self.sbx_widget, 4, 1)
532
+ metrics_grid.addWidget(self.layers_widget, 5, 1)
533
+
534
+ left_layout.addLayout(metrics_grid)
535
+
536
+ # Controller
537
+ left_layout.addWidget(QLabel("Controller:"))
538
+ self.control_widget = ControlButtonWidget()
539
+ left_layout.addWidget(self.control_widget)
540
+
541
+ # Right panel - Game and Neural Network
542
+ right_panel = QWidget()
543
+ right_layout = QVBoxLayout(right_panel)
544
+
545
+ # Game display
546
+ right_layout.addWidget(QLabel("Game Display:"))
547
+ self.game_widget = GameDisplayWidget()
548
+ right_layout.addWidget(self.game_widget)
549
+
550
+ # Neural Network
551
+ right_layout.addWidget(QLabel("Neural Network Visualization:"))
552
+ self.nn_widget = NeuralNetworkWidget()
553
+ right_layout.addWidget(self.nn_widget)
554
+
555
+ content_layout.addWidget(left_panel, 1)
556
+ content_layout.addWidget(right_panel, 2)
557
+
558
+ # Footer
559
+ footer_label = QLabel("AI Learns to Play Super Mario Bros!\nUsing a Genetic Algorithm and Neural Network, a population of AI were able to learn to play different levels of Super Mario Bros for the NES.")
560
+ footer_label.setAlignment(Qt.AlignCenter)
561
+ footer_label.setStyleSheet("""
562
+ QLabel {
563
+ color: #cccccc;
564
+ font-size: 12px;
565
+ margin: 10px;
566
+ padding: 10px;
567
+ background-color: #2a2a2a;
568
+ border-radius: 5px;
569
+ }
570
+ """)
571
+ footer_label.setWordWrap(True)
572
+
573
+ # Assemble main layout
574
+ main_layout.addLayout(header_layout)
575
+ main_layout.addLayout(control_buttons_layout)
576
+ main_layout.addLayout(content_layout)
577
+ main_layout.addWidget(footer_label)
578
+
579
+ self.setStyleSheet("""
580
+ QMainWindow, QWidget {
581
+ background-color: #1a1a1a;
582
+ color: #ffffff;
583
+ }
584
+ """)
585
+
586
+ def connect_signals(self):
587
+ """Connect signals from AI worker to UI updates"""
588
+ self.start_button.clicked.connect(self.start_training)
589
+ self.stop_button.clicked.connect(self.stop_training)
590
+ self.reset_button.clicked.connect(self.reset_training)
591
+
592
+ self.ai_worker.update_signal.connect(self.update_metrics)
593
+ self.ai_worker.frame_signal.connect(self.update_game_display)
594
+
595
+ def start_training(self):
596
+ """Start the AI training"""
597
+ self.ai_worker.start()
598
+ self.start_button.setEnabled(False)
599
+ self.stop_button.setEnabled(True)
600
+
601
+ def stop_training(self):
602
+ """Stop the AI training"""
603
+ self.ai_worker.stop()
604
+ self.ai_worker.wait()
605
+ self.start_button.setEnabled(True)
606
+ self.stop_button.setEnabled(False)
607
+
608
+ def reset_training(self):
609
+ """Reset the training"""
610
+ self.stop_training()
611
+ self.ai_worker = MarioAIWorker()
612
+ self.connect_signals()
613
+ self.generation = 1
614
+ self.best_fitness = 0
615
+ self.max_distance = 0
616
+ self.update_ui_metrics()
617
+
618
+ def update_metrics(self, data):
619
+ """Update metrics from AI worker"""
620
+ self.generation = data['generation']
621
+ self.individual = data['individual']
622
+ self.best_fitness = data['best_fitness']
623
+ self.max_distance = data['max_distance']
624
+
625
+ self.update_ui_metrics()
626
+
627
+ # Update neural network visualization with random activations
628
+ random_activations = [random.random() for _ in range(sum([80, 9, 6]))]
629
+ self.nn_widget.update_activations(random_activations)
630
+
631
+ # Update controller buttons based on random actions
632
+ buttons = ['U', 'D', 'L', 'R', 'A', 'B']
633
+ for button in buttons:
634
+ self.control_widget.activate_button(button, random.random() > 0.7)
635
+
636
+ def update_ui_metrics(self):
637
+ """Update the UI metric widgets"""
638
+ self.gen_widget.update_value(self.generation)
639
+ self.ind_widget.update_value(self.individual)
640
+ self.fit_widget.update_value(self.best_fitness)
641
+ self.dist_widget.update_value(self.max_distance)
642
+
643
+ def update_game_display(self, frame):
644
+ """Update the game display with new frame"""
645
+ self.game_widget.update_frame(frame)
646
+
647
+ def closeEvent(self, event):
648
+ """Ensure clean shutdown"""
649
+ self.stop_training()
650
+ event.accept()
651
+
652
+ def main():
653
+ app = QApplication(sys.argv)
654
+
655
+ font = QFont("Courier New", 10)
656
+ app.setFont(font)
657
+
658
+ window = MarioAITrainer()
659
+ window.show()
660
+
661
+ sys.exit(app.exec_())
662
+
663
+ if __name__ == "__main__":
664
+ main()
output.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:03553c2d38c2055be02b0a484bc0f332dcddbd010f9a248b660575338661c8ef
3
+ size 55795271
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy==1.26.4
2
+ torch>=1.6.0
3
+ torchvision
4
+ gym==0.23
5
+ nes-py
6
+ gym-super-mario-bros==7.2.3
7
+ opencv-python
8
+ matplotlib
9
+ PyQt5
10
+ pygame