Spaces:
Build error
Build error
| import streamlit as st | |
| from PIL import Image | |
| import os | |
| import TDTSR | |
| import pytesseract | |
| from pytesseract import Output | |
| import postprocess as pp | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| import cv2 | |
| import numpy as np | |
| from transformers import TrOCRProcessor, VisionEncoderDecoderModel | |
| from cv2 import dnn_superres | |
| pytesseract.pytesseract.tesseract_cmd = r'C:\Program Files\Tesseract-OCR\tesseract.exe' | |
| st.set_option('deprecation.showPyplotGlobalUse', False) | |
| st.set_page_config(layout='wide') | |
| st.title("Table Detection and Table Structure Recognition") | |
| c1, c2, c3 = st.columns((1,1,1)) | |
| def PIL_to_cv(pil_img): | |
| return cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR) | |
| def cv_to_PIL(cv_img): | |
| return Image.fromarray(cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB)) | |
| def pytess(cell_pil_img): | |
| return ' '.join(pytesseract.image_to_data(cell_pil_img, output_type=Output.DICT, config='preserve_interword_spaces')['text']).strip() | |
| def TrOCR(cell_pil_img): | |
| processor = TrOCRProcessor.from_pretrained("SalML/trocr-base-printed") | |
| model = VisionEncoderDecoderModel.from_pretrained("SalML/trocr-base-printed") | |
| pixel_values = processor(images=cell_pil_img, return_tensors="pt").pixel_values | |
| generated_ids = model.generate(pixel_values) | |
| generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
| return generated_text | |
| def super_res(pil_img): | |
| # requires opencv-contrib-python installed without the opencv-python | |
| sr = dnn_superres.DnnSuperResImpl_create() | |
| image = PIL_to_cv(pil_img) | |
| model_path = "./LapSRN_x8.pb" | |
| model_name = model_path.split('/')[1].split('_')[0].lower() | |
| model_scale = int(model_path.split('/')[1].split('_')[1].split('.')[0][1]) | |
| sr.readModel(model_path) | |
| sr.setModel(model_name, model_scale) | |
| final_img = sr.upsample(image) | |
| final_img = cv_to_PIL(final_img) | |
| return final_img | |
| def sharpen_image(pil_img): | |
| img = PIL_to_cv(pil_img) | |
| sharpen_kernel = np.array([[-1,-1,-1], [-1,9,-1], [-1,-1,-1]]) | |
| # sharpen_kernel = np.array([[0, -1, 0], | |
| # [-1, 5,-1], | |
| # [0, -1, 0]]) | |
| sharpen = cv2.filter2D(img, -1, sharpen_kernel) | |
| pil_img = cv_to_PIL(sharpen) | |
| return pil_img | |
| def preprocess_magic(pil_img): | |
| cv_img = PIL_to_cv(pil_img) | |
| grayscale_image = cv2.cvtColor(cv_img, cv2.COLOR_BGR2GRAY) | |
| _, binary_image = cv2.threshold(grayscale_image, 0, 255, cv2.THRESH_OTSU) | |
| count_white = np.sum(binary_image > 0) | |
| count_black = np.sum(binary_image == 0) | |
| if count_black > count_white: | |
| binary_image = 255 - binary_image | |
| black_text_white_background_image = binary_image | |
| return cv_to_PIL(black_text_white_background_image) | |
| ### main code: | |
| for td_sample in os.listdir('D:/Jupyter/Multi-Type-TD-TSR/TD_samples/'): | |
| image = Image.open("D:/Jupyter/Multi-Type-TD-TSR/TD_samples/"+td_sample).convert("RGB") | |
| model, image, probas, bboxes_scaled = TDTSR.table_detector(image, THRESHOLD_PROBA=0.6) | |
| TDTSR.plot_results_detection(c1, model, image, probas, bboxes_scaled) | |
| cropped_img_list = TDTSR.plot_table_detection(c2, model, image, probas, bboxes_scaled) | |
| for unpadded_table in cropped_img_list: | |
| # table : pil_img | |
| table = TDTSR.add_margin(unpadded_table) | |
| model, image, probas, bboxes_scaled = TDTSR.table_struct_recog(table, THRESHOLD_PROBA=0.6) | |
| # The try, except block of code below plots table header row and simple rows | |
| try: | |
| rows, cols = TDTSR.plot_structure(c3, model, image, probas, bboxes_scaled, class_to_show=0) | |
| rows, cols = TDTSR.sort_table_featuresv2(rows, cols) | |
| # headers, rows, cols are ordered dictionaries with 5th element value of tuple being pil_img | |
| rows, cols = TDTSR.individual_table_featuresv2(table, rows, cols) | |
| # TDTSR.plot_table_features(c1, header, row_header, rows, cols) | |
| except Exception as printableException: | |
| st.write(td_sample, ' terminated with exception:', printableException) | |
| # master_row = TDTSR.master_row_set(header, row_header, rows, cols) | |
| master_row = rows | |
| # cells_img = TDTSR.object_to_cells(master_row, cols) | |
| cells_img = TDTSR.object_to_cellsv2(master_row, cols) | |
| headers = [] | |
| cells_list = [] | |
| # st.write(cells_img) | |
| for n, kv in enumerate(cells_img.items()): | |
| k, row_images = kv | |
| if n == 0: | |
| for idx, header in enumerate(row_images): | |
| # plt.imshow(header) | |
| # c2.pyplot() | |
| # c2.write(pytess(header)) | |
| ############################ | |
| SR_img = super_res(header) | |
| # # w, h = SR_img.size | |
| # # SR_img = SR_img.crop((0 ,0 ,w, h-60)) | |
| # plt.imshow(SR_img) | |
| # c3.pyplot() | |
| # c3.write(pytess(SR_img)) | |
| header_text = pytess(SR_img) | |
| if header_text == '': | |
| header_text = 'empty_col'+str(idx) | |
| headers.append(header_text) | |
| else: | |
| for cells in row_images: | |
| # plt.imshow(cells) | |
| # c2.pyplot() | |
| # c2.write(pytess(cells)) | |
| ############################## | |
| SR_img = super_res(cells) | |
| # # w, h = SR_img.size | |
| # # SR_img = SR_img.crop((0 ,0 ,w, h-60)) | |
| # plt.imshow(SR_img) | |
| # c3.pyplot() | |
| # c3.write(pytess(SR_img)) | |
| cells_list.append(pytess(SR_img)) | |
| df = pd.DataFrame("", index=range(0, len(master_row)), columns=headers) | |
| cell_idx = 0 | |
| for nrows in range(len(master_row)-1): | |
| for ncols in range(len(cols)): | |
| df.iat[nrows, ncols] = cells_list[cell_idx] | |
| cell_idx += 1 | |
| c3.dataframe(df) | |
| # break | |