File size: 4,513 Bytes
850b0e4
 
c8c7b71
766dc1d
850b0e4
 
 
 
 
ee2e0b7
 
c8c7b71
ee2e0b7
 
151dc74
850b0e4
c8c7b71
850b0e4
 
 
 
151dc74
 
ee2e0b7
c8c7b71
 
 
 
 
151dc74
9dd6110
 
c8c7b71
9dd6110
c8c7b71
 
 
72414fe
c8c7b71
9dd6110
c8c7b71
 
 
9dd6110
 
 
a28ce33
cfb971e
c8c7b71
9dd6110
 
1ec2ec6
766dc1d
101729b
1ec2ec6
 
 
7cc62e7
850b0e4
 
 
1ec2ec6
7cc62e7
850b0e4
 
c8c7b71
850b0e4
c8c7b71
ee2e0b7
151dc74
ee2e0b7
c8c7b71
 
850b0e4
2c782ec
0150d1c
 
 
c8c7b71
 
 
 
0150d1c
 
c8c7b71
0150d1c
 
ee2e0b7
 
 
c8c7b71
9f829c5
5235656
226c5b4
850b0e4
c8c7b71
040c712
c8c7b71
 
 
d59d1e6
1ec2ec6
 
9f829c5
 
 
1ec2ec6
 
c8c7b71
 
141b1fb
d59d1e6
6a39a28
6a94190
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import gradio as gr
import torch
import uuid
import spaces
from mario_gpt.dataset import MarioDataset
from mario_gpt.prompter import Prompter
from mario_gpt.lm import MarioLM
from mario_gpt.utils import view_level, convert_level_to_png

from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles

import os
import uvicorn
from pathlib import Path

mario_lm = MarioLM()
device = torch.device('cuda')
mario_lm = mario_lm.to(device)
TILE_DIR = "data/tiles"

gr.set_static_paths(paths=[Path("static").absolute()])

app = FastAPI()

def make_html_file(generated_level):
    level_text = f"""{'''
'''.join(view_level(generated_level,mario_lm.tokenizer))}"""
    unique_id = uuid.uuid1()
    html_filename = f"demo-{unique_id}.html"
    with open(Path("static") / html_filename, 'w', encoding='utf-8') as f:
        f.write(f'''<!DOCTYPE html>
<html lang="en">

<head>
    <meta charset="utf-8">
    <title>Mario Game</title>
    <script src="https://cjrtnc.leaningtech.com/20230216/loader.js"></script>
</head>

<body>
</body>
<script>
    cheerpjInit().then(function () {{
        cheerpjAddStringFile("/str/mylevel.txt", `{level_text}`);
    }});
    cheerpjCreateDisplay(512, 500);
    cheerpjRunJar("/app/gradio_api/file=static/mario.jar");
</script>
</html>''')
    return html_filename # Return just the filename

@spaces.GPU
def generate(pipes, enemies, blocks, elevation, temperature = 2.0, level_size = 1399, prompt = "", progress=gr.Progress(track_tqdm=True)):
    if prompt == "":
        prompt = f"{pipes} pipes, {enemies} enemies, {blocks} blocks, {elevation} elevation"
    print(f"Using prompt: {prompt}")
    print(f"Using temperature: {temperature}")
    prompts = [prompt]
    generated_level = mario_lm.sample(
        prompts=prompts,
        num_steps=level_size,
        temperature=float(temperature),
        use_tqdm=True
    )
    filename = make_html_file(generated_level)
    img = convert_level_to_png(generated_level.squeeze(), TILE_DIR, mario_lm.tokenizer)[0]
    
    gradio_html = f'''<div>
        <iframe width=512 height=512 style="margin: 0 auto" src="/gradio_api/file=static/{filename}"></iframe>
        <p style="text-align:center">Press the arrow keys to move. Press <code>a</code> to run, <code>s</code> to jump and <code>d</code> to shoot fireflowers</p>
    </div>'''
    return [img, gradio_html]

with gr.Blocks().queue() as demo:
    gr.Markdown('''# MarioGPT
### Playable demo for MarioGPT: Open-Ended Text2Level Generation through Large Language Models
[[Github](https://github.com/shyamsn97/mario-gpt)], [[Paper](https://arxiv.org/abs/2302.05981)]
    ''')
    with gr.Tabs():
        with gr.TabItem("Compose prompt"):
            with gr.Row():
                pipes = gr.Radio(["no", "little", "some", "many"], value="some", label="How many pipes?")
                enemies = gr.Radio(["no", "little", "some", "many"], value="some", label="How many enemies?")
            with gr.Row():
                blocks = gr.Radio(["little", "some", "many"], value="some", label="How many blocks?")
                elevation = gr.Radio(["low", "high"], value="low", label="Elevation?")
        with gr.TabItem("Type prompt"):
            text_prompt = gr.Textbox(value="", label="Enter your MarioGPT prompt. ex: 'many pipes, many enemies, some blocks, low elevation'")
        
    with gr.Accordion(label="Advanced settings", open=False):
        temperature = gr.Slider(value=2.0, minimum=0.1, maximum=2.0, step=0.1, label="temperature: Increase these for more diverse, but lower quality, generations")
        level_size = gr.Slider(value=1399, minimum=100, maximum=2799, step=1, label="level_size")
    
    btn = gr.Button("Generate level")
    with gr.Row():
        with gr.Group():
            level_play = gr.HTML()    
        level_image = gr.Image()
    btn.click(fn=generate, inputs=[pipes, enemies, blocks, elevation, temperature, level_size, text_prompt], outputs=[level_image, level_play])
    gr.Examples(
        examples=[
            ["many", "many", "some", "high"],
            ["no", "some", "many", "high"],
            ["many", "many", "little", "low"],
            ["no", "no", "many", "high"],
        ],
        inputs=[pipes, enemies, blocks, elevation],
        outputs=[level_image, level_play],
        fn=generate,
        cache_examples=True,
    )
demo.launch()
app.mount("/static", StaticFiles(directory="static", html=True), name="static")
app = gr.mount_gradio_app(app, demo, "/")
uvicorn.run(app, host="0.0.0.0", port=7860)