aziac commited on
Commit
6bc7bcc
·
1 Parent(s): 7017186

added solve captcha endpoint

Browse files
.DS_Store CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
 
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ model/saved_model.pb filter=lfs diff=lfs merge=lfs -text
37
+ model/variables/variables.data-00000-of-00001 filter=lfs diff=lfs merge=lfs -text
38
+ model/** filter=lfs diff=lfs merge=lfs -text
app/main.py CHANGED
@@ -1,11 +1,57 @@
1
  import os
2
  import random
3
  from pathlib import Path
 
 
 
4
  from fastapi import FastAPI, HTTPException
5
  from fastapi.middleware.cors import CORSMiddleware
 
 
6
 
7
- # Initialize the FastAPI app
8
- app = FastAPI()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  # --- CORS Middleware ---
11
  app.add_middleware(
@@ -17,30 +63,97 @@ app.add_middleware(
17
  )
18
 
19
  # --- Constants ---
20
- # Define the path to the directory containing captcha images
21
  IMAGE_DIR = Path("static/images")
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  # --- API Endpoints ---
25
  @app.get("/")
26
  async def read_root():
27
- """A simple root endpoint to check if the API is running."""
28
  return {"message": "Welcome to the Captcha Solver API!"}
29
 
30
-
31
  @app.get("/get_captcha")
32
  async def get_captcha():
33
- """
34
- Returns the filename of a random captcha image from the static/images directory.
35
- """
36
  if not IMAGE_DIR.is_dir():
37
- raise HTTPException(status_code=500, detail="Image directory not found on server.")
38
-
39
- image_files = [f for f in os.listdir(IMAGE_DIR) if f.endswith(('.png'))]
40
-
41
  if not image_files:
42
  raise HTTPException(status_code=404, detail="No captcha images found.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
- random_image_filename = random.choice(image_files)
 
 
 
 
 
 
45
 
46
- return {"filename": random_image_filename}
 
 
 
 
1
  import os
2
  import random
3
  from pathlib import Path
4
+ import numpy as np
5
+ import tensorflow as tf
6
+ from PIL import Image
7
  from fastapi import FastAPI, HTTPException
8
  from fastapi.middleware.cors import CORSMiddleware
9
+ from pydantic import BaseModel
10
+ from contextlib import asynccontextmanager
11
 
12
+ # --- Pydantic Models for Request Body ---
13
+ class CaptchaRequest(BaseModel):
14
+ filename: str
15
+
16
+ # --- Global Variables ---
17
+ # This will hold our loaded prediction model
18
+ prediction_model = None
19
+
20
+ # --- Configuration based on your Training Notebook ---
21
+
22
+ # 1. CHARACTER SET
23
+ data_dir = Path("./static/images/")
24
+ images = sorted(list(map(str, list(data_dir.glob("*.png")))))
25
+ labels = [img.split(os.path.sep)[-1].split(".png")[0] for img in images]
26
+ characters = set(char for label in labels for char in label)
27
+ CHARACTERS = sorted(list(characters))
28
+
29
+ # 2. IMAGE DIMENSIONS
30
+ # These dimensions are taken directly from your notebook.
31
+ IMG_WIDTH = 200
32
+ IMG_HEIGHT = 50
33
+
34
+ # --- App Lifespan Management (Model Loading) ---
35
+ @asynccontextmanager
36
+ async def lifespan(app: FastAPI):
37
+ # Code to run on startup
38
+ print("INFO: Loading TensorFlow prediction model...")
39
+ global prediction_model
40
+ try:
41
+ # NOTE: Ensure you save the `prediction_model` from your notebook,
42
+ # not the multi-input training `model`.
43
+ prediction_model = tf.saved_model.load('model')
44
+ print("INFO: TensorFlow model loaded successfully.")
45
+ except Exception as e:
46
+ print(f"ERROR: Failed to load model: {e}")
47
+ prediction_model = None
48
+ yield
49
+ # Code to run on shutdown
50
+ print("INFO: Application shutting down.")
51
+
52
+
53
+ # Initialize the FastAPI app with the lifespan manager
54
+ app = FastAPI(lifespan=lifespan)
55
 
56
  # --- CORS Middleware ---
57
  app.add_middleware(
 
63
  )
64
 
65
  # --- Constants ---
 
66
  IMAGE_DIR = Path("static/images")
67
 
68
+ # --- Helper Functions based on your Notebook ---
69
+
70
+ def preprocess_image(image_path):
71
+ """
72
+ Loads and preprocesses an image for model prediction based on the notebook's
73
+ `encode_single_sample` function.
74
+ """
75
+ try:
76
+ # 1. Read image, convert to grayscale
77
+ img = Image.open(image_path).convert('L') #
78
+ # 2. Resize to the desired size (width, height)
79
+ img = img.resize((IMG_WIDTH, IMG_HEIGHT)) #
80
+ # 3. Convert to numpy array of float32 in [0, 1] range
81
+ img = np.array(img, dtype=np.float32) / 255.0 #
82
+
83
+ # 4. Transpose the image because the RNN part of the model expects the time
84
+ # dimension to correspond to the width of the image.
85
+ # The notebook does this with `ops.transpose(img, axes=[1, 0, 2])`.
86
+ # Here, a numpy array of shape (height, width) becomes (width, height).
87
+ img = img.T
88
+
89
+ # 5. Add channel and batch dimensions
90
+ img = np.expand_dims(img, axis=-1) # Add channel -> (width, height, 1)
91
+ img = np.expand_dims(img, axis=0) # Add batch -> (1, width, height, 1)
92
+
93
+ return img
94
+ except Exception as e:
95
+ print(f"Error preprocessing image {image_path}: {e}")
96
+ return None
97
+
98
+ def decode_prediction(pred):
99
+ """
100
+ Decodes the raw model output into a human-readable string using CTC decoding,
101
+ mirroring the notebook's `decode_batch_predictions` function.
102
+ """
103
+ # 1. Get the input length (number of timesteps)
104
+ input_len = np.ones(pred.shape[0]) * pred.shape[1]
105
+
106
+ # 2. Use Keras's CTC decoder (greedy search is sufficient and fast)
107
+ # This is equivalent to `tf.nn.ctc_greedy_decoder` used in the notebook.
108
+ results = tf.keras.backend.ctc_decode(pred, input_length=input_len, greedy=True)[0][0]
109
+
110
+ # 3. Iterate over the results and convert back to text
111
+ output_text = ""
112
+ for res in results.numpy():
113
+ # The `CHARACTERS` list maps indices to characters.
114
+ # -1 is the default padding value from ctc_decode.
115
+ if res != -1 and res < len(CHARACTERS):
116
+ output_text += CHARACTERS[res]
117
+
118
+ return output_text
119
 
120
  # --- API Endpoints ---
121
  @app.get("/")
122
  async def read_root():
 
123
  return {"message": "Welcome to the Captcha Solver API!"}
124
 
 
125
  @app.get("/get_captcha")
126
  async def get_captcha():
 
 
 
127
  if not IMAGE_DIR.is_dir():
128
+ raise HTTPException(status_code=500, detail="Image directory not found.")
129
+ image_files = [f for f in os.listdir(IMAGE_DIR) if f.endswith(('.png', '.jpg', '.jpeg'))]
 
 
130
  if not image_files:
131
  raise HTTPException(status_code=404, detail="No captcha images found.")
132
+ return {"filename": random.choice(image_files)}
133
+
134
+ @app.post("/solve_captcha")
135
+ async def solve_captcha(request: CaptchaRequest):
136
+ if prediction_model is None:
137
+ raise HTTPException(status_code=503, detail="Model is not loaded or failed to load.")
138
+
139
+ image_path = IMAGE_DIR / request.filename
140
+ if not image_path.is_file():
141
+ raise HTTPException(status_code=404, detail=f"File '{request.filename}' not found.")
142
+
143
+ # Preprocess the image according to the notebook's logic
144
+ processed_image = preprocess_image(image_path)
145
+ if processed_image is None:
146
+ raise HTTPException(status_code=500, detail="Failed to process the image.")
147
 
148
+ try:
149
+ # Get model prediction by calling the loaded model directly
150
+ # The `prediction_model` from the notebook expects only the image as input.
151
+ preds = prediction_model(tf.constant(processed_image))
152
+
153
+ # Decode the prediction
154
+ predicted_label = decode_prediction(preds)
155
 
156
+ return {"prediction": predicted_label}
157
+ except Exception as e:
158
+ print(f"Error during prediction: {e}")
159
+ raise HTTPException(status_code=500, detail=f"An error occurred during model inference: {e}")
model/.DS_Store ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ce35a183b313defdf28e0e5a7cfb29468a17bb0d9b42f1ef75f4e366851478f7
3
+ size 6148
model/fingerprint.pb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:97e29e2ce27e4c2d1d1273f0cdb069a094ecdbea21a6559d82fbd34ed9c17b4b
3
+ size 78
model/saved_model.pb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cd3fd69880e1b68390152c8236dd3000a7abdde05867d6f2074d120dbfdd6c17
3
+ size 269319
model/variables/variables.data-00000-of-00001 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b5bd622ed42679af9c2142e8c79621c1fe209608c3061414e1869c207df6b609
3
+ size 3467858
model/variables/variables.index ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a326c821608827e2a07f3ccb56d4b74ac4b172245d5b97a9d23832a6fd87ea37
3
+ size 2907
requirements.txt CHANGED
@@ -1,3 +1,6 @@
1
  fastapi
2
  uvicorn[standard]
3
- python-multipart
 
 
 
 
1
  fastapi
2
  uvicorn[standard]
3
+ python-multipart
4
+ tensorflow
5
+ numpy
6
+ Pillow