--- base_model: - Qwen/Qwen2-VL-7B-Instruct datasets: - JosephZ/vg150_train_sgg_prompt library_name: transformers license: apache-2.0 metrics: - recall tags: - image - scene-graph - scene-graph-generation pipeline_tag: image-text-to-text --- # Model Description An end-to-end multimodal LLM for Scene Graph Generation (SGG), which was introduced in [Compile Scene Graphs with Reinforcement Learning](https://huggingface.co/papers/2504.13617) # R1-SGG: Compile Scene Graphs with Reinforcement Learning ## **Structured Visual Reasoning with Multimodal LLMs and Reinforcement Learning** [![Paper](https://img.shields.io/badge/arXiv-2504.13617-b31b1b.svg)](https://arxiv.org/abs/2504.13617) [![License](https://img.shields.io/badge/license-Apache--2.0-blue.svg)](LICENSE) [![Hugging Face](https://img.shields.io/badge/HuggingFace-Demo-orange?logo=huggingface)](https://huggingface.co/spaces/JosephZ/R1-SGG) --- ## ๐Ÿš€ Update - โœ… ![Hugging Face](https://img.shields.io/badge/HuggingFace-Model-orange?logo=huggingface)[R1-SGG-7B](https://huggingface.co/JosephZ/R1-SGG-7B), [R1-SGG-Zero-7B](https://huggingface.co/JosephZ/R1-SGG-Zero-7B) - โœ… Support [PSG](https://github.com/Jingkang50/OpenPSG) dataset (bbox format only, not Panoptic) - โœ… Updated loss implementation - โœ… Always use `custom_per_device_train_batch_size` instead of `per_device_train_batch_size` for faster sampling under gradient accumulation - โš ๏ธ Current loss implementation might still be affected by gradient accumulation: [trl issue #3021](https://github.com/huggingface/trl/issues/3021) --- ## ๐Ÿ› ๏ธ Setup Environment ```bash bash install.sh ``` Main dependencies: ```bash - torch == 2.5.0 or 2.5.1 (cu124, optional) - transformers (supports Qwen2VL, Qwen2.5VL) - trl - vLLM ``` --- ## ๐Ÿ“š Dataset Load preprocessed datasets via: ```python from datasets import load_dataset db_train = load_dataset("JosephZ/vg150_train_sgg_prompt")["train"] db_val = load_dataset("JosephZ/vg150_val_sgg_prompt")["train"] ``` or for PSG: ```python db_train = load_dataset("JosephZ/psg_train_sg")["train"] # keys: image_id, image, objects, relationships db_val = load_dataset("JosephZ/psg_test_sg")["train"] ``` We transformed VG150 into HuggingFace Datasets format with keys: - `image_id` - `image` - `prompt_open` - `prompt_close` - `objects` - `relationships` --- ## ๐Ÿ”ฅ Supported Models - [x] Qwen/Qwen2-VL-2B-Instruct - [x] Qwen/Qwen2-VL-7B-Instruct - [x] Qwen/Qwen2.5-VL-3B-Instruct - [x] Qwen/Qwen2.5-VL-7B-Instruct --- ## ๐Ÿ‹๏ธโ€โ™‚๏ธ Training ### Training with Supervised Fine-Tuning (SFT) For **SLURM users**: ```bash sbatch scripts/sft/7B_sgg.sh ``` For **local machines**: ```bash bash scripts/sft_local/7B_sgg.sh ``` โฑ๏ธ Approximate training time: - 2B models: ~4 hours (4ร—A100 SXM4 GPUs) - 7B models: ~10 hours (4ร—A100 SXM4 GPUs) --- ### Training with Reinforcement Learning (GRPO) ** Update (11/05/2025): to use "Hard Recall"**: ``` --reward_funcs format_reward edge_hard_reward ``` For **A100 GPUs**: ```bash sbatch scripts/grpo/train_a100_2B.sh ``` (12 hours on 16ร—A100 GPUs) For **GH200 GPUs**: ```bash sbatch scripts/grpo/train_gh200.sh ``` (16 hours on 16ร—GH200 GPUs) For clusters with many RTX_3090/4090 GPUs: ```bash sbatch scripts/grpo/train_fused.sh ``` - Training 7B models on 24GB cards is possible with Zero3, but slow due to communication bottlenecks. - (Fun fact: training with 120ร—RTX_4090 is crazy but severely limited by communication latency.) ๐Ÿ’ก **Recommended learning rate**: `6e-7`. --- ## ๐Ÿงช Inference and Evaluation ### Inference with SFT-trained models: ```bash bash scripts/inference/run_sgg_inference.sh $DATASET $MODEL_NAME $OUTPUT_DIR ``` For models trained **with predefined categories**, add `true`: ```bash bash scripts/inference/run_sgg_inference.sh $DATASET $MODEL_NAME $OUTPUT_DIR true ``` ### Inference with GRPO-trained models: ```bash bash scripts/inference/run_sgg_inference.sh $DATASET $MODEL_NAME $OUTPUT_DIR false/true true ``` ### Evaluation: ```bash DATASET_TYPE=vg # or psg python src/sgg_gather_preds.py $DATASET_TYPE $OUTPUT_DIR sgg_pred_results.json python src/vg150_eval.py $DATASET sgg_pred_results.json ``` --- ## ๐Ÿค Acknowledgement The `GRPOTrainer` used in this project is based on [trl's GRPOTrainer](https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py), extended to support multimodal inputs. --- ## ๐Ÿ“– Citation If you find this work helpful, please cite: ```bibtex @article{chen2025compile, title={Compile Scene Graphs with Reinforcement Learning}, author={Chen, Zuyao and Wu, Jinlin and Lei, Zhen and Pollefeys, Marc and Chen, Chang Wen}, journal={arXiv preprint arXiv:2504.13617}, year={2025} } ``` --- # โœจ Happy Compiling!