Spaces:
Running
Running
pretrained method
Browse files- app/main.py +64 -91
- app/vocab.txt +20 -0
- model/.DS_Store +0 -3
- model/fingerprint.pb +0 -3
- model/saved_model.pb +0 -3
- model/variables/variables.data-00000-of-00001 +0 -3
- model/variables/variables.index +0 -3
- requirements.txt +4 -2
app/main.py
CHANGED
|
@@ -3,51 +3,57 @@ import random
|
|
| 3 |
from pathlib import Path
|
| 4 |
import numpy as np
|
| 5 |
import tensorflow as tf
|
| 6 |
-
import keras
|
| 7 |
-
from
|
| 8 |
from fastapi import FastAPI, HTTPException
|
| 9 |
from fastapi.middleware.cors import CORSMiddleware
|
| 10 |
from pydantic import BaseModel
|
| 11 |
from contextlib import asynccontextmanager
|
| 12 |
|
|
|
|
|
|
|
|
|
|
| 13 |
# --- Pydantic Models for Request Body ---
|
| 14 |
class CaptchaRequest(BaseModel):
|
| 15 |
filename: str
|
| 16 |
|
| 17 |
# --- Global Variables ---
|
| 18 |
-
# This will hold our loaded prediction model
|
| 19 |
prediction_model = None
|
|
|
|
|
|
|
| 20 |
|
| 21 |
-
# --- Configuration
|
| 22 |
-
|
| 23 |
-
# 1. CHARACTER SET
|
| 24 |
-
data_dir = Path("./static/images/")
|
| 25 |
-
images = sorted(list(map(str, list(data_dir.glob("*.png")))))
|
| 26 |
-
labels = [img.split(os.path.sep)[-1].split(".png")[0] for img in images]
|
| 27 |
-
characters = set(char for label in labels for char in label)
|
| 28 |
-
CHARACTERS = sorted(list(characters))
|
| 29 |
-
|
| 30 |
-
# 2. IMAGE DIMENSIONS
|
| 31 |
-
# These dimensions are taken directly from your notebook.
|
| 32 |
IMG_WIDTH = 200
|
| 33 |
IMG_HEIGHT = 50
|
| 34 |
|
| 35 |
# --- App Lifespan Management (Model Loading) ---
|
| 36 |
@asynccontextmanager
|
| 37 |
async def lifespan(app: FastAPI):
|
| 38 |
-
|
| 39 |
-
print("INFO: Loading TensorFlow prediction model...")
|
| 40 |
-
global prediction_model
|
| 41 |
try:
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
except Exception as e:
|
| 47 |
-
print(f"ERROR: Failed to load model: {e}")
|
| 48 |
prediction_model = None
|
| 49 |
yield
|
| 50 |
-
# Code to run on shutdown
|
| 51 |
print("INFO: Application shutting down.")
|
| 52 |
|
| 53 |
|
|
@@ -55,67 +61,23 @@ async def lifespan(app: FastAPI):
|
|
| 55 |
app = FastAPI(lifespan=lifespan)
|
| 56 |
|
| 57 |
# --- CORS Middleware ---
|
| 58 |
-
app.add_middleware(
|
| 59 |
-
CORSMiddleware,
|
| 60 |
-
allow_origins=["*"],
|
| 61 |
-
allow_credentials=True,
|
| 62 |
-
allow_methods=["*"],
|
| 63 |
-
allow_headers=["*"],
|
| 64 |
-
)
|
| 65 |
|
| 66 |
# --- Constants ---
|
| 67 |
IMAGE_DIR = Path("static/images")
|
| 68 |
|
| 69 |
-
# --- Helper Functions
|
| 70 |
|
| 71 |
-
def
|
| 72 |
-
|
| 73 |
-
Loads and preprocesses an image for model prediction based on the notebook's
|
| 74 |
-
`encode_single_sample` function.
|
| 75 |
-
"""
|
| 76 |
-
try:
|
| 77 |
-
# 1. Read image, convert to grayscale
|
| 78 |
-
img = Image.open(image_path).convert('L') #
|
| 79 |
-
# 2. Resize to the desired size (width, height)
|
| 80 |
-
img = img.resize((IMG_WIDTH, IMG_HEIGHT)) #
|
| 81 |
-
# 3. Convert to numpy array of float32 in [0, 1] range
|
| 82 |
-
img = np.array(img, dtype=np.float32) / 255.0 #
|
| 83 |
-
|
| 84 |
-
# 4. Transpose the image because the RNN part of the model expects the time
|
| 85 |
-
# dimension to correspond to the width of the image.
|
| 86 |
-
# The notebook does this with `ops.transpose(img, axes=[1, 0, 2])`.
|
| 87 |
-
# Here, a numpy array of shape (height, width) becomes (width, height).
|
| 88 |
-
img = img.T
|
| 89 |
-
|
| 90 |
-
# 5. Add channel and batch dimensions
|
| 91 |
-
img = np.expand_dims(img, axis=-1) # Add channel -> (width, height, 1)
|
| 92 |
-
img = np.expand_dims(img, axis=0) # Add batch -> (1, width, height, 1)
|
| 93 |
-
|
| 94 |
-
return img
|
| 95 |
-
except Exception as e:
|
| 96 |
-
print(f"Error preprocessing image {image_path}: {e}")
|
| 97 |
-
return None
|
| 98 |
-
|
| 99 |
-
def decode_prediction(pred):
|
| 100 |
-
"""
|
| 101 |
-
Decodes the raw model output into a human-readable string using CTC decoding,
|
| 102 |
-
mirroring the notebook's `decode_batch_predictions` function.
|
| 103 |
-
"""
|
| 104 |
-
# 1. Get the input length (number of timesteps)
|
| 105 |
input_len = np.ones(pred.shape[0]) * pred.shape[1]
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
for res in results.numpy():
|
| 114 |
-
# The `CHARACTERS` list maps indices to characters.
|
| 115 |
-
# -1 is the default padding value from ctc_decode.
|
| 116 |
-
if res != -1 and res < len(CHARACTERS):
|
| 117 |
-
output_text += CHARACTERS[res]
|
| 118 |
-
|
| 119 |
return output_text
|
| 120 |
|
| 121 |
# --- API Endpoints ---
|
|
@@ -134,27 +96,38 @@ async def get_captcha():
|
|
| 134 |
|
| 135 |
@app.post("/solve_captcha")
|
| 136 |
async def solve_captcha(request: CaptchaRequest):
|
| 137 |
-
if prediction_model is None:
|
| 138 |
-
raise HTTPException(status_code=503, detail="Model is not loaded
|
| 139 |
|
| 140 |
image_path = IMAGE_DIR / request.filename
|
| 141 |
if not image_path.is_file():
|
| 142 |
raise HTTPException(status_code=404, detail=f"File '{request.filename}' not found.")
|
| 143 |
|
| 144 |
-
# Preprocess the image according to the notebook's logic
|
| 145 |
-
processed_image = preprocess_image(image_path)
|
| 146 |
-
if processed_image is None:
|
| 147 |
-
raise HTTPException(status_code=500, detail="Failed to process the image.")
|
| 148 |
-
|
| 149 |
try:
|
| 150 |
-
#
|
| 151 |
-
|
| 152 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
|
| 154 |
-
#
|
| 155 |
-
|
| 156 |
|
| 157 |
-
return {"prediction": predicted_label}
|
| 158 |
except Exception as e:
|
| 159 |
print(f"Error during prediction: {e}")
|
| 160 |
raise HTTPException(status_code=500, detail=f"An error occurred during model inference: {e}")
|
|
|
|
| 3 |
from pathlib import Path
|
| 4 |
import numpy as np
|
| 5 |
import tensorflow as tf
|
| 6 |
+
from tensorflow import keras
|
| 7 |
+
from tensorflow.keras import layers
|
| 8 |
from fastapi import FastAPI, HTTPException
|
| 9 |
from fastapi.middleware.cors import CORSMiddleware
|
| 10 |
from pydantic import BaseModel
|
| 11 |
from contextlib import asynccontextmanager
|
| 12 |
|
| 13 |
+
# New import for the pre-trained model
|
| 14 |
+
from huggingface_hub import from_pretrained_keras
|
| 15 |
+
|
| 16 |
# --- Pydantic Models for Request Body ---
|
| 17 |
class CaptchaRequest(BaseModel):
|
| 18 |
filename: str
|
| 19 |
|
| 20 |
# --- Global Variables ---
|
|
|
|
| 21 |
prediction_model = None
|
| 22 |
+
num_to_char = None
|
| 23 |
+
max_length = 5 # From your Gradio script
|
| 24 |
|
| 25 |
+
# --- Configuration for the pre-trained "keras-io/ocr-for-captcha" model ---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
IMG_WIDTH = 200
|
| 27 |
IMG_HEIGHT = 50
|
| 28 |
|
| 29 |
# --- App Lifespan Management (Model Loading) ---
|
| 30 |
@asynccontextmanager
|
| 31 |
async def lifespan(app: FastAPI):
|
| 32 |
+
global prediction_model, num_to_char
|
|
|
|
|
|
|
| 33 |
try:
|
| 34 |
+
print("INFO: Loading pre-trained Keras model and vocab...")
|
| 35 |
+
|
| 36 |
+
# 1. Load the base model from Hugging Face Hub
|
| 37 |
+
base_model = from_pretrained_keras("keras-io/ocr-for-captcha", compile=False)
|
| 38 |
+
|
| 39 |
+
# 2. Create the inference-only prediction_model (from your Gradio script)
|
| 40 |
+
prediction_model = keras.models.Model(
|
| 41 |
+
base_model.get_layer(name="image").input, base_model.get_layer(name="dense2").output
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
# 3. Load the vocabulary from the file
|
| 45 |
+
with open("vocab.txt", "r") as f:
|
| 46 |
+
vocab = f.read().splitlines()
|
| 47 |
+
|
| 48 |
+
# 4. Create the character mapping layer (from your Gradio script)
|
| 49 |
+
num_to_char = layers.StringLookup(vocabulary=vocab, mask_token=None, invert=True)
|
| 50 |
+
|
| 51 |
+
print("INFO: Model and vocab loaded successfully.")
|
| 52 |
+
|
| 53 |
except Exception as e:
|
| 54 |
+
print(f"ERROR: Failed to load pre-trained model or vocab: {e}")
|
| 55 |
prediction_model = None
|
| 56 |
yield
|
|
|
|
| 57 |
print("INFO: Application shutting down.")
|
| 58 |
|
| 59 |
|
|
|
|
| 61 |
app = FastAPI(lifespan=lifespan)
|
| 62 |
|
| 63 |
# --- CORS Middleware ---
|
| 64 |
+
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
# --- Constants ---
|
| 67 |
IMAGE_DIR = Path("static/images")
|
| 68 |
|
| 69 |
+
# --- Helper Functions (from your Gradio script) ---
|
| 70 |
|
| 71 |
+
def decode_batch_predictions(pred):
|
| 72 |
+
# This function is directly from your Gradio script
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
input_len = np.ones(pred.shape[0]) * pred.shape[1]
|
| 74 |
+
results = keras.backend.ctc_decode(pred, input_length=input_len, greedy=True)[0][0][
|
| 75 |
+
:, :max_length
|
| 76 |
+
]
|
| 77 |
+
output_text = []
|
| 78 |
+
for res in results:
|
| 79 |
+
res = tf.strings.reduce_join(num_to_char(res)).numpy().decode("utf-8")
|
| 80 |
+
output_text.append(res)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
return output_text
|
| 82 |
|
| 83 |
# --- API Endpoints ---
|
|
|
|
| 96 |
|
| 97 |
@app.post("/solve_captcha")
|
| 98 |
async def solve_captcha(request: CaptchaRequest):
|
| 99 |
+
if prediction_model is None or num_to_char is None:
|
| 100 |
+
raise HTTPException(status_code=503, detail="Model or vocab is not loaded.")
|
| 101 |
|
| 102 |
image_path = IMAGE_DIR / request.filename
|
| 103 |
if not image_path.is_file():
|
| 104 |
raise HTTPException(status_code=404, detail=f"File '{request.filename}' not found.")
|
| 105 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
try:
|
| 107 |
+
# This core logic is taken directly from your `classify_image` function
|
| 108 |
+
|
| 109 |
+
# 1. Read image
|
| 110 |
+
img = tf.io.read_file(str(image_path)) # Convert Path object to string for tf.io
|
| 111 |
+
# 2. Decode and convert to grayscale
|
| 112 |
+
img = tf.io.decode_png(img, channels=1)
|
| 113 |
+
# 3. Convert to float32 in [0, 1] range
|
| 114 |
+
img = tf.image.convert_image_dtype(img, tf.float32)
|
| 115 |
+
# 4. Resize to the desired size
|
| 116 |
+
img = tf.image.resize(img, [IMG_HEIGHT, IMG_WIDTH])
|
| 117 |
+
# 5. Transpose the image
|
| 118 |
+
img = tf.transpose(img, perm=[1, 0, 2])
|
| 119 |
+
# 6. Add a batch dimension
|
| 120 |
+
img = tf.expand_dims(img, axis=0)
|
| 121 |
+
|
| 122 |
+
# 7. Get predictions
|
| 123 |
+
preds = prediction_model.predict(img)
|
| 124 |
+
|
| 125 |
+
# 8. Decode the predictions
|
| 126 |
+
pred_text = decode_batch_predictions(preds)
|
| 127 |
|
| 128 |
+
# Return the first (and only) prediction
|
| 129 |
+
return {"prediction": pred_text[0]}
|
| 130 |
|
|
|
|
| 131 |
except Exception as e:
|
| 132 |
print(f"Error during prediction: {e}")
|
| 133 |
raise HTTPException(status_code=500, detail=f"An error occurred during model inference: {e}")
|
app/vocab.txt
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[UNK]
|
| 2 |
+
8
|
| 3 |
+
6
|
| 4 |
+
m
|
| 5 |
+
x
|
| 6 |
+
d
|
| 7 |
+
y
|
| 8 |
+
w
|
| 9 |
+
2
|
| 10 |
+
7
|
| 11 |
+
n
|
| 12 |
+
g
|
| 13 |
+
5
|
| 14 |
+
c
|
| 15 |
+
f
|
| 16 |
+
p
|
| 17 |
+
e
|
| 18 |
+
3
|
| 19 |
+
4
|
| 20 |
+
b
|
model/.DS_Store
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:ce35a183b313defdf28e0e5a7cfb29468a17bb0d9b42f1ef75f4e366851478f7
|
| 3 |
-
size 6148
|
|
|
|
|
|
|
|
|
|
|
|
model/fingerprint.pb
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:97e29e2ce27e4c2d1d1273f0cdb069a094ecdbea21a6559d82fbd34ed9c17b4b
|
| 3 |
-
size 78
|
|
|
|
|
|
|
|
|
|
|
|
model/saved_model.pb
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:cd3fd69880e1b68390152c8236dd3000a7abdde05867d6f2074d120dbfdd6c17
|
| 3 |
-
size 269319
|
|
|
|
|
|
|
|
|
|
|
|
model/variables/variables.data-00000-of-00001
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:b5bd622ed42679af9c2142e8c79621c1fe209608c3061414e1869c207df6b609
|
| 3 |
-
size 3467858
|
|
|
|
|
|
|
|
|
|
|
|
model/variables/variables.index
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:a326c821608827e2a07f3ccb56d4b74ac4b172245d5b97a9d23832a6fd87ea37
|
| 3 |
-
size 2907
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
|
@@ -1,6 +1,8 @@
|
|
| 1 |
fastapi
|
| 2 |
uvicorn[standard]
|
| 3 |
python-multipart
|
| 4 |
-
tensorflow
|
| 5 |
-
|
|
|
|
|
|
|
| 6 |
Pillow
|
|
|
|
| 1 |
fastapi
|
| 2 |
uvicorn[standard]
|
| 3 |
python-multipart
|
| 4 |
+
tensorflow>=2.6,<2.15
|
| 5 |
+
keras<3.0.0
|
| 6 |
+
huggingface_hub
|
| 7 |
+
numpy<2
|
| 8 |
Pillow
|