Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from janus.models import VLChatProcessor | |
| from PIL import Image | |
| import spaces | |
| # Suppress specific warnings | |
| import warnings | |
| warnings.filterwarnings("ignore", category=FutureWarning) | |
| # Medical Imaging Analysis Configuration | |
| MEDICAL_CONFIG = { | |
| "echo_guidelines": "ASE 2023 Standards", | |
| "histo_guidelines": "CAP Protocols 2024", | |
| "cardiac_params": ["LVEF", "E/A Ratio", "Wall Motion"], | |
| "histo_params": ["Nuclear Atypia", "Mitotic Count", "Stromal Invasion"] | |
| } | |
| # Initialize Medical Imaging Model | |
| model_path = "deepseek-ai/Janus-Pro-1B" | |
| class MedicalImagingAdapter(torch.nn.Module): | |
| def __init__(self, base_model): | |
| super().__init__() | |
| self.base_model = base_model | |
| # Cardiac-specific projections | |
| self.cardiac_proj = torch.nn.Linear(2048, 2048) | |
| # Histopathology-specific projections | |
| self.histo_proj = torch.nn.Linear(2048, 2048) | |
| def forward(self, *args, **kwargs): | |
| outputs = self.base_model(*args, **kwargs) | |
| return outputs | |
| vl_gpt = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True) | |
| vl_gpt.language_model = MedicalImagingAdapter(vl_gpt.language_model) | |
| if torch.cuda.is_available(): | |
| vl_gpt = vl_gpt.to(torch.bfloat16).cuda() | |
| vl_chat_processor = VLChatProcessor.from_pretrained(model_path) | |
| # **Fix: Set legacy=False in tokenizer to use the new behavior** | |
| vl_chat_processor.tokenizer = AutoTokenizer.from_pretrained(model_path, legacy=False) | |
| # Medical Image Processing Pipelines | |
| def preprocess_echo(image): | |
| """Process echocardiography images""" | |
| img = Image.fromarray(image).convert('L') # Grayscale | |
| return np.array(img.resize((512, 512))) | |
| def preprocess_histo(image): | |
| """Process histopathology slides""" | |
| img = Image.fromarray(image) | |
| return np.array(img.resize((1024, 1024))) | |
| def analyze_medical_case(image, clinical_context, modality): | |
| # Preprocess based on modality | |
| processed_img = preprocess_echo(image) if modality == "Echo" else preprocess_histo(image) | |
| # Create modality-specific prompt | |
| system_prompt = f""" | |
| Analyze this {modality} image following {MEDICAL_CONFIG['echo_guidelines' if modality=='Echo' else 'histo_guidelines']}. | |
| Clinical Context: {clinical_context} | |
| """ | |
| conversation = [{ | |
| "role": "<|Radiologist|>" if modality == "Echo" else "<|Pathologist|>", | |
| "content": system_prompt, | |
| "images": [processed_img], | |
| }, {"role": "<|AI_Assistant|>", "content": ""}] | |
| inputs = vl_chat_processor( | |
| conversations=conversation, | |
| images=[Image.fromarray(processed_img)], | |
| force_batchify=True | |
| ).to(vl_gpt.device) | |
| outputs = vl_gpt.generate( | |
| inputs_embeds=vl_gpt.prepare_inputs_embeds(**inputs), | |
| attention_mask=inputs.attention_mask, | |
| max_new_tokens=512, | |
| temperature=0.1, | |
| top_p=0.9, | |
| repetition_penalty=1.5 | |
| ) | |
| report = vl_chat_processor.tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True) | |
| return format_medical_report(report, modality) | |
| def format_medical_report(text, modality): | |
| # Structure report based on modality | |
| sections = { | |
| "Echo": [ | |
| ("Chamber Dimensions", "LVEDD", "LVESD"), | |
| ("Valvular Function", "Aortic Valve", "Mitral Valve"), | |
| ("Hemodynamics", "E/A Ratio", "LVEF") | |
| ], | |
| "Histo": [ | |
| ("Architecture", "Gland Formation", "Stromal Pattern"), | |
| ("Cellular Features", "Nuclear Atypia", "Mitotic Count"), | |
| ("Diagnostic Impression", "Tumor Grade", "Margin Status") | |
| ] | |
| } | |
| formatted = f"**{modality} Analysis Report**\n\n" | |
| for section in sections[modality]: | |
| header = section[0] | |
| formatted += f"### {header}\n" | |
| for sub in section[1:]: | |
| if sub in text: | |
| start = text.find(sub) | |
| end = text.find("\n\n", start) | |
| formatted += f"- **{sub}:** {text[start+len(sub)+1:end].strip()}\n" | |
| return formatted | |
| # Medical Imaging Interface | |
| with gr.Blocks(title="Cardiac & Histopathology AI", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| ## Medical Imaging Analysis Platform | |
| *Analyzes echocardiograms and histopathology slides - Research Use Only* | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image(label="Upload Medical Image") | |
| modality_select = gr.Radio( | |
| ["Echo", "Histo"], | |
| label="Image Modality", | |
| info="Select 'Echo' for cardiac ultrasound, 'Histo' for biopsy slides" | |
| ) | |
| clinical_input = gr.Textbox( | |
| label="Clinical Context", | |
| placeholder="e.g., 'Assess LV function' or 'Evaluate for malignancy'" | |
| ) | |
| analyze_btn = gr.Button("Analyze Case", variant="primary") | |
| with gr.Column(): | |
| report_output = gr.Markdown(label="AI Clinical Report") | |
| # Preloaded examples | |
| gr.Examples( | |
| examples=[ | |
| ["Evaluate LV systolic function", "case1.png", "Echo"], | |
| ["Assess mitral valve function", "case2.jpg", "Echo"], | |
| ["Analyze for malignant features", "case3.png", "Histo"], | |
| ["Evaluate tumor margins", "case4.png", "Histo"] | |
| ], | |
| inputs=[clinical_input, image_input, modality_select], | |
| label="Example Medical Cases" | |
| ) | |
| # **Fixed: Removed @demo.func and used .click() correctly** | |
| analyze_btn.click( | |
| analyze_medical_case, | |
| [image_input, clinical_input, modality_select], | |
| report_output | |
| ) | |
| demo.launch(share=True) | |