Spaces:
Sleeping
Sleeping
| from time import sleep | |
| import logging | |
| import sys | |
| import re | |
| import httpx | |
| from fastapi import FastAPI, Request, status | |
| from fastapi.responses import JSONResponse, FileResponse | |
| from fastapi.exceptions import RequestValidationError | |
| from transformers import pipeline | |
| from phishing_datasets import submit_entry | |
| from url_tools import extract_urls, resolve_short_url, extract_domain_from_url | |
| from urlscan_client import UrlscanClient | |
| import requests | |
| from mnemonic_attack import find_confusable_brand | |
| from models.models import MessageModel, QueryModel, AppModel, InputModel, OutputModel, ReportMessagesModel, ReportInputModel | |
| from models.enums import ActionModel, SubActionModel | |
| app = FastAPI() | |
| urlscan = UrlscanClient() | |
| # Remove all handlers associated with the root logger object | |
| for handler in logging.root.handlers[:]: | |
| logging.root.removeHandler(handler) | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(levelname)s: %(asctime)s %(message)s', | |
| handlers=[logging.StreamHandler(sys.stdout)] | |
| ) | |
| pipe = pipeline(task="text-classification", model="mrm8488/bert-tiny-finetuned-sms-spam-detection") | |
| def get_well_known_aasa(): | |
| return JSONResponse( | |
| content={ | |
| "messagefilter": { | |
| "apps": [ | |
| "X9NN3FSS3T.com.lela.Serenity.SerenityMessageFilterExtension", | |
| "X9NN3FSS3T.com.lela.Serenity" | |
| ] | |
| }, | |
| "classificationreport": { | |
| "apps": [ | |
| "X9NN3FSS3T.com.lela.Serenity.SerenityUnwantedCommunicationExtension", | |
| "X9NN3FSS3T.com.lela.Serenity" | |
| ] | |
| } | |
| }, | |
| media_type="application/json" | |
| ) | |
| def get_robots_txt(): | |
| return FileResponse("robots.txt") | |
| def predict(model: InputModel) -> OutputModel: | |
| sender = model.query.sender | |
| text = model.query.message.text | |
| logging.info(f"[{sender}] {text}") | |
| # Debug sleep | |
| pattern = r"^Sent from your Twilio trial account - sleep (\d+)$" | |
| match = re.search(pattern, text) | |
| if match: | |
| number_str = match.group(1) | |
| sleep_duration = int(number_str) | |
| logging.debug(f"[DEBUG SLEEP] Sleeping for {sleep_duration} seconds for sender {sender}") | |
| sleep(sleep_duration) | |
| return OutputModel(action=ActionModel.JUNK, sub_action=SubActionModel.NONE) | |
| # Debug category | |
| pattern = r"^Sent from your Twilio trial account - (junk|transaction|promotion)$" | |
| match = re.search(pattern, text) | |
| if match: | |
| category_str = match.group(1) | |
| logging.info(f"[DEBUG CATEGORY] Forced category: {category_str} for sender {sender}") | |
| match category_str: | |
| case 'junk': | |
| return OutputModel(action=ActionModel.JUNK, sub_action=SubActionModel.NONE) | |
| case 'transaction': | |
| return OutputModel(action=ActionModel.TRANSACTION, sub_action=SubActionModel.NONE) | |
| case 'promotion': | |
| return OutputModel(action=ActionModel.PROMOTION, sub_action=SubActionModel.NONE) | |
| # Brand usurpation detection using confusables | |
| confusable_brand = find_confusable_brand(text) | |
| if confusable_brand: | |
| logging.warning(f"[BRAND USURPATION] Confusable/homoglyph variant of brand '{confusable_brand}' detected in message. Classified as JUNK.") | |
| return OutputModel(action=ActionModel.JUNK, sub_action=SubActionModel.NONE) | |
| result = pipe(text) | |
| label = result[0]['label'] | |
| score = result[0]['score'] | |
| logging.info(f"[CLASSIFICATION] label={label} score={score}") | |
| if label == 'LABEL_0': | |
| score = 1 - score | |
| # Pattern for detecting an alphanumeric SenderID | |
| alphanumeric_sender_pattern = r'^[A-Za-z][A-Za-z0-9\-\.]{2,14}$' | |
| # Pattern for detecting a short code | |
| shorten_sender_pattern = r'^(?:3\d{4}|[4-8]\d{4})$' | |
| commercial_stop = False | |
| # Detection of commercial senders (short code or alphanumeric) | |
| if re.search(shorten_sender_pattern, sender): | |
| logging.info("[COMMERCIAL] Commercial sender detected (short code)") | |
| score = score * 0.7 | |
| elif re.match(alphanumeric_sender_pattern, sender): | |
| logging.info("[COMMERCIAL] Alphanumeric SenderID detected") | |
| score = score * 0.7 | |
| urls = extract_urls(text) | |
| if urls: | |
| logging.info(f"[URL] URLs found: {urls}") | |
| logging.info("[URL] Searching for previous scans") | |
| search_results = [urlscan.search(f"domain:{extract_domain_from_url(url)}") for url in urls] | |
| scan_results = [] | |
| for search_result in search_results: | |
| results = search_result.get('results', []) | |
| for result in results: | |
| result_uuid = result.get('_id', str) | |
| scan_result = urlscan.get_result(result_uuid) | |
| scan_results.append(scan_result) | |
| if not scan_results: | |
| logging.info("[URL] No previous scan found, launching a new scan...") | |
| scan_results = [urlscan.scan(url) for url in urls] | |
| for result in scan_results: | |
| overall = result.get('verdicts', {}).get('overall', {}) | |
| logging.info(f"[URLSCAN] Overall verdict: {overall}") | |
| if overall.get('hasVerdicts'): | |
| score = overall.get('score') | |
| logging.info(f"[URLSCAN] Verdict score: {score}") | |
| if 0 < overall.get('score'): | |
| score = 1.0 | |
| break | |
| elif overall.get('score') < 0: | |
| score = score * 0.9 | |
| else: | |
| logging.info(f"[URL] No URL found") | |
| score = score * 0.9 | |
| logging.info(f"[FINAL SCORE] {score}") | |
| action = ActionModel.NONE | |
| if score > 0.7: | |
| action=ActionModel.JUNK | |
| elif score > 0.5: | |
| if commercial_stop: | |
| action=ActionModel.PROMOTION | |
| else: | |
| action=ActionModel.JUNK | |
| logging.info(f"[FINAL ACTION] {action}") | |
| return OutputModel(action=action, sub_action=SubActionModel.NONE) | |
| def report(model: ReportInputModel): | |
| logging.info(f"[REPORT] {model.classification.messages}") | |
| for msg in model.classification.messages: | |
| logging.info( | |
| f"[REPORT] {msg.timestamp=} {msg.sender=} {msg.message=}" | |
| ) | |
| return {"status": "received", "count": len(model.classification.messages)} | |
| async def validation_exception_handler(request: Request, exc: RequestValidationError): | |
| exc_str = f'{exc}'.replace('\n', ' ').replace(' ', ' ') | |
| logging.error(f"{request}: {exc_str}") | |
| content = {'status_code': 10422, 'message': exc_str, 'data': None} | |
| return JSONResponse(content=content, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY) | |