"""Download utilities."""
from __future__ import annotations
import logging
import shutil
import urllib.error
from collections.abc import Mapping
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal, TypeAlias, TypedDict
from urllib.request import urlretrieve
from tqdm import tqdm
from typing_extensions import NotRequired, Unpack
from .hashing import raise_on_digest_mismatch
from ..constants import TimeoutHint
from ..version import VERSION
if TYPE_CHECKING:
import botocore.client
import requests
import requests.exceptions
__all__ = [
"DownloadBackend",
"DownloadError",
"DownloadKwargs",
"RequestKwargs",
"UnexpectedDirectoryError",
"download",
"download_from_google",
"download_from_s3",
]
logger = logging.getLogger(__file__)
#: Represents an available backend for downloading
DownloadBackend: TypeAlias = Literal["urllib", "requests"]
class TqdmReportHook(tqdm): # type:ignore
"""A custom progress bar that can be used with urllib.
Based on https://gist.github.com/leimao/37ff6e990b3226c2c9670a2cd1e4a6f5
"""
def update_to(
self,
blocks: int = 1,
block_size: int = 1,
total_size: int | None = None,
) -> None:
"""Update the internal state based on a urllib report hook.
:param blocks: Number of blocks transferred so far
:param block_size: Size of each block (in tqdm units)
:param total_size: Total size (in tqdm units). If [default: None] remains
unchanged.
"""
if total_size is not None:
self.total = total_size
self.update(blocks * block_size - self.n) # will also set self.n = b * bsize
class RequestKwargs(TypedDict):
"""Keyword arguments for :func:`requests.get`."""
auth: NotRequired[tuple[str, str]]
timeout: NotRequired[TimeoutHint]
allow_redirects: NotRequired[bool]
proxies: NotRequired[dict[str, str]]
verify: NotRequired[bool]
stream: NotRequired[bool]
cert: NotRequired[str | tuple[str, str]]
params: NotRequired[dict[str, Any]]
headers: NotRequired[dict[str, str | bytes | None] | None]
class DownloadKwargs(RequestKwargs):
"""Keyword arguments for :func:`download`."""
# note: `force` is intentionally omitted here because
# it is passed through from other signature components
clean_on_failure: NotRequired[bool]
backend: NotRequired[DownloadBackend | None]
hexdigests: NotRequired[Mapping[str, str] | None]
hexdigests_remote: NotRequired[Mapping[str, str] | None]
hexdigests_strict: NotRequired[bool]
progress_bar: NotRequired[bool]
tqdm_kwargs: NotRequired[Mapping[str, Any] | None]
DEFAULT_AGENT = f"pystow v{VERSION}"
[docs]
def download( # noqa:C901
url: str,
path: str | Path,
*,
force: bool = True,
clean_on_failure: bool = True,
backend: DownloadBackend | None = None,
hexdigests: Mapping[str, str] | None = None,
hexdigests_remote: Mapping[str, str] | None = None,
hexdigests_strict: bool = False,
progress_bar: bool = True,
tqdm_kwargs: Mapping[str, Any] | None = None,
_version: str | None = None,
**kwargs: Unpack[RequestKwargs],
) -> None:
"""Download a file from a given URL.
:param url: URL to download
:param path: Path to download the file to
:param force: If false and the file already exists, will not re-download.
:param clean_on_failure: If true, will delete the file on any exception raised
during download
:param backend: The downloader to use. Choose 'urllib' or 'requests'
:param hexdigests: The expected hexdigests as (algorithm_name, expected_hex_digest)
pairs.
:param hexdigests_remote: The expected hexdigests as (algorithm_name, url to file
with expected hexdigest) pairs.
:param hexdigests_strict: Set this to ``False`` to stop automatically checking for
the `algorithm(filename)=hash` format
:param progress_bar: Set to true to show a progress bar while downloading
:param tqdm_kwargs: Override the default arguments passed to :class:`tadm.tqdm` when
progress_bar is True.
:param kwargs: If using :func:`urllib.request.urlretrieve`, there are no keyword
arguments available. If using ``requests`` as a backend, passes these to
:func:`requests.get`. If using ``requests`` as a backend, ``stream`` is set to
True by default.
:raises Exception: Thrown if an error besides a keyboard interrupt is thrown during
download
:raises KeyboardInterrupt: If a keyboard interrupt is thrown during download
:raises UnexpectedDirectory: If a directory is given for the ``path`` argument
:raises ValueError: If an invalid backend is chosen
:raises DownloadError: If an error occurs during download
"""
path = Path(path).resolve()
if path.is_dir():
raise UnexpectedDirectoryError(path)
if path.is_file() and not force:
raise_on_digest_mismatch(
path=path,
hexdigests=hexdigests,
hexdigests_remote=hexdigests_remote,
hexdigests_strict=hexdigests_strict,
)
logger.debug("did not re-download %s from %s", path, url)
return
if backend is None:
backend = "urllib"
desc = f"Downloading {path.name}"
if _version:
desc += f" (v{_version})"
_tqdm_kwargs = {
"unit": "B",
"unit_scale": True,
"unit_divisor": 1024,
"miniters": 1,
"disable": not progress_bar,
"desc": desc,
"leave": False,
}
if tqdm_kwargs:
_tqdm_kwargs.update(tqdm_kwargs)
try:
if backend == "urllib":
logger.info("downloading with urllib from %s to %s", url, path)
if kwargs:
logger.warning(
"no kwargs should be supplied when using urllib, skipping: %s", kwargs
)
with TqdmReportHook(**_tqdm_kwargs) as t:
try:
urlretrieve(url, path, reporthook=t.update_to) # noqa:S310
except urllib.error.URLError as e:
logger.info("download failed %s to %s", url, e)
raise DownloadError(backend, url, path, e) from e
elif backend == "requests":
import requests
import requests.exceptions
kwargs.setdefault("stream", True)
if "headers" not in kwargs or kwargs["headers"] is None:
kwargs["headers"] = {}
# ignore the type error because we make sure it's a dict above
kwargs["headers"].setdefault("User-Agent", DEFAULT_AGENT) # type:ignore[union-attr]
try:
# see https://requests.readthedocs.io/en/master/user/quickstart/#raw-response-content
# pattern from https://stackoverflow.com/a/39217788/5775947
with requests.get(url, **kwargs) as response, path.open("wb") as file: # noqa:S113
logger.info(
"downloading (stream=%s) with requests from %s to %s",
kwargs["stream"],
url,
path,
)
# Solution for progress bar from https://stackoverflow.com/a/63831344/5775947
total_size = int(response.headers.get("Content-Length", 0))
# Decompress if needed
response.raw.read = partial( # type:ignore[method-assign]
response.raw.read, decode_content=True
)
with tqdm.wrapattr(
response.raw, "read", total=total_size, **_tqdm_kwargs
) as fsrc:
shutil.copyfileobj(fsrc, file)
except requests.exceptions.ConnectionError as e:
raise DownloadError(backend, url, path, e) from e
else:
raise ValueError(f'Invalid backend: {backend}. Use "requests" or "urllib".')
except (Exception, KeyboardInterrupt):
if clean_on_failure:
path.unlink(missing_ok=True)
raise
raise_on_digest_mismatch(
path=path,
hexdigests=hexdigests,
hexdigests_remote=hexdigests_remote,
hexdigests_strict=hexdigests_strict,
)
[docs]
class DownloadError(OSError):
"""An error that wraps information from a requests or urllib download failure."""
def __init__(
self,
backend: DownloadBackend,
url: str,
path: Path,
exc: urllib.error.URLError | requests.exceptions.ConnectionError,
) -> None:
"""Initialize the error.
:param backend: The backend used
:param url: The url that failed to download
:param path: The path that was supposed to be downloaded to
:param exc: The exception raised
"""
self.backend = backend
self.url = url
self.path = path
self.exc = exc
# TODO parse out HTTP error code, if possible
def __str__(self) -> str:
return f"Failed with {self.backend} to download {self.url} to {self.path}"
[docs]
class UnexpectedDirectoryError(FileExistsError):
"""Thrown if a directory path is given where file path should have been."""
def __init__(self, path: Path):
"""Instantiate the exception.
:param path: The path to a directory that should have been a file.
"""
self.path = path
def __str__(self) -> str:
return f"got directory instead of file: {self.path}"
CHUNK_SIZE = 32768
DOWNLOAD_URL = "https://docs.google.com/uc?export=download"
TOKEN_KEY = "download_warning" # noqa:S105
[docs]
def download_from_google(
file_id: str,
path: str | Path,
force: bool = True,
clean_on_failure: bool = True,
hexdigests: Mapping[str, str] | None = None,
) -> None:
"""Download a file from google drive.
Implementation inspired by https://github.com/ndrplz/google-drive-downloader.
:param file_id: The google file identifier
:param path: The place to write the file
:param force: If false and the file already exists, will not re-download.
:param clean_on_failure: If true, will delete the file on any exception raised
during download
:param hexdigests: The expected hexdigests as (algorithm_name, expected_hex_digest)
pairs.
:raises Exception: Thrown if an error besides a keyboard interrupt is thrown during
download
:raises KeyboardInterrupt: If a keyboard interrupt is thrown during download
:raises UnexpectedDirectory: If a directory is given for the ``path`` argument
"""
import requests
path = Path(path).resolve()
if path.is_dir():
raise UnexpectedDirectoryError(path)
if path.is_file() and not force:
raise_on_digest_mismatch(path=path, hexdigests=hexdigests)
logger.debug("did not re-download %s from Google ID %s", path, file_id)
return
try:
with requests.Session() as sess:
res = sess.get(DOWNLOAD_URL, params={"id": file_id}, stream=True)
token = _get_confirm_token(res)
res = sess.get(DOWNLOAD_URL, params={"id": file_id, "confirm": token}, stream=True)
with path.open("wb") as file:
for chunk in tqdm(res.iter_content(CHUNK_SIZE), desc="writing", unit="chunk"):
if chunk: # filter out keep-alive new chunks
file.write(chunk)
except (Exception, KeyboardInterrupt):
if clean_on_failure:
path.unlink(missing_ok=True)
raise
raise_on_digest_mismatch(path=path, hexdigests=hexdigests)
def _get_confirm_token(res: requests.Response) -> str:
for key, value in res.cookies.items():
if key.startswith(TOKEN_KEY):
return value
raise ValueError(f"no token found with key {TOKEN_KEY} in cookies: {res.cookies}")
[docs]
def download_from_s3(
s3_bucket: str,
s3_key: str,
path: str | Path,
client: None | botocore.client.BaseClient = None,
client_kwargs: Mapping[str, Any] | None = None,
download_file_kwargs: Mapping[str, Any] | None = None,
force: bool = True,
clean_on_failure: bool = True,
) -> None:
"""Download a file from S3.
:param s3_bucket: The key inside the S3 bucket name
:param s3_key: The key inside the S3 bucket
:param path: The place to write the file
:param client: A botocore client. If none given, one will be created automatically
:param client_kwargs: Keyword arguments to be passed to the client on instantiation.
:param download_file_kwargs: Keyword arguments to be passed to
:func:`boto3.s3.transfer.S3Transfer.download_file`
:param force: If false and the file already exists, will not re-download.
:param clean_on_failure: If true, will delete the file on any exception raised
during download
:raises Exception: Thrown if an error besides a keyboard interrupt is thrown during
download
:raises KeyboardInterrupt: If a keyboard interrupt is thrown during download
:raises UnexpectedDirectory: If a directory is given for the ``path`` argument
"""
path = Path(path).resolve()
if path.is_dir():
raise UnexpectedDirectoryError(path)
if path.is_file() and not force:
logger.debug("did not re-download %s from %s %s", path, s3_bucket, s3_key)
return
try:
import boto3.s3.transfer
if client is None:
import boto3
import botocore.client
client_kwargs = {} if client_kwargs is None else dict(client_kwargs)
client_kwargs.setdefault(
"config", botocore.client.Config(signature_version=botocore.UNSIGNED)
)
client = boto3.client("s3", **client_kwargs)
download_file_kwargs = {} if download_file_kwargs is None else dict(download_file_kwargs)
download_file_kwargs.setdefault(
"Config", boto3.s3.transfer.TransferConfig(use_threads=False)
)
client.download_file(s3_bucket, s3_key, path.as_posix(), **download_file_kwargs)
except (Exception, KeyboardInterrupt):
if clean_on_failure:
path.unlink(missing_ok=True)
raise