Spaces:
Running
Running
| from flask import Flask, render_template, request, jsonify | |
| import spacy | |
| import json | |
| import requests | |
| from gliner import GLiNER | |
| app = Flask(__name__) | |
| # Load a blank English spaCy pipeline for tokenization | |
| nlp = spacy.blank("en") | |
| # GLiNER pipeline (will be configured on first use) | |
| gliner_nlp = None | |
| # GLiNER multitask model for relationships | |
| gliner_multitask = None | |
| def get_or_create_multitask_model(): | |
| """ | |
| Get or create GLiNER multitask model for relationship extraction | |
| """ | |
| global gliner_multitask | |
| if gliner_multitask is None: | |
| try: | |
| gliner_multitask = GLiNER.from_pretrained("knowledgator/gliner-multitask-large-v0.5") | |
| except Exception as e: | |
| print(f"Error loading GLiNER multitask model: {e}") | |
| return None | |
| return gliner_multitask | |
| def index(): | |
| return render_template('index.html') | |
| def tokenize_text(): | |
| """ | |
| Tokenize the input text and return token boundaries | |
| """ | |
| data = request.get_json() | |
| text = data.get('text', '') | |
| if not text: | |
| return jsonify({'error': 'No text provided'}), 400 | |
| # Process text with spaCy | |
| doc = nlp(text) | |
| # Extract token information | |
| tokens = [] | |
| for token in doc: | |
| tokens.append({ | |
| 'text': token.text, | |
| 'start': token.idx, | |
| 'end': token.idx + len(token.text) | |
| }) | |
| return jsonify({ | |
| 'tokens': tokens, | |
| 'text': text | |
| }) | |
| def find_token_boundaries(): | |
| """ | |
| Given a text selection, find the token boundaries that encompass it | |
| """ | |
| data = request.get_json() | |
| text = data.get('text', '') | |
| start = data.get('start', 0) | |
| end = data.get('end', 0) | |
| label = data.get('label', 'UNLABELED') | |
| if not text: | |
| return jsonify({'error': 'No text provided'}), 400 | |
| # Process text with spaCy | |
| doc = nlp(text) | |
| # Find tokens that overlap with the selection | |
| token_start = None | |
| token_end = None | |
| for token in doc: | |
| # Check if token overlaps with selection | |
| if token.idx < end and token.idx + len(token.text) > start: | |
| if token_start is None: | |
| token_start = token.idx | |
| token_end = token.idx + len(token.text) | |
| # If no tokens found, return original boundaries | |
| if token_start is None: | |
| token_start = start | |
| token_end = end | |
| return jsonify({ | |
| 'start': token_start, | |
| 'end': token_end, | |
| 'selected_text': text[token_start:token_end], | |
| 'label': label | |
| }) | |
| def get_default_labels(): | |
| """ | |
| Return the default annotation labels with their colors | |
| """ | |
| default_labels = [ | |
| {'name': 'PERSON', 'color': '#fef3c7', 'border': '#f59e0b'}, | |
| {'name': 'LOCATION', 'color': '#dbeafe', 'border': '#3b82f6'}, | |
| {'name': 'ORGANIZATION', 'color': '#dcfce7', 'border': '#10b981'} | |
| ] | |
| return jsonify({'labels': default_labels}) | |
| def get_default_relationship_labels(): | |
| """ | |
| Return the default relationship labels with their colors | |
| """ | |
| default_relationship_labels = [ | |
| {'name': 'worked at', 'color': '#fce7f3', 'border': '#ec4899'}, | |
| {'name': 'visited', 'color': '#f3e8ff', 'border': '#a855f7'} | |
| ] | |
| return jsonify({'relationship_labels': default_relationship_labels}) | |
| def get_or_create_gliner_pipeline(labels): | |
| """ | |
| Get or create GLiNER pipeline with specified labels | |
| """ | |
| global gliner_nlp | |
| # Convert labels to lowercase for GLiNER | |
| gliner_labels = [label.lower() for label in labels] | |
| try: | |
| # Create new pipeline if it doesn't exist or labels changed | |
| custom_spacy_config = { | |
| "gliner_model": "gliner-community/gliner_small-v2.5", | |
| "chunk_size": 250, | |
| "labels": gliner_labels, | |
| "style": "ent" | |
| } | |
| gliner_nlp = spacy.blank("en") | |
| gliner_nlp.add_pipe("gliner_spacy", config=custom_spacy_config) | |
| return gliner_nlp | |
| except Exception as e: | |
| print(f"Error creating GLiNER pipeline: {e}") | |
| return None | |
| def run_gliner(): | |
| """ | |
| Run GLiNER entity extraction on the provided text with specified labels | |
| """ | |
| data = request.get_json() | |
| text = data.get('text', '') | |
| labels = data.get('labels', []) | |
| if not text: | |
| return jsonify({'error': 'No text provided'}), 400 | |
| if not labels: | |
| return jsonify({'error': 'No labels provided'}), 400 | |
| try: | |
| # Get or create GLiNER pipeline | |
| pipeline = get_or_create_gliner_pipeline(labels) | |
| if pipeline is None: | |
| return jsonify({'error': 'Failed to initialize GLiNER pipeline'}), 500 | |
| # Process text with GLiNER | |
| doc = pipeline(text) | |
| # Extract entities with token boundaries | |
| entities = [] | |
| for ent in doc.ents: | |
| # Map GLiNER label back to user's label format | |
| original_label = None | |
| for label in labels: | |
| if label.lower() == ent.label_.lower(): | |
| original_label = label | |
| break | |
| if original_label: | |
| entities.append({ | |
| 'text': ent.text, | |
| 'start': ent.start_char, | |
| 'end': ent.end_char, | |
| 'label': original_label, | |
| 'confidence': getattr(ent, 'score', 1.0) if hasattr(ent, 'score') else 1.0 | |
| }) | |
| return jsonify({ | |
| 'entities': entities, | |
| 'total_found': len(entities) | |
| }) | |
| except Exception as e: | |
| print(f"GLiNER processing error: {e}") | |
| return jsonify({'error': f'GLiNER processing failed: {str(e)}'}), 500 | |
| def run_gliner_relationships(): | |
| """ | |
| Run GLiNER relationship extraction on the provided text with specified relationship labels | |
| """ | |
| data = request.get_json() | |
| text = data.get('text', '') | |
| relationship_labels = data.get('relationship_labels', []) | |
| entity_labels = data.get('entity_labels', ["person", "organization", "location", "date", "place"]) | |
| if not text: | |
| return jsonify({'error': 'No text provided'}), 400 | |
| if not relationship_labels: | |
| return jsonify({'error': 'No relationship labels provided'}), 400 | |
| try: | |
| # Get GLiNER multitask model | |
| model = get_or_create_multitask_model() | |
| if model is None: | |
| return jsonify({'error': 'Failed to initialize GLiNER multitask model'}), 500 | |
| # First extract entities using the provided entity labels | |
| print(f"Using entity labels: {entity_labels}") | |
| entities = model.predict_entities(text, entity_labels, threshold=0.3) | |
| print(entities) | |
| # Then extract relationships using the specific format | |
| formatted_labels = [] | |
| for label in relationship_labels: | |
| for entity_label in entity_labels: | |
| formatted_labels.append(f"{entity_label} <> {label}") | |
| print(f"Formatted relationship labels: {formatted_labels}") | |
| relation_entities = model.predict_entities(text, formatted_labels, threshold=0.3) | |
| # Process results into relationship triplets | |
| relationships = [] | |
| # Group relation entities by their relation type and try to find entity pairs | |
| for rel_entity in relation_entities: | |
| print(rel_entity) | |
| label_parts = rel_entity['label'].split(' <> ') | |
| if len(label_parts) == 2: | |
| entity_type, relation_type = label_parts | |
| # Find potential subject and object entities near this relation | |
| rel_start = rel_entity['start'] | |
| rel_end = rel_entity['end'] | |
| # Look for entities before and after the relation mention | |
| subject_candidates = [e for e in entities if e['end'] <= rel_start and abs(e['end'] - rel_start) < 100] | |
| object_candidates = [e for e in entities if e['start'] >= rel_end and abs(e['start'] - rel_end) < 100] | |
| # Also look for entities that contain or are contained by the relation text | |
| overlapping_entities = [e for e in entities if | |
| (e['start'] <= rel_start and e['end'] >= rel_end) or # entity contains relation | |
| (rel_start <= e['start'] and rel_end >= e['end']) # relation contains entity | |
| ] | |
| if subject_candidates and object_candidates: | |
| # Take the closest entities | |
| subject = max(subject_candidates, key=lambda x: x['end']) | |
| object_entity = min(object_candidates, key=lambda x: x['start']) | |
| relationships.append({ | |
| 'subject': subject['text'], | |
| 'subject_start': subject['start'], | |
| 'subject_end': subject['end'], | |
| 'relation_type': relation_type, | |
| 'relation_text': rel_entity['text'], | |
| 'relation_start': rel_entity['start'], | |
| 'relation_end': rel_entity['end'], | |
| 'object': object_entity['text'], | |
| 'object_start': object_entity['start'], | |
| 'object_end': object_entity['end'], | |
| 'confidence': rel_entity['score'], | |
| 'full_text': f"{subject['text']} {relation_type} {object_entity['text']}" | |
| }) | |
| elif overlapping_entities: | |
| # Handle cases where the relation text spans or overlaps with entities | |
| for ent in overlapping_entities: | |
| relationships.append({ | |
| 'subject': ent['text'], | |
| 'subject_start': ent['start'], | |
| 'subject_end': ent['end'], | |
| 'relation_type': relation_type, | |
| 'relation_text': rel_entity['text'], | |
| 'relation_start': rel_entity['start'], | |
| 'relation_end': rel_entity['end'], | |
| 'object': '', # Will be filled by user or further processing | |
| 'object_start': -1, | |
| 'object_end': -1, | |
| 'confidence': rel_entity['score'], | |
| 'full_text': f"{ent['text']} {relation_type} [object]" | |
| }) | |
| return jsonify({ | |
| 'relationships': relationships, | |
| 'total_found': len(relationships) | |
| }) | |
| except Exception as e: | |
| print(f"GLiNER relationship processing error: {e}") | |
| return jsonify({'error': f'GLiNER relationship processing failed: {str(e)}'}), 500 | |
| def search_wikidata(): | |
| """ | |
| Search Wikidata for entities matching the query | |
| """ | |
| data = request.get_json() | |
| query = data.get('query', '').strip() | |
| limit = data.get('limit', 10) | |
| if not query: | |
| return jsonify({'error': 'No query provided'}), 400 | |
| try: | |
| # Wikidata search API endpoint | |
| url = 'https://www.wikidata.org/w/api.php' | |
| params = { | |
| 'action': 'wbsearchentities', | |
| 'search': query, | |
| 'language': 'en', | |
| 'format': 'json', | |
| 'limit': limit, | |
| 'type': 'item' | |
| } | |
| headers = { | |
| 'User-Agent': 'AnnotationTool/1.0 (https://github.com/user/annotation-tool) Python/requests' | |
| } | |
| response = requests.get(url, params=params, headers=headers, timeout=10) | |
| response.raise_for_status() | |
| data = response.json() | |
| # Extract relevant information | |
| results = [] | |
| if 'search' in data: | |
| for item in data['search']: | |
| result = { | |
| 'id': item.get('id', ''), | |
| 'label': item.get('label', ''), | |
| 'description': item.get('description', ''), | |
| 'url': f"https://www.wikidata.org/wiki/{item.get('id', '')}" | |
| } | |
| results.append(result) | |
| return jsonify({ | |
| 'results': results, | |
| 'total': len(results) | |
| }) | |
| except requests.exceptions.RequestException as e: | |
| print(f"Wikidata API error: {e}") | |
| return jsonify({'error': 'Failed to search Wikidata'}), 500 | |
| except Exception as e: | |
| print(f"Wikidata search error: {e}") | |
| return jsonify({'error': f'Search failed: {str(e)}'}), 500 | |
| def get_wikidata_entity(): | |
| """ | |
| Get Wikidata entity information by Q-code | |
| """ | |
| data = request.get_json() | |
| qcode = data.get('qcode', '').strip() | |
| if not qcode: | |
| return jsonify({'error': 'No Q-code provided'}), 400 | |
| # Ensure Q-code format | |
| if not qcode.startswith('Q'): | |
| qcode = 'Q' + qcode.lstrip('Q') | |
| try: | |
| # Wikidata entity API endpoint | |
| url = 'https://www.wikidata.org/w/api.php' | |
| params = { | |
| 'action': 'wbgetentities', | |
| 'ids': qcode, | |
| 'languages': 'en', | |
| 'format': 'json' | |
| } | |
| headers = { | |
| 'User-Agent': 'AnnotationTool/1.0 (https://github.com/user/annotation-tool) Python/requests' | |
| } | |
| response = requests.get(url, params=params, headers=headers, timeout=10) | |
| response.raise_for_status() | |
| data = response.json() | |
| if 'entities' in data and qcode in data['entities']: | |
| entity = data['entities'][qcode] | |
| if 'missing' in entity: | |
| return jsonify({'error': f'Entity {qcode} not found'}), 404 | |
| # Extract information | |
| result = { | |
| 'id': qcode, | |
| 'label': entity.get('labels', {}).get('en', {}).get('value', ''), | |
| 'description': entity.get('descriptions', {}).get('en', {}).get('value', ''), | |
| 'url': f"https://www.wikidata.org/wiki/{qcode}" | |
| } | |
| return jsonify({'entity': result}) | |
| else: | |
| return jsonify({'error': f'Entity {qcode} not found'}), 404 | |
| except requests.exceptions.RequestException as e: | |
| print(f"Wikidata API error: {e}") | |
| return jsonify({'error': 'Failed to get Wikidata entity'}), 500 | |
| except Exception as e: | |
| print(f"Wikidata entity error: {e}") | |
| return jsonify({'error': f'Request failed: {str(e)}'}), 500 | |
| if __name__ == '__main__': | |
| import os | |
| port = int(os.environ.get('PORT', 7860)) | |
| app.run(host='0.0.0.0', port=port, debug=False) |