EasyOCR-model / unit_test /unit_test.py
anjali2002's picture
Initial commit of EasyOCR model
b4959be
import os
import sys
import importlib
import pickle
import lzma
import PIL.Image
import numpy as np
import torch
# %%
class Attributes:
pass
class UnitTest:
def __init__(self,
easyocr_module,
test_data = "./data/EasyOcrUnitTestPackage.pickle",
image_data_dir = "../examples",
verbose = 0,
numeric_acceptance_error = 0.1):
self.verbose = verbose
easy_ocr_init = os.path.join(easyocr_module, "__init__.py")
if not os.path.isfile(easy_ocr_init):
raise FileNotFoundError("Invalid easyocr_module. The directory should contain __init__.py.")
spec = importlib.util.spec_from_file_location("easyocr", easy_ocr_init)
easyocr = importlib.util.module_from_spec(spec)
sys.modules["easyocr"] = easyocr
spec.loader.exec_module(easyocr)
self.easyocr = easyocr
if not hasattr(self.easyocr, 'utils'):
setattr(self.easyocr, 'utils', importlib.import_module('easyocr.utils'))
if not hasattr(self.easyocr, 'detection'):
setattr(self.easyocr, 'detection', importlib.import_module('easyocr.detection'))
if not hasattr(self.easyocr, 'recognition'):
setattr(self.easyocr, 'recognition', importlib.import_module('easyocr.recognition'))
self.easyocr_dir = os.path.dirname(easyocr.__file__)
print("Unit test is set for EasyOCR at {}".format(os.path.abspath(self.easyocr_dir)))
self.image_data_dir = image_data_dir
self.set_data(test_data)
self.set_easyocr()
self.numeric_acceptance_error = numeric_acceptance_error
def set_data(self, test_data):
self.inputs = Attributes()
with lzma.open(test_data, 'rb') as fid:
solution_book = pickle.load(fid)
self.test_book = solution_book['tests']
if any([file not in os.listdir(self.image_data_dir) for file in solution_book['inputs']['images'].keys()]):
raise FileNotFoundError("Cannot find {} in {}.").format(', '.join([file for file in solution_book['inputs']['images'].keys()
if file not in os.listdir(self.image_data_dir)], self.image_data_dir))
images = {os.path.splitext(file)[0]: {
key: np.asarray(PIL.Image.open(os.path.join(self.image_data_dir, file)).crop(crop_box))[:,:,::-1] for (key,crop_box) in page.items()
} for (file,page) in solution_book['inputs']['images'].items()}
english_mini_bgr, english_mini_gray = self.easyocr.utils.reformat_input(images['english']['mini'])
english_small_bgr, english_small_gray = self.easyocr.utils.reformat_input(images['english']['small'])
images['english'].update({'mini_bgr': english_mini_bgr,
'mini_gray': english_mini_gray,
'small_bgr': english_small_bgr,
'small_gray': english_small_gray,
})
setattr(self.inputs, 'images', self.dict2attr(images))
setattr(self.inputs, 'easyocr_config', self.dict2attr(solution_book['inputs']['easyocr_config']))
def dict2attr(self, dict_):
attr = Attributes()
[setattr(attr, key, self.dict2attr(value)) if isinstance(value, dict) else setattr(attr, key, value) for (key,value) in dict_.items()]
return attr
def count_parameters(self, model):
return sum([param.numel() for param in model.parameters()])
def get_weight_norm(self, model):
with torch.no_grad():
return sum([param.norm() for param in model.parameters()]).cpu().item()
def get_nested_attr(self, parent, attr):
if len(attr.split(".")) == 1:
return getattr(parent, attr)
else:
attrs = attr.split(".")
parent = getattr(parent, attrs[0])
attr = ".".join(attrs[1:])
attr = self.get_nested_attr(parent, attr)
return attr
def easyocr_read_as(self, image, language):
if not isinstance(language, list):
language = [language]
reader = self.easyocr.Reader(language)
_, pred, confidence = reader.readtext(image)[0]
reader = None
torch.cuda.empty_cache()
return pred, confidence
def set_easyocr(self):
ocr = self.easyocr.Reader([self.inputs.easyocr_config.main_language])
setattr(self.easyocr, 'ocr', ocr)
def validate(self, test, solution, dtype):
if dtype == str:
return test == solution
elif np.issubdtype(dtype, np.integer):
return abs(1-test/solution) < self.numeric_acceptance_error
elif np.issubdtype(dtype, np.inexact):
return abs(1-test/solution) < self.numeric_acceptance_error
elif dtype == dict:
return self.are_dicts_equal(test, solution)
elif dtype == list or dtype == tuple:
return self.are_lists_equal(test, solution)
elif dtype == np.ndarray:
return (abs(1-test/solution) < self.numeric_acceptance_error).all()
elif dtype == torch.Tensor:
return (abs(1-test/solution) < self.numeric_acceptance_error).all()
else:
raise TypeError("Unsupport data type ({}) to validate. Supporting types are str, int, float, dict, list, np.ndarray, or torch.Tensor".format(dtype))
def are_dicts_equal(self, test, solution):
if test.keys() == solution.keys():
return all([self.validate(test[key], solution[key], type(solution[key])) for key in solution.keys()])
else:
return False
def are_lists_equal(self, test, solution):
if len(test) == len(solution):
return all([self.validate(tt, ss, type(ss)) for (tt,ss) in zip(test, solution)])
else:
return False
def is_list_or_tuple(self, test):
return isinstance(test, list) or isinstance(test, tuple)
#Should check length of results/solutions/dtypes
def validate_all(self, results, solutions, dtypes):
if not isinstance(results, list):
results = [results]
if not isinstance(solutions, list):
solutions = [solutions]
if not isinstance(dtypes, list):
dtypes = [dtypes]
validation = []
for (result, solution, dtype) in zip(results, solutions, dtypes):
if (not self.is_list_or_tuple(result)
and not self.is_list_or_tuple(result)
and not self.is_list_or_tuple(result)
):
validation.append(self.validate(result, solution, type(solution)))
elif(self.is_list_or_tuple(result)
and self.is_list_or_tuple(result)
and self.is_list_or_tuple(result)
):
validation.append(self.validate_all(results, solutions, type(solution)))
else:
raise
return all(validation)
def do_test(self, verbose = None):
if verbose is not None:
self.verbose = verbose
num_module_to_test = len(self.test_book)
num_module_pass = 0
print("Testing EasyOCR: {:d} modules will be tested.\n".format(num_module_to_test))
for name,tests in self.test_book.items():
num_test = len(tests)
num_passed = 0
min_pass = sum([test['severity'] == 'Error' for test in tests.values()])
if self.verbose > 0:
print("##Testing module {}: {:d} tests will be performed.".format(name, num_test))
for test_id, test in tests.items():
if self.verbose > 1:
print("#### {}: {}".format(test_id, test['description']))
if test['method'].startswith('unit_test.'):
test['method'] = '.'.join(test['method'].split('.')[1:])
test_method = self.get_nested_attr(self, test['method'])
test['input'] = [(self.get_nested_attr(self, '.'.join(input_.split('.')[1:]))
if input_.startswith('unit_test.') else input_) if isinstance(input_, str) else input_ for input_ in test['input']]
if verbose > 3:
print("###### Input: {}".format(test['input']))
results = test_method(*test['input'])
if verbose > 2:
print("###### Expected output: {}".format(test['output']))
print("###### Received output: {}".format(results))
test_result = self.validate(results, test['output'], type(test['output']))
if test_result:
num_passed += 1
if self.verbose > 1:
print("#### Passed. [{:d}/{:d}]".format(num_passed, num_test))
else:
if test['severity'] == "Warning":
num_passed += 1
if self.verbose > 1:
print("#### Passed. [{:d}/{:d}]".format(num_passed, num_test))
if self.verbose > 2:
print("##### Warning: While the result is considered as passed, the test yields results ({}) \
that are different from the expected values ({}). It is strongly recommended to make sure \
that this is expected.".format(results, test['output']))
else:
if self.verbose > 1:
print("#### Failed")
if self.verbose > 2:
print("##### The test yields results ({}) which are different from the expected values ({}).")
if num_passed >= min_pass:
num_module_pass += 1
if self.verbose > 0:
print("##Module {}: Passed.\n".format(name))
else:
print("##Module {}: Failed.\n".format(name))
print("#"*50)
if num_module_pass >= num_module_to_test:
print("Testing completed:\n Final result: Passed.")
else:
print("Testing completed:\n Final result: Failed.")