import asyncio import httpx from typing import List, Dict, Any BATCH_SIZE = 1000 TIMEOUT_S = 600 MAX_RETRIES = 3 RETRY_DELAY = 1 def chunks(xs: List[str], n: int): for i in range(0, len(xs), n): yield xs[i:i+n] async def fetch_metadata(client: httpx.AsyncClient, base_url: str) -> Dict[str, Any]: for attempt in range(MAX_RETRIES): r = await client.get(f"{base_url}/metadata", timeout=30) r.raise_for_status() return r.json() async def call_predict(client: httpx.AsyncClient, base_url: str, smiles_batch: List[str]) -> Dict[str, Any]: for attempt in range(MAX_RETRIES): r = await client.post( f"{base_url}/predict", json={"smiles": smiles_batch}, timeout=TIMEOUT_S, ) r.raise_for_status() return r.json() async def evaluate_model(hf_space_tag: str, smiles_list: List[str]) -> Dict[str, Any]: # Convert username/space-name to username-space-name.hf.space base_url = f"https://{hf_space_tag.replace('/', '-').replace('_', '-').lower()}.hf.space" results = [] async with httpx.AsyncClient() as client: meta = await fetch_metadata(client, base_url) max_bs = min(meta.get("max_batch_size", BATCH_SIZE), BATCH_SIZE) for batch in chunks(smiles_list, max_bs): resp = await call_predict(client, base_url, batch) predictions_dict = resp["predictions"] for smiles in batch: if smiles in predictions_dict: results.append({"smiles": smiles, "raw_predictions": predictions_dict[smiles]}) else: results.append({"smiles": smiles, "raw_predictions": {}, "error": "No prediction found"}) return {"results": results, "metadata": meta}