Source code for pystow.utils

"""Utilities."""

from __future__ import annotations

import contextlib
import csv
import gzip
import io
import logging
import lzma
import pickle
import shutil
import tarfile
import typing
import zipfile
from collections.abc import Callable, Generator, Iterable, Mapping, Sequence
from io import BytesIO
from pathlib import Path, PurePosixPath
from subprocess import check_output
from typing import (
    TYPE_CHECKING,
    Any,
    BinaryIO,
    Literal,
    Protocol,
    TextIO,
    TypeAlias,
    TypeVar,
    cast,
    overload,
)
from urllib.parse import urlparse
from uuid import uuid4

from tqdm.auto import tqdm

from .download import (
    DownloadBackend,
    DownloadError,
    UnexpectedDirectoryError,
    download,
    download_from_google,
    download_from_s3,
)
from .env import (
    get_base,
    get_home,
    get_name,
    getenv_path,
    mkdir,
    mock_envvar,
    mock_home,
    use_appdirs,
)
from .hashing import (
    Hash,
    HexDigestError,
    HexDigestMismatch,
    get_hash_hexdigest,
    get_hashes,
    get_hexdigests_remote,
    get_offending_hexdigests,
    raise_on_digest_mismatch,
)
from .io_typing import (
    _MODE_TO_SIMPLE,
    MODE_MAP,
    OPERATION_VALUES,
    REPRESENTATION_VALUES,
    REVERSE_MODE_MAP,
    InvalidOperationError,
    InvalidRepresentationError,
    Operation,
    Reader,
    Representation,
    Writer,
    ensure_sensible_default_encoding,
    ensure_sensible_newline,
    get_mode_pair,
)
from .iter import reyield
from .pydantic_utils import (
    ModelValidateFailureAction,
    iter_pydantic_jsonl,
    iter_pydantic_tsv,
    model_dump_yaml,
    read_pydantic_json,
    read_pydantic_jsonl,
    read_pydantic_tsv,
    read_pydantic_yaml,
    stream_write_pydantic_jsonl,
    write_pydantic_json,
    write_pydantic_jsonl,
    write_pydantic_yaml,
)
from .safe_open import (
    is_url,
    open_inner_zipfile,
    open_url,
    safe_open,
    safe_open_dict_reader,
    safe_open_json,
    safe_open_yaml,
    safe_read_text,
    safe_write_text,
    write_json,
    write_yaml,
)
from ..constants import README_TEXT, TimeoutHint

if TYPE_CHECKING:
    import bs4
    import lxml.etree
    import numpy.typing
    import pandas
    import rdflib

__all__ = [
    "MODE_MAP",
    "OPERATION_VALUES",
    "REPRESENTATION_VALUES",
    "REVERSE_MODE_MAP",
    "DownloadBackend",
    "DownloadError",
    "Hash",
    "HeaderMismatchError",
    "HexDigestError",
    "HexDigestMismatch",
    "InvalidOperationError",
    "InvalidRepresentationError",
    "ModelValidateFailureAction",
    "Operation",
    "Representation",
    "UnexpectedDirectory",
    "UnexpectedDirectoryError",
    "download",
    "download_from_google",
    "download_from_s3",
    "get_base",
    "get_commit",
    "get_df_io",
    "get_hash_hexdigest",
    "get_hashes",
    "get_hexdigests_remote",
    "get_home",
    "get_mode_pair",
    "get_name",
    "get_np_io",
    "get_offending_hexdigests",
    "get_soup",
    "getenv_path",
    "gunzip",
    "gzip_compress",
    "is_url",
    "iter_pydantic_jsonl",
    "iter_pydantic_tsv",
    "iter_tarred_csvs",
    "iter_tarred_files",
    "iter_zipped_csvs",
    "iter_zipped_files",
    "mkdir",
    "mock_envvar",
    "mock_home",
    "model_dump_yaml",
    "n",
    "name_from_s3_key",
    "name_from_url",
    "open_inner_zipfile",
    "open_tarfile",
    "open_url",
    "open_zip_dict_reader",
    "open_zip_reader",
    "open_zip_writer",
    "open_zipfile",
    "path_to_sqlite",
    "raise_on_digest_mismatch",
    "read_lzma_csv",
    "read_pydantic_json",
    "read_pydantic_jsonl",
    "read_pydantic_tsv",
    "read_pydantic_yaml",
    "read_rdf",
    "read_tarfile_csv",
    "read_tarfile_xml",
    "read_zip_np",
    "read_zipfile_csv",
    "read_zipfile_rdf",
    "read_zipfile_xml",
    "reyield",
    "safe_open",
    "safe_open_dict_reader",
    "safe_open_dict_writer",
    "safe_open_json",
    "safe_open_reader",
    "safe_open_writer",
    "safe_open_yaml",
    "safe_read_text",
    "safe_tarfile_open",
    "safe_write_text",
    "safe_zipfile_open",
    "stream_write_pydantic_jsonl",
    "tarfile_writestr",
    "use_appdirs",
    "write_json",
    "write_lzma_csv",
    "write_pickle_gz",
    "write_pydantic_json",
    "write_pydantic_jsonl",
    "write_pydantic_yaml",
    "write_tarfile_csv",
    "write_tarfile_xml",
    "write_yaml",
    "write_zipfile_csv",
    "write_zipfile_np",
    "write_zipfile_rdf",
    "write_zipfile_xml",
]

