maahi2412 commited on
Commit
eb27057
·
verified ·
1 Parent(s): 0343868

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +408 -408
app.py CHANGED
@@ -1,467 +1,467 @@
1
- # from flask import Flask, request, jsonify
2
- # import os
3
- # import pdfplumber
4
- # import pytesseract
5
- # from PIL import Image
6
- # from transformers import PegasusForConditionalGeneration, PegasusTokenizer
7
- # import torch
8
- # import logging
9
 
10
- # app = Flask(__name__)
11
 
12
- # # Set up logging
13
- # logging.basicConfig(level=logging.INFO)
14
- # logger = logging.getLogger(__name__)
15
 
16
- # # Load Pegasus Model (load once globally)
17
- # logger.info("Loading Pegasus model and tokenizer...")
18
- # tokenizer = PegasusTokenizer.from_pretrained("google/pegasus-xsum")
19
- # model = PegasusForConditionalGeneration.from_pretrained("google/pegasus-xsum").to("cpu") # Force CPU to manage memory
20
- # logger.info("Model loaded successfully.")
21
 
22
- # # Extract text from PDF with page limit
23
- # def extract_text_from_pdf(file_path, max_pages=5):
24
- # text = ""
25
- # try:
26
- # with pdfplumber.open(file_path) as pdf:
27
- # total_pages = len(pdf.pages)
28
- # pages_to_process = min(total_pages, max_pages)
29
- # logger.info(f"Extracting text from {pages_to_process} of {total_pages} pages in {file_path}")
30
- # for i, page in enumerate(pdf.pages[:pages_to_process]):
31
- # try:
32
- # extracted = page.extract_text()
33
- # if extracted:
34
- # text += extracted + "\n"
35
- # else:
36
- # logger.info(f"No text on page {i+1}, attempting OCR...")
37
- # image = page.to_image().original
38
- # text += pytesseract.image_to_string(image) + "\n"
39
- # except Exception as e:
40
- # logger.warning(f"Error processing page {i+1}: {e}")
41
- # continue
42
- # except Exception as e:
43
- # logger.error(f"Failed to process PDF {file_path}: {e}")
44
- # return ""
45
- # return text.strip()
46
 
47
- # # Extract text from image (OCR)
48
- # def extract_text_from_image(file_path):
49
- # try:
50
- # logger.info(f"Extracting text from image {file_path} using OCR...")
51
- # image = Image.open(file_path)
52
- # text = pytesseract.image_to_string(image)
53
- # return text.strip()
54
- # except Exception as e:
55
- # logger.error(f"Failed to process image {file_path}: {e}")
56
- # return ""
57
 
58
- # # Summarize text with chunking for large inputs
59
- # def summarize_text(text, max_input_length=512, max_output_length=150):
60
- # try:
61
- # logger.info("Summarizing text...")
62
- # # Tokenize and truncate to max_input_length
63
- # inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=max_input_length, padding=True)
64
- # input_length = inputs["input_ids"].shape[1]
65
- # logger.info(f"Input length: {input_length} tokens")
66
 
67
- # # Adjust generation params for efficiency
68
- # summary_ids = model.generate(
69
- # inputs["input_ids"],
70
- # max_length=max_output_length,
71
- # min_length=30,
72
- # num_beams=2, # Reduce beams for speedup
73
- # early_stopping=True,
74
- # length_penalty=1.0, # Encourage shorter outputs
75
- # )
76
- # summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
77
- # logger.info("Summarization completed.")
78
- # return summary
79
- # except Exception as e:
80
- # logger.error(f"Error during summarization: {e}")
81
- # return ""
82
 
83
- # @app.route('/summarize', methods=['POST'])
84
- # def summarize_document():
85
- # if 'file' not in request.files:
86
- # logger.error("No file uploaded in request.")
87
- # return jsonify({"error": "No file uploaded"}), 400
88
 
89
- # file = request.files['file']
90
- # filename = file.filename
91
- # if not filename:
92
- # logger.error("Empty filename in request.")
93
- # return jsonify({"error": "No file uploaded"}), 400
94
 
95
- # file_path = os.path.join("/tmp", filename)
96
- # try:
97
- # file.save(file_path)
98
- # logger.info(f"File saved to {file_path}")
99
 
100
- # if filename.lower().endswith('.pdf'):
101
- # text = extract_text_from_pdf(file_path, max_pages=2) # Reduce to 2 pages
102
- # elif filename.lower().endswith(('.png', '.jpeg', '.jpg')):
103
- # text = extract_text_from_image(file_path)
104
- # else:
105
- # logger.error(f"Unsupported file format: {filename}")
106
- # return jsonify({"error": "Unsupported file format. Use PDF, PNG, JPEG, or JPG"}), 400
107
 
108
- # if not text:
109
- # logger.warning(f"No text extracted from {filename}")
110
- # return jsonify({"error": "No text extracted from the file"}), 400
111
 
112
- # summary = summarize_text(text)
113
- # if not summary:
114
- # logger.warning("Summarization failed to produce output.")
115
- # return jsonify({"error": "Failed to generate summary"}), 500
116
 
117
- # logger.info(f"Summary generated for {filename}")
118
- # return jsonify({"summary": summary})
119
 
120
- # except Exception as e:
121
- # logger.error(f"Unexpected error processing {filename}: {e}")
122
- # return jsonify({"error": str(e)}), 500
123
 
