Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,513 Bytes
850b0e4 c8c7b71 766dc1d 850b0e4 ee2e0b7 c8c7b71 ee2e0b7 151dc74 850b0e4 c8c7b71 850b0e4 151dc74 ee2e0b7 c8c7b71 151dc74 9dd6110 c8c7b71 9dd6110 c8c7b71 72414fe c8c7b71 9dd6110 c8c7b71 9dd6110 a28ce33 cfb971e c8c7b71 9dd6110 1ec2ec6 766dc1d 101729b 1ec2ec6 7cc62e7 850b0e4 1ec2ec6 7cc62e7 850b0e4 c8c7b71 850b0e4 c8c7b71 ee2e0b7 151dc74 ee2e0b7 c8c7b71 850b0e4 2c782ec 0150d1c c8c7b71 0150d1c c8c7b71 0150d1c ee2e0b7 c8c7b71 9f829c5 5235656 226c5b4 850b0e4 c8c7b71 040c712 c8c7b71 d59d1e6 1ec2ec6 9f829c5 1ec2ec6 c8c7b71 141b1fb d59d1e6 6a39a28 6a94190 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
import gradio as gr
import torch
import uuid
import spaces
from mario_gpt.dataset import MarioDataset
from mario_gpt.prompter import Prompter
from mario_gpt.lm import MarioLM
from mario_gpt.utils import view_level, convert_level_to_png
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
import os
import uvicorn
from pathlib import Path
mario_lm = MarioLM()
device = torch.device('cuda')
mario_lm = mario_lm.to(device)
TILE_DIR = "data/tiles"
gr.set_static_paths(paths=[Path("static").absolute()])
app = FastAPI()
def make_html_file(generated_level):
level_text = f"""{'''
'''.join(view_level(generated_level,mario_lm.tokenizer))}"""
unique_id = uuid.uuid1()
html_filename = f"demo-{unique_id}.html"
with open(Path("static") / html_filename, 'w', encoding='utf-8') as f:
f.write(f'''<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8">
<title>Mario Game</title>
<script src="https://cjrtnc.leaningtech.com/20230216/loader.js"></script>
</head>
<body>
</body>
<script>
cheerpjInit().then(function () {{
cheerpjAddStringFile("/str/mylevel.txt", `{level_text}`);
}});
cheerpjCreateDisplay(512, 500);
cheerpjRunJar("/app/gradio_api/file=static/mario.jar");
</script>
</html>''')
return html_filename # Return just the filename
@spaces.GPU
def generate(pipes, enemies, blocks, elevation, temperature = 2.0, level_size = 1399, prompt = "", progress=gr.Progress(track_tqdm=True)):
if prompt == "":
prompt = f"{pipes} pipes, {enemies} enemies, {blocks} blocks, {elevation} elevation"
print(f"Using prompt: {prompt}")
print(f"Using temperature: {temperature}")
prompts = [prompt]
generated_level = mario_lm.sample(
prompts=prompts,
num_steps=level_size,
temperature=float(temperature),
use_tqdm=True
)
filename = make_html_file(generated_level)
img = convert_level_to_png(generated_level.squeeze(), TILE_DIR, mario_lm.tokenizer)[0]
gradio_html = f'''<div>
<iframe width=512 height=512 style="margin: 0 auto" src="/gradio_api/file=static/{filename}"></iframe>
<p style="text-align:center">Press the arrow keys to move. Press <code>a</code> to run, <code>s</code> to jump and <code>d</code> to shoot fireflowers</p>
</div>'''
return [img, gradio_html]
with gr.Blocks().queue() as demo:
gr.Markdown('''# MarioGPT
### Playable demo for MarioGPT: Open-Ended Text2Level Generation through Large Language Models
[[Github](https://github.com/shyamsn97/mario-gpt)], [[Paper](https://arxiv.org/abs/2302.05981)]
''')
with gr.Tabs():
with gr.TabItem("Compose prompt"):
with gr.Row():
pipes = gr.Radio(["no", "little", "some", "many"], value="some", label="How many pipes?")
enemies = gr.Radio(["no", "little", "some", "many"], value="some", label="How many enemies?")
with gr.Row():
blocks = gr.Radio(["little", "some", "many"], value="some", label="How many blocks?")
elevation = gr.Radio(["low", "high"], value="low", label="Elevation?")
with gr.TabItem("Type prompt"):
text_prompt = gr.Textbox(value="", label="Enter your MarioGPT prompt. ex: 'many pipes, many enemies, some blocks, low elevation'")
with gr.Accordion(label="Advanced settings", open=False):
temperature = gr.Slider(value=2.0, minimum=0.1, maximum=2.0, step=0.1, label="temperature: Increase these for more diverse, but lower quality, generations")
level_size = gr.Slider(value=1399, minimum=100, maximum=2799, step=1, label="level_size")
btn = gr.Button("Generate level")
with gr.Row():
with gr.Group():
level_play = gr.HTML()
level_image = gr.Image()
btn.click(fn=generate, inputs=[pipes, enemies, blocks, elevation, temperature, level_size, text_prompt], outputs=[level_image, level_play])
gr.Examples(
examples=[
["many", "many", "some", "high"],
["no", "some", "many", "high"],
["many", "many", "little", "low"],
["no", "no", "many", "high"],
],
inputs=[pipes, enemies, blocks, elevation],
outputs=[level_image, level_play],
fn=generate,
cache_examples=True,
)
demo.launch()
app.mount("/static", StaticFiles(directory="static", html=True), name="static")
app = gr.mount_gradio_app(app, demo, "/")
uvicorn.run(app, host="0.0.0.0", port=7860) |