logger = logging.getLogger(__name__)


#: Backwards compatible name
UnexpectedDirectory = UnexpectedDirectoryError


[docs] def name_from_url(url: str) -> str: """Get the filename from the end of the URL. :param url: A URL :returns: The name of the file at the end of the URL """ parse_result = urlparse(url) path = PurePosixPath(parse_result.path) name = path.name return name
def base_from_gzip_name(name: str) -> str: """Get the base name for a file after stripping the gz ending. :param name: The name of the gz file :returns: The cleaned name of the file, with no gz ending :raises ValueError: if the file does not end with ".gz" """ if not name.endswith(".gz"): raise ValueError(f"Name does not end with .gz: {name}") return name[: -len(".gz")]
[docs] def name_from_s3_key(key: str) -> str: """Get the filename from the S3 key. :param key: A S3 path :returns: The name of the file """ return key.split("/")[-1]
[docs] def n() -> str: """Get a random string for testing. :returns: A random string for testing purposes. """ return str(uuid4())
[docs] def get_df_io(df: pandas.DataFrame, sep: str = "\t", index: bool = False, **kwargs: Any) -> BytesIO: """Get the dataframe as bytes. :param df: A dataframe :param sep: The separator in the dataframe. Overrides Pandas default to use a tab. :param index: Should the index be output? Overrides the Pandas default to be false. :param kwargs: Additional kwargs to pass to :func:`pandas.DataFrame.to_csv`. :returns: A bytes object that can be used as a file. """ return io.BytesIO(df.to_csv(sep=sep, index=index, **kwargs).encode("utf-8"))
[docs] def get_np_io(arr: numpy.typing.ArrayLike, **kwargs: Any) -> BytesIO: """Get the numpy object as bytes. :param arr: Array-like :param kwargs: Additional kwargs to pass to :func:`numpy.save`. :returns: A bytes object that can be used as a file. """ import numpy as np bio = BytesIO() np.save(bio, arr, **kwargs) bio.seek(0) return bio
[docs] def write_pickle_gz( obj: Any, path: str | Path, **kwargs: Any, ) -> None: """Write an object to a gzipped pickle. :param obj: The object to write :param path: The path of the file to write to :param kwargs: Additional kwargs to pass to :func:`pickle.dump` """ with safe_open(path, representation="binary", operation="write") as file: pickle.dump(obj, file, **kwargs)
[docs] def write_lzma_csv( df: pandas.DataFrame, path: str | Path, sep: str = "\t", index: bool = False, **kwargs: Any, ) -> None: """Write a dataframe as an lzma-compressed file. :param df: A dataframe :param path: The path to the resulting LZMA compressed dataframe file :param sep: The separator in the dataframe. Overrides Pandas default to use a tab. :param index: Should the index be output? Overrides the Pandas default to be false. :param kwargs: Additional kwargs to pass to :func:`get_df_io` and transitively to :func:`pandas.DataFrame.to_csv`. """ with lzma.open(path, "wb") as file: df.to_csv(file, sep=sep, index=index, **kwargs)
[docs] def read_lzma_csv( path: str | Path, sep: str = "\t", **kwargs: Any, ) -> pandas.DataFrame: """Read a dataframe from a lzma-compressed file. :param path: The path to the resulting LZMA compressed dataframe file :param sep: The separator in the dataframe. Overrides Pandas default to use a tab. :param kwargs: Additional kwargs to pass to :func:`get_df_io` and transitively to :func:`pandas.DataFrame.to_csv`. """ import pandas as pd with lzma.open(path, "rb") as file: return pd.read_csv(file, sep=sep, **kwargs)
[docs] def write_zipfile_csv( df: pandas.DataFrame, path: str | Path, inner_path: str, sep: str = "\t", index: bool = False, **kwargs: Any, ) -> None: """Write a dataframe to an inner CSV file to a zip archive. :param df: A dataframe :param path: The path to the resulting zip archive :param inner_path: The path inside the zip archive to write the dataframe :param sep: The separator in the dataframe. Overrides Pandas default to use a tab. :param index: Should the index be output? Overrides the Pandas default to be false. :param kwargs: Additional kwargs to pass to :func:`get_df_io` and transitively to :func:`pandas.DataFrame.to_csv`. """ with open_zipfile(path, inner_path, operation="write", representation="binary") as file: df.to_csv(file, sep=sep, index=index, **kwargs)
[docs] def read_zipfile_csv( path: str | Path, inner_path: str, sep: str = "\t", **kwargs: Any ) -> pandas.DataFrame: """Read an inner CSV file from a zip archive. :param path: The path to the zip archive :param inner_path: The path inside the zip archive to the dataframe :param sep: The separator in the dataframe. Overrides Pandas default to use a tab. :param kwargs: Additional kwargs to pass to :func:`pandas.read_csv`. :returns: A dataframe """ import pandas as pd with open_zipfile(path, inner_path, representation="text", operation="read") as file: return pd.read_csv(file, sep=sep, **kwargs)
# docstr-coverage:excused `overload` @typing.overload @contextlib.contextmanager def open_zipfile( path: str | Path, inner_path: str, *, operation: Operation = ..., representation: Literal["text"] = ..., zipfile_kwargs: Mapping[str, Any] | None = ..., open_kwargs: Mapping[str, Any] | None = ..., encoding: str | None = ..., ) -> Generator[typing.TextIO, None, None]: ... # docstr-coverage:excused `overload` @typing.overload @contextlib.contextmanager def open_zipfile( path: str | Path, inner_path: str, *, operation: Operation = ..., representation: Literal["binary"] = ..., zipfile_kwargs: Mapping[str, Any] | None = ..., open_kwargs: Mapping[str, Any] | None = ..., encoding: str | None = ..., ) -> Generator[typing.BinaryIO, None, None]: ...
[docs] @contextlib.contextmanager def open_zipfile( path: str | Path, inner_path: str, *, operation: Operation = "read", representation: Representation = "text", zipfile_kwargs: Mapping[str, Any] | None = None, open_kwargs: Mapping[str, Any] | None = None, encoding: str | None = None, ) -> Generator[typing.TextIO, None, None] | Generator[typing.BinaryIO, None, None]: """Open a zipfile.""" mode = _MODE_TO_SIMPLE[operation] with ( zipfile.ZipFile(file=path, mode=mode, **(zipfile_kwargs or {})) as zip_file, open_inner_zipfile( zip_file, inner_path, operation=operation, representation=representation, open_kwargs=open_kwargs, encoding=encoding, ) as file, ): yield file
[docs] @contextlib.contextmanager def open_tarfile( path: str | Path, inner_path: str, *, operation: Operation = "read", representation: Representation = "binary", open_kwargs: Mapping[str, Any] | None = None, ) -> Generator[typing.IO[bytes], None, None]: """Open a tar file.""" if representation != "binary": raise NotImplementedError("tarfile must use binary representation") if operation == "read": with tarfile.open(path, "r", **(open_kwargs or {})) as tar: member = tar.getmember(inner_path) file = tar.extractfile(member) if file is None: raise FileNotFoundError(f"could not find {inner_path} in tarfile {path}") yield file elif operation == "write": file = BytesIO() yield file file.seek(0) tarinfo = tarfile.TarInfo(name=inner_path) tarinfo.size = len(file.getbuffer()) with tarfile.TarFile(path, mode="w") as tar_file: tar_file.addfile(tarinfo, file) else: raise InvalidOperationError(operation)
[docs] @contextlib.contextmanager def open_zip_reader( path: str | Path, inner_path: str, delimiter: str = "\t", **kwargs: Any ) -> Generator[Reader, None, None]: """Read an inner CSV file from a zip archive. :param path: The path to the zip archive :param inner_path: The path inside the zip archive to the CSV :param delimiter: The separator in the CSV. Defaults to tab. :param kwargs: Additional kwargs to pass to :func:`csv.reader`. :returns: A reader over the file """ with open_zipfile(path, inner_path, representation="text") as file: yield csv.reader(file, delimiter=delimiter, **kwargs)
[docs] @contextlib.contextmanager def open_zip_dict_reader( path: str | Path, inner_path: str, delimiter: str = "\t", **kwargs: Any ) -> Generator[csv.DictReader[str], None, None]: """Read an inner CSV file from a zip archive. :param path: The path to the zip archive :param inner_path: The path inside the zip archive to the CSV :param delimiter: The separator in the CSV. Defaults to tab. :param kwargs: Additional kwargs to pass to :class:`csv.DictReader`. :returns: A dictionary reader over the file """ with open_zipfile(path, inner_path, representation="text") as file: yield csv.DictReader(file, delimiter=delimiter, **kwargs)
[docs] @contextlib.contextmanager def open_zip_writer( path: str | Path, inner_path: str, delimiter: str = "\t", **kwargs: Any ) -> Generator[Writer, None, None]: """Open a writer for an inner CSV file from a zip archive. :param path: The path to the zip archive :param inner_path: The path inside the zip archive to the CSV :param delimiter: The separator in the CSV. Defaults to tab. :param kwargs: Additional kwargs to pass to :func:`csv.writer`. :returns: A writer over the file """ with open_zipfile(path, inner_path, operation="write", representation="text") as file: yield csv.writer(file, delimiter=delimiter, **kwargs)
[docs] def write_zipfile_xml( element_tree: lxml.etree.ElementTree, path: str | Path, inner_path: str, **kwargs: Any, ) -> None: """Write an XML element tree to an inner XML file to a zip archive. :param element_tree: An XML element tree :param path: The path to the resulting zip archive :param inner_path: The path inside the zip archive to write the XML element :param kwargs: Additional kwargs to pass to :func:`lxml.etree.tostring` """ from lxml import etree kwargs.setdefault("pretty_print", True) with open_zipfile(path, inner_path, operation="write", representation="binary") as file: file.write(etree.tostring(element_tree, **kwargs))
[docs] def read_zipfile_xml(path: str | Path, inner_path: str, **kwargs: Any) -> lxml.etree.ElementTree: """Read an inner XML file from a zip archive. :param path: The path to the zip archive :param inner_path: The path inside the zip archive to the xml file :param kwargs: Additional kwargs to pass to :func:`lxml.etree.parse` :returns: An XML element tree """ from lxml import etree with open_zipfile(path, inner_path, operation="read", representation="binary") as file: return etree.parse(file, **kwargs)
[docs] def write_zipfile_np( arr: numpy.typing.ArrayLike, path: str | Path, inner_path: str, **kwargs: Any, ) -> None: """Write a dataframe to an inner CSV file to a zip archive. :param arr: Array-like :param path: The path to the resulting zip archive :param inner_path: The path inside the zip archive to write the dataframe :param kwargs: Additional kwargs to pass to :func:`get_np_io` and transitively to :func:`numpy.save`. """ import numpy as np with open_zipfile(path, inner_path, operation="write", representation="binary") as file: np.save(file, arr, **kwargs)
[docs] def read_zip_np(path: str | Path, inner_path: str, **kwargs: Any) -> numpy.typing.ArrayLike: """Read an inner numpy array-like from a zip archive. :param path: The path to the zip archive :param inner_path: The path inside the zip archive to the dataframe :param kwargs: Additional kwargs to pass to :func:`numpy.load`. :returns: A numpy array or other object """ import numpy as np with open_zipfile(path, inner_path, operation="read", representation="binary") as file: return cast(np.typing.ArrayLike, np.load(file, **kwargs))
[docs] def read_zipfile_rdf(path: str | Path, inner_path: str, **kwargs: Any) -> rdflib.Graph: """Read an inner RDF file from a zip archive. :param path: The path to the zip archive :param inner_path: The path inside the zip archive to the dataframe :param kwargs: Additional kwargs to pass to :meth:`rdflib.Graph.parse`. :returns: A graph """ import rdflib graph = rdflib.Graph() with open_zipfile(path, inner_path, operation="read", representation="binary") as file: graph.parse(file, **kwargs) return graph
[docs] def write_zipfile_rdf( graph: rdflib.Graph, path: str | Path, inner_path: str, **kwargs: Any ) -> None: """Read an inner RDF file from a zip archive. :param graph: The graph to write :param path: The path to the zip archive :param inner_path: The path inside the zip archive to the dataframe :param kwargs: Additional kwargs to pass to :meth:`rdflib.Graph.parse`. """ with open_zipfile(path, inner_path, operation="write", representation="binary") as file: graph.serialize(file, **kwargs)
[docs] def write_tarfile_csv( df: pandas.DataFrame, path: str | Path, inner_path: str, sep: str = "\t", index: bool = False, **kwargs: Any, ) -> None: """Write a dataframe to an inner CSV file from a tar archive. :param df: A dataframe :param path: The path to the resulting tar archive :param inner_path: The path inside the tar archive to write the dataframe :param sep: The separator in the dataframe. Overrides Pandas default to use a tab. :param index: Should the index be output? Overrides the Pandas default to be false. :param kwargs: Additional kwargs to pass to :func:`get_df_io` and transitively to :func:`pandas.DataFrame.to_csv`. """ with open_tarfile(path, inner_path, operation="write") as file: df.to_csv(file, sep=sep, index=index, **kwargs)
[docs] def write_tarfile_xml( element_tree: lxml.etree.ElementTree, path: str | Path, inner_path: str, **kwargs: Any, ) -> None: """Write an XML document a tar archive. :param element_tree: An element :param path: The path to the resulting tar archive :param inner_path: The path inside the tar archive to write the dataframe :param kwargs: Additional kwargs to pass to :func:`lxml.etree.tostring` """ from lxml import etree kwargs.setdefault("pretty_print", True) with open_tarfile(path, inner_path, operation="write") as file: file.write(etree.tostring(element_tree, **kwargs))
[docs] def read_tarfile_csv( path: str | Path, inner_path: str, sep: str = "\t", **kwargs: Any ) -> pandas.DataFrame: """Read an inner CSV file from a tar archive. :param path: The path to the tar archive :param inner_path: The path inside the tar archive to the dataframe :param sep: The separator in the dataframe. Overrides Pandas default to use a tab. :param kwargs: Additional kwargs to pass to :func:`pandas.read_csv`. :returns: A dataframe """ import pandas as pd with open_tarfile(path, inner_path) as file: return pd.read_csv(file, sep=sep, **kwargs)
[docs] def read_tarfile_xml(path: str | Path, inner_path: str, **kwargs: Any) -> lxml.etree.ElementTree: """Read an inner XML file from a tar archive. :param path: The path to the tar archive :param inner_path: The path inside the tar archive to the xml file :param kwargs: Additional kwargs to pass to :func:`lxml.etree.parse` :returns: An XML element tree """ from lxml import etree with open_tarfile(path, inner_path) as file: return etree.parse(file, **kwargs)
[docs] def read_rdf(path: str | Path, **kwargs: Any) -> rdflib.Graph: """Read an RDF file with :mod:`rdflib`. :param path: The path to the RDF file :param kwargs: Additional kwargs to pass to :func:`rdflib.Graph.parse` :returns: A parsed RDF graph """ import rdflib graph = rdflib.Graph() with safe_open(path, representation="binary", operation="read") as file: graph.parse(file, **kwargs) return graph
def write_sql(df: pandas.DataFrame, name: str, path: str | Path, **kwargs: Any) -> None: """Write a dataframe as a SQL table. :param df: A dataframe :param name: The table the database to write to :param path: The path to the resulting tar archive :param kwargs: Additional keyword arguments to pass to :meth:`pandas.DataFrame.to_sql` """ import sqlite3 with contextlib.closing(sqlite3.connect(path)) as conn: df.to_sql(name, conn, **kwargs)
[docs] def get_commit(org: str, repo: str, provider: str = "git") -> str: """Get last commit hash for the given repo. :param org: The GitHub organization or owner :param repo: The GitHub repository name :param provider: The method for getting the most recent commit :returns: A commit hash's hex digest as a string :raises ValueError: if an invalid provider is given """ if provider == "git": output = check_output(["git", "ls-remote", f"https://github.com/{org}/{repo}"]) # noqa lines = (line.strip().split("\t") for line in output.decode("utf8").splitlines()) rv = next(line[0] for line in lines if line[1] == "HEAD") elif provider == "github": url = f"https://api.github.com/repos/{org}/{repo}/branches/master" res_json = safe_open_json(url) rv = res_json["commit"]["sha"] else: raise ValueError(f"invalid implementation: {provider}") return rv
def ensure_readme() -> None: """Ensure there's a README in the PyStow data directory. :raises PermissionError: If the script calling this function does not have adequate permissions to write a file into the PyStow home directory. """ try: readme_path = get_home(ensure_exists=True).joinpath("README.md") except PermissionError as e: raise PermissionError( "PyStow was not able to create its home directory in due to a lack of " "permissions. This can happen, e.g., if you're working on a server where you don't " "have full rights. See https://pystow.readthedocs.io/en/latest/installation.html#" "configuration for instructions on choosing a different home folder location for " "PyStow to somewhere where you have write permissions." ) from e if readme_path.is_file(): return with readme_path.open("w", encoding="utf8") as file: print(README_TEXT, file=file)
[docs] def path_to_sqlite(path: str | Path) -> str: """Convert a path to a SQLite connection string. :param path: A path to a SQLite database file :returns: A standard connection string to the database """ path = Path(path).expanduser().resolve() return f"sqlite:///{path.as_posix()}"
[docs] def gunzip(source: str | Path, target: str | Path | None = None, *, cleanup: bool = False) -> Path: """Unzip a file in the source to the target. :param source: The path to an input file :param target: The path to an output file :param cleanup: Whether to clean the output file """ source = Path(source).expanduser().resolve() if target is None: raise NotImplementedError else: target = Path(target).expanduser().resolve() with gzip.open(source, "rb") as in_file, open(target, "wb") as out_file: shutil.copyfileobj(in_file, out_file) if cleanup: source.unlink() return target
[docs] def gzip_compress( source: str | Path, *, target: str | Path | None = None, cleanup: bool = False ) -> Path: """Compress a file, then delete the original.""" source = Path(source).expanduser().resolve() if target is None: target = source.with_suffix(source.suffix + ".gz") else: target = Path(target).expanduser().resolve() with open(source, "rb") as in_file, gzip.open(target, "wb") as out_file: shutil.copyfileobj(in_file, out_file) if cleanup: source.unlink() return target
[docs] @contextlib.contextmanager def safe_open_writer( f: str | Path | TextIO, *, delimiter: str = "\t", **kwargs: Any ) -> Generator[Writer, None, None]: """Open a CSV writer, wrapping :func:`csv.writer`. :param f: A path to a file, or an already open text-based IO object :param delimiter: The delimiter for writing to CSV :param kwargs: Keyword arguments to pass to :func:`csv.writer` :yields: A CSV writer object, constructed from :func:`csv.writer` """ with safe_open(f, operation="write", representation="text") as file: yield csv.writer(file, delimiter=delimiter, **kwargs)
[docs] @contextlib.contextmanager def safe_open_dict_writer( f: str | Path | TextIO, fieldnames: typing.Sequence[str], *, delimiter: str = "\t", **kwargs: Any, ) -> Generator[csv.DictWriter[str], None, None]: """Open a CSV dictionary writer, wrapping :func:`csv.DictWriter`. :param f: A path to a file, or an already open text-based IO object :param fieldnames: A path to a file, or an already open text-based IO object :param delimiter: The delimiter for writing to CSV :param kwargs: Keyword arguments to pass to :func:`csv.DictWriter` :yields: A CSV dictionary writer object, constructed from :func:`csv.DictWriter` """ with safe_open(f, operation="write", representation="text", newline="") as file: yield csv.DictWriter(file, fieldnames, delimiter=delimiter, **kwargs)
[docs] @contextlib.contextmanager def safe_open_reader( f: str | Path | TextIO, *, delimiter: str = "\t", **kwargs: Any ) -> Generator[Reader, None, None]: """Open a CSV reader, wrapping :func:`csv.reader`. :param f: A path to a file, or an already open text-based IO object :param delimiter: The delimiter for writing to CSV :param kwargs: Keyword arguments to pass to :func:`csv.reader` :yields: A CSV reader object, constructed from :func:`csv.reader` """ with safe_open(f, operation="read", representation="text", newline="") as file: yield csv.reader(file, delimiter=delimiter, **kwargs)
[docs] def get_soup( url: str, *, verify: bool = True, timeout: TimeoutHint | None = None, user_agent: str | None = None, ) -> bs4.BeautifulSoup: """Get a beautiful soup parsed version of the given web page. :param url: The URL to download and parse with BeautifulSoup :param verify: Should SSL be used? This is almost always true, except for Ensembl, which makes a big pain :param timeout: How many integer seconds to wait for a response? Defaults to 15 if none given. :param user_agent: A custom user-agent to set, e.g., to avoid anti-crawling mechanisms :returns: A BeautifulSoup object """ import requests from bs4 import BeautifulSoup headers = {} if user_agent: headers["User-Agent"] = user_agent res = requests.get(url, verify=verify, timeout=timeout or 15, headers=headers) soup = BeautifulSoup(res.text, features="html.parser") return soup
ArchiveType = TypeVar("ArchiveType", contravariant=True) ArchiveInfo = TypeVar("ArchiveInfo", covariant=True) Predicate: TypeAlias = Callable[[ArchiveInfo], bool] class ArchivedFileIterator(Protocol[ArchiveType, ArchiveInfo]): """A protocol for opening files in an archive.""" # docstr-coverage:excused `overload` @overload def __call__( self, path: str | Path | ArchiveType, *, representation: Literal["binary"] = ..., progress: bool = ..., tqdm_kwargs: Mapping[str, Any] | None = ..., keep: Predicate[ArchiveInfo] | None = ..., open_kwargs: Mapping[str, Any] | None = ..., encoding: str | None = ..., newline: str | None = ..., ) -> Iterable[BinaryIO]: ... # docstr-coverage:excused `overload` @overload def __call__( self, path: str | Path | ArchiveType, *, representation: Literal["text"] = ..., progress: bool = ..., tqdm_kwargs: Mapping[str, Any] | None = ..., keep: Predicate[ArchiveInfo] | None = ..., open_kwargs: Mapping[str, Any] | None = ..., encoding: str | None = ..., newline: str | None = ..., ) -> Iterable[TextIO]: ... def __call__( self, path: str | Path | ArchiveType, *, representation: Representation = ..., progress: bool = True, tqdm_kwargs: Mapping[str, Any] | None = ..., keep: Predicate[ArchiveInfo] | None = ..., open_kwargs: Mapping[str, Any] | None = None, encoding: str | None = ..., newline: str | None = ..., ) -> Iterable[TextIO] | Iterable[BinaryIO]: ... # docstr-coverage:excused `overload` @overload def iter_tarred_files( path: str | Path | tarfile.TarFile, *, representation: Literal["binary"] = ..., progress: bool = ..., tqdm_kwargs: Mapping[str, Any] | None = ..., keep: Predicate[tarfile.TarInfo] | None = ..., open_kwargs: Mapping[str, Any] | None = ..., encoding: str | None = ..., newline: str | None = ..., ) -> Iterable[BinaryIO]: ... # docstr-coverage:excused `overload` @overload def iter_tarred_files( path: str | Path | tarfile.TarFile, *, representation: Literal["text"] = ..., progress: bool = ..., tqdm_kwargs: Mapping[str, Any] | None = ..., keep: Predicate[tarfile.TarInfo] | None = ..., open_kwargs: Mapping[str, Any] | None = ..., encoding: str | None = ..., newline: str | None = ..., ) -> Iterable[TextIO]: ...
[docs] def iter_tarred_files( path: str | Path | tarfile.TarFile, *, representation: Representation = "text", progress: bool = True, tqdm_kwargs: Mapping[str, Any] | None = None, keep: Predicate[tarfile.TarInfo] | None = None, open_kwargs: Mapping[str, Any] | None = None, encoding: str | None = None, newline: str | None = None, ) -> Iterable[TextIO] | Iterable[BinaryIO]: """Iterate over opened files in a tar archive in read mode.""" encoding = ensure_sensible_default_encoding(encoding, representation=representation) newline = ensure_sensible_newline(newline, representation=representation) with safe_tarfile_open(path) as tar_file: _tqdm_kwargs: dict[str, Any] = { "unit": "file", "unit_scale": True, } if isinstance(tar_file.name, str | Path): _tqdm_kwargs["desc"] = f"reading {Path(tar_file.name).name}" if tqdm_kwargs is not None: _tqdm_kwargs.update(tqdm_kwargs) for member in tqdm(tar_file.getmembers(), disable=not progress, **_tqdm_kwargs): if keep is not None and not keep(member): continue file = tar_file.extractfile(member, **(open_kwargs or {})) if file is None: continue if representation == "text": yield io.TextIOWrapper(file, encoding=encoding, newline=newline) else: yield cast(BinaryIO, file) # FIXME
[docs] @contextlib.contextmanager def safe_tarfile_open( tar_file: str | Path | tarfile.TarFile, ) -> Generator[tarfile.TarFile, None, None]: """Open a tar archive safely.""" if isinstance(tar_file, str | Path): with tarfile.open(Path(tar_file).expanduser().resolve(), mode="r") as tar_file: yield tar_file else: yield tar_file
ReturnType: TypeAlias = Literal["sequence", "record"] # docstr-coverage:excused `overload` @overload def iter_tarred_csvs( path: str | Path | tarfile.TarFile, *, progress: bool = ..., return_type: Literal["sequence"] = ..., max_line_length: int | None = ..., ) -> Iterable[Sequence[str]]: ... # docstr-coverage:excused `overload` @overload def iter_tarred_csvs( path: str | Path | tarfile.TarFile, *, progress: bool = ..., return_type: Literal["record"] = ..., max_line_length: int | None = ..., ) -> Iterable[dict[str, Any]]: ...
[docs] def iter_tarred_csvs( path: str | Path | tarfile.TarFile, *, progress: bool = True, return_type: ReturnType = "sequence", tqdm_kwargs: Mapping[str, Any] | None = None, max_line_length: int | None = None, encoding: str | None = None, ) -> Iterable[Sequence[str]] | Iterable[dict[str, Any]]: """Iterate over the lines from tarred CSV files.""" yield from _iter_archived_csvs( path, progress=progress, return_type=return_type, iter_files=iter_tarred_files, keep=_keep_tar_info_csv, tqdm_kwargs=tqdm_kwargs, max_line_length=max_line_length, encoding=encoding, )
def _keep_tar_info_csv(tar_info: tarfile.TarInfo) -> bool: return tar_info.name.endswith(".csv") # docstr-coverage:excused `overload` @overload def iter_zipped_files( path: str | Path | zipfile.ZipFile, *, representation: Literal["binary"] = ..., progress: bool = ..., tqdm_kwargs: Mapping[str, Any] | None = ..., keep: Predicate[zipfile.ZipInfo] | None = ..., open_kwargs: Mapping[str, Any] | None = ..., encoding: str | None = ..., newline: str | None = ..., ) -> Iterable[typing.BinaryIO]: ... # docstr-coverage:excused `overload` @overload def iter_zipped_files( path: str | Path | zipfile.ZipFile, *, representation: Literal["text"] = ..., progress: bool = ..., tqdm_kwargs: Mapping[str, Any] | None = ..., keep: Predicate[zipfile.ZipInfo] | None = ..., open_kwargs: Mapping[str, Any] | None = ..., encoding: str | None = ..., newline: str | None = ..., ) -> Iterable[typing.TextIO]: ...
[docs] def iter_zipped_files( path: str | Path | zipfile.ZipFile, *, representation: Representation = "text", progress: bool = True, tqdm_kwargs: Mapping[str, Any] | None = None, keep: Predicate[zipfile.ZipInfo] | None = None, open_kwargs: Mapping[str, Any] | None = None, encoding: str | None = None, newline: str | None = None, ) -> Iterable[typing.TextIO] | Iterable[typing.BinaryIO]: """Iterate over opened files in a zip file in read mode.""" with safe_zipfile_open(path) as zip_file: _tqdm_kwargs = { "desc": f"reading {zip_file.filename}", "unit": "file", "unit_scale": True, } if tqdm_kwargs is not None: _tqdm_kwargs.update(tqdm_kwargs) for info in tqdm(zip_file.infolist(), disable=not progress, **_tqdm_kwargs): if keep is not None and not keep(info): continue with open_inner_zipfile( zip_file, info.filename, operation="read", representation=representation, open_kwargs=open_kwargs, encoding=encoding, newline=newline, ) as file: yield file
[docs] @contextlib.contextmanager def safe_zipfile_open( zip_file: str | Path | zipfile.ZipFile, ) -> Generator[zipfile.ZipFile, None, None]: """Open a zip archive safely.""" if isinstance(zip_file, str | Path): with zipfile.ZipFile(Path(zip_file).expanduser().resolve(), mode="r") as zip_file: yield zip_file else: yield zip_file
# docstr-coverage:excused `overload` @overload def iter_zipped_csvs( path: str | Path | zipfile.ZipFile, *, progress: bool = ..., return_type: Literal["sequence"] = ..., tqdm_kwargs: Mapping[str, Any] | None = ..., max_line_length: int | None = ..., ) -> Iterable[Sequence[str]]: ... # docstr-coverage:excused `overload` @overload def iter_zipped_csvs( path: str | Path | zipfile.ZipFile, *, progress: bool = ..., return_type: Literal["record"] = ..., tqdm_kwargs: Mapping[str, Any] | None = ..., max_line_length: int | None = ..., ) -> Iterable[dict[str, Any]]: ...
[docs] def iter_zipped_csvs( path: str | Path | zipfile.ZipFile, *, progress: bool = True, return_type: ReturnType = "sequence", tqdm_kwargs: Mapping[str, Any] | None = None, max_line_length: int | None = None, encoding: str | None = None, ) -> Iterable[Sequence[str]] | Iterable[dict[str, Any]]: """Iterate over the lines from zipped CSV files.""" yield from _iter_archived_csvs( path, progress=progress, return_type=return_type, iter_files=iter_zipped_files, keep=_keep_zip_info_csv, tqdm_kwargs=tqdm_kwargs, max_line_length=max_line_length, encoding=encoding, )
def _keep_zip_info_csv(zip_info: zipfile.ZipInfo) -> bool: return zip_info.filename.endswith(".csv") def _iter_archived_csvs( path: str | Path | ArchiveType, *, progress: bool = True, tqdm_kwargs: Mapping[str, Any] | None = None, keep: Predicate[ArchiveInfo] | None = None, return_type: ReturnType = "sequence", iter_files: ArchivedFileIterator[ArchiveType, ArchiveInfo], max_line_length: int | None = None, encoding: str | None = None, ) -> Iterable[Sequence[str]] | Iterable[dict[str, Any]]: """Iterate over the lines from zipped CSV files.""" header: Sequence[str] | None = None for file in iter_files( path, representation="text", progress=progress, tqdm_kwargs=tqdm_kwargs, keep=keep, encoding=encoding, newline="", ): filename = file.name if max_line_length is not None: # this will break everything if there's an issue in the # header, but we aren't going to consider that case it = _cut_long_lines(file, max_line_length, filename) else: it = file reader: csv.DictReader[str] | Reader match return_type: case "sequence": reader = csv.reader(it) case "record": reader = csv.DictReader(it) case _: raise ValueError(f"unrecognized return type {return_type}") if header is None: header = _get_header(reader) elif (current_header := _get_header(reader)) != header: raise HeaderMismatchError(header, current_header) rv = tqdm( reader, disable=not progress, leave=False, desc=f"reading {filename}", unit="row", unit_scale=True, ) yield from rv def _cut_long_lines(it: Iterable[str], max_length: int, name: str) -> Iterable[str]: for i, line in enumerate(it): if len(line) > max_length: tqdm.write(f"[{name}:{i:,}] line of length {len(line):,} is too long: {line[:100]}") continue yield line def _get_header(reader: csv.DictReader[str] | Reader) -> Sequence[str]: if isinstance(reader, csv.DictReader): return cast(Sequence[str], reader.fieldnames) else: return next(reader)
[docs] class HeaderMismatchError(ValueError): """Raised when the current header in an archive of CSVs is different than the original.""" def __init__(self, original: Sequence[str], current: Sequence[str]) -> None: """Instantiate the error.""" self.original = original self.current = current def __str__(self) -> str: return ( f"header mismatch. first header was {self.original} " f"and current header is {self.current}" )
[docs] def tarfile_writestr(tar_file: tarfile.TarFile, filename: str, data: str) -> None: """Write to a tarfile.""" # TODO later, combine with other tarfile writing data_bytes = data.encode("utf-8") tar_info = tarfile.TarInfo(name=filename) tar_info.size = len(data_bytes) tar_file.addfile(tar_info, io.BytesIO(data_bytes))