124
- # finally:
125
- # if os.path.exists(file_path):
126
- # try:
127
- # os.remove(file_path)
128
- # logger.info(f"Cleaned up file: {file_path}")
129
- # except Exception as e:
130
- # logger.warning(f"Failed to delete {file_path}: {e}")
131
-
132
- # if __name__ == '__main__':
133
- # logger.info("Starting Flask app...")
134
- # app.run(host='0.0.0.0', port=7860)
135
 
 
 
 
136
 
137
- import os
138
- import pdfplumber
139
- from PIL import Image
140
- import pytesseract
141
- import transformers
142
- from transformers import logging
143
- logging.set_verbosity_error()
144
- import numpy as np
145
- from flask import Flask, request, jsonify
146
- from flask_cors import CORS
147
- from transformers import PegasusForConditionalGeneration, PegasusTokenizer, BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
148
- from datasets import load_dataset, concatenate_datasets
149
- import torch
150
- from sklearn.feature_extraction.text import TfidfVectorizer
151
- from sklearn.metrics.pairwise import cosine_similarity
 
152
 
153
- app = Flask(__name__)
154
- CORS(app)
155
- UPLOAD_FOLDER = 'uploads'
156
- PEGASUS_MODEL_DIR = 'fine_tuned_pegasus'
157
- BERT_MODEL_DIR = 'fine_tuned_bert'
158
- LEGALBERT_MODEL_DIR = 'fine_tuned_legalbert'
159
- MAX_FILE_SIZE = 100 * 1024 * 1024
160
- os.makedirs(UPLOAD_FOLDER, exist_ok=True)
161
 
162
- transformers.logging.set_verbosity_error()
163
- os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
164
 
165
- # Pegasus Fine-Tuning
166
- def load_or_finetune_pegasus():
167
- if os.path.exists(PEGASUS_MODEL_DIR):
168
- print("Loading fine-tuned Pegasus model...")
169
- tokenizer = PegasusTokenizer.from_pretrained(PEGASUS_MODEL_DIR)
170
- model = PegasusForConditionalGeneration.from_pretrained(PEGASUS_MODEL_DIR)
171
- else:
172
- print("Fine-tuning Pegasus on CNN/Daily Mail and XSUM...")
173
- tokenizer = PegasusTokenizer.from_pretrained("google/pegasus-xsum")
174
- model = PegasusForConditionalGeneration.from_pretrained("google/pegasus-xsum")
175
 
176
- # Load and combine datasets
177
- cnn_dm = load_dataset("cnn_dailymail", "3.0.0", split="train[:5000]") # 5K samples
178
- xsum = load_dataset("xsum", split="train[:5000]") # 5K samples
179
- combined_dataset = concatenate_datasets([cnn_dm, xsum])
180
 
181
- def preprocess_function(examples):
182
- inputs = tokenizer(examples["article"] if "article" in examples else examples["document"],
183
- max_length=512, truncation=True, padding="max_length")
184
- targets = tokenizer(examples["highlights"] if "highlights" in examples else examples["summary"],
185
- max_length=400, truncation=True, padding="max_length")
186
- inputs["labels"] = targets["input_ids"]
187
- return inputs
188
 
189
- tokenized_dataset = combined_dataset.map(preprocess_function, batched=True)
190
- train_dataset = tokenized_dataset.select(range(8000)) # 80%
191
- eval_dataset = tokenized_dataset.select(range(8000, 10000)) # 20%
192
 
193
- training_args = TrainingArguments(
194
- output_dir="./pegasus_finetune",
195
- num_train_epochs=3, # Increased for better fine-tuning
196
- per_device_train_batch_size=1,
197
- per_device_eval_batch_size=1,
198
- warmup_steps=500,
199
- weight_decay=0.01,
200
- logging_dir="./logs",
201
- logging_steps=10,
202
- eval_strategy="epoch",
203
- save_strategy="epoch",
204
- load_best_model_at_end=True,
205
- )
206
 
207
- trainer = Trainer(
208
- model=model,
209
- args=training_args,
210
- train_dataset=train_dataset,
211
- eval_dataset=eval_dataset,
212
- )
213
 
214
- trainer.train()
215
- trainer.save_model(PEGASUS_MODEL_DIR)
216
- tokenizer.save_pretrained(PEGASUS_MODEL_DIR)
217
- print(f"Fine-tuned Pegasus saved to {PEGASUS_MODEL_DIR}")
218
 
219
- return tokenizer, model
220
 
221
- # BERT Fine-Tuning
222
- def load_or_finetune_bert():
223
- if os.path.exists(BERT_MODEL_DIR):
224
- print("Loading fine-tuned BERT model...")
225
- tokenizer = BertTokenizer.from_pretrained(BERT_MODEL_DIR)
226
- model = BertForSequenceClassification.from_pretrained(BERT_MODEL_DIR, num_labels=2)
227
- else:
228
- print("Fine-tuning BERT on CNN/Daily Mail for extractive summarization...")
229
- tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
230
- model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)
231
 
232
- # Load dataset and preprocess for sentence classification
233
- cnn_dm = load_dataset("cnn_dailymail", "3.0.0", split="train[:5000]")
234
 
235
- def preprocess_for_extractive(examples):
236
- sentences = []
237
- labels = []
238
- for article, highlights in zip(examples["article"], examples["highlights"]):
239
- article_sents = article.split(". ")
240
- highlight_sents = highlights.split(". ")
241
- for sent in article_sents:
242
- if sent.strip():
243
- # Label as 1 if sentence is similar to any highlight, else 0
244
- is_summary = any(sent.strip() in h for h in highlight_sents)
245
- sentences.append(sent)
246
- labels.append(1 if is_summary else 0)
247
- return {"sentence": sentences, "label": labels}
248
 
