aziac commited on
Commit
caa7998
·
1 Parent(s): 805ac97

pretrained method

Browse files
app/main.py CHANGED
@@ -3,51 +3,57 @@ import random
3
  from pathlib import Path
4
  import numpy as np
5
  import tensorflow as tf
6
- import keras
7
- from PIL import Image
8
  from fastapi import FastAPI, HTTPException
9
  from fastapi.middleware.cors import CORSMiddleware
10
  from pydantic import BaseModel
11
  from contextlib import asynccontextmanager
12
 
 
 
 
13
  # --- Pydantic Models for Request Body ---
14
  class CaptchaRequest(BaseModel):
15
  filename: str
16
 
17
  # --- Global Variables ---
18
- # This will hold our loaded prediction model
19
  prediction_model = None
 
 
20
 
21
- # --- Configuration based on your Training Notebook ---
22
-
23
- # 1. CHARACTER SET
24
- data_dir = Path("./static/images/")
25
- images = sorted(list(map(str, list(data_dir.glob("*.png")))))
26
- labels = [img.split(os.path.sep)[-1].split(".png")[0] for img in images]
27
- characters = set(char for label in labels for char in label)
28
- CHARACTERS = sorted(list(characters))
29
-
30
- # 2. IMAGE DIMENSIONS
31
- # These dimensions are taken directly from your notebook.
32
  IMG_WIDTH = 200
33
  IMG_HEIGHT = 50
34
 
35
  # --- App Lifespan Management (Model Loading) ---
36
  @asynccontextmanager
37
  async def lifespan(app: FastAPI):
38
- # Code to run on startup
39
- print("INFO: Loading TensorFlow prediction model...")
40
- global prediction_model
41
  try:
