Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -13,11 +13,11 @@ from utils import plot_rec_box, LoadImage, format_html, box_4_2_poly_to_box_4_1
|
|
| 13 |
img_loader = LoadImage()
|
| 14 |
table_rec_path = "models/table_rec/ch_ppstructure_mobile_v2_SLANet.onnx"
|
| 15 |
det_model_dir = {
|
| 16 |
-
"mobile_det": "models/ocr/ch_PP-OCRv4_det_infer.onnx"
|
| 17 |
}
|
| 18 |
|
| 19 |
rec_model_dir = {
|
| 20 |
-
"mobile_rec": "models/ocr/ch_PP-OCRv4_rec_infer.onnx"
|
| 21 |
}
|
| 22 |
table_engine_list = [
|
| 23 |
"auto",
|
|
@@ -67,29 +67,28 @@ def select_ocr_model(det_model, rec_model):
|
|
| 67 |
|
| 68 |
def select_table_model(img, table_engine_type, det_model, rec_model):
|
| 69 |
if table_engine_type == "rapid_table":
|
| 70 |
-
return rapid_table_engine,
|
| 71 |
elif table_engine_type == "wired_table_v1":
|
| 72 |
-
return wired_table_engine_v1,
|
| 73 |
elif table_engine_type == "wired_table_v2":
|
| 74 |
print("使用v2 wired table")
|
| 75 |
-
return wired_table_engine_v2,
|
| 76 |
elif table_engine_type == "lineless_table":
|
| 77 |
-
return lineless_table_engine,
|
| 78 |
elif table_engine_type == "pp_table":
|
| 79 |
return pp_engine_dict[f"{det_model}_{rec_model}"], 0
|
| 80 |
elif table_engine_type == "auto":
|
| 81 |
cls, elasp = table_cls(img)
|
| 82 |
if cls == 'wired':
|
| 83 |
table_engine = wired_table_engine_v2
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
return table_engine, elasp
|
| 87 |
|
| 88 |
|
| 89 |
def process_image(img, table_engine_type, det_model, rec_model):
|
| 90 |
img = img_loader(img)
|
| 91 |
start = time.time()
|
| 92 |
-
table_engine,
|
| 93 |
ocr_engine = select_ocr_model(det_model, rec_model)
|
| 94 |
|
| 95 |
if isinstance(table_engine, PPStructure):
|
|
@@ -106,11 +105,12 @@ def process_image(img, table_engine_type, det_model, rec_model):
|
|
| 106 |
|
| 107 |
if isinstance(table_engine, RapidTable):
|
| 108 |
html, polygons, table_rec_elapse = table_engine(img, ocr_result=ocr_res)
|
|
|
|
| 109 |
elif isinstance(table_engine, (WiredTableRecognition, LinelessTableRecognition)):
|
| 110 |
html, table_rec_elapse, polygons, _, _ = table_engine(img, ocr_result=ocr_res)
|
| 111 |
|
| 112 |
sum_elapse = time.time() - start
|
| 113 |
-
all_elapse = f"- table all cost: {sum_elapse:.5f}\n - table rec cost: {table_rec_elapse:.5f}\n - ocr cost: {det_cost + cls_cost + rec_cost:.5f}"
|
| 114 |
|
| 115 |
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 116 |
table_boxes_img = plot_rec_box(img.copy(), polygons)
|
|
@@ -126,7 +126,7 @@ def main():
|
|
| 126 |
|
| 127 |
with gr.Blocks() as demo:
|
| 128 |
with gr.Row(): # 两列布局
|
| 129 |
-
with gr.Column(): # 左边列
|
| 130 |
img_input = gr.Image(label="Upload or Select Image", sources="upload")
|
| 131 |
|
| 132 |
# 示例图片选择器
|
|
@@ -148,7 +148,7 @@ def main():
|
|
| 148 |
run_button = gr.Button("Run")
|
| 149 |
gr.Markdown("# Elapsed Time")
|
| 150 |
elapse_text = gr.Text(label="") # 使用 `gr.Text` 组件展示字符串
|
| 151 |
-
with gr.Column(): # 右边列
|
| 152 |
# 使用 Markdown 标题分隔各个组件
|
| 153 |
gr.Markdown("# Html Render")
|
| 154 |
html_output = gr.HTML(label="", elem_classes="scrollable-container")
|
|
|
|
| 13 |
img_loader = LoadImage()
|
| 14 |
table_rec_path = "models/table_rec/ch_ppstructure_mobile_v2_SLANet.onnx"
|
| 15 |
det_model_dir = {
|
| 16 |
+
"mobile_det": "models/ocr/ch_PP-OCRv4_det_infer.onnx",
|
| 17 |
}
|
| 18 |
|
| 19 |
rec_model_dir = {
|
| 20 |
+
"mobile_rec": "models/ocr/ch_PP-OCRv4_rec_infer.onnx",
|
| 21 |
}
|
| 22 |
table_engine_list = [
|
| 23 |
"auto",
|
|
|
|
| 67 |
|
| 68 |
def select_table_model(img, table_engine_type, det_model, rec_model):
|
| 69 |
if table_engine_type == "rapid_table":
|
| 70 |
+
return rapid_table_engine, table_engine_type
|
| 71 |
elif table_engine_type == "wired_table_v1":
|
| 72 |
+
return wired_table_engine_v1, table_engine_type
|
| 73 |
elif table_engine_type == "wired_table_v2":
|
| 74 |
print("使用v2 wired table")
|
| 75 |
+
return wired_table_engine_v2, table_engine_type
|
| 76 |
elif table_engine_type == "lineless_table":
|
| 77 |
+
return lineless_table_engine, table_engine_type
|
| 78 |
elif table_engine_type == "pp_table":
|
| 79 |
return pp_engine_dict[f"{det_model}_{rec_model}"], 0
|
| 80 |
elif table_engine_type == "auto":
|
| 81 |
cls, elasp = table_cls(img)
|
| 82 |
if cls == 'wired':
|
| 83 |
table_engine = wired_table_engine_v2
|
| 84 |
+
return table_engine, "wired_table_v2"
|
| 85 |
+
return lineless_table_engine, "lineless_table"
|
|
|
|
| 86 |
|
| 87 |
|
| 88 |
def process_image(img, table_engine_type, det_model, rec_model):
|
| 89 |
img = img_loader(img)
|
| 90 |
start = time.time()
|
| 91 |
+
table_engine, talbe_type = select_table_model(img, table_engine_type, det_model, rec_model)
|
| 92 |
ocr_engine = select_ocr_model(det_model, rec_model)
|
| 93 |
|
| 94 |
if isinstance(table_engine, PPStructure):
|
|
|
|
| 105 |
|
| 106 |
if isinstance(table_engine, RapidTable):
|
| 107 |
html, polygons, table_rec_elapse = table_engine(img, ocr_result=ocr_res)
|
| 108 |
+
polygons = [[polygon[0], polygon[1], polygon[4], polygon[5]] for polygon in polygons]
|
| 109 |
elif isinstance(table_engine, (WiredTableRecognition, LinelessTableRecognition)):
|
| 110 |
html, table_rec_elapse, polygons, _, _ = table_engine(img, ocr_result=ocr_res)
|
| 111 |
|
| 112 |
sum_elapse = time.time() - start
|
| 113 |
+
all_elapse = f"- table_type: {talbe_type}\n table all cost: {sum_elapse:.5f}\n - table rec cost: {table_rec_elapse:.5f}\n - ocr cost: {det_cost + cls_cost + rec_cost:.5f}"
|
| 114 |
|
| 115 |
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 116 |
table_boxes_img = plot_rec_box(img.copy(), polygons)
|
|
|
|
| 126 |
|
| 127 |
with gr.Blocks() as demo:
|
| 128 |
with gr.Row(): # 两列布局
|
| 129 |
+
with gr.Column(variant="panel"): # 左边列
|
| 130 |
img_input = gr.Image(label="Upload or Select Image", sources="upload")
|
| 131 |
|
| 132 |
# 示例图片选择器
|
|
|
|
| 148 |
run_button = gr.Button("Run")
|
| 149 |
gr.Markdown("# Elapsed Time")
|
| 150 |
elapse_text = gr.Text(label="") # 使用 `gr.Text` 组件展示字符串
|
| 151 |
+
with gr.Column(scale=2): # 右边列
|
| 152 |
# 使用 Markdown 标题分隔各个组件
|
| 153 |
gr.Markdown("# Html Render")
|
| 154 |
html_output = gr.HTML(label="", elem_classes="scrollable-container")
|