249
- dataset = cnn_dm.map(preprocess_for_extractive, batched=True, remove_columns=["article", "highlights", "id"])
250
- tokenized_dataset = dataset.map(
251
- lambda x: tokenizer(x["sentence"], max_length=512, truncation=True, padding="max_length"),
252
- batched=True
253
- )
254
- tokenized_dataset = tokenized_dataset.remove_columns(["sentence"])
255
- train_dataset = tokenized_dataset.select(range(int(0.8 * len(tokenized_dataset))))
256
- eval_dataset = tokenized_dataset.select(range(int(0.8 * len(tokenized_dataset)), len(tokenized_dataset)))
257
 
258
- training_args = TrainingArguments(
259
- output_dir="./bert_finetune",
260
- num_train_epochs=3,
261
- per_device_train_batch_size=8,
262
- per_device_eval_batch_size=8,
263
- warmup_steps=500,
264
- weight_decay=0.01,
265
- logging_dir="./logs",
266
- logging_steps=10,
267
- eval_strategy="epoch",
268
- save_strategy="epoch",
269
- load_best_model_at_end=True,
270
- )
271
 
272
- trainer = Trainer(
273
- model=model,
274
- args=training_args,
275
- train_dataset=train_dataset,
276
- eval_dataset=eval_dataset,
277
- )
278
 
279
- trainer.train()
280
- trainer.save_model(BERT_MODEL_DIR)
281
- tokenizer.save_pretrained(BERT_MODEL_DIR)
282
- print(f"Fine-tuned BERT saved to {BERT_MODEL_DIR}")
283
 
284
- return tokenizer, model
285
 
286
- # LegalBERT Fine-Tuning
287
- def load_or_finetune_legalbert():
288
- if os.path.exists(LEGALBERT_MODEL_DIR):
289
- print("Loading fine-tuned LegalBERT model...")
290
- tokenizer = BertTokenizer.from_pretrained(LEGALBERT_MODEL_DIR)
291
- model = BertForSequenceClassification.from_pretrained(LEGALBERT_MODEL_DIR, num_labels=2)
292
- else:
293
- print("Fine-tuning LegalBERT on Billsum for extractive summarization...")
294
- tokenizer = BertTokenizer.from_pretrained("nlpaueb/legal-bert-base-uncased")
295
- model = BertForSequenceClassification.from_pretrained("nlpaueb/legal-bert-base-uncased", num_labels=2)
296
 
297
- # Load dataset
298
- billsum = load_dataset("billsum", split="train[:5000]")
299
 
300
- def preprocess_for_extractive(examples):
301
- sentences = []
302
- labels = []
303
- for text, summary in zip(examples["text"], examples["summary"]):
304
- text_sents = text.split(". ")
305
- summary_sents = summary.split(". ")
306
- for sent in text_sents:
307
- if sent.strip():
308
- is_summary = any(sent.strip() in s for s in summary_sents)
309
- sentences.append(sent)
310
- labels.append(1 if is_summary else 0)
311
- return {"sentence": sentences, "label": labels}
312
 
313
- dataset = billsum.map(preprocess_for_extractive, batched=True, remove_columns=["text", "summary", "title"])
314
- tokenized_dataset = dataset.map(
315
- lambda x: tokenizer(x["sentence"], max_length=512, truncation=True, padding="max_length"),
316
- batched=True
317
- )
318
- tokenized_dataset = tokenized_dataset.remove_columns(["sentence"])
319
- train_dataset = tokenized_dataset.select(range(int(0.8 * len(tokenized_dataset))))
320
- eval_dataset = tokenized_dataset.select(range(int(0.8 * len(tokenized_dataset)), len(tokenized_dataset)))
321
 
322
- training_args = TrainingArguments(
323
- output_dir="./legalbert_finetune",
324
- num_train_epochs=3,
325
- per_device_train_batch_size=8,
326
- per_device_eval_batch_size=8,
327
- warmup_steps=500,
328
- weight_decay=0.01,
329
- logging_dir="./logs",
330
- logging_steps=10,
331
- eval_strategy="epoch",
332
- save_strategy="epoch",
333
- load_best_model_at_end=True,
334
- )
335
 
336
- trainer = Trainer(
337
- model=model,
338
- args=training_args,
339
- train_dataset=train_dataset,
340
- eval_dataset=eval_dataset,
341
- )
342
 
343
- trainer.train()
344
- trainer.save_model(LEGALBERT_MODEL_DIR)
345
- tokenizer.save_pretrained(LEGALBERT_MODEL_DIR)
346
- print(f"Fine-tuned LegalBERT saved to {LEGALBERT_MODEL_DIR}")
347
 
348
- return tokenizer, model
349
 
350
- # Load models
351
- # pegasus_tokenizer, pegasus_model = load_or_finetune_pegasus()
352
- # bert_tokenizer, bert_model = load_or_finetune_bert()
353
- # legalbert_tokenizer, legalbert_model = load_or_finetune_legalbert()
354
 
355
- def extract_text_from_pdf(file_path):
356
- text = ""
357
- with pdfplumber.open(file_path) as pdf:
358
- for page in pdf.pages:
359
- text += page.extract_text() or ""
360
- return text
361
 
