|
|
import os |
|
|
import sys |
|
|
import importlib |
|
|
import pickle |
|
|
import lzma |
|
|
import PIL.Image |
|
|
import numpy as np |
|
|
|
|
|
import torch |
|
|
|
|
|
|
|
|
class Attributes: |
|
|
pass |
|
|
|
|
|
class UnitTest: |
|
|
def __init__(self, |
|
|
easyocr_module, |
|
|
test_data = "./data/EasyOcrUnitTestPackage.pickle", |
|
|
image_data_dir = "../examples", |
|
|
verbose = 0, |
|
|
numeric_acceptance_error = 0.1): |
|
|
|
|
|
self.verbose = verbose |
|
|
|
|
|
easy_ocr_init = os.path.join(easyocr_module, "__init__.py") |
|
|
if not os.path.isfile(easy_ocr_init): |
|
|
raise FileNotFoundError("Invalid easyocr_module. The directory should contain __init__.py.") |
|
|
|
|
|
spec = importlib.util.spec_from_file_location("easyocr", easy_ocr_init) |
|
|
easyocr = importlib.util.module_from_spec(spec) |
|
|
sys.modules["easyocr"] = easyocr |
|
|
spec.loader.exec_module(easyocr) |
|
|
|
|
|
self.easyocr = easyocr |
|
|
if not hasattr(self.easyocr, 'utils'): |
|
|
setattr(self.easyocr, 'utils', importlib.import_module('easyocr.utils')) |
|
|
if not hasattr(self.easyocr, 'detection'): |
|
|
setattr(self.easyocr, 'detection', importlib.import_module('easyocr.detection')) |
|
|
if not hasattr(self.easyocr, 'recognition'): |
|
|
setattr(self.easyocr, 'recognition', importlib.import_module('easyocr.recognition')) |
|
|
|
|
|
self.easyocr_dir = os.path.dirname(easyocr.__file__) |
|
|
|
|
|
print("Unit test is set for EasyOCR at {}".format(os.path.abspath(self.easyocr_dir))) |
|
|
|
|
|
self.image_data_dir = image_data_dir |
|
|
|
|
|
self.set_data(test_data) |
|
|
self.set_easyocr() |
|
|
self.numeric_acceptance_error = numeric_acceptance_error |
|
|
|
|
|
def set_data(self, test_data): |
|
|
|
|
|
self.inputs = Attributes() |
|
|
|
|
|
with lzma.open(test_data, 'rb') as fid: |
|
|
solution_book = pickle.load(fid) |
|
|
self.test_book = solution_book['tests'] |
|
|
|
|
|
if any([file not in os.listdir(self.image_data_dir) for file in solution_book['inputs']['images'].keys()]): |
|
|
raise FileNotFoundError("Cannot find {} in {}.").format(', '.join([file for file in solution_book['inputs']['images'].keys() |
|
|
if file not in os.listdir(self.image_data_dir)], self.image_data_dir)) |
|
|
images = {os.path.splitext(file)[0]: { |
|
|
key: np.asarray(PIL.Image.open(os.path.join(self.image_data_dir, file)).crop(crop_box))[:,:,::-1] for (key,crop_box) in page.items() |
|
|
} for (file,page) in solution_book['inputs']['images'].items()} |
|
|
|
|
|
english_mini_bgr, english_mini_gray = self.easyocr.utils.reformat_input(images['english']['mini']) |
|
|
english_small_bgr, english_small_gray = self.easyocr.utils.reformat_input(images['english']['small']) |
|
|
images['english'].update({'mini_bgr': english_mini_bgr, |
|
|
'mini_gray': english_mini_gray, |
|
|
'small_bgr': english_small_bgr, |
|
|
'small_gray': english_small_gray, |
|
|
}) |
|
|
|
|
|
setattr(self.inputs, 'images', self.dict2attr(images)) |
|
|
setattr(self.inputs, 'easyocr_config', self.dict2attr(solution_book['inputs']['easyocr_config'])) |
|
|
|
|
|
def dict2attr(self, dict_): |
|
|
attr = Attributes() |
|
|
[setattr(attr, key, self.dict2attr(value)) if isinstance(value, dict) else setattr(attr, key, value) for (key,value) in dict_.items()] |
|
|
return attr |
|
|
|
|
|
def count_parameters(self, model): |
|
|
return sum([param.numel() for param in model.parameters()]) |
|
|
|
|
|
def get_weight_norm(self, model): |
|
|
with torch.no_grad(): |
|
|
return sum([param.norm() for param in model.parameters()]).cpu().item() |
|
|
|
|
|
def get_nested_attr(self, parent, attr): |
|
|
if len(attr.split(".")) == 1: |
|
|
return getattr(parent, attr) |
|
|
else: |
|
|
attrs = attr.split(".") |
|
|
parent = getattr(parent, attrs[0]) |
|
|
attr = ".".join(attrs[1:]) |
|
|
attr = self.get_nested_attr(parent, attr) |
|
|
return attr |
|
|
|
|
|
def easyocr_read_as(self, image, language): |
|
|
if not isinstance(language, list): |
|
|
language = [language] |
|
|
reader = self.easyocr.Reader(language) |
|
|
_, pred, confidence = reader.readtext(image)[0] |
|
|
reader = None |
|
|
torch.cuda.empty_cache() |
|
|
return pred, confidence |
|
|
|
|
|
def set_easyocr(self): |
|
|
ocr = self.easyocr.Reader([self.inputs.easyocr_config.main_language]) |
|
|
setattr(self.easyocr, 'ocr', ocr) |
|
|
|
|
|
|
|
|
def validate(self, test, solution, dtype): |
|
|
if dtype == str: |
|
|
return test == solution |
|
|
elif np.issubdtype(dtype, np.integer): |
|
|
return abs(1-test/solution) < self.numeric_acceptance_error |
|
|
elif np.issubdtype(dtype, np.inexact): |
|
|
return abs(1-test/solution) < self.numeric_acceptance_error |
|
|
elif dtype == dict: |
|
|
return self.are_dicts_equal(test, solution) |
|
|
elif dtype == list or dtype == tuple: |
|
|
return self.are_lists_equal(test, solution) |
|
|
elif dtype == np.ndarray: |
|
|
return (abs(1-test/solution) < self.numeric_acceptance_error).all() |
|
|
elif dtype == torch.Tensor: |
|
|
return (abs(1-test/solution) < self.numeric_acceptance_error).all() |
|
|
else: |
|
|
raise TypeError("Unsupport data type ({}) to validate. Supporting types are str, int, float, dict, list, np.ndarray, or torch.Tensor".format(dtype)) |
|
|
|
|
|
def are_dicts_equal(self, test, solution): |
|
|
if test.keys() == solution.keys(): |
|
|
return all([self.validate(test[key], solution[key], type(solution[key])) for key in solution.keys()]) |
|
|
else: |
|
|
return False |
|
|
|
|
|
def are_lists_equal(self, test, solution): |
|
|
if len(test) == len(solution): |
|
|
return all([self.validate(tt, ss, type(ss)) for (tt,ss) in zip(test, solution)]) |
|
|
else: |
|
|
return False |
|
|
|
|
|
def is_list_or_tuple(self, test): |
|
|
return isinstance(test, list) or isinstance(test, tuple) |
|
|
|
|
|
|
|
|
def validate_all(self, results, solutions, dtypes): |
|
|
if not isinstance(results, list): |
|
|
results = [results] |
|
|
if not isinstance(solutions, list): |
|
|
solutions = [solutions] |
|
|
if not isinstance(dtypes, list): |
|
|
dtypes = [dtypes] |
|
|
|
|
|
|
|
|
validation = [] |
|
|
for (result, solution, dtype) in zip(results, solutions, dtypes): |
|
|
if (not self.is_list_or_tuple(result) |
|
|
and not self.is_list_or_tuple(result) |
|
|
and not self.is_list_or_tuple(result) |
|
|
): |
|
|
validation.append(self.validate(result, solution, type(solution))) |
|
|
elif(self.is_list_or_tuple(result) |
|
|
and self.is_list_or_tuple(result) |
|
|
and self.is_list_or_tuple(result) |
|
|
): |
|
|
validation.append(self.validate_all(results, solutions, type(solution))) |
|
|
else: |
|
|
raise |
|
|
return all(validation) |
|
|
|
|
|
def do_test(self, verbose = None): |
|
|
if verbose is not None: |
|
|
self.verbose = verbose |
|
|
|
|
|
num_module_to_test = len(self.test_book) |
|
|
num_module_pass = 0 |
|
|
print("Testing EasyOCR: {:d} modules will be tested.\n".format(num_module_to_test)) |
|
|
for name,tests in self.test_book.items(): |
|
|
num_test = len(tests) |
|
|
num_passed = 0 |
|
|
min_pass = sum([test['severity'] == 'Error' for test in tests.values()]) |
|
|
if self.verbose > 0: |
|
|
print("##Testing module {}: {:d} tests will be performed.".format(name, num_test)) |
|
|
for test_id, test in tests.items(): |
|
|
if self.verbose > 1: |
|
|
print("#### {}: {}".format(test_id, test['description'])) |
|
|
|
|
|
if test['method'].startswith('unit_test.'): |
|
|
test['method'] = '.'.join(test['method'].split('.')[1:]) |
|
|
test_method = self.get_nested_attr(self, test['method']) |
|
|
|
|
|
test['input'] = [(self.get_nested_attr(self, '.'.join(input_.split('.')[1:])) |
|
|
if input_.startswith('unit_test.') else input_) if isinstance(input_, str) else input_ for input_ in test['input']] |
|
|
if verbose > 3: |
|
|
print("###### Input: {}".format(test['input'])) |
|
|
results = test_method(*test['input']) |
|
|
if verbose > 2: |
|
|
print("###### Expected output: {}".format(test['output'])) |
|
|
print("###### Received output: {}".format(results)) |
|
|
test_result = self.validate(results, test['output'], type(test['output'])) |
|
|
if test_result: |
|
|
num_passed += 1 |
|
|
if self.verbose > 1: |
|
|
print("#### Passed. [{:d}/{:d}]".format(num_passed, num_test)) |
|
|
else: |
|
|
if test['severity'] == "Warning": |
|
|
num_passed += 1 |
|
|
if self.verbose > 1: |
|
|
print("#### Passed. [{:d}/{:d}]".format(num_passed, num_test)) |
|
|
if self.verbose > 2: |
|
|
print("##### Warning: While the result is considered as passed, the test yields results ({}) \ |
|
|
that are different from the expected values ({}). It is strongly recommended to make sure \ |
|
|
that this is expected.".format(results, test['output'])) |
|
|
else: |
|
|
if self.verbose > 1: |
|
|
print("#### Failed") |
|
|
if self.verbose > 2: |
|
|
print("##### The test yields results ({}) which are different from the expected values ({}).") |
|
|
|
|
|
if num_passed >= min_pass: |
|
|
num_module_pass += 1 |
|
|
if self.verbose > 0: |
|
|
print("##Module {}: Passed.\n".format(name)) |
|
|
else: |
|
|
print("##Module {}: Failed.\n".format(name)) |
|
|
|
|
|
print("#"*50) |
|
|
if num_module_pass >= num_module_to_test: |
|
|
print("Testing completed:\n Final result: Passed.") |
|
|
else: |
|
|
print("Testing completed:\n Final result: Failed.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|