Spaces:
Sleeping
Sleeping
| import time | |
| from glob import glob | |
| from pathlib import Path | |
| from typing import List | |
| from fastapi import FastAPI, File, Form, UploadFile | |
| from fastapi.responses import JSONResponse, Response | |
| from tqdm import tqdm | |
| from utils import * | |
| from concrete.ml.deployment import FHEModelClient, FHEModelServer | |
| # Load the FHE server | |
| # Initialize an instance of FastAPI | |
| app = FastAPI() | |
| # Define the default route | |
| def root(): | |
| """ | |
| Root endpoint of the health prediction API. | |
| Returns: | |
| dict: The welcome message. | |
| """ | |
| return {"message": "Welcome to your encrypted DNA testing use-case with FHE!"} | |
| def send_input( | |
| user_id: str = Form(...), root_dir: str = Form(...), files: List[UploadFile] = File(...) | |
| ): | |
| """Send the inputs to the server.""" | |
| print("------------ Step 3.2: Send the data to the server") | |
| print(f"{user_id=}, {root_dir=}, {len(files)=}") | |
| SERVER_DIR = Path(root_dir) / f"{user_id}/server" | |
| SERVER_KEY_SMOOTHER_MODULE_DIR = SERVER_DIR / KEY_SMOOTHER_MODULE_DIR | |
| SERVER_KEY_BASE_MODULE_DIR = SERVER_DIR / KEY_BASE_MODULE_DIR | |
| SERVER_ENCRYPTED_INPUT_DIR = SERVER_DIR / ENCRYPTED_INPUT_DIR | |
| # Save the files using the above paths | |
| with (SERVER_KEY_BASE_MODULE_DIR / "eval_key").open("wb") as eval_key_1: | |
| eval_key_1.write(files[0].file.read()) | |
| with (SERVER_KEY_SMOOTHER_MODULE_DIR / "eval_key").open("wb") as eval_key_2: | |
| eval_key_2.write(files[1].file.read()) | |
| print(f"{len(files)=}") | |
| for i in tqdm(range(2, len(files))): | |
| with (SERVER_ENCRYPTED_INPUT_DIR / f"encrypted_window_{i}").open("wb") as eval_key_2: | |
| eval_key_2.write(files[i].file.read()) | |
| def run_fhe( | |
| user_id: str = Form(), | |
| root_dir: str = Form(...), | |
| ): | |
| """Inference in FHE.""" | |
| print("------------ Step 4.2: Run in FHE on the Server Side") | |
| print(f"{user_id=}, {root_dir=}") | |
| SERVER_DIR = Path(root_dir) / f"{user_id}/server" | |
| SERVER_KEY_SMOOTHER_MODULE_DIR = SERVER_DIR / KEY_SMOOTHER_MODULE_DIR | |
| SERVER_KEY_BASE_MODULE_DIR = SERVER_DIR / KEY_BASE_MODULE_DIR | |
| SERVER_ENCRYPTED_INPUT_DIR = SERVER_DIR / ENCRYPTED_INPUT_DIR | |
| SERVER_ENCRYPTED_OUTPUT_DIR = SERVER_DIR / ENCRYPTED_OUTPUT_DIR | |
| with (SERVER_KEY_BASE_MODULE_DIR / "eval_key").open("rb") as eval_key_1: | |
| eval_key_base_module = eval_key_1.read() | |
| assert isinstance(eval_key_base_module, bytes) | |
| with (SERVER_KEY_SMOOTHER_MODULE_DIR / "eval_key").open("rb") as eval_key_2: | |
| eval_key_smoother_module = eval_key_2.read() | |
| assert isinstance(eval_key_smoother_module, bytes) | |
| shared_base_modules_path = glob(f"{SHARED_BASE_MODULE_DIR}/model_*") | |
| shared_base_modules_path = sorted(shared_base_modules_path, key=extract_model_number) | |
| print(f"{len(shared_base_modules_path)=}") | |
| assert len(shared_base_modules_path) == META["NW"] | |
| client_encrypted_input_path = glob(f"{SERVER_ENCRYPTED_INPUT_DIR}/encrypted_window_*") | |
| client_encrypted_input_path = sorted(client_encrypted_input_path, key=extract_model_number) | |
| print(f"{len(client_encrypted_input_path)=}") | |
| assert len(shared_base_modules_path) == META["NW"] | |
| nb_total_iterations = META["NW"] * 2 | |
| start_time = time.time() | |
| y_proba = [] | |
| for i, (model_path, encrypted_window_path) in tqdm( | |
| enumerate(zip(shared_base_modules_path, client_encrypted_input_path)) | |
| ): | |
| server = FHEModelServer(model_path) | |
| with open(encrypted_window_path, "rb") as f: | |
| encrypted_window = f.read() | |
| encrypted_output = server.run( | |
| encrypted_window, serialized_evaluation_keys=eval_key_base_module | |
| ) | |
| assert isinstance(encrypted_output, bytes) | |
| client = FHEModelClient(model_path, key_dir=model_path) | |
| decrypted_output = client.deserialize_decrypt_dequantize(encrypted_output) | |
| with (SERVER_ENCRYPTED_OUTPUT_DIR / f"decrypted_window_{i}").open("wb") as f: | |
| f.write(encrypted_window) | |
| y_proba.append(decrypted_output) | |
| with open(FHE_COMPUTATION_TIMELINE, "w", encoding="utf-8") as f: | |
| f.write(f"{time.time() - start_time:.2f} seconds ({(i + 1)/nb_total_iterations:.0%})") | |
| client = FHEModelClient(SHARED_SMOOTHER_MODULE_DIR, key_dir=SHARED_SMOOTHER_MODULE_DIR) | |
| server = FHEModelServer(SHARED_SMOOTHER_MODULE_DIR) | |
| y_proba = numpy.transpose(numpy.array(y_proba), (1, 0, 2)) | |
| y_proba = y_proba.astype(numpy.int8) | |
| print(f"{y_proba.shape=}, {type(y_proba)}") | |
| X_slide, _ = slide_window(y_proba, META["SS"]) | |
| yhat_encrypted = [] | |
| for i in tqdm(range(len(X_slide))): | |
| input = X_slide[i].reshape(1, -1) | |
| encrypted_input = client.quantize_encrypt_serialize(input) | |
| encrypted_output = server.run( | |
| encrypted_input, serialized_evaluation_keys=eval_key_smoother_module | |
| ) | |
| # output = client.deserialize_decrypt_dequantize(encrypted_output) | |
| # y_pred = numpy.argmax(output, axis=-1)[0] | |
| yhat_encrypted.append(encrypted_output) | |
| with open(FHE_COMPUTATION_TIMELINE, "w", encoding="utf-8") as f: | |
| f.write(f"{time.time() - start_time:.2f} seconds ({(i + 1)/nb_total_iterations:.0%})") | |
| write_pickle(SERVER_ENCRYPTED_OUTPUT_DIR / "encrypted_final_output.pkl", yhat_encrypted) | |
| fhe_execution_time = round(time.time() - start_time, 2) | |
| return JSONResponse(content=fhe_execution_time) | |
| def run_fhe_stage1( | |
| user_id: str = Form(), | |
| root_dir: str = Form(...), | |
| ): | |
| """Inference in FHE.""" | |
| print("------------ Step 4.2: Run in FHE on the Server Side") | |
| print(f"{user_id=}, {root_dir=}") | |
| SERVER_DIR = Path(root_dir) / f"{user_id}/server" | |
| SERVER_KEY_SMOOTHER_MODULE_DIR = SERVER_DIR / KEY_SMOOTHER_MODULE_DIR | |
| SERVER_KEY_BASE_MODULE_DIR = SERVER_DIR / KEY_BASE_MODULE_DIR | |
| SERVER_ENCRYPTED_INPUT_DIR = SERVER_DIR / ENCRYPTED_INPUT_DIR | |
| SERVER_ENCRYPTED_OUTPUT_DIR = SERVER_DIR / ENCRYPTED_OUTPUT_DIR | |
| with (SERVER_KEY_BASE_MODULE_DIR / "eval_key").open("rb") as eval_key_1: | |
| eval_key_base_module = eval_key_1.read() | |
| assert isinstance(eval_key_base_module, bytes) | |
| with (SERVER_KEY_SMOOTHER_MODULE_DIR / "eval_key").open("rb") as eval_key_2: | |
| eval_key_smoother_module = eval_key_2.read() | |
| assert isinstance(eval_key_smoother_module, bytes) | |
| shared_base_modules_path = glob(f"{SHARED_BASE_MODULE_DIR}/model_*") | |
| shared_base_modules_path = sorted(shared_base_modules_path, key=extract_model_number) | |
| print(f"{len(shared_base_modules_path)=}") | |
| assert len(shared_base_modules_path) == META["NW"] | |
| client_encrypted_input_path = glob(f"{SERVER_ENCRYPTED_INPUT_DIR}/encrypted_window_*") | |
| client_encrypted_input_path = sorted(client_encrypted_input_path, key=extract_model_number) | |
| print(f"{len(client_encrypted_input_path)=}") | |
| assert len(shared_base_modules_path) == META["NW"] | |
| start = time.time() | |
| y_proba = [] | |
| for i, (model_path, encrypted_window_path) in tqdm( | |
| enumerate(zip(shared_base_modules_path, client_encrypted_input_path)) | |
| ): | |
| server = FHEModelServer(model_path) | |
| with open(encrypted_window_path, "rb") as f: | |
| encrypted_window = f.read() | |
| encrypted_output = server.run( | |
| encrypted_window, serialized_evaluation_keys=eval_key_base_module | |
| ) | |
| assert isinstance(encrypted_output, bytes) | |
| client = FHEModelClient(model_path, key_dir=model_path) | |
| decrypted_output = client.deserialize_decrypt_dequantize(encrypted_output) | |
| with (SERVER_ENCRYPTED_OUTPUT_DIR / f"decrypted_window_{i}").open("wb") as f: | |
| f.write(encrypted_window) | |
| y_proba.append(decrypted_output) | |
| client = FHEModelClient(SHARED_SMOOTHER_MODULE_DIR, key_dir=SHARED_SMOOTHER_MODULE_DIR) | |
| server = FHEModelServer(SHARED_SMOOTHER_MODULE_DIR) | |
| y_proba = numpy.transpose(numpy.array(y_proba), (1, 0, 2)) | |
| y_proba = y_proba.astype(numpy.int8) | |
| print(f"{y_proba.shape=}, {type(y_proba)}") | |
| X_slide, _ = slide_window(y_proba, META["SS"]) | |
| yhat_encrypted = [] | |
| for i in tqdm(range(len(X_slide))): | |
| input = X_slide[i].reshape(1, -1) | |
| encrypted_input = client.quantize_encrypt_serialize(input) | |
| encrypted_output = server.run( | |
| encrypted_input, serialized_evaluation_keys=eval_key_smoother_module | |
| ) | |
| # output = client.deserialize_decrypt_dequantize(encrypted_output) | |
| # y_pred = numpy.argmax(output, axis=-1)[0] | |
| yhat_encrypted.append(encrypted_output) | |
| write_pickle(SERVER_ENCRYPTED_OUTPUT_DIR / "encrypted_final_output.pkl", yhat_encrypted) | |
| fhe_execution_time = round(time.time() - start, 2) | |
| return JSONResponse(content=fhe_execution_time) | |
| def run_fhe_stage2( | |
| user_id: str = Form(), | |
| root_dir: str = Form(...), | |
| ): | |
| """Inference in FHE.""" | |
| print("------------ Step 4.2: Run in FHE on the Server Side") | |
| print(f"{user_id=}, {root_dir=}") | |
| SERVER_DIR = Path(root_dir) / f"{user_id}/server" | |
| SERVER_KEY_SMOOTHER_MODULE_DIR = SERVER_DIR / KEY_SMOOTHER_MODULE_DIR | |
| SERVER_KEY_BASE_MODULE_DIR = SERVER_DIR / KEY_BASE_MODULE_DIR | |
| SERVER_ENCRYPTED_INPUT_DIR = SERVER_DIR / ENCRYPTED_INPUT_DIR | |
| SERVER_ENCRYPTED_OUTPUT_DIR = SERVER_DIR / ENCRYPTED_OUTPUT_DIR | |
| with (SERVER_KEY_BASE_MODULE_DIR / "eval_key").open("rb") as eval_key_1: | |
| eval_key_base_module = eval_key_1.read() | |
| assert isinstance(eval_key_base_module, bytes) | |
| with (SERVER_KEY_SMOOTHER_MODULE_DIR / "eval_key").open("rb") as eval_key_2: | |
| eval_key_smoother_module = eval_key_2.read() | |
| assert isinstance(eval_key_smoother_module, bytes) | |
| shared_base_modules_path = glob(f"{SHARED_BASE_MODULE_DIR}/model_*") | |
| shared_base_modules_path = sorted(shared_base_modules_path, key=extract_model_number) | |
| print(f"{len(shared_base_modules_path)=}") | |
| assert len(shared_base_modules_path) == META["NW"] | |
| client_encrypted_input_path = glob(f"{SERVER_ENCRYPTED_INPUT_DIR}/encrypted_window_*") | |
| client_encrypted_input_path = sorted(client_encrypted_input_path, key=extract_model_number) | |
| print(f"{len(client_encrypted_input_path)=}") | |
| assert len(shared_base_modules_path) == META["NW"] | |
| start = time.time() | |
| y_proba = [] | |
| for i, (model_path, encrypted_window_path) in tqdm( | |
| enumerate(zip(shared_base_modules_path, client_encrypted_input_path)) | |
| ): | |
| server = FHEModelServer(model_path) | |
| with open(encrypted_window_path, "rb") as f: | |
| encrypted_window = f.read() | |
| encrypted_output = server.run( | |
| encrypted_window, serialized_evaluation_keys=eval_key_base_module | |
| ) | |
| assert isinstance(encrypted_output, bytes) | |
| client = FHEModelClient(model_path, key_dir=model_path) | |
| decrypted_output = client.deserialize_decrypt_dequantize(encrypted_output) | |
| with (SERVER_ENCRYPTED_OUTPUT_DIR / f"decrypted_window_{i}").open("wb") as f: | |
| f.write(encrypted_window) | |
| y_proba.append(decrypted_output) | |
| client = FHEModelClient(SHARED_SMOOTHER_MODULE_DIR, key_dir=SHARED_SMOOTHER_MODULE_DIR) | |
| server = FHEModelServer(SHARED_SMOOTHER_MODULE_DIR) | |
| y_proba = numpy.transpose(numpy.array(y_proba), (1, 0, 2)) | |
| y_proba = y_proba.astype(numpy.int8) | |
| print(f"{y_proba.shape=}, {type(y_proba)}") | |
| X_slide, _ = slide_window(y_proba, META["SS"]) | |
| yhat_encrypted = [] | |
| for i in tqdm(range(len(X_slide))): | |
| input = X_slide[i].reshape(1, -1) | |
| encrypted_input = client.quantize_encrypt_serialize(input) | |
| encrypted_output = server.run( | |
| encrypted_input, serialized_evaluation_keys=eval_key_smoother_module | |
| ) | |
| # output = client.deserialize_decrypt_dequantize(encrypted_output) | |
| # y_pred = numpy.argmax(output, axis=-1)[0] | |
| yhat_encrypted.append(encrypted_output) | |
| write_pickle(SERVER_ENCRYPTED_OUTPUT_DIR / "encrypted_final_output.pkl", yhat_encrypted) | |
| fhe_execution_time = round(time.time() - start, 2) | |
| return JSONResponse(content=fhe_execution_time) | |
| def get_output(user_id: str = Form(), root_dir: str = Form()): | |
| """Retrieve the encrypted output from the server.""" | |
| print("\nStep 5.2: Get the output from the server ............\n") | |
| SERVER_DIR = Path(root_dir) / f"{user_id}/server" | |
| SERVER_ENCRYPTED_OUTPUT_DIR = SERVER_DIR / ENCRYPTED_OUTPUT_DIR | |
| yhat_encrypted = load_pickle(SERVER_ENCRYPTED_OUTPUT_DIR / "encrypted_final_output.pkl") | |
| CLIENT_DIR = Path(root_dir) / f"{user_id}/client" | |
| CLIENT_ENCRYPTED_OUTPUT_DIR = CLIENT_DIR / ENCRYPTED_OUTPUT_DIR | |
| write_pickle(CLIENT_ENCRYPTED_OUTPUT_DIR / "encrypted_final_output.pkl", yhat_encrypted) | |
| assert len(yhat_encrypted) == META["NW"] | |
| time.sleep(1) | |
| # Send the encrypted output | |
| return Response("OK") | |