LLM-LOD / app.py
wjbmattingly's picture
Upload 4 files
71303dd verified
from flask import Flask, render_template, request, jsonify
import spacy
import json
import requests
from gliner import GLiNER
app = Flask(__name__)
# Load a blank English spaCy pipeline for tokenization
nlp = spacy.blank("en")
# GLiNER pipeline (will be configured on first use)
gliner_nlp = None
# GLiNER multitask model for relationships
gliner_multitask = None
def get_or_create_multitask_model():
"""
Get or create GLiNER multitask model for relationship extraction
"""
global gliner_multitask
if gliner_multitask is None:
try:
gliner_multitask = GLiNER.from_pretrained("knowledgator/gliner-multitask-large-v0.5")
except Exception as e:
print(f"Error loading GLiNER multitask model: {e}")
return None
return gliner_multitask
@app.route('/')
def index():
return render_template('index.html')
@app.route('/tokenize', methods=['POST'])
def tokenize_text():
"""
Tokenize the input text and return token boundaries
"""
data = request.get_json()
text = data.get('text', '')
if not text:
return jsonify({'error': 'No text provided'}), 400
# Process text with spaCy
doc = nlp(text)
# Extract token information
tokens = []
for token in doc:
tokens.append({
'text': token.text,
'start': token.idx,
'end': token.idx + len(token.text)
})
return jsonify({
'tokens': tokens,
'text': text
})
@app.route('/find_token_boundaries', methods=['POST'])
def find_token_boundaries():
"""
Given a text selection, find the token boundaries that encompass it
"""
data = request.get_json()
text = data.get('text', '')
start = data.get('start', 0)
end = data.get('end', 0)
label = data.get('label', 'UNLABELED')
if not text:
return jsonify({'error': 'No text provided'}), 400
# Process text with spaCy
doc = nlp(text)
# Find tokens that overlap with the selection
token_start = None
token_end = None
for token in doc:
# Check if token overlaps with selection
if token.idx < end and token.idx + len(token.text) > start:
if token_start is None:
token_start = token.idx
token_end = token.idx + len(token.text)
# If no tokens found, return original boundaries
if token_start is None:
token_start = start
token_end = end
return jsonify({
'start': token_start,
'end': token_end,
'selected_text': text[token_start:token_end],
'label': label
})
@app.route('/get_default_labels', methods=['GET'])
def get_default_labels():
"""
Return the default annotation labels with their colors
"""
default_labels = [
{'name': 'PERSON', 'color': '#fef3c7', 'border': '#f59e0b'},
{'name': 'LOCATION', 'color': '#dbeafe', 'border': '#3b82f6'},
{'name': 'ORGANIZATION', 'color': '#dcfce7', 'border': '#10b981'}
]
return jsonify({'labels': default_labels})
@app.route('/get_default_relationship_labels', methods=['GET'])
def get_default_relationship_labels():
"""
Return the default relationship labels with their colors
"""
default_relationship_labels = [
{'name': 'worked at', 'color': '#fce7f3', 'border': '#ec4899'},
{'name': 'visited', 'color': '#f3e8ff', 'border': '#a855f7'}
]
return jsonify({'relationship_labels': default_relationship_labels})
def get_or_create_gliner_pipeline(labels):
"""
Get or create GLiNER pipeline with specified labels
"""
global gliner_nlp
# Convert labels to lowercase for GLiNER
gliner_labels = [label.lower() for label in labels]
try:
# Create new pipeline if it doesn't exist or labels changed
custom_spacy_config = {
"gliner_model": "gliner-community/gliner_small-v2.5",
"chunk_size": 250,
"labels": gliner_labels,
"style": "ent"
}
gliner_nlp = spacy.blank("en")
gliner_nlp.add_pipe("gliner_spacy", config=custom_spacy_config)
return gliner_nlp
except Exception as e:
print(f"Error creating GLiNER pipeline: {e}")
return None
@app.route('/run_gliner', methods=['POST'])
def run_gliner():
"""
Run GLiNER entity extraction on the provided text with specified labels
"""
data = request.get_json()
text = data.get('text', '')
labels = data.get('labels', [])
if not text:
return jsonify({'error': 'No text provided'}), 400
if not labels:
return jsonify({'error': 'No labels provided'}), 400
try:
# Get or create GLiNER pipeline
pipeline = get_or_create_gliner_pipeline(labels)
if pipeline is None:
return jsonify({'error': 'Failed to initialize GLiNER pipeline'}), 500
# Process text with GLiNER
doc = pipeline(text)
# Extract entities with token boundaries
entities = []
for ent in doc.ents:
# Map GLiNER label back to user's label format
original_label = None
for label in labels:
if label.lower() == ent.label_.lower():
original_label = label
break
if original_label:
entities.append({
'text': ent.text,
'start': ent.start_char,
'end': ent.end_char,
'label': original_label,
'confidence': getattr(ent, 'score', 1.0) if hasattr(ent, 'score') else 1.0
})
return jsonify({
'entities': entities,
'total_found': len(entities)
})
except Exception as e:
print(f"GLiNER processing error: {e}")
return jsonify({'error': f'GLiNER processing failed: {str(e)}'}), 500
@app.route('/run_gliner_relationships', methods=['POST'])
def run_gliner_relationships():
"""
Run GLiNER relationship extraction on the provided text with specified relationship labels
"""
data = request.get_json()
text = data.get('text', '')
relationship_labels = data.get('relationship_labels', [])
entity_labels = data.get('entity_labels', ["person", "organization", "location", "date", "place"])
if not text:
return jsonify({'error': 'No text provided'}), 400
if not relationship_labels:
return jsonify({'error': 'No relationship labels provided'}), 400
try:
# Get GLiNER multitask model
model = get_or_create_multitask_model()
if model is None:
return jsonify({'error': 'Failed to initialize GLiNER multitask model'}), 500
# First extract entities using the provided entity labels
print(f"Using entity labels: {entity_labels}")
entities = model.predict_entities(text, entity_labels, threshold=0.3)
print(entities)
# Then extract relationships using the specific format
formatted_labels = []
for label in relationship_labels:
for entity_label in entity_labels:
formatted_labels.append(f"{entity_label} <> {label}")
print(f"Formatted relationship labels: {formatted_labels}")
relation_entities = model.predict_entities(text, formatted_labels, threshold=0.3)
# Process results into relationship triplets
relationships = []
# Group relation entities by their relation type and try to find entity pairs
for rel_entity in relation_entities:
print(rel_entity)
label_parts = rel_entity['label'].split(' <> ')
if len(label_parts) == 2:
entity_type, relation_type = label_parts
# Find potential subject and object entities near this relation
rel_start = rel_entity['start']
rel_end = rel_entity['end']
# Look for entities before and after the relation mention
subject_candidates = [e for e in entities if e['end'] <= rel_start and abs(e['end'] - rel_start) < 100]
object_candidates = [e for e in entities if e['start'] >= rel_end and abs(e['start'] - rel_end) < 100]
# Also look for entities that contain or are contained by the relation text
overlapping_entities = [e for e in entities if
(e['start'] <= rel_start and e['end'] >= rel_end) or # entity contains relation
(rel_start <= e['start'] and rel_end >= e['end']) # relation contains entity
]
if subject_candidates and object_candidates:
# Take the closest entities
subject = max(subject_candidates, key=lambda x: x['end'])
object_entity = min(object_candidates, key=lambda x: x['start'])
relationships.append({
'subject': subject['text'],
'subject_start': subject['start'],
'subject_end': subject['end'],
'relation_type': relation_type,
'relation_text': rel_entity['text'],
'relation_start': rel_entity['start'],
'relation_end': rel_entity['end'],
'object': object_entity['text'],
'object_start': object_entity['start'],
'object_end': object_entity['end'],
'confidence': rel_entity['score'],
'full_text': f"{subject['text']} {relation_type} {object_entity['text']}"
})
elif overlapping_entities:
# Handle cases where the relation text spans or overlaps with entities
for ent in overlapping_entities:
relationships.append({
'subject': ent['text'],
'subject_start': ent['start'],
'subject_end': ent['end'],
'relation_type': relation_type,
'relation_text': rel_entity['text'],
'relation_start': rel_entity['start'],
'relation_end': rel_entity['end'],
'object': '', # Will be filled by user or further processing
'object_start': -1,
'object_end': -1,
'confidence': rel_entity['score'],
'full_text': f"{ent['text']} {relation_type} [object]"
})
return jsonify({
'relationships': relationships,
'total_found': len(relationships)
})
except Exception as e:
print(f"GLiNER relationship processing error: {e}")
return jsonify({'error': f'GLiNER relationship processing failed: {str(e)}'}), 500
@app.route('/search_wikidata', methods=['POST'])
def search_wikidata():
"""
Search Wikidata for entities matching the query
"""
data = request.get_json()
query = data.get('query', '').strip()
limit = data.get('limit', 10)
if not query:
return jsonify({'error': 'No query provided'}), 400
try:
# Wikidata search API endpoint
url = 'https://www.wikidata.org/w/api.php'
params = {
'action': 'wbsearchentities',
'search': query,
'language': 'en',
'format': 'json',
'limit': limit,
'type': 'item'
}
headers = {
'User-Agent': 'AnnotationTool/1.0 (https://github.com/user/annotation-tool) Python/requests'
}
response = requests.get(url, params=params, headers=headers, timeout=10)
response.raise_for_status()
data = response.json()
# Extract relevant information
results = []
if 'search' in data:
for item in data['search']:
result = {
'id': item.get('id', ''),
'label': item.get('label', ''),
'description': item.get('description', ''),
'url': f"https://www.wikidata.org/wiki/{item.get('id', '')}"
}
results.append(result)
return jsonify({
'results': results,
'total': len(results)
})
except requests.exceptions.RequestException as e:
print(f"Wikidata API error: {e}")
return jsonify({'error': 'Failed to search Wikidata'}), 500
except Exception as e:
print(f"Wikidata search error: {e}")
return jsonify({'error': f'Search failed: {str(e)}'}), 500
@app.route('/get_wikidata_entity', methods=['POST'])
def get_wikidata_entity():
"""
Get Wikidata entity information by Q-code
"""
data = request.get_json()
qcode = data.get('qcode', '').strip()
if not qcode:
return jsonify({'error': 'No Q-code provided'}), 400
# Ensure Q-code format
if not qcode.startswith('Q'):
qcode = 'Q' + qcode.lstrip('Q')
try:
# Wikidata entity API endpoint
url = 'https://www.wikidata.org/w/api.php'
params = {
'action': 'wbgetentities',
'ids': qcode,
'languages': 'en',
'format': 'json'
}
headers = {
'User-Agent': 'AnnotationTool/1.0 (https://github.com/user/annotation-tool) Python/requests'
}
response = requests.get(url, params=params, headers=headers, timeout=10)
response.raise_for_status()
data = response.json()
if 'entities' in data and qcode in data['entities']:
entity = data['entities'][qcode]
if 'missing' in entity:
return jsonify({'error': f'Entity {qcode} not found'}), 404
# Extract information
result = {
'id': qcode,
'label': entity.get('labels', {}).get('en', {}).get('value', ''),
'description': entity.get('descriptions', {}).get('en', {}).get('value', ''),
'url': f"https://www.wikidata.org/wiki/{qcode}"
}
return jsonify({'entity': result})
else:
return jsonify({'error': f'Entity {qcode} not found'}), 404
except requests.exceptions.RequestException as e:
print(f"Wikidata API error: {e}")
return jsonify({'error': 'Failed to get Wikidata entity'}), 500
except Exception as e:
print(f"Wikidata entity error: {e}")
return jsonify({'error': f'Request failed: {str(e)}'}), 500
if __name__ == '__main__':
import os
port = int(os.environ.get('PORT', 7860))
app.run(host='0.0.0.0', port=port, debug=False)