362
- def extract_text_from_image(file_path):
363
- image = Image.open(file_path)
364
- text = pytesseract.image_to_string(image)
365
- return text
366
 
367
- def choose_model(text):
368
- legal_keywords = ["court", "legal", "law", "judgment", "contract", "statute", "case"]
369
- tfidf = TfidfVectorizer(vocabulary=legal_keywords)
370
- tfidf_matrix = tfidf.fit_transform([text.lower()])
371
- score = np.sum(tfidf_matrix.toarray())
372
- if score > 0.1:
373
- return "legalbert"
374
- elif len(text.split()) > 50:
375
- return "pegasus"
376
- else:
377
- return "bert"
378
 
379
- def summarize_with_pegasus(text):
380
- inputs = pegasus_tokenizer(text, truncation=True, padding="longest", return_tensors="pt", max_length=512)
381
- summary_ids = pegasus_model.generate(
382
- inputs["input_ids"],
383
- max_length=400, min_length=80, length_penalty=1.5, num_beams=4
384
- )
385
- return pegasus_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
386
 
387
- def summarize_with_bert(text):
388
- sentences = text.split(". ")
389
- if len(sentences) < 6: # Ensure enough for 5 sentences
390
- return text
391
- inputs = bert_tokenizer(sentences, return_tensors="pt", padding=True, truncation=True, max_length=512)
392
- with torch.no_grad():
393
- outputs = bert_model(**inputs)
394
- logits = outputs.logits
395
- probs = torch.softmax(logits, dim=1)[:, 1] # Probability of being a summary sentence
396
- key_sentence_idx = probs.argsort(descending=True)[:5] # Top 5 sentences
397
- return ". ".join([sentences[idx] for idx in key_sentence_idx if sentences[idx].strip()])
398
 
399
- def summarize_with_legalbert(text):
400
- sentences = text.split(". ")
401
- if len(sentences) < 6:
402
- return text
403
- inputs = legalbert_tokenizer(sentences, return_tensors="pt", padding=True, truncation=True, max_length=512)
404
- with torch.no_grad():
405
- outputs = legalbert_model(**inputs)
406
- logits = outputs.logits
407
- probs = torch.softmax(logits, dim=1)[:, 1]
408
- key_sentence_idx = probs.argsort(descending=True)[:5]
409
- return ". ".join([sentences[idx] for idx in key_sentence_idx if sentences[idx].strip()])
410
 
411
- # Load Models
412
- pegasus_tokenizer, pegasus_model = load_or_finetune_pegasus()
413
- bert_tokenizer, bert_model = load_or_finetune_bert()
414
- legalbert_tokenizer, legalbert_model = load_or_finetune_legalbert()
415
 
416
- @app.route('/summarize', methods=['POST'])
417
- def summarize_document():
418
- if 'file' not in request.files:
419
- return jsonify({"error": "No file uploaded"}), 400
420
 
421
- file = request.files['file']
422
- filename = file.filename
423
- file.seek(0, os.SEEK_END)
424
- file_size = file.tell()
425
- if file_size > MAX_FILE_SIZE:
426
- return jsonify({"error": f"File size exceeds {MAX_FILE_SIZE // (1024 * 1024)} MB"}), 413
427
- file.seek(0)
428
- file_path = os.path.join(UPLOAD_FOLDER, filename)
429
- try:
430
- file.save(file_path)
431
- except Exception as e:
432
- return jsonify({"error": f"Failed to save file: {str(e)}"}), 500
433
 
434
- try:
435
- if filename.endswith('.pdf'):
436
- text = extract_text_from_pdf(file_path)
437
- elif filename.endswith(('.png', '.jpeg', '.jpg')):
438
- text = extract_text_from_image(file_path)
439
- else:
440
- os.remove(file_path)
441
- return jsonify({"error": "Unsupported file format."}), 400
442
- except Exception as e:
443
- os.remove(file_path)
444
- return jsonify({"error": f"Text extraction failed: {str(e)}"}), 500
445
 
446
- if not text.strip():
447
- os.remove(file_path)
448
- return jsonify({"error": "No text extracted"}), 400
449
 
450
- try:
451
- model = choose_model(text)
452
- if model == "pegasus":
453
- summary = summarize_with_pegasus(text)
454
- elif model == "bert":
455
- summary = summarize_with_bert(text)
456
- elif model == "legalbert":
457
- summary = summarize_with_legalbert(text)
458
- except Exception as e:
459
- os.remove(file_path)
460
- return jsonify({"error": f"Summarization failed: {str(e)}"}), 500
461
 
462
- os.remove(file_path)
463
- return jsonify({"model_used": model, "summary": summary})
464
 
465
- if __name__ == '__main__':
466
- port = int(os.environ.get("PORT", 5000))
467
- app.run(debug=False, host='0.0.0.0', port=port)
 
1
+ from flask import Flask, request, jsonify
2
+ import os
3
+ import pdfplumber
4
+ import pytesseract
5
+ from PIL import Image
6
+ from transformers import PegasusForConditionalGeneration, PegasusTokenizer
7
+ import torch
8
+ import logging
9
 
10
+ app = Flask(__name__)
11
 
12
+ # Set up logging
13
+ logging.basicConfig(level=logging.INFO)
14
+ logger = logging.getLogger(__name__)
15
 
