Source code for farm.file_utils

"""
Utilities for working with the local dataset cache.
This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
Copyright by the AllenNLP authors.
"""
from __future__ import absolute_import, division, print_function, unicode_literals

import json
import logging
import os
import tempfile
import tarfile
import zipfile

from functools import wraps
from hashlib import sha256
from io import open
from pathlib import Path

import boto3
import numpy as np
import requests
from botocore.exceptions import ClientError
from dotmap import DotMap
from tqdm import tqdm
from transformers.file_utils import cached_path

try:
    from torch.hub import _get_torch_home

    torch_cache_home = Path(_get_torch_home())
except ImportError:
    torch_cache_home = Path(os.path.expanduser(
        os.getenv(
            "TORCH_HOME", Path(os.getenv("XDG_CACHE_HOME", "~/.cache")) / "torch"
        )
    ))
default_cache_path = torch_cache_home / "farm"

try:
    from urllib.parse import urlparse
except ImportError:
    from urlparse import urlparse

try:
    from pathlib import Path

    FARM_CACHE = Path(os.getenv("FARM_CACHE", default_cache_path))
except (AttributeError, ImportError):
    FARM_CACHE = os.getenv("FARM_CACHE", default_cache_path)


logger = logging.getLogger(__name__)  # pylint: disable=invalid-name



