Spaces:
Sleeping
Sleeping
| """Modified from https://github.com/JaidedAI/EasyOCR/blob/803b907/easyocr/detection.py. | |
| 1. Disable DataParallel. | |
| """ | |
| import torch | |
| import torch.backends.cudnn as cudnn | |
| from torch.autograd import Variable | |
| from PIL import Image | |
| from collections import OrderedDict | |
| import cv2 | |
| import numpy as np | |
| from .craft_utils import getDetBoxes, adjustResultCoordinates | |
| from .imgproc import resize_aspect_ratio, normalizeMeanVariance | |
| from .craft import CRAFT | |
| def copyStateDict(state_dict): | |
| if list(state_dict.keys())[0].startswith("module"): | |
| start_idx = 1 | |
| else: | |
| start_idx = 0 | |
| new_state_dict = OrderedDict() | |
| for k, v in state_dict.items(): | |
| name = ".".join(k.split(".")[start_idx:]) | |
| new_state_dict[name] = v | |
| return new_state_dict | |
| def test_net(canvas_size, mag_ratio, net, image, text_threshold, link_threshold, low_text, poly, device, estimate_num_chars=False): | |
| if isinstance(image, np.ndarray) and len(image.shape) == 4: # image is batch of np arrays | |
| image_arrs = image | |
| else: # image is single numpy array | |
| image_arrs = [image] | |
| img_resized_list = [] | |
| # resize | |
| for img in image_arrs: | |
| img_resized, target_ratio, size_heatmap = resize_aspect_ratio(img, canvas_size, | |
| interpolation=cv2.INTER_LINEAR, | |
| mag_ratio=mag_ratio) | |
| img_resized_list.append(img_resized) | |
| ratio_h = ratio_w = 1 / target_ratio | |
| # preprocessing | |
| x = [np.transpose(normalizeMeanVariance(n_img), (2, 0, 1)) | |
| for n_img in img_resized_list] | |
| x = torch.from_numpy(np.array(x)) | |
| x = x.to(device) | |
| # forward pass | |
| with torch.no_grad(): | |
| y, feature = net(x) | |
| boxes_list, polys_list = [], [] | |
| for out in y: | |
| # make score and link map | |
| score_text = out[:, :, 0].cpu().data.numpy() | |
| score_link = out[:, :, 1].cpu().data.numpy() | |
| # Post-processing | |
| boxes, polys, mapper = getDetBoxes( | |
| score_text, score_link, text_threshold, link_threshold, low_text, poly, estimate_num_chars) | |
| # coordinate adjustment | |
| boxes = adjustResultCoordinates(boxes, ratio_w, ratio_h) | |
| polys = adjustResultCoordinates(polys, ratio_w, ratio_h) | |
| if estimate_num_chars: | |
| boxes = list(boxes) | |
| polys = list(polys) | |
| for k in range(len(polys)): | |
| if estimate_num_chars: | |
| boxes[k] = (boxes[k], mapper[k]) | |
| if polys[k] is None: | |
| polys[k] = boxes[k] | |
| boxes_list.append(boxes) | |
| polys_list.append(polys) | |
| return boxes_list, polys_list | |
| def get_detector(trained_model, device='cpu', quantize=True, cudnn_benchmark=False): | |
| net = CRAFT() | |
| if device == 'cpu': | |
| net.load_state_dict(copyStateDict(torch.load(trained_model, map_location=device))) | |
| if quantize: | |
| try: | |
| torch.quantization.quantize_dynamic(net, dtype=torch.qint8, inplace=True) | |
| except: | |
| pass | |
| else: | |
| net.load_state_dict(copyStateDict(torch.load(trained_model, map_location=device))) | |
| # net = torch.nn.DataParallel(net).to(device) | |
| net = net.to(device) | |
| cudnn.benchmark = cudnn_benchmark | |
| net.eval() | |
| return net | |
| def get_textbox(detector, image, canvas_size, mag_ratio, text_threshold, link_threshold, low_text, poly, device, optimal_num_chars=None, **kwargs): | |
| result = [] | |
| estimate_num_chars = optimal_num_chars is not None | |
| bboxes_list, polys_list = test_net(canvas_size, mag_ratio, detector, | |
| image, text_threshold, | |
| link_threshold, low_text, poly, | |
| device, estimate_num_chars) | |
| if estimate_num_chars: | |
| polys_list = [[p for p, _ in sorted(polys, key=lambda x: abs(optimal_num_chars - x[1]))] | |
| for polys in polys_list] | |
| for polys in polys_list: | |
| single_img_result = [] | |
| for i, box in enumerate(polys): | |
| poly = np.array(box).astype(np.int32).reshape((-1)) | |
| single_img_result.append(poly) | |
| result.append(single_img_result) | |
| return result | |