16
+ # Load Pegasus Model (load once globally)
17
+ logger.info("Loading Pegasus model and tokenizer...")
18
+ tokenizer = PegasusTokenizer.from_pretrained("google/pegasus-xsum")
19
+ model = PegasusForConditionalGeneration.from_pretrained("google/pegasus-xsum").to("cpu") # Force CPU to manage memory
20
+ logger.info("Model loaded successfully.")
21
 
22
+ # Extract text from PDF with page limit
23
+ def extract_text_from_pdf(file_path, max_pages=5):
24
+ text = ""
25
+ try:
26
+ with pdfplumber.open(file_path) as pdf:
27
+ total_pages = len(pdf.pages)
28
+ pages_to_process = min(total_pages, max_pages)
29
+ logger.info(f"Extracting text from {pages_to_process} of {total_pages} pages in {file_path}")
30
+ for i, page in enumerate(pdf.pages[:pages_to_process]):
31
+ try:
32
+ extracted = page.extract_text()
33
+ if extracted:
34
+ text += extracted + "\n"
35
+ else:
36
+ logger.info(f"No text on page {i+1}, attempting OCR...")
37
+ image = page.to_image().original
38
+ text += pytesseract.image_to_string(image) + "\n"
39
+ except Exception as e:
40
+ logger.warning(f"Error processing page {i+1}: {e}")
41
+ continue
42
+ except Exception as e:
43
+ logger.error(f"Failed to process PDF {file_path}: {e}")
44
+ return ""
45
+ return text.strip()
46
 
47
+ # Extract text from image (OCR)
48
+ def extract_text_from_image(file_path):
49
+ try:
50
+ logger.info(f"Extracting text from image {file_path} using OCR...")
51
+ image = Image.open(file_path)
52
+ text = pytesseract.image_to_string(image)
53
+ return text.strip()
54
+ except Exception as e:
55
+ logger.error(f"Failed to process image {file_path}: {e}")
56
+ return ""
57
 
58
+ # Summarize text with chunking for large inputs
59
+ def summarize_text(text, max_input_length=512, max_output_length=150):
60
+ try:
61
+ logger.info("Summarizing text...")
62
+ # Tokenize and truncate to max_input_length
63
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=max_input_length, padding=True)
64
+ input_length = inputs["input_ids"].shape[1]
65
+ logger.info(f"Input length: {input_length} tokens")
66
 
67
+ # Adjust generation params for efficiency
68
+ summary_ids = model.generate(
69
+ inputs["input_ids"],
70
+ max_length=max_output_length,
71
+ min_length=30,
72
+ num_beams=2, # Reduce beams for speedup
73
+ early_stopping=True,
74
+ length_penalty=1.0, # Encourage shorter outputs
75
+ )
76
+ summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
77
+ logger.info("Summarization completed.")
78
+ return summary
79
+ except Exception as e:
80
+ logger.error(f"Error during summarization: {e}")
81
+ return ""
82
 
83
+ @app.route('/summarize', methods=['POST'])
84
+ def summarize_document():
85
+ if 'file' not in request.files:
86
+ logger.error("No file uploaded in request.")
87
+ return jsonify({"error": "No file uploaded"}), 400
88
 
89
+ file = request.files['file']
90
+ filename = file.filename
91
+ if not filename:
92
+ logger.error("Empty filename in request.")
93
+ return jsonify({"error": "No file uploaded"}), 400
94
 
95
+ file_path = os.path.join("/tmp", filename)
96
+ try:
97
+ file.save(file_path)
98
+ logger.info(f"File saved to {file_path}")
99
 
100
+ if filename.lower().endswith('.pdf'):
101
+ text = extract_text_from_pdf(file_path, max_pages=2) # Reduce to 2 pages
102
+ elif filename.lower().endswith(('.png', '.jpeg', '.jpg')):
103
+ text = extract_text_from_image(file_path)
104
+ else:
105
+ logger.error(f"Unsupported file format: {filename}")
106
+ return jsonify({"error": "Unsupported file format. Use PDF, PNG, JPEG, or JPG"}), 400
107
 
108
+ if not text:
109
+ logger.warning(f"No text extracted from {filename}")
110
+ return jsonify({"error": "No text extracted from the file"}), 400
111
 
112
+ summary = summarize_text(text)
113
+ if not summary:
114
+ logger.warning("Summarization failed to produce output.")
115
+ return jsonify({"error": "Failed to generate summary"}), 500
116
 
117
+ logger.info(f"Summary generated for {filename}")
118
+ return jsonify({"summary": summary})
119
 
120
+ except Exception as e:
121
+ logger.error(f"Unexpected error processing {filename}: {e}")
122
+ return jsonify({"error": str(e)}), 500
123
 
124
+ finally:
125
+ if os.path.exists(file_path):
126
+ try:
127
+ os.remove(file_path)
128
+ logger.info(f"Cleaned up file: {file_path}")
129
+ except Exception as e:
130
+ logger.warning(f"Failed to delete {file_path}: {e}")
 
 
 
 
131
 
132
+ if __name__ == '__main__':
133
+ logger.info("Starting Flask app...")
134
+ app.run(host='0.0.0.0', port=7860)
135
 
136
+ # ---------------------------------
137
+ # import os
138
+ # import pdfplumber
139
+ # from PIL import Image
140
+ # import pytesseract
141
+ # import transformers
142
+ # from transformers import logging
143
+ # logging.set_verbosity_error()
144
+ # import numpy as np
145
+ # from flask import Flask, request, jsonify
146
+ # from flask_cors import CORS
147
+ # from transformers import PegasusForConditionalGeneration, PegasusTokenizer, BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
148
+ # from datasets import load_dataset, concatenate_datasets
149
+ # import torch
150
+ # from sklearn.feature_extraction.text import TfidfVectorizer
151
+ # from sklearn.metrics.pairwise import cosine_similarity
152
 
