from flask import Flask, render_template, request, jsonify import spacy import json import requests from gliner import GLiNER app = Flask(__name__) # Load a blank English spaCy pipeline for tokenization nlp = spacy.blank("en") # GLiNER pipeline (will be configured on first use) gliner_nlp = None # GLiNER multitask model for relationships gliner_multitask = None def get_or_create_multitask_model(): """ Get or create GLiNER multitask model for relationship extraction """ global gliner_multitask if gliner_multitask is None: try: gliner_multitask = GLiNER.from_pretrained("knowledgator/gliner-multitask-large-v0.5") except Exception as e: print(f"Error loading GLiNER multitask model: {e}") return None return gliner_multitask @app.route('/') def index(): return render_template('index.html') @app.route('/tokenize', methods=['POST']) def tokenize_text(): """ Tokenize the input text and return token boundaries """ data = request.get_json() text = data.get('text', '') if not text: return jsonify({'error': 'No text provided'}), 400 # Process text with spaCy doc = nlp(text) # Extract token information tokens = [] for token in doc: tokens.append({ 'text': token.text, 'start': token.idx, 'end': token.idx + len(token.text) }) return jsonify({ 'tokens': tokens, 'text': text }) @app.route('/find_token_boundaries', methods=['POST']) def find_token_boundaries(): """ Given a text selection, find the token boundaries that encompass it """ data = request.get_json() text = data.get('text', '') start = data.get('start', 0) end = data.get('end', 0) label = data.get('label', 'UNLABELED') if not text: return jsonify({'error': 'No text provided'}), 400 # Process text with spaCy doc = nlp(text) # Find tokens that overlap with the selection token_start = None token_end = None for token in doc: # Check if token overlaps with selection if token.idx < end and token.idx + len(token.text) > start: if token_start is None: token_start = token.idx token_end = token.idx + len(token.text) # If no tokens found, return original boundaries if token_start is None: token_start = start token_end = end return jsonify({ 'start': token_start, 'end': token_end, 'selected_text': text[token_start:token_end], 'label': label }) @app.route('/get_default_labels', methods=['GET']) def get_default_labels(): """ Return the default annotation labels with their colors """ default_labels = [ {'name': 'PERSON', 'color': '#fef3c7', 'border': '#f59e0b'}, {'name': 'LOCATION', 'color': '#dbeafe', 'border': '#3b82f6'}, {'name': 'ORGANIZATION', 'color': '#dcfce7', 'border': '#10b981'} ] return jsonify({'labels': default_labels}) @app.route('/get_default_relationship_labels', methods=['GET']) def get_default_relationship_labels(): """ Return the default relationship labels with their colors """ default_relationship_labels = [ {'name': 'worked at', 'color': '#fce7f3', 'border': '#ec4899'}, {'name': 'visited', 'color': '#f3e8ff', 'border': '#a855f7'} ] return jsonify({'relationship_labels': default_relationship_labels}) def get_or_create_gliner_pipeline(labels): """ Get or create GLiNER pipeline with specified labels """ global gliner_nlp # Convert labels to lowercase for GLiNER gliner_labels = [label.lower() for label in labels] try: # Create new pipeline if it doesn't exist or labels changed custom_spacy_config = { "gliner_model": "gliner-community/gliner_small-v2.5", "chunk_size": 250, "labels": gliner_labels, "style": "ent" } gliner_nlp = spacy.blank("en") gliner_nlp.add_pipe("gliner_spacy", config=custom_spacy_config) return gliner_nlp except Exception as e: print(f"Error creating GLiNER pipeline: {e}") return None @app.route('/run_gliner', methods=['POST']) def run_gliner(): """ Run GLiNER entity extraction on the provided text with specified labels """ data = request.get_json() text = data.get('text', '') labels = data.get('labels', []) if not text: return jsonify({'error': 'No text provided'}), 400 if not labels: return jsonify({'error': 'No labels provided'}), 400 try: # Get or create GLiNER pipeline pipeline = get_or_create_gliner_pipeline(labels) if pipeline is None: return jsonify({'error': 'Failed to initialize GLiNER pipeline'}), 500 # Process text with GLiNER doc = pipeline(text) # Extract entities with token boundaries entities = [] for ent in doc.ents: # Map GLiNER label back to user's label format original_label = None for label in labels: if label.lower() == ent.label_.lower(): original_label = label break if original_label: entities.append({ 'text': ent.text, 'start': ent.start_char, 'end': ent.end_char, 'label': original_label, 'confidence': getattr(ent, 'score', 1.0) if hasattr(ent, 'score') else 1.0 }) return jsonify({ 'entities': entities, 'total_found': len(entities) }) except Exception as e: print(f"GLiNER processing error: {e}") return jsonify({'error': f'GLiNER processing failed: {str(e)}'}), 500 @app.route('/run_gliner_relationships', methods=['POST']) def run_gliner_relationships(): """ Run GLiNER relationship extraction on the provided text with specified relationship labels """ data = request.get_json() text = data.get('text', '') relationship_labels = data.get('relationship_labels', []) entity_labels = data.get('entity_labels', ["person", "organization", "location", "date", "place"]) if not text: return jsonify({'error': 'No text provided'}), 400 if not relationship_labels: return jsonify({'error': 'No relationship labels provided'}), 400 try: # Get GLiNER multitask model model = get_or_create_multitask_model() if model is None: return jsonify({'error': 'Failed to initialize GLiNER multitask model'}), 500 # First extract entities using the provided entity labels print(f"Using entity labels: {entity_labels}") entities = model.predict_entities(text, entity_labels, threshold=0.3) print(entities) # Then extract relationships using the specific format formatted_labels = [] for label in relationship_labels: for entity_label in entity_labels: formatted_labels.append(f"{entity_label} <> {label}") print(f"Formatted relationship labels: {formatted_labels}") relation_entities = model.predict_entities(text, formatted_labels, threshold=0.3) # Process results into relationship triplets relationships = [] # Group relation entities by their relation type and try to find entity pairs for rel_entity in relation_entities: print(rel_entity) label_parts = rel_entity['label'].split(' <> ') if len(label_parts) == 2: entity_type, relation_type = label_parts # Find potential subject and object entities near this relation rel_start = rel_entity['start'] rel_end = rel_entity['end'] # Look for entities before and after the relation mention subject_candidates = [e for e in entities if e['end'] <= rel_start and abs(e['end'] - rel_start) < 100] object_candidates = [e for e in entities if e['start'] >= rel_end and abs(e['start'] - rel_end) < 100] # Also look for entities that contain or are contained by the relation text overlapping_entities = [e for e in entities if (e['start'] <= rel_start and e['end'] >= rel_end) or # entity contains relation (rel_start <= e['start'] and rel_end >= e['end']) # relation contains entity ] if subject_candidates and object_candidates: # Take the closest entities subject = max(subject_candidates, key=lambda x: x['end']) object_entity = min(object_candidates, key=lambda x: x['start']) relationships.append({ 'subject': subject['text'], 'subject_start': subject['start'], 'subject_end': subject['end'], 'relation_type': relation_type, 'relation_text': rel_entity['text'], 'relation_start': rel_entity['start'], 'relation_end': rel_entity['end'], 'object': object_entity['text'], 'object_start': object_entity['start'], 'object_end': object_entity['end'], 'confidence': rel_entity['score'], 'full_text': f"{subject['text']} {relation_type} {object_entity['text']}" }) elif overlapping_entities: # Handle cases where the relation text spans or overlaps with entities for ent in overlapping_entities: relationships.append({ 'subject': ent['text'], 'subject_start': ent['start'], 'subject_end': ent['end'], 'relation_type': relation_type, 'relation_text': rel_entity['text'], 'relation_start': rel_entity['start'], 'relation_end': rel_entity['end'], 'object': '', # Will be filled by user or further processing 'object_start': -1, 'object_end': -1, 'confidence': rel_entity['score'], 'full_text': f"{ent['text']} {relation_type} [object]" }) return jsonify({ 'relationships': relationships, 'total_found': len(relationships) }) except Exception as e: print(f"GLiNER relationship processing error: {e}") return jsonify({'error': f'GLiNER relationship processing failed: {str(e)}'}), 500 @app.route('/search_wikidata', methods=['POST']) def search_wikidata(): """ Search Wikidata for entities matching the query """ data = request.get_json() query = data.get('query', '').strip() limit = data.get('limit', 10) if not query: return jsonify({'error': 'No query provided'}), 400 try: # Wikidata search API endpoint url = 'https://www.wikidata.org/w/api.php' params = { 'action': 'wbsearchentities', 'search': query, 'language': 'en', 'format': 'json', 'limit': limit, 'type': 'item' } headers = { 'User-Agent': 'AnnotationTool/1.0 (https://github.com/user/annotation-tool) Python/requests' } response = requests.get(url, params=params, headers=headers, timeout=10) response.raise_for_status() data = response.json() # Extract relevant information results = [] if 'search' in data: for item in data['search']: result = { 'id': item.get('id', ''), 'label': item.get('label', ''), 'description': item.get('description', ''), 'url': f"https://www.wikidata.org/wiki/{item.get('id', '')}" } results.append(result) return jsonify({ 'results': results, 'total': len(results) }) except requests.exceptions.RequestException as e: print(f"Wikidata API error: {e}") return jsonify({'error': 'Failed to search Wikidata'}), 500 except Exception as e: print(f"Wikidata search error: {e}") return jsonify({'error': f'Search failed: {str(e)}'}), 500 @app.route('/get_wikidata_entity', methods=['POST']) def get_wikidata_entity(): """ Get Wikidata entity information by Q-code """ data = request.get_json() qcode = data.get('qcode', '').strip() if not qcode: return jsonify({'error': 'No Q-code provided'}), 400 # Ensure Q-code format if not qcode.startswith('Q'): qcode = 'Q' + qcode.lstrip('Q') try: # Wikidata entity API endpoint url = 'https://www.wikidata.org/w/api.php' params = { 'action': 'wbgetentities', 'ids': qcode, 'languages': 'en', 'format': 'json' } headers = { 'User-Agent': 'AnnotationTool/1.0 (https://github.com/user/annotation-tool) Python/requests' } response = requests.get(url, params=params, headers=headers, timeout=10) response.raise_for_status() data = response.json() if 'entities' in data and qcode in data['entities']: entity = data['entities'][qcode] if 'missing' in entity: return jsonify({'error': f'Entity {qcode} not found'}), 404 # Extract information result = { 'id': qcode, 'label': entity.get('labels', {}).get('en', {}).get('value', ''), 'description': entity.get('descriptions', {}).get('en', {}).get('value', ''), 'url': f"https://www.wikidata.org/wiki/{qcode}" } return jsonify({'entity': result}) else: return jsonify({'error': f'Entity {qcode} not found'}), 404 except requests.exceptions.RequestException as e: print(f"Wikidata API error: {e}") return jsonify({'error': 'Failed to get Wikidata entity'}), 500 except Exception as e: print(f"Wikidata entity error: {e}") return jsonify({'error': f'Request failed: {str(e)}'}), 500 if __name__ == '__main__': import os port = int(os.environ.get('PORT', 7860)) app.run(host='0.0.0.0', port=port, debug=False)