Spaces:
Build error
Build error
| import os | |
| import cv2 | |
| from transformers import DetrFeatureExtractor | |
| from transformers import DetrForObjectDetection | |
| import torch | |
| import matplotlib.pyplot as plt | |
| from matplotlib.patches import Circle, Wedge, Rectangle | |
| import streamlit as st | |
| from PIL import Image | |
| import math | |
| colors = ["red", "blue", "green", "yellow", "orange", "violet"] | |
| def table_detector(image, THRESHOLD_PROBA): | |
| ''' | |
| Table detection using DEtect-object TRansformer pre-trained on 1 million tables | |
| ''' | |
| feature_extractor = DetrFeatureExtractor(do_resize=True, size=800, max_size=800) | |
| encoding = feature_extractor(image, return_tensors="pt") | |
| # encoding.keys() | |
| model = DetrForObjectDetection.from_pretrained("SalML/DETR-table-detection") | |
| # SalML\DETR-table-detection | |
| with torch.no_grad(): | |
| outputs = model(**encoding) | |
| # keep only predictions of queries with 0.9+ confidence (excluding no-object class) | |
| probas = outputs.logits.softmax(-1)[0, :, :-1] | |
| keep = probas.max(-1).values > THRESHOLD_PROBA | |
| # rescale bounding boxes | |
| target_sizes = torch.tensor(image.size[::-1]).unsqueeze(0) | |
| postprocessed_outputs = feature_extractor.post_process(outputs, target_sizes) | |
| bboxes_scaled = postprocessed_outputs[0]['boxes'][keep] | |
| return (model, image, probas[keep], bboxes_scaled) | |
| def table_struct_recog(image, THRESHOLD_PROBA): | |
| ''' | |
| Table structure recognition using DEtect-object TRansformer pre-trained on 1 million tables | |
| ''' | |
| feature_extractor = DetrFeatureExtractor(do_resize=True, size=1000, max_size=1000) | |
| encoding = feature_extractor(image, return_tensors="pt") | |
| model = DetrForObjectDetection.from_pretrained("SalML/DETR-table-structure-recognition") | |
| with torch.no_grad(): | |
| outputs = model(**encoding) | |
| # keep only predictions of queries with 0.9+ confidence (excluding no-object class) | |
| probas = outputs.logits.softmax(-1)[0, :, :-1] | |
| keep = probas.max(-1).values > THRESHOLD_PROBA | |
| # rescale bounding boxes | |
| target_sizes = torch.tensor(image.size[::-1]).unsqueeze(0) | |
| postprocessed_outputs = feature_extractor.post_process(outputs, target_sizes) | |
| bboxes_scaled = postprocessed_outputs[0]['boxes'][keep] | |
| return (model, image, probas[keep], bboxes_scaled) | |
| def add_margin(pil_img, top=20, right=20, bottom=20, left=20, color=(255,255,255)): | |
| ''' | |
| Image padding as part of TSR pre-processing to prevent missing table edges | |
| ''' | |
| width, height = pil_img.size | |
| new_width = width + right + left | |
| new_height = height + top + bottom | |
| result = Image.new(pil_img.mode, (new_width, new_height), color) | |
| result.paste(pil_img, (left, top)) | |
| return result | |
| def plot_results_detection(c1, model, pil_img, prob, boxes, show_only_cropped=False): | |
| ''' | |
| Plots the full pillow pdf-page image and adds a rectangle patch for table detection | |
| ''' | |
| plt.figure(figsize=(32,20)) | |
| plt.imshow(pil_img) | |
| ax = plt.gca() | |
| for p, (xmin, ymin, xmax, ymax) in zip(prob, boxes.tolist()): | |
| cl = p.argmax() | |
| xmin, ymin, xmax, ymax = xmin-3, ymin-3, xmax+3, ymax+3 | |
| ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,fill=False, color=colors[cl.item()], linewidth=3)) | |
| text = f'{model.config.id2label[cl.item()]}: {p[cl]:0.2f}' | |
| ax.text(xmin, ymin, text, fontsize=15,bbox=dict(facecolor='yellow', alpha=0.5)) | |
| plt.axis('off') | |
| plt.show() | |
| c1.pyplot() | |
| def plot_table_detection(c2, model, pil_img, prob, boxes): | |
| ''' | |
| Plots only the cropped table(s) from the table detection | |
| ''' | |
| plt.figure(figsize=(32,20)) | |
| ax = plt.gca() | |
| cropped_img_list = [] | |
| for p, (xmin, ymin, xmax, ymax) in zip(prob, boxes.tolist()): | |
| xmin, ymin, xmax, ymax = xmin-3, ymin-3, xmax+3, ymax+3 | |
| cropped_img = pil_img.crop((xmin, ymin, xmax, ymax)) | |
| cropped_img_list.append(cropped_img) | |
| for cropped_img in cropped_img_list: | |
| plt.imshow(cropped_img) | |
| plt.axis('off') | |
| plt.show() | |
| c2.pyplot() | |
| return cropped_img_list | |
| def plot_structure(c3, model, pil_img, prob, boxes, class_to_show=0): | |
| ''' | |
| To plot table pillow image and the TSR bounding boxes on the table | |
| ''' | |
| plt.figure(figsize=(32,20)) | |
| plt.imshow(pil_img) | |
| ax = plt.gca() | |
| rows = {} | |
| cols = {} | |
| header = {} | |
| row_header = {} | |
| idx = 0 | |
| for p, (xmin, ymin, xmax, ymax) in zip(prob, boxes.tolist()): | |
| xmin, ymin, xmax, ymax = xmin-3, ymin-3, xmax+3, ymax+3 | |
| cl = p.argmax() | |
| class_text = model.config.id2label[cl.item()] | |
| text = f'{class_text}: {p[cl]:0.2f}' | |
| # st.write(class_text) | |
| if class_text != 'table': | |
| ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,fill=False, color=colors[cl.item()], linewidth=3)) | |
| ax.text(xmin, ymin, text, fontsize=15, bbox=dict(facecolor='yellow', alpha=0.5)) | |
| # if class_text == 'table column header': | |
| # header['header'] = (xmin, ymin, xmax, ymax) | |
| if class_text == 'table row': | |
| rows['table row '+str(idx)] = (xmin, ymin, xmax, ymax) | |
| if class_text == 'table column': | |
| cols['table column '+str(idx)] = (xmin, ymin, xmax, ymax) | |
| # if class_text == 'table projected row header': | |
| # row_header['header table row'+str(idx)] = (xmin, ymin, xmax, ymax) | |
| idx += 1 | |
| plt.show() | |
| c3.pyplot() | |
| # return header, row_header, rows, cols | |
| return rows, cols | |
| def sort_table_features(header, row_header, rows, cols): | |
| # Sometimes the header and first row overlap, and we need the header bbox not to have first row's bbox inside the headers bbox | |
| y_header = header['header'][3] - 10 | |
| rows_ = {table_feature : (xmin, ymin, xmax, ymax) for table_feature, (xmin, ymin, xmax, ymax) in sorted(rows.items(), key=lambda tup: tup[1][1]) if ymin > y_header} | |
| cols_ = {table_feature : (xmin, ymin, xmax, ymax) for table_feature, (xmin, ymin, xmax, ymax) in sorted(cols.items(), key=lambda tup: tup[1][0])} | |
| row_header_ = {table_feature : (xmin, ymin, xmax, ymax) for table_feature, (xmin, ymin, xmax, ymax) in sorted(row_header.items(), key=lambda tup: tup[1][1])} | |
| new_row = {} | |
| idx = 0 | |
| for k1, v1 in rows_.items(): | |
| save_row = True | |
| row_xmin, row_ymin, row_xmax, row_ymax = v1 | |
| for k2, v2 in row_header_.items(): | |
| header_row_xmin, header_row_ymin, header_row_xmax, header_row_ymax = v2 | |
| # table row and header table row are within 2 pixel range, skip saving the row | |
| if math.isclose(row_ymin, header_row_ymin, abs_tol=2): | |
| save_row = False | |
| if save_row: | |
| new_row['table row.'+str(idx)] = (row_xmin, row_ymin, row_xmax, row_ymax) | |
| idx += 1 | |
| new_row_ = {table_feature : (xmin, ymin, xmax, ymax) for table_feature, (xmin, ymin, xmax, ymax) in sorted(new_row.items(), key=lambda tup: tup[1][1])} | |
| return row_header_, new_row_, cols_ | |
| def sort_table_featuresv2(rows, cols): | |
| # Sometimes the header and first row overlap, and we need the header bbox not to have first row's bbox inside the headers bbox | |
| rows_ = {table_feature : (xmin, ymin, xmax, ymax) for table_feature, (xmin, ymin, xmax, ymax) in sorted(rows.items(), key=lambda tup: tup[1][1])} | |
| cols_ = {table_feature : (xmin, ymin, xmax, ymax) for table_feature, (xmin, ymin, xmax, ymax) in sorted(cols.items(), key=lambda tup: tup[1][0])} | |
| return rows_, cols_ | |
| def individual_table_features(pil_img, header, row_header, rows, cols): | |
| for k, v in header.items(): | |
| xmin, ymin, xmax, ymax = v | |
| cropped_img = pil_img.crop((xmin, ymin, xmax, ymax)) | |
| header[k] = xmin, ymin, xmax, ymax, cropped_img | |
| for k, v in row_header.items(): | |
| xmin, ymin, xmax, ymax = v | |
| cropped_img = pil_img.crop((xmin, ymin, xmax, ymax)) | |
| row_header[k] = xmin, ymin, xmax, ymax, cropped_img | |
| for k, v in rows.items(): | |
| xmin, ymin, xmax, ymax = v | |
| cropped_img = pil_img.crop((xmin, ymin, xmax, ymax)) | |
| rows[k] = xmin, ymin, xmax, ymax, cropped_img | |
| for k, v in cols.items(): | |
| xmin, ymin, xmax, ymax = v | |
| cropped_img = pil_img.crop((xmin, ymin, xmax, ymax)) | |
| cols[k] = xmin, ymin, xmax, ymax, cropped_img | |
| return header, row_header, rows, cols | |
| def individual_table_featuresv2(pil_img, rows, cols): | |
| for k, v in rows.items(): | |
| xmin, ymin, xmax, ymax = v | |
| cropped_img = pil_img.crop((xmin, ymin, xmax, ymax)) | |
| rows[k] = xmin, ymin, xmax, ymax, cropped_img | |
| for k, v in cols.items(): | |
| xmin, ymin, xmax, ymax = v | |
| cropped_img = pil_img.crop((xmin, ymin, xmax, ymax)) | |
| cols[k] = xmin, ymin, xmax, ymax, cropped_img | |
| return rows, cols | |
| def plot_table_features(c2, header, row_header, rows, cols): | |
| for k, v in header.items(): | |
| _, _, _, _, pil_img = v | |
| for k, v in row_header.items(): | |
| _, _, _, _, pil_img = v | |
| for k, v in rows.items(): | |
| _, _, _, _, pil_img = v | |
| for k, v in cols.items(): | |
| _, _, _, _, pil_img = v | |
| def master_row_set(header, row_header, rows, cols): | |
| master_row = {**header, **row_header, **rows} | |
| master_row_ = {table_feature : (xmin, ymin, xmax, ymax, img) for table_feature, (xmin, ymin, xmax, ymax, img) in sorted(master_row.items(), key=lambda tup: tup[1][1])} | |
| return master_row_ | |
| def object_to_cells(master_row, cols): | |
| ''' | |
| Iterates to every row, be it header/simple row/header table row, cuts rows into cells and saves images in dictionary where length of dictionary = total rows | |
| ''' | |
| cells_img = {} | |
| header_idx = 0 | |
| row_idx = 0 | |
| for k_row, v_row in master_row.items(): | |
| if k_row[:16] == 'header table row': | |
| _, _, _, _, row_header_img = v_row | |
| cells_img[k_row+'.'+str(row_idx)] = row_header_img | |
| row_idx += 1 | |
| elif k_row == 'header': | |
| _, ymin, _, ymax, header_img = v_row | |
| xa, ya, xb, yb = 0, 0, 0, ymax-ymin | |
| for k_col, v_col in cols.items(): | |
| xmin_col, _, xmax_col, _, col_img = v_col | |
| xa = xmin_col-19 | |
| xb = xmax_col-20 | |
| header_img_cropped = header_img.crop((xa, ya, xb, yb)) | |
| cells_img[k_row+'.'+str(header_idx)] = header_img_cropped | |
| header_idx += 1 | |
| elif k_row[:9] == 'table row': | |
| xmin, ymin, xmax, ymax, row_img = v_row | |
| xa, ya, xb, yb = 0, 0, 0, ymax-ymin | |
| row_img_list = [] | |
| for k_col, v_col in cols.items(): | |
| xmin_col, _, xmax_col, _, col_img = v_col | |
| xa = xmin_col-19 | |
| xb = xmax_col-20 | |
| row_img_cropped = row_img.crop((xa, ya, xb, yb)) | |
| row_img_list.append(row_img_cropped) | |
| cells_img[k_row+'.'+str(row_idx)] = row_img_list | |
| row_idx += 1 | |
| return cells_img | |
| def object_to_cellsv2(master_row, cols): | |
| ''' | |
| Iterates to every row, be it header/simple row/header table row, cuts rows into cells and saves images in dictionary where length of dictionary = total rows | |
| ''' | |
| cells_img = {} | |
| header_idx = 0 | |
| row_idx = 0 | |
| for k_row, v_row in master_row.items(): | |
| xmin, ymin, xmax, ymax, row_img = v_row | |
| xa, ya, xb, yb = 0, 0, 0, ymax-ymin | |
| row_img_list = [] | |
| for k_col, v_col in cols.items(): | |
| xmin_col, _, xmax_col, _, col_img = v_col | |
| xa = xmin_col-19 | |
| xb = xmax_col-20 | |
| row_img_cropped = row_img.crop((xa, ya, xb, yb)) | |
| row_img_list.append(row_img_cropped) | |
| cells_img[k_row+'.'+str(row_idx)] = row_img_list | |
| row_idx += 1 | |
| return cells_img |