153
+ # app = Flask(__name__)
154
+ # CORS(app)
155
+ # UPLOAD_FOLDER = 'uploads'
156
+ # PEGASUS_MODEL_DIR = 'fine_tuned_pegasus'
157
+ # BERT_MODEL_DIR = 'fine_tuned_bert'
158
+ # LEGALBERT_MODEL_DIR = 'fine_tuned_legalbert'
159
+ # MAX_FILE_SIZE = 100 * 1024 * 1024
160
+ # os.makedirs(UPLOAD_FOLDER, exist_ok=True)
161
 
162
+ # transformers.logging.set_verbosity_error()
163
+ # os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
164
 
165
+ # # Pegasus Fine-Tuning
166
+ # def load_or_finetune_pegasus():
167
+ # if os.path.exists(PEGASUS_MODEL_DIR):
168
+ # print("Loading fine-tuned Pegasus model...")
169
+ # tokenizer = PegasusTokenizer.from_pretrained(PEGASUS_MODEL_DIR)
170
+ # model = PegasusForConditionalGeneration.from_pretrained(PEGASUS_MODEL_DIR)
171
+ # else:
172
+ # print("Fine-tuning Pegasus on CNN/Daily Mail and XSUM...")
173
+ # tokenizer = PegasusTokenizer.from_pretrained("google/pegasus-xsum")
174
+ # model = PegasusForConditionalGeneration.from_pretrained("google/pegasus-xsum")
175
 
176
+ # # Load and combine datasets
177
+ # cnn_dm = load_dataset("cnn_dailymail", "3.0.0", split="train[:5000]") # 5K samples
178
+ # xsum = load_dataset("xsum", split="train[:5000]") # 5K samples
179
+ # combined_dataset = concatenate_datasets([cnn_dm, xsum])
180
 
181
+ # def preprocess_function(examples):
182
+ # inputs = tokenizer(examples["article"] if "article" in examples else examples["document"],
183
+ # max_length=512, truncation=True, padding="max_length")
184
+ # targets = tokenizer(examples["highlights"] if "highlights" in examples else examples["summary"],
185
+ # max_length=400, truncation=True, padding="max_length")
186
+ # inputs["labels"] = targets["input_ids"]
187
+ # return inputs
188
 
189
+ # tokenized_dataset = combined_dataset.map(preprocess_function, batched=True)
190
+ # train_dataset = tokenized_dataset.select(range(8000)) # 80%
191
+ # eval_dataset = tokenized_dataset.select(range(8000, 10000)) # 20%
192
 
193
+ # training_args = TrainingArguments(
194
+ # output_dir="./pegasus_finetune",
195
+ # num_train_epochs=3, # Increased for better fine-tuning
196
+ # per_device_train_batch_size=1,
197
+ # per_device_eval_batch_size=1,
198
+ # warmup_steps=500,
199
+ # weight_decay=0.01,
200
+ # logging_dir="./logs",
201
+ # logging_steps=10,
202
+ # eval_strategy="epoch",
203
+ # save_strategy="epoch",
204
+ # load_best_model_at_end=True,
205
+ # )
206
 
207
+ # trainer = Trainer(
208
+ # model=model,
209
+ # args=training_args,
210
+ # train_dataset=train_dataset,
211
+ # eval_dataset=eval_dataset,
212
+ # )
213
 
214
+ # trainer.train()
215
+ # trainer.save_model(PEGASUS_MODEL_DIR)
216
+ # tokenizer.save_pretrained(PEGASUS_MODEL_DIR)
217
+ # print(f"Fine-tuned Pegasus saved to {PEGASUS_MODEL_DIR}")
218
 
219
+ # return tokenizer, model
220
 
221
+ # # BERT Fine-Tuning
222
+ # def load_or_finetune_bert():
223
+ # if os.path.exists(BERT_MODEL_DIR):
224
+ # print("Loading fine-tuned BERT model...")
225
+ # tokenizer = BertTokenizer.from_pretrained(BERT_MODEL_DIR)
226
+ # model = BertForSequenceClassification.from_pretrained(BERT_MODEL_DIR, num_labels=2)
227
+ # else:
228
+ # print("Fine-tuning BERT on CNN/Daily Mail for extractive summarization...")
229
+ # tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
230
+ # model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)
231
 
232
+ # # Load dataset and preprocess for sentence classification
233
+ # cnn_dm = load_dataset("cnn_dailymail", "3.0.0", split="train[:5000]")
234
 
235
+ # def preprocess_for_extractive(examples):
236
+ # sentences = []
237
+ # labels = []
238
+ # for article, highlights in zip(examples["article"], examples["highlights"]):
239
+ # article_sents = article.split(". ")
240
+ # highlight_sents = highlights.split(". ")
241
+ # for sent in article_sents:
242
+ # if sent.strip():
243
+ # # Label as 1 if sentence is similar to any highlight, else 0
244
+ # is_summary = any(sent.strip() in h for h in highlight_sents)
245
+ # sentences.append(sent)
246
+ # labels.append(1 if is_summary else 0)
247
+ # return {"sentence": sentences, "label": labels}
248
 
