|
|
import gradio as gr |
|
|
import os |
|
|
from huggingface_hub import InferenceClient, list_models |
|
|
from diffusers import StableDiffusionXLPipeline |
|
|
import torch |
|
|
from PIL import Image |
|
|
import traceback |
|
|
|
|
|
|
|
|
HF_TOKEN = os.getenv("ohgoddamn") |
|
|
USE_LOCAL = False |
|
|
client_cache = {} |
|
|
|
|
|
|
|
|
all_models = [ |
|
|
"stabilityai/stable-diffusion-xl-base-1.0", |
|
|
"runwayml/stable-diffusion-v1-5", |
|
|
"Uthar/John6666_epicrealism-xl-v8kiss-sdxl" |
|
|
] |
|
|
|
|
|
|
|
|
def load_local_pipeline(model_id): |
|
|
pipe = StableDiffusionXLPipeline.from_pretrained( |
|
|
model_id, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 |
|
|
) |
|
|
return pipe.to("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
def generate(model_id, prompt, use_local): |
|
|
global client_cache |
|
|
debug_log = "" |
|
|
try: |
|
|
if use_local: |
|
|
debug_log += f"🔧 Using local pipeline for: {model_id}\n" |
|
|
pipe = load_local_pipeline(model_id) |
|
|
image = pipe(prompt).images[0] |
|
|
else: |
|
|
debug_log += f"🌐 Using InferenceClient for: {model_id}\n" |
|
|
if model_id not in client_cache: |
|
|
client_cache[model_id] = InferenceClient(model=model_id, token=HF_TOKEN) |
|
|
image = client_cache[model_id].text_to_image(prompt) |
|
|
return image, debug_log + "\n✅ Success." |
|
|
except Exception as e: |
|
|
error_msg = traceback.format_exc() |
|
|
return None, debug_log + f"\n❌ Error:\n{error_msg}" |
|
|
|
|
|
|
|
|
def check_model_access(models): |
|
|
results = "" |
|
|
for model in models: |
|
|
try: |
|
|
client = InferenceClient(model=model, token=HF_TOKEN) |
|
|
_ = client.text_to_image("test prompt") |
|
|
results += f"✅ {model} is working.\n" |
|
|
except Exception as e: |
|
|
results += f"❌ {model} failed: {str(e).splitlines()[0]}\n" |
|
|
return results |
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("# 🧪 Stable Diffusion API Tester") |
|
|
with gr.Row(): |
|
|
model = gr.Dropdown(choices=all_models, label="Model", value=all_models[0]) |
|
|
use_local = gr.Checkbox(label="Use Local Diffusers Instead of API", value=USE_LOCAL) |
|
|
prompt = gr.Textbox(label="Prompt", value="a cyberpunk cat playing guitar in Tokyo") |
|
|
generate_btn = gr.Button("Generate") |
|
|
image_out = gr.Image(label="Generated Image") |
|
|
debug_out = gr.Textbox(label="Debug Output", lines=10) |
|
|
with gr.Accordion("Self-Check: API Model Access", open=False): |
|
|
check_btn = gr.Button("Check All Models") |
|
|
check_results = gr.Textbox(label="Model API Status", lines=10) |
|
|
|
|
|
generate_btn.click(generate, inputs=[model, prompt, use_local], outputs=[image_out, debug_out]) |
|
|
check_btn.click(check_model_access, inputs=[gr.State(all_models)], outputs=[check_results]) |
|
|
|
|
|
demo.launch() |