Update README.md
Browse files
README.md
CHANGED
|
@@ -12,3 +12,141 @@ tags:
|
|
| 12 |
# InfiGUI-R1-3B
|
| 13 |
|
| 14 |
This repository contains the model from the [InfiGUI-R1](https://arxiv.org/abs/2504.14239) paper. The model is based on `Qwen2.5-VL-3B-Instruct` and trained using the proposed Actor2Reasoner framework, enhanced through reinforcement learning to improve its planning and reflection capabilities for GUI tasks.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
# InfiGUI-R1-3B
|
| 13 |
|
| 14 |
This repository contains the model from the [InfiGUI-R1](https://arxiv.org/abs/2504.14239) paper. The model is based on `Qwen2.5-VL-3B-Instruct` and trained using the proposed Actor2Reasoner framework, enhanced through reinforcement learning to improve its planning and reflection capabilities for GUI tasks.
|
| 15 |
+
|
| 16 |
+
## Quick Start
|
| 17 |
+
|
| 18 |
+
### Installation
|
| 19 |
+
First install required dependencies:
|
| 20 |
+
```bash
|
| 21 |
+
pip install transformers qwen-vl-utils
|
| 22 |
+
```
|
| 23 |
+
|
| 24 |
+
### An Example of GUI Grounding & Trajectory Task
|
| 25 |
+
```python
|
| 26 |
+
import cv2
|
| 27 |
+
import json
|
| 28 |
+
import torch
|
| 29 |
+
import requests
|
| 30 |
+
from PIL import Image
|
| 31 |
+
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
|
| 32 |
+
from qwen_vl_utils import process_vision_info, smart_resize
|
| 33 |
+
|
| 34 |
+
MAX_IMAGE_PIXELS = 5600*28*28
|
| 35 |
+
|
| 36 |
+
# Load model and processor
|
| 37 |
+
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
| 38 |
+
"Reallm-Labs/InfiGUI-R1-3B",
|
| 39 |
+
torch_dtype=torch.bfloat16,
|
| 40 |
+
attn_implementation="flash_attention_2",
|
| 41 |
+
device_map="auto"
|
| 42 |
+
)
|
| 43 |
+
processor = AutoProcessor.from_pretrained("Reallm-Labs/InfiGUI-R1-3B", max_pixels=MAX_IMAGE_PIXELS, padding_side="left")
|
| 44 |
+
|
| 45 |
+
# Prepare image
|
| 46 |
+
img_url = "https://raw.githubusercontent.com/Reallm-Labs/InfiGUI-R1/main/images/test_img.png"
|
| 47 |
+
response = requests.get(img_url)
|
| 48 |
+
with open("test_img.png", "wb") as f:
|
| 49 |
+
f.write(response.content)
|
| 50 |
+
image = Image.open("test_img.png")
|
| 51 |
+
width, height = image.size
|
| 52 |
+
new_height, new_width = smart_resize(height, width, max_pixels=MAX_IMAGE_PIXELS)
|
| 53 |
+
|
| 54 |
+
# Prepare inputs
|
| 55 |
+
instruction = "View detailed storage space usage"
|
| 56 |
+
|
| 57 |
+
system_prompt = "You FIRST think about the reasoning process as an internal monologue and then provide the final answer.\nThe reasoning process MUST BE enclosed within <think> </think> tags."
|
| 58 |
+
tool_prompt = "\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>\n{\"type\": \"function\", \"function\": {\"name\": \"mobile_use\", \"description\": \"Use a touchscreen to interact with a mobile device, and take screenshots.\\n* This is an interface to a mobile device with touchscreen. You can perform actions like clicking, typing, swiping, etc.\\n* Some applications may take time to start or process actions, so you may need to wait and take successive screenshots to see the results of your actions.\\n* The screen's resolution is " + str(new_width) + "x" + str(new_height) + ".\\n* Make sure to click any buttons, links, icons, etc with the cursor tip in the center of the element. Don't click boxes on their edges unless asked.\", \"parameters\": {\"properties\": {\"action\": {\"description\": \"The action to perform. The available actions are:\\n* `key`: Perform a key event on the mobile device.\\n - This supports adb's `keyevent` syntax.\\n - Examples: \\\"volume_up\\\", \\\"volume_down\\\", \\\"power\\\", \\\"camera\\\", \\\"clear\\\".\\n* `click`: Click the point on the screen with coordinate (x, y).\\n* `long_press`: Press the point on the screen with coordinate (x, y) for specified seconds.\\n* `swipe`: Swipe from the starting point with coordinate (x, y) to the end point with coordinates2 (x2, y2).\\n* `type`: Input the specified text into the activated input box.\\n* `system_button`: Press the system button.\\n* `open`: Open an app on the device.\\n* `wait`: Wait specified seconds for the change to happen.\\n* `terminate`: Terminate the current task and report its completion status.\", \"enum\": [\"key\", \"click\", \"long_press\", \"swipe\", \"type\", \"system_button\", \"open\", \"wait\", \"terminate\"], \"type\": \"string\"}, \"coordinate\": {\"description\": \"(x, y): The x (pixels from the left edge) and y (pixels from the top edge) coordinates to move the mouse to. Required only by `action=click`, `action=long_press`, and `action=swipe`.\", \"type\": \"array\"}, \"coordinate2\": {\"description\": \"(x, y): The x (pixels from the left edge) and y (pixels from the top edge) coordinates to move the mouse to. Required only by `action=swipe`.\", \"type\": \"array\"}, \"text\": {\"description\": \"Required only by `action=key`, `action=type`, and `action=open`.\", \"type\": \"string\"}, \"time\": {\"description\": \"The seconds to wait. Required only by `action=long_press` and `action=wait`.\", \"type\": \"number\"}, \"button\": {\"description\": \"Back means returning to the previous interface, Home means returning to the desktop, Menu means opening the application background menu, and Enter means pressing the enter. Required only by `action=system_button`\", \"enum\": [\"Back\", \"Home\", \"Menu\", \"Enter\"], \"type\": \"string\"}, \"status\": {\"description\": \"The status of the task. Required only by `action=terminate`.\", \"type\": \"string\", \"enum\": [\"success\", \"failure\"]}}, \"required\": [\"action\"], \"type\": \"object\"}}}\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call>"
|
| 59 |
+
grounding_prompt = f'''The screen's resolution is {new_width}x{new_height}.\nPoint to the UI element most relevant to "{instruction}", output its coordinates using JSON format:\n```json\n[\n {{"point_2d": [x, y], "label": "object name/description"}}\n]```'''
|
| 60 |
+
trajectory_prompt = f"The user query: {instruction}\nTask progress (You have done the following operation on the current device): "
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# Build messages
|
| 64 |
+
grounding_messages = [
|
| 65 |
+
{"role": "system", "content": system_prompt},
|
| 66 |
+
{
|
| 67 |
+
"role": "user",
|
| 68 |
+
"content": [
|
| 69 |
+
{"type": "image", "image": "test_img.png"},
|
| 70 |
+
{"type": "text", "text": grounding_prompt}
|
| 71 |
+
]
|
| 72 |
+
}
|
| 73 |
+
]
|
| 74 |
+
trajectory_messages = [
|
| 75 |
+
{"role": "system", "content": system_prompt + tool_prompt},
|
| 76 |
+
{
|
| 77 |
+
"role": "user",
|
| 78 |
+
"content": [
|
| 79 |
+
{"type": "text", "text": trajectory_prompt},
|
| 80 |
+
{"type": "image", "image": "test_img.png"}
|
| 81 |
+
],
|
| 82 |
+
},
|
| 83 |
+
]
|
| 84 |
+
messages = [grounding_messages, trajectory_messages]
|
| 85 |
+
|
| 86 |
+
# Process and generate
|
| 87 |
+
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 88 |
+
image_inputs, video_inputs = process_vision_info(messages)
|
| 89 |
+
inputs = processor(text=text, images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt").to("cuda")
|
| 90 |
+
generated_ids = model.generate(**inputs, max_new_tokens=512)
|
| 91 |
+
output_text = processor.batch_decode(
|
| 92 |
+
[out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)],
|
| 93 |
+
skip_special_tokens=True,
|
| 94 |
+
clean_up_tokenization_spaces=False
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
# Visualize results
|
| 98 |
+
output_text = [ot.split("</think>")[-1] for ot in output_text]
|
| 99 |
+
|
| 100 |
+
grounding_output = output_text[0].replace("```json", "").replace("```", "").strip()
|
| 101 |
+
trajectory_output = output_text[1].replace("<tool_call>", "").replace("</tool_call>", "").strip()
|
| 102 |
+
|
| 103 |
+
try:
|
| 104 |
+
grounding_output = json.loads(grounding_output)
|
| 105 |
+
trajectory_output = json.loads(trajectory_output)
|
| 106 |
+
|
| 107 |
+
grounding_coords = grounding_output[0]['point_2d']
|
| 108 |
+
trajectory_coords = trajectory_output["arguments"]['coordinate'] if "coordinate" in trajectory_output["arguments"] else None
|
| 109 |
+
|
| 110 |
+
grounding_label = grounding_output[0]['label']
|
| 111 |
+
trajectory_label = json.dumps(trajectory_output["arguments"])
|
| 112 |
+
|
| 113 |
+
# Load the original image
|
| 114 |
+
img = cv2.imread("test_img.png")
|
| 115 |
+
if img is None:
|
| 116 |
+
raise ValueError("Could not load the image")
|
| 117 |
+
|
| 118 |
+
height, width = img.shape[:2]
|
| 119 |
+
|
| 120 |
+
# Create copies for each visualization
|
| 121 |
+
grounding_img = img.copy()
|
| 122 |
+
trajectory_img = img.copy()
|
| 123 |
+
|
| 124 |
+
# Visualize grounding coordinates
|
| 125 |
+
if grounding_coords:
|
| 126 |
+
x = int(grounding_coords[0] / new_width * width)
|
| 127 |
+
y = int(grounding_coords[1] / new_height * height)
|
| 128 |
+
|
| 129 |
+
cv2.circle(grounding_img, (x, y), 10, (0, 0, 255), -1)
|
| 130 |
+
cv2.putText(grounding_img, grounding_label, (x+10, y-10),
|
| 131 |
+
cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 0, 255), 2)
|
| 132 |
+
cv2.imwrite("grounding_output.png", grounding_img)
|
| 133 |
+
print("Predicted coordinates:", grounding_coords)
|
| 134 |
+
print(f"Grounding visualization saved to grounding_output.png")
|
| 135 |
+
|
| 136 |
+
# Visualize trajectory coordinates
|
| 137 |
+
if trajectory_coords:
|
| 138 |
+
x = int(trajectory_coords[0] / new_width * width)
|
| 139 |
+
y = int(trajectory_coords[1] / new_height * height)
|
| 140 |
+
|
| 141 |
+
cv2.circle(trajectory_img, (x, y), 10, (0, 0, 255), -1)
|
| 142 |
+
cv2.putText(trajectory_img, trajectory_label, (x+10, y-10),
|
| 143 |
+
cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 0, 255), 2)
|
| 144 |
+
cv2.imwrite("trajectory_output.png", trajectory_img)
|
| 145 |
+
print("Predicted action:", trajectory_label)
|
| 146 |
+
print(f"Trajectory visualization saved to trajectory_output.png")
|
| 147 |
+
|
| 148 |
+
except:
|
| 149 |
+
print("Error: Failed to parse coordinates or process image")
|
| 150 |
+
```
|
| 151 |
+
|
| 152 |
+
For more information, please refer to our [repo](https://github.com/Reallm-Labs/InfiGUI-R1).
|