249
+ # dataset = cnn_dm.map(preprocess_for_extractive, batched=True, remove_columns=["article", "highlights", "id"])
250
+ # tokenized_dataset = dataset.map(
251
+ # lambda x: tokenizer(x["sentence"], max_length=512, truncation=True, padding="max_length"),
252
+ # batched=True
253
+ # )
254
+ # tokenized_dataset = tokenized_dataset.remove_columns(["sentence"])
255
+ # train_dataset = tokenized_dataset.select(range(int(0.8 * len(tokenized_dataset))))
256
+ # eval_dataset = tokenized_dataset.select(range(int(0.8 * len(tokenized_dataset)), len(tokenized_dataset)))
257
 
258
+ # training_args = TrainingArguments(
259
+ # output_dir="./bert_finetune",
260
+ # num_train_epochs=3,
261
+ # per_device_train_batch_size=8,
262
+ # per_device_eval_batch_size=8,
263
+ # warmup_steps=500,
264
+ # weight_decay=0.01,
265
+ # logging_dir="./logs",
266
+ # logging_steps=10,
267
+ # eval_strategy="epoch",
268
+ # save_strategy="epoch",
269
+ # load_best_model_at_end=True,
270
+ # )
271
 
272
+ # trainer = Trainer(
273
+ # model=model,
274
+ # args=training_args,
275
+ # train_dataset=train_dataset,
276
+ # eval_dataset=eval_dataset,
277
+ # )
278
 
279
+ # trainer.train()
280
+ # trainer.save_model(BERT_MODEL_DIR)
281
+ # tokenizer.save_pretrained(BERT_MODEL_DIR)
282
+ # print(f"Fine-tuned BERT saved to {BERT_MODEL_DIR}")
283
 
284
+ # return tokenizer, model
285
 
286
+ # # LegalBERT Fine-Tuning
287
+ # def load_or_finetune_legalbert():
288
+ # if os.path.exists(LEGALBERT_MODEL_DIR):
289
+ # print("Loading fine-tuned LegalBERT model...")
290
+ # tokenizer = BertTokenizer.from_pretrained(LEGALBERT_MODEL_DIR)
291
+ # model = BertForSequenceClassification.from_pretrained(LEGALBERT_MODEL_DIR, num_labels=2)
292
+ # else:
293
+ # print("Fine-tuning LegalBERT on Billsum for extractive summarization...")
294
+ # tokenizer = BertTokenizer.from_pretrained("nlpaueb/legal-bert-base-uncased")
295
+ # model = BertForSequenceClassification.from_pretrained("nlpaueb/legal-bert-base-uncased", num_labels=2)
296
 
297
+ # # Load dataset
298
+ # billsum = load_dataset("billsum", split="train[:5000]")
299
 
300
+ # def preprocess_for_extractive(examples):
301
+ # sentences = []
302
+ # labels = []
303
+ # for text, summary in zip(examples["text"], examples["summary"]):
304
+ # text_sents = text.split(". ")
305
+ # summary_sents = summary.split(". ")
306
+ # for sent in text_sents:
307
+ # if sent.strip():
308
+ # is_summary = any(sent.strip() in s for s in summary_sents)
309
+ # sentences.append(sent)
310
+ # labels.append(1 if is_summary else 0)
311
+ # return {"sentence": sentences, "label": labels}
312
 
313
+ # dataset = billsum.map(preprocess_for_extractive, batched=True, remove_columns=["text", "summary", "title"])
314
+ # tokenized_dataset = dataset.map(
315
+ # lambda x: tokenizer(x["sentence"], max_length=512, truncation=True, padding="max_length"),
316
+ # batched=True
317
+ # )
318
+ # tokenized_dataset = tokenized_dataset.remove_columns(["sentence"])
319
+ # train_dataset = tokenized_dataset.select(range(int(0.8 * len(tokenized_dataset))))
320
+ # eval_dataset = tokenized_dataset.select(range(int(0.8 * len(tokenized_dataset)), len(tokenized_dataset)))
321
 
322
+ # training_args = TrainingArguments(
323
+ # output_dir="./legalbert_finetune",
324
+ # num_train_epochs=3,
325
+ # per_device_train_batch_size=8,
326
+ # per_device_eval_batch_size=8,
327
+ # warmup_steps=500,
328
+ # weight_decay=0.01,
329
+ # logging_dir="./logs",
330
+ # logging_steps=10,
331
+ # eval_strategy="epoch",
332
+ # save_strategy="epoch",
333
+ # load_best_model_at_end=True,
334
+ # )
335
 
336
+ # trainer = Trainer(
337
+ # model=model,
338
+ # args=training_args,
339
+ # train_dataset=train_dataset,
340
+ # eval_dataset=eval_dataset,
341
+ # )
342
 
343
+ # trainer.train()
344
+ # trainer.save_model(LEGALBERT_MODEL_DIR)
345
+ # tokenizer.save_pretrained(LEGALBERT_MODEL_DIR)
346
+ # print(f"Fine-tuned LegalBERT saved to {LEGALBERT_MODEL_DIR}")
347
 
348
+ # return tokenizer, model
349
 
350
+ # # Load models
351
+ # # pegasus_tokenizer, pegasus_model = load_or_finetune_pegasus()
352
+ # # bert_tokenizer, bert_model = load_or_finetune_bert()
353
+ # # legalbert_tokenizer, legalbert_model = load_or_finetune_legalbert()
354
 
