Spaces:
Runtime error
Runtime error
| import logging.config | |
| import os | |
| import pathlib | |
| import shutil | |
| import sys | |
| import tempfile | |
| import zipfile | |
| import filelock | |
| import requests | |
| import tqdm | |
| # Hide an error message from `tokenizers` if this process is forked. | |
| os.environ["TOKENIZERS_PARALLELISM"] = "True" | |
| def path_in_cache(file_path): | |
| try: | |
| os.makedirs(TEXTATTACK_CACHE_DIR) | |
| except FileExistsError: # cache path exists | |
| pass | |
| return os.path.join(TEXTATTACK_CACHE_DIR, file_path) | |
| def s3_url(uri): | |
| return "https://textattack.s3.amazonaws.com/" + uri | |
| def download_from_s3(folder_name, skip_if_cached=True): | |
| """Folder name will be saved as `<cache_dir>/textattack/<folder_name>`. If | |
| it doesn't exist on disk, the zip file will be downloaded and extracted. | |
| Args: | |
| folder_name (str): path to folder or file in cache | |
| skip_if_cached (bool): If `True`, skip downloading if content is already cached. | |
| Returns: | |
| str: path to the downloaded folder or file on disk | |
| """ | |
| cache_dest_path = path_in_cache(folder_name) | |
| os.makedirs(os.path.dirname(cache_dest_path), exist_ok=True) | |
| # Use a lock to prevent concurrent downloads. | |
| cache_dest_lock_path = cache_dest_path + ".lock" | |
| cache_file_lock = filelock.FileLock(cache_dest_lock_path) | |
| cache_file_lock.acquire() | |
| # Check if already downloaded. | |
| if skip_if_cached and os.path.exists(cache_dest_path): | |
| cache_file_lock.release() | |
| return cache_dest_path | |
| # If the file isn't found yet, download the zip file to the cache. | |
| downloaded_file = tempfile.NamedTemporaryFile( | |
| dir=TEXTATTACK_CACHE_DIR, suffix=".zip", delete=False | |
| ) | |
| folder_s3_url = s3_url(folder_name) | |
| http_get(folder_s3_url, downloaded_file) | |
| # Move or unzip the file. | |
| downloaded_file.close() | |
| if zipfile.is_zipfile(downloaded_file.name): | |
| unzip_file(downloaded_file.name, cache_dest_path) | |
| else: | |
| logger.info(f"Copying {downloaded_file.name} to {cache_dest_path}.") | |
| shutil.copyfile(downloaded_file.name, cache_dest_path) | |
| cache_file_lock.release() | |
| # Remove the temporary file. | |
| os.remove(downloaded_file.name) | |
| logger.info(f"Successfully saved {folder_name} to cache.") | |
| return cache_dest_path | |
| def download_from_url(url, save_path, skip_if_cached=True): | |
| """Downloaded file will be saved under | |
| `<cache_dir>/textattack/<save_path>`. If it doesn't exist on disk, the zip | |
| file will be downloaded and extracted. | |
| Args: | |
| url (str): URL path from which to download. | |
| save_path (str): path to which to save the downloaded content. | |
| skip_if_cached (bool): If `True`, skip downloading if content is already cached. | |
| Returns: | |
| str: path to the downloaded folder or file on disk | |
| """ | |
| cache_dest_path = path_in_cache(save_path) | |
| os.makedirs(os.path.dirname(cache_dest_path), exist_ok=True) | |
| # Use a lock to prevent concurrent downloads. | |
| cache_dest_lock_path = cache_dest_path + ".lock" | |
| cache_file_lock = filelock.FileLock(cache_dest_lock_path) | |
| cache_file_lock.acquire() | |
| # Check if already downloaded. | |
| if skip_if_cached and os.path.exists(cache_dest_path): | |
| cache_file_lock.release() | |
| return cache_dest_path | |
| # If the file isn't found yet, download the zip file to the cache. | |
| downloaded_file = tempfile.NamedTemporaryFile( | |
| dir=TEXTATTACK_CACHE_DIR, suffix=".zip", delete=False | |
| ) | |
| http_get(url, downloaded_file) | |
| # Move or unzip the file. | |
| downloaded_file.close() | |
| if zipfile.is_zipfile(downloaded_file.name): | |
| unzip_file(downloaded_file.name, cache_dest_path) | |
| else: | |
| logger.info(f"Copying {downloaded_file.name} to {cache_dest_path}.") | |
| shutil.copyfile(downloaded_file.name, cache_dest_path) | |
| cache_file_lock.release() | |
| # Remove the temporary file. | |
| os.remove(downloaded_file.name) | |
| logger.info(f"Successfully saved {url} to cache.") | |
| return cache_dest_path | |
| def unzip_file(path_to_zip_file, unzipped_folder_path): | |
| """Unzips a .zip file to folder path.""" | |
| logger.info(f"Unzipping file {path_to_zip_file} to {unzipped_folder_path}.") | |
| enclosing_unzipped_path = pathlib.Path(unzipped_folder_path).parent | |
| with zipfile.ZipFile(path_to_zip_file, "r") as zip_ref: | |
| zip_ref.extractall(enclosing_unzipped_path) | |
| def http_get(url, out_file, proxies=None): | |
| """Get contents of a URL and save to a file. | |
| https://github.com/huggingface/transformers/blob/master/src/transformers/file_utils.py | |
| """ | |
| logger.info(f"Downloading {url}.") | |
| req = requests.get(url, stream=True, proxies=proxies) | |
| content_length = req.headers.get("Content-Length") | |
| total = int(content_length) if content_length is not None else None | |
| if req.status_code == 403 or req.status_code == 404: | |
| raise Exception(f"Could not reach {url}.") | |
| progress = tqdm.tqdm(unit="B", unit_scale=True, total=total) | |
| for chunk in req.iter_content(chunk_size=1024): | |
| if chunk: # filter out keep-alive new chunks | |
| progress.update(len(chunk)) | |
| out_file.write(chunk) | |
| progress.close() | |
| if sys.stdout.isatty(): | |
| LOG_STRING = "\033[34;1mtextattack\033[0m" | |
| else: | |
| LOG_STRING = "textattack" | |
| logger = logging.getLogger(__name__) | |
| logging.config.dictConfig( | |
| {"version": 1, "loggers": {__name__: {"level": logging.INFO}}} | |
| ) | |
| formatter = logging.Formatter(f"{LOG_STRING}: %(message)s") | |
| stream_handler = logging.StreamHandler() | |
| stream_handler.setFormatter(formatter) | |
| logger.addHandler(stream_handler) | |
| logger.propagate = False | |
| def _post_install(): | |
| logger.info("Updating TextAttack package dependencies.") | |
| logger.info("Downloading NLTK required packages.") | |
| import nltk | |
| nltk.download("averaged_perceptron_tagger") | |
| nltk.download("stopwords") | |
| nltk.download("omw") | |
| nltk.download("universal_tagset") | |
| nltk.download("wordnet") | |
| nltk.download("punkt") | |
| try: | |
| import stanza | |
| stanza.download("en") | |
| except Exception: | |
| pass | |
| def set_cache_dir(cache_dir): | |
| """Sets all relevant cache directories to ``TA_CACHE_DIR``.""" | |
| # Tensorflow Hub cache directory | |
| os.environ["TFHUB_CACHE_DIR"] = cache_dir | |
| # HuggingFace `transformers` cache directory | |
| os.environ["PYTORCH_TRANSFORMERS_CACHE"] = cache_dir | |
| # HuggingFace `datasets` cache directory | |
| os.environ["HF_HOME"] = cache_dir | |
| # Basic directory for Linux user-specific non-data files | |
| os.environ["XDG_CACHE_HOME"] = cache_dir | |
| def _post_install_if_needed(): | |
| """Runs _post_install if hasn't been run since install.""" | |
| # Check for post-install file. | |
| post_install_file_path = path_in_cache("post_install_check_3") | |
| post_install_file_lock_path = post_install_file_path + ".lock" | |
| post_install_file_lock = filelock.FileLock(post_install_file_lock_path) | |
| post_install_file_lock.acquire() | |
| if os.path.exists(post_install_file_path): | |
| post_install_file_lock.release() | |
| return | |
| # Run post-install. | |
| _post_install() | |
| # Create file that indicates post-install completed. | |
| open(post_install_file_path, "w").close() | |
| post_install_file_lock.release() | |
| TEXTATTACK_CACHE_DIR = os.environ.get( | |
| "TA_CACHE_DIR", os.path.expanduser("~/.cache/textattack") | |
| ) | |
| if "TA_CACHE_DIR" in os.environ: | |
| set_cache_dir(os.environ["TA_CACHE_DIR"]) | |
| _post_install_if_needed() | |