Spaces:
Sleeping
Sleeping
| import torch | |
| from PIL import Image | |
| import pytesseract | |
| from transformers import BlipProcessor, BlipForConditionalGeneration | |
| from io import BytesIO | |
| import base64 | |
| from typing import Union | |
| import whisper # Must be `openai-whisper` installed | |
| # === DEVICE SETUP === | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # === LOAD WHISPER MODEL FOR AUDIO TRANSCRIPTION === | |
| try: | |
| whisper_model = whisper.load_model("base") # Options: "tiny", "base", "small", "medium", "large" | |
| except Exception as e: | |
| raise RuntimeError(f"β Failed to load Whisper model: {str(e)}") | |
| # === LOAD BLIP FOR IMAGE CAPTIONING === | |
| try: | |
| blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") | |
| blip_model = BlipForConditionalGeneration.from_pretrained( | |
| "Salesforce/blip-image-captioning-base" | |
| ).to(device) | |
| except Exception as e: | |
| raise RuntimeError(f"β Failed to load BLIP model: {str(e)}") | |
| # === TEXT EXTRACTION (OCR) === | |
| def extract_text_from_image_base64(image_base64: str) -> str: | |
| """Extract text from a base64-encoded image.""" | |
| try: | |
| image_data = base64.b64decode(image_base64) | |
| image = Image.open(BytesIO(image_data)) | |
| return pytesseract.image_to_string(image).strip() | |
| except Exception as e: | |
| return f"β OCR Error (base64): {str(e)}" | |
| def extract_text_from_image_path(image_path: str) -> str: | |
| """Extract text from an image file path.""" | |
| try: | |
| image = Image.open(image_path) | |
| return pytesseract.image_to_string(image).strip() | |
| except Exception as e: | |
| return f"β OCR Error (path): {str(e)}" | |
| def extract_text_from_image_bytes(image_bytes: bytes) -> str: | |
| """Extract text from raw image bytes (e.g., file uploads).""" | |
| try: | |
| image = Image.open(BytesIO(image_bytes)) | |
| return pytesseract.image_to_string(image).strip() | |
| except Exception as e: | |
| return f"β OCR Error (bytes): {str(e)}" | |
| def extract_text_from_image(image_base64: str) -> str: | |
| """API alias for default OCR from base64 input.""" | |
| return extract_text_from_image_base64(image_base64) | |
| # === IMAGE CAPTIONING === | |
| def caption_image(image: Image.Image) -> str: | |
| """Generate a caption from a PIL image object.""" | |
| try: | |
| inputs = blip_processor(image.convert("RGB"), return_tensors="pt").to(device) | |
| outputs = blip_model.generate(**inputs) | |
| return blip_processor.decode(outputs[0], skip_special_tokens=True) | |
| except Exception as e: | |
| return f"β Captioning Error: {str(e)}" | |
| def caption_image_path(image_path: str) -> str: | |
| """Generate a caption from image file path.""" | |
| try: | |
| image = Image.open(image_path) | |
| return caption_image(image) | |
| except Exception as e: | |
| return f"β Captioning Error (path): {str(e)}" | |
| def caption_image_bytes(image_bytes: bytes) -> str: | |
| """Generate a caption from image bytes.""" | |
| try: | |
| image = Image.open(BytesIO(image_bytes)) | |
| return caption_image(image) | |
| except Exception as e: | |
| return f"β Captioning Error (bytes): {str(e)}" | |
| def describe_image(input_data: Union[str, bytes]) -> str: | |
| """ | |
| Unified captioning API β accepts either path or bytes. | |
| """ | |
| try: | |
| if isinstance(input_data, bytes): | |
| return caption_image_bytes(input_data) | |
| elif isinstance(input_data, str): | |
| return caption_image_path(input_data) | |
| else: | |
| return "β Unsupported input type for describe_image" | |
| except Exception as e: | |
| return f"β Description Error: {str(e)}" | |
| # === AUDIO TRANSCRIPTION === | |
| def transcribe_audio_bytes(audio_bytes: bytes) -> str: | |
| """Transcribe raw audio bytes using Whisper.""" | |
| try: | |
| # Save to temporary file | |
| temp_path = "/tmp/temp_audio.wav" | |
| with open(temp_path, "wb") as f: | |
| f.write(audio_bytes) | |
| result = whisper_model.transcribe(temp_path) | |
| return result.get("text", "").strip() | |
| except Exception as e: | |
| return f"β Transcription Error: {str(e)}" | |
| def transcribe_audio_path(audio_path: str) -> str: | |
| """Transcribe audio file using Whisper.""" | |
| try: | |
| result = whisper_model.transcribe(audio_path) | |
| return result.get("text", "").strip() | |
| except Exception as e: | |
| return f"β Transcription Error (path): {str(e)}" | |