Spaces:
Build error
Build error
| from flask import Flask, request, jsonify, render_template | |
| from backend_utils import initialize_all_components, make_predictions | |
| from config import classifier_class_mapping, config | |
| import subprocess | |
| # from flask_cors import CORS, cross_origin | |
| import json | |
| # todo: downgrade version sklearn to 1.0.2 | |
| app = Flask(__name__) | |
| # CORS(app) | |
| components = initialize_all_components(config) | |
| db_metadata = components[0] | |
| db_constructor = components[1] | |
| db_params = components[2] | |
| ex_list = components[3] | |
| model_retrieval = components[4] | |
| model_generative = components[5] | |
| tokenizer_generative = components[6] | |
| model_classifier = components[7] | |
| classifier_head = components[8] | |
| tokenizer_classifier = components[9] | |
| def call_predict_api( | |
| input_query, | |
| model_retrieval, | |
| model_generative, | |
| model_classifier, classifier_head, | |
| tokenizer_generative, tokenizer_classifier, | |
| db_metadata, db_constructor, db_params, ex_list, | |
| config | |
| ): | |
| ''' | |
| wrapper to the make prediction function | |
| ''' | |
| predictions = make_predictions( | |
| input_query, | |
| model_retrieval, | |
| model_generative, | |
| model_classifier, classifier_head, | |
| tokenizer_generative, tokenizer_classifier, | |
| db_metadata, db_constructor, db_params, ex_list, | |
| config | |
| ) | |
| return predictions | |
| def hello_world(): | |
| return render_template("index.html") | |
| def predict(): | |
| #request_data = request.get_json() | |
| #user_query = request_data.get('user_query', None) | |
| user_query = request.args.get("user_query") | |
| print(f"user_query: {user_query}") | |
| if user_query != None: | |
| print("predicting") | |
| predictions = call_predict_api( | |
| user_query, | |
| model_retrieval, | |
| model_generative, | |
| model_classifier, classifier_head, | |
| tokenizer_generative, tokenizer_classifier, | |
| db_metadata, db_constructor, db_params, ex_list, | |
| config | |
| ) | |
| # print(predictions) | |
| if type(predictions) == str: | |
| if predictions == 'null': | |
| return jsonify({'predictions': 'null'}) | |
| return jsonify({ | |
| 'predictions': predictions | |
| }) | |
| if __name__ == '__main__': | |
| app.run(host="0.0.0.0", port=7860) | |