[docs]def url_to_filename(url, etag=None): """ Convert `url` into a hashed filename in a repeatable way. If `etag` is specified, append its hash to the url's, delimited by a period. """ url_bytes = url.encode("utf-8") url_hash = sha256(url_bytes) filename = url_hash.hexdigest() if etag: etag_bytes = etag.encode("utf-8") etag_hash = sha256(etag_bytes) filename += "." + etag_hash.hexdigest() return filename
[docs]def filename_to_url(filename, cache_dir=None): """ Return the url and etag (which may be ``None``) stored for `filename`. Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist. """ if cache_dir is None: cache_dir = FARM_CACHE cache_path = cache_dir / filename if not os.path.exists(cache_path): raise EnvironmentError("file {} not found".format(cache_path)) meta_path = cache_path + ".json" if not os.path.exists(meta_path): raise EnvironmentError("file {} not found".format(meta_path)) with open(meta_path, encoding="utf-8") as meta_file: metadata = json.load(meta_file) url = metadata["url"] etag = metadata["etag"] return url, etag
[docs]def download_from_s3(s3_url: str, cache_dir: str = None, access_key: str = None, secret_access_key: str = None, region_name: str = None): """ Download a "folder" from s3 to local. Skip already existing files. Useful for downloading all files of one model The default and recommended authentication follows boto3's trajectory of checking for ENV variables, .aws/credentials etc. (see https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html). However, there's also the option to pass `access_key`, `secret_access_key` and `region_name` directly as this is needed in some enterprise enviroments with local s3 deployments. :param s3_url: Url of the "folder" in s3 (e.g. s3://mybucket/my_modelname) :param cache_dir: Optional local directory where the files shall be stored. If not supplied, we'll use a subfolder in torch's cache dir (~/.cache/torch/farm) :param access_key: Optional S3 Access Key :param secret_access_key: Optional S3 Secret Access Key :param region_name: Optional Region Name :return: local path of the folder """ if cache_dir is None: cache_dir = FARM_CACHE logger.info(f"Downloading from {s3_url} to {cache_dir}") if access_key or secret_access_key: assert secret_access_key and access_key, "You only supplied one of secret_access_key and access_key. We need both." session = boto3.Session( aws_access_key_id=access_key, aws_secret_access_key=secret_access_key, region_name=region_name ) s3_resource = session.resource('s3') else: s3_resource = boto3.resource('s3') bucket_name, s3_path = split_s3_path(s3_url) bucket = s3_resource.Bucket(bucket_name) objects = bucket.objects.filter(Prefix=s3_path) if not objects: raise ValueError("Could not find s3_url: {s3_url}") for obj in objects: path, filename = os.path.split(obj.key) path = os.path.join(cache_dir, path) # Create local folder if not os.path.exists(path): os.makedirs(path) # Download file if not present locally if filename: filepath = os.path.join(path, filename) if os.path.exists(filepath): logger.info(f"Skipping {obj.key} (exists locally)") else: logger.info(f"Downloading {obj.key} to {filepath} (size: {obj.size/1000000} MB)") bucket.download_file(obj.key, filepath) return path
[docs]def split_s3_path(url): """Split a full s3 path into the bucket name and path.""" parsed = urlparse(url) if not parsed.netloc or not parsed.path: raise ValueError("bad s3 path {}".format(url)) bucket_name = parsed.netloc s3_path = parsed.path # Remove '/' at beginning of path. if s3_path.startswith("/"): s3_path = s3_path[1:] return bucket_name, s3_path
[docs]def s3_request(func): """ Wrapper function for s3 requests in order to create more helpful error messages. """ @wraps(func) def wrapper(url, *args, **kwargs): try: return func(url, *args, **kwargs) except ClientError as exc: if int(exc.response["Error"]["Code"]) == 404: raise EnvironmentError("file {} not found".format(url)) else: raise return wrapper
[docs]@s3_request def s3_etag(url): """Check ETag on S3 object.""" s3_resource = boto3.resource("s3") bucket_name, s3_path = split_s3_path(url) s3_object = s3_resource.Object(bucket_name, s3_path) return s3_object.e_tag
[docs]@s3_request def s3_get(url, temp_file): """Pull a file directly from S3.""" s3_resource = boto3.resource("s3") bucket_name, s3_path = split_s3_path(url) s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
[docs]def http_get(url, temp_file, proxies=None): req = requests.get(url, stream=True, proxies=proxies) content_length = req.headers.get("Content-Length") total = int(content_length) if content_length is not None else None progress = tqdm(unit="B", total=total) for chunk in req.iter_content(chunk_size=1024): if chunk: # filter out keep-alive new chunks progress.update(len(chunk)) temp_file.write(chunk) progress.close()
[docs]def fetch_archive_from_http(url, output_dir, proxies=None): """ Fetch an archive (zip or tar.gz) from a url via http and extract content to an output directory. :param url: http address :type url: str :param output_dir: local path :type output_dir: str :param proxies: proxies details as required by requests library :type proxies: dict :return: bool if anything got fetched """ # verify & prepare local directory path = Path(output_dir) if not path.exists(): path.mkdir(parents=True) is_not_empty = len(list(Path(path).rglob("*"))) > 0 if is_not_empty: logger.info( f"Found data stored in `{output_dir}`. Delete this first if you really want to fetch new data." ) return False else: logger.info(f"Fetching from {url} to `{output_dir}`") # download & extract with tempfile.NamedTemporaryFile() as temp_file: http_get(url, temp_file, proxies=proxies) temp_file.flush() temp_file.seek(0) # making tempfile accessible # extract if url[-4:] == ".zip": archive = zipfile.ZipFile(temp_file.name) archive.extractall(output_dir) elif url[-7:] == ".tar.gz": archive = tarfile.open(temp_file.name) archive.extractall(output_dir) # temp_file gets deleted here return True
[docs]def load_from_cache(pretrained_model_name_or_path, s3_dict, **kwargs): # Adjusted from HF Transformers to fit loading WordEmbeddings from deepsets s3 # Load from URL or cache if already cached cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) s3_file = s3_dict[pretrained_model_name_or_path] try: resolved_file = cached_path( s3_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, ) if resolved_file is None: raise EnvironmentError except EnvironmentError: if pretrained_model_name_or_path in s3_dict: msg = "Couldn't reach server at '{}' to download data.".format( s3_file ) else: msg = ( "Model name '{}' was not found in model name list. " "We assumed '{}' was a path, a model identifier, or url to a configuration file or " "a directory containing such a file but couldn't find any such file at this path or url.".format( pretrained_model_name_or_path, s3_file, ) ) raise EnvironmentError(msg) if resolved_file == s3_file: logger.info("loading file {}".format(s3_file)) else: logger.info("loading file {} from cache at {}".format(s3_file, resolved_file)) return resolved_file
[docs]def read_set_from_file(filename): """ Extract a de-duped collection (set) of text from a file. Expected file format is one item per line. """ collection = set() with open(filename, "r", encoding="utf-8") as file_: for line in file_: collection.add(line.rstrip()) return collection
[docs]def get_file_extension(path, dot=True, lower=True): ext = os.path.splitext(path)[1] ext = ext if dot else ext[1:] return ext.lower() if lower else ext
[docs]def read_config(path): if path: with open(path) as json_data_file: conf_args = json.load(json_data_file) else: raise ValueError("No config provided for classifier") # flatten last part of config, take either value or default as value for gk, gv in conf_args.items(): for k, v in gv.items(): conf_args[gk][k] = v["value"] if (v["value"] is not None) else v["default"] # DotMap for making nested dictionary accessible through dot notation args = DotMap(conf_args, _dynamic=False) return args
[docs]def unnestConfig(config): """ This function creates a list of config files for evaluating parameters with different values. If a config parameter is of type list this list is iterated over and a config object without lists is returned. Can handle lists inside any number of parameters. Can handle nested (one level) configs """ nestedKeys = [] nestedVals = [] for gk, gv in config.items(): if(gk != "task"): for k, v in gv.items(): if isinstance(v, list): if ( k != "layer_dims" ): # exclude layer dims, since it is already a list nestedKeys.append([gk, k]) nestedVals.append(v) elif isinstance(v, dict): logger.warning("Config too deep! Working on %s" %(str(v))) if len(nestedKeys) == 0: unnestedConfig = [config] else: logger.info( "Nested config at parameters: %s" % (", ".join(".".join(x) for x in nestedKeys)) ) unnestedConfig = [] mesh = np.meshgrid( *nestedVals ) # get all combinations, each dimension corresponds to one parameter type # flatten mesh into shape: [num_parameters, num_combinations] so we can iterate in 2d over any paramter combinations mesh = [x.flatten() for x in mesh] # loop over all combinations for i in range(len(mesh[0])): tempconfig = config.copy() for j, k in enumerate(nestedKeys): if isinstance(k, str): tempconfig[k] = mesh[j][ i ] # get ith val of correct param value and overwrite original config elif len(k) == 2: tempconfig[k[0]][k[1]] = mesh[j][i] # set nested dictionary keys else: logger.warning("Config too deep! Working on %s" %(str(k))) unnestedConfig.append(tempconfig) return unnestedConfig