| Viewing file:  context.py (9.33 KB)      -rw-r--r-- Select action/file-type:
 
  (+) |  (+) |  (+) | Code (+) | Session (+) |  (+) | SDB (+) |  (+) |  (+) |  (+) |  (+) |  (+) | 
 
from __future__ import annotations
 import contextlib
 import functools
 import operator
 import os
 import shutil
 import subprocess
 import sys
 import tempfile
 import urllib.request
 import warnings
 from typing import Iterator
 
 
 if sys.version_info < (3, 12):
 from backports import tarfile
 else:
 import tarfile
 
 
 @contextlib.contextmanager
 def pushd(dir: str | os.PathLike) -> Iterator[str | os.PathLike]:
 """
 >>> tmp_path = getfixture('tmp_path')
 >>> with pushd(tmp_path):
 ...     assert os.getcwd() == os.fspath(tmp_path)
 >>> assert os.getcwd() != os.fspath(tmp_path)
 """
 
 orig = os.getcwd()
 os.chdir(dir)
 try:
 yield dir
 finally:
 os.chdir(orig)
 
 
 @contextlib.contextmanager
 def tarball(
 url, target_dir: str | os.PathLike | None = None
 ) -> Iterator[str | os.PathLike]:
 """
 Get a tarball, extract it, yield, then clean up.
 
 >>> import urllib.request
 >>> url = getfixture('tarfile_served')
 >>> target = getfixture('tmp_path') / 'out'
 >>> tb = tarball(url, target_dir=target)
 >>> import pathlib
 >>> with tb as extracted:
 ...     contents = pathlib.Path(extracted, 'contents.txt').read_text(encoding='utf-8')
 >>> assert not os.path.exists(extracted)
 """
 if target_dir is None:
 target_dir = os.path.basename(url).replace('.tar.gz', '').replace('.tgz', '')
 # In the tar command, use --strip-components=1 to strip the first path and
 #  then
 #  use -C to cause the files to be extracted to {target_dir}. This ensures
 #  that we always know where the files were extracted.
 os.mkdir(target_dir)
 try:
 req = urllib.request.urlopen(url)
 with tarfile.open(fileobj=req, mode='r|*') as tf:
 tf.extractall(path=target_dir, filter=strip_first_component)
 yield target_dir
 finally:
 shutil.rmtree(target_dir)
 
 
 def strip_first_component(
 member: tarfile.TarInfo,
 path,
 ) -> tarfile.TarInfo:
 _, member.name = member.name.split('/', 1)
 return member
 
 
 def _compose(*cmgrs):
 """
 Compose any number of dependent context managers into a single one.
 
 The last, innermost context manager may take arbitrary arguments, but
 each successive context manager should accept the result from the
 previous as a single parameter.
 
 Like :func:`jaraco.functools.compose`, behavior works from right to
 left, so the context manager should be indicated from outermost to
 innermost.
 
 Example, to create a context manager to change to a temporary
 directory:
 
 >>> temp_dir_as_cwd = _compose(pushd, temp_dir)
 >>> with temp_dir_as_cwd() as dir:
 ...     assert os.path.samefile(os.getcwd(), dir)
 """
 
 def compose_two(inner, outer):
 def composed(*args, **kwargs):
 with inner(*args, **kwargs) as saved, outer(saved) as res:
 yield res
 
 return contextlib.contextmanager(composed)
 
 return functools.reduce(compose_two, reversed(cmgrs))
 
 
 tarball_cwd = _compose(pushd, tarball)
 
 
 @contextlib.contextmanager
 def tarball_context(*args, **kwargs):
 warnings.warn(
 "tarball_context is deprecated. Use tarball or tarball_cwd instead.",
 DeprecationWarning,
 stacklevel=2,
 )
 pushd_ctx = kwargs.pop('pushd', pushd)
 with tarball(*args, **kwargs) as tball, pushd_ctx(tball) as dir:
 yield dir
 
 
 def infer_compression(url):
 """
 Given a URL or filename, infer the compression code for tar.
 
 >>> infer_compression('http://foo/bar.tar.gz')
 'z'
 >>> infer_compression('http://foo/bar.tgz')
 'z'
 >>> infer_compression('file.bz')
 'j'
 >>> infer_compression('file.xz')
 'J'
 """
 warnings.warn(
 "infer_compression is deprecated with no replacement",
 DeprecationWarning,
 stacklevel=2,
 )
 # cheat and just assume it's the last two characters
 compression_indicator = url[-2:]
 mapping = dict(gz='z', bz='j', xz='J')
 # Assume 'z' (gzip) if no match
 return mapping.get(compression_indicator, 'z')
 
 
 @contextlib.contextmanager
 def temp_dir(remover=shutil.rmtree):
 """
 Create a temporary directory context. Pass a custom remover
 to override the removal behavior.
 
 >>> import pathlib
 >>> with temp_dir() as the_dir:
 ...     assert os.path.isdir(the_dir)
 ...     _ = pathlib.Path(the_dir).joinpath('somefile').write_text('contents', encoding='utf-8')
 >>> assert not os.path.exists(the_dir)
 """
 temp_dir = tempfile.mkdtemp()
 try:
 yield temp_dir
 finally:
 remover(temp_dir)
 
 
 @contextlib.contextmanager
 def repo_context(url, branch=None, quiet=True, dest_ctx=temp_dir):
 """
 Check out the repo indicated by url.
 
 If dest_ctx is supplied, it should be a context manager
 to yield the target directory for the check out.
 """
 exe = 'git' if 'git' in url else 'hg'
 with dest_ctx() as repo_dir:
 cmd = [exe, 'clone', url, repo_dir]
 if branch:
 cmd.extend(['--branch', branch])
 devnull = open(os.path.devnull, 'w')
 stdout = devnull if quiet else None
 subprocess.check_call(cmd, stdout=stdout)
 yield repo_dir
 
 
 def null():
 """
 A null context suitable to stand in for a meaningful context.
 
 >>> with null() as value:
 ...     assert value is None
 
 This context is most useful when dealing with two or more code
 branches but only some need a context. Wrap the others in a null
 context to provide symmetry across all options.
 """
 warnings.warn(
 "null is deprecated. Use contextlib.nullcontext",
 DeprecationWarning,
 stacklevel=2,
 )
 return contextlib.nullcontext()
 
 
 class ExceptionTrap:
 """
 A context manager that will catch certain exceptions and provide an
 indication they occurred.
 
 >>> with ExceptionTrap() as trap:
 ...     raise Exception()
 >>> bool(trap)
 True
 
 >>> with ExceptionTrap() as trap:
 ...     pass
 >>> bool(trap)
 False
 
 >>> with ExceptionTrap(ValueError) as trap:
 ...     raise ValueError("1 + 1 is not 3")
 >>> bool(trap)
 True
 >>> trap.value
 ValueError('1 + 1 is not 3')
 >>> trap.tb
 <traceback object at ...>
 
 >>> with ExceptionTrap(ValueError) as trap:
 ...     raise Exception()
 Traceback (most recent call last):
 ...
 Exception
 
 >>> bool(trap)
 False
 """
 
 exc_info = None, None, None
 
 def __init__(self, exceptions=(Exception,)):
 self.exceptions = exceptions
 
 def __enter__(self):
 return self
 
 @property
 def type(self):
 return self.exc_info[0]
 
 @property
 def value(self):
 return self.exc_info[1]
 
 @property
 def tb(self):
 return self.exc_info[2]
 
 def __exit__(self, *exc_info):
 type = exc_info[0]
 matches = type and issubclass(type, self.exceptions)
 if matches:
 self.exc_info = exc_info
 return matches
 
 def __bool__(self):
 return bool(self.type)
 
 def raises(self, func, *, _test=bool):
 """
 Wrap func and replace the result with the truth
 value of the trap (True if an exception occurred).
 
 First, give the decorator an alias to support Python 3.8
 Syntax.
 
 >>> raises = ExceptionTrap(ValueError).raises
 
 Now decorate a function that always fails.
 
 >>> @raises
 ... def fail():
 ...     raise ValueError('failed')
 >>> fail()
 True
 """
 
 @functools.wraps(func)
 def wrapper(*args, **kwargs):
 with ExceptionTrap(self.exceptions) as trap:
 func(*args, **kwargs)
 return _test(trap)
 
 return wrapper
 
 def passes(self, func):
 """
 Wrap func and replace the result with the truth
 value of the trap (True if no exception).
 
 First, give the decorator an alias to support Python 3.8
 Syntax.
 
 >>> passes = ExceptionTrap(ValueError).passes
 
 Now decorate a function that always fails.
 
 >>> @passes
 ... def fail():
 ...     raise ValueError('failed')
 
 >>> fail()
 False
 """
 return self.raises(func, _test=operator.not_)
 
 
 class suppress(contextlib.suppress, contextlib.ContextDecorator):
 """
 A version of contextlib.suppress with decorator support.
 
 >>> @suppress(KeyError)
 ... def key_error():
 ...     {}['']
 >>> key_error()
 """
 
 
 class on_interrupt(contextlib.ContextDecorator):
 """
 Replace a KeyboardInterrupt with SystemExit(1)
 
 >>> def do_interrupt():
 ...     raise KeyboardInterrupt()
 >>> on_interrupt('error')(do_interrupt)()
 Traceback (most recent call last):
 ...
 SystemExit: 1
 >>> on_interrupt('error', code=255)(do_interrupt)()
 Traceback (most recent call last):
 ...
 SystemExit: 255
 >>> on_interrupt('suppress')(do_interrupt)()
 >>> with __import__('pytest').raises(KeyboardInterrupt):
 ...     on_interrupt('ignore')(do_interrupt)()
 """
 
 def __init__(self, action='error', /, code=1):
 self.action = action
 self.code = code
 
 def __enter__(self):
 return self
 
 def __exit__(self, exctype, excinst, exctb):
 if exctype is not KeyboardInterrupt or self.action == 'ignore':
 return
 elif self.action == 'error':
 raise SystemExit(self.code) from excinst
 return self.action == 'suppress'
 
 |