Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -33,7 +33,7 @@ table_engine_list = [
|
|
| 33 |
# 示例图片路径
|
| 34 |
example_images = [
|
| 35 |
"images/wired1.png",
|
| 36 |
-
"images/wired2.
|
| 37 |
"images/wired3.png",
|
| 38 |
"images/lineless1.png",
|
| 39 |
"images/wired4.jpg",
|
|
@@ -67,6 +67,17 @@ for det_model in det_model_dir.keys():
|
|
| 67 |
rec_model_dir=rec_model_path
|
| 68 |
)
|
| 69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
def select_ocr_model(det_model, rec_model):
|
| 72 |
return ocr_engine_dict[f"{det_model}_{rec_model}"]
|
|
@@ -94,8 +105,10 @@ def select_table_model(img, table_engine_type, det_model, rec_model):
|
|
| 94 |
return lineless_table_engine, "lineless_table"
|
| 95 |
|
| 96 |
|
| 97 |
-
def process_image(
|
| 98 |
-
|
|
|
|
|
|
|
| 99 |
start = time.time()
|
| 100 |
table_engine, talbe_type = select_table_model(img, table_engine_type, det_model, rec_model)
|
| 101 |
ocr_engine = select_ocr_model(det_model, rec_model)
|
|
@@ -108,24 +121,20 @@ def process_image(img, table_engine_type, det_model, rec_model, small_box_cut_en
|
|
| 108 |
ocr_boxes = result[0]['res']['boxes']
|
| 109 |
all_elapse = f"- `table all cost: {time.time() - start:.5f}`"
|
| 110 |
else:
|
| 111 |
-
ocr_res, ocr_infer_elapse = ocr_engine(img)
|
| 112 |
det_cost, cls_cost, rec_cost = ocr_infer_elapse
|
|
|
|
|
|
|
| 113 |
ocr_boxes = [box_4_2_poly_to_box_4_1(ori_ocr[0]) for ori_ocr in ocr_res]
|
| 114 |
if isinstance(table_engine, RapidTable):
|
| 115 |
html, polygons, table_rec_elapse = table_engine(img, ocr_result=ocr_res)
|
| 116 |
polygons = [[polygon[0], polygon[1], polygon[4], polygon[5]] for polygon in polygons]
|
| 117 |
elif isinstance(table_engine, (WiredTableRecognition, LinelessTableRecognition)):
|
| 118 |
-
html, table_rec_elapse, polygons,
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
)
|
| 124 |
-
else:
|
| 125 |
-
html, table_rec_elapse, polygons, logic_points, ocr_res = table_engine(
|
| 126 |
-
img, ocr_result=ocr_res
|
| 127 |
-
)
|
| 128 |
-
|
| 129 |
sum_elapse = time.time() - start
|
| 130 |
all_elapse = f"- table_type: {talbe_type}\n table all cost: {sum_elapse:.5f}\n - table rec cost: {table_rec_elapse:.5f}\n - ocr cost: {det_cost + cls_cost + rec_cost:.5f}"
|
| 131 |
|
|
@@ -191,10 +200,33 @@ def main():
|
|
| 191 |
label="Box Cutting Enhancement (Disable to avoid excessive cutting, Enable to reduce missed cutting)",
|
| 192 |
value=True
|
| 193 |
)
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
|
| 199 |
run_button = gr.Button("Run")
|
| 200 |
gr.Markdown("# Elapsed Time")
|
|
@@ -210,7 +242,7 @@ def main():
|
|
| 210 |
|
| 211 |
run_button.click(
|
| 212 |
fn=process_image,
|
| 213 |
-
inputs=[img_input, table_engine_type,
|
| 214 |
outputs=[html_output, table_boxes_output, ocr_boxes_output, elapse_text]
|
| 215 |
)
|
| 216 |
|
|
|
|
| 33 |
# 示例图片路径
|
| 34 |
example_images = [
|
| 35 |
"images/wired1.png",
|
| 36 |
+
"images/wired2.jpg",
|
| 37 |
"images/wired3.png",
|
| 38 |
"images/lineless1.png",
|
| 39 |
"images/wired4.jpg",
|
|
|
|
| 67 |
rec_model_dir=rec_model_path
|
| 68 |
)
|
| 69 |
|
| 70 |
+
def trans_char_ocr_res(ocr_res):
|
| 71 |
+
word_result = []
|
| 72 |
+
for res in ocr_res:
|
| 73 |
+
score = res[2]
|
| 74 |
+
for word_box, word in zip(res[3], res[4]):
|
| 75 |
+
word_res = []
|
| 76 |
+
word_res.append(word_box)
|
| 77 |
+
word_res.append(word)
|
| 78 |
+
word_res.append(score)
|
| 79 |
+
word_result.append(word_res)
|
| 80 |
+
return word_result
|
| 81 |
|
| 82 |
def select_ocr_model(det_model, rec_model):
|
| 83 |
return ocr_engine_dict[f"{det_model}_{rec_model}"]
|
|
|
|
| 105 |
return lineless_table_engine, "lineless_table"
|
| 106 |
|
| 107 |
|
| 108 |
+
def process_image(img_input, small_box_cut_enhance, table_engine_type, char_ocr, rotated_fix, col_threshold, row_threshold):
|
| 109 |
+
det_model="mobile_det"
|
| 110 |
+
rec_model="mobile_rec"
|
| 111 |
+
img = img_loader(img_input)
|
| 112 |
start = time.time()
|
| 113 |
table_engine, talbe_type = select_table_model(img, table_engine_type, det_model, rec_model)
|
| 114 |
ocr_engine = select_ocr_model(det_model, rec_model)
|
|
|
|
| 121 |
ocr_boxes = result[0]['res']['boxes']
|
| 122 |
all_elapse = f"- `table all cost: {time.time() - start:.5f}`"
|
| 123 |
else:
|
| 124 |
+
ocr_res, ocr_infer_elapse = ocr_engine(img, return_word_box=char_ocr)
|
| 125 |
det_cost, cls_cost, rec_cost = ocr_infer_elapse
|
| 126 |
+
if char_ocr:
|
| 127 |
+
ocr_res = trans_char_ocr_res(ocr_res)
|
| 128 |
ocr_boxes = [box_4_2_poly_to_box_4_1(ori_ocr[0]) for ori_ocr in ocr_res]
|
| 129 |
if isinstance(table_engine, RapidTable):
|
| 130 |
html, polygons, table_rec_elapse = table_engine(img, ocr_result=ocr_res)
|
| 131 |
polygons = [[polygon[0], polygon[1], polygon[4], polygon[5]] for polygon in polygons]
|
| 132 |
elif isinstance(table_engine, (WiredTableRecognition, LinelessTableRecognition)):
|
| 133 |
+
html, table_rec_elapse, polygons, logic_points, ocr_res = table_engine(img, ocr_result=ocr_res,
|
| 134 |
+
enhance_box_line=small_box_cut_enhance,
|
| 135 |
+
rotated_fix=rotated_fix,
|
| 136 |
+
col_threshold=col_threshold,
|
| 137 |
+
row_threshold=row_threshold)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
sum_elapse = time.time() - start
|
| 139 |
all_elapse = f"- table_type: {talbe_type}\n table all cost: {sum_elapse:.5f}\n - table rec cost: {table_rec_elapse:.5f}\n - ocr cost: {det_cost + cls_cost + rec_cost:.5f}"
|
| 140 |
|
|
|
|
| 200 |
label="Box Cutting Enhancement (Disable to avoid excessive cutting, Enable to reduce missed cutting)",
|
| 201 |
value=True
|
| 202 |
)
|
| 203 |
+
char_ocr = gr.Checkbox(
|
| 204 |
+
label="char rec ocr",
|
| 205 |
+
value=False
|
| 206 |
+
)
|
| 207 |
+
rotate_adapt = gr.Checkbox(
|
| 208 |
+
label="Table Rotate Rec Enhancement",
|
| 209 |
+
value=False
|
| 210 |
+
)
|
| 211 |
+
col_threshold = gr.Slider(
|
| 212 |
+
label="col threshold(determine same col)",
|
| 213 |
+
minimum=5,
|
| 214 |
+
maximum=100,
|
| 215 |
+
value=15,
|
| 216 |
+
step=5
|
| 217 |
+
)
|
| 218 |
+
row_threshold = gr.Slider(
|
| 219 |
+
label="row threshold(determine same row)",
|
| 220 |
+
minimum=5,
|
| 221 |
+
maximum=100,
|
| 222 |
+
value=10,
|
| 223 |
+
step=5
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
# det_model = gr.Dropdown(det_models_labels, label="Select OCR Detection Model",
|
| 227 |
+
# value=det_models_labels[0])
|
| 228 |
+
# rec_model = gr.Dropdown(rec_models_labels, label="Select OCR Recognition Model",
|
| 229 |
+
# value=rec_models_labels[0])
|
| 230 |
|
| 231 |
run_button = gr.Button("Run")
|
| 232 |
gr.Markdown("# Elapsed Time")
|
|
|
|
| 242 |
|
| 243 |
run_button.click(
|
| 244 |
fn=process_image,
|
| 245 |
+
inputs=[img_input, small_box_cut_enhance, table_engine_type, char_ocr, rotate_adapt, col_threshold, row_threshold],
|
| 246 |
outputs=[html_output, table_boxes_output, ocr_boxes_output, elapse_text]
|
| 247 |
)
|
| 248 |
|