Shortcuts

Source code for torchtext.utils

import requests
import csv
import hashlib
from tqdm import tqdm
import os
import tarfile
import logging
import re
import sys
import zipfile
import gzip


[docs]def reporthook(t): """ https://github.com/tqdm/tqdm. """ last_b = [0] def inner(b=1, bsize=1, tsize=None): """ b: int, optional Number of blocks just transferred [default: 1]. bsize: int, optional Size of each block (in tqdm units) [default: 1]. tsize: int, optional Total size (in tqdm units). If [default: None] remains unchanged. """ if tsize is not None: t.total = tsize t.update((b - last_b[0]) * bsize) last_b[0] = b return inner
[docs]def download_from_url(url, path=None, root='.data', overwrite=False, hash_value=None, hash_type="sha256"): """Download file, with logic (from tensor2tensor) for Google Drive. Returns the path to the downloaded file. Arguments: url: the url of the file from URL header. (None) root: download folder used to store the file in (.data) overwrite: overwrite existing files (False) hash_value (str, optional): hash for url (Default: ``None``). hash_type (str, optional): hash type, among "sha256" and "md5" (Default: ``"sha256"``). Examples: >>> url = 'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/validation.tar.gz' >>> torchtext.utils.download_from_url(url) >>> url = 'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/validation.tar.gz' >>> torchtext.utils.download_from_url(url) >>> '.data/validation.tar.gz' """ def _check_hash(path): if hash_value: with open(path, "rb") as file_obj: if not validate_file(file_obj, hash_value, hash_type): raise RuntimeError("The hash of {} does not match. Delete the file manually and retry.".format(path)) def _process_response(r, root, filename): chunk_size = 16 * 1024 total_size = int(r.headers.get('Content-length', 0)) if filename is None: d = r.headers['content-disposition'] filename = re.findall("filename=\"(.+)\"", d) if filename is None: raise RuntimeError("Filename could not be autodetected") filename = filename[0] path = os.path.join(root, filename) if os.path.exists(path): logging.info('File %s already exists.' % path) if not overwrite: _check_hash(path) return path logging.info('Overwriting file %s.' % path) logging.info('Downloading file {} to {}.'.format(filename, path)) with open(path, "wb") as file: with tqdm(total=total_size, unit='B', unit_scale=1, desc=path.split('/')[-1]) as t: for chunk in r.iter_content(chunk_size): if chunk: file.write(chunk) t.update(len(chunk)) logging.info('File {} downloaded.'.format(path)) _check_hash(path) return path if path is None: _, filename = os.path.split(url) else: root, filename = os.path.split(path) if not os.path.exists(root): try: os.makedirs(root) except OSError: print("Can't create the download directory {}.".format(root)) raise if filename is not None: path = os.path.join(root, filename) # skip requests.get if path exists and not overwrite. if os.path.exists(path): logging.info('File %s already exists.' % path) if not overwrite: _check_hash(path) return path if 'drive.google.com' not in url: response = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'}, stream=True) return _process_response(response, root, filename) else: # google drive links get filename from google drive filename = None logging.info('Downloading from Google Drive; may take a few minutes') confirm_token = None session = requests.Session() response = session.get(url, stream=True) for k, v in response.cookies.items(): if k.startswith("download_warning"): confirm_token = v if confirm_token: url = url + "&confirm=" + confirm_token response = session.get(url, stream=True) return _process_response(response, root, filename)
[docs]def unicode_csv_reader(unicode_csv_data, **kwargs): r"""Since the standard csv library does not handle unicode in Python 2, we need a wrapper. Borrowed and slightly modified from the Python docs: https://docs.python.org/2/library/csv.html#csv-examples Arguments: unicode_csv_data: unicode csv data (see example below) Examples: >>> from torchtext.utils import unicode_csv_reader >>> import io >>> with io.open(data_path, encoding="utf8") as f: >>> reader = unicode_csv_reader(f) """ # Fix field larger than field limit error maxInt = sys.maxsize while True: # decrease the maxInt value by factor 10 # as long as the OverflowError occurs. try: csv.field_size_limit(maxInt) break except OverflowError: maxInt = int(maxInt / 10) csv.field_size_limit(maxInt) for line in csv.reader(unicode_csv_data, **kwargs): yield line
def utf_8_encoder(unicode_csv_data): for line in unicode_csv_data: yield line.encode('utf-8')
[docs]def extract_archive(from_path, to_path=None, overwrite=False): """Extract archive. Arguments: from_path: the path of the archive. to_path: the root path of the extracted files (directory of from_path) overwrite: overwrite existing files (False) Returns: List of paths to extracted files even if not overwritten. Examples: >>> url = 'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/validation.tar.gz' >>> from_path = './validation.tar.gz' >>> to_path = './' >>> torchtext.utils.download_from_url(url, from_path) >>> torchtext.utils.extract_archive(from_path, to_path) >>> ['.data/val.de', '.data/val.en'] >>> torchtext.utils.download_from_url(url, from_path) >>> torchtext.utils.extract_archive(from_path, to_path) >>> ['.data/val.de', '.data/val.en'] """ if to_path is None: to_path = os.path.dirname(from_path) if from_path.endswith(('.tar.gz', '.tgz')): logging.info('Opening tar file {}.'.format(from_path)) with tarfile.open(from_path, 'r') as tar: files = [] for file_ in tar: file_path = os.path.join(to_path, file_.name) if file_.isfile(): files.append(file_path) if os.path.exists(file_path): logging.info('{} already extracted.'.format(file_path)) if not overwrite: continue tar.extract(file_, to_path) return files elif from_path.endswith('.zip'): assert zipfile.is_zipfile(from_path), from_path logging.info('Opening zip file {}.'.format(from_path)) with zipfile.ZipFile(from_path, 'r') as zfile: files = [] for file_ in zfile.namelist(): file_path = os.path.join(to_path, file_) files.append(file_path) if os.path.exists(file_path): logging.info('{} already extracted.'.format(file_path)) if not overwrite: continue zfile.extract(file_, to_path) files = [f for f in files if os.path.isfile(f)] return files elif from_path.endswith('.gz'): default_block_size = 65536 filename = from_path[:-3] files = [filename] with gzip.open(from_path, 'rb') as gzfile, \ open(filename, 'wb') as d_file: while True: block = gzfile.read(default_block_size) if not block: break else: d_file.write(block) d_file.write(block) return files else: raise NotImplementedError( "We currently only support tar.gz, .tgz, .gz and zip achives.")
def validate_file(file_obj, hash_value, hash_type="sha256"): """Validate a given file object with its hash. Args: file_obj: File object to read from. hash_value (str): Hash for url. hash_type (str, optional): Hash type, among "sha256" and "md5" (Default: ``"sha256"``). Returns: bool: return True if its a valid file, else False. """ if hash_type == "sha256": hash_func = hashlib.sha256() elif hash_type == "md5": hash_func = hashlib.md5() else: raise ValueError while True: # Read by chunk to avoid filling memory chunk = file_obj.read(1024 ** 2) if not chunk: break hash_func.update(chunk) return hash_func.hexdigest() == hash_value

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources