GRPO Counting Model

A Stable Diffusion 3.5-M model fine-tuned with GRPO (Generative Reinforcement Policy Optimization) method, specifically designed for generating images with precise object counting control.

Model Description

  • Base Model: Stable Diffusion 3.5-M
  • Training Method: GRPO (Group Relative Policy Optimization)
  • Training Data: COCO80 dataset
  • Key Feature: Precise control over object quantities in generated images
  • Supported Range: 1-10 objects

Model Variants

This repository contains four variants of the model, each trained with different strategies:

  1. Strict First (strict_first/)

    • Uses strict reward function
    • Timestep selection: First 50 steps
    • Best for: Most accurate object counting
  2. Relative First (relative_first/)

    • Uses relative reward function
    • Timestep selection: First 50 steps
    • Best for: Balance between accuracy and image quality
  3. Strict Random (strict_random/)

    • Uses strict reward function
    • Timestep selection: Random steps
    • Best for: Diverse image generation with accurate counting
  4. Relative Random (relative_random/)

    • Uses relative reward function
    • Timestep selection: Random steps
    • Best for: Maximum diversity in generation

Model Usage

1. Download the Model

from huggingface_hub import snapshot_download

# Download the model locally (replace variant with one of: strict_first, relative_first, strict_random, relative_random)
variant = "strict_first"  # Choose the variant you want to use
model_path = snapshot_download(
    repo_id="MiaTiancai/grpo-counting-model",
    local_dir=f"./grpo_counting_model_{variant}",  # specify your local path
    subfolder=variant  # Specify which variant to download
)

2. Model Inference

For inference, please refer to the Flow-GRPO repository. The repository contains all necessary code and instructions for running inference with this model.

Usage Tips

  1. Prompt Format:

    • Always include a specific number (1-10) in your prompt
    • Use clear object descriptions
    • Examples: "3 cats sitting on a couch", "5 red balloons floating in the sky"
  2. Best Practices:

    • Keep numbers within the supported range (1-10)
    • Use simple and clear scene descriptions
    • Avoid overly complex compositions
    • Choose the appropriate model variant based on your needs:
      • For highest counting accuracy: use strict_first
      • For best image quality: use relative_random
      • For balanced results: use relative_first

Limitations

  • Optimal performance for scenes with 1-10 objects
  • May have reduced effectiveness with complex scenes
  • Results depend on prompt quality and clarity
  • Each variant has its own strengths and trade-offs

License

MIT License

Downloads last month
-
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support