42
- # NOTE: Ensure you save the `prediction_model` from your notebook,
43
- # not the multi-input training `model`.
44
- prediction_model = keras.layers.TFSMLayer('model', call_endpoint='serving_default')
45
- print("INFO: TensorFlow model loaded successfully.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  except Exception as e:
47
- print(f"ERROR: Failed to load model: {e}")
48
  prediction_model = None
49
  yield
50
- # Code to run on shutdown
51
  print("INFO: Application shutting down.")
52
 
53
 
@@ -55,67 +61,23 @@ async def lifespan(app: FastAPI):
55
  app = FastAPI(lifespan=lifespan)
56
 
57
  # --- CORS Middleware ---
58
- app.add_middleware(
59
- CORSMiddleware,
60
- allow_origins=["*"],
61
- allow_credentials=True,
62
- allow_methods=["*"],
63
- allow_headers=["*"],
64
- )
65
 
66
  # --- Constants ---
67
  IMAGE_DIR = Path("static/images")
68
 
69
- # --- Helper Functions based on your Notebook ---
70
 
71
- def preprocess_image(image_path):
72
- """
73
- Loads and preprocesses an image for model prediction based on the notebook's
74
- `encode_single_sample` function.
75
- """
76
- try:
77
- # 1. Read image, convert to grayscale
78
- img = Image.open(image_path).convert('L') #
79
- # 2. Resize to the desired size (width, height)
80
- img = img.resize((IMG_WIDTH, IMG_HEIGHT)) #
81
- # 3. Convert to numpy array of float32 in [0, 1] range
82
- img = np.array(img, dtype=np.float32) / 255.0 #
83
-
84
- # 4. Transpose the image because the RNN part of the model expects the time
85
- # dimension to correspond to the width of the image.
86
- # The notebook does this with `ops.transpose(img, axes=[1, 0, 2])`.
87
- # Here, a numpy array of shape (height, width) becomes (width, height).
88
- img = img.T
89
-
90
- # 5. Add channel and batch dimensions
91
- img = np.expand_dims(img, axis=-1) # Add channel -> (width, height, 1)
92
- img = np.expand_dims(img, axis=0) # Add batch -> (1, width, height, 1)
93
-
94
- return img
95
- except Exception as e:
96
- print(f"Error preprocessing image {image_path}: {e}")
97
- return None
98
-
99
- def decode_prediction(pred):
100
- """
101
- Decodes the raw model output into a human-readable string using CTC decoding,
102
- mirroring the notebook's `decode_batch_predictions` function.
103
- """
104
- # 1. Get the input length (number of timesteps)
105
  input_len = np.ones(pred.shape[0]) * pred.shape[1]
106
-
107
- # 2. Use Keras's CTC decoder (greedy search is sufficient and fast)
108
- # This is equivalent to `tf.nn.ctc_greedy_decoder` used in the notebook.
109
- results = tf.keras.backend.ctc_decode(pred, input_length=input_len, greedy=True)[0][0]
110
-
111
- # 3. Iterate over the results and convert back to text
112
- output_text = ""
113
- for res in results.numpy():
114
- # The `CHARACTERS` list maps indices to characters.
115
- # -1 is the default padding value from ctc_decode.
116
- if res != -1 and res < len(CHARACTERS):
117
- output_text += CHARACTERS[res]
118
-
119
  return output_text
120
 
121
  # --- API Endpoints ---
@@ -134,27 +96,38 @@ async def get_captcha():
134
 
135
  @app.post("/solve_captcha")
136
  async def solve_captcha(request: CaptchaRequest):
137
- if prediction_model is None:
138
- raise HTTPException(status_code=503, detail="Model is not loaded or failed to load.")
139
 
140
  image_path = IMAGE_DIR / request.filename
141
  if not image_path.is_file():
142
  raise HTTPException(status_code=404, detail=f"File '{request.filename}' not found.")
143
 
144
- # Preprocess the image according to the notebook's logic
145
- processed_image = preprocess_image(image_path)
146
- if processed_image is None:
147
- raise HTTPException(status_code=500, detail="Failed to process the image.")
148
-
149
  try:
150
- # A TFSMLayer is a callable Keras layer.
151
- # We can call it directly with our input numpy array.
152
- preds = prediction_model(processed_image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
- # Decode the prediction
155
- predicted_label = decode_prediction(preds)
156
 
157
- return {"prediction": predicted_label}
158
  except Exception as e:
159
  print(f"Error during prediction: {e}")
160
  raise HTTPException(status_code=500, detail=f"An error occurred during model inference: {e}")
 
3
  from pathlib import Path
4
  import numpy as np
5
  import tensorflow as tf
6
+ from tensorflow import keras
7
+ from tensorflow.keras import layers
8
  from fastapi import FastAPI, HTTPException
9
  from fastapi.middleware.cors import CORSMiddleware
10
  from pydantic import BaseModel
11
  from contextlib import asynccontextmanager
12
 
13
+ # New import for the pre-trained model
14
+ from huggingface_hub import from_pretrained_keras
15
+
16
  # --- Pydantic Models for Request Body ---
17
  class CaptchaRequest(BaseModel):
18
  filename: str
19
 
20
  # --- Global Variables ---
 
21
  prediction_model = None
22
+ num_to_char = None
23
+ max_length = 5 # From your Gradio script
24
 
25
+ # --- Configuration for the pre-trained "keras-io/ocr-for-captcha" model ---
 
 
 
 
 
 
 
 
 
 
26
  IMG_WIDTH = 200
27
  IMG_HEIGHT = 50
28
 
29
  # --- App Lifespan Management (Model Loading) ---
30
  @asynccontextmanager
31
  async def lifespan(app: FastAPI):
32
+ global prediction_model, num_to_char
 
 
33
  try:
34
+ print("INFO: Loading pre-trained Keras model and vocab...")
35
+
36
+ # 1. Load the base model from Hugging Face Hub
37
+ base_model = from_pretrained_keras("keras-io/ocr-for-captcha", compile=False)
38
+
39
+ # 2. Create the inference-only prediction_model (from your Gradio script)
40
+ prediction_model = keras.models.Model(
41
+ base_model.get_layer(name="image").input, base_model.get_layer(name="dense2").output
42
+ )
43
+
44
+ # 3. Load the vocabulary from the file
45
+ with open("vocab.txt", "r") as f:
46
+ vocab = f.read().splitlines()
47
+
48
+ # 4. Create the character mapping layer (from your Gradio script)
49
+ num_to_char = layers.StringLookup(vocabulary=vocab, mask_token=None, invert=True)
50
+
51
+ print("INFO: Model and vocab loaded successfully.")
52
+
53
  except Exception as e:
54
+ print(f"ERROR: Failed to load pre-trained model or vocab: {e}")
55
  prediction_model = None
56
  yield
 
57
  print("INFO: Application shutting down.")
58
 
59
 
 
61
  app = FastAPI(lifespan=lifespan)
62
 
63
  # --- CORS Middleware ---
64
+ app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"])
 
 
 
 
 
 
65
 
66
  # --- Constants ---
67
  IMAGE_DIR = Path("static/images")
68
 
69
+ # --- Helper Functions (from your Gradio script) ---
70
 
71
+ def decode_batch_predictions(pred):
72
+ # This function is directly from your Gradio script
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  input_len = np.ones(pred.shape[0]) * pred.shape[1]
74
+ results = keras.backend.ctc_decode(pred, input_length=input_len, greedy=True)[0][0][
75
+ :, :max_length
76
+ ]
77
+ output_text = []
78
+ for res in results:
79
+ res = tf.strings.reduce_join(num_to_char(res)).numpy().decode("utf-8")
80
+ output_text.append(res)
 
 
 
 
 
 
81
  return output_text
82
 
83
  # --- API Endpoints ---
 
96
 
97
  @app.post("/solve_captcha")
98
  async def solve_captcha(request: CaptchaRequest):
99
+ if prediction_model is None or num_to_char is None:
100
+ raise HTTPException(status_code=503, detail="Model or vocab is not loaded.")
101
 
102
  image_path = IMAGE_DIR / request.filename
103
  if not image_path.is_file():
104
  raise HTTPException(status_code=404, detail=f"File '{request.filename}' not found.")
105
 
 
 
 
 
 
106
  try:
107
+ # This core logic is taken directly from your `classify_image` function
108
+
109
+ # 1. Read image
110
+ img = tf.io.read_file(str(image_path)) # Convert Path object to string for tf.io
111
+ # 2. Decode and convert to grayscale
112
+ img = tf.io.decode_png(img, channels=1)
113
+ # 3. Convert to float32 in [0, 1] range
114
+ img = tf.image.convert_image_dtype(img, tf.float32)
115
+ # 4. Resize to the desired size
116
+ img = tf.image.resize(img, [IMG_HEIGHT, IMG_WIDTH])
117
+ # 5. Transpose the image
118
+ img = tf.transpose(img, perm=[1, 0, 2])
119
+ # 6. Add a batch dimension
120
+ img = tf.expand_dims(img, axis=0)
121
+
122
+ # 7. Get predictions
123
+ preds = prediction_model.predict(img)
124
+
125
+ # 8. Decode the predictions
126
+ pred_text = decode_batch_predictions(preds)
127
 
128
+ # Return the first (and only) prediction
129
+ return {"prediction": pred_text[0]}
130
 
 
131
  except Exception as e:
132
  print(f"Error during prediction: {e}")
133
  raise HTTPException(status_code=500, detail=f"An error occurred during model inference: {e}")
app/vocab.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [UNK]
2
+ 8
3
+ 6
4
+ m
5
+ x
6
+ d
7
+ y
8
+ w
9
+ 2
10
+ 7
11
+ n
12
+ g
13
+ 5
14
+ c
15
+ f
16
+ p
17
+ e
18
+ 3
19
+ 4
20
+ b
model/.DS_Store DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:ce35a183b313defdf28e0e5a7cfb29468a17bb0d9b42f1ef75f4e366851478f7
3
- size 6148
 
 
 
 
model/fingerprint.pb DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:97e29e2ce27e4c2d1d1273f0cdb069a094ecdbea21a6559d82fbd34ed9c17b4b
3
- size 78
 
 
 
 
model/saved_model.pb DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:cd3fd69880e1b68390152c8236dd3000a7abdde05867d6f2074d120dbfdd6c17
3
- size 269319
 
 
 
 
model/variables/variables.data-00000-of-00001 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:b5bd622ed42679af9c2142e8c79621c1fe209608c3061414e1869c207df6b609
3
- size 3467858
 
 
 
 
model/variables/variables.index DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:a326c821608827e2a07f3ccb56d4b74ac4b172245d5b97a9d23832a6fd87ea37
3
- size 2907
 
 
 
 
requirements.txt CHANGED
@@ -1,6 +1,8 @@
1
  fastapi
2
  uvicorn[standard]
3
  python-multipart
4
- tensorflow
5
- numpy
 
 
6
  Pillow
 
1
  fastapi
2
  uvicorn[standard]
3
  python-multipart
4
+ tensorflow>=2.6,<2.15
5
+ keras<3.0.0
6
+ huggingface_hub
7
+ numpy<2
8
  Pillow