355
+ # def extract_text_from_pdf(file_path):
356
+ # text = ""
357
+ # with pdfplumber.open(file_path) as pdf:
358
+ # for page in pdf.pages:
359
+ # text += page.extract_text() or ""
360
+ # return text
361
 
362
+ # def extract_text_from_image(file_path):
363
+ # image = Image.open(file_path)
364
+ # text = pytesseract.image_to_string(image)
365
+ # return text
366
 
367
+ # def choose_model(text):
368
+ # legal_keywords = ["court", "legal", "law", "judgment", "contract", "statute", "case"]
369
+ # tfidf = TfidfVectorizer(vocabulary=legal_keywords)
370
+ # tfidf_matrix = tfidf.fit_transform([text.lower()])
371
+ # score = np.sum(tfidf_matrix.toarray())
372
+ # if score > 0.1:
373
+ # return "legalbert"
374
+ # elif len(text.split()) > 50:
375
+ # return "pegasus"
376
+ # else:
377
+ # return "bert"
378
 
379
+ # def summarize_with_pegasus(text):
380
+ # inputs = pegasus_tokenizer(text, truncation=True, padding="longest", return_tensors="pt", max_length=512)
381
+ # summary_ids = pegasus_model.generate(
382
+ # inputs["input_ids"],
383
+ # max_length=400, min_length=80, length_penalty=1.5, num_beams=4
384
+ # )
385
+ # return pegasus_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
386
 
387
+ # def summarize_with_bert(text):
388
+ # sentences = text.split(". ")
389
+ # if len(sentences) < 6: # Ensure enough for 5 sentences
390
+ # return text
391
+ # inputs = bert_tokenizer(sentences, return_tensors="pt", padding=True, truncation=True, max_length=512)
392
+ # with torch.no_grad():
393
+ # outputs = bert_model(**inputs)
394
+ # logits = outputs.logits
395
+ # probs = torch.softmax(logits, dim=1)[:, 1] # Probability of being a summary sentence
396
+ # key_sentence_idx = probs.argsort(descending=True)[:5] # Top 5 sentences
397
+ # return ". ".join([sentences[idx] for idx in key_sentence_idx if sentences[idx].strip()])
398
 
399
+ # def summarize_with_legalbert(text):
400
+ # sentences = text.split(". ")
401
+ # if len(sentences) < 6:
402
+ # return text
403
+ # inputs = legalbert_tokenizer(sentences, return_tensors="pt", padding=True, truncation=True, max_length=512)
404
+ # with torch.no_grad():
405
+ # outputs = legalbert_model(**inputs)
406
+ # logits = outputs.logits
407
+ # probs = torch.softmax(logits, dim=1)[:, 1]
408
+ # key_sentence_idx = probs.argsort(descending=True)[:5]
409
+ # return ". ".join([sentences[idx] for idx in key_sentence_idx if sentences[idx].strip()])
410
 
411
+ # # Load Models
412
+ # pegasus_tokenizer, pegasus_model = load_or_finetune_pegasus()
413
+ # bert_tokenizer, bert_model = load_or_finetune_bert()
414
+ # legalbert_tokenizer, legalbert_model = load_or_finetune_legalbert()
415
 
416
+ # @app.route('/summarize', methods=['POST'])
417
+ # def summarize_document():
418
+ # if 'file' not in request.files:
419
+ # return jsonify({"error": "No file uploaded"}), 400
420
 
421
+ # file = request.files['file']
422
+ # filename = file.filename
423
+ # file.seek(0, os.SEEK_END)
424
+ # file_size = file.tell()
425
+ # if file_size > MAX_FILE_SIZE:
426
+ # return jsonify({"error": f"File size exceeds {MAX_FILE_SIZE // (1024 * 1024)} MB"}), 413
427
+ # file.seek(0)
428
+ # file_path = os.path.join(UPLOAD_FOLDER, filename)
429
+ # try:
430
+ # file.save(file_path)
431
+ # except Exception as e:
432
+ # return jsonify({"error": f"Failed to save file: {str(e)}"}), 500
433
 
434
+ # try:
435
+ # if filename.endswith('.pdf'):
436
+ # text = extract_text_from_pdf(file_path)
437
+ # elif filename.endswith(('.png', '.jpeg', '.jpg')):
438
+ # text = extract_text_from_image(file_path)
439
+ # else:
440
+ # os.remove(file_path)
441
+ # return jsonify({"error": "Unsupported file format."}), 400
442
+ # except Exception as e:
443
+ # os.remove(file_path)
444
+ # return jsonify({"error": f"Text extraction failed: {str(e)}"}), 500
445
 
446
+ # if not text.strip():
447
+ # os.remove(file_path)
448
+ # return jsonify({"error": "No text extracted"}), 400
449
 
450
+ # try:
451
+ # model = choose_model(text)
452
+ # if model == "pegasus":
453
+ # summary = summarize_with_pegasus(text)
454
+ # elif model == "bert":
455
+ # summary = summarize_with_bert(text)
456
+ # elif model == "legalbert":
457
+ # summary = summarize_with_legalbert(text)
458
+ # except Exception as e:
459
+ # os.remove(file_path)
460
+ # return jsonify({"error": f"Summarization failed: {str(e)}"}), 500
461
 
462
+ # os.remove(file_path)
463
+ # return jsonify({"model_used": model, "summary": summary})
464
 
465
+ # if __name__ == '__main__':
466
+ # port = int(os.environ.get("PORT", 5000))
467
+ # app.run(debug=False, host='0.0.0.0', port=port)