diff --git a/.gitignore b/.gitignore index 301efbf..24c5ae3 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ *.pyo -.DS_Store +*.pyc +.DS_Store diff --git a/README.md b/README.md index fa97fa2..260c547 100644 --- a/README.md +++ b/README.md @@ -3,19 +3,7 @@ Unofficial spotify plugin for Kodi, (for now) not yet available in the official Based on the opensource Librespot client. Special thanks to mherger for building the special spotty binaries, based on librespot. - -## Install with repository -Install the add-on from my Kodi repo: -https://github.com/kodi-community-addons/repository.marcelveldt/raw/master/repository.marcelveldt/repository.marcelveldt-1.0.1.zip - +This a fork of [marcelveldt](https://github.com/marcelveldt) version modified to make it work with Kodi 19. ## Support -Support is provided on the Kodi forums: -http://forum.kodi.tv/showthread.php?tid=265356 -Or create issue in Github - - -## Help needed with maintaining ! -I am very busy currently so I do not have a lot of time to work on this project or watch the forums. -Be aware that this is a community driven project, so feel free to submit PR's yourself to improve the code and/or help others with support on the forums etc. If you're willing to really participate in the development, please contact me so I can give you write access to the repo. I do my best to maintain the project every once in a while, when I have some spare time left. -Thanks for understanding! +create issue in Github diff --git a/addon.xml b/addon.xml index 83f2517..2299855 100644 --- a/addon.xml +++ b/addon.xml @@ -1,13 +1,12 @@ - + - + - - - - - + + + + audio @@ -18,5 +17,9 @@ Unofficial Spotify music plugin for Kodi Requires a Spotify premium account. This product uses SPOTIFY(R) CORE but is not endorsed, certified or otherwise approved in any way by Spotify. Spotify is the registered trade mark of the Spotify Group. + + resources/icon.png + resources/fanart.jpg + diff --git a/fanart.jpg b/resources/fanart.jpg similarity index 100% rename from fanart.jpg rename to resources/fanart.jpg diff --git a/icon.png b/resources/icon.png similarity index 100% rename from icon.png rename to resources/icon.png diff --git a/resources/lib/backports/__init__.py b/resources/lib/backports/__init__.py new file mode 100644 index 0000000..69e3be5 --- /dev/null +++ b/resources/lib/backports/__init__.py @@ -0,0 +1 @@ +__path__ = __import__('pkgutil').extend_path(__path__, __name__) diff --git a/resources/lib/backports/functools_lru_cache.py b/resources/lib/backports/functools_lru_cache.py new file mode 100644 index 0000000..707c6c7 --- /dev/null +++ b/resources/lib/backports/functools_lru_cache.py @@ -0,0 +1,184 @@ +from __future__ import absolute_import + +import functools +from collections import namedtuple +from threading import RLock + +_CacheInfo = namedtuple("CacheInfo", ["hits", "misses", "maxsize", "currsize"]) + + +@functools.wraps(functools.update_wrapper) +def update_wrapper(wrapper, + wrapped, + assigned = functools.WRAPPER_ASSIGNMENTS, + updated = functools.WRAPPER_UPDATES): + """ + Patch two bugs in functools.update_wrapper. + """ + # workaround for http://bugs.python.org/issue3445 + assigned = tuple(attr for attr in assigned if hasattr(wrapped, attr)) + wrapper = functools.update_wrapper(wrapper, wrapped, assigned, updated) + # workaround for https://bugs.python.org/issue17482 + wrapper.__wrapped__ = wrapped + return wrapper + + +class _HashedSeq(list): + __slots__ = 'hashvalue' + + def __init__(self, tup, hash=hash): + self[:] = tup + self.hashvalue = hash(tup) + + def __hash__(self): + return self.hashvalue + + +def _make_key(args, kwds, typed, + kwd_mark=(object(),), + fasttypes=set([int, str, frozenset, type(None)]), + sorted=sorted, tuple=tuple, type=type, len=len): + 'Make a cache key from optionally typed positional and keyword arguments' + key = args + if kwds: + sorted_items = sorted(kwds.items()) + key += kwd_mark + for item in sorted_items: + key += item + if typed: + key += tuple(type(v) for v in args) + if kwds: + key += tuple(type(v) for k, v in sorted_items) + elif len(key) == 1 and type(key[0]) in fasttypes: + return key[0] + return _HashedSeq(key) + + +def lru_cache(maxsize=100, typed=False): + """Least-recently-used cache decorator. + + If *maxsize* is set to None, the LRU features are disabled and the cache + can grow without bound. + + If *typed* is True, arguments of different types will be cached separately. + For example, f(3.0) and f(3) will be treated as distinct calls with + distinct results. + + Arguments to the cached function must be hashable. + + View the cache statistics named tuple (hits, misses, maxsize, currsize) with + f.cache_info(). Clear the cache and statistics with f.cache_clear(). + Access the underlying function with f.__wrapped__. + + See: http://en.wikipedia.org/wiki/Cache_algorithms#Least_Recently_Used + + """ + + # Users should only access the lru_cache through its public API: + # cache_info, cache_clear, and f.__wrapped__ + # The internals of the lru_cache are encapsulated for thread safety and + # to allow the implementation to change (including a possible C version). + + def decorating_function(user_function): + + cache = dict() + stats = [0, 0] # make statistics updateable non-locally + HITS, MISSES = 0, 1 # names for the stats fields + make_key = _make_key + cache_get = cache.get # bound method to lookup key or return None + _len = len # localize the global len() function + lock = RLock() # because linkedlist updates aren't threadsafe + root = [] # root of the circular doubly linked list + root[:] = [root, root, None, None] # initialize by pointing to self + nonlocal_root = [root] # make updateable non-locally + PREV, NEXT, KEY, RESULT = 0, 1, 2, 3 # names for the link fields + + if maxsize == 0: + + def wrapper(*args, **kwds): + # no caching, just do a statistics update after a successful call + result = user_function(*args, **kwds) + stats[MISSES] += 1 + return result + + elif maxsize is None: + + def wrapper(*args, **kwds): + # simple caching without ordering or size limit + key = make_key(args, kwds, typed) + result = cache_get(key, root) # root used here as a unique not-found sentinel + if result is not root: + stats[HITS] += 1 + return result + result = user_function(*args, **kwds) + cache[key] = result + stats[MISSES] += 1 + return result + + else: + + def wrapper(*args, **kwds): + # size limited caching that tracks accesses by recency + key = make_key(args, kwds, typed) if kwds or typed else args + with lock: + link = cache_get(key) + if link is not None: + # record recent use of the key by moving it to the front of the list + root, = nonlocal_root + link_prev, link_next, key, result = link + link_prev[NEXT] = link_next + link_next[PREV] = link_prev + last = root[PREV] + last[NEXT] = root[PREV] = link + link[PREV] = last + link[NEXT] = root + stats[HITS] += 1 + return result + result = user_function(*args, **kwds) + with lock: + root, = nonlocal_root + if key in cache: + # getting here means that this same key was added to the + # cache while the lock was released. since the link + # update is already done, we need only return the + # computed result and update the count of misses. + pass + elif _len(cache) >= maxsize: + # use the old root to store the new key and result + oldroot = root + oldroot[KEY] = key + oldroot[RESULT] = result + # empty the oldest link and make it the new root + root = nonlocal_root[0] = oldroot[NEXT] + oldkey = root[KEY] + root[KEY] = root[RESULT] = None + # now update the cache dictionary for the new links + del cache[oldkey] + cache[key] = oldroot + else: + # put result in a new link at the front of the list + last = root[PREV] + link = [last, root, key, result] + last[NEXT] = root[PREV] = cache[key] = link + stats[MISSES] += 1 + return result + + def cache_info(): + """Report cache statistics""" + with lock: + return _CacheInfo(stats[HITS], stats[MISSES], maxsize, len(cache)) + + def cache_clear(): + """Clear the cache and cache statistics""" + with lock: + cache.clear() + root = nonlocal_root[0] + root[:] = [root, root, None, None] + stats[:] = [0, 0] + + wrapper.__wrapped__ = user_function + wrapper.cache_info = cache_info + wrapper.cache_clear = cache_clear + return update_wrapper(wrapper, user_function) + + return decorating_function diff --git a/resources/lib/cheroot/__init__.py b/resources/lib/cheroot/__init__.py new file mode 100644 index 0000000..640234e --- /dev/null +++ b/resources/lib/cheroot/__init__.py @@ -0,0 +1,15 @@ +"""High-performance, pure-Python HTTP server used by CherryPy.""" + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +try: + import pkg_resources +except ImportError: + pass + + +try: + __version__ = pkg_resources.get_distribution('cheroot').version +except Exception: + __version__ = '8.5.2' diff --git a/resources/lib/cheroot/__main__.py b/resources/lib/cheroot/__main__.py new file mode 100644 index 0000000..d2e27c1 --- /dev/null +++ b/resources/lib/cheroot/__main__.py @@ -0,0 +1,6 @@ +"""Stub for accessing the Cheroot CLI tool.""" + +from .cli import main + +if __name__ == '__main__': + main() diff --git a/resources/lib/cheroot/_compat.py b/resources/lib/cheroot/_compat.py new file mode 100644 index 0000000..10dcdef --- /dev/null +++ b/resources/lib/cheroot/_compat.py @@ -0,0 +1,148 @@ +# pylint: disable=unused-import +"""Compatibility code for using Cheroot with various versions of Python.""" + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +import os +import platform +import re + +import six + +try: + import selectors # lgtm [py/unused-import] +except ImportError: + import selectors2 as selectors # noqa: F401 # lgtm [py/unused-import] + +try: + import ssl + IS_ABOVE_OPENSSL10 = ssl.OPENSSL_VERSION_INFO >= (1, 1) + del ssl +except ImportError: + IS_ABOVE_OPENSSL10 = None + +# contextlib.suppress was added in Python 3.4 +try: + from contextlib import suppress +except ImportError: + from contextlib import contextmanager + + @contextmanager + def suppress(*exceptions): + """Return a context manager that suppresses the `exceptions`.""" + try: + yield + except exceptions: + pass + + +IS_CI = bool(os.getenv('CI')) +IS_GITHUB_ACTIONS_WORKFLOW = bool(os.getenv('GITHUB_WORKFLOW')) + + +IS_PYPY = platform.python_implementation() == 'PyPy' + + +SYS_PLATFORM = platform.system() +IS_WINDOWS = SYS_PLATFORM == 'Windows' +IS_LINUX = SYS_PLATFORM == 'Linux' +IS_MACOS = SYS_PLATFORM == 'Darwin' + +PLATFORM_ARCH = platform.machine() +IS_PPC = PLATFORM_ARCH.startswith('ppc') + + +if not six.PY2: + def ntob(n, encoding='ISO-8859-1'): + """Return the native string as bytes in the given encoding.""" + assert_native(n) + # In Python 3, the native string type is unicode + return n.encode(encoding) + + def ntou(n, encoding='ISO-8859-1'): + """Return the native string as Unicode with the given encoding.""" + assert_native(n) + # In Python 3, the native string type is unicode + return n + + def bton(b, encoding='ISO-8859-1'): + """Return the byte string as native string in the given encoding.""" + return b.decode(encoding) +else: + # Python 2 + def ntob(n, encoding='ISO-8859-1'): + """Return the native string as bytes in the given encoding.""" + assert_native(n) + # In Python 2, the native string type is bytes. Assume it's already + # in the given encoding, which for ISO-8859-1 is almost always what + # was intended. + return n + + def ntou(n, encoding='ISO-8859-1'): + """Return the native string as Unicode with the given encoding.""" + assert_native(n) + # In Python 2, the native string type is bytes. + # First, check for the special encoding 'escape'. The test suite uses + # this to signal that it wants to pass a string with embedded \uXXXX + # escapes, but without having to prefix it with u'' for Python 2, + # but no prefix for Python 3. + if encoding == 'escape': + return re.sub( + r'\\u([0-9a-zA-Z]{4})', + lambda m: six.unichr(int(m.group(1), 16)), + n.decode('ISO-8859-1'), + ) + # Assume it's already in the given encoding, which for ISO-8859-1 + # is almost always what was intended. + return n.decode(encoding) + + def bton(b, encoding='ISO-8859-1'): + """Return the byte string as native string in the given encoding.""" + return b + + +def assert_native(n): + """Check whether the input is of native :py:class:`str` type. + + Raises: + TypeError: in case of failed check + + """ + if not isinstance(n, str): + raise TypeError('n must be a native str (got %s)' % type(n).__name__) + + +if not six.PY2: + """Python 3 has :py:class:`memoryview` builtin.""" + # Python 2.7 has it backported, but socket.write() does + # str(memoryview(b'0' * 100)) -> + # instead of accessing it correctly. + memoryview = memoryview +else: + """Link :py:class:`memoryview` to buffer under Python 2.""" + memoryview = buffer # noqa: F821 + + +def extract_bytes(mv): + r"""Retrieve bytes out of the given input buffer. + + :param mv: input :py:func:`buffer` + :type mv: memoryview or bytes + + :return: unwrapped bytes + :rtype: bytes + + :raises ValueError: if the input is not one of \ + :py:class:`memoryview`/:py:func:`buffer` \ + or :py:class:`bytes` + """ + if isinstance(mv, memoryview): + return bytes(mv) if six.PY2 else mv.tobytes() + + if isinstance(mv, bytes): + return mv + + raise ValueError( + 'extract_bytes() only accepts bytes and memoryview/buffer', + ) diff --git a/resources/lib/cheroot/cli.py b/resources/lib/cheroot/cli.py new file mode 100644 index 0000000..4607e22 --- /dev/null +++ b/resources/lib/cheroot/cli.py @@ -0,0 +1,247 @@ +"""Command line tool for starting a Cheroot WSGI/HTTP server instance. + +Basic usage: + +.. code-block:: shell-session + + $ # Start a server on 127.0.0.1:8000 with the default settings + $ # for the WSGI app myapp/wsgi.py:application() + $ cheroot myapp.wsgi + + $ # Start a server on 0.0.0.0:9000 with 8 threads + $ # for the WSGI app myapp/wsgi.py:main_app() + $ cheroot myapp.wsgi:main_app --bind 0.0.0.0:9000 --threads 8 + + $ # Start a server for the cheroot.server.Gateway subclass + $ # myapp/gateway.py:HTTPGateway + $ cheroot myapp.gateway:HTTPGateway + + $ # Start a server on the UNIX socket /var/spool/myapp.sock + $ cheroot myapp.wsgi --bind /var/spool/myapp.sock + + $ # Start a server on the abstract UNIX socket CherootServer + $ cheroot myapp.wsgi --bind @CherootServer + +.. spelling:: + + cli +""" + +import argparse +from importlib import import_module +import os +import sys + +import six + +from . import server +from . import wsgi +from ._compat import suppress + + +__metaclass__ = type + + +class BindLocation: + """A class for storing the bind location for a Cheroot instance.""" + + +class TCPSocket(BindLocation): + """TCPSocket.""" + + def __init__(self, address, port): + """Initialize. + + Args: + address (str): Host name or IP address + port (int): TCP port number + + """ + self.bind_addr = address, port + + +class UnixSocket(BindLocation): + """UnixSocket.""" + + def __init__(self, path): + """Initialize.""" + self.bind_addr = path + + +class AbstractSocket(BindLocation): + """AbstractSocket.""" + + def __init__(self, abstract_socket): + """Initialize.""" + self.bind_addr = '\x00{sock_path}'.format(sock_path=abstract_socket) + + +class Application: + """Application.""" + + @classmethod + def resolve(cls, full_path): + """Read WSGI app/Gateway path string and import application module.""" + mod_path, _, app_path = full_path.partition(':') + app = getattr(import_module(mod_path), app_path or 'application') + # suppress the `TypeError` exception, just in case `app` is not a class + with suppress(TypeError): + if issubclass(app, server.Gateway): + return GatewayYo(app) + + return cls(app) + + def __init__(self, wsgi_app): + """Initialize.""" + if not callable(wsgi_app): + raise TypeError( + 'Application must be a callable object or ' + 'cheroot.server.Gateway subclass', + ) + self.wsgi_app = wsgi_app + + def server_args(self, parsed_args): + """Return keyword args for Server class.""" + args = { + arg: value + for arg, value in vars(parsed_args).items() + if not arg.startswith('_') and value is not None + } + args.update(vars(self)) + return args + + def server(self, parsed_args): + """Server.""" + return wsgi.Server(**self.server_args(parsed_args)) + + +class GatewayYo: + """Gateway.""" + + def __init__(self, gateway): + """Init.""" + self.gateway = gateway + + def server(self, parsed_args): + """Server.""" + server_args = vars(self) + server_args['bind_addr'] = parsed_args['bind_addr'] + if parsed_args.max is not None: + server_args['maxthreads'] = parsed_args.max + if parsed_args.numthreads is not None: + server_args['minthreads'] = parsed_args.numthreads + return server.HTTPServer(**server_args) + + +def parse_wsgi_bind_location(bind_addr_string): + """Convert bind address string to a BindLocation.""" + # if the string begins with an @ symbol, use an abstract socket, + # this is the first condition to verify, otherwise the urlparse + # validation would detect //@ as a valid url with a hostname + # with value: "" and port: None + if bind_addr_string.startswith('@'): + return AbstractSocket(bind_addr_string[1:]) + + # try and match for an IP/hostname and port + match = six.moves.urllib.parse.urlparse( + '//{addr}'.format(addr=bind_addr_string), + ) + try: + addr = match.hostname + port = match.port + if addr is not None or port is not None: + return TCPSocket(addr, port) + except ValueError: + pass + + # else, assume a UNIX socket path + return UnixSocket(path=bind_addr_string) + + +def parse_wsgi_bind_addr(bind_addr_string): + """Convert bind address string to bind address parameter.""" + return parse_wsgi_bind_location(bind_addr_string).bind_addr + + +_arg_spec = { + '_wsgi_app': { + 'metavar': 'APP_MODULE', + 'type': Application.resolve, + 'help': 'WSGI application callable or cheroot.server.Gateway subclass', + }, + '--bind': { + 'metavar': 'ADDRESS', + 'dest': 'bind_addr', + 'type': parse_wsgi_bind_addr, + 'default': '[::1]:8000', + 'help': 'Network interface to listen on (default: [::1]:8000)', + }, + '--chdir': { + 'metavar': 'PATH', + 'type': os.chdir, + 'help': 'Set the working directory', + }, + '--server-name': { + 'dest': 'server_name', + 'type': str, + 'help': 'Web server name to be advertised via Server HTTP header', + }, + '--threads': { + 'metavar': 'INT', + 'dest': 'numthreads', + 'type': int, + 'help': 'Minimum number of worker threads', + }, + '--max-threads': { + 'metavar': 'INT', + 'dest': 'max', + 'type': int, + 'help': 'Maximum number of worker threads', + }, + '--timeout': { + 'metavar': 'INT', + 'dest': 'timeout', + 'type': int, + 'help': 'Timeout in seconds for accepted connections', + }, + '--shutdown-timeout': { + 'metavar': 'INT', + 'dest': 'shutdown_timeout', + 'type': int, + 'help': 'Time in seconds to wait for worker threads to cleanly exit', + }, + '--request-queue-size': { + 'metavar': 'INT', + 'dest': 'request_queue_size', + 'type': int, + 'help': 'Maximum number of queued connections', + }, + '--accepted-queue-size': { + 'metavar': 'INT', + 'dest': 'accepted_queue_size', + 'type': int, + 'help': 'Maximum number of active requests in queue', + }, + '--accepted-queue-timeout': { + 'metavar': 'INT', + 'dest': 'accepted_queue_timeout', + 'type': int, + 'help': 'Timeout in seconds for putting requests into queue', + }, +} + + +def main(): + """Create a new Cheroot instance with arguments from the command line.""" + parser = argparse.ArgumentParser( + description='Start an instance of the Cheroot WSGI/HTTP server.', + ) + for arg, spec in _arg_spec.items(): + parser.add_argument(arg, **spec) + raw_args = parser.parse_args() + + # ensure cwd in sys.path + '' in sys.path or sys.path.insert(0, '') + + # create a server based on the arguments provided + raw_args._wsgi_app.server(raw_args).safe_start() diff --git a/resources/lib/cheroot/connections.py b/resources/lib/cheroot/connections.py new file mode 100644 index 0000000..7debcbf --- /dev/null +++ b/resources/lib/cheroot/connections.py @@ -0,0 +1,369 @@ +"""Utilities to manage open connections.""" + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +import io +import os +import socket +import threading +import time + +from . import errors +from ._compat import selectors +from ._compat import suppress +from .makefile import MakeFile + +import six + +try: + import fcntl +except ImportError: + try: + from ctypes import windll, WinError + import ctypes.wintypes + _SetHandleInformation = windll.kernel32.SetHandleInformation + _SetHandleInformation.argtypes = [ + ctypes.wintypes.HANDLE, + ctypes.wintypes.DWORD, + ctypes.wintypes.DWORD, + ] + _SetHandleInformation.restype = ctypes.wintypes.BOOL + except ImportError: + def prevent_socket_inheritance(sock): + """Stub inheritance prevention. + + Dummy function, since neither fcntl nor ctypes are available. + """ + pass + else: + def prevent_socket_inheritance(sock): + """Mark the given socket fd as non-inheritable (Windows).""" + if not _SetHandleInformation(sock.fileno(), 1, 0): + raise WinError() +else: + def prevent_socket_inheritance(sock): + """Mark the given socket fd as non-inheritable (POSIX).""" + fd = sock.fileno() + old_flags = fcntl.fcntl(fd, fcntl.F_GETFD) + fcntl.fcntl(fd, fcntl.F_SETFD, old_flags | fcntl.FD_CLOEXEC) + + +class _ThreadsafeSelector: + """Thread-safe wrapper around a DefaultSelector. + + There are 2 thread contexts in which it may be accessed: + * the selector thread + * one of the worker threads in workers/threadpool.py + + The expected read/write patterns are: + * :py:func:`~iter`: selector thread + * :py:meth:`register`: selector thread and threadpool, + via :py:meth:`~cheroot.workers.threadpool.ThreadPool.put` + * :py:meth:`unregister`: selector thread only + + Notably, this means :py:class:`_ThreadsafeSelector` never needs to worry + that connections will be removed behind its back. + + The lock is held when iterating or modifying the selector but is not + required when :py:meth:`select()ing ` on it. + """ + + def __init__(self): + self._selector = selectors.DefaultSelector() + self._lock = threading.Lock() + + def __len__(self): + with self._lock: + return len(self._selector.get_map() or {}) + + @property + def connections(self): + """Retrieve connections registered with the selector.""" + with self._lock: + mapping = self._selector.get_map() or {} + for _, (_, sock_fd, _, conn) in mapping.items(): + yield (sock_fd, conn) + + def register(self, fileobj, events, data=None): + """Register ``fileobj`` with the selector.""" + with self._lock: + return self._selector.register(fileobj, events, data) + + def unregister(self, fileobj): + """Unregister ``fileobj`` from the selector.""" + with self._lock: + return self._selector.unregister(fileobj) + + def select(self, timeout=None): + """Return socket fd and data pairs from selectors.select call. + + Returns entries ready to read in the form: + (socket_file_descriptor, connection) + """ + return ( + (key.fd, key.data) + for key, _ in self._selector.select(timeout=timeout) + ) + + def close(self): + """Close the selector.""" + with self._lock: + self._selector.close() + + +class ConnectionManager: + """Class which manages HTTPConnection objects. + + This is for connections which are being kept-alive for follow-up requests. + """ + + def __init__(self, server): + """Initialize ConnectionManager object. + + Args: + server (cheroot.server.HTTPServer): web server object + that uses this ConnectionManager instance. + """ + self._serving = False + self._stop_requested = False + + self.server = server + self._selector = _ThreadsafeSelector() + + self._selector.register( + server.socket.fileno(), + selectors.EVENT_READ, data=server, + ) + + def put(self, conn): + """Put idle connection into the ConnectionManager to be managed. + + :param conn: HTTP connection to be managed + :type conn: cheroot.server.HTTPConnection + """ + conn.last_used = time.time() + # if this conn doesn't have any more data waiting to be read, + # register it with the selector. + if conn.rfile.has_data(): + self.server.process_conn(conn) + else: + self._selector.register( + conn.socket.fileno(), selectors.EVENT_READ, data=conn, + ) + + def _expire(self): + """Expire least recently used connections. + + This happens if there are either too many open connections, or if the + connections have been timed out. + + This should be called periodically. + """ + # find any connections still registered with the selector + # that have not been active recently enough. + threshold = time.time() - self.server.timeout + timed_out_connections = [ + (sock_fd, conn) + for (sock_fd, conn) in self._selector.connections + if conn != self.server and conn.last_used < threshold + ] + for sock_fd, conn in timed_out_connections: + self._selector.unregister(sock_fd) + conn.close() + + def stop(self): + """Stop the selector loop in run() synchronously. + + May take up to half a second. + """ + self._stop_requested = True + while self._serving: + time.sleep(0.01) + + def run(self, expiration_interval): + """Run the connections selector indefinitely. + + Args: + expiration_interval (float): Interval, in seconds, at which + connections will be checked for expiration. + + Connections that are ready to process are submitted via + self.server.process_conn() + + Connections submitted for processing must be `put()` + back if they should be examined again for another request. + + Can be shut down by calling `stop()`. + """ + self._serving = True + try: + self._run(expiration_interval) + finally: + self._serving = False + + def _run(self, expiration_interval): + last_expiration_check = time.time() + + while not self._stop_requested: + try: + active_list = self._selector.select(timeout=0.01) + except OSError: + self._remove_invalid_sockets() + continue + + for (sock_fd, conn) in active_list: + if conn is self.server: + # New connection + new_conn = self._from_server_socket(self.server.socket) + if new_conn is not None: + self.server.process_conn(new_conn) + else: + # unregister connection from the selector until the server + # has read from it and returned it via put() + self._selector.unregister(sock_fd) + self.server.process_conn(conn) + + now = time.time() + if (now - last_expiration_check) > expiration_interval: + self._expire() + last_expiration_check = now + + def _remove_invalid_sockets(self): + """Clean up the resources of any broken connections. + + This method attempts to detect any connections in an invalid state, + unregisters them from the selector and closes the file descriptors of + the corresponding network sockets where possible. + """ + invalid_conns = [] + for sock_fd, conn in self._selector.connections: + if conn is self.server: + continue + + try: + os.fstat(sock_fd) + except OSError: + invalid_conns.append((sock_fd, conn)) + + for sock_fd, conn in invalid_conns: + self._selector.unregister(sock_fd) + # One of the reason on why a socket could cause an error + # is that the socket is already closed, ignore the + # socket error if we try to close it at this point. + # This is equivalent to OSError in Py3 + with suppress(socket.error): + conn.close() + + def _from_server_socket(self, server_socket): # noqa: C901 # FIXME + try: + s, addr = server_socket.accept() + if self.server.stats['Enabled']: + self.server.stats['Accepts'] += 1 + prevent_socket_inheritance(s) + if hasattr(s, 'settimeout'): + s.settimeout(self.server.timeout) + + mf = MakeFile + ssl_env = {} + # if ssl cert and key are set, we try to be a secure HTTP server + if self.server.ssl_adapter is not None: + try: + s, ssl_env = self.server.ssl_adapter.wrap(s) + except errors.NoSSLError: + msg = ( + 'The client sent a plain HTTP request, but ' + 'this server only speaks HTTPS on this port.' + ) + buf = [ + '%s 400 Bad Request\r\n' % self.server.protocol, + 'Content-Length: %s\r\n' % len(msg), + 'Content-Type: text/plain\r\n\r\n', + msg, + ] + + sock_to_make = s if not six.PY2 else s._sock + wfile = mf(sock_to_make, 'wb', io.DEFAULT_BUFFER_SIZE) + try: + wfile.write(''.join(buf).encode('ISO-8859-1')) + except socket.error as ex: + if ex.args[0] not in errors.socket_errors_to_ignore: + raise + return + if not s: + return + mf = self.server.ssl_adapter.makefile + # Re-apply our timeout since we may have a new socket object + if hasattr(s, 'settimeout'): + s.settimeout(self.server.timeout) + + conn = self.server.ConnectionClass(self.server, s, mf) + + if not isinstance( + self.server.bind_addr, + (six.text_type, six.binary_type), + ): + # optional values + # Until we do DNS lookups, omit REMOTE_HOST + if addr is None: # sometimes this can happen + # figure out if AF_INET or AF_INET6. + if len(s.getsockname()) == 2: + # AF_INET + addr = ('0.0.0.0', 0) + else: + # AF_INET6 + addr = ('::', 0) + conn.remote_addr = addr[0] + conn.remote_port = addr[1] + + conn.ssl_env = ssl_env + return conn + + except socket.timeout: + # The only reason for the timeout in start() is so we can + # notice keyboard interrupts on Win32, which don't interrupt + # accept() by default + return + except socket.error as ex: + if self.server.stats['Enabled']: + self.server.stats['Socket Errors'] += 1 + if ex.args[0] in errors.socket_error_eintr: + # I *think* this is right. EINTR should occur when a signal + # is received during the accept() call; all docs say retry + # the call, and I *think* I'm reading it right that Python + # will then go ahead and poll for and handle the signal + # elsewhere. See + # https://github.com/cherrypy/cherrypy/issues/707. + return + if ex.args[0] in errors.socket_errors_nonblocking: + # Just try again. See + # https://github.com/cherrypy/cherrypy/issues/479. + return + if ex.args[0] in errors.socket_errors_to_ignore: + # Our socket was closed. + # See https://github.com/cherrypy/cherrypy/issues/686. + return + raise + + def close(self): + """Close all monitored connections.""" + for (_, conn) in self._selector.connections: + if conn is not self.server: # server closes its own socket + conn.close() + self._selector.close() + + @property + def _num_connections(self): + """Return the current number of connections. + + Includes all connections registered with the selector, + minus one for the server socket, which is always registered + with the selector. + """ + return len(self._selector) - 1 + + @property + def can_add_keepalive_connection(self): + """Flag whether it is allowed to add a new keep-alive connection.""" + ka_limit = self.server.keep_alive_conn_limit + return ka_limit is None or self._num_connections < ka_limit diff --git a/resources/lib/cheroot/errors.py b/resources/lib/cheroot/errors.py new file mode 100644 index 0000000..e00629f --- /dev/null +++ b/resources/lib/cheroot/errors.py @@ -0,0 +1,88 @@ +# -*- coding: utf-8 -*- +"""Collection of exceptions raised and/or processed by Cheroot.""" + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +import errno +import sys + + +class MaxSizeExceeded(Exception): + """Exception raised when a client sends more data then acceptable within limit. + + Depends on ``request.body.maxbytes`` config option if used within CherryPy + """ + + +class NoSSLError(Exception): + """Exception raised when a client speaks HTTP to an HTTPS socket.""" + + +class FatalSSLAlert(Exception): + """Exception raised when the SSL implementation signals a fatal alert.""" + + +def plat_specific_errors(*errnames): + """Return error numbers for all errors in ``errnames`` on this platform. + + The :py:mod:`errno` module contains different global constants + depending on the specific platform (OS). This function will return + the list of numeric values for a given list of potential names. + """ + missing_attr = {None} + unique_nums = {getattr(errno, k, None) for k in errnames} + return list(unique_nums - missing_attr) + + +socket_error_eintr = plat_specific_errors('EINTR', 'WSAEINTR') + +socket_errors_to_ignore = plat_specific_errors( + 'EPIPE', + 'EBADF', 'WSAEBADF', + 'ENOTSOCK', 'WSAENOTSOCK', + 'ETIMEDOUT', 'WSAETIMEDOUT', + 'ECONNREFUSED', 'WSAECONNREFUSED', + 'ECONNRESET', 'WSAECONNRESET', + 'ECONNABORTED', 'WSAECONNABORTED', + 'ENETRESET', 'WSAENETRESET', + 'EHOSTDOWN', 'EHOSTUNREACH', +) +socket_errors_to_ignore.append('timed out') +socket_errors_to_ignore.append('The read operation timed out') +socket_errors_nonblocking = plat_specific_errors( + 'EAGAIN', 'EWOULDBLOCK', 'WSAEWOULDBLOCK', +) + +if sys.platform == 'darwin': + socket_errors_to_ignore.extend(plat_specific_errors('EPROTOTYPE')) + socket_errors_nonblocking.extend(plat_specific_errors('EPROTOTYPE')) + + +acceptable_sock_shutdown_error_codes = { + errno.ENOTCONN, + errno.EPIPE, errno.ESHUTDOWN, # corresponds to BrokenPipeError in Python 3 + errno.ECONNRESET, # corresponds to ConnectionResetError in Python 3 +} +"""Errors that may happen during the connection close sequence. + +* ENOTCONN — client is no longer connected +* EPIPE — write on a pipe while the other end has been closed +* ESHUTDOWN — write on a socket which has been shutdown for writing +* ECONNRESET — connection is reset by the peer, we received a TCP RST packet + +Refs: +* https://github.com/cherrypy/cheroot/issues/341#issuecomment-735884889 +* https://bugs.python.org/issue30319 +* https://bugs.python.org/issue30329 +* https://github.com/python/cpython/commit/83a2c28 +* https://github.com/python/cpython/blob/c39b52f/Lib/poplib.py#L297-L302 +* https://docs.microsoft.com/windows/win32/api/winsock/nf-winsock-shutdown +""" + +try: # py3 + acceptable_sock_shutdown_exceptions = ( + BrokenPipeError, ConnectionResetError, + ) +except NameError: # py2 + acceptable_sock_shutdown_exceptions = () diff --git a/resources/lib/cheroot/makefile.py b/resources/lib/cheroot/makefile.py new file mode 100644 index 0000000..1383c65 --- /dev/null +++ b/resources/lib/cheroot/makefile.py @@ -0,0 +1,447 @@ +"""Socket file object.""" + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +import socket + +try: + # prefer slower Python-based io module + import _pyio as io +except ImportError: + # Python 2.6 + import io + +import six + +from . import errors +from ._compat import extract_bytes, memoryview + + +# Write only 16K at a time to sockets +SOCK_WRITE_BLOCKSIZE = 16384 + + +class BufferedWriter(io.BufferedWriter): + """Faux file object attached to a socket object.""" + + def write(self, b): + """Write bytes to buffer.""" + self._checkClosed() + if isinstance(b, str): + raise TypeError("can't write str to binary stream") + + with self._write_lock: + self._write_buf.extend(b) + self._flush_unlocked() + return len(b) + + def _flush_unlocked(self): + self._checkClosed('flush of closed file') + while self._write_buf: + try: + # ssl sockets only except 'bytes', not bytearrays + # so perhaps we should conditionally wrap this for perf? + n = self.raw.write(bytes(self._write_buf)) + except io.BlockingIOError as e: + n = e.characters_written + del self._write_buf[:n] + + +class MakeFile_PY2(getattr(socket, '_fileobject', object)): + """Faux file object attached to a socket object.""" + + def __init__(self, *args, **kwargs): + """Initialize faux file object.""" + self.bytes_read = 0 + self.bytes_written = 0 + socket._fileobject.__init__(self, *args, **kwargs) + self._refcount = 0 + + def _reuse(self): + self._refcount += 1 + + def _drop(self): + if self._refcount < 0: + self.close() + else: + self._refcount -= 1 + + def write(self, data): + """Send entire data contents for non-blocking sockets.""" + bytes_sent = 0 + data_mv = memoryview(data) + payload_size = len(data_mv) + while bytes_sent < payload_size: + try: + bytes_sent += self.send( + data_mv[bytes_sent:bytes_sent + SOCK_WRITE_BLOCKSIZE], + ) + except socket.error as e: + if e.args[0] not in errors.socket_errors_nonblocking: + raise + + def send(self, data): + """Send some part of message to the socket.""" + bytes_sent = self._sock.send(extract_bytes(data)) + self.bytes_written += bytes_sent + return bytes_sent + + def flush(self): + """Write all data from buffer to socket and reset write buffer.""" + if self._wbuf: + buffer = ''.join(self._wbuf) + self._wbuf = [] + self.write(buffer) + + def recv(self, size): + """Receive message of a size from the socket.""" + while True: + try: + data = self._sock.recv(size) + self.bytes_read += len(data) + return data + except socket.error as e: + what = ( + e.args[0] not in errors.socket_errors_nonblocking + and e.args[0] not in errors.socket_error_eintr + ) + if what: + raise + + class FauxSocket: + """Faux socket with the minimal interface required by pypy.""" + + def _reuse(self): + pass + + _fileobject_uses_str_type = six.PY2 and isinstance( + socket._fileobject(FauxSocket())._rbuf, six.string_types, + ) + + # FauxSocket is no longer needed + del FauxSocket + + if not _fileobject_uses_str_type: # noqa: C901 # FIXME + def read(self, size=-1): + """Read data from the socket to buffer.""" + # Use max, disallow tiny reads in a loop as they are very + # inefficient. + # We never leave read() with any leftover data from a new recv() + # call in our internal buffer. + rbufsize = max(self._rbufsize, self.default_bufsize) + # Our use of StringIO rather than lists of string objects returned + # by recv() minimizes memory usage and fragmentation that occurs + # when rbufsize is large compared to the typical return value of + # recv(). + buf = self._rbuf + buf.seek(0, 2) # seek end + if size < 0: + # Read until EOF + # reset _rbuf. we consume it via buf. + self._rbuf = io.BytesIO() + while True: + data = self.recv(rbufsize) + if not data: + break + buf.write(data) + return buf.getvalue() + else: + # Read until size bytes or EOF seen, whichever comes first + buf_len = buf.tell() + if buf_len >= size: + # Already have size bytes in our buffer? Extract and + # return. + buf.seek(0) + rv = buf.read(size) + self._rbuf = io.BytesIO() + self._rbuf.write(buf.read()) + return rv + + # reset _rbuf. we consume it via buf. + self._rbuf = io.BytesIO() + while True: + left = size - buf_len + # recv() will malloc the amount of memory given as its + # parameter even though it often returns much less data + # than that. The returned data string is short lived + # as we copy it into a StringIO and free it. This avoids + # fragmentation issues on many platforms. + data = self.recv(left) + if not data: + break + n = len(data) + if n == size and not buf_len: + # Shortcut. Avoid buffer data copies when: + # - We have no data in our buffer. + # AND + # - Our call to recv returned exactly the + # number of bytes we were asked to read. + return data + if n == left: + buf.write(data) + del data # explicit free + break + assert n <= left, 'recv(%d) returned %d bytes' % (left, n) + buf.write(data) + buf_len += n + del data # explicit free + # assert buf_len == buf.tell() + return buf.getvalue() + + def readline(self, size=-1): + """Read line from the socket to buffer.""" + buf = self._rbuf + buf.seek(0, 2) # seek end + if buf.tell() > 0: + # check if we already have it in our buffer + buf.seek(0) + bline = buf.readline(size) + if bline.endswith('\n') or len(bline) == size: + self._rbuf = io.BytesIO() + self._rbuf.write(buf.read()) + return bline + del bline + if size < 0: + # Read until \n or EOF, whichever comes first + if self._rbufsize <= 1: + # Speed up unbuffered case + buf.seek(0) + buffers = [buf.read()] + # reset _rbuf. we consume it via buf. + self._rbuf = io.BytesIO() + data = None + recv = self.recv + while data != '\n': + data = recv(1) + if not data: + break + buffers.append(data) + return ''.join(buffers) + + buf.seek(0, 2) # seek end + # reset _rbuf. we consume it via buf. + self._rbuf = io.BytesIO() + while True: + data = self.recv(self._rbufsize) + if not data: + break + nl = data.find('\n') + if nl >= 0: + nl += 1 + buf.write(data[:nl]) + self._rbuf.write(data[nl:]) + del data + break + buf.write(data) + return buf.getvalue() + + else: + # Read until size bytes or \n or EOF seen, whichever comes + # first + buf.seek(0, 2) # seek end + buf_len = buf.tell() + if buf_len >= size: + buf.seek(0) + rv = buf.read(size) + self._rbuf = io.BytesIO() + self._rbuf.write(buf.read()) + return rv + # reset _rbuf. we consume it via buf. + self._rbuf = io.BytesIO() + while True: + data = self.recv(self._rbufsize) + if not data: + break + left = size - buf_len + # did we just receive a newline? + nl = data.find('\n', 0, left) + if nl >= 0: + nl += 1 + # save the excess data to _rbuf + self._rbuf.write(data[nl:]) + if buf_len: + buf.write(data[:nl]) + break + else: + # Shortcut. Avoid data copy through buf when + # returning a substring of our first recv(). + return data[:nl] + n = len(data) + if n == size and not buf_len: + # Shortcut. Avoid data copy through buf when + # returning exactly all of our first recv(). + return data + if n >= left: + buf.write(data[:left]) + self._rbuf.write(data[left:]) + break + buf.write(data) + buf_len += n + # assert buf_len == buf.tell() + return buf.getvalue() + + def has_data(self): + """Return true if there is buffered data to read.""" + return bool(self._rbuf.getvalue()) + + else: + def read(self, size=-1): + """Read data from the socket to buffer.""" + if size < 0: + # Read until EOF + buffers = [self._rbuf] + self._rbuf = '' + if self._rbufsize <= 1: + recv_size = self.default_bufsize + else: + recv_size = self._rbufsize + + while True: + data = self.recv(recv_size) + if not data: + break + buffers.append(data) + return ''.join(buffers) + else: + # Read until size bytes or EOF seen, whichever comes first + data = self._rbuf + buf_len = len(data) + if buf_len >= size: + self._rbuf = data[size:] + return data[:size] + buffers = [] + if data: + buffers.append(data) + self._rbuf = '' + while True: + left = size - buf_len + recv_size = max(self._rbufsize, left) + data = self.recv(recv_size) + if not data: + break + buffers.append(data) + n = len(data) + if n >= left: + self._rbuf = data[left:] + buffers[-1] = data[:left] + break + buf_len += n + return ''.join(buffers) + + def readline(self, size=-1): + """Read line from the socket to buffer.""" + data = self._rbuf + if size < 0: + # Read until \n or EOF, whichever comes first + if self._rbufsize <= 1: + # Speed up unbuffered case + assert data == '' + buffers = [] + while data != '\n': + data = self.recv(1) + if not data: + break + buffers.append(data) + return ''.join(buffers) + nl = data.find('\n') + if nl >= 0: + nl += 1 + self._rbuf = data[nl:] + return data[:nl] + buffers = [] + if data: + buffers.append(data) + self._rbuf = '' + while True: + data = self.recv(self._rbufsize) + if not data: + break + buffers.append(data) + nl = data.find('\n') + if nl >= 0: + nl += 1 + self._rbuf = data[nl:] + buffers[-1] = data[:nl] + break + return ''.join(buffers) + else: + # Read until size bytes or \n or EOF seen, whichever comes + # first + nl = data.find('\n', 0, size) + if nl >= 0: + nl += 1 + self._rbuf = data[nl:] + return data[:nl] + buf_len = len(data) + if buf_len >= size: + self._rbuf = data[size:] + return data[:size] + buffers = [] + if data: + buffers.append(data) + self._rbuf = '' + while True: + data = self.recv(self._rbufsize) + if not data: + break + buffers.append(data) + left = size - buf_len + nl = data.find('\n', 0, left) + if nl >= 0: + nl += 1 + self._rbuf = data[nl:] + buffers[-1] = data[:nl] + break + n = len(data) + if n >= left: + self._rbuf = data[left:] + buffers[-1] = data[:left] + break + buf_len += n + return ''.join(buffers) + + def has_data(self): + """Return true if there is buffered data to read.""" + return bool(self._rbuf) + + +if not six.PY2: + class StreamReader(io.BufferedReader): + """Socket stream reader.""" + + def __init__(self, sock, mode='r', bufsize=io.DEFAULT_BUFFER_SIZE): + """Initialize socket stream reader.""" + super().__init__(socket.SocketIO(sock, mode), bufsize) + self.bytes_read = 0 + + def read(self, *args, **kwargs): + """Capture bytes read.""" + val = super().read(*args, **kwargs) + self.bytes_read += len(val) + return val + + def has_data(self): + """Return true if there is buffered data to read.""" + return len(self._read_buf) > self._read_pos + + class StreamWriter(BufferedWriter): + """Socket stream writer.""" + + def __init__(self, sock, mode='w', bufsize=io.DEFAULT_BUFFER_SIZE): + """Initialize socket stream writer.""" + super().__init__(socket.SocketIO(sock, mode), bufsize) + self.bytes_written = 0 + + def write(self, val, *args, **kwargs): + """Capture bytes written.""" + res = super().write(val, *args, **kwargs) + self.bytes_written += len(val) + return res + + def MakeFile(sock, mode='r', bufsize=io.DEFAULT_BUFFER_SIZE): + """File object attached to a socket object.""" + cls = StreamReader if 'r' in mode else StreamWriter + return cls(sock, mode, bufsize) +else: + StreamReader = StreamWriter = MakeFile = MakeFile_PY2 diff --git a/resources/lib/cheroot/server.py b/resources/lib/cheroot/server.py new file mode 100644 index 0000000..8b59a33 --- /dev/null +++ b/resources/lib/cheroot/server.py @@ -0,0 +1,2191 @@ +""" +A high-speed, production ready, thread pooled, generic HTTP server. + +For those of you wanting to understand internals of this module, here's the +basic call flow. The server's listening thread runs a very tight loop, +sticking incoming connections onto a Queue:: + + server = HTTPServer(...) + server.start() + -> serve() + while ready: + _connections.run() + while not stop_requested: + child = socket.accept() # blocks until a request comes in + conn = HTTPConnection(child, ...) + server.process_conn(conn) # adds conn to threadpool + +Worker threads are kept in a pool and poll the Queue, popping off and then +handling each connection in turn. Each connection can consist of an arbitrary +number of requests and their responses, so we run a nested loop:: + + while True: + conn = server.requests.get() + conn.communicate() + -> while True: + req = HTTPRequest(...) + req.parse_request() + -> # Read the Request-Line, e.g. "GET /page HTTP/1.1" + req.rfile.readline() + read_headers(req.rfile, req.inheaders) + req.respond() + -> response = app(...) + try: + for chunk in response: + if chunk: + req.write(chunk) + finally: + if hasattr(response, "close"): + response.close() + if req.close_connection: + return + +For running a server you can invoke :func:`start() ` (it +will run the server forever) or use invoking :func:`prepare() +` and :func:`serve() ` like this:: + + server = HTTPServer(...) + server.prepare() + try: + threading.Thread(target=server.serve).start() + + # waiting/detecting some appropriate stop condition here + ... + + finally: + server.stop() + +And now for a trivial doctest to exercise the test suite + +>>> 'HTTPServer' in globals() +True +""" + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +import os +import io +import re +import email.utils +import socket +import sys +import time +import traceback as traceback_ +import logging +import platform +import contextlib +import threading + +try: + from functools import lru_cache +except ImportError: + from backports.functools_lru_cache import lru_cache + +import six +from six.moves import queue +from six.moves import urllib + +from . import connections, errors, __version__ +from ._compat import bton, ntou +from ._compat import IS_PPC +from .workers import threadpool +from .makefile import MakeFile, StreamWriter + + +__all__ = ( + 'HTTPRequest', 'HTTPConnection', 'HTTPServer', + 'HeaderReader', 'DropUnderscoreHeaderReader', + 'SizeCheckWrapper', 'KnownLengthRFile', 'ChunkedRFile', + 'Gateway', 'get_ssl_adapter_class', +) + + +IS_WINDOWS = platform.system() == 'Windows' +"""Flag indicating whether the app is running under Windows.""" + + +IS_GAE = os.getenv('SERVER_SOFTWARE', '').startswith('Google App Engine/') +"""Flag indicating whether the app is running in GAE env. + +Ref: +https://cloud.google.com/appengine/docs/standard/python/tools +/using-local-server#detecting_application_runtime_environment +""" + + +IS_UID_GID_RESOLVABLE = not IS_WINDOWS and not IS_GAE +"""Indicates whether UID/GID resolution's available under current platform.""" + + +if IS_UID_GID_RESOLVABLE: + try: + import grp + import pwd + except ImportError: + """Unavailable in the current env. + + This shouldn't be happening normally. + All of the known cases are excluded via the if clause. + """ + IS_UID_GID_RESOLVABLE = False + grp, pwd = None, None + import struct + + +if IS_WINDOWS and hasattr(socket, 'AF_INET6'): + if not hasattr(socket, 'IPPROTO_IPV6'): + socket.IPPROTO_IPV6 = 41 + if not hasattr(socket, 'IPV6_V6ONLY'): + socket.IPV6_V6ONLY = 27 + + +if not hasattr(socket, 'SO_PEERCRED'): + """ + NOTE: the value for SO_PEERCRED can be architecture specific, in + which case the getsockopt() will hopefully fail. The arch + specific value could be derived from platform.processor() + """ + socket.SO_PEERCRED = 21 if IS_PPC else 17 + + +LF = b'\n' +CRLF = b'\r\n' +TAB = b'\t' +SPACE = b' ' +COLON = b':' +SEMICOLON = b';' +EMPTY = b'' +ASTERISK = b'*' +FORWARD_SLASH = b'/' +QUOTED_SLASH = b'%2F' +QUOTED_SLASH_REGEX = re.compile(b''.join((b'(?i)', QUOTED_SLASH))) + + +_STOPPING_FOR_INTERRUPT = object() # sentinel used during shutdown + + +comma_separated_headers = [ + b'Accept', b'Accept-Charset', b'Accept-Encoding', + b'Accept-Language', b'Accept-Ranges', b'Allow', b'Cache-Control', + b'Connection', b'Content-Encoding', b'Content-Language', b'Expect', + b'If-Match', b'If-None-Match', b'Pragma', b'Proxy-Authenticate', b'TE', + b'Trailer', b'Transfer-Encoding', b'Upgrade', b'Vary', b'Via', b'Warning', + b'WWW-Authenticate', +] + + +if not hasattr(logging, 'statistics'): + logging.statistics = {} + + +class HeaderReader: + """Object for reading headers from an HTTP request. + + Interface and default implementation. + """ + + def __call__(self, rfile, hdict=None): # noqa: C901 # FIXME + """ + Read headers from the given stream into the given header dict. + + If hdict is None, a new header dict is created. Returns the populated + header dict. + + Headers which are repeated are folded together using a comma if their + specification so dictates. + + This function raises ValueError when the read bytes violate the HTTP + spec. + You should probably return "400 Bad Request" if this happens. + """ + if hdict is None: + hdict = {} + + while True: + line = rfile.readline() + if not line: + # No more data--illegal end of headers + raise ValueError('Illegal end of headers.') + + if line == CRLF: + # Normal end of headers + break + if not line.endswith(CRLF): + raise ValueError('HTTP requires CRLF terminators') + + if line[0] in (SPACE, TAB): + # It's a continuation line. + v = line.strip() + else: + try: + k, v = line.split(COLON, 1) + except ValueError: + raise ValueError('Illegal header line.') + v = v.strip() + k = self._transform_key(k) + hname = k + + if not self._allow_header(k): + continue + + if k in comma_separated_headers: + existing = hdict.get(hname) + if existing: + v = b', '.join((existing, v)) + hdict[hname] = v + + return hdict + + def _allow_header(self, key_name): + return True + + def _transform_key(self, key_name): + # TODO: what about TE and WWW-Authenticate? + return key_name.strip().title() + + +class DropUnderscoreHeaderReader(HeaderReader): + """Custom HeaderReader to exclude any headers with underscores in them.""" + + def _allow_header(self, key_name): + orig = super(DropUnderscoreHeaderReader, self)._allow_header(key_name) + return orig and '_' not in key_name + + +class SizeCheckWrapper: + """Wraps a file-like object, raising MaxSizeExceeded if too large. + + :param rfile: ``file`` of a limited size + :param int maxlen: maximum length of the file being read + """ + + def __init__(self, rfile, maxlen): + """Initialize SizeCheckWrapper instance.""" + self.rfile = rfile + self.maxlen = maxlen + self.bytes_read = 0 + + def _check_length(self): + if self.maxlen and self.bytes_read > self.maxlen: + raise errors.MaxSizeExceeded() + + def read(self, size=None): + """Read a chunk from ``rfile`` buffer and return it. + + :param int size: amount of data to read + + :returns: chunk from ``rfile``, limited by size if specified + :rtype: bytes + """ + data = self.rfile.read(size) + self.bytes_read += len(data) + self._check_length() + return data + + def readline(self, size=None): + """Read a single line from ``rfile`` buffer and return it. + + :param int size: minimum amount of data to read + + :returns: one line from ``rfile`` + :rtype: bytes + """ + if size is not None: + data = self.rfile.readline(size) + self.bytes_read += len(data) + self._check_length() + return data + + # User didn't specify a size ... + # We read the line in chunks to make sure it's not a 100MB line ! + res = [] + while True: + data = self.rfile.readline(256) + self.bytes_read += len(data) + self._check_length() + res.append(data) + # See https://github.com/cherrypy/cherrypy/issues/421 + if len(data) < 256 or data[-1:] == LF: + return EMPTY.join(res) + + def readlines(self, sizehint=0): + """Read all lines from ``rfile`` buffer and return them. + + :param int sizehint: hint of minimum amount of data to read + + :returns: lines of bytes read from ``rfile`` + :rtype: list[bytes] + """ + # Shamelessly stolen from StringIO + total = 0 + lines = [] + line = self.readline(sizehint) + while line: + lines.append(line) + total += len(line) + if 0 < sizehint <= total: + break + line = self.readline(sizehint) + return lines + + def close(self): + """Release resources allocated for ``rfile``.""" + self.rfile.close() + + def __iter__(self): + """Return file iterator.""" + return self + + def __next__(self): + """Generate next file chunk.""" + data = next(self.rfile) + self.bytes_read += len(data) + self._check_length() + return data + + next = __next__ + + +class KnownLengthRFile: + """Wraps a file-like object, returning an empty string when exhausted. + + :param rfile: ``file`` of a known size + :param int content_length: length of the file being read + """ + + def __init__(self, rfile, content_length): + """Initialize KnownLengthRFile instance.""" + self.rfile = rfile + self.remaining = content_length + + def read(self, size=None): + """Read a chunk from ``rfile`` buffer and return it. + + :param int size: amount of data to read + + :rtype: bytes + :returns: chunk from ``rfile``, limited by size if specified + """ + if self.remaining == 0: + return b'' + if size is None: + size = self.remaining + else: + size = min(size, self.remaining) + + data = self.rfile.read(size) + self.remaining -= len(data) + return data + + def readline(self, size=None): + """Read a single line from ``rfile`` buffer and return it. + + :param int size: minimum amount of data to read + + :returns: one line from ``rfile`` + :rtype: bytes + """ + if self.remaining == 0: + return b'' + if size is None: + size = self.remaining + else: + size = min(size, self.remaining) + + data = self.rfile.readline(size) + self.remaining -= len(data) + return data + + def readlines(self, sizehint=0): + """Read all lines from ``rfile`` buffer and return them. + + :param int sizehint: hint of minimum amount of data to read + + :returns: lines of bytes read from ``rfile`` + :rtype: list[bytes] + """ + # Shamelessly stolen from StringIO + total = 0 + lines = [] + line = self.readline(sizehint) + while line: + lines.append(line) + total += len(line) + if 0 < sizehint <= total: + break + line = self.readline(sizehint) + return lines + + def close(self): + """Release resources allocated for ``rfile``.""" + self.rfile.close() + + def __iter__(self): + """Return file iterator.""" + return self + + def __next__(self): + """Generate next file chunk.""" + data = next(self.rfile) + self.remaining -= len(data) + return data + + next = __next__ + + +class ChunkedRFile: + """Wraps a file-like object, returning an empty string when exhausted. + + This class is intended to provide a conforming wsgi.input value for + request entities that have been encoded with the 'chunked' transfer + encoding. + + :param rfile: file encoded with the 'chunked' transfer encoding + :param int maxlen: maximum length of the file being read + :param int bufsize: size of the buffer used to read the file + """ + + def __init__(self, rfile, maxlen, bufsize=8192): + """Initialize ChunkedRFile instance.""" + self.rfile = rfile + self.maxlen = maxlen + self.bytes_read = 0 + self.buffer = EMPTY + self.bufsize = bufsize + self.closed = False + + def _fetch(self): + if self.closed: + return + + line = self.rfile.readline() + self.bytes_read += len(line) + + if self.maxlen and self.bytes_read > self.maxlen: + raise errors.MaxSizeExceeded( + 'Request Entity Too Large', self.maxlen, + ) + + line = line.strip().split(SEMICOLON, 1) + + try: + chunk_size = line.pop(0) + chunk_size = int(chunk_size, 16) + except ValueError: + raise ValueError( + 'Bad chunked transfer size: {chunk_size!r}'. + format(chunk_size=chunk_size), + ) + + if chunk_size <= 0: + self.closed = True + return + +# if line: chunk_extension = line[0] + + if self.maxlen and self.bytes_read + chunk_size > self.maxlen: + raise IOError('Request Entity Too Large') + + chunk = self.rfile.read(chunk_size) + self.bytes_read += len(chunk) + self.buffer += chunk + + crlf = self.rfile.read(2) + if crlf != CRLF: + raise ValueError( + "Bad chunked transfer coding (expected '\\r\\n', " + 'got ' + repr(crlf) + ')', + ) + + def read(self, size=None): + """Read a chunk from ``rfile`` buffer and return it. + + :param int size: amount of data to read + + :returns: chunk from ``rfile``, limited by size if specified + :rtype: bytes + """ + data = EMPTY + + if size == 0: + return data + + while True: + if size and len(data) >= size: + return data + + if not self.buffer: + self._fetch() + if not self.buffer: + # EOF + return data + + if size: + remaining = size - len(data) + data += self.buffer[:remaining] + self.buffer = self.buffer[remaining:] + else: + data += self.buffer + self.buffer = EMPTY + + def readline(self, size=None): + """Read a single line from ``rfile`` buffer and return it. + + :param int size: minimum amount of data to read + + :returns: one line from ``rfile`` + :rtype: bytes + """ + data = EMPTY + + if size == 0: + return data + + while True: + if size and len(data) >= size: + return data + + if not self.buffer: + self._fetch() + if not self.buffer: + # EOF + return data + + newline_pos = self.buffer.find(LF) + if size: + if newline_pos == -1: + remaining = size - len(data) + data += self.buffer[:remaining] + self.buffer = self.buffer[remaining:] + else: + remaining = min(size - len(data), newline_pos) + data += self.buffer[:remaining] + self.buffer = self.buffer[remaining:] + else: + if newline_pos == -1: + data += self.buffer + self.buffer = EMPTY + else: + data += self.buffer[:newline_pos] + self.buffer = self.buffer[newline_pos:] + + def readlines(self, sizehint=0): + """Read all lines from ``rfile`` buffer and return them. + + :param int sizehint: hint of minimum amount of data to read + + :returns: lines of bytes read from ``rfile`` + :rtype: list[bytes] + """ + # Shamelessly stolen from StringIO + total = 0 + lines = [] + line = self.readline(sizehint) + while line: + lines.append(line) + total += len(line) + if 0 < sizehint <= total: + break + line = self.readline(sizehint) + return lines + + def read_trailer_lines(self): + """Read HTTP headers and yield them. + + Returns: + Generator: yields CRLF separated lines. + + """ + if not self.closed: + raise ValueError( + 'Cannot read trailers until the request body has been read.', + ) + + while True: + line = self.rfile.readline() + if not line: + # No more data--illegal end of headers + raise ValueError('Illegal end of headers.') + + self.bytes_read += len(line) + if self.maxlen and self.bytes_read > self.maxlen: + raise IOError('Request Entity Too Large') + + if line == CRLF: + # Normal end of headers + break + if not line.endswith(CRLF): + raise ValueError('HTTP requires CRLF terminators') + + yield line + + def close(self): + """Release resources allocated for ``rfile``.""" + self.rfile.close() + + +class HTTPRequest: + """An HTTP Request (and response). + + A single HTTP connection may consist of multiple request/response pairs. + """ + + server = None + """The HTTPServer object which is receiving this request.""" + + conn = None + """The HTTPConnection object on which this request connected.""" + + inheaders = {} + """A dict of request headers.""" + + outheaders = [] + """A list of header tuples to write in the response.""" + + ready = False + """When True, the request has been parsed and is ready to begin generating + the response. When False, signals the calling Connection that the response + should not be generated and the connection should close.""" + + close_connection = False + """Signals the calling Connection that the request should close. This does + not imply an error! The client and/or server may each request that the + connection be closed.""" + + chunked_write = False + """If True, output will be encoded with the "chunked" transfer-coding. + + This value is set automatically inside send_headers.""" + + header_reader = HeaderReader() + """ + A HeaderReader instance or compatible reader. + """ + + def __init__(self, server, conn, proxy_mode=False, strict_mode=True): + """Initialize HTTP request container instance. + + Args: + server (HTTPServer): web server object receiving this request + conn (HTTPConnection): HTTP connection object for this request + proxy_mode (bool): whether this HTTPServer should behave as a PROXY + server for certain requests + strict_mode (bool): whether we should return a 400 Bad Request when + we encounter a request that a HTTP compliant client should not be + making + """ + self.server = server + self.conn = conn + + self.ready = False + self.started_request = False + self.scheme = b'http' + if self.server.ssl_adapter is not None: + self.scheme = b'https' + # Use the lowest-common protocol in case read_request_line errors. + self.response_protocol = 'HTTP/1.0' + self.inheaders = {} + + self.status = '' + self.outheaders = [] + self.sent_headers = False + self.close_connection = self.__class__.close_connection + self.chunked_read = False + self.chunked_write = self.__class__.chunked_write + self.proxy_mode = proxy_mode + self.strict_mode = strict_mode + + def parse_request(self): + """Parse the next HTTP request start-line and message-headers.""" + self.rfile = SizeCheckWrapper( + self.conn.rfile, + self.server.max_request_header_size, + ) + try: + success = self.read_request_line() + except errors.MaxSizeExceeded: + self.simple_response( + '414 Request-URI Too Long', + 'The Request-URI sent with the request exceeds the maximum ' + 'allowed bytes.', + ) + return + else: + if not success: + return + + try: + success = self.read_request_headers() + except errors.MaxSizeExceeded: + self.simple_response( + '413 Request Entity Too Large', + 'The headers sent with the request exceed the maximum ' + 'allowed bytes.', + ) + return + else: + if not success: + return + + self.ready = True + + def read_request_line(self): # noqa: C901 # FIXME + """Read and parse first line of the HTTP request. + + Returns: + bool: True if the request line is valid or False if it's malformed. + + """ + # HTTP/1.1 connections are persistent by default. If a client + # requests a page, then idles (leaves the connection open), + # then rfile.readline() will raise socket.error("timed out"). + # Note that it does this based on the value given to settimeout(), + # and doesn't need the client to request or acknowledge the close + # (although your TCP stack might suffer for it: cf Apache's history + # with FIN_WAIT_2). + request_line = self.rfile.readline() + + # Set started_request to True so communicate() knows to send 408 + # from here on out. + self.started_request = True + if not request_line: + return False + + if request_line == CRLF: + # RFC 2616 sec 4.1: "...if the server is reading the protocol + # stream at the beginning of a message and receives a CRLF + # first, it should ignore the CRLF." + # But only ignore one leading line! else we enable a DoS. + request_line = self.rfile.readline() + if not request_line: + return False + + if not request_line.endswith(CRLF): + self.simple_response( + '400 Bad Request', 'HTTP requires CRLF terminators', + ) + return False + + try: + method, uri, req_protocol = request_line.strip().split(SPACE, 2) + if not req_protocol.startswith(b'HTTP/'): + self.simple_response( + '400 Bad Request', 'Malformed Request-Line: bad protocol', + ) + return False + rp = req_protocol[5:].split(b'.', 1) + if len(rp) != 2: + self.simple_response( + '400 Bad Request', 'Malformed Request-Line: bad version', + ) + return False + rp = tuple(map(int, rp)) # Minor.Major must be threat as integers + if rp > (1, 1): + self.simple_response( + '505 HTTP Version Not Supported', 'Cannot fulfill request', + ) + return False + except (ValueError, IndexError): + self.simple_response('400 Bad Request', 'Malformed Request-Line') + return False + + self.uri = uri + self.method = method.upper() + + if self.strict_mode and method != self.method: + resp = ( + 'Malformed method name: According to RFC 2616 ' + '(section 5.1.1) and its successors ' + 'RFC 7230 (section 3.1.1) and RFC 7231 (section 4.1) ' + 'method names are case-sensitive and uppercase.' + ) + self.simple_response('400 Bad Request', resp) + return False + + try: + if six.PY2: # FIXME: Figure out better way to do this + # Ref: https://stackoverflow.com/a/196392/595220 (like this?) + """This is a dummy check for unicode in URI.""" + ntou(bton(uri, 'ascii'), 'ascii') + scheme, authority, path, qs, fragment = urllib.parse.urlsplit(uri) + except UnicodeError: + self.simple_response('400 Bad Request', 'Malformed Request-URI') + return False + + uri_is_absolute_form = (scheme or authority) + + if self.method == b'OPTIONS': + # TODO: cover this branch with tests + path = ( + uri + # https://tools.ietf.org/html/rfc7230#section-5.3.4 + if (self.proxy_mode and uri_is_absolute_form) + else path + ) + elif self.method == b'CONNECT': + # TODO: cover this branch with tests + if not self.proxy_mode: + self.simple_response('405 Method Not Allowed') + return False + + # `urlsplit()` above parses "example.com:3128" as path part of URI. + # this is a workaround, which makes it detect netloc correctly + uri_split = urllib.parse.urlsplit(b''.join((b'//', uri))) + _scheme, _authority, _path, _qs, _fragment = uri_split + _port = EMPTY + try: + _port = uri_split.port + except ValueError: + pass + + # FIXME: use third-party validation to make checks against RFC + # the validation doesn't take into account, that urllib parses + # invalid URIs without raising errors + # https://tools.ietf.org/html/rfc7230#section-5.3.3 + invalid_path = ( + _authority != uri + or not _port + or any((_scheme, _path, _qs, _fragment)) + ) + if invalid_path: + self.simple_response( + '400 Bad Request', + 'Invalid path in Request-URI: request-' + 'target must match authority-form.', + ) + return False + + authority = path = _authority + scheme = qs = fragment = EMPTY + else: + disallowed_absolute = ( + self.strict_mode + and not self.proxy_mode + and uri_is_absolute_form + ) + if disallowed_absolute: + # https://tools.ietf.org/html/rfc7230#section-5.3.2 + # (absolute form) + """Absolute URI is only allowed within proxies.""" + self.simple_response( + '400 Bad Request', + 'Absolute URI not allowed if server is not a proxy.', + ) + return False + + invalid_path = ( + self.strict_mode + and not uri.startswith(FORWARD_SLASH) + and not uri_is_absolute_form + ) + if invalid_path: + # https://tools.ietf.org/html/rfc7230#section-5.3.1 + # (origin_form) and + """Path should start with a forward slash.""" + resp = ( + 'Invalid path in Request-URI: request-target must contain ' + 'origin-form which starts with absolute-path (URI ' + 'starting with a slash "/").' + ) + self.simple_response('400 Bad Request', resp) + return False + + if fragment: + self.simple_response( + '400 Bad Request', + 'Illegal #fragment in Request-URI.', + ) + return False + + if path is None: + # FIXME: It looks like this case cannot happen + self.simple_response( + '400 Bad Request', + 'Invalid path in Request-URI.', + ) + return False + + # Unquote the path+params (e.g. "/this%20path" -> "/this path"). + # https://www.w3.org/Protocols/rfc2616/rfc2616-sec5.html#sec5.1.2 + # + # But note that "...a URI must be separated into its components + # before the escaped characters within those components can be + # safely decoded." https://www.ietf.org/rfc/rfc2396.txt, sec 2.4.2 + # Therefore, "/this%2Fpath" becomes "/this%2Fpath", not + # "/this/path". + try: + # TODO: Figure out whether exception can really happen here. + # It looks like it's caught on urlsplit() call above. + atoms = [ + urllib.parse.unquote_to_bytes(x) + for x in QUOTED_SLASH_REGEX.split(path) + ] + except ValueError as ex: + self.simple_response('400 Bad Request', ex.args[0]) + return False + path = QUOTED_SLASH.join(atoms) + + if not path.startswith(FORWARD_SLASH): + path = FORWARD_SLASH + path + + if scheme is not EMPTY: + self.scheme = scheme + self.authority = authority + self.path = path + + # Note that, like wsgiref and most other HTTP servers, + # we "% HEX HEX"-unquote the path but not the query string. + self.qs = qs + + # Compare request and server HTTP protocol versions, in case our + # server does not support the requested protocol. Limit our output + # to min(req, server). We want the following output: + # request server actual written supported response + # protocol protocol response protocol feature set + # a 1.0 1.0 1.0 1.0 + # b 1.0 1.1 1.1 1.0 + # c 1.1 1.0 1.0 1.0 + # d 1.1 1.1 1.1 1.1 + # Notice that, in (b), the response will be "HTTP/1.1" even though + # the client only understands 1.0. RFC 2616 10.5.6 says we should + # only return 505 if the _major_ version is different. + sp = int(self.server.protocol[5]), int(self.server.protocol[7]) + + if sp[0] != rp[0]: + self.simple_response('505 HTTP Version Not Supported') + return False + + self.request_protocol = req_protocol + self.response_protocol = 'HTTP/%s.%s' % min(rp, sp) + + return True + + def read_request_headers(self): # noqa: C901 # FIXME + """Read ``self.rfile`` into ``self.inheaders``. + + Ref: :py:attr:`self.inheaders `. + + :returns: success status + :rtype: bool + """ + # then all the http headers + try: + self.header_reader(self.rfile, self.inheaders) + except ValueError as ex: + self.simple_response('400 Bad Request', ex.args[0]) + return False + + mrbs = self.server.max_request_body_size + + try: + cl = int(self.inheaders.get(b'Content-Length', 0)) + except ValueError: + self.simple_response( + '400 Bad Request', + 'Malformed Content-Length Header.', + ) + return False + + if mrbs and cl > mrbs: + self.simple_response( + '413 Request Entity Too Large', + 'The entity sent with the request exceeds the maximum ' + 'allowed bytes.', + ) + return False + + # Persistent connection support + if self.response_protocol == 'HTTP/1.1': + # Both server and client are HTTP/1.1 + if self.inheaders.get(b'Connection', b'') == b'close': + self.close_connection = True + else: + # Either the server or client (or both) are HTTP/1.0 + if self.inheaders.get(b'Connection', b'') != b'Keep-Alive': + self.close_connection = True + + # Transfer-Encoding support + te = None + if self.response_protocol == 'HTTP/1.1': + te = self.inheaders.get(b'Transfer-Encoding') + if te: + te = [x.strip().lower() for x in te.split(b',') if x.strip()] + + self.chunked_read = False + + if te: + for enc in te: + if enc == b'chunked': + self.chunked_read = True + else: + # Note that, even if we see "chunked", we must reject + # if there is an extension we don't recognize. + self.simple_response('501 Unimplemented') + self.close_connection = True + return False + + # From PEP 333: + # "Servers and gateways that implement HTTP 1.1 must provide + # transparent support for HTTP 1.1's "expect/continue" mechanism. + # This may be done in any of several ways: + # 1. Respond to requests containing an Expect: 100-continue request + # with an immediate "100 Continue" response, and proceed normally. + # 2. Proceed with the request normally, but provide the application + # with a wsgi.input stream that will send the "100 Continue" + # response if/when the application first attempts to read from + # the input stream. The read request must then remain blocked + # until the client responds. + # 3. Wait until the client decides that the server does not support + # expect/continue, and sends the request body on its own. + # (This is suboptimal, and is not recommended.) + # + # We used to do 3, but are now doing 1. Maybe we'll do 2 someday, + # but it seems like it would be a big slowdown for such a rare case. + if self.inheaders.get(b'Expect', b'') == b'100-continue': + # Don't use simple_response here, because it emits headers + # we don't want. See + # https://github.com/cherrypy/cherrypy/issues/951 + msg = b''.join(( + self.server.protocol.encode('ascii'), SPACE, b'100 Continue', + CRLF, CRLF, + )) + try: + self.conn.wfile.write(msg) + except socket.error as ex: + if ex.args[0] not in errors.socket_errors_to_ignore: + raise + return True + + def respond(self): + """Call the gateway and write its iterable output.""" + mrbs = self.server.max_request_body_size + if self.chunked_read: + self.rfile = ChunkedRFile(self.conn.rfile, mrbs) + else: + cl = int(self.inheaders.get(b'Content-Length', 0)) + if mrbs and mrbs < cl: + if not self.sent_headers: + self.simple_response( + '413 Request Entity Too Large', + 'The entity sent with the request exceeds the ' + 'maximum allowed bytes.', + ) + return + self.rfile = KnownLengthRFile(self.conn.rfile, cl) + + self.server.gateway(self).respond() + self.ready and self.ensure_headers_sent() + + if self.chunked_write: + self.conn.wfile.write(b'0\r\n\r\n') + + def simple_response(self, status, msg=''): + """Write a simple response back to the client.""" + status = str(status) + proto_status = '%s %s\r\n' % (self.server.protocol, status) + content_length = 'Content-Length: %s\r\n' % len(msg) + content_type = 'Content-Type: text/plain\r\n' + buf = [ + proto_status.encode('ISO-8859-1'), + content_length.encode('ISO-8859-1'), + content_type.encode('ISO-8859-1'), + ] + + if status[:3] in ('413', '414'): + # Request Entity Too Large / Request-URI Too Long + self.close_connection = True + if self.response_protocol == 'HTTP/1.1': + # This will not be true for 414, since read_request_line + # usually raises 414 before reading the whole line, and we + # therefore cannot know the proper response_protocol. + buf.append(b'Connection: close\r\n') + else: + # HTTP/1.0 had no 413/414 status nor Connection header. + # Emit 400 instead and trust the message body is enough. + status = '400 Bad Request' + + buf.append(CRLF) + if msg: + if isinstance(msg, six.text_type): + msg = msg.encode('ISO-8859-1') + buf.append(msg) + + try: + self.conn.wfile.write(EMPTY.join(buf)) + except socket.error as ex: + if ex.args[0] not in errors.socket_errors_to_ignore: + raise + + def ensure_headers_sent(self): + """Ensure headers are sent to the client if not already sent.""" + if not self.sent_headers: + self.sent_headers = True + self.send_headers() + + def write(self, chunk): + """Write unbuffered data to the client.""" + if self.chunked_write and chunk: + chunk_size_hex = hex(len(chunk))[2:].encode('ascii') + buf = [chunk_size_hex, CRLF, chunk, CRLF] + self.conn.wfile.write(EMPTY.join(buf)) + else: + self.conn.wfile.write(chunk) + + def send_headers(self): # noqa: C901 # FIXME + """Assert, process, and send the HTTP response message-headers. + + You must set ``self.status``, and :py:attr:`self.outheaders + ` before calling this. + """ + hkeys = [key.lower() for key, value in self.outheaders] + status = int(self.status[:3]) + + if status == 413: + # Request Entity Too Large. Close conn to avoid garbage. + self.close_connection = True + elif b'content-length' not in hkeys: + # "All 1xx (informational), 204 (no content), + # and 304 (not modified) responses MUST NOT + # include a message-body." So no point chunking. + if status < 200 or status in (204, 205, 304): + pass + else: + needs_chunked = ( + self.response_protocol == 'HTTP/1.1' + and self.method != b'HEAD' + ) + if needs_chunked: + # Use the chunked transfer-coding + self.chunked_write = True + self.outheaders.append((b'Transfer-Encoding', b'chunked')) + else: + # Closing the conn is the only way to determine len. + self.close_connection = True + + # Override the decision to not close the connection if the connection + # manager doesn't have space for it. + if not self.close_connection: + can_keep = self.server.can_add_keepalive_connection + self.close_connection = not can_keep + + if b'connection' not in hkeys: + if self.response_protocol == 'HTTP/1.1': + # Both server and client are HTTP/1.1 or better + if self.close_connection: + self.outheaders.append((b'Connection', b'close')) + else: + # Server and/or client are HTTP/1.0 + if not self.close_connection: + self.outheaders.append((b'Connection', b'Keep-Alive')) + + if (b'Connection', b'Keep-Alive') in self.outheaders: + self.outheaders.append(( + b'Keep-Alive', + u'timeout={connection_timeout}'. + format(connection_timeout=self.server.timeout). + encode('ISO-8859-1'), + )) + + if (not self.close_connection) and (not self.chunked_read): + # Read any remaining request body data on the socket. + # "If an origin server receives a request that does not include an + # Expect request-header field with the "100-continue" expectation, + # the request includes a request body, and the server responds + # with a final status code before reading the entire request body + # from the transport connection, then the server SHOULD NOT close + # the transport connection until it has read the entire request, + # or until the client closes the connection. Otherwise, the client + # might not reliably receive the response message. However, this + # requirement is not be construed as preventing a server from + # defending itself against denial-of-service attacks, or from + # badly broken client implementations." + remaining = getattr(self.rfile, 'remaining', 0) + if remaining > 0: + self.rfile.read(remaining) + + if b'date' not in hkeys: + self.outheaders.append(( + b'Date', + email.utils.formatdate(usegmt=True).encode('ISO-8859-1'), + )) + + if b'server' not in hkeys: + self.outheaders.append(( + b'Server', + self.server.server_name.encode('ISO-8859-1'), + )) + + proto = self.server.protocol.encode('ascii') + buf = [proto + SPACE + self.status + CRLF] + for k, v in self.outheaders: + buf.append(k + COLON + SPACE + v + CRLF) + buf.append(CRLF) + self.conn.wfile.write(EMPTY.join(buf)) + + +class HTTPConnection: + """An HTTP connection (active socket).""" + + remote_addr = None + remote_port = None + ssl_env = None + rbufsize = io.DEFAULT_BUFFER_SIZE + wbufsize = io.DEFAULT_BUFFER_SIZE + RequestHandlerClass = HTTPRequest + peercreds_enabled = False + peercreds_resolve_enabled = False + + # Fields set by ConnectionManager. + last_used = None + + def __init__(self, server, sock, makefile=MakeFile): + """Initialize HTTPConnection instance. + + Args: + server (HTTPServer): web server object receiving this request + sock (socket._socketobject): the raw socket object (usually + TCP) for this connection + makefile (file): a fileobject class for reading from the socket + """ + self.server = server + self.socket = sock + self.rfile = makefile(sock, 'rb', self.rbufsize) + self.wfile = makefile(sock, 'wb', self.wbufsize) + self.requests_seen = 0 + + self.peercreds_enabled = self.server.peercreds_enabled + self.peercreds_resolve_enabled = self.server.peercreds_resolve_enabled + + # LRU cached methods: + # Ref: https://stackoverflow.com/a/14946506/595220 + self.resolve_peer_creds = ( + lru_cache(maxsize=1)(self.resolve_peer_creds) + ) + self.get_peer_creds = ( + lru_cache(maxsize=1)(self.get_peer_creds) + ) + + def communicate(self): # noqa: C901 # FIXME + """Read each request and respond appropriately. + + Returns true if the connection should be kept open. + """ + request_seen = False + try: + req = self.RequestHandlerClass(self.server, self) + req.parse_request() + if self.server.stats['Enabled']: + self.requests_seen += 1 + if not req.ready: + # Something went wrong in the parsing (and the server has + # probably already made a simple_response). Return and + # let the conn close. + return False + + request_seen = True + req.respond() + if not req.close_connection: + return True + except socket.error as ex: + errnum = ex.args[0] + # sadly SSL sockets return a different (longer) time out string + timeout_errs = 'timed out', 'The read operation timed out' + if errnum in timeout_errs: + # Don't error if we're between requests; only error + # if 1) no request has been started at all, or 2) we're + # in the middle of a request. + # See https://github.com/cherrypy/cherrypy/issues/853 + if (not request_seen) or (req and req.started_request): + self._conditional_error(req, '408 Request Timeout') + elif errnum not in errors.socket_errors_to_ignore: + self.server.error_log( + 'socket.error %s' % repr(errnum), + level=logging.WARNING, traceback=True, + ) + self._conditional_error(req, '500 Internal Server Error') + except (KeyboardInterrupt, SystemExit): + raise + except errors.FatalSSLAlert: + pass + except errors.NoSSLError: + self._handle_no_ssl(req) + except Exception as ex: + self.server.error_log( + repr(ex), level=logging.ERROR, traceback=True, + ) + self._conditional_error(req, '500 Internal Server Error') + return False + + linger = False + + def _handle_no_ssl(self, req): + if not req or req.sent_headers: + return + # Unwrap wfile + try: + resp_sock = self.socket._sock + except AttributeError: + # self.socket is of OpenSSL.SSL.Connection type + resp_sock = self.socket._socket + self.wfile = StreamWriter(resp_sock, 'wb', self.wbufsize) + msg = ( + 'The client sent a plain HTTP request, but ' + 'this server only speaks HTTPS on this port.' + ) + req.simple_response('400 Bad Request', msg) + self.linger = True + + def _conditional_error(self, req, response): + """Respond with an error. + + Don't bother writing if a response + has already started being written. + """ + if not req or req.sent_headers: + return + + try: + req.simple_response(response) + except errors.FatalSSLAlert: + pass + except errors.NoSSLError: + self._handle_no_ssl(req) + + def close(self): + """Close the socket underlying this connection.""" + self.rfile.close() + + if not self.linger: + self._close_kernel_socket() + # close the socket file descriptor + # (will be closed in the OS if there is no + # other reference to the underlying socket) + self.socket.close() + else: + # On the other hand, sometimes we want to hang around for a bit + # to make sure the client has a chance to read our entire + # response. Skipping the close() calls here delays the FIN + # packet until the socket object is garbage-collected later. + # Someday, perhaps, we'll do the full lingering_close that + # Apache does, but not today. + pass + + def get_peer_creds(self): # LRU cached on per-instance basis, see __init__ + """Return the PID/UID/GID tuple of the peer socket for UNIX sockets. + + This function uses SO_PEERCRED to query the UNIX PID, UID, GID + of the peer, which is only available if the bind address is + a UNIX domain socket. + + Raises: + NotImplementedError: in case of unsupported socket type + RuntimeError: in case of SO_PEERCRED lookup unsupported or disabled + + """ + PEERCRED_STRUCT_DEF = '3i' + + if IS_WINDOWS or self.socket.family != socket.AF_UNIX: + raise NotImplementedError( + 'SO_PEERCRED is only supported in Linux kernel and WSL', + ) + elif not self.peercreds_enabled: + raise RuntimeError( + 'Peer creds lookup is disabled within this server', + ) + + try: + peer_creds = self.socket.getsockopt( + # FIXME: Use LOCAL_CREDS for BSD-like OSs + # Ref: https://gist.github.com/LucaFilipozzi/e4f1e118202aff27af6aadebda1b5d91 # noqa + socket.SOL_SOCKET, socket.SO_PEERCRED, + struct.calcsize(PEERCRED_STRUCT_DEF), + ) + except socket.error as socket_err: + """Non-Linux kernels don't support SO_PEERCRED. + + Refs: + http://welz.org.za/notes/on-peer-cred.html + https://github.com/daveti/tcpSockHack + msdn.microsoft.com/en-us/commandline/wsl/release_notes#build-15025 + """ + six.raise_from( # 3.6+: raise RuntimeError from socket_err + RuntimeError, + socket_err, + ) + else: + pid, uid, gid = struct.unpack(PEERCRED_STRUCT_DEF, peer_creds) + return pid, uid, gid + + @property + def peer_pid(self): + """Return the id of the connected peer process.""" + pid, _, _ = self.get_peer_creds() + return pid + + @property + def peer_uid(self): + """Return the user id of the connected peer process.""" + _, uid, _ = self.get_peer_creds() + return uid + + @property + def peer_gid(self): + """Return the group id of the connected peer process.""" + _, _, gid = self.get_peer_creds() + return gid + + def resolve_peer_creds(self): # LRU cached on per-instance basis + """Look up the username and group tuple of the ``PEERCREDS``. + + :returns: the username and group tuple of the ``PEERCREDS`` + + :raises NotImplementedError: if the OS is unsupported + :raises RuntimeError: if UID/GID lookup is unsupported or disabled + """ + if not IS_UID_GID_RESOLVABLE: + raise NotImplementedError( + 'UID/GID lookup is unavailable under current platform. ' + 'It can only be done under UNIX-like OS ' + 'but not under the Google App Engine', + ) + elif not self.peercreds_resolve_enabled: + raise RuntimeError( + 'UID/GID lookup is disabled within this server', + ) + + user = pwd.getpwuid(self.peer_uid).pw_name # [0] + group = grp.getgrgid(self.peer_gid).gr_name # [0] + + return user, group + + @property + def peer_user(self): + """Return the username of the connected peer process.""" + user, _ = self.resolve_peer_creds() + return user + + @property + def peer_group(self): + """Return the group of the connected peer process.""" + _, group = self.resolve_peer_creds() + return group + + def _close_kernel_socket(self): + """Terminate the connection at the transport level.""" + # Honor ``sock_shutdown`` for PyOpenSSL connections. + shutdown = getattr( + self.socket, 'sock_shutdown', + self.socket.shutdown, + ) + + try: + shutdown(socket.SHUT_RDWR) # actually send a TCP FIN + except errors.acceptable_sock_shutdown_exceptions: + pass + except socket.error as e: + if e.errno not in errors.acceptable_sock_shutdown_error_codes: + raise + + +class HTTPServer: + """An HTTP server.""" + + _bind_addr = '127.0.0.1' + _interrupt = None + + gateway = None + """A Gateway instance.""" + + minthreads = None + """The minimum number of worker threads to create (default 10).""" + + maxthreads = None + """The maximum number of worker threads to create. + + (default -1 = no limit)""" + + server_name = None + """The name of the server; defaults to ``self.version``.""" + + protocol = 'HTTP/1.1' + """The version string to write in the Status-Line of all HTTP responses. + + For example, "HTTP/1.1" is the default. This also limits the supported + features used in the response.""" + + request_queue_size = 5 + """The 'backlog' arg to socket.listen(); max queued connections. + + (default 5).""" + + shutdown_timeout = 5 + """The total time to wait for worker threads to cleanly exit. + + Specified in seconds.""" + + timeout = 10 + """The timeout in seconds for accepted connections (default 10).""" + + expiration_interval = 0.5 + """The interval, in seconds, at which the server checks for + expired connections (default 0.5). + """ + + version = 'Cheroot/{version!s}'.format(version=__version__) + """A version string for the HTTPServer.""" + + software = None + """The value to set for the SERVER_SOFTWARE entry in the WSGI environ. + + If None, this defaults to ``'%s Server' % self.version``. + """ + + ready = False + """Internal flag which indicating the socket is accepting connections.""" + + max_request_header_size = 0 + """The maximum size, in bytes, for request headers, or 0 for no limit.""" + + max_request_body_size = 0 + """The maximum size, in bytes, for request bodies, or 0 for no limit.""" + + nodelay = True + """If True (the default since 3.1), sets the TCP_NODELAY socket option.""" + + ConnectionClass = HTTPConnection + """The class to use for handling HTTP connections.""" + + ssl_adapter = None + """An instance of ``ssl.Adapter`` (or a subclass). + + Ref: :py:class:`ssl.Adapter `. + + You must have the corresponding TLS driver library installed. + """ + + peercreds_enabled = False + """ + If :py:data:`True`, peer creds will be looked up via UNIX domain socket. + """ + + peercreds_resolve_enabled = False + """ + If :py:data:`True`, username/group will be looked up in the OS from + ``PEERCREDS``-provided IDs. + """ + + keep_alive_conn_limit = 10 + """The maximum number of waiting keep-alive connections that will be kept open. + + Default is 10. Set to None to have unlimited connections.""" + + def __init__( + self, bind_addr, gateway, + minthreads=10, maxthreads=-1, server_name=None, + peercreds_enabled=False, peercreds_resolve_enabled=False, + ): + """Initialize HTTPServer instance. + + Args: + bind_addr (tuple): network interface to listen to + gateway (Gateway): gateway for processing HTTP requests + minthreads (int): minimum number of threads for HTTP thread pool + maxthreads (int): maximum number of threads for HTTP thread pool + server_name (str): web server name to be advertised via Server + HTTP header + """ + self.bind_addr = bind_addr + self.gateway = gateway + + self.requests = threadpool.ThreadPool( + self, min=minthreads or 1, max=maxthreads, + ) + + if not server_name: + server_name = self.version + self.server_name = server_name + self.peercreds_enabled = peercreds_enabled + self.peercreds_resolve_enabled = ( + peercreds_resolve_enabled and peercreds_enabled + ) + self.clear_stats() + + def clear_stats(self): + """Reset server stat counters..""" + self._start_time = None + self._run_time = 0 + self.stats = { + 'Enabled': False, + 'Bind Address': lambda s: repr(self.bind_addr), + 'Run time': lambda s: (not s['Enabled']) and -1 or self.runtime(), + 'Accepts': 0, + 'Accepts/sec': lambda s: s['Accepts'] / self.runtime(), + 'Queue': lambda s: getattr(self.requests, 'qsize', None), + 'Threads': lambda s: len(getattr(self.requests, '_threads', [])), + 'Threads Idle': lambda s: getattr(self.requests, 'idle', None), + 'Socket Errors': 0, + 'Requests': lambda s: (not s['Enabled']) and -1 or sum( + (w['Requests'](w) for w in s['Worker Threads'].values()), 0, + ), + 'Bytes Read': lambda s: (not s['Enabled']) and -1 or sum( + (w['Bytes Read'](w) for w in s['Worker Threads'].values()), 0, + ), + 'Bytes Written': lambda s: (not s['Enabled']) and -1 or sum( + (w['Bytes Written'](w) for w in s['Worker Threads'].values()), + 0, + ), + 'Work Time': lambda s: (not s['Enabled']) and -1 or sum( + (w['Work Time'](w) for w in s['Worker Threads'].values()), 0, + ), + 'Read Throughput': lambda s: (not s['Enabled']) and -1 or sum( + ( + w['Bytes Read'](w) / (w['Work Time'](w) or 1e-6) + for w in s['Worker Threads'].values() + ), 0, + ), + 'Write Throughput': lambda s: (not s['Enabled']) and -1 or sum( + ( + w['Bytes Written'](w) / (w['Work Time'](w) or 1e-6) + for w in s['Worker Threads'].values() + ), 0, + ), + 'Worker Threads': {}, + } + logging.statistics['Cheroot HTTPServer %d' % id(self)] = self.stats + + def runtime(self): + """Return server uptime.""" + if self._start_time is None: + return self._run_time + else: + return self._run_time + (time.time() - self._start_time) + + def __str__(self): + """Render Server instance representing bind address.""" + return '%s.%s(%r)' % ( + self.__module__, self.__class__.__name__, + self.bind_addr, + ) + + @property + def bind_addr(self): + """Return the interface on which to listen for connections. + + For TCP sockets, a (host, port) tuple. Host values may be any + :term:`IPv4` or :term:`IPv6` address, or any valid hostname. + The string 'localhost' is a synonym for '127.0.0.1' (or '::1', + if your hosts file prefers :term:`IPv6`). + The string '0.0.0.0' is a special :term:`IPv4` entry meaning + "any active interface" (INADDR_ANY), and '::' is the similar + IN6ADDR_ANY for :term:`IPv6`. + The empty string or :py:data:`None` are not allowed. + + For UNIX sockets, supply the file name as a string. + + Systemd socket activation is automatic and doesn't require tempering + with this variable. + + .. glossary:: + + :abbr:`IPv4 (Internet Protocol version 4)` + Internet Protocol version 4 + + :abbr:`IPv6 (Internet Protocol version 6)` + Internet Protocol version 6 + """ + return self._bind_addr + + @bind_addr.setter + def bind_addr(self, value): + """Set the interface on which to listen for connections.""" + if isinstance(value, tuple) and value[0] in ('', None): + # Despite the socket module docs, using '' does not + # allow AI_PASSIVE to work. Passing None instead + # returns '0.0.0.0' like we want. In other words: + # host AI_PASSIVE result + # '' Y 192.168.x.y + # '' N 192.168.x.y + # None Y 0.0.0.0 + # None N 127.0.0.1 + # But since you can get the same effect with an explicit + # '0.0.0.0', we deny both the empty string and None as values. + raise ValueError( + "Host values of '' or None are not allowed. " + "Use '0.0.0.0' (IPv4) or '::' (IPv6) instead " + 'to listen on all active interfaces.', + ) + self._bind_addr = value + + def safe_start(self): + """Run the server forever, and stop it cleanly on exit.""" + try: + self.start() + except (KeyboardInterrupt, IOError): + # The time.sleep call might raise + # "IOError: [Errno 4] Interrupted function call" on KBInt. + self.error_log('Keyboard Interrupt: shutting down') + self.stop() + raise + except SystemExit: + self.error_log('SystemExit raised: shutting down') + self.stop() + raise + + def prepare(self): # noqa: C901 # FIXME + """Prepare server to serving requests. + + It binds a socket's port, setups the socket to ``listen()`` and does + other preparing things. + """ + self._interrupt = None + + if self.software is None: + self.software = '%s Server' % self.version + + # Select the appropriate socket + self.socket = None + msg = 'No socket could be created' + if os.getenv('LISTEN_PID', None): + # systemd socket activation + self.socket = socket.fromfd(3, socket.AF_INET, socket.SOCK_STREAM) + elif isinstance(self.bind_addr, (six.text_type, six.binary_type)): + # AF_UNIX socket + try: + self.bind_unix_socket(self.bind_addr) + except socket.error as serr: + msg = '%s -- (%s: %s)' % (msg, self.bind_addr, serr) + six.raise_from(socket.error(msg), serr) + else: + # AF_INET or AF_INET6 socket + # Get the correct address family for our host (allows IPv6 + # addresses) + host, port = self.bind_addr + try: + info = socket.getaddrinfo( + host, port, socket.AF_UNSPEC, + socket.SOCK_STREAM, 0, socket.AI_PASSIVE, + ) + except socket.gaierror: + sock_type = socket.AF_INET + bind_addr = self.bind_addr + + if ':' in host: + sock_type = socket.AF_INET6 + bind_addr = bind_addr + (0, 0) + + info = [(sock_type, socket.SOCK_STREAM, 0, '', bind_addr)] + + for res in info: + af, socktype, proto, canonname, sa = res + try: + self.bind(af, socktype, proto) + break + except socket.error as serr: + msg = '%s -- (%s: %s)' % (msg, sa, serr) + if self.socket: + self.socket.close() + self.socket = None + + if not self.socket: + raise socket.error(msg) + + # Timeout so KeyboardInterrupt can be caught on Win32 + self.socket.settimeout(1) + self.socket.listen(self.request_queue_size) + + # must not be accessed once stop() has been called + self._connections = connections.ConnectionManager(self) + + # Create worker threads + self.requests.start() + + self.ready = True + self._start_time = time.time() + + def serve(self): + """Serve requests, after invoking :func:`prepare()`.""" + while self.ready and not self.interrupt: + try: + self._connections.run(self.expiration_interval) + except (KeyboardInterrupt, SystemExit): + raise + except Exception: + self.error_log( + 'Error in HTTPServer.serve', level=logging.ERROR, + traceback=True, + ) + + # raise exceptions reported by any worker threads, + # such that the exception is raised from the serve() thread. + if self.interrupt: + while self._stopping_for_interrupt: + time.sleep(0.1) + if self.interrupt: + raise self.interrupt + + def start(self): + """Run the server forever. + + It is shortcut for invoking :func:`prepare()` then :func:`serve()`. + """ + # We don't have to trap KeyboardInterrupt or SystemExit here, + # because cherrypy.server already does so, calling self.stop() for us. + # If you're using this server with another framework, you should + # trap those exceptions in whatever code block calls start(). + self.prepare() + self.serve() + + @contextlib.contextmanager + def _run_in_thread(self): + """Context manager for running this server in a thread.""" + self.prepare() + thread = threading.Thread(target=self.serve) + thread.setDaemon(True) + thread.start() + try: + yield thread + finally: + self.stop() + + @property + def can_add_keepalive_connection(self): + """Flag whether it is allowed to add a new keep-alive connection.""" + return self.ready and self._connections.can_add_keepalive_connection + + def put_conn(self, conn): + """Put an idle connection back into the ConnectionManager.""" + if self.ready: + self._connections.put(conn) + else: + # server is shutting down, just close it + conn.close() + + def error_log(self, msg='', level=20, traceback=False): + """Write error message to log. + + Args: + msg (str): error message + level (int): logging level + traceback (bool): add traceback to output or not + """ + # Override this in subclasses as desired + sys.stderr.write('{msg!s}\n'.format(msg=msg)) + sys.stderr.flush() + if traceback: + tblines = traceback_.format_exc() + sys.stderr.write(tblines) + sys.stderr.flush() + + def bind(self, family, type, proto=0): + """Create (or recreate) the actual socket object.""" + sock = self.prepare_socket( + self.bind_addr, + family, type, proto, + self.nodelay, self.ssl_adapter, + ) + sock = self.socket = self.bind_socket(sock, self.bind_addr) + self.bind_addr = self.resolve_real_bind_addr(sock) + return sock + + def bind_unix_socket(self, bind_addr): # noqa: C901 # FIXME + """Create (or recreate) a UNIX socket object.""" + if IS_WINDOWS: + """ + Trying to access socket.AF_UNIX under Windows + causes an AttributeError. + """ + raise ValueError( # or RuntimeError? + 'AF_UNIX sockets are not supported under Windows.', + ) + + fs_permissions = 0o777 # TODO: allow changing mode + + try: + # Make possible reusing the socket... + os.unlink(self.bind_addr) + except OSError: + """ + File does not exist, which is the primary goal anyway. + """ + except TypeError as typ_err: + err_msg = str(typ_err) + if ( + 'remove() argument 1 must be encoded ' + 'string without null bytes, not unicode' + not in err_msg + and 'embedded NUL character' not in err_msg # py34 + and 'argument must be a ' + 'string without NUL characters' not in err_msg # pypy2 + ): + raise + except ValueError as val_err: + err_msg = str(val_err) + if ( + 'unlink: embedded null ' + 'character in path' not in err_msg + and 'embedded null byte' not in err_msg + and 'argument must be a ' + 'string without NUL characters' not in err_msg # pypy3 + ): + raise + + sock = self.prepare_socket( + bind_addr=bind_addr, + family=socket.AF_UNIX, type=socket.SOCK_STREAM, proto=0, + nodelay=self.nodelay, ssl_adapter=self.ssl_adapter, + ) + + try: + """Linux way of pre-populating fs mode permissions.""" + # Allow everyone access the socket... + os.fchmod(sock.fileno(), fs_permissions) + FS_PERMS_SET = True + except OSError: + FS_PERMS_SET = False + + try: + sock = self.bind_socket(sock, bind_addr) + except socket.error: + sock.close() + raise + + bind_addr = self.resolve_real_bind_addr(sock) + + try: + """FreeBSD/macOS pre-populating fs mode permissions.""" + if not FS_PERMS_SET: + try: + os.lchmod(bind_addr, fs_permissions) + except AttributeError: + os.chmod(bind_addr, fs_permissions, follow_symlinks=False) + FS_PERMS_SET = True + except OSError: + pass + + if not FS_PERMS_SET: + self.error_log( + 'Failed to set socket fs mode permissions', + level=logging.WARNING, + ) + + self.bind_addr = bind_addr + self.socket = sock + return sock + + @staticmethod + def prepare_socket(bind_addr, family, type, proto, nodelay, ssl_adapter): + """Create and prepare the socket object.""" + sock = socket.socket(family, type, proto) + connections.prevent_socket_inheritance(sock) + + host, port = bind_addr[:2] + IS_EPHEMERAL_PORT = port == 0 + + if not (IS_WINDOWS or IS_EPHEMERAL_PORT): + """Enable SO_REUSEADDR for the current socket. + + Skip for Windows (has different semantics) + or ephemeral ports (can steal ports from others). + + Refs: + * https://msdn.microsoft.com/en-us/library/ms740621(v=vs.85).aspx + * https://github.com/cherrypy/cheroot/issues/114 + * https://gavv.github.io/blog/ephemeral-port-reuse/ + """ + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + if nodelay and not isinstance( + bind_addr, + (six.text_type, six.binary_type), + ): + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + + if ssl_adapter is not None: + sock = ssl_adapter.bind(sock) + + # If listening on the IPV6 any address ('::' = IN6ADDR_ANY), + # activate dual-stack. See + # https://github.com/cherrypy/cherrypy/issues/871. + listening_ipv6 = ( + hasattr(socket, 'AF_INET6') + and family == socket.AF_INET6 + and host in ('::', '::0', '::0.0.0.0') + ) + if listening_ipv6: + try: + sock.setsockopt( + socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0, + ) + except (AttributeError, socket.error): + # Apparently, the socket option is not available in + # this machine's TCP stack + pass + + return sock + + @staticmethod + def bind_socket(socket_, bind_addr): + """Bind the socket to given interface.""" + socket_.bind(bind_addr) + return socket_ + + @staticmethod + def resolve_real_bind_addr(socket_): + """Retrieve actual bind address from bound socket.""" + # FIXME: keep requested bind_addr separate real bound_addr (port + # is different in case of ephemeral port 0) + bind_addr = socket_.getsockname() + if socket_.family in ( + # Windows doesn't have socket.AF_UNIX, so not using it in check + socket.AF_INET, + socket.AF_INET6, + ): + """UNIX domain sockets are strings or bytes. + + In case of bytes with a leading null-byte it's an abstract socket. + """ + return bind_addr[:2] + + if isinstance(bind_addr, six.binary_type): + bind_addr = bton(bind_addr) + + return bind_addr + + def process_conn(self, conn): + """Process an incoming HTTPConnection.""" + try: + self.requests.put(conn) + except queue.Full: + # Just drop the conn. TODO: write 503 back? + conn.close() + + @property + def interrupt(self): + """Flag interrupt of the server.""" + return self._interrupt + + @property + def _stopping_for_interrupt(self): + """Return whether the server is responding to an interrupt.""" + return self._interrupt is _STOPPING_FOR_INTERRUPT + + @interrupt.setter + def interrupt(self, interrupt): + """Perform the shutdown of this server and save the exception. + + Typically invoked by a worker thread in + :py:mod:`~cheroot.workers.threadpool`, the exception is raised + from the thread running :py:meth:`serve` once :py:meth:`stop` + has completed. + """ + self._interrupt = _STOPPING_FOR_INTERRUPT + self.stop() + self._interrupt = interrupt + + def stop(self): # noqa: C901 # FIXME + """Gracefully shutdown a server that is serving forever.""" + if not self.ready: + return # already stopped + + self.ready = False + if self._start_time is not None: + self._run_time += (time.time() - self._start_time) + self._start_time = None + + self._connections.stop() + + sock = getattr(self, 'socket', None) + if sock: + if not isinstance( + self.bind_addr, + (six.text_type, six.binary_type), + ): + # Touch our own socket to make accept() return immediately. + try: + host, port = sock.getsockname()[:2] + except socket.error as ex: + if ex.args[0] not in errors.socket_errors_to_ignore: + # Changed to use error code and not message + # See + # https://github.com/cherrypy/cherrypy/issues/860. + raise + else: + # Note that we're explicitly NOT using AI_PASSIVE, + # here, because we want an actual IP to touch. + # localhost won't work if we've bound to a public IP, + # but it will if we bound to '0.0.0.0' (INADDR_ANY). + for res in socket.getaddrinfo( + host, port, socket.AF_UNSPEC, + socket.SOCK_STREAM, + ): + af, socktype, proto, canonname, sa = res + s = None + try: + s = socket.socket(af, socktype, proto) + # See + # https://groups.google.com/group/cherrypy-users/ + # browse_frm/thread/bbfe5eb39c904fe0 + s.settimeout(1.0) + s.connect((host, port)) + s.close() + except socket.error: + if s: + s.close() + if hasattr(sock, 'close'): + sock.close() + self.socket = None + + self._connections.close() + self.requests.stop(self.shutdown_timeout) + + +class Gateway: + """Base class to interface HTTPServer with other systems, such as WSGI.""" + + def __init__(self, req): + """Initialize Gateway instance with request. + + Args: + req (HTTPRequest): current HTTP request + """ + self.req = req + + def respond(self): + """Process the current request. Must be overridden in a subclass.""" + raise NotImplementedError # pragma: no cover + + +# These may either be ssl.Adapter subclasses or the string names +# of such classes (in which case they will be lazily loaded). +ssl_adapters = { + 'builtin': 'cheroot.ssl.builtin.BuiltinSSLAdapter', + 'pyopenssl': 'cheroot.ssl.pyopenssl.pyOpenSSLAdapter', +} + + +def get_ssl_adapter_class(name='builtin'): + """Return an SSL adapter class for the given name.""" + adapter = ssl_adapters[name.lower()] + if isinstance(adapter, six.string_types): + last_dot = adapter.rfind('.') + attr_name = adapter[last_dot + 1:] + mod_path = adapter[:last_dot] + + try: + mod = sys.modules[mod_path] + if mod is None: + raise KeyError() + except KeyError: + # The last [''] is important. + mod = __import__(mod_path, globals(), locals(), ['']) + + # Let an AttributeError propagate outward. + try: + adapter = getattr(mod, attr_name) + except AttributeError: + raise AttributeError( + "'%s' object has no attribute '%s'" + % (mod_path, attr_name), + ) + + return adapter diff --git a/resources/lib/cheroot/ssl/__init__.py b/resources/lib/cheroot/ssl/__init__.py new file mode 100644 index 0000000..d45fd7f --- /dev/null +++ b/resources/lib/cheroot/ssl/__init__.py @@ -0,0 +1,52 @@ +"""Implementation of the SSL adapter base interface.""" + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +from abc import ABCMeta, abstractmethod + +from six import add_metaclass + + +@add_metaclass(ABCMeta) +class Adapter: + """Base class for SSL driver library adapters. + + Required methods: + + * ``wrap(sock) -> (wrapped socket, ssl environ dict)`` + * ``makefile(sock, mode='r', bufsize=DEFAULT_BUFFER_SIZE) -> + socket file object`` + """ + + @abstractmethod + def __init__( + self, certificate, private_key, certificate_chain=None, + ciphers=None, + ): + """Set up certificates, private key ciphers and reset context.""" + self.certificate = certificate + self.private_key = private_key + self.certificate_chain = certificate_chain + self.ciphers = ciphers + self.context = None + + @abstractmethod + def bind(self, sock): + """Wrap and return the given socket.""" + return sock + + @abstractmethod + def wrap(self, sock): + """Wrap and return the given socket, plus WSGI environ entries.""" + raise NotImplementedError # pragma: no cover + + @abstractmethod + def get_environ(self): + """Return WSGI environ entries to be merged into each request.""" + raise NotImplementedError # pragma: no cover + + @abstractmethod + def makefile(self, sock, mode='r', bufsize=-1): + """Return socket file object.""" + raise NotImplementedError # pragma: no cover diff --git a/resources/lib/cheroot/ssl/builtin.py b/resources/lib/cheroot/ssl/builtin.py new file mode 100644 index 0000000..ff987a7 --- /dev/null +++ b/resources/lib/cheroot/ssl/builtin.py @@ -0,0 +1,485 @@ +""" +A library for integrating Python's builtin :py:mod:`ssl` library with Cheroot. + +The :py:mod:`ssl` module must be importable for SSL functionality. + +To use this module, set ``HTTPServer.ssl_adapter`` to an instance of +``BuiltinSSLAdapter``. +""" + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +import socket +import sys +import threading + +try: + import ssl +except ImportError: + ssl = None + +try: + from _pyio import DEFAULT_BUFFER_SIZE +except ImportError: + try: + from io import DEFAULT_BUFFER_SIZE + except ImportError: + DEFAULT_BUFFER_SIZE = -1 + +import six + +from . import Adapter +from .. import errors +from .._compat import IS_ABOVE_OPENSSL10, suppress +from ..makefile import StreamReader, StreamWriter +from ..server import HTTPServer + +if six.PY2: + generic_socket_error = socket.error +else: + generic_socket_error = OSError + + +def _assert_ssl_exc_contains(exc, *msgs): + """Check whether SSL exception contains either of messages provided.""" + if len(msgs) < 1: + raise TypeError( + '_assert_ssl_exc_contains() requires ' + 'at least one message to be passed.', + ) + err_msg_lower = str(exc).lower() + return any(m.lower() in err_msg_lower for m in msgs) + + +def _loopback_for_cert_thread(context, server): + """Wrap a socket in ssl and perform the server-side handshake.""" + # As we only care about parsing the certificate, the failure of + # which will cause an exception in ``_loopback_for_cert``, + # we can safely ignore connection and ssl related exceptions. Ref: + # https://github.com/cherrypy/cheroot/issues/302#issuecomment-662592030 + with suppress(ssl.SSLError, OSError): + with context.wrap_socket( + server, do_handshake_on_connect=True, server_side=True, + ) as ssl_sock: + # in TLS 1.3 (Python 3.7+, OpenSSL 1.1.1+), the server + # sends the client session tickets that can be used to + # resume the TLS session on a new connection without + # performing the full handshake again. session tickets are + # sent as a post-handshake message at some _unspecified_ + # time and thus a successful connection may be closed + # without the client having received the tickets. + # Unfortunately, on Windows (Python 3.8+), this is treated + # as an incomplete handshake on the server side and a + # ``ConnectionAbortedError`` is raised. + # TLS 1.3 support is still incomplete in Python 3.8; + # there is no way for the client to wait for tickets. + # While not necessary for retrieving the parsed certificate, + # we send a tiny bit of data over the connection in an + # attempt to give the server a chance to send the session + # tickets and close the connection cleanly. + # Note that, as this is essentially a race condition, + # the error may still occur ocasionally. + ssl_sock.send(b'0000') + + +def _loopback_for_cert(certificate, private_key, certificate_chain): + """Create a loopback connection to parse a cert with a private key.""" + context = ssl.create_default_context(cafile=certificate_chain) + context.load_cert_chain(certificate, private_key) + context.check_hostname = False + context.verify_mode = ssl.CERT_NONE + + # Python 3+ Unix, Python 3.5+ Windows + client, server = socket.socketpair() + try: + # `wrap_socket` will block until the ssl handshake is complete. + # it must be called on both ends at the same time -> thread + # openssl will cache the peer's cert during a successful handshake + # and return it via `getpeercert` even after the socket is closed. + # when `close` is called, the SSL shutdown notice will be sent + # and then python will wait to receive the corollary shutdown. + thread = threading.Thread( + target=_loopback_for_cert_thread, args=(context, server), + ) + try: + thread.start() + with context.wrap_socket( + client, do_handshake_on_connect=True, + server_side=False, + ) as ssl_sock: + ssl_sock.recv(4) + return ssl_sock.getpeercert() + finally: + thread.join() + finally: + client.close() + server.close() + + +def _parse_cert(certificate, private_key, certificate_chain): + """Parse a certificate.""" + # loopback_for_cert uses socket.socketpair which was only + # introduced in Python 3.0 for *nix and 3.5 for Windows + # and requires OS support (AttributeError, OSError) + # it also requires a private key either in its own file + # or combined with the cert (SSLError) + with suppress(AttributeError, ssl.SSLError, OSError): + return _loopback_for_cert(certificate, private_key, certificate_chain) + + # KLUDGE: using an undocumented, private, test method to parse a cert + # unfortunately, it is the only built-in way without a connection + # as a private, undocumented method, it may change at any time + # so be tolerant of *any* possible errors it may raise + with suppress(Exception): + return ssl._ssl._test_decode_cert(certificate) + + return {} + + +def _sni_callback(sock, sni, context): + """Handle the SNI callback to tag the socket with the SNI.""" + sock.sni = sni + # return None to allow the TLS negotiation to continue + + +class BuiltinSSLAdapter(Adapter): + """Wrapper for integrating Python's builtin :py:mod:`ssl` with Cheroot.""" + + certificate = None + """The file name of the server SSL certificate.""" + + private_key = None + """The file name of the server's private key file.""" + + certificate_chain = None + """The file name of the certificate chain file.""" + + ciphers = None + """The ciphers list of SSL.""" + + # from mod_ssl/pkg.sslmod/ssl_engine_vars.c ssl_var_lookup_ssl_cert + CERT_KEY_TO_ENV = { + 'version': 'M_VERSION', + 'serialNumber': 'M_SERIAL', + 'notBefore': 'V_START', + 'notAfter': 'V_END', + 'subject': 'S_DN', + 'issuer': 'I_DN', + 'subjectAltName': 'SAN', + # not parsed by the Python standard library + # - A_SIG + # - A_KEY + # not provided by mod_ssl + # - OCSP + # - caIssuers + # - crlDistributionPoints + } + + # from mod_ssl/pkg.sslmod/ssl_engine_vars.c ssl_var_lookup_ssl_cert_dn_rec + CERT_KEY_TO_LDAP_CODE = { + 'countryName': 'C', + 'stateOrProvinceName': 'ST', + # NOTE: mod_ssl also provides 'stateOrProvinceName' as 'SP' + # for compatibility with SSLeay + 'localityName': 'L', + 'organizationName': 'O', + 'organizationalUnitName': 'OU', + 'commonName': 'CN', + 'title': 'T', + 'initials': 'I', + 'givenName': 'G', + 'surname': 'S', + 'description': 'D', + 'userid': 'UID', + 'emailAddress': 'Email', + # not provided by mod_ssl + # - dnQualifier: DNQ + # - domainComponent: DC + # - postalCode: PC + # - streetAddress: STREET + # - serialNumber + # - generationQualifier + # - pseudonym + # - jurisdictionCountryName + # - jurisdictionLocalityName + # - jurisdictionStateOrProvince + # - businessCategory + } + + def __init__( + self, certificate, private_key, certificate_chain=None, + ciphers=None, + ): + """Set up context in addition to base class properties if available.""" + if ssl is None: + raise ImportError('You must install the ssl module to use HTTPS.') + + super(BuiltinSSLAdapter, self).__init__( + certificate, private_key, certificate_chain, ciphers, + ) + + self.context = ssl.create_default_context( + purpose=ssl.Purpose.CLIENT_AUTH, + cafile=certificate_chain, + ) + self.context.load_cert_chain(certificate, private_key) + if self.ciphers is not None: + self.context.set_ciphers(ciphers) + + self._server_env = self._make_env_cert_dict( + 'SSL_SERVER', + _parse_cert(certificate, private_key, self.certificate_chain), + ) + if not self._server_env: + return + cert = None + with open(certificate, mode='rt') as f: + cert = f.read() + + # strip off any keys by only taking the first certificate + cert_start = cert.find(ssl.PEM_HEADER) + if cert_start == -1: + return + cert_end = cert.find(ssl.PEM_FOOTER, cert_start) + if cert_end == -1: + return + cert_end += len(ssl.PEM_FOOTER) + self._server_env['SSL_SERVER_CERT'] = cert[cert_start:cert_end] + + @property + def context(self): + """:py:class:`~ssl.SSLContext` that will be used to wrap sockets.""" + return self._context + + @context.setter + def context(self, context): + """Set the ssl ``context`` to use.""" + self._context = context + # Python 3.7+ + # if a context is provided via `cherrypy.config.update` then + # `self.context` will be set after `__init__` + # use a property to intercept it to add an SNI callback + # but don't override the user's callback + # TODO: chain callbacks + with suppress(AttributeError): + if ssl.HAS_SNI and context.sni_callback is None: + context.sni_callback = _sni_callback + + def bind(self, sock): + """Wrap and return the given socket.""" + return super(BuiltinSSLAdapter, self).bind(sock) + + def wrap(self, sock): + """Wrap and return the given socket, plus WSGI environ entries.""" + EMPTY_RESULT = None, {} + try: + s = self.context.wrap_socket( + sock, do_handshake_on_connect=True, server_side=True, + ) + except ssl.SSLError as ex: + if ex.errno == ssl.SSL_ERROR_EOF: + # This is almost certainly due to the cherrypy engine + # 'pinging' the socket to assert it's connectable; + # the 'ping' isn't SSL. + return EMPTY_RESULT + elif ex.errno == ssl.SSL_ERROR_SSL: + if _assert_ssl_exc_contains(ex, 'http request'): + # The client is speaking HTTP to an HTTPS server. + raise errors.NoSSLError + + # Check if it's one of the known errors + # Errors that are caught by PyOpenSSL, but thrown by + # built-in ssl + _block_errors = ( + 'unknown protocol', 'unknown ca', 'unknown_ca', + 'unknown error', + 'https proxy request', 'inappropriate fallback', + 'wrong version number', + 'no shared cipher', 'certificate unknown', + 'ccs received early', + 'certificate verify failed', # client cert w/o trusted CA + 'version too low', # caused by SSL3 connections + 'unsupported protocol', # caused by TLS1 connections + ) + if _assert_ssl_exc_contains(ex, *_block_errors): + # Accepted error, let's pass + return EMPTY_RESULT + elif _assert_ssl_exc_contains(ex, 'handshake operation timed out'): + # This error is thrown by builtin SSL after a timeout + # when client is speaking HTTP to an HTTPS server. + # The connection can safely be dropped. + return EMPTY_RESULT + raise + except generic_socket_error as exc: + """It is unclear why exactly this happens. + + It's reproducible only with openssl>1.0 and stdlib + :py:mod:`ssl` wrapper. + In CherryPy it's triggered by Checker plugin, which connects + to the app listening to the socket port in TLS mode via plain + HTTP during startup (from the same process). + + + Ref: https://github.com/cherrypy/cherrypy/issues/1618 + """ + is_error0 = exc.args == (0, 'Error') + + if is_error0 and IS_ABOVE_OPENSSL10: + return EMPTY_RESULT + raise + return s, self.get_environ(s) + + def get_environ(self, sock): + """Create WSGI environ entries to be merged into each request.""" + cipher = sock.cipher() + ssl_environ = { + 'wsgi.url_scheme': 'https', + 'HTTPS': 'on', + 'SSL_PROTOCOL': cipher[1], + 'SSL_CIPHER': cipher[0], + 'SSL_CIPHER_EXPORT': '', + 'SSL_CIPHER_USEKEYSIZE': cipher[2], + 'SSL_VERSION_INTERFACE': '%s Python/%s' % ( + HTTPServer.version, sys.version, + ), + 'SSL_VERSION_LIBRARY': ssl.OPENSSL_VERSION, + 'SSL_CLIENT_VERIFY': 'NONE', + # 'NONE' - client did not provide a cert (overriden below) + } + + # Python 3.3+ + with suppress(AttributeError): + compression = sock.compression() + if compression is not None: + ssl_environ['SSL_COMPRESS_METHOD'] = compression + + # Python 3.6+ + with suppress(AttributeError): + ssl_environ['SSL_SESSION_ID'] = sock.session.id.hex() + with suppress(AttributeError): + target_cipher = cipher[:2] + for cip in sock.context.get_ciphers(): + if target_cipher == (cip['name'], cip['protocol']): + ssl_environ['SSL_CIPHER_ALGKEYSIZE'] = cip['alg_bits'] + break + + # Python 3.7+ sni_callback + with suppress(AttributeError): + ssl_environ['SSL_TLS_SNI'] = sock.sni + + if self.context and self.context.verify_mode != ssl.CERT_NONE: + client_cert = sock.getpeercert() + if client_cert: + # builtin ssl **ALWAYS** validates client certificates + # and terminates the connection on failure + ssl_environ['SSL_CLIENT_VERIFY'] = 'SUCCESS' + ssl_environ.update( + self._make_env_cert_dict('SSL_CLIENT', client_cert), + ) + ssl_environ['SSL_CLIENT_CERT'] = ssl.DER_cert_to_PEM_cert( + sock.getpeercert(binary_form=True), + ).strip() + + ssl_environ.update(self._server_env) + + # not supplied by the Python standard library (as of 3.8) + # - SSL_SESSION_RESUMED + # - SSL_SECURE_RENEG + # - SSL_CLIENT_CERT_CHAIN_n + # - SRP_USER + # - SRP_USERINFO + + return ssl_environ + + def _make_env_cert_dict(self, env_prefix, parsed_cert): + """Return a dict of WSGI environment variables for a certificate. + + E.g. SSL_CLIENT_M_VERSION, SSL_CLIENT_M_SERIAL, etc. + See https://httpd.apache.org/docs/2.4/mod/mod_ssl.html#envvars. + """ + if not parsed_cert: + return {} + + env = {} + for cert_key, env_var in self.CERT_KEY_TO_ENV.items(): + key = '%s_%s' % (env_prefix, env_var) + value = parsed_cert.get(cert_key) + if env_var == 'SAN': + env.update(self._make_env_san_dict(key, value)) + elif env_var.endswith('_DN'): + env.update(self._make_env_dn_dict(key, value)) + else: + env[key] = str(value) + + # mod_ssl 2.1+; Python 3.2+ + # number of days until the certificate expires + if 'notBefore' in parsed_cert: + remain = ssl.cert_time_to_seconds(parsed_cert['notAfter']) + remain -= ssl.cert_time_to_seconds(parsed_cert['notBefore']) + remain /= 60 * 60 * 24 + env['%s_V_REMAIN' % (env_prefix,)] = str(int(remain)) + + return env + + def _make_env_san_dict(self, env_prefix, cert_value): + """Return a dict of WSGI environment variables for a certificate DN. + + E.g. SSL_CLIENT_SAN_Email_0, SSL_CLIENT_SAN_DNS_0, etc. + See SSL_CLIENT_SAN_* at + https://httpd.apache.org/docs/2.4/mod/mod_ssl.html#envvars. + """ + if not cert_value: + return {} + + env = {} + dns_count = 0 + email_count = 0 + for attr_name, val in cert_value: + if attr_name == 'DNS': + env['%s_DNS_%i' % (env_prefix, dns_count)] = val + dns_count += 1 + elif attr_name == 'Email': + env['%s_Email_%i' % (env_prefix, email_count)] = val + email_count += 1 + + # other mod_ssl SAN vars: + # - SAN_OTHER_msUPN_n + return env + + def _make_env_dn_dict(self, env_prefix, cert_value): + """Return a dict of WSGI environment variables for a certificate DN. + + E.g. SSL_CLIENT_S_DN_CN, SSL_CLIENT_S_DN_C, etc. + See SSL_CLIENT_S_DN_x509 at + https://httpd.apache.org/docs/2.4/mod/mod_ssl.html#envvars. + """ + if not cert_value: + return {} + + dn = [] + dn_attrs = {} + for rdn in cert_value: + for attr_name, val in rdn: + attr_code = self.CERT_KEY_TO_LDAP_CODE.get(attr_name) + dn.append('%s=%s' % (attr_code or attr_name, val)) + if not attr_code: + continue + dn_attrs.setdefault(attr_code, []) + dn_attrs[attr_code].append(val) + + env = { + env_prefix: ','.join(dn), + } + for attr_code, values in dn_attrs.items(): + env['%s_%s' % (env_prefix, attr_code)] = ','.join(values) + if len(values) == 1: + continue + for i, val in enumerate(values): + env['%s_%s_%i' % (env_prefix, attr_code, i)] = val + return env + + def makefile(self, sock, mode='r', bufsize=DEFAULT_BUFFER_SIZE): + """Return socket file object.""" + cls = StreamReader if 'r' in mode else StreamWriter + return cls(sock, mode, bufsize) diff --git a/resources/lib/cheroot/ssl/pyopenssl.py b/resources/lib/cheroot/ssl/pyopenssl.py new file mode 100644 index 0000000..adc9a1b --- /dev/null +++ b/resources/lib/cheroot/ssl/pyopenssl.py @@ -0,0 +1,382 @@ +""" +A library for integrating :doc:`pyOpenSSL ` with Cheroot. + +The :py:mod:`OpenSSL ` module must be importable +for SSL/TLS/HTTPS functionality. +You can obtain it from `here `_. + +To use this module, set :py:attr:`HTTPServer.ssl_adapter +` to an instance of +:py:class:`ssl.Adapter `. +There are two ways to use :abbr:`TLS (Transport-Level Security)`: + +Method One +---------- + + * :py:attr:`ssl_adapter.context + `: an instance of + :py:class:`SSL.Context `. + +If this is not None, it is assumed to be an :py:class:`SSL.Context +` instance, and will be passed to +:py:class:`SSL.Connection ` on bind(). +The developer is responsible for forming a valid :py:class:`Context +` object. This +approach is to be preferred for more flexibility, e.g. if the cert and +key are streams instead of files, or need decryption, or +:py:data:`SSL.SSLv3_METHOD ` +is desired instead of the default :py:data:`SSL.SSLv23_METHOD +`, etc. Consult +the :doc:`pyOpenSSL ` documentation for +complete options. + +Method Two (shortcut) +--------------------- + + * :py:attr:`ssl_adapter.certificate + `: the file name + of the server's TLS certificate. + * :py:attr:`ssl_adapter.private_key + `: the file name + of the server's private key file. + +Both are :py:data:`None` by default. If :py:attr:`ssl_adapter.context +` is :py:data:`None`, +but ``.private_key`` and ``.certificate`` are both given and valid, they +will be read, and the context will be automatically created from them. + +.. spelling:: + + pyopenssl +""" + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +import socket +import sys +import threading +import time + +import six + +try: + import OpenSSL.version + from OpenSSL import SSL + from OpenSSL import crypto + + try: + ssl_conn_type = SSL.Connection + except AttributeError: + ssl_conn_type = SSL.ConnectionType +except ImportError: + SSL = None + +from . import Adapter +from .. import errors, server as cheroot_server +from ..makefile import StreamReader, StreamWriter + + +class SSLFileobjectMixin: + """Base mixin for a TLS socket stream.""" + + ssl_timeout = 3 + ssl_retry = .01 + + # FIXME: + def _safe_call(self, is_reader, call, *args, **kwargs): # noqa: C901 + """Wrap the given call with TLS error-trapping. + + is_reader: if False EOF errors will be raised. If True, EOF errors + will return "" (to emulate normal sockets). + """ + start = time.time() + while True: + try: + return call(*args, **kwargs) + except SSL.WantReadError: + # Sleep and try again. This is dangerous, because it means + # the rest of the stack has no way of differentiating + # between a "new handshake" error and "client dropped". + # Note this isn't an endless loop: there's a timeout below. + # Ref: https://stackoverflow.com/a/5133568/595220 + time.sleep(self.ssl_retry) + except SSL.WantWriteError: + time.sleep(self.ssl_retry) + except SSL.SysCallError as e: + if is_reader and e.args == (-1, 'Unexpected EOF'): + return b'' + + errnum = e.args[0] + if is_reader and errnum in errors.socket_errors_to_ignore: + return b'' + raise socket.error(errnum) + except SSL.Error as e: + if is_reader and e.args == (-1, 'Unexpected EOF'): + return b'' + + thirdarg = None + try: + thirdarg = e.args[0][0][2] + except IndexError: + pass + + if thirdarg == 'http request': + # The client is talking HTTP to an HTTPS server. + raise errors.NoSSLError() + + raise errors.FatalSSLAlert(*e.args) + + if time.time() - start > self.ssl_timeout: + raise socket.timeout('timed out') + + def recv(self, size): + """Receive message of a size from the socket.""" + return self._safe_call( + True, + super(SSLFileobjectMixin, self).recv, + size, + ) + + def readline(self, size=-1): + """Receive message of a size from the socket. + + Matches the following interface: + https://docs.python.org/3/library/io.html#io.IOBase.readline + """ + return self._safe_call( + True, + super(SSLFileobjectMixin, self).readline, + size, + ) + + def sendall(self, *args, **kwargs): + """Send whole message to the socket.""" + return self._safe_call( + False, + super(SSLFileobjectMixin, self).sendall, + *args, **kwargs + ) + + def send(self, *args, **kwargs): + """Send some part of message to the socket.""" + return self._safe_call( + False, + super(SSLFileobjectMixin, self).send, + *args, **kwargs + ) + + +class SSLFileobjectStreamReader(SSLFileobjectMixin, StreamReader): + """SSL file object attached to a socket object.""" + + +class SSLFileobjectStreamWriter(SSLFileobjectMixin, StreamWriter): + """SSL file object attached to a socket object.""" + + +class SSLConnectionProxyMeta: + """Metaclass for generating a bunch of proxy methods.""" + + def __new__(mcl, name, bases, nmspc): + """Attach a list of proxy methods to a new class.""" + proxy_methods = ( + 'get_context', 'pending', 'send', 'write', 'recv', 'read', + 'renegotiate', 'bind', 'listen', 'connect', 'accept', + 'setblocking', 'fileno', 'close', 'get_cipher_list', + 'getpeername', 'getsockname', 'getsockopt', 'setsockopt', + 'makefile', 'get_app_data', 'set_app_data', 'state_string', + 'sock_shutdown', 'get_peer_certificate', 'want_read', + 'want_write', 'set_connect_state', 'set_accept_state', + 'connect_ex', 'sendall', 'settimeout', 'gettimeout', + 'shutdown', + ) + proxy_methods_no_args = ( + 'shutdown', + ) + + proxy_props = ( + 'family', + ) + + def lock_decorator(method): + """Create a proxy method for a new class.""" + def proxy_wrapper(self, *args): + self._lock.acquire() + try: + new_args = ( + args[:] if method not in proxy_methods_no_args else [] + ) + return getattr(self._ssl_conn, method)(*new_args) + finally: + self._lock.release() + return proxy_wrapper + for m in proxy_methods: + nmspc[m] = lock_decorator(m) + nmspc[m].__name__ = m + + def make_property(property_): + """Create a proxy method for a new class.""" + def proxy_prop_wrapper(self): + return getattr(self._ssl_conn, property_) + proxy_prop_wrapper.__name__ = property_ + return property(proxy_prop_wrapper) + for p in proxy_props: + nmspc[p] = make_property(p) + + # Doesn't work via super() for some reason. + # Falling back to type() instead: + return type(name, bases, nmspc) + + +@six.add_metaclass(SSLConnectionProxyMeta) +class SSLConnection: + r"""A thread-safe wrapper for an ``SSL.Connection``. + + :param tuple args: the arguments to create the wrapped \ + :py:class:`SSL.Connection(*args) \ + ` + """ + + def __init__(self, *args): + """Initialize SSLConnection instance.""" + self._ssl_conn = SSL.Connection(*args) + self._lock = threading.RLock() + + +class pyOpenSSLAdapter(Adapter): + """A wrapper for integrating pyOpenSSL with Cheroot.""" + + certificate = None + """The file name of the server's TLS certificate.""" + + private_key = None + """The file name of the server's private key file.""" + + certificate_chain = None + """Optional. The file name of CA's intermediate certificate bundle. + + This is needed for cheaper "chained root" TLS certificates, + and should be left as :py:data:`None` if not required.""" + + context = None + """ + An instance of :py:class:`SSL.Context `. + """ + + ciphers = None + """The ciphers list of TLS.""" + + def __init__( + self, certificate, private_key, certificate_chain=None, + ciphers=None, + ): + """Initialize OpenSSL Adapter instance.""" + if SSL is None: + raise ImportError('You must install pyOpenSSL to use HTTPS.') + + super(pyOpenSSLAdapter, self).__init__( + certificate, private_key, certificate_chain, ciphers, + ) + + self._environ = None + + def bind(self, sock): + """Wrap and return the given socket.""" + if self.context is None: + self.context = self.get_context() + conn = SSLConnection(self.context, sock) + self._environ = self.get_environ() + return conn + + def wrap(self, sock): + """Wrap and return the given socket, plus WSGI environ entries.""" + # pyOpenSSL doesn't perform the handshake until the first read/write + # forcing the handshake to complete tends to result in the connection + # closing so we can't reliably access protocol/client cert for the env + return sock, self._environ.copy() + + def get_context(self): + """Return an ``SSL.Context`` from self attributes. + + Ref: :py:class:`SSL.Context ` + """ + # See https://code.activestate.com/recipes/442473/ + c = SSL.Context(SSL.SSLv23_METHOD) + c.use_privatekey_file(self.private_key) + if self.certificate_chain: + c.load_verify_locations(self.certificate_chain) + c.use_certificate_file(self.certificate) + return c + + def get_environ(self): + """Return WSGI environ entries to be merged into each request.""" + ssl_environ = { + 'wsgi.url_scheme': 'https', + 'HTTPS': 'on', + 'SSL_VERSION_INTERFACE': '%s %s/%s Python/%s' % ( + cheroot_server.HTTPServer.version, + OpenSSL.version.__title__, OpenSSL.version.__version__, + sys.version, + ), + 'SSL_VERSION_LIBRARY': SSL.SSLeay_version( + SSL.SSLEAY_VERSION, + ).decode(), + } + + if self.certificate: + # Server certificate attributes + with open(self.certificate, 'rb') as cert_file: + cert = crypto.load_certificate( + crypto.FILETYPE_PEM, cert_file.read(), + ) + + ssl_environ.update({ + 'SSL_SERVER_M_VERSION': cert.get_version(), + 'SSL_SERVER_M_SERIAL': cert.get_serial_number(), + # 'SSL_SERVER_V_START': + # Validity of server's certificate (start time), + # 'SSL_SERVER_V_END': + # Validity of server's certificate (end time), + }) + + for prefix, dn in [ + ('I', cert.get_issuer()), + ('S', cert.get_subject()), + ]: + # X509Name objects don't seem to have a way to get the + # complete DN string. Use str() and slice it instead, + # because str(dn) == "" + dnstr = str(dn)[18:-2] + + wsgikey = 'SSL_SERVER_%s_DN' % prefix + ssl_environ[wsgikey] = dnstr + + # The DN should be of the form: /k1=v1/k2=v2, but we must allow + # for any value to contain slashes itself (in a URL). + while dnstr: + pos = dnstr.rfind('=') + dnstr, value = dnstr[:pos], dnstr[pos + 1:] + pos = dnstr.rfind('/') + dnstr, key = dnstr[:pos], dnstr[pos + 1:] + if key and value: + wsgikey = 'SSL_SERVER_%s_DN_%s' % (prefix, key) + ssl_environ[wsgikey] = value + + return ssl_environ + + def makefile(self, sock, mode='r', bufsize=-1): + """Return socket file object.""" + cls = ( + SSLFileobjectStreamReader + if 'r' in mode else + SSLFileobjectStreamWriter + ) + if SSL and isinstance(sock, ssl_conn_type): + wrapped_socket = cls(sock, mode, bufsize) + wrapped_socket.ssl_timeout = sock.gettimeout() + return wrapped_socket + # This is from past: + # TODO: figure out what it's meant for + else: + return cheroot_server.CP_fileobject(sock, mode, bufsize) diff --git a/resources/lib/cheroot/testing.py b/resources/lib/cheroot/testing.py new file mode 100644 index 0000000..c9a6ac9 --- /dev/null +++ b/resources/lib/cheroot/testing.py @@ -0,0 +1,153 @@ +"""Pytest fixtures and other helpers for doing testing by end-users.""" + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +from contextlib import closing +import errno +import socket +import threading +import time + +import pytest +from six.moves import http_client + +import cheroot.server +from cheroot.test import webtest +import cheroot.wsgi + +EPHEMERAL_PORT = 0 +NO_INTERFACE = None # Using this or '' will cause an exception +ANY_INTERFACE_IPV4 = '0.0.0.0' +ANY_INTERFACE_IPV6 = '::' + +config = { + cheroot.wsgi.Server: { + 'bind_addr': (NO_INTERFACE, EPHEMERAL_PORT), + 'wsgi_app': None, + }, + cheroot.server.HTTPServer: { + 'bind_addr': (NO_INTERFACE, EPHEMERAL_PORT), + 'gateway': cheroot.server.Gateway, + }, +} + + +def cheroot_server(server_factory): + """Set up and tear down a Cheroot server instance.""" + conf = config[server_factory].copy() + bind_port = conf.pop('bind_addr')[-1] + + for interface in ANY_INTERFACE_IPV6, ANY_INTERFACE_IPV4: + try: + actual_bind_addr = (interface, bind_port) + httpserver = server_factory( # create it + bind_addr=actual_bind_addr, + **conf + ) + except OSError: + pass + else: + break + + httpserver.shutdown_timeout = 0 # Speed-up tests teardown + + threading.Thread(target=httpserver.safe_start).start() # spawn it + while not httpserver.ready: # wait until fully initialized and bound + time.sleep(0.1) + + yield httpserver + + httpserver.stop() # destroy it + + +@pytest.fixture +def wsgi_server(): + """Set up and tear down a Cheroot WSGI server instance.""" + for srv in cheroot_server(cheroot.wsgi.Server): + yield srv + + +@pytest.fixture +def native_server(): + """Set up and tear down a Cheroot HTTP server instance.""" + for srv in cheroot_server(cheroot.server.HTTPServer): + yield srv + + +class _TestClient: + def __init__(self, server): + self._interface, self._host, self._port = _get_conn_data( + server.bind_addr, + ) + self.server_instance = server + self._http_connection = self.get_connection() + + def get_connection(self): + name = '{interface}:{port}'.format( + interface=self._interface, + port=self._port, + ) + conn_cls = ( + http_client.HTTPConnection + if self.server_instance.ssl_adapter is None else + http_client.HTTPSConnection + ) + return conn_cls(name) + + def request( + self, uri, method='GET', headers=None, http_conn=None, + protocol='HTTP/1.1', + ): + return webtest.openURL( + uri, method=method, + headers=headers, + host=self._host, port=self._port, + http_conn=http_conn or self._http_connection, + protocol=protocol, + ) + + def __getattr__(self, attr_name): + def _wrapper(uri, **kwargs): + http_method = attr_name.upper() + return self.request(uri, method=http_method, **kwargs) + + return _wrapper + + +def _probe_ipv6_sock(interface): + # Alternate way is to check IPs on interfaces using glibc, like: + # github.com/Gautier/minifail/blob/master/minifail/getifaddrs.py + try: + with closing(socket.socket(family=socket.AF_INET6)) as sock: + sock.bind((interface, 0)) + except (OSError, socket.error) as sock_err: + # In Python 3 socket.error is an alias for OSError + # In Python 2 socket.error is a subclass of IOError + if sock_err.errno != errno.EADDRNOTAVAIL: + raise + else: + return True + + return False + + +def _get_conn_data(bind_addr): + if isinstance(bind_addr, tuple): + host, port = bind_addr + else: + host, port = bind_addr, 0 + + interface = webtest.interface(host) + + if ':' in interface and not _probe_ipv6_sock(interface): + interface = '127.0.0.1' + if ':' in host: + host = interface + + return interface, host, port + + +def get_server_client(server): + """Create and return a test client for the given server.""" + return _TestClient(server) diff --git a/resources/lib/cheroot/workers/__init__.py b/resources/lib/cheroot/workers/__init__.py new file mode 100644 index 0000000..098b8f2 --- /dev/null +++ b/resources/lib/cheroot/workers/__init__.py @@ -0,0 +1 @@ +"""HTTP workers pool.""" diff --git a/resources/lib/cheroot/workers/threadpool.py b/resources/lib/cheroot/workers/threadpool.py new file mode 100644 index 0000000..915934c --- /dev/null +++ b/resources/lib/cheroot/workers/threadpool.py @@ -0,0 +1,329 @@ +"""A thread-based worker pool. + +.. spelling:: + + joinable +""" + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + + +import collections +import threading +import time +import socket +import warnings + +from six.moves import queue + +from jaraco.functools import pass_none + + +__all__ = ('WorkerThread', 'ThreadPool') + + +class TrueyZero: + """Object which equals and does math like the integer 0 but evals True.""" + + def __add__(self, other): + return other + + def __radd__(self, other): + return other + + +trueyzero = TrueyZero() + +_SHUTDOWNREQUEST = None + + +class WorkerThread(threading.Thread): + """Thread which continuously polls a Queue for Connection objects. + + Due to the timing issues of polling a Queue, a WorkerThread does not + check its own 'ready' flag after it has started. To stop the thread, + it is necessary to stick a _SHUTDOWNREQUEST object onto the Queue + (one for each running WorkerThread). + """ + + conn = None + """The current connection pulled off the Queue, or None.""" + + server = None + """The HTTP Server which spawned this thread, and which owns the + Queue and is placing active connections into it.""" + + ready = False + """A simple flag for the calling server to know when this thread + has begun polling the Queue.""" + + def __init__(self, server): + """Initialize WorkerThread instance. + + Args: + server (cheroot.server.HTTPServer): web server object + receiving this request + """ + self.ready = False + self.server = server + + self.requests_seen = 0 + self.bytes_read = 0 + self.bytes_written = 0 + self.start_time = None + self.work_time = 0 + self.stats = { + 'Requests': lambda s: self.requests_seen + ( + self.start_time is None + and trueyzero + or self.conn.requests_seen + ), + 'Bytes Read': lambda s: self.bytes_read + ( + self.start_time is None + and trueyzero + or self.conn.rfile.bytes_read + ), + 'Bytes Written': lambda s: self.bytes_written + ( + self.start_time is None + and trueyzero + or self.conn.wfile.bytes_written + ), + 'Work Time': lambda s: self.work_time + ( + self.start_time is None + and trueyzero + or time.time() - self.start_time + ), + 'Read Throughput': lambda s: s['Bytes Read'](s) / ( + s['Work Time'](s) or 1e-6 + ), + 'Write Throughput': lambda s: s['Bytes Written'](s) / ( + s['Work Time'](s) or 1e-6 + ), + } + threading.Thread.__init__(self) + + def run(self): + """Process incoming HTTP connections. + + Retrieves incoming connections from thread pool. + """ + self.server.stats['Worker Threads'][self.getName()] = self.stats + try: + self.ready = True + while True: + conn = self.server.requests.get() + if conn is _SHUTDOWNREQUEST: + return + + self.conn = conn + is_stats_enabled = self.server.stats['Enabled'] + if is_stats_enabled: + self.start_time = time.time() + keep_conn_open = False + try: + keep_conn_open = conn.communicate() + finally: + if keep_conn_open: + self.server.put_conn(conn) + else: + conn.close() + if is_stats_enabled: + self.requests_seen += self.conn.requests_seen + self.bytes_read += self.conn.rfile.bytes_read + self.bytes_written += self.conn.wfile.bytes_written + self.work_time += time.time() - self.start_time + self.start_time = None + self.conn = None + except (KeyboardInterrupt, SystemExit) as ex: + self.server.interrupt = ex + + +class ThreadPool: + """A Request Queue for an HTTPServer which pools threads. + + ThreadPool objects must provide min, get(), put(obj), start() + and stop(timeout) attributes. + """ + + def __init__( + self, server, min=10, max=-1, accepted_queue_size=-1, + accepted_queue_timeout=10, + ): + """Initialize HTTP requests queue instance. + + Args: + server (cheroot.server.HTTPServer): web server object + receiving this request + min (int): minimum number of worker threads + max (int): maximum number of worker threads + accepted_queue_size (int): maximum number of active + requests in queue + accepted_queue_timeout (int): timeout for putting request + into queue + """ + self.server = server + self.min = min + self.max = max + self._threads = [] + self._queue = queue.Queue(maxsize=accepted_queue_size) + self._queue_put_timeout = accepted_queue_timeout + self.get = self._queue.get + self._pending_shutdowns = collections.deque() + + def start(self): + """Start the pool of threads.""" + for i in range(self.min): + self._threads.append(WorkerThread(self.server)) + for worker in self._threads: + worker.setName( + 'CP Server {worker_name!s}'. + format(worker_name=worker.getName()), + ) + worker.start() + for worker in self._threads: + while not worker.ready: + time.sleep(.1) + + @property + def idle(self): # noqa: D401; irrelevant for properties + """Number of worker threads which are idle. Read-only.""" + idles = len([t for t in self._threads if t.conn is None]) + return max(idles - len(self._pending_shutdowns), 0) + + def put(self, obj): + """Put request into queue. + + Args: + obj (:py:class:`~cheroot.server.HTTPConnection`): HTTP connection + waiting to be processed + """ + self._queue.put(obj, block=True, timeout=self._queue_put_timeout) + + def _clear_dead_threads(self): + # Remove any dead threads from our list + for t in [t for t in self._threads if not t.is_alive()]: + self._threads.remove(t) + try: + self._pending_shutdowns.popleft() + except IndexError: + pass + + def grow(self, amount): + """Spawn new worker threads (not above self.max).""" + if self.max > 0: + budget = max(self.max - len(self._threads), 0) + else: + # self.max <= 0 indicates no maximum + budget = float('inf') + + n_new = min(amount, budget) + + workers = [self._spawn_worker() for i in range(n_new)] + while not all(worker.ready for worker in workers): + time.sleep(.1) + self._threads.extend(workers) + + def _spawn_worker(self): + worker = WorkerThread(self.server) + worker.setName( + 'CP Server {worker_name!s}'. + format(worker_name=worker.getName()), + ) + worker.start() + return worker + + def shrink(self, amount): + """Kill off worker threads (not below self.min).""" + # Grow/shrink the pool if necessary. + # Remove any dead threads from our list + amount -= len(self._pending_shutdowns) + self._clear_dead_threads() + if amount <= 0: + return + + # calculate the number of threads above the minimum + n_extra = max(len(self._threads) - self.min, 0) + + # don't remove more than amount + n_to_remove = min(amount, n_extra) + + # put shutdown requests on the queue equal to the number of threads + # to remove. As each request is processed by a worker, that worker + # will terminate and be culled from the list. + for n in range(n_to_remove): + self._pending_shutdowns.append(None) + self._queue.put(_SHUTDOWNREQUEST) + + def stop(self, timeout=5): + """Terminate all worker threads. + + Args: + timeout (int): time to wait for threads to stop gracefully + """ + # for compatability, negative timeouts are treated like None + # TODO: treat negative timeouts like already expired timeouts + if timeout is not None and timeout < 0: + timeout = None + warnings.warning( + 'In the future, negative timeouts to Server.stop() ' + 'will be equivalent to a timeout of zero.', + stacklevel=2, + ) + + if timeout is not None: + endtime = time.time() + timeout + + # Must shut down threads here so the code that calls + # this method can know when all threads are stopped. + for worker in self._threads: + self._queue.put(_SHUTDOWNREQUEST) + + ignored_errors = ( + # TODO: explain this exception. + AssertionError, + # Ignore repeated Ctrl-C. See cherrypy#691. + KeyboardInterrupt, + ) + + for worker in self._clear_threads(): + remaining_time = timeout and endtime - time.time() + try: + worker.join(remaining_time) + if worker.is_alive(): + # Timeout exhausted; forcibly shut down the socket. + self._force_close(worker.conn) + worker.join() + except ignored_errors: + pass + + @staticmethod + @pass_none + def _force_close(conn): + if conn.rfile.closed: + return + try: + try: + conn.socket.shutdown(socket.SHUT_RD) + except TypeError: + # pyOpenSSL sockets don't take an arg + conn.socket.shutdown() + except OSError: + # shutdown sometimes fails (race with 'closed' check?) + # ref #238 + pass + + def _clear_threads(self): + """Clear self._threads and yield all joinable threads.""" + # threads = pop_all(self._threads) + threads, self._threads[:] = self._threads[:], [] + return ( + thread + for thread in threads + if thread is not threading.currentThread() + ) + + @property + def qsize(self): + """Return the queue size.""" + return self._queue.qsize() diff --git a/resources/lib/cheroot/wsgi.py b/resources/lib/cheroot/wsgi.py new file mode 100644 index 0000000..6635f52 --- /dev/null +++ b/resources/lib/cheroot/wsgi.py @@ -0,0 +1,435 @@ +"""This class holds Cheroot WSGI server implementation. + +Simplest example on how to use this server:: + + from cheroot import wsgi + + def my_crazy_app(environ, start_response): + status = '200 OK' + response_headers = [('Content-type','text/plain')] + start_response(status, response_headers) + return [b'Hello world!'] + + addr = '0.0.0.0', 8070 + server = wsgi.Server(addr, my_crazy_app) + server.start() + +The Cheroot WSGI server can serve as many WSGI applications +as you want in one instance by using a PathInfoDispatcher:: + + path_map = { + '/': my_crazy_app, + '/blog': my_blog_app, + } + d = wsgi.PathInfoDispatcher(path_map) + server = wsgi.Server(addr, d) +""" + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +import sys + +import six +from six.moves import filter + +from . import server +from .workers import threadpool +from ._compat import ntob, bton + + +class Server(server.HTTPServer): + """A subclass of HTTPServer which calls a WSGI application.""" + + wsgi_version = (1, 0) + """The version of WSGI to produce.""" + + def __init__( + self, bind_addr, wsgi_app, numthreads=10, server_name=None, + max=-1, request_queue_size=5, timeout=10, shutdown_timeout=5, + accepted_queue_size=-1, accepted_queue_timeout=10, + peercreds_enabled=False, peercreds_resolve_enabled=False, + ): + """Initialize WSGI Server instance. + + Args: + bind_addr (tuple): network interface to listen to + wsgi_app (callable): WSGI application callable + numthreads (int): number of threads for WSGI thread pool + server_name (str): web server name to be advertised via + Server HTTP header + max (int): maximum number of worker threads + request_queue_size (int): the 'backlog' arg to + socket.listen(); max queued connections + timeout (int): the timeout in seconds for accepted connections + shutdown_timeout (int): the total time, in seconds, to + wait for worker threads to cleanly exit + accepted_queue_size (int): maximum number of active + requests in queue + accepted_queue_timeout (int): timeout for putting request + into queue + """ + super(Server, self).__init__( + bind_addr, + gateway=wsgi_gateways[self.wsgi_version], + server_name=server_name, + peercreds_enabled=peercreds_enabled, + peercreds_resolve_enabled=peercreds_resolve_enabled, + ) + self.wsgi_app = wsgi_app + self.request_queue_size = request_queue_size + self.timeout = timeout + self.shutdown_timeout = shutdown_timeout + self.requests = threadpool.ThreadPool( + self, min=numthreads or 1, max=max, + accepted_queue_size=accepted_queue_size, + accepted_queue_timeout=accepted_queue_timeout, + ) + + @property + def numthreads(self): + """Set minimum number of threads.""" + return self.requests.min + + @numthreads.setter + def numthreads(self, value): + self.requests.min = value + + +class Gateway(server.Gateway): + """A base class to interface HTTPServer with WSGI.""" + + def __init__(self, req): + """Initialize WSGI Gateway instance with request. + + Args: + req (HTTPRequest): current HTTP request + """ + super(Gateway, self).__init__(req) + self.started_response = False + self.env = self.get_environ() + self.remaining_bytes_out = None + + @classmethod + def gateway_map(cls): + """Create a mapping of gateways and their versions. + + Returns: + dict[tuple[int,int],class]: map of gateway version and + corresponding class + + """ + return {gw.version: gw for gw in cls.__subclasses__()} + + def get_environ(self): + """Return a new environ dict targeting the given wsgi.version.""" + raise NotImplementedError # pragma: no cover + + def respond(self): + """Process the current request. + + From :pep:`333`: + + The start_response callable must not actually transmit + the response headers. Instead, it must store them for the + server or gateway to transmit only after the first + iteration of the application return value that yields + a NON-EMPTY string, or upon the application's first + invocation of the write() callable. + """ + response = self.req.server.wsgi_app(self.env, self.start_response) + try: + for chunk in filter(None, response): + if not isinstance(chunk, six.binary_type): + raise ValueError('WSGI Applications must yield bytes') + self.write(chunk) + finally: + # Send headers if not already sent + self.req.ensure_headers_sent() + if hasattr(response, 'close'): + response.close() + + def start_response(self, status, headers, exc_info=None): + """WSGI callable to begin the HTTP response.""" + # "The application may call start_response more than once, + # if and only if the exc_info argument is provided." + if self.started_response and not exc_info: + raise AssertionError( + 'WSGI start_response called a second ' + 'time with no exc_info.', + ) + self.started_response = True + + # "if exc_info is provided, and the HTTP headers have already been + # sent, start_response must raise an error, and should raise the + # exc_info tuple." + if self.req.sent_headers: + try: + six.reraise(*exc_info) + finally: + exc_info = None + + self.req.status = self._encode_status(status) + + for k, v in headers: + if not isinstance(k, str): + raise TypeError( + 'WSGI response header key %r is not of type str.' % k, + ) + if not isinstance(v, str): + raise TypeError( + 'WSGI response header value %r is not of type str.' % v, + ) + if k.lower() == 'content-length': + self.remaining_bytes_out = int(v) + out_header = ntob(k), ntob(v) + self.req.outheaders.append(out_header) + + return self.write + + @staticmethod + def _encode_status(status): + """Cast status to bytes representation of current Python version. + + According to :pep:`3333`, when using Python 3, the response status + and headers must be bytes masquerading as Unicode; that is, they + must be of type "str" but are restricted to code points in the + "Latin-1" set. + """ + if six.PY2: + return status + if not isinstance(status, str): + raise TypeError('WSGI response status is not of type str.') + return status.encode('ISO-8859-1') + + def write(self, chunk): + """WSGI callable to write unbuffered data to the client. + + This method is also used internally by start_response (to write + data from the iterable returned by the WSGI application). + """ + if not self.started_response: + raise AssertionError('WSGI write called before start_response.') + + chunklen = len(chunk) + rbo = self.remaining_bytes_out + if rbo is not None and chunklen > rbo: + if not self.req.sent_headers: + # Whew. We can send a 500 to the client. + self.req.simple_response( + '500 Internal Server Error', + 'The requested resource returned more bytes than the ' + 'declared Content-Length.', + ) + else: + # Dang. We have probably already sent data. Truncate the chunk + # to fit (so the client doesn't hang) and raise an error later. + chunk = chunk[:rbo] + + self.req.ensure_headers_sent() + + self.req.write(chunk) + + if rbo is not None: + rbo -= chunklen + if rbo < 0: + raise ValueError( + 'Response body exceeds the declared Content-Length.', + ) + + +class Gateway_10(Gateway): + """A Gateway class to interface HTTPServer with WSGI 1.0.x.""" + + version = 1, 0 + + def get_environ(self): + """Return a new environ dict targeting the given wsgi.version.""" + req = self.req + req_conn = req.conn + env = { + # set a non-standard environ entry so the WSGI app can know what + # the *real* server protocol is (and what features to support). + # See http://www.faqs.org/rfcs/rfc2145.html. + 'ACTUAL_SERVER_PROTOCOL': req.server.protocol, + 'PATH_INFO': bton(req.path), + 'QUERY_STRING': bton(req.qs), + 'REMOTE_ADDR': req_conn.remote_addr or '', + 'REMOTE_PORT': str(req_conn.remote_port or ''), + 'REQUEST_METHOD': bton(req.method), + 'REQUEST_URI': bton(req.uri), + 'SCRIPT_NAME': '', + 'SERVER_NAME': req.server.server_name, + # Bah. "SERVER_PROTOCOL" is actually the REQUEST protocol. + 'SERVER_PROTOCOL': bton(req.request_protocol), + 'SERVER_SOFTWARE': req.server.software, + 'wsgi.errors': sys.stderr, + 'wsgi.input': req.rfile, + 'wsgi.input_terminated': bool(req.chunked_read), + 'wsgi.multiprocess': False, + 'wsgi.multithread': True, + 'wsgi.run_once': False, + 'wsgi.url_scheme': bton(req.scheme), + 'wsgi.version': self.version, + } + + if isinstance(req.server.bind_addr, six.string_types): + # AF_UNIX. This isn't really allowed by WSGI, which doesn't + # address unix domain sockets. But it's better than nothing. + env['SERVER_PORT'] = '' + try: + env['X_REMOTE_PID'] = str(req_conn.peer_pid) + env['X_REMOTE_UID'] = str(req_conn.peer_uid) + env['X_REMOTE_GID'] = str(req_conn.peer_gid) + + env['X_REMOTE_USER'] = str(req_conn.peer_user) + env['X_REMOTE_GROUP'] = str(req_conn.peer_group) + + env['REMOTE_USER'] = env['X_REMOTE_USER'] + except RuntimeError: + """Unable to retrieve peer creds data. + + Unsupported by current kernel or socket error happened, or + unsupported socket type, or disabled. + """ + else: + env['SERVER_PORT'] = str(req.server.bind_addr[1]) + + # Request headers + env.update( + ( + 'HTTP_{header_name!s}'. + format(header_name=bton(k).upper().replace('-', '_')), + bton(v), + ) + for k, v in req.inheaders.items() + ) + + # CONTENT_TYPE/CONTENT_LENGTH + ct = env.pop('HTTP_CONTENT_TYPE', None) + if ct is not None: + env['CONTENT_TYPE'] = ct + cl = env.pop('HTTP_CONTENT_LENGTH', None) + if cl is not None: + env['CONTENT_LENGTH'] = cl + + if req.conn.ssl_env: + env.update(req.conn.ssl_env) + + return env + + +class Gateway_u0(Gateway_10): + """A Gateway class to interface HTTPServer with WSGI u.0. + + WSGI u.0 is an experimental protocol, which uses Unicode for keys + and values in both Python 2 and Python 3. + """ + + version = 'u', 0 + + def get_environ(self): + """Return a new environ dict targeting the given wsgi.version.""" + req = self.req + env_10 = super(Gateway_u0, self).get_environ() + env = dict(map(self._decode_key, env_10.items())) + + # Request-URI + enc = env.setdefault(six.u('wsgi.url_encoding'), six.u('utf-8')) + try: + env['PATH_INFO'] = req.path.decode(enc) + env['QUERY_STRING'] = req.qs.decode(enc) + except UnicodeDecodeError: + # Fall back to latin 1 so apps can transcode if needed. + env['wsgi.url_encoding'] = 'ISO-8859-1' + env['PATH_INFO'] = env_10['PATH_INFO'] + env['QUERY_STRING'] = env_10['QUERY_STRING'] + + env.update(map(self._decode_value, env.items())) + + return env + + @staticmethod + def _decode_key(item): + k, v = item + if six.PY2: + k = k.decode('ISO-8859-1') + return k, v + + @staticmethod + def _decode_value(item): + k, v = item + skip_keys = 'REQUEST_URI', 'wsgi.input' + if not six.PY2 or not isinstance(v, bytes) or k in skip_keys: + return k, v + return k, v.decode('ISO-8859-1') + + +wsgi_gateways = Gateway.gateway_map() + + +class PathInfoDispatcher: + """A WSGI dispatcher for dispatch based on the PATH_INFO.""" + + def __init__(self, apps): + """Initialize path info WSGI app dispatcher. + + Args: + apps (dict[str,object]|list[tuple[str,object]]): URI prefix + and WSGI app pairs + """ + try: + apps = list(apps.items()) + except AttributeError: + pass + + # Sort the apps by len(path), descending + def by_path_len(app): + return len(app[0]) + apps.sort(key=by_path_len, reverse=True) + + # The path_prefix strings must start, but not end, with a slash. + # Use "" instead of "/". + self.apps = [(p.rstrip('/'), a) for p, a in apps] + + def __call__(self, environ, start_response): + """Process incoming WSGI request. + + Ref: :pep:`3333` + + Args: + environ (Mapping): a dict containing WSGI environment variables + start_response (callable): function, which sets response + status and headers + + Returns: + list[bytes]: iterable containing bytes to be returned in + HTTP response body + + """ + path = environ['PATH_INFO'] or '/' + for p, app in self.apps: + # The apps list should be sorted by length, descending. + if path.startswith('{path!s}/'.format(path=p)) or path == p: + environ = environ.copy() + environ['SCRIPT_NAME'] = environ.get('SCRIPT_NAME', '') + p + environ['PATH_INFO'] = path[len(p):] + return app(environ, start_response) + + start_response( + '404 Not Found', [ + ('Content-Type', 'text/plain'), + ('Content-Length', '0'), + ], + ) + return [''] + + +# compatibility aliases +globals().update( + WSGIServer=Server, + WSGIGateway=Gateway, + WSGIGateway_u0=Gateway_u0, + WSGIGateway_10=Gateway_10, + WSGIPathInfoDispatcher=PathInfoDispatcher, +) diff --git a/resources/lib/cherrypy/__init__.py b/resources/lib/cherrypy/__init__.py new file mode 100644 index 0000000..84481d5 --- /dev/null +++ b/resources/lib/cherrypy/__init__.py @@ -0,0 +1,370 @@ +"""CherryPy is a pythonic, object-oriented HTTP framework. + +CherryPy consists of not one, but four separate API layers. + +The APPLICATION LAYER is the simplest. CherryPy applications are written as +a tree of classes and methods, where each branch in the tree corresponds to +a branch in the URL path. Each method is a 'page handler', which receives +GET and POST params as keyword arguments, and returns or yields the (HTML) +body of the response. The special method name 'index' is used for paths +that end in a slash, and the special method name 'default' is used to +handle multiple paths via a single handler. This layer also includes: + + * the 'exposed' attribute (and cherrypy.expose) + * cherrypy.quickstart() + * _cp_config attributes + * cherrypy.tools (including cherrypy.session) + * cherrypy.url() + +The ENVIRONMENT LAYER is used by developers at all levels. It provides +information about the current request and response, plus the application +and server environment, via a (default) set of top-level objects: + + * cherrypy.request + * cherrypy.response + * cherrypy.engine + * cherrypy.server + * cherrypy.tree + * cherrypy.config + * cherrypy.thread_data + * cherrypy.log + * cherrypy.HTTPError, NotFound, and HTTPRedirect + * cherrypy.lib + +The EXTENSION LAYER allows advanced users to construct and share their own +plugins. It consists of: + + * Hook API + * Tool API + * Toolbox API + * Dispatch API + * Config Namespace API + +Finally, there is the CORE LAYER, which uses the core API's to construct +the default components which are available at higher layers. You can think +of the default components as the 'reference implementation' for CherryPy. +Megaframeworks (and advanced users) may replace the default components +with customized or extended components. The core API's are: + + * Application API + * Engine API + * Request API + * Server API + * WSGI API + +These API's are described in the `CherryPy specification +`_. +""" + +try: + import pkg_resources +except ImportError: + pass + +from threading import local as _local + +from ._cperror import ( + HTTPError, HTTPRedirect, InternalRedirect, + NotFound, CherryPyException, +) + +from . import _cpdispatch as dispatch + +from ._cptools import default_toolbox as tools, Tool +from ._helper import expose, popargs, url + +from . import _cprequest, _cpserver, _cptree, _cplogging, _cpconfig + +import cherrypy.lib.httputil as _httputil + +from ._cptree import Application +from . import _cpwsgi as wsgi + +from . import process +try: + from .process import win32 + engine = win32.Win32Bus() + engine.console_control_handler = win32.ConsoleCtrlHandler(engine) + del win32 +except ImportError: + engine = process.bus + +from . import _cpchecker + +__all__ = ( + 'HTTPError', 'HTTPRedirect', 'InternalRedirect', + 'NotFound', 'CherryPyException', + 'dispatch', 'tools', 'Tool', 'Application', + 'wsgi', 'process', 'tree', 'engine', + 'quickstart', 'serving', 'request', 'response', 'thread_data', + 'log', 'expose', 'popargs', 'url', 'config', +) + + +__import__('cherrypy._cptools') +__import__('cherrypy._cprequest') + + +tree = _cptree.Tree() + + +try: + __version__ = pkg_resources.require('cherrypy')[0].version +except Exception: + __version__ = '18.6.0' + + +engine.listeners['before_request'] = set() +engine.listeners['after_request'] = set() + + +engine.autoreload = process.plugins.Autoreloader(engine) +engine.autoreload.subscribe() + +engine.thread_manager = process.plugins.ThreadManager(engine) +engine.thread_manager.subscribe() + +engine.signal_handler = process.plugins.SignalHandler(engine) + + +class _HandleSignalsPlugin(object): + """Handle signals from other processes. + + Based on the configured platform handlers above. + """ + + def __init__(self, bus): + self.bus = bus + + def subscribe(self): + """Add the handlers based on the platform.""" + if hasattr(self.bus, 'signal_handler'): + self.bus.signal_handler.subscribe() + if hasattr(self.bus, 'console_control_handler'): + self.bus.console_control_handler.subscribe() + + +engine.signals = _HandleSignalsPlugin(engine) + + +server = _cpserver.Server() +server.subscribe() + + +def quickstart(root=None, script_name='', config=None): + """Mount the given root, start the builtin server (and engine), then block. + + root: an instance of a "controller class" (a collection of page handler + methods) which represents the root of the application. + script_name: a string containing the "mount point" of the application. + This should start with a slash, and be the path portion of the URL + at which to mount the given root. For example, if root.index() will + handle requests to "http://www.example.com:8080/dept/app1/", then + the script_name argument would be "/dept/app1". + + It MUST NOT end in a slash. If the script_name refers to the root + of the URI, it MUST be an empty string (not "/"). + config: a file or dict containing application config. If this contains + a [global] section, those entries will be used in the global + (site-wide) config. + """ + if config: + _global_conf_alias.update(config) + + tree.mount(root, script_name, config) + + engine.signals.subscribe() + engine.start() + engine.block() + + +class _Serving(_local): + """An interface for registering request and response objects. + + Rather than have a separate "thread local" object for the request and + the response, this class works as a single threadlocal container for + both objects (and any others which developers wish to define). In this + way, we can easily dump those objects when we stop/start a new HTTP + conversation, yet still refer to them as module-level globals in a + thread-safe way. + """ + + request = _cprequest.Request(_httputil.Host('127.0.0.1', 80), + _httputil.Host('127.0.0.1', 1111)) + """ + The request object for the current thread. In the main thread, + and any threads which are not receiving HTTP requests, this is None.""" + + response = _cprequest.Response() + """ + The response object for the current thread. In the main thread, + and any threads which are not receiving HTTP requests, this is None.""" + + def load(self, request, response): + self.request = request + self.response = response + + def clear(self): + """Remove all attributes of self.""" + self.__dict__.clear() + + +serving = _Serving() + + +class _ThreadLocalProxy(object): + + __slots__ = ['__attrname__', '__dict__'] + + def __init__(self, attrname): + self.__attrname__ = attrname + + def __getattr__(self, name): + child = getattr(serving, self.__attrname__) + return getattr(child, name) + + def __setattr__(self, name, value): + if name in ('__attrname__', ): + object.__setattr__(self, name, value) + else: + child = getattr(serving, self.__attrname__) + setattr(child, name, value) + + def __delattr__(self, name): + child = getattr(serving, self.__attrname__) + delattr(child, name) + + @property + def __dict__(self): + child = getattr(serving, self.__attrname__) + d = child.__class__.__dict__.copy() + d.update(child.__dict__) + return d + + def __getitem__(self, key): + child = getattr(serving, self.__attrname__) + return child[key] + + def __setitem__(self, key, value): + child = getattr(serving, self.__attrname__) + child[key] = value + + def __delitem__(self, key): + child = getattr(serving, self.__attrname__) + del child[key] + + def __contains__(self, key): + child = getattr(serving, self.__attrname__) + return key in child + + def __len__(self): + child = getattr(serving, self.__attrname__) + return len(child) + + def __nonzero__(self): + child = getattr(serving, self.__attrname__) + return bool(child) + # Python 3 + __bool__ = __nonzero__ + + +# Create request and response object (the same objects will be used +# throughout the entire life of the webserver, but will redirect +# to the "serving" object) +request = _ThreadLocalProxy('request') +response = _ThreadLocalProxy('response') + +# Create thread_data object as a thread-specific all-purpose storage + + +class _ThreadData(_local): + """A container for thread-specific data.""" + + +thread_data = _ThreadData() + + +# Monkeypatch pydoc to allow help() to go through the threadlocal proxy. +# Jan 2007: no Googleable examples of anyone else replacing pydoc.resolve. +# The only other way would be to change what is returned from type(request) +# and that's not possible in pure Python (you'd have to fake ob_type). +def _cherrypy_pydoc_resolve(thing, forceload=0): + """Given an object or a path to an object, get the object and its name.""" + if isinstance(thing, _ThreadLocalProxy): + thing = getattr(serving, thing.__attrname__) + return _pydoc._builtin_resolve(thing, forceload) + + +try: + import pydoc as _pydoc + _pydoc._builtin_resolve = _pydoc.resolve + _pydoc.resolve = _cherrypy_pydoc_resolve +except ImportError: + pass + + +class _GlobalLogManager(_cplogging.LogManager): + """A site-wide LogManager; routes to app.log or global log as appropriate. + + This :class:`LogManager` implements + cherrypy.log() and cherrypy.log.access(). If either + function is called during a request, the message will be sent to the + logger for the current Application. If they are called outside of a + request, the message will be sent to the site-wide logger. + """ + + def __call__(self, *args, **kwargs): + """Log the given message to the app.log or global log. + + Log the given message to the app.log or global + log as appropriate. + """ + # Do NOT use try/except here. See + # https://github.com/cherrypy/cherrypy/issues/945 + if hasattr(request, 'app') and hasattr(request.app, 'log'): + log = request.app.log + else: + log = self + return log.error(*args, **kwargs) + + def access(self): + """Log an access message to the app.log or global log. + + Log the given message to the app.log or global + log as appropriate. + """ + try: + return request.app.log.access() + except AttributeError: + return _cplogging.LogManager.access(self) + + +log = _GlobalLogManager() +# Set a default screen handler on the global log. +log.screen = True +log.error_file = '' +# Using an access file makes CP about 10% slower. Leave off by default. +log.access_file = '' + + +@engine.subscribe('log') +def _buslog(msg, level): + log.error(msg, 'ENGINE', severity=level) + + +# Use _global_conf_alias so quickstart can use 'config' as an arg +# without shadowing cherrypy.config. +config = _global_conf_alias = _cpconfig.Config() +config.defaults = { + 'tools.log_tracebacks.on': True, + 'tools.log_headers.on': True, + 'tools.trailing_slash.on': True, + 'tools.encode.on': True +} +config.namespaces['log'] = lambda k, v: setattr(log, k, v) +config.namespaces['checker'] = lambda k, v: setattr(checker, k, v) +# Must reset to get our defaults applied. +config.reset() + +checker = _cpchecker.Checker() +engine.subscribe('start', checker) diff --git a/resources/lib/cherrypy/__main__.py b/resources/lib/cherrypy/__main__.py new file mode 100644 index 0000000..6674f7c --- /dev/null +++ b/resources/lib/cherrypy/__main__.py @@ -0,0 +1,5 @@ +"""CherryPy'd cherryd daemon runner.""" +from cherrypy.daemon import run + + +__name__ == '__main__' and run() diff --git a/resources/lib/cherrypy/_cpchecker.py b/resources/lib/cherrypy/_cpchecker.py new file mode 100644 index 0000000..f26f319 --- /dev/null +++ b/resources/lib/cherrypy/_cpchecker.py @@ -0,0 +1,323 @@ +"""Checker for CherryPy sites and mounted apps.""" +import os +import warnings +import builtins + +import cherrypy + + +class Checker(object): + """A checker for CherryPy sites and their mounted applications. + + When this object is called at engine startup, it executes each + of its own methods whose names start with ``check_``. If you wish + to disable selected checks, simply add a line in your global + config which sets the appropriate method to False:: + + [global] + checker.check_skipped_app_config = False + + You may also dynamically add or replace ``check_*`` methods in this way. + """ + + on = True + """If True (the default), run all checks; if False, turn off all checks.""" + + def __init__(self): + """Initialize Checker instance.""" + self._populate_known_types() + + def __call__(self): + """Run all check_* methods.""" + if self.on: + oldformatwarning = warnings.formatwarning + warnings.formatwarning = self.formatwarning + try: + for name in dir(self): + if name.startswith('check_'): + method = getattr(self, name) + if method and hasattr(method, '__call__'): + method() + finally: + warnings.formatwarning = oldformatwarning + + def formatwarning(self, message, category, filename, lineno, line=None): + """Format a warning.""" + return 'CherryPy Checker:\n%s\n\n' % message + + # This value should be set inside _cpconfig. + global_config_contained_paths = False + + def check_app_config_entries_dont_start_with_script_name(self): + """Check for App config with sections that repeat script_name.""" + for sn, app in cherrypy.tree.apps.items(): + if not isinstance(app, cherrypy.Application): + continue + if not app.config: + continue + if sn == '': + continue + sn_atoms = sn.strip('/').split('/') + for key in app.config.keys(): + key_atoms = key.strip('/').split('/') + if key_atoms[:len(sn_atoms)] == sn_atoms: + warnings.warn( + 'The application mounted at %r has config ' + 'entries that start with its script name: %r' % (sn, + key)) + + def check_site_config_entries_in_app_config(self): + """Check for mounted Applications that have site-scoped config.""" + for sn, app in cherrypy.tree.apps.items(): + if not isinstance(app, cherrypy.Application): + continue + + msg = [] + for section, entries in app.config.items(): + if section.startswith('/'): + for key, value in entries.items(): + for n in ('engine.', 'server.', 'tree.', 'checker.'): + if key.startswith(n): + msg.append('[%s] %s = %s' % + (section, key, value)) + if msg: + msg.insert(0, + 'The application mounted at %r contains the ' + 'following config entries, which are only allowed ' + 'in site-wide config. Move them to a [global] ' + 'section and pass them to cherrypy.config.update() ' + 'instead of tree.mount().' % sn) + warnings.warn(os.linesep.join(msg)) + + def check_skipped_app_config(self): + """Check for mounted Applications that have no config.""" + for sn, app in cherrypy.tree.apps.items(): + if not isinstance(app, cherrypy.Application): + continue + if not app.config: + msg = 'The Application mounted at %r has an empty config.' % sn + if self.global_config_contained_paths: + msg += (' It looks like the config you passed to ' + 'cherrypy.config.update() contains application-' + 'specific sections. You must explicitly pass ' + 'application config via ' + 'cherrypy.tree.mount(..., config=app_config)') + warnings.warn(msg) + return + + def check_app_config_brackets(self): + """Check for App config with extraneous brackets in section names.""" + for sn, app in cherrypy.tree.apps.items(): + if not isinstance(app, cherrypy.Application): + continue + if not app.config: + continue + for key in app.config.keys(): + if key.startswith('[') or key.endswith(']'): + warnings.warn( + 'The application mounted at %r has config ' + 'section names with extraneous brackets: %r. ' + 'Config *files* need brackets; config *dicts* ' + '(e.g. passed to tree.mount) do not.' % (sn, key)) + + def check_static_paths(self): + """Check Application config for incorrect static paths.""" + # Use the dummy Request object in the main thread. + request = cherrypy.request + for sn, app in cherrypy.tree.apps.items(): + if not isinstance(app, cherrypy.Application): + continue + request.app = app + for section in app.config: + # get_resource will populate request.config + request.get_resource(section + '/dummy.html') + conf = request.config.get + + if conf('tools.staticdir.on', False): + msg = '' + root = conf('tools.staticdir.root') + dir = conf('tools.staticdir.dir') + if dir is None: + msg = 'tools.staticdir.dir is not set.' + else: + fulldir = '' + if os.path.isabs(dir): + fulldir = dir + if root: + msg = ('dir is an absolute path, even ' + 'though a root is provided.') + testdir = os.path.join(root, dir[1:]) + if os.path.exists(testdir): + msg += ( + '\nIf you meant to serve the ' + 'filesystem folder at %r, remove the ' + 'leading slash from dir.' % (testdir,)) + else: + if not root: + msg = ( + 'dir is a relative path and ' + 'no root provided.') + else: + fulldir = os.path.join(root, dir) + if not os.path.isabs(fulldir): + msg = ('%r is not an absolute path.' % ( + fulldir,)) + + if fulldir and not os.path.exists(fulldir): + if msg: + msg += '\n' + msg += ('%r (root + dir) is not an existing ' + 'filesystem path.' % fulldir) + + if msg: + warnings.warn('%s\nsection: [%s]\nroot: %r\ndir: %r' + % (msg, section, root, dir)) + + # -------------------------- Compatibility -------------------------- # + obsolete = { + 'server.default_content_type': 'tools.response_headers.headers', + 'log_access_file': 'log.access_file', + 'log_config_options': None, + 'log_file': 'log.error_file', + 'log_file_not_found': None, + 'log_request_headers': 'tools.log_headers.on', + 'log_to_screen': 'log.screen', + 'show_tracebacks': 'request.show_tracebacks', + 'throw_errors': 'request.throw_errors', + 'profiler.on': ('cherrypy.tree.mount(profiler.make_app(' + 'cherrypy.Application(Root())))'), + } + + deprecated = {} + + def _compat(self, config): + """Process config and warn on each obsolete or deprecated entry.""" + for section, conf in config.items(): + if isinstance(conf, dict): + for k in conf: + if k in self.obsolete: + warnings.warn('%r is obsolete. Use %r instead.\n' + 'section: [%s]' % + (k, self.obsolete[k], section)) + elif k in self.deprecated: + warnings.warn('%r is deprecated. Use %r instead.\n' + 'section: [%s]' % + (k, self.deprecated[k], section)) + else: + if section in self.obsolete: + warnings.warn('%r is obsolete. Use %r instead.' + % (section, self.obsolete[section])) + elif section in self.deprecated: + warnings.warn('%r is deprecated. Use %r instead.' + % (section, self.deprecated[section])) + + def check_compatibility(self): + """Process config and warn on each obsolete or deprecated entry.""" + self._compat(cherrypy.config) + for sn, app in cherrypy.tree.apps.items(): + if not isinstance(app, cherrypy.Application): + continue + self._compat(app.config) + + # ------------------------ Known Namespaces ------------------------ # + extra_config_namespaces = [] + + def _known_ns(self, app): + ns = ['wsgi'] + ns.extend(app.toolboxes) + ns.extend(app.namespaces) + ns.extend(app.request_class.namespaces) + ns.extend(cherrypy.config.namespaces) + ns += self.extra_config_namespaces + + for section, conf in app.config.items(): + is_path_section = section.startswith('/') + if is_path_section and isinstance(conf, dict): + for k in conf: + atoms = k.split('.') + if len(atoms) > 1: + if atoms[0] not in ns: + # Spit out a special warning if a known + # namespace is preceded by "cherrypy." + if atoms[0] == 'cherrypy' and atoms[1] in ns: + msg = ( + 'The config entry %r is invalid; ' + 'try %r instead.\nsection: [%s]' + % (k, '.'.join(atoms[1:]), section)) + else: + msg = ( + 'The config entry %r is invalid, ' + 'because the %r config namespace ' + 'is unknown.\n' + 'section: [%s]' % (k, atoms[0], section)) + warnings.warn(msg) + elif atoms[0] == 'tools': + if atoms[1] not in dir(cherrypy.tools): + msg = ( + 'The config entry %r may be invalid, ' + 'because the %r tool was not found.\n' + 'section: [%s]' % (k, atoms[1], section)) + warnings.warn(msg) + + def check_config_namespaces(self): + """Process config and warn on each unknown config namespace.""" + for sn, app in cherrypy.tree.apps.items(): + if not isinstance(app, cherrypy.Application): + continue + self._known_ns(app) + + # -------------------------- Config Types -------------------------- # + known_config_types = {} + + def _populate_known_types(self): + b = [x for x in vars(builtins).values() + if type(x) is type(str)] + + def traverse(obj, namespace): + for name in dir(obj): + # Hack for 3.2's warning about body_params + if name == 'body_params': + continue + vtype = type(getattr(obj, name, None)) + if vtype in b: + self.known_config_types[namespace + '.' + name] = vtype + + traverse(cherrypy.request, 'request') + traverse(cherrypy.response, 'response') + traverse(cherrypy.server, 'server') + traverse(cherrypy.engine, 'engine') + traverse(cherrypy.log, 'log') + + def _known_types(self, config): + msg = ('The config entry %r in section %r is of type %r, ' + 'which does not match the expected type %r.') + + for section, conf in config.items(): + if not isinstance(conf, dict): + conf = {section: conf} + for k, v in conf.items(): + if v is not None: + expected_type = self.known_config_types.get(k, None) + vtype = type(v) + if expected_type and vtype != expected_type: + warnings.warn(msg % (k, section, vtype.__name__, + expected_type.__name__)) + + def check_config_types(self): + """Assert that config values are of the same type as default values.""" + self._known_types(cherrypy.config) + for sn, app in cherrypy.tree.apps.items(): + if not isinstance(app, cherrypy.Application): + continue + self._known_types(app.config) + + # -------------------- Specific config warnings -------------------- # + def check_localhost(self): + """Warn if any socket_host is 'localhost'. See #711.""" + for k, v in cherrypy.config.items(): + if k == 'server.socket_host' and v == 'localhost': + warnings.warn("The use of 'localhost' as a socket host can " + 'cause problems on newer systems, since ' + "'localhost' can map to either an IPv4 or an " + "IPv6 address. You should use '127.0.0.1' " + "or '[::1]' instead.") diff --git a/resources/lib/cherrypy/_cpcompat.py b/resources/lib/cherrypy/_cpcompat.py new file mode 100644 index 0000000..a43f6d3 --- /dev/null +++ b/resources/lib/cherrypy/_cpcompat.py @@ -0,0 +1,59 @@ +"""Compatibility code for using CherryPy with various versions of Python. + +To retain compatibility with older Python versions, this module provides a +useful abstraction over the differences between Python versions, sometimes by +preferring a newer idiom, sometimes an older one, and sometimes a custom one. + +In particular, Python 2 uses str and '' for byte strings, while Python 3 +uses str and '' for unicode strings. We will call each of these the 'native +string' type for each version. Because of this major difference, this module +provides +two functions: 'ntob', which translates native strings (of type 'str') into +byte strings regardless of Python version, and 'ntou', which translates native +strings to unicode strings. + +Try not to use the compatibility functions 'ntob', 'ntou', 'tonative'. +They were created with Python 2.3-2.5 compatibility in mind. +Instead, use unicode literals (from __future__) and bytes literals +and their .encode/.decode methods as needed. +""" + +import http.client + + +def ntob(n, encoding='ISO-8859-1'): + """Return the given native string as a byte string in the given + encoding. + """ + assert_native(n) + # In Python 3, the native string type is unicode + return n.encode(encoding) + + +def ntou(n, encoding='ISO-8859-1'): + """Return the given native string as a unicode string with the given + encoding. + """ + assert_native(n) + # In Python 3, the native string type is unicode + return n + + +def tonative(n, encoding='ISO-8859-1'): + """Return the given string as a native string in the given encoding.""" + # In Python 3, the native string type is unicode + if isinstance(n, bytes): + return n.decode(encoding) + return n + + +def assert_native(n): + if not isinstance(n, str): + raise TypeError('n must be a native str (got %s)' % type(n).__name__) + + +# Some platforms don't expose HTTPSConnection, so handle it separately +HTTPSConnection = getattr(http.client, 'HTTPSConnection', None) + + +text_or_bytes = str, bytes diff --git a/resources/lib/cherrypy/_cpconfig.py b/resources/lib/cherrypy/_cpconfig.py new file mode 100644 index 0000000..8e3fd61 --- /dev/null +++ b/resources/lib/cherrypy/_cpconfig.py @@ -0,0 +1,296 @@ +""" +Configuration system for CherryPy. + +Configuration in CherryPy is implemented via dictionaries. Keys are strings +which name the mapped value, which may be of any type. + + +Architecture +------------ + +CherryPy Requests are part of an Application, which runs in a global context, +and configuration data may apply to any of those three scopes: + +Global + Configuration entries which apply everywhere are stored in + cherrypy.config. + +Application + Entries which apply to each mounted application are stored + on the Application object itself, as 'app.config'. This is a two-level + dict where each key is a path, or "relative URL" (for example, "/" or + "/path/to/my/page"), and each value is a config dict. Usually, this + data is provided in the call to tree.mount(root(), config=conf), + although you may also use app.merge(conf). + +Request + Each Request object possesses a single 'Request.config' dict. + Early in the request process, this dict is populated by merging global + config entries, Application entries (whose path equals or is a parent + of Request.path_info), and any config acquired while looking up the + page handler (see next). + + +Declaration +----------- + +Configuration data may be supplied as a Python dictionary, as a filename, +or as an open file object. When you supply a filename or file, CherryPy +uses Python's builtin ConfigParser; you declare Application config by +writing each path as a section header:: + + [/path/to/my/page] + request.stream = True + +To declare global configuration entries, place them in a [global] section. + +You may also declare config entries directly on the classes and methods +(page handlers) that make up your CherryPy application via the ``_cp_config`` +attribute, set with the ``cherrypy.config`` decorator. For example:: + + @cherrypy.config(**{'tools.gzip.on': True}) + class Demo: + + @cherrypy.expose + @cherrypy.config(**{'request.show_tracebacks': False}) + def index(self): + return "Hello world" + +.. note:: + + This behavior is only guaranteed for the default dispatcher. + Other dispatchers may have different restrictions on where + you can attach config attributes. + + +Namespaces +---------- + +Configuration keys are separated into namespaces by the first "." in the key. +Current namespaces: + +engine + Controls the 'application engine', including autoreload. + These can only be declared in the global config. + +tree + Grafts cherrypy.Application objects onto cherrypy.tree. + These can only be declared in the global config. + +hooks + Declares additional request-processing functions. + +log + Configures the logging for each application. + These can only be declared in the global or / config. + +request + Adds attributes to each Request. + +response + Adds attributes to each Response. + +server + Controls the default HTTP server via cherrypy.server. + These can only be declared in the global config. + +tools + Runs and configures additional request-processing packages. + +wsgi + Adds WSGI middleware to an Application's "pipeline". + These can only be declared in the app's root config ("/"). + +checker + Controls the 'checker', which looks for common errors in + app state (including config) when the engine starts. + Global config only. + +The only key that does not exist in a namespace is the "environment" entry. +This special entry 'imports' other config entries from a template stored in +cherrypy._cpconfig.environments[environment]. It only applies to the global +config, and only when you use cherrypy.config.update. + +You can define your own namespaces to be called at the Global, Application, +or Request level, by adding a named handler to cherrypy.config.namespaces, +app.namespaces, or app.request_class.namespaces. The name can +be any string, and the handler must be either a callable or a (Python 2.5 +style) context manager. +""" + +import cherrypy +from cherrypy._cpcompat import text_or_bytes +from cherrypy.lib import reprconf + + +def _if_filename_register_autoreload(ob): + """Register for autoreload if ob is a string (presumed filename).""" + is_filename = isinstance(ob, text_or_bytes) + is_filename and cherrypy.engine.autoreload.files.add(ob) + + +def merge(base, other): + """Merge one app config (from a dict, file, or filename) into another. + + If the given config is a filename, it will be appended to + the list of files to monitor for "autoreload" changes. + """ + _if_filename_register_autoreload(other) + + # Load other into base + for section, value_map in reprconf.Parser.load(other).items(): + if not isinstance(value_map, dict): + raise ValueError( + 'Application config must include section headers, but the ' + "config you tried to merge doesn't have any sections. " + 'Wrap your config in another dict with paths as section ' + "headers, for example: {'/': config}.") + base.setdefault(section, {}).update(value_map) + + +class Config(reprconf.Config): + """The 'global' configuration data for the entire CherryPy process.""" + + def update(self, config): + """Update self from a dict, file or filename.""" + _if_filename_register_autoreload(config) + super(Config, self).update(config) + + def _apply(self, config): + """Update self from a dict.""" + if isinstance(config.get('global'), dict): + if len(config) > 1: + cherrypy.checker.global_config_contained_paths = True + config = config['global'] + if 'tools.staticdir.dir' in config: + config['tools.staticdir.section'] = 'global' + super(Config, self)._apply(config) + + @staticmethod + def __call__(**kwargs): + """Decorate for page handlers to set _cp_config.""" + def tool_decorator(f): + _Vars(f).setdefault('_cp_config', {}).update(kwargs) + return f + return tool_decorator + + +class _Vars(object): + """Adapter allowing setting a default attribute on a function or class.""" + + def __init__(self, target): + self.target = target + + def setdefault(self, key, default): + if not hasattr(self.target, key): + setattr(self.target, key, default) + return getattr(self.target, key) + + +# Sphinx begin config.environments +Config.environments = environments = { + 'staging': { + 'engine.autoreload.on': False, + 'checker.on': False, + 'tools.log_headers.on': False, + 'request.show_tracebacks': False, + 'request.show_mismatched_params': False, + }, + 'production': { + 'engine.autoreload.on': False, + 'checker.on': False, + 'tools.log_headers.on': False, + 'request.show_tracebacks': False, + 'request.show_mismatched_params': False, + 'log.screen': False, + }, + 'embedded': { + # For use with CherryPy embedded in another deployment stack. + 'engine.autoreload.on': False, + 'checker.on': False, + 'tools.log_headers.on': False, + 'request.show_tracebacks': False, + 'request.show_mismatched_params': False, + 'log.screen': False, + 'engine.SIGHUP': None, + 'engine.SIGTERM': None, + }, + 'test_suite': { + 'engine.autoreload.on': False, + 'checker.on': False, + 'tools.log_headers.on': False, + 'request.show_tracebacks': True, + 'request.show_mismatched_params': True, + 'log.screen': False, + }, +} +# Sphinx end config.environments + + +def _server_namespace_handler(k, v): + """Config handler for the "server" namespace.""" + atoms = k.split('.', 1) + if len(atoms) > 1: + # Special-case config keys of the form 'server.servername.socket_port' + # to configure additional HTTP servers. + if not hasattr(cherrypy, 'servers'): + cherrypy.servers = {} + + servername, k = atoms + if servername not in cherrypy.servers: + from cherrypy import _cpserver + cherrypy.servers[servername] = _cpserver.Server() + # On by default, but 'on = False' can unsubscribe it (see below). + cherrypy.servers[servername].subscribe() + + if k == 'on': + if v: + cherrypy.servers[servername].subscribe() + else: + cherrypy.servers[servername].unsubscribe() + else: + setattr(cherrypy.servers[servername], k, v) + else: + setattr(cherrypy.server, k, v) + + +Config.namespaces['server'] = _server_namespace_handler + + +def _engine_namespace_handler(k, v): + """Config handler for the "engine" namespace.""" + engine = cherrypy.engine + + if k in {'SIGHUP', 'SIGTERM'}: + engine.subscribe(k, v) + return + + if '.' in k: + plugin, attrname = k.split('.', 1) + plugin = getattr(engine, plugin) + op = 'subscribe' if v else 'unsubscribe' + sub_unsub = getattr(plugin, op, None) + if attrname == 'on' and callable(sub_unsub): + sub_unsub() + return + setattr(plugin, attrname, v) + else: + setattr(engine, k, v) + + +Config.namespaces['engine'] = _engine_namespace_handler + + +def _tree_namespace_handler(k, v): + """Namespace handler for the 'tree' config namespace.""" + if isinstance(v, dict): + for script_name, app in v.items(): + cherrypy.tree.graft(app, script_name) + msg = 'Mounted: %s on %s' % (app, script_name or '/') + cherrypy.engine.log(msg) + else: + cherrypy.tree.graft(v, v.script_name) + cherrypy.engine.log('Mounted: %s on %s' % (v, v.script_name or '/')) + + +Config.namespaces['tree'] = _tree_namespace_handler diff --git a/resources/lib/cherrypy/_cpdispatch.py b/resources/lib/cherrypy/_cpdispatch.py new file mode 100644 index 0000000..83eb79c --- /dev/null +++ b/resources/lib/cherrypy/_cpdispatch.py @@ -0,0 +1,686 @@ +"""CherryPy dispatchers. + +A 'dispatcher' is the object which looks up the 'page handler' callable +and collects config for the current request based on the path_info, other +request attributes, and the application architecture. The core calls the +dispatcher as early as possible, passing it a 'path_info' argument. + +The default dispatcher discovers the page handler by matching path_info +to a hierarchical arrangement of objects, starting at request.app.root. +""" + +import string +import sys +import types +try: + classtype = (type, types.ClassType) +except AttributeError: + classtype = type + +import cherrypy + + +class PageHandler(object): + + """Callable which sets response.body.""" + + def __init__(self, callable, *args, **kwargs): + self.callable = callable + self.args = args + self.kwargs = kwargs + + @property + def args(self): + """The ordered args should be accessible from post dispatch hooks.""" + return cherrypy.serving.request.args + + @args.setter + def args(self, args): + cherrypy.serving.request.args = args + return cherrypy.serving.request.args + + @property + def kwargs(self): + """The named kwargs should be accessible from post dispatch hooks.""" + return cherrypy.serving.request.kwargs + + @kwargs.setter + def kwargs(self, kwargs): + cherrypy.serving.request.kwargs = kwargs + return cherrypy.serving.request.kwargs + + def __call__(self): + try: + return self.callable(*self.args, **self.kwargs) + except TypeError: + x = sys.exc_info()[1] + try: + test_callable_spec(self.callable, self.args, self.kwargs) + except cherrypy.HTTPError: + raise sys.exc_info()[1] + except Exception: + raise x + raise + + +def test_callable_spec(callable, callable_args, callable_kwargs): + """ + Inspect callable and test to see if the given args are suitable for it. + + When an error occurs during the handler's invoking stage there are 2 + erroneous cases: + 1. Too many parameters passed to a function which doesn't define + one of *args or **kwargs. + 2. Too little parameters are passed to the function. + + There are 3 sources of parameters to a cherrypy handler. + 1. query string parameters are passed as keyword parameters to the + handler. + 2. body parameters are also passed as keyword parameters. + 3. when partial matching occurs, the final path atoms are passed as + positional args. + Both the query string and path atoms are part of the URI. If they are + incorrect, then a 404 Not Found should be raised. Conversely the body + parameters are part of the request; if they are invalid a 400 Bad Request. + """ + show_mismatched_params = getattr( + cherrypy.serving.request, 'show_mismatched_params', False) + try: + (args, varargs, varkw, defaults) = getargspec(callable) + except TypeError: + if isinstance(callable, object) and hasattr(callable, '__call__'): + (args, varargs, varkw, + defaults) = getargspec(callable.__call__) + else: + # If it wasn't one of our own types, re-raise + # the original error + raise + + if args and ( + # For callable objects, which have a __call__(self) method + hasattr(callable, '__call__') or + # For normal methods + inspect.ismethod(callable) + ): + # Strip 'self' + args = args[1:] + + arg_usage = dict([(arg, 0,) for arg in args]) + vararg_usage = 0 + varkw_usage = 0 + extra_kwargs = set() + + for i, value in enumerate(callable_args): + try: + arg_usage[args[i]] += 1 + except IndexError: + vararg_usage += 1 + + for key in callable_kwargs.keys(): + try: + arg_usage[key] += 1 + except KeyError: + varkw_usage += 1 + extra_kwargs.add(key) + + # figure out which args have defaults. + args_with_defaults = args[-len(defaults or []):] + for i, val in enumerate(defaults or []): + # Defaults take effect only when the arg hasn't been used yet. + if arg_usage[args_with_defaults[i]] == 0: + arg_usage[args_with_defaults[i]] += 1 + + missing_args = [] + multiple_args = [] + for key, usage in arg_usage.items(): + if usage == 0: + missing_args.append(key) + elif usage > 1: + multiple_args.append(key) + + if missing_args: + # In the case where the method allows body arguments + # there are 3 potential errors: + # 1. not enough query string parameters -> 404 + # 2. not enough body parameters -> 400 + # 3. not enough path parts (partial matches) -> 404 + # + # We can't actually tell which case it is, + # so I'm raising a 404 because that covers 2/3 of the + # possibilities + # + # In the case where the method does not allow body + # arguments it's definitely a 404. + message = None + if show_mismatched_params: + message = 'Missing parameters: %s' % ','.join(missing_args) + raise cherrypy.HTTPError(404, message=message) + + # the extra positional arguments come from the path - 404 Not Found + if not varargs and vararg_usage > 0: + raise cherrypy.HTTPError(404) + + body_params = cherrypy.serving.request.body.params or {} + body_params = set(body_params.keys()) + qs_params = set(callable_kwargs.keys()) - body_params + + if multiple_args: + if qs_params.intersection(set(multiple_args)): + # If any of the multiple parameters came from the query string then + # it's a 404 Not Found + error = 404 + else: + # Otherwise it's a 400 Bad Request + error = 400 + + message = None + if show_mismatched_params: + message = 'Multiple values for parameters: '\ + '%s' % ','.join(multiple_args) + raise cherrypy.HTTPError(error, message=message) + + if not varkw and varkw_usage > 0: + + # If there were extra query string parameters, it's a 404 Not Found + extra_qs_params = set(qs_params).intersection(extra_kwargs) + if extra_qs_params: + message = None + if show_mismatched_params: + message = 'Unexpected query string '\ + 'parameters: %s' % ', '.join(extra_qs_params) + raise cherrypy.HTTPError(404, message=message) + + # If there were any extra body parameters, it's a 400 Not Found + extra_body_params = set(body_params).intersection(extra_kwargs) + if extra_body_params: + message = None + if show_mismatched_params: + message = 'Unexpected body parameters: '\ + '%s' % ', '.join(extra_body_params) + raise cherrypy.HTTPError(400, message=message) + + +try: + import inspect +except ImportError: + def test_callable_spec(callable, args, kwargs): # noqa: F811 + return None +else: + getargspec = inspect.getargspec + # Python 3 requires using getfullargspec if + # keyword-only arguments are present + if hasattr(inspect, 'getfullargspec'): + def getargspec(callable): + return inspect.getfullargspec(callable)[:4] + + +class LateParamPageHandler(PageHandler): + + """When passing cherrypy.request.params to the page handler, we do not + want to capture that dict too early; we want to give tools like the + decoding tool a chance to modify the params dict in-between the lookup + of the handler and the actual calling of the handler. This subclass + takes that into account, and allows request.params to be 'bound late' + (it's more complicated than that, but that's the effect). + """ + + @property + def kwargs(self): + """Page handler kwargs (with cherrypy.request.params copied in).""" + kwargs = cherrypy.serving.request.params.copy() + if self._kwargs: + kwargs.update(self._kwargs) + return kwargs + + @kwargs.setter + def kwargs(self, kwargs): + cherrypy.serving.request.kwargs = kwargs + self._kwargs = kwargs + + +if sys.version_info < (3, 0): + punctuation_to_underscores = string.maketrans( + string.punctuation, '_' * len(string.punctuation)) + + def validate_translator(t): + if not isinstance(t, str) or len(t) != 256: + raise ValueError( + 'The translate argument must be a str of len 256.') +else: + punctuation_to_underscores = str.maketrans( + string.punctuation, '_' * len(string.punctuation)) + + def validate_translator(t): + if not isinstance(t, dict): + raise ValueError('The translate argument must be a dict.') + + +class Dispatcher(object): + + """CherryPy Dispatcher which walks a tree of objects to find a handler. + + The tree is rooted at cherrypy.request.app.root, and each hierarchical + component in the path_info argument is matched to a corresponding nested + attribute of the root object. Matching handlers must have an 'exposed' + attribute which evaluates to True. The special method name "index" + matches a URI which ends in a slash ("/"). The special method name + "default" may match a portion of the path_info (but only when no longer + substring of the path_info matches some other object). + + This is the default, built-in dispatcher for CherryPy. + """ + + dispatch_method_name = '_cp_dispatch' + """ + The name of the dispatch method that nodes may optionally implement + to provide their own dynamic dispatch algorithm. + """ + + def __init__(self, dispatch_method_name=None, + translate=punctuation_to_underscores): + validate_translator(translate) + self.translate = translate + if dispatch_method_name: + self.dispatch_method_name = dispatch_method_name + + def __call__(self, path_info): + """Set handler and config for the current request.""" + request = cherrypy.serving.request + func, vpath = self.find_handler(path_info) + + if func: + # Decode any leftover %2F in the virtual_path atoms. + vpath = [x.replace('%2F', '/') for x in vpath] + request.handler = LateParamPageHandler(func, *vpath) + else: + request.handler = cherrypy.NotFound() + + def find_handler(self, path): + """Return the appropriate page handler, plus any virtual path. + + This will return two objects. The first will be a callable, + which can be used to generate page output. Any parameters from + the query string or request body will be sent to that callable + as keyword arguments. + + The callable is found by traversing the application's tree, + starting from cherrypy.request.app.root, and matching path + components to successive objects in the tree. For example, the + URL "/path/to/handler" might return root.path.to.handler. + + The second object returned will be a list of names which are + 'virtual path' components: parts of the URL which are dynamic, + and were not used when looking up the handler. + These virtual path components are passed to the handler as + positional arguments. + """ + request = cherrypy.serving.request + app = request.app + root = app.root + dispatch_name = self.dispatch_method_name + + # Get config for the root object/path. + fullpath = [x for x in path.strip('/').split('/') if x] + ['index'] + fullpath_len = len(fullpath) + segleft = fullpath_len + nodeconf = {} + if hasattr(root, '_cp_config'): + nodeconf.update(root._cp_config) + if '/' in app.config: + nodeconf.update(app.config['/']) + object_trail = [['root', root, nodeconf, segleft]] + + node = root + iternames = fullpath[:] + while iternames: + name = iternames[0] + # map to legal Python identifiers (e.g. replace '.' with '_') + objname = name.translate(self.translate) + + nodeconf = {} + subnode = getattr(node, objname, None) + pre_len = len(iternames) + if subnode is None: + dispatch = getattr(node, dispatch_name, None) + if dispatch and hasattr(dispatch, '__call__') and not \ + getattr(dispatch, 'exposed', False) and \ + pre_len > 1: + # Don't expose the hidden 'index' token to _cp_dispatch + # We skip this if pre_len == 1 since it makes no sense + # to call a dispatcher when we have no tokens left. + index_name = iternames.pop() + subnode = dispatch(vpath=iternames) + iternames.append(index_name) + else: + # We didn't find a path, but keep processing in case there + # is a default() handler. + iternames.pop(0) + else: + # We found the path, remove the vpath entry + iternames.pop(0) + segleft = len(iternames) + if segleft > pre_len: + # No path segment was removed. Raise an error. + raise cherrypy.CherryPyException( + 'A vpath segment was added. Custom dispatchers may only ' + 'remove elements. While trying to process ' + '{0} in {1}'.format(name, fullpath) + ) + elif segleft == pre_len: + # Assume that the handler used the current path segment, but + # did not pop it. This allows things like + # return getattr(self, vpath[0], None) + iternames.pop(0) + segleft -= 1 + node = subnode + + if node is not None: + # Get _cp_config attached to this node. + if hasattr(node, '_cp_config'): + nodeconf.update(node._cp_config) + + # Mix in values from app.config for this path. + existing_len = fullpath_len - pre_len + if existing_len != 0: + curpath = '/' + '/'.join(fullpath[0:existing_len]) + else: + curpath = '' + new_segs = fullpath[fullpath_len - pre_len:fullpath_len - segleft] + for seg in new_segs: + curpath += '/' + seg + if curpath in app.config: + nodeconf.update(app.config[curpath]) + + object_trail.append([name, node, nodeconf, segleft]) + + def set_conf(): + """Collapse all object_trail config into cherrypy.request.config. + """ + base = cherrypy.config.copy() + # Note that we merge the config from each node + # even if that node was None. + for name, obj, conf, segleft in object_trail: + base.update(conf) + if 'tools.staticdir.dir' in conf: + base['tools.staticdir.section'] = '/' + \ + '/'.join(fullpath[0:fullpath_len - segleft]) + return base + + # Try successive objects (reverse order) + num_candidates = len(object_trail) - 1 + for i in range(num_candidates, -1, -1): + + name, candidate, nodeconf, segleft = object_trail[i] + if candidate is None: + continue + + # Try a "default" method on the current leaf. + if hasattr(candidate, 'default'): + defhandler = candidate.default + if getattr(defhandler, 'exposed', False): + # Insert any extra _cp_config from the default handler. + conf = getattr(defhandler, '_cp_config', {}) + object_trail.insert( + i + 1, ['default', defhandler, conf, segleft]) + request.config = set_conf() + # See https://github.com/cherrypy/cherrypy/issues/613 + request.is_index = path.endswith('/') + return defhandler, fullpath[fullpath_len - segleft:-1] + + # Uncomment the next line to restrict positional params to + # "default". + # if i < num_candidates - 2: continue + + # Try the current leaf. + if getattr(candidate, 'exposed', False): + request.config = set_conf() + if i == num_candidates: + # We found the extra ".index". Mark request so tools + # can redirect if path_info has no trailing slash. + request.is_index = True + else: + # We're not at an 'index' handler. Mark request so tools + # can redirect if path_info has NO trailing slash. + # Note that this also includes handlers which take + # positional parameters (virtual paths). + request.is_index = False + return candidate, fullpath[fullpath_len - segleft:-1] + + # We didn't find anything + request.config = set_conf() + return None, [] + + +class MethodDispatcher(Dispatcher): + + """Additional dispatch based on cherrypy.request.method.upper(). + + Methods named GET, POST, etc will be called on an exposed class. + The method names must be all caps; the appropriate Allow header + will be output showing all capitalized method names as allowable + HTTP verbs. + + Note that the containing class must be exposed, not the methods. + """ + + def __call__(self, path_info): + """Set handler and config for the current request.""" + request = cherrypy.serving.request + resource, vpath = self.find_handler(path_info) + + if resource: + # Set Allow header + avail = [m for m in dir(resource) if m.isupper()] + if 'GET' in avail and 'HEAD' not in avail: + avail.append('HEAD') + avail.sort() + cherrypy.serving.response.headers['Allow'] = ', '.join(avail) + + # Find the subhandler + meth = request.method.upper() + func = getattr(resource, meth, None) + if func is None and meth == 'HEAD': + func = getattr(resource, 'GET', None) + if func: + # Grab any _cp_config on the subhandler. + if hasattr(func, '_cp_config'): + request.config.update(func._cp_config) + + # Decode any leftover %2F in the virtual_path atoms. + vpath = [x.replace('%2F', '/') for x in vpath] + request.handler = LateParamPageHandler(func, *vpath) + else: + request.handler = cherrypy.HTTPError(405) + else: + request.handler = cherrypy.NotFound() + + +class RoutesDispatcher(object): + + """A Routes based dispatcher for CherryPy.""" + + def __init__(self, full_result=False, **mapper_options): + """ + Routes dispatcher + + Set full_result to True if you wish the controller + and the action to be passed on to the page handler + parameters. By default they won't be. + """ + import routes + self.full_result = full_result + self.controllers = {} + self.mapper = routes.Mapper(**mapper_options) + self.mapper.controller_scan = self.controllers.keys + + def connect(self, name, route, controller, **kwargs): + self.controllers[name] = controller + self.mapper.connect(name, route, controller=name, **kwargs) + + def redirect(self, url): + raise cherrypy.HTTPRedirect(url) + + def __call__(self, path_info): + """Set handler and config for the current request.""" + func = self.find_handler(path_info) + if func: + cherrypy.serving.request.handler = LateParamPageHandler(func) + else: + cherrypy.serving.request.handler = cherrypy.NotFound() + + def find_handler(self, path_info): + """Find the right page handler, and set request.config.""" + import routes + + request = cherrypy.serving.request + + config = routes.request_config() + config.mapper = self.mapper + if hasattr(request, 'wsgi_environ'): + config.environ = request.wsgi_environ + config.host = request.headers.get('Host', None) + config.protocol = request.scheme + config.redirect = self.redirect + + result = self.mapper.match(path_info) + + config.mapper_dict = result + params = {} + if result: + params = result.copy() + if not self.full_result: + params.pop('controller', None) + params.pop('action', None) + request.params.update(params) + + # Get config for the root object/path. + request.config = base = cherrypy.config.copy() + curpath = '' + + def merge(nodeconf): + if 'tools.staticdir.dir' in nodeconf: + nodeconf['tools.staticdir.section'] = curpath or '/' + base.update(nodeconf) + + app = request.app + root = app.root + if hasattr(root, '_cp_config'): + merge(root._cp_config) + if '/' in app.config: + merge(app.config['/']) + + # Mix in values from app.config. + atoms = [x for x in path_info.split('/') if x] + if atoms: + last = atoms.pop() + else: + last = None + for atom in atoms: + curpath = '/'.join((curpath, atom)) + if curpath in app.config: + merge(app.config[curpath]) + + handler = None + if result: + controller = result.get('controller') + controller = self.controllers.get(controller, controller) + if controller: + if isinstance(controller, classtype): + controller = controller() + # Get config from the controller. + if hasattr(controller, '_cp_config'): + merge(controller._cp_config) + + action = result.get('action') + if action is not None: + handler = getattr(controller, action, None) + # Get config from the handler + if hasattr(handler, '_cp_config'): + merge(handler._cp_config) + else: + handler = controller + + # Do the last path atom here so it can + # override the controller's _cp_config. + if last: + curpath = '/'.join((curpath, last)) + if curpath in app.config: + merge(app.config[curpath]) + + return handler + + +def XMLRPCDispatcher(next_dispatcher=Dispatcher()): + from cherrypy.lib import xmlrpcutil + + def xmlrpc_dispatch(path_info): + path_info = xmlrpcutil.patched_path(path_info) + return next_dispatcher(path_info) + return xmlrpc_dispatch + + +def VirtualHost(next_dispatcher=Dispatcher(), use_x_forwarded_host=True, + **domains): + """ + Select a different handler based on the Host header. + + This can be useful when running multiple sites within one CP server. + It allows several domains to point to different parts of a single + website structure. For example:: + + http://www.domain.example -> root + http://www.domain2.example -> root/domain2/ + http://www.domain2.example:443 -> root/secure + + can be accomplished via the following config:: + + [/] + request.dispatch = cherrypy.dispatch.VirtualHost( + **{'www.domain2.example': '/domain2', + 'www.domain2.example:443': '/secure', + }) + + next_dispatcher + The next dispatcher object in the dispatch chain. + The VirtualHost dispatcher adds a prefix to the URL and calls + another dispatcher. Defaults to cherrypy.dispatch.Dispatcher(). + + use_x_forwarded_host + If True (the default), any "X-Forwarded-Host" + request header will be used instead of the "Host" header. This + is commonly added by HTTP servers (such as Apache) when proxying. + + ``**domains`` + A dict of {host header value: virtual prefix} pairs. + The incoming "Host" request header is looked up in this dict, + and, if a match is found, the corresponding "virtual prefix" + value will be prepended to the URL path before calling the + next dispatcher. Note that you often need separate entries + for "example.com" and "www.example.com". In addition, "Host" + headers may contain the port number. + """ + from cherrypy.lib import httputil + + def vhost_dispatch(path_info): + request = cherrypy.serving.request + header = request.headers.get + + domain = header('Host', '') + if use_x_forwarded_host: + domain = header('X-Forwarded-Host', domain) + + prefix = domains.get(domain, '') + if prefix: + path_info = httputil.urljoin(prefix, path_info) + + result = next_dispatcher(path_info) + + # Touch up staticdir config. See + # https://github.com/cherrypy/cherrypy/issues/614. + section = request.config.get('tools.staticdir.section') + if section: + section = section[len(prefix):] + request.config['tools.staticdir.section'] = section + + return result + return vhost_dispatch diff --git a/resources/lib/cherrypy/_cperror.py b/resources/lib/cherrypy/_cperror.py new file mode 100644 index 0000000..4e72768 --- /dev/null +++ b/resources/lib/cherrypy/_cperror.py @@ -0,0 +1,619 @@ +"""Exception classes for CherryPy. + +CherryPy provides (and uses) exceptions for declaring that the HTTP response +should be a status other than the default "200 OK". You can ``raise`` them like +normal Python exceptions. You can also call them and they will raise +themselves; this means you can set an +:class:`HTTPError` +or :class:`HTTPRedirect` as the +:attr:`request.handler`. + +.. _redirectingpost: + +Redirecting POST +================ + +When you GET a resource and are redirected by the server to another Location, +there's generally no problem since GET is both a "safe method" (there should +be no side-effects) and an "idempotent method" (multiple calls are no different +than a single call). + +POST, however, is neither safe nor idempotent--if you +charge a credit card, you don't want to be charged twice by a redirect! + +For this reason, *none* of the 3xx responses permit a user-agent (browser) to +resubmit a POST on redirection without first confirming the action with the +user: + +===== ================================= =========== +300 Multiple Choices Confirm with the user +301 Moved Permanently Confirm with the user +302 Found (Object moved temporarily) Confirm with the user +303 See Other GET the new URI; no confirmation +304 Not modified for conditional GET only; + POST should not raise this error +305 Use Proxy Confirm with the user +307 Temporary Redirect Confirm with the user +308 Permanent Redirect No confirmation +===== ================================= =========== + +However, browsers have historically implemented these restrictions poorly; +in particular, many browsers do not force the user to confirm 301, 302 +or 307 when redirecting POST. For this reason, CherryPy defaults to 303, +which most user-agents appear to have implemented correctly. Therefore, if +you raise HTTPRedirect for a POST request, the user-agent will most likely +attempt to GET the new URI (without asking for confirmation from the user). +We realize this is confusing for developers, but it's the safest thing we +could do. You are of course free to raise ``HTTPRedirect(uri, status=302)`` +or any other 3xx status if you know what you're doing, but given the +environment, we couldn't let any of those be the default. + +Custom Error Handling +===================== + +.. image:: /refman/cperrors.gif + +Anticipated HTTP responses +-------------------------- + +The 'error_page' config namespace can be used to provide custom HTML output for +expected responses (like 404 Not Found). Supply a filename from which the +output will be read. The contents will be interpolated with the values +%(status)s, %(message)s, %(traceback)s, and %(version)s using plain old Python +`string formatting +`_. + +:: + + _cp_config = { + 'error_page.404': os.path.join(localDir, "static/index.html") + } + + +Beginning in version 3.1, you may also provide a function or other callable as +an error_page entry. It will be passed the same status, message, traceback and +version arguments that are interpolated into templates:: + + def error_page_402(status, message, traceback, version): + return "Error %s - Well, I'm very sorry but you haven't paid!" % status + cherrypy.config.update({'error_page.402': error_page_402}) + +Also in 3.1, in addition to the numbered error codes, you may also supply +"error_page.default" to handle all codes which do not have their own error_page +entry. + + + +Unanticipated errors +-------------------- + +CherryPy also has a generic error handling mechanism: whenever an unanticipated +error occurs in your code, it will call +:func:`Request.error_response` to +set the response status, headers, and body. By default, this is the same +output as +:class:`HTTPError(500) `. If you want to provide +some other behavior, you generally replace "request.error_response". + +Here is some sample code that shows how to display a custom error message and +send an e-mail containing the error:: + + from cherrypy import _cperror + + def handle_error(): + cherrypy.response.status = 500 + cherrypy.response.body = [ + "Sorry, an error occurred" + ] + sendMail('error@domain.com', + 'Error in your web app', + _cperror.format_exc()) + + @cherrypy.config(**{'request.error_response': handle_error}) + class Root: + pass + +Note that you have to explicitly set +:attr:`response.body ` +and not simply return an error message as a result. +""" + +import io +import contextlib +import urllib.parse +from sys import exc_info as _exc_info +from traceback import format_exception as _format_exception +from xml.sax import saxutils +import html + +from more_itertools import always_iterable + +import cherrypy +from cherrypy._cpcompat import ntob +from cherrypy._cpcompat import tonative +from cherrypy._helper import classproperty +from cherrypy.lib import httputil as _httputil + + +class CherryPyException(Exception): + + """A base class for CherryPy exceptions.""" + pass + + +class InternalRedirect(CherryPyException): + + """Exception raised to switch to the handler for a different URL. + + This exception will redirect processing to another path within the site + (without informing the client). Provide the new path as an argument when + raising the exception. Provide any params in the querystring for the new + URL. + """ + + def __init__(self, path, query_string=''): + self.request = cherrypy.serving.request + + self.query_string = query_string + if '?' in path: + # Separate any params included in the path + path, self.query_string = path.split('?', 1) + + # Note that urljoin will "do the right thing" whether url is: + # 1. a URL relative to root (e.g. "/dummy") + # 2. a URL relative to the current path + # Note that any query string will be discarded. + path = urllib.parse.urljoin(self.request.path_info, path) + + # Set a 'path' member attribute so that code which traps this + # error can have access to it. + self.path = path + + CherryPyException.__init__(self, path, self.query_string) + + +class HTTPRedirect(CherryPyException): + + """Exception raised when the request should be redirected. + + This exception will force a HTTP redirect to the URL or URL's you give it. + The new URL must be passed as the first argument to the Exception, + e.g., HTTPRedirect(newUrl). Multiple URLs are allowed in a list. + If a URL is absolute, it will be used as-is. If it is relative, it is + assumed to be relative to the current cherrypy.request.path_info. + + If one of the provided URL is a unicode object, it will be encoded + using the default encoding or the one passed in parameter. + + There are multiple types of redirect, from which you can select via the + ``status`` argument. If you do not provide a ``status`` arg, it defaults to + 303 (or 302 if responding with HTTP/1.0). + + Examples:: + + raise cherrypy.HTTPRedirect("") + raise cherrypy.HTTPRedirect("/abs/path", 307) + raise cherrypy.HTTPRedirect(["path1", "path2?a=1&b=2"], 301) + + See :ref:`redirectingpost` for additional caveats. + """ + + urls = None + """The list of URL's to emit.""" + + encoding = 'utf-8' + """The encoding when passed urls are not native strings""" + + def __init__(self, urls, status=None, encoding=None): + self.urls = abs_urls = [ + # Note that urljoin will "do the right thing" whether url is: + # 1. a complete URL with host (e.g. "http://www.example.com/test") + # 2. a URL relative to root (e.g. "/dummy") + # 3. a URL relative to the current path + # Note that any query string in cherrypy.request is discarded. + urllib.parse.urljoin( + cherrypy.url(), + tonative(url, encoding or self.encoding), + ) + for url in always_iterable(urls) + ] + + status = ( + int(status) + if status is not None + else self.default_status + ) + if not 300 <= status <= 399: + raise ValueError('status must be between 300 and 399.') + + CherryPyException.__init__(self, abs_urls, status) + + @classproperty + def default_status(cls): + """ + The default redirect status for the request. + + RFC 2616 indicates a 301 response code fits our goal; however, + browser support for 301 is quite messy. Use 302/303 instead. See + http://www.alanflavell.org.uk/www/post-redirect.html + """ + return 303 if cherrypy.serving.request.protocol >= (1, 1) else 302 + + @property + def status(self): + """The integer HTTP status code to emit.""" + _, status = self.args[:2] + return status + + def set_response(self): + """Modify cherrypy.response status, headers, and body to represent + self. + + CherryPy uses this internally, but you can also use it to create an + HTTPRedirect object and set its output without *raising* the exception. + """ + response = cherrypy.serving.response + response.status = status = self.status + + if status in (300, 301, 302, 303, 307, 308): + response.headers['Content-Type'] = 'text/html;charset=utf-8' + # "The ... URI SHOULD be given by the Location field + # in the response." + response.headers['Location'] = self.urls[0] + + # "Unless the request method was HEAD, the entity of the response + # SHOULD contain a short hypertext note with a hyperlink to the + # new URI(s)." + msg = { + 300: 'This resource can be found at ', + 301: 'This resource has permanently moved to ', + 302: 'This resource resides temporarily at ', + 303: 'This resource can be found at ', + 307: 'This resource has moved temporarily to ', + 308: 'This resource has been moved to ', + }[status] + msg += '%s.' + msgs = [ + msg % (saxutils.quoteattr(u), html.escape(u, quote=False)) + for u in self.urls + ] + response.body = ntob('
\n'.join(msgs), 'utf-8') + # Previous code may have set C-L, so we have to reset it + # (allow finalize to set it). + response.headers.pop('Content-Length', None) + elif status == 304: + # Not Modified. + # "The response MUST include the following header fields: + # Date, unless its omission is required by section 14.18.1" + # The "Date" header should have been set in Response.__init__ + + # "...the response SHOULD NOT include other entity-headers." + for key in ('Allow', 'Content-Encoding', 'Content-Language', + 'Content-Length', 'Content-Location', 'Content-MD5', + 'Content-Range', 'Content-Type', 'Expires', + 'Last-Modified'): + if key in response.headers: + del response.headers[key] + + # "The 304 response MUST NOT contain a message-body." + response.body = None + # Previous code may have set C-L, so we have to reset it. + response.headers.pop('Content-Length', None) + elif status == 305: + # Use Proxy. + # self.urls[0] should be the URI of the proxy. + response.headers['Location'] = ntob(self.urls[0], 'utf-8') + response.body = None + # Previous code may have set C-L, so we have to reset it. + response.headers.pop('Content-Length', None) + else: + raise ValueError('The %s status code is unknown.' % status) + + def __call__(self): + """Use this exception as a request.handler (raise self).""" + raise self + + +def clean_headers(status): + """Remove any headers which should not apply to an error response.""" + response = cherrypy.serving.response + + # Remove headers which applied to the original content, + # but do not apply to the error page. + respheaders = response.headers + for key in ['Accept-Ranges', 'Age', 'ETag', 'Location', 'Retry-After', + 'Vary', 'Content-Encoding', 'Content-Length', 'Expires', + 'Content-Location', 'Content-MD5', 'Last-Modified']: + if key in respheaders: + del respheaders[key] + + if status != 416: + # A server sending a response with status code 416 (Requested + # range not satisfiable) SHOULD include a Content-Range field + # with a byte-range-resp-spec of "*". The instance-length + # specifies the current length of the selected resource. + # A response with status code 206 (Partial Content) MUST NOT + # include a Content-Range field with a byte-range- resp-spec of "*". + if 'Content-Range' in respheaders: + del respheaders['Content-Range'] + + +class HTTPError(CherryPyException): + + """Exception used to return an HTTP error code (4xx-5xx) to the client. + + This exception can be used to automatically send a response using a + http status code, with an appropriate error page. It takes an optional + ``status`` argument (which must be between 400 and 599); it defaults to 500 + ("Internal Server Error"). It also takes an optional ``message`` argument, + which will be returned in the response body. See + `RFC2616 `_ + for a complete list of available error codes and when to use them. + + Examples:: + + raise cherrypy.HTTPError(403) + raise cherrypy.HTTPError( + "403 Forbidden", "You are not allowed to access this resource.") + """ + + status = None + """The HTTP status code. May be of type int or str (with a Reason-Phrase). + """ + + code = None + """The integer HTTP status code.""" + + reason = None + """The HTTP Reason-Phrase string.""" + + def __init__(self, status=500, message=None): + self.status = status + try: + self.code, self.reason, defaultmsg = _httputil.valid_status(status) + except ValueError: + raise self.__class__(500, _exc_info()[1].args[0]) + + if self.code < 400 or self.code > 599: + raise ValueError('status must be between 400 and 599.') + + # See http://www.python.org/dev/peps/pep-0352/ + # self.message = message + self._message = message or defaultmsg + CherryPyException.__init__(self, status, message) + + def set_response(self): + """Modify cherrypy.response status, headers, and body to represent + self. + + CherryPy uses this internally, but you can also use it to create an + HTTPError object and set its output without *raising* the exception. + """ + response = cherrypy.serving.response + + clean_headers(self.code) + + # In all cases, finalize will be called after this method, + # so don't bother cleaning up response values here. + response.status = self.status + tb = None + if cherrypy.serving.request.show_tracebacks: + tb = format_exc() + + response.headers.pop('Content-Length', None) + + content = self.get_error_page(self.status, traceback=tb, + message=self._message) + response.body = content + + _be_ie_unfriendly(self.code) + + def get_error_page(self, *args, **kwargs): + return get_error_page(*args, **kwargs) + + def __call__(self): + """Use this exception as a request.handler (raise self).""" + raise self + + @classmethod + @contextlib.contextmanager + def handle(cls, exception, status=500, message=''): + """Translate exception into an HTTPError.""" + try: + yield + except exception as exc: + raise cls(status, message or str(exc)) + + +class NotFound(HTTPError): + + """Exception raised when a URL could not be mapped to any handler (404). + + This is equivalent to raising + :class:`HTTPError("404 Not Found") `. + """ + + def __init__(self, path=None): + if path is None: + request = cherrypy.serving.request + path = request.script_name + request.path_info + self.args = (path,) + HTTPError.__init__(self, 404, "The path '%s' was not found." % path) + + +_HTTPErrorTemplate = ''' + + + + %(status)s + + + +

%(status)s

+

%(message)s

+
%(traceback)s
+
+ + Powered by CherryPy %(version)s + +
+ + +''' + + +def get_error_page(status, **kwargs): + """Return an HTML page, containing a pretty error response. + + status should be an int or a str. + kwargs will be interpolated into the page template. + """ + try: + code, reason, message = _httputil.valid_status(status) + except ValueError: + raise cherrypy.HTTPError(500, _exc_info()[1].args[0]) + + # We can't use setdefault here, because some + # callers send None for kwarg values. + if kwargs.get('status') is None: + kwargs['status'] = '%s %s' % (code, reason) + if kwargs.get('message') is None: + kwargs['message'] = message + if kwargs.get('traceback') is None: + kwargs['traceback'] = '' + if kwargs.get('version') is None: + kwargs['version'] = cherrypy.__version__ + + for k, v in kwargs.items(): + if v is None: + kwargs[k] = '' + else: + kwargs[k] = html.escape(kwargs[k], quote=False) + + # Use a custom template or callable for the error page? + pages = cherrypy.serving.request.error_page + error_page = pages.get(code) or pages.get('default') + + # Default template, can be overridden below. + template = _HTTPErrorTemplate + if error_page: + try: + if hasattr(error_page, '__call__'): + # The caller function may be setting headers manually, + # so we delegate to it completely. We may be returning + # an iterator as well as a string here. + # + # We *must* make sure any content is not unicode. + result = error_page(**kwargs) + if cherrypy.lib.is_iterator(result): + from cherrypy.lib.encoding import UTF8StreamEncoder + return UTF8StreamEncoder(result) + elif isinstance(result, str): + return result.encode('utf-8') + else: + if not isinstance(result, bytes): + raise ValueError( + 'error page function did not ' + 'return a bytestring, str or an ' + 'iterator - returned object of type %s.' + % (type(result).__name__)) + return result + else: + # Load the template from this path. + template = io.open(error_page, newline='').read() + except Exception: + e = _format_exception(*_exc_info())[-1] + m = kwargs['message'] + if m: + m += '
' + m += 'In addition, the custom error page failed:\n
%s' % e + kwargs['message'] = m + + response = cherrypy.serving.response + response.headers['Content-Type'] = 'text/html;charset=utf-8' + result = template % kwargs + return result.encode('utf-8') + + +_ie_friendly_error_sizes = { + 400: 512, 403: 256, 404: 512, 405: 256, + 406: 512, 408: 512, 409: 512, 410: 256, + 500: 512, 501: 512, 505: 512, +} + + +def _be_ie_unfriendly(status): + response = cherrypy.serving.response + + # For some statuses, Internet Explorer 5+ shows "friendly error + # messages" instead of our response.body if the body is smaller + # than a given size. Fix this by returning a body over that size + # (by adding whitespace). + # See http://support.microsoft.com/kb/q218155/ + s = _ie_friendly_error_sizes.get(status, 0) + if s: + s += 1 + # Since we are issuing an HTTP error status, we assume that + # the entity is short, and we should just collapse it. + content = response.collapse_body() + content_length = len(content) + if content_length and content_length < s: + # IN ADDITION: the response must be written to IE + # in one chunk or it will still get replaced! Bah. + content = content + (b' ' * (s - content_length)) + response.body = content + response.headers['Content-Length'] = str(len(content)) + + +def format_exc(exc=None): + """Return exc (or sys.exc_info if None), formatted.""" + try: + if exc is None: + exc = _exc_info() + if exc == (None, None, None): + return '' + import traceback + return ''.join(traceback.format_exception(*exc)) + finally: + del exc + + +def bare_error(extrabody=None): + """Produce status, headers, body for a critical error. + + Returns a triple without calling any other questionable functions, + so it should be as error-free as possible. Call it from an HTTP server + if you get errors outside of the request. + + If extrabody is None, a friendly but rather unhelpful error message + is set in the body. If extrabody is a string, it will be appended + as-is to the body. + """ + + # The whole point of this function is to be a last line-of-defense + # in handling errors. That is, it must not raise any errors itself; + # it cannot be allowed to fail. Therefore, don't add to it! + # In particular, don't call any other CP functions. + + body = b'Unrecoverable error in the server.' + if extrabody is not None: + if not isinstance(extrabody, bytes): + extrabody = extrabody.encode('utf-8') + body += b'\n' + extrabody + + return (b'500 Internal Server Error', + [(b'Content-Type', b'text/plain'), + (b'Content-Length', ntob(str(len(body)), 'ISO-8859-1'))], + [body]) diff --git a/resources/lib/cherrypy/_cplogging.py b/resources/lib/cherrypy/_cplogging.py new file mode 100644 index 0000000..151d3b4 --- /dev/null +++ b/resources/lib/cherrypy/_cplogging.py @@ -0,0 +1,457 @@ +""" +Simple config +============= + +Although CherryPy uses the :mod:`Python logging module `, it does so +behind the scenes so that simple logging is simple, but complicated logging +is still possible. "Simple" logging means that you can log to the screen +(i.e. console/stdout) or to a file, and that you can easily have separate +error and access log files. + +Here are the simplified logging settings. You use these by adding lines to +your config file or dict. You should set these at either the global level or +per application (see next), but generally not both. + + * ``log.screen``: Set this to True to have both "error" and "access" messages + printed to stdout. + * ``log.access_file``: Set this to an absolute filename where you want + "access" messages written. + * ``log.error_file``: Set this to an absolute filename where you want "error" + messages written. + +Many events are automatically logged; to log your own application events, call +:func:`cherrypy.log`. + +Architecture +============ + +Separate scopes +--------------- + +CherryPy provides log managers at both the global and application layers. +This means you can have one set of logging rules for your entire site, +and another set of rules specific to each application. The global log +manager is found at :func:`cherrypy.log`, and the log manager for each +application is found at :attr:`app.log`. +If you're inside a request, the latter is reachable from +``cherrypy.request.app.log``; if you're outside a request, you'll have to +obtain a reference to the ``app``: either the return value of +:func:`tree.mount()` or, if you used +:func:`quickstart()` instead, via +``cherrypy.tree.apps['/']``. + +By default, the global logs are named "cherrypy.error" and "cherrypy.access", +and the application logs are named "cherrypy.error.2378745" and +"cherrypy.access.2378745" (the number is the id of the Application object). +This means that the application logs "bubble up" to the site logs, so if your +application has no log handlers, the site-level handlers will still log the +messages. + +Errors vs. Access +----------------- + +Each log manager handles both "access" messages (one per HTTP request) and +"error" messages (everything else). Note that the "error" log is not just for +errors! The format of access messages is highly formalized, but the error log +isn't--it receives messages from a variety of sources (including full error +tracebacks, if enabled). + +If you are logging the access log and error log to the same source, then there +is a possibility that a specially crafted error message may replicate an access +log message as described in CWE-117. In this case it is the application +developer's responsibility to manually escape data before +using CherryPy's log() +functionality, or they may create an application that is vulnerable to CWE-117. +This would be achieved by using a custom handler escape any special characters, +and attached as described below. + +Custom Handlers +=============== + +The simple settings above work by manipulating Python's standard :mod:`logging` +module. So when you need something more complex, the full power of the standard +module is yours to exploit. You can borrow or create custom handlers, formats, +filters, and much more. Here's an example that skips the standard FileHandler +and uses a RotatingFileHandler instead: + +:: + + #python + log = app.log + + # Remove the default FileHandlers if present. + log.error_file = "" + log.access_file = "" + + maxBytes = getattr(log, "rot_maxBytes", 10000000) + backupCount = getattr(log, "rot_backupCount", 1000) + + # Make a new RotatingFileHandler for the error log. + fname = getattr(log, "rot_error_file", "error.log") + h = handlers.RotatingFileHandler(fname, 'a', maxBytes, backupCount) + h.setLevel(DEBUG) + h.setFormatter(_cplogging.logfmt) + log.error_log.addHandler(h) + + # Make a new RotatingFileHandler for the access log. + fname = getattr(log, "rot_access_file", "access.log") + h = handlers.RotatingFileHandler(fname, 'a', maxBytes, backupCount) + h.setLevel(DEBUG) + h.setFormatter(_cplogging.logfmt) + log.access_log.addHandler(h) + + +The ``rot_*`` attributes are pulled straight from the application log object. +Since "log.*" config entries simply set attributes on the log object, you can +add custom attributes to your heart's content. Note that these handlers are +used ''instead'' of the default, simple handlers outlined above (so don't set +the "log.error_file" config entry, for example). +""" + +import datetime +import logging +import os +import sys + +import cherrypy +from cherrypy import _cperror + + +# Silence the no-handlers "warning" (stderr write!) in stdlib logging +logging.Logger.manager.emittedNoHandlerWarning = 1 +logfmt = logging.Formatter('%(message)s') + + +class NullHandler(logging.Handler): + + """A no-op logging handler to silence the logging.lastResort handler.""" + + def handle(self, record): + pass + + def emit(self, record): + pass + + def createLock(self): + self.lock = None + + +class LogManager(object): + + """An object to assist both simple and advanced logging. + + ``cherrypy.log`` is an instance of this class. + """ + + appid = None + """The id() of the Application object which owns this log manager. If this + is a global log manager, appid is None.""" + + error_log = None + """The actual :class:`logging.Logger` instance for error messages.""" + + access_log = None + """The actual :class:`logging.Logger` instance for access messages.""" + + access_log_format = '{h} {l} {u} {t} "{r}" {s} {b} "{f}" "{a}"' + + logger_root = None + """The "top-level" logger name. + + This string will be used as the first segment in the Logger names. + The default is "cherrypy", for example, in which case the Logger names + will be of the form:: + + cherrypy.error. + cherrypy.access. + """ + + def __init__(self, appid=None, logger_root='cherrypy'): + self.logger_root = logger_root + self.appid = appid + if appid is None: + self.error_log = logging.getLogger('%s.error' % logger_root) + self.access_log = logging.getLogger('%s.access' % logger_root) + else: + self.error_log = logging.getLogger( + '%s.error.%s' % (logger_root, appid)) + self.access_log = logging.getLogger( + '%s.access.%s' % (logger_root, appid)) + self.error_log.setLevel(logging.INFO) + self.access_log.setLevel(logging.INFO) + + # Silence the no-handlers "warning" (stderr write!) in stdlib logging + self.error_log.addHandler(NullHandler()) + self.access_log.addHandler(NullHandler()) + + cherrypy.engine.subscribe('graceful', self.reopen_files) + + def reopen_files(self): + """Close and reopen all file handlers.""" + for log in (self.error_log, self.access_log): + for h in log.handlers: + if isinstance(h, logging.FileHandler): + h.acquire() + h.stream.close() + h.stream = open(h.baseFilename, h.mode) + h.release() + + def error(self, msg='', context='', severity=logging.INFO, + traceback=False): + """Write the given ``msg`` to the error log. + + This is not just for errors! Applications may call this at any time + to log application-specific information. + + If ``traceback`` is True, the traceback of the current exception + (if any) will be appended to ``msg``. + """ + exc_info = None + if traceback: + exc_info = _cperror._exc_info() + + self.error_log.log( + severity, + ' '.join((self.time(), context, msg)), + exc_info=exc_info, + ) + + def __call__(self, *args, **kwargs): + """An alias for ``error``.""" + return self.error(*args, **kwargs) + + def access(self): + """Write to the access log (in Apache/NCSA Combined Log format). + + See the + `apache documentation + `_ + for format details. + + CherryPy calls this automatically for you. Note there are no arguments; + it collects the data itself from + :class:`cherrypy.request`. + + Like Apache started doing in 2.0.46, non-printable and other special + characters in %r (and we expand that to all parts) are escaped using + \\xhh sequences, where hh stands for the hexadecimal representation + of the raw byte. Exceptions from this rule are " and \\, which are + escaped by prepending a backslash, and all whitespace characters, + which are written in their C-style notation (\\n, \\t, etc). + """ + request = cherrypy.serving.request + remote = request.remote + response = cherrypy.serving.response + outheaders = response.headers + inheaders = request.headers + if response.output_status is None: + status = '-' + else: + status = response.output_status.split(b' ', 1)[0] + status = status.decode('ISO-8859-1') + + atoms = {'h': remote.name or remote.ip, + 'l': '-', + 'u': getattr(request, 'login', None) or '-', + 't': self.time(), + 'r': request.request_line, + 's': status, + 'b': dict.get(outheaders, 'Content-Length', '') or '-', + 'f': dict.get(inheaders, 'Referer', ''), + 'a': dict.get(inheaders, 'User-Agent', ''), + 'o': dict.get(inheaders, 'Host', '-'), + 'i': request.unique_id, + 'z': LazyRfc3339UtcTime(), + } + for k, v in atoms.items(): + if not isinstance(v, str): + v = str(v) + v = v.replace('"', '\\"').encode('utf8') + # Fortunately, repr(str) escapes unprintable chars, \n, \t, etc + # and backslash for us. All we have to do is strip the quotes. + v = repr(v)[2:-1] + + # in python 3.0 the repr of bytes (as returned by encode) + # uses double \'s. But then the logger escapes them yet, again + # resulting in quadruple slashes. Remove the extra one here. + v = v.replace('\\\\', '\\') + + # Escape double-quote. + atoms[k] = v + + try: + self.access_log.log( + logging.INFO, self.access_log_format.format(**atoms)) + except Exception: + self(traceback=True) + + def time(self): + """Return now() in Apache Common Log Format (no timezone).""" + now = datetime.datetime.now() + monthnames = ['jan', 'feb', 'mar', 'apr', 'may', 'jun', + 'jul', 'aug', 'sep', 'oct', 'nov', 'dec'] + month = monthnames[now.month - 1].capitalize() + return ('[%02d/%s/%04d:%02d:%02d:%02d]' % + (now.day, month, now.year, now.hour, now.minute, now.second)) + + def _get_builtin_handler(self, log, key): + for h in log.handlers: + if getattr(h, '_cpbuiltin', None) == key: + return h + + # ------------------------- Screen handlers ------------------------- # + def _set_screen_handler(self, log, enable, stream=None): + h = self._get_builtin_handler(log, 'screen') + if enable: + if not h: + if stream is None: + stream = sys.stderr + h = logging.StreamHandler(stream) + h.setFormatter(logfmt) + h._cpbuiltin = 'screen' + log.addHandler(h) + elif h: + log.handlers.remove(h) + + @property + def screen(self): + """Turn stderr/stdout logging on or off. + + If you set this to True, it'll add the appropriate StreamHandler for + you. If you set it to False, it will remove the handler. + """ + h = self._get_builtin_handler + has_h = h(self.error_log, 'screen') or h(self.access_log, 'screen') + return bool(has_h) + + @screen.setter + def screen(self, newvalue): + self._set_screen_handler(self.error_log, newvalue, stream=sys.stderr) + self._set_screen_handler(self.access_log, newvalue, stream=sys.stdout) + + # -------------------------- File handlers -------------------------- # + + def _add_builtin_file_handler(self, log, fname): + h = logging.FileHandler(fname) + h.setFormatter(logfmt) + h._cpbuiltin = 'file' + log.addHandler(h) + + def _set_file_handler(self, log, filename): + h = self._get_builtin_handler(log, 'file') + if filename: + if h: + if h.baseFilename != os.path.abspath(filename): + h.close() + log.handlers.remove(h) + self._add_builtin_file_handler(log, filename) + else: + self._add_builtin_file_handler(log, filename) + else: + if h: + h.close() + log.handlers.remove(h) + + @property + def error_file(self): + """The filename for self.error_log. + + If you set this to a string, it'll add the appropriate FileHandler for + you. If you set it to ``None`` or ``''``, it will remove the handler. + """ + h = self._get_builtin_handler(self.error_log, 'file') + if h: + return h.baseFilename + return '' + + @error_file.setter + def error_file(self, newvalue): + self._set_file_handler(self.error_log, newvalue) + + @property + def access_file(self): + """The filename for self.access_log. + + If you set this to a string, it'll add the appropriate FileHandler for + you. If you set it to ``None`` or ``''``, it will remove the handler. + """ + h = self._get_builtin_handler(self.access_log, 'file') + if h: + return h.baseFilename + return '' + + @access_file.setter + def access_file(self, newvalue): + self._set_file_handler(self.access_log, newvalue) + + # ------------------------- WSGI handlers ------------------------- # + + def _set_wsgi_handler(self, log, enable): + h = self._get_builtin_handler(log, 'wsgi') + if enable: + if not h: + h = WSGIErrorHandler() + h.setFormatter(logfmt) + h._cpbuiltin = 'wsgi' + log.addHandler(h) + elif h: + log.handlers.remove(h) + + @property + def wsgi(self): + """Write errors to wsgi.errors. + + If you set this to True, it'll add the appropriate + :class:`WSGIErrorHandler` for you + (which writes errors to ``wsgi.errors``). + If you set it to False, it will remove the handler. + """ + return bool(self._get_builtin_handler(self.error_log, 'wsgi')) + + @wsgi.setter + def wsgi(self, newvalue): + self._set_wsgi_handler(self.error_log, newvalue) + + +class WSGIErrorHandler(logging.Handler): + + "A handler class which writes logging records to environ['wsgi.errors']." + + def flush(self): + """Flushes the stream.""" + try: + stream = cherrypy.serving.request.wsgi_environ.get('wsgi.errors') + except (AttributeError, KeyError): + pass + else: + stream.flush() + + def emit(self, record): + """Emit a record.""" + try: + stream = cherrypy.serving.request.wsgi_environ.get('wsgi.errors') + except (AttributeError, KeyError): + pass + else: + try: + msg = self.format(record) + fs = '%s\n' + import types + # if no unicode support... + if not hasattr(types, 'UnicodeType'): + stream.write(fs % msg) + else: + try: + stream.write(fs % msg) + except UnicodeError: + stream.write(fs % msg.encode('UTF-8')) + self.flush() + except Exception: + self.handleError(record) + + +class LazyRfc3339UtcTime(object): + def __str__(self): + """Return now() in RFC3339 UTC Format.""" + now = datetime.datetime.now() + return now.isoformat('T') + 'Z' diff --git a/resources/lib/cherrypy/_cpmodpy.py b/resources/lib/cherrypy/_cpmodpy.py new file mode 100644 index 0000000..0e608c4 --- /dev/null +++ b/resources/lib/cherrypy/_cpmodpy.py @@ -0,0 +1,354 @@ +"""Native adapter for serving CherryPy via mod_python + +Basic usage: + +########################################## +# Application in a module called myapp.py +########################################## + +import cherrypy + +class Root: + @cherrypy.expose + def index(self): + return 'Hi there, Ho there, Hey there' + + +# We will use this method from the mod_python configuration +# as the entry point to our application +def setup_server(): + cherrypy.tree.mount(Root()) + cherrypy.config.update({'environment': 'production', + 'log.screen': False, + 'show_tracebacks': False}) + +########################################## +# mod_python settings for apache2 +# This should reside in your httpd.conf +# or a file that will be loaded at +# apache startup +########################################## + +# Start +DocumentRoot "/" +Listen 8080 +LoadModule python_module /usr/lib/apache2/modules/mod_python.so + + + PythonPath "sys.path+['/path/to/my/application']" + SetHandler python-program + PythonHandler cherrypy._cpmodpy::handler + PythonOption cherrypy.setup myapp::setup_server + PythonDebug On + +# End + +The actual path to your mod_python.so is dependent on your +environment. In this case we suppose a global mod_python +installation on a Linux distribution such as Ubuntu. + +We do set the PythonPath configuration setting so that +your application can be found by from the user running +the apache2 instance. Of course if your application +resides in the global site-package this won't be needed. + +Then restart apache2 and access http://127.0.0.1:8080 +""" + +import io +import logging +import os +import re +import sys + +from more_itertools import always_iterable + +import cherrypy +from cherrypy._cperror import format_exc, bare_error +from cherrypy.lib import httputil + + +# ------------------------------ Request-handling + + +def setup(req): + from mod_python import apache + + # Run any setup functions defined by a "PythonOption cherrypy.setup" + # directive. + options = req.get_options() + if 'cherrypy.setup' in options: + for function in options['cherrypy.setup'].split(): + atoms = function.split('::', 1) + if len(atoms) == 1: + mod = __import__(atoms[0], globals(), locals()) + else: + modname, fname = atoms + mod = __import__(modname, globals(), locals(), [fname]) + func = getattr(mod, fname) + func() + + cherrypy.config.update({'log.screen': False, + 'tools.ignore_headers.on': True, + 'tools.ignore_headers.headers': ['Range'], + }) + + engine = cherrypy.engine + if hasattr(engine, 'signal_handler'): + engine.signal_handler.unsubscribe() + if hasattr(engine, 'console_control_handler'): + engine.console_control_handler.unsubscribe() + engine.autoreload.unsubscribe() + cherrypy.server.unsubscribe() + + @engine.subscribe('log') + def _log(msg, level): + newlevel = apache.APLOG_ERR + if logging.DEBUG >= level: + newlevel = apache.APLOG_DEBUG + elif logging.INFO >= level: + newlevel = apache.APLOG_INFO + elif logging.WARNING >= level: + newlevel = apache.APLOG_WARNING + # On Windows, req.server is required or the msg will vanish. See + # http://www.modpython.org/pipermail/mod_python/2003-October/014291.html + # Also, "When server is not specified...LogLevel does not apply..." + apache.log_error(msg, newlevel, req.server) + + engine.start() + + def cherrypy_cleanup(data): + engine.exit() + try: + # apache.register_cleanup wasn't available until 3.1.4. + apache.register_cleanup(cherrypy_cleanup) + except AttributeError: + req.server.register_cleanup(req, cherrypy_cleanup) + + +class _ReadOnlyRequest: + expose = ('read', 'readline', 'readlines') + + def __init__(self, req): + for method in self.expose: + self.__dict__[method] = getattr(req, method) + + +recursive = False + +_isSetUp = False + + +def handler(req): + from mod_python import apache + try: + global _isSetUp + if not _isSetUp: + setup(req) + _isSetUp = True + + # Obtain a Request object from CherryPy + local = req.connection.local_addr + local = httputil.Host( + local[0], local[1], req.connection.local_host or '') + remote = req.connection.remote_addr + remote = httputil.Host( + remote[0], remote[1], req.connection.remote_host or '') + + scheme = req.parsed_uri[0] or 'http' + req.get_basic_auth_pw() + + try: + # apache.mpm_query only became available in mod_python 3.1 + q = apache.mpm_query + threaded = q(apache.AP_MPMQ_IS_THREADED) + forked = q(apache.AP_MPMQ_IS_FORKED) + except AttributeError: + bad_value = ("You must provide a PythonOption '%s', " + "either 'on' or 'off', when running a version " + 'of mod_python < 3.1') + + options = req.get_options() + + threaded = options.get('multithread', '').lower() + if threaded == 'on': + threaded = True + elif threaded == 'off': + threaded = False + else: + raise ValueError(bad_value % 'multithread') + + forked = options.get('multiprocess', '').lower() + if forked == 'on': + forked = True + elif forked == 'off': + forked = False + else: + raise ValueError(bad_value % 'multiprocess') + + sn = cherrypy.tree.script_name(req.uri or '/') + if sn is None: + send_response(req, '404 Not Found', [], '') + else: + app = cherrypy.tree.apps[sn] + method = req.method + path = req.uri + qs = req.args or '' + reqproto = req.protocol + headers = list(req.headers_in.copy().items()) + rfile = _ReadOnlyRequest(req) + prev = None + + try: + redirections = [] + while True: + request, response = app.get_serving(local, remote, scheme, + 'HTTP/1.1') + request.login = req.user + request.multithread = bool(threaded) + request.multiprocess = bool(forked) + request.app = app + request.prev = prev + + # Run the CherryPy Request object and obtain the response + try: + request.run(method, path, qs, reqproto, headers, rfile) + break + except cherrypy.InternalRedirect: + ir = sys.exc_info()[1] + app.release_serving() + prev = request + + if not recursive: + if ir.path in redirections: + raise RuntimeError( + 'InternalRedirector visited the same URL ' + 'twice: %r' % ir.path) + else: + # Add the *previous* path_info + qs to + # redirections. + if qs: + qs = '?' + qs + redirections.append(sn + path + qs) + + # Munge environment and try again. + method = 'GET' + path = ir.path + qs = ir.query_string + rfile = io.BytesIO() + + send_response( + req, response.output_status, response.header_list, + response.body, response.stream) + finally: + app.release_serving() + except Exception: + tb = format_exc() + cherrypy.log(tb, 'MOD_PYTHON', severity=logging.ERROR) + s, h, b = bare_error() + send_response(req, s, h, b) + return apache.OK + + +def send_response(req, status, headers, body, stream=False): + # Set response status + req.status = int(status[:3]) + + # Set response headers + req.content_type = 'text/plain' + for header, value in headers: + if header.lower() == 'content-type': + req.content_type = value + continue + req.headers_out.add(header, value) + + if stream: + # Flush now so the status and headers are sent immediately. + req.flush() + + # Set response body + for seg in always_iterable(body): + req.write(seg) + + +# --------------- Startup tools for CherryPy + mod_python --------------- # +try: + import subprocess + + def popen(fullcmd): + p = subprocess.Popen(fullcmd, shell=True, + stdout=subprocess.PIPE, stderr=subprocess.STDOUT, + close_fds=True) + return p.stdout +except ImportError: + def popen(fullcmd): + pipein, pipeout = os.popen4(fullcmd) + return pipeout + + +def read_process(cmd, args=''): + fullcmd = '%s %s' % (cmd, args) + pipeout = popen(fullcmd) + try: + firstline = pipeout.readline() + cmd_not_found = re.search( + b'(not recognized|No such file|not found)', + firstline, + re.IGNORECASE + ) + if cmd_not_found: + raise IOError('%s must be on your system path.' % cmd) + output = firstline + pipeout.read() + finally: + pipeout.close() + return output + + +class ModPythonServer(object): + + template = """ +# Apache2 server configuration file for running CherryPy with mod_python. + +DocumentRoot "/" +Listen %(port)s +LoadModule python_module modules/mod_python.so + + + SetHandler python-program + PythonHandler %(handler)s + PythonDebug On +%(opts)s + +""" + + def __init__(self, loc='/', port=80, opts=None, apache_path='apache', + handler='cherrypy._cpmodpy::handler'): + self.loc = loc + self.port = port + self.opts = opts + self.apache_path = apache_path + self.handler = handler + + def start(self): + opts = ''.join([' PythonOption %s %s\n' % (k, v) + for k, v in self.opts]) + conf_data = self.template % {'port': self.port, + 'loc': self.loc, + 'opts': opts, + 'handler': self.handler, + } + + mpconf = os.path.join(os.path.dirname(__file__), 'cpmodpy.conf') + f = open(mpconf, 'wb') + try: + f.write(conf_data) + finally: + f.close() + + response = read_process(self.apache_path, '-k start -f %s' % mpconf) + self.ready = True + return response + + def stop(self): + os.popen('apache -k stop') + self.ready = False diff --git a/resources/lib/cherrypy/_cpnative_server.py b/resources/lib/cherrypy/_cpnative_server.py new file mode 100644 index 0000000..e9671d2 --- /dev/null +++ b/resources/lib/cherrypy/_cpnative_server.py @@ -0,0 +1,168 @@ +"""Native adapter for serving CherryPy via its builtin server.""" + +import logging +import sys +import io + +import cheroot.server + +import cherrypy +from cherrypy._cperror import format_exc, bare_error +from cherrypy.lib import httputil +from ._cpcompat import tonative + + +class NativeGateway(cheroot.server.Gateway): + """Native gateway implementation allowing to bypass WSGI.""" + + recursive = False + + def respond(self): + """Obtain response from CherryPy machinery and then send it.""" + req = self.req + try: + # Obtain a Request object from CherryPy + local = req.server.bind_addr # FIXME: handle UNIX sockets + local = tonative(local[0]), local[1] + local = httputil.Host(local[0], local[1], '') + remote = tonative(req.conn.remote_addr), req.conn.remote_port + remote = httputil.Host(remote[0], remote[1], '') + + scheme = tonative(req.scheme) + sn = cherrypy.tree.script_name(tonative(req.uri or '/')) + if sn is None: + self.send_response('404 Not Found', [], ['']) + else: + app = cherrypy.tree.apps[sn] + method = tonative(req.method) + path = tonative(req.path) + qs = tonative(req.qs or '') + headers = ( + (tonative(h), tonative(v)) + for h, v in req.inheaders.items() + ) + rfile = req.rfile + prev = None + + try: + redirections = [] + while True: + request, response = app.get_serving( + local, remote, scheme, 'HTTP/1.1') + request.multithread = True + request.multiprocess = False + request.app = app + request.prev = prev + + # Run the CherryPy Request object and obtain the + # response + try: + request.run( + method, path, qs, + tonative(req.request_protocol), + headers, rfile, + ) + break + except cherrypy.InternalRedirect: + ir = sys.exc_info()[1] + app.release_serving() + prev = request + + if not self.recursive: + if ir.path in redirections: + raise RuntimeError( + 'InternalRedirector visited the same ' + 'URL twice: %r' % ir.path) + else: + # Add the *previous* path_info + qs to + # redirections. + if qs: + qs = '?' + qs + redirections.append(sn + path + qs) + + # Munge environment and try again. + method = 'GET' + path = ir.path + qs = ir.query_string + rfile = io.BytesIO() + + self.send_response( + response.output_status, response.header_list, + response.body) + finally: + app.release_serving() + except Exception: + tb = format_exc() + # print tb + cherrypy.log(tb, 'NATIVE_ADAPTER', severity=logging.ERROR) + s, h, b = bare_error() + self.send_response(s, h, b) + + def send_response(self, status, headers, body): + """Send response to HTTP request.""" + req = self.req + + # Set response status + req.status = status or b'500 Server Error' + + # Set response headers + for header, value in headers: + req.outheaders.append((header, value)) + if (req.ready and not req.sent_headers): + req.sent_headers = True + req.send_headers() + + # Set response body + for seg in body: + req.write(seg) + + +class CPHTTPServer(cheroot.server.HTTPServer): + """Wrapper for cheroot.server.HTTPServer. + + cheroot has been designed to not reference CherryPy in any way, + so that it can be used in other frameworks and applications. + Therefore, we wrap it here, so we can apply some attributes + from config -> cherrypy.server -> HTTPServer. + """ + + def __init__(self, server_adapter=cherrypy.server): + """Initialize CPHTTPServer.""" + self.server_adapter = server_adapter + + server_name = (self.server_adapter.socket_host or + self.server_adapter.socket_file or + None) + + cheroot.server.HTTPServer.__init__( + self, server_adapter.bind_addr, NativeGateway, + minthreads=server_adapter.thread_pool, + maxthreads=server_adapter.thread_pool_max, + server_name=server_name) + + self.max_request_header_size = ( + self.server_adapter.max_request_header_size or 0) + self.max_request_body_size = ( + self.server_adapter.max_request_body_size or 0) + self.request_queue_size = self.server_adapter.socket_queue_size + self.timeout = self.server_adapter.socket_timeout + self.shutdown_timeout = self.server_adapter.shutdown_timeout + self.protocol = self.server_adapter.protocol_version + self.nodelay = self.server_adapter.nodelay + + ssl_module = self.server_adapter.ssl_module or 'pyopenssl' + if self.server_adapter.ssl_context: + adapter_class = cheroot.server.get_ssl_adapter_class(ssl_module) + self.ssl_adapter = adapter_class( + self.server_adapter.ssl_certificate, + self.server_adapter.ssl_private_key, + self.server_adapter.ssl_certificate_chain, + self.server_adapter.ssl_ciphers) + self.ssl_adapter.context = self.server_adapter.ssl_context + elif self.server_adapter.ssl_certificate: + adapter_class = cheroot.server.get_ssl_adapter_class(ssl_module) + self.ssl_adapter = adapter_class( + self.server_adapter.ssl_certificate, + self.server_adapter.ssl_private_key, + self.server_adapter.ssl_certificate_chain, + self.server_adapter.ssl_ciphers) diff --git a/resources/lib/cherrypy/_cpreqbody.py b/resources/lib/cherrypy/_cpreqbody.py new file mode 100644 index 0000000..4d3cefe --- /dev/null +++ b/resources/lib/cherrypy/_cpreqbody.py @@ -0,0 +1,993 @@ +"""Request body processing for CherryPy. + +.. versionadded:: 3.2 + +Application authors have complete control over the parsing of HTTP request +entities. In short, +:attr:`cherrypy.request.body` +is now always set to an instance of +:class:`RequestBody`, +and *that* class is a subclass of :class:`Entity`. + +When an HTTP request includes an entity body, it is often desirable to +provide that information to applications in a form other than the raw bytes. +Different content types demand different approaches. Examples: + + * For a GIF file, we want the raw bytes in a stream. + * An HTML form is better parsed into its component fields, and each text field + decoded from bytes to unicode. + * A JSON body should be deserialized into a Python dict or list. + +When the request contains a Content-Type header, the media type is used as a +key to look up a value in the +:attr:`request.body.processors` dict. +If the full media +type is not found, then the major type is tried; for example, if no processor +is found for the 'image/jpeg' type, then we look for a processor for the +'image' types altogether. If neither the full type nor the major type has a +matching processor, then a default processor is used +(:func:`default_proc`). For most +types, this means no processing is done, and the body is left unread as a +raw byte stream. Processors are configurable in an 'on_start_resource' hook. + +Some processors, especially those for the 'text' types, attempt to decode bytes +to unicode. If the Content-Type request header includes a 'charset' parameter, +this is used to decode the entity. Otherwise, one or more default charsets may +be attempted, although this decision is up to each processor. If a processor +successfully decodes an Entity or Part, it should set the +:attr:`charset` attribute +on the Entity or Part to the name of the successful charset, so that +applications can easily re-encode or transcode the value if they wish. + +If the Content-Type of the request entity is of major type 'multipart', then +the above parsing process, and possibly a decoding process, is performed for +each part. + +For both the full entity and multipart parts, a Content-Disposition header may +be used to fill :attr:`name` and +:attr:`filename` attributes on the +request.body or the Part. + +.. _custombodyprocessors: + +Custom Processors +================= + +You can add your own processors for any specific or major MIME type. Simply add +it to the :attr:`processors` dict in a +hook/tool that runs at ``on_start_resource`` or ``before_request_body``. +Here's the built-in JSON tool for an example:: + + def json_in(force=True, debug=False): + request = cherrypy.serving.request + def json_processor(entity): + '''Read application/json data into request.json.''' + if not entity.headers.get("Content-Length", ""): + raise cherrypy.HTTPError(411) + + body = entity.fp.read() + try: + request.json = json_decode(body) + except ValueError: + raise cherrypy.HTTPError(400, 'Invalid JSON document') + if force: + request.body.processors.clear() + request.body.default_proc = cherrypy.HTTPError( + 415, 'Expected an application/json content type') + request.body.processors['application/json'] = json_processor + +We begin by defining a new ``json_processor`` function to stick in the +``processors`` dictionary. All processor functions take a single argument, +the ``Entity`` instance they are to process. It will be called whenever a +request is received (for those URI's where the tool is turned on) which +has a ``Content-Type`` of "application/json". + +First, it checks for a valid ``Content-Length`` (raising 411 if not valid), +then reads the remaining bytes on the socket. The ``fp`` object knows its +own length, so it won't hang waiting for data that never arrives. It will +return when all data has been read. Then, we decode those bytes using +Python's built-in ``json`` module, and stick the decoded result onto +``request.json`` . If it cannot be decoded, we raise 400. + +If the "force" argument is True (the default), the ``Tool`` clears the +``processors`` dict so that request entities of other ``Content-Types`` +aren't parsed at all. Since there's no entry for those invalid MIME +types, the ``default_proc`` method of ``cherrypy.request.body`` is +called. But this does nothing by default (usually to provide the page +handler an opportunity to handle it.) +But in our case, we want to raise 415, so we replace +``request.body.default_proc`` +with the error (``HTTPError`` instances, when called, raise themselves). + +If we were defining a custom processor, we can do so without making a ``Tool``. +Just add the config entry:: + + request.body.processors = {'application/json': json_processor} + +Note that you can only replace the ``processors`` dict wholesale this way, +not update the existing one. +""" + +try: + from io import DEFAULT_BUFFER_SIZE +except ImportError: + DEFAULT_BUFFER_SIZE = 8192 +import re +import sys +import tempfile +from urllib.parse import unquote + +import cheroot.server + +import cherrypy +from cherrypy._cpcompat import ntou +from cherrypy.lib import httputil + + +def unquote_plus(bs): + """Bytes version of urllib.parse.unquote_plus.""" + bs = bs.replace(b'+', b' ') + atoms = bs.split(b'%') + for i in range(1, len(atoms)): + item = atoms[i] + try: + pct = int(item[:2], 16) + atoms[i] = bytes([pct]) + item[2:] + except ValueError: + pass + return b''.join(atoms) + + +# ------------------------------- Processors -------------------------------- # + +def process_urlencoded(entity): + """Read application/x-www-form-urlencoded data into entity.params.""" + qs = entity.fp.read() + for charset in entity.attempt_charsets: + try: + params = {} + for aparam in qs.split(b'&'): + for pair in aparam.split(b';'): + if not pair: + continue + + atoms = pair.split(b'=', 1) + if len(atoms) == 1: + atoms.append(b'') + + key = unquote_plus(atoms[0]).decode(charset) + value = unquote_plus(atoms[1]).decode(charset) + + if key in params: + if not isinstance(params[key], list): + params[key] = [params[key]] + params[key].append(value) + else: + params[key] = value + except UnicodeDecodeError: + pass + else: + entity.charset = charset + break + else: + raise cherrypy.HTTPError( + 400, 'The request entity could not be decoded. The following ' + 'charsets were attempted: %s' % repr(entity.attempt_charsets)) + + # Now that all values have been successfully parsed and decoded, + # apply them to the entity.params dict. + for key, value in params.items(): + if key in entity.params: + if not isinstance(entity.params[key], list): + entity.params[key] = [entity.params[key]] + entity.params[key].append(value) + else: + entity.params[key] = value + + +def process_multipart(entity): + """Read all multipart parts into entity.parts.""" + ib = '' + if 'boundary' in entity.content_type.params: + # http://tools.ietf.org/html/rfc2046#section-5.1.1 + # "The grammar for parameters on the Content-type field is such that it + # is often necessary to enclose the boundary parameter values in quotes + # on the Content-type line" + ib = entity.content_type.params['boundary'].strip('"') + + if not re.match('^[ -~]{0,200}[!-~]$', ib): + raise ValueError('Invalid boundary in multipart form: %r' % (ib,)) + + ib = ('--' + ib).encode('ascii') + + # Find the first marker + while True: + b = entity.readline() + if not b: + return + + b = b.strip() + if b == ib: + break + + # Read all parts + while True: + part = entity.part_class.from_fp(entity.fp, ib) + entity.parts.append(part) + part.process() + if part.fp.done: + break + + +def process_multipart_form_data(entity): + """Read all multipart/form-data parts into entity.parts or entity.params. + """ + process_multipart(entity) + + kept_parts = [] + for part in entity.parts: + if part.name is None: + kept_parts.append(part) + else: + if part.filename is None: + # It's a regular field + value = part.fullvalue() + else: + # It's a file upload. Retain the whole part so consumer code + # has access to its .file and .filename attributes. + value = part + + if part.name in entity.params: + if not isinstance(entity.params[part.name], list): + entity.params[part.name] = [entity.params[part.name]] + entity.params[part.name].append(value) + else: + entity.params[part.name] = value + + entity.parts = kept_parts + + +def _old_process_multipart(entity): + """The behavior of 3.2 and lower. Deprecated and will be changed in 3.3.""" + process_multipart(entity) + + params = entity.params + + for part in entity.parts: + if part.name is None: + key = ntou('parts') + else: + key = part.name + + if part.filename is None: + # It's a regular field + value = part.fullvalue() + else: + # It's a file upload. Retain the whole part so consumer code + # has access to its .file and .filename attributes. + value = part + + if key in params: + if not isinstance(params[key], list): + params[key] = [params[key]] + params[key].append(value) + else: + params[key] = value + + +# -------------------------------- Entities --------------------------------- # +class Entity(object): + + """An HTTP request body, or MIME multipart body. + + This class collects information about the HTTP request entity. When a + given entity is of MIME type "multipart", each part is parsed into its own + Entity instance, and the set of parts stored in + :attr:`entity.parts`. + + Between the ``before_request_body`` and ``before_handler`` tools, CherryPy + tries to process the request body (if any) by calling + :func:`request.body.process`. + This uses the ``content_type`` of the Entity to look up a suitable + processor in + :attr:`Entity.processors`, + a dict. + If a matching processor cannot be found for the complete Content-Type, + it tries again using the major type. For example, if a request with an + entity of type "image/jpeg" arrives, but no processor can be found for + that complete type, then one is sought for the major type "image". If a + processor is still not found, then the + :func:`default_proc` method + of the Entity is called (which does nothing by default; you can + override this too). + + CherryPy includes processors for the "application/x-www-form-urlencoded" + type, the "multipart/form-data" type, and the "multipart" major type. + CherryPy 3.2 processes these types almost exactly as older versions. + Parts are passed as arguments to the page handler using their + ``Content-Disposition.name`` if given, otherwise in a generic "parts" + argument. Each such part is either a string, or the + :class:`Part` itself if it's a file. (In this + case it will have ``file`` and ``filename`` attributes, or possibly a + ``value`` attribute). Each Part is itself a subclass of + Entity, and has its own ``process`` method and ``processors`` dict. + + There is a separate processor for the "multipart" major type which is more + flexible, and simply stores all multipart parts in + :attr:`request.body.parts`. You can + enable it with:: + + cherrypy.request.body.processors['multipart'] = \ + _cpreqbody.process_multipart + + in an ``on_start_resource`` tool. + """ + + # http://tools.ietf.org/html/rfc2046#section-4.1.2: + # "The default character set, which must be assumed in the + # absence of a charset parameter, is US-ASCII." + # However, many browsers send data in utf-8 with no charset. + attempt_charsets = ['utf-8'] + r"""A list of strings, each of which should be a known encoding. + + When the Content-Type of the request body warrants it, each of the given + encodings will be tried in order. The first one to successfully decode the + entity without raising an error is stored as + :attr:`entity.charset`. This defaults + to ``['utf-8']`` (plus 'ISO-8859-1' for "text/\*" types, as required by + `HTTP/1.1 + `_), + but ``['us-ascii', 'utf-8']`` for multipart parts. + """ + + charset = None + """The successful decoding; see "attempt_charsets" above.""" + + content_type = None + """The value of the Content-Type request header. + + If the Entity is part of a multipart payload, this will be the Content-Type + given in the MIME headers for this part. + """ + + default_content_type = 'application/x-www-form-urlencoded' + """This defines a default ``Content-Type`` to use if no Content-Type header + is given. The empty string is used for RequestBody, which results in the + request body not being read or parsed at all. This is by design; a missing + ``Content-Type`` header in the HTTP request entity is an error at best, + and a security hole at worst. For multipart parts, however, the MIME spec + declares that a part with no Content-Type defaults to "text/plain" + (see :class:`Part`). + """ + + filename = None + """The ``Content-Disposition.filename`` header, if available.""" + + fp = None + """The readable socket file object.""" + + headers = None + """A dict of request/multipart header names and values. + + This is a copy of the ``request.headers`` for the ``request.body``; + for multipart parts, it is the set of headers for that part. + """ + + length = None + """The value of the ``Content-Length`` header, if provided.""" + + name = None + """The "name" parameter of the ``Content-Disposition`` header, if any.""" + + params = None + """ + If the request Content-Type is 'application/x-www-form-urlencoded' or + multipart, this will be a dict of the params pulled from the entity + body; that is, it will be the portion of request.params that come + from the message body (sometimes called "POST params", although they + can be sent with various HTTP method verbs). This value is set between + the 'before_request_body' and 'before_handler' hooks (assuming that + process_request_body is True).""" + + processors = {'application/x-www-form-urlencoded': process_urlencoded, + 'multipart/form-data': process_multipart_form_data, + 'multipart': process_multipart, + } + """A dict of Content-Type names to processor methods.""" + + parts = None + """A list of Part instances if ``Content-Type`` is of major type + "multipart".""" + + part_class = None + """The class used for multipart parts. + + You can replace this with custom subclasses to alter the processing of + multipart parts. + """ + + def __init__(self, fp, headers, params=None, parts=None): + # Make an instance-specific copy of the class processors + # so Tools, etc. can replace them per-request. + self.processors = self.processors.copy() + + self.fp = fp + self.headers = headers + + if params is None: + params = {} + self.params = params + + if parts is None: + parts = [] + self.parts = parts + + # Content-Type + self.content_type = headers.elements('Content-Type') + if self.content_type: + self.content_type = self.content_type[0] + else: + self.content_type = httputil.HeaderElement.from_str( + self.default_content_type) + + # Copy the class 'attempt_charsets', prepending any Content-Type + # charset + dec = self.content_type.params.get('charset', None) + if dec: + self.attempt_charsets = [dec] + [c for c in self.attempt_charsets + if c != dec] + else: + self.attempt_charsets = self.attempt_charsets[:] + + # Length + self.length = None + clen = headers.get('Content-Length', None) + # If Transfer-Encoding is 'chunked', ignore any Content-Length. + if ( + clen is not None and + 'chunked' not in headers.get('Transfer-Encoding', '') + ): + try: + self.length = int(clen) + except ValueError: + pass + + # Content-Disposition + self.name = None + self.filename = None + disp = headers.elements('Content-Disposition') + if disp: + disp = disp[0] + if 'name' in disp.params: + self.name = disp.params['name'] + if self.name.startswith('"') and self.name.endswith('"'): + self.name = self.name[1:-1] + if 'filename' in disp.params: + self.filename = disp.params['filename'] + if ( + self.filename.startswith('"') and + self.filename.endswith('"') + ): + self.filename = self.filename[1:-1] + if 'filename*' in disp.params: + # @see https://tools.ietf.org/html/rfc5987 + encoding, lang, filename = disp.params['filename*'].split("'") + self.filename = unquote(str(filename), encoding) + + def read(self, size=None, fp_out=None): + return self.fp.read(size, fp_out) + + def readline(self, size=None): + return self.fp.readline(size) + + def readlines(self, sizehint=None): + return self.fp.readlines(sizehint) + + def __iter__(self): + return self + + def __next__(self): + line = self.readline() + if not line: + raise StopIteration + return line + + def next(self): + return self.__next__() + + def read_into_file(self, fp_out=None): + """Read the request body into fp_out (or make_file() if None). + + Return fp_out. + """ + if fp_out is None: + fp_out = self.make_file() + self.read(fp_out=fp_out) + return fp_out + + def make_file(self): + """Return a file-like object into which the request body will be read. + + By default, this will return a TemporaryFile. Override as needed. + See also :attr:`cherrypy._cpreqbody.Part.maxrambytes`.""" + return tempfile.TemporaryFile() + + def fullvalue(self): + """Return this entity as a string, whether stored in a file or not.""" + if self.file: + # It was stored in a tempfile. Read it. + self.file.seek(0) + value = self.file.read() + self.file.seek(0) + else: + value = self.value + value = self.decode_entity(value) + return value + + def decode_entity(self, value): + """Return a given byte encoded value as a string""" + for charset in self.attempt_charsets: + try: + value = value.decode(charset) + except UnicodeDecodeError: + pass + else: + self.charset = charset + return value + else: + raise cherrypy.HTTPError( + 400, + 'The request entity could not be decoded. The following ' + 'charsets were attempted: %s' % repr(self.attempt_charsets) + ) + + def process(self): + """Execute the best-match processor for the given media type.""" + proc = None + ct = self.content_type.value + try: + proc = self.processors[ct] + except KeyError: + toptype = ct.split('/', 1)[0] + try: + proc = self.processors[toptype] + except KeyError: + pass + if proc is None: + self.default_proc() + else: + proc(self) + + def default_proc(self): + """Called if a more-specific processor is not found for the + ``Content-Type``. + """ + # Leave the fp alone for someone else to read. This works fine + # for request.body, but the Part subclasses need to override this + # so they can move on to the next part. + pass + + +class Part(Entity): + + """A MIME part entity, part of a multipart entity.""" + + # "The default character set, which must be assumed in the absence of a + # charset parameter, is US-ASCII." + attempt_charsets = ['us-ascii', 'utf-8'] + r"""A list of strings, each of which should be a known encoding. + + When the Content-Type of the request body warrants it, each of the given + encodings will be tried in order. The first one to successfully decode the + entity without raising an error is stored as + :attr:`entity.charset`. This defaults + to ``['utf-8']`` (plus 'ISO-8859-1' for "text/\*" types, as required by + `HTTP/1.1 + `_), + but ``['us-ascii', 'utf-8']`` for multipart parts. + """ + + boundary = None + """The MIME multipart boundary.""" + + default_content_type = 'text/plain' + """This defines a default ``Content-Type`` to use if no Content-Type header + is given. The empty string is used for RequestBody, which results in the + request body not being read or parsed at all. This is by design; a missing + ``Content-Type`` header in the HTTP request entity is an error at best, + and a security hole at worst. For multipart parts, however (this class), + the MIME spec declares that a part with no Content-Type defaults to + "text/plain". + """ + + # This is the default in stdlib cgi. We may want to increase it. + maxrambytes = 1000 + """The threshold of bytes after which point the ``Part`` will store + its data in a file (generated by + :func:`make_file`) + instead of a string. Defaults to 1000, just like the :mod:`cgi` + module in Python's standard library. + """ + + def __init__(self, fp, headers, boundary): + Entity.__init__(self, fp, headers) + self.boundary = boundary + self.file = None + self.value = None + + @classmethod + def from_fp(cls, fp, boundary): + headers = cls.read_headers(fp) + return cls(fp, headers, boundary) + + @classmethod + def read_headers(cls, fp): + headers = httputil.HeaderMap() + while True: + line = fp.readline() + if not line: + # No more data--illegal end of headers + raise EOFError('Illegal end of headers.') + + if line == b'\r\n': + # Normal end of headers + break + if not line.endswith(b'\r\n'): + raise ValueError('MIME requires CRLF terminators: %r' % line) + + if line[0] in b' \t': + # It's a continuation line. + v = line.strip().decode('ISO-8859-1') + else: + k, v = line.split(b':', 1) + k = k.strip().decode('ISO-8859-1') + v = v.strip().decode('ISO-8859-1') + + existing = headers.get(k) + if existing: + v = ', '.join((existing, v)) + headers[k] = v + + return headers + + def read_lines_to_boundary(self, fp_out=None): + """Read bytes from self.fp and return or write them to a file. + + If the 'fp_out' argument is None (the default), all bytes read are + returned in a single byte string. + + If the 'fp_out' argument is not None, it must be a file-like + object that supports the 'write' method; all bytes read will be + written to the fp, and that fp is returned. + """ + endmarker = self.boundary + b'--' + delim = b'' + prev_lf = True + lines = [] + seen = 0 + while True: + line = self.fp.readline(1 << 16) + if not line: + raise EOFError('Illegal end of multipart body.') + if line.startswith(b'--') and prev_lf: + strippedline = line.strip() + if strippedline == self.boundary: + break + if strippedline == endmarker: + self.fp.finish() + break + + line = delim + line + + if line.endswith(b'\r\n'): + delim = b'\r\n' + line = line[:-2] + prev_lf = True + elif line.endswith(b'\n'): + delim = b'\n' + line = line[:-1] + prev_lf = True + else: + delim = b'' + prev_lf = False + + if fp_out is None: + lines.append(line) + seen += len(line) + if seen > self.maxrambytes: + fp_out = self.make_file() + for line in lines: + fp_out.write(line) + else: + fp_out.write(line) + + if fp_out is None: + result = b''.join(lines) + return result + else: + fp_out.seek(0) + return fp_out + + def default_proc(self): + """Called if a more-specific processor is not found for the + ``Content-Type``. + """ + if self.filename: + # Always read into a file if a .filename was given. + self.file = self.read_into_file() + else: + result = self.read_lines_to_boundary() + if isinstance(result, bytes): + self.value = result + else: + self.file = result + + def read_into_file(self, fp_out=None): + """Read the request body into fp_out (or make_file() if None). + + Return fp_out. + """ + if fp_out is None: + fp_out = self.make_file() + self.read_lines_to_boundary(fp_out=fp_out) + return fp_out + + +Entity.part_class = Part + +inf = float('inf') + + +class SizedReader: + + def __init__(self, fp, length, maxbytes, bufsize=DEFAULT_BUFFER_SIZE, + has_trailers=False): + # Wrap our fp in a buffer so peek() works + self.fp = fp + self.length = length + self.maxbytes = maxbytes + self.buffer = b'' + self.bufsize = bufsize + self.bytes_read = 0 + self.done = False + self.has_trailers = has_trailers + + def read(self, size=None, fp_out=None): + """Read bytes from the request body and return or write them to a file. + + A number of bytes less than or equal to the 'size' argument are read + off the socket. The actual number of bytes read are tracked in + self.bytes_read. The number may be smaller than 'size' when 1) the + client sends fewer bytes, 2) the 'Content-Length' request header + specifies fewer bytes than requested, or 3) the number of bytes read + exceeds self.maxbytes (in which case, 413 is raised). + + If the 'fp_out' argument is None (the default), all bytes read are + returned in a single byte string. + + If the 'fp_out' argument is not None, it must be a file-like + object that supports the 'write' method; all bytes read will be + written to the fp, and None is returned. + """ + + if self.length is None: + if size is None: + remaining = inf + else: + remaining = size + else: + remaining = self.length - self.bytes_read + if size and size < remaining: + remaining = size + if remaining == 0: + self.finish() + if fp_out is None: + return b'' + else: + return None + + chunks = [] + + # Read bytes from the buffer. + if self.buffer: + if remaining is inf: + data = self.buffer + self.buffer = b'' + else: + data = self.buffer[:remaining] + self.buffer = self.buffer[remaining:] + datalen = len(data) + remaining -= datalen + + # Check lengths. + self.bytes_read += datalen + if self.maxbytes and self.bytes_read > self.maxbytes: + raise cherrypy.HTTPError(413) + + # Store the data. + if fp_out is None: + chunks.append(data) + else: + fp_out.write(data) + + # Read bytes from the socket. + while remaining > 0: + chunksize = min(remaining, self.bufsize) + try: + data = self.fp.read(chunksize) + except Exception: + e = sys.exc_info()[1] + if e.__class__.__name__ == 'MaxSizeExceeded': + # Post data is too big + raise cherrypy.HTTPError( + 413, 'Maximum request length: %r' % e.args[1]) + else: + raise + if not data: + self.finish() + break + datalen = len(data) + remaining -= datalen + + # Check lengths. + self.bytes_read += datalen + if self.maxbytes and self.bytes_read > self.maxbytes: + raise cherrypy.HTTPError(413) + + # Store the data. + if fp_out is None: + chunks.append(data) + else: + fp_out.write(data) + + if fp_out is None: + return b''.join(chunks) + + def readline(self, size=None): + """Read a line from the request body and return it.""" + chunks = [] + while size is None or size > 0: + chunksize = self.bufsize + if size is not None and size < self.bufsize: + chunksize = size + data = self.read(chunksize) + if not data: + break + pos = data.find(b'\n') + 1 + if pos: + chunks.append(data[:pos]) + remainder = data[pos:] + self.buffer += remainder + self.bytes_read -= len(remainder) + break + else: + chunks.append(data) + return b''.join(chunks) + + def readlines(self, sizehint=None): + """Read lines from the request body and return them.""" + if self.length is not None: + if sizehint is None: + sizehint = self.length - self.bytes_read + else: + sizehint = min(sizehint, self.length - self.bytes_read) + + lines = [] + seen = 0 + while True: + line = self.readline() + if not line: + break + lines.append(line) + seen += len(line) + if seen >= sizehint: + break + return lines + + def finish(self): + self.done = True + if self.has_trailers and hasattr(self.fp, 'read_trailer_lines'): + self.trailers = {} + + try: + for line in self.fp.read_trailer_lines(): + if line[0] in b' \t': + # It's a continuation line. + v = line.strip() + else: + try: + k, v = line.split(b':', 1) + except ValueError: + raise ValueError('Illegal header line.') + k = k.strip().title() + v = v.strip() + + if k in cheroot.server.comma_separated_headers: + existing = self.trailers.get(k) + if existing: + v = b', '.join((existing, v)) + self.trailers[k] = v + except Exception: + e = sys.exc_info()[1] + if e.__class__.__name__ == 'MaxSizeExceeded': + # Post data is too big + raise cherrypy.HTTPError( + 413, 'Maximum request length: %r' % e.args[1]) + else: + raise + + +class RequestBody(Entity): + + """The entity of the HTTP request.""" + + bufsize = 8 * 1024 + """The buffer size used when reading the socket.""" + + # Don't parse the request body at all if the client didn't provide + # a Content-Type header. See + # https://github.com/cherrypy/cherrypy/issues/790 + default_content_type = '' + """This defines a default ``Content-Type`` to use if no Content-Type header + is given. The empty string is used for RequestBody, which results in the + request body not being read or parsed at all. This is by design; a missing + ``Content-Type`` header in the HTTP request entity is an error at best, + and a security hole at worst. For multipart parts, however, the MIME spec + declares that a part with no Content-Type defaults to "text/plain" + (see :class:`Part`). + """ + + maxbytes = None + """Raise ``MaxSizeExceeded`` if more bytes than this are read from + the socket. + """ + + def __init__(self, fp, headers, params=None, request_params=None): + Entity.__init__(self, fp, headers, params) + + # http://www.w3.org/Protocols/rfc2616/rfc2616-sec3.html#sec3.7.1 + # When no explicit charset parameter is provided by the + # sender, media subtypes of the "text" type are defined + # to have a default charset value of "ISO-8859-1" when + # received via HTTP. + if self.content_type.value.startswith('text/'): + for c in ('ISO-8859-1', 'iso-8859-1', 'Latin-1', 'latin-1'): + if c in self.attempt_charsets: + break + else: + self.attempt_charsets.append('ISO-8859-1') + + # Temporary fix while deprecating passing .parts as .params. + self.processors['multipart'] = _old_process_multipart + + if request_params is None: + request_params = {} + self.request_params = request_params + + def process(self): + """Process the request entity based on its Content-Type.""" + # "The presence of a message-body in a request is signaled by the + # inclusion of a Content-Length or Transfer-Encoding header field in + # the request's message-headers." + # It is possible to send a POST request with no body, for example; + # however, app developers are responsible in that case to set + # cherrypy.request.process_body to False so this method isn't called. + h = cherrypy.serving.request.headers + if 'Content-Length' not in h and 'Transfer-Encoding' not in h: + raise cherrypy.HTTPError(411) + + self.fp = SizedReader(self.fp, self.length, + self.maxbytes, bufsize=self.bufsize, + has_trailers='Trailer' in h) + super(RequestBody, self).process() + + # Body params should also be a part of the request_params + # add them in here. + request_params = self.request_params + for key, value in self.params.items(): + if key in request_params: + if not isinstance(request_params[key], list): + request_params[key] = [request_params[key]] + request_params[key].append(value) + else: + request_params[key] = value diff --git a/resources/lib/cherrypy/_cprequest.py b/resources/lib/cherrypy/_cprequest.py new file mode 100644 index 0000000..9b86bd6 --- /dev/null +++ b/resources/lib/cherrypy/_cprequest.py @@ -0,0 +1,932 @@ +import sys +import time +import collections +import operator +from http.cookies import SimpleCookie, CookieError + +import uuid + +from more_itertools import consume + +import cherrypy +from cherrypy._cpcompat import ntob +from cherrypy import _cpreqbody +from cherrypy._cperror import format_exc, bare_error +from cherrypy.lib import httputil, reprconf, encoding + + +class Hook(object): + + """A callback and its metadata: failsafe, priority, and kwargs.""" + + callback = None + """ + The bare callable that this Hook object is wrapping, which will + be called when the Hook is called.""" + + failsafe = False + """ + If True, the callback is guaranteed to run even if other callbacks + from the same call point raise exceptions.""" + + priority = 50 + """ + Defines the order of execution for a list of Hooks. Priority numbers + should be limited to the closed interval [0, 100], but values outside + this range are acceptable, as are fractional values.""" + + kwargs = {} + """ + A set of keyword arguments that will be passed to the + callable on each call.""" + + def __init__(self, callback, failsafe=None, priority=None, **kwargs): + self.callback = callback + + if failsafe is None: + failsafe = getattr(callback, 'failsafe', False) + self.failsafe = failsafe + + if priority is None: + priority = getattr(callback, 'priority', 50) + self.priority = priority + + self.kwargs = kwargs + + def __lt__(self, other): + """ + Hooks sort by priority, ascending, such that + hooks of lower priority are run first. + """ + return self.priority < other.priority + + def __call__(self): + """Run self.callback(**self.kwargs).""" + return self.callback(**self.kwargs) + + def __repr__(self): + cls = self.__class__ + return ('%s.%s(callback=%r, failsafe=%r, priority=%r, %s)' + % (cls.__module__, cls.__name__, self.callback, + self.failsafe, self.priority, + ', '.join(['%s=%r' % (k, v) + for k, v in self.kwargs.items()]))) + + +class HookMap(dict): + + """A map of call points to lists of callbacks (Hook objects).""" + + def __new__(cls, points=None): + d = dict.__new__(cls) + for p in points or []: + d[p] = [] + return d + + def __init__(self, *a, **kw): + pass + + def attach(self, point, callback, failsafe=None, priority=None, **kwargs): + """Append a new Hook made from the supplied arguments.""" + self[point].append(Hook(callback, failsafe, priority, **kwargs)) + + def run(self, point): + """Execute all registered Hooks (callbacks) for the given point.""" + self.run_hooks(iter(sorted(self[point]))) + + @classmethod + def run_hooks(cls, hooks): + """Execute the indicated hooks, trapping errors. + + Hooks with ``.failsafe == True`` are guaranteed to run + even if others at the same hookpoint fail. In this case, + log the failure and proceed on to the next hook. The only + way to stop all processing from one of these hooks is + to raise a BaseException like SystemExit or + KeyboardInterrupt and stop the whole server. + """ + assert isinstance(hooks, collections.abc.Iterator) + quiet_errors = ( + cherrypy.HTTPError, + cherrypy.HTTPRedirect, + cherrypy.InternalRedirect, + ) + safe = filter(operator.attrgetter('failsafe'), hooks) + for hook in hooks: + try: + hook() + except quiet_errors: + cls.run_hooks(safe) + raise + except Exception: + cherrypy.log(traceback=True, severity=40) + cls.run_hooks(safe) + raise + + def __copy__(self): + newmap = self.__class__() + # We can't just use 'update' because we want copies of the + # mutable values (each is a list) as well. + for k, v in self.items(): + newmap[k] = v[:] + return newmap + copy = __copy__ + + def __repr__(self): + cls = self.__class__ + return '%s.%s(points=%r)' % ( + cls.__module__, + cls.__name__, + list(self) + ) + + +# Config namespace handlers + +def hooks_namespace(k, v): + """Attach bare hooks declared in config.""" + # Use split again to allow multiple hooks for a single + # hookpoint per path (e.g. "hooks.before_handler.1"). + # Little-known fact you only get from reading source ;) + hookpoint = k.split('.', 1)[0] + if isinstance(v, str): + v = cherrypy.lib.reprconf.attributes(v) + if not isinstance(v, Hook): + v = Hook(v) + cherrypy.serving.request.hooks[hookpoint].append(v) + + +def request_namespace(k, v): + """Attach request attributes declared in config.""" + # Provides config entries to set request.body attrs (like + # attempt_charsets). + if k[:5] == 'body.': + setattr(cherrypy.serving.request.body, k[5:], v) + else: + setattr(cherrypy.serving.request, k, v) + + +def response_namespace(k, v): + """Attach response attributes declared in config.""" + # Provides config entries to set default response headers + # http://cherrypy.org/ticket/889 + if k[:8] == 'headers.': + cherrypy.serving.response.headers[k.split('.', 1)[1]] = v + else: + setattr(cherrypy.serving.response, k, v) + + +def error_page_namespace(k, v): + """Attach error pages declared in config.""" + if k != 'default': + k = int(k) + cherrypy.serving.request.error_page[k] = v + + +hookpoints = ['on_start_resource', 'before_request_body', + 'before_handler', 'before_finalize', + 'on_end_resource', 'on_end_request', + 'before_error_response', 'after_error_response'] + + +class Request(object): + + """An HTTP request. + + This object represents the metadata of an HTTP request message; + that is, it contains attributes which describe the environment + in which the request URL, headers, and body were sent (if you + want tools to interpret the headers and body, those are elsewhere, + mostly in Tools). This 'metadata' consists of socket data, + transport characteristics, and the Request-Line. This object + also contains data regarding the configuration in effect for + the given URL, and the execution plan for generating a response. + """ + + prev = None + """ + The previous Request object (if any). This should be None + unless we are processing an InternalRedirect.""" + + # Conversation/connection attributes + local = httputil.Host('127.0.0.1', 80) + 'An httputil.Host(ip, port, hostname) object for the server socket.' + + remote = httputil.Host('127.0.0.1', 1111) + 'An httputil.Host(ip, port, hostname) object for the client socket.' + + scheme = 'http' + """ + The protocol used between client and server. In most cases, + this will be either 'http' or 'https'.""" + + server_protocol = 'HTTP/1.1' + """ + The HTTP version for which the HTTP server is at least + conditionally compliant.""" + + base = '' + """The (scheme://host) portion of the requested URL. + In some cases (e.g. when proxying via mod_rewrite), this may contain + path segments which cherrypy.url uses when constructing url's, but + which otherwise are ignored by CherryPy. Regardless, this value + MUST NOT end in a slash.""" + + # Request-Line attributes + request_line = '' + """ + The complete Request-Line received from the client. This is a + single string consisting of the request method, URI, and protocol + version (joined by spaces). Any final CRLF is removed.""" + + method = 'GET' + """ + Indicates the HTTP method to be performed on the resource identified + by the Request-URI. Common methods include GET, HEAD, POST, PUT, and + DELETE. CherryPy allows any extension method; however, various HTTP + servers and gateways may restrict the set of allowable methods. + CherryPy applications SHOULD restrict the set (on a per-URI basis).""" + + query_string = '' + """ + The query component of the Request-URI, a string of information to be + interpreted by the resource. The query portion of a URI follows the + path component, and is separated by a '?'. For example, the URI + 'http://www.cherrypy.org/wiki?a=3&b=4' has the query component, + 'a=3&b=4'.""" + + query_string_encoding = 'utf8' + """ + The encoding expected for query string arguments after % HEX HEX decoding). + If a query string is provided that cannot be decoded with this encoding, + 404 is raised (since technically it's a different URI). If you want + arbitrary encodings to not error, set this to 'Latin-1'; you can then + encode back to bytes and re-decode to whatever encoding you like later. + """ + + protocol = (1, 1) + """The HTTP protocol version corresponding to the set + of features which should be allowed in the response. If BOTH + the client's request message AND the server's level of HTTP + compliance is HTTP/1.1, this attribute will be the tuple (1, 1). + If either is 1.0, this attribute will be the tuple (1, 0). + Lower HTTP protocol versions are not explicitly supported.""" + + params = {} + """ + A dict which combines query string (GET) and request entity (POST) + variables. This is populated in two stages: GET params are added + before the 'on_start_resource' hook, and POST params are added + between the 'before_request_body' and 'before_handler' hooks.""" + + # Message attributes + header_list = [] + """ + A list of the HTTP request headers as (name, value) tuples. + In general, you should use request.headers (a dict) instead.""" + + headers = httputil.HeaderMap() + """ + A dict-like object containing the request headers. Keys are header + names (in Title-Case format); however, you may get and set them in + a case-insensitive manner. That is, headers['Content-Type'] and + headers['content-type'] refer to the same value. Values are header + values (decoded according to :rfc:`2047` if necessary). See also: + httputil.HeaderMap, httputil.HeaderElement.""" + + cookie = SimpleCookie() + """See help(Cookie).""" + + rfile = None + """ + If the request included an entity (body), it will be available + as a stream in this attribute. However, the rfile will normally + be read for you between the 'before_request_body' hook and the + 'before_handler' hook, and the resulting string is placed into + either request.params or the request.body attribute. + + You may disable the automatic consumption of the rfile by setting + request.process_request_body to False, either in config for the desired + path, or in an 'on_start_resource' or 'before_request_body' hook. + + WARNING: In almost every case, you should not attempt to read from the + rfile stream after CherryPy's automatic mechanism has read it. If you + turn off the automatic parsing of rfile, you should read exactly the + number of bytes specified in request.headers['Content-Length']. + Ignoring either of these warnings may result in a hung request thread + or in corruption of the next (pipelined) request. + """ + + process_request_body = True + """ + If True, the rfile (if any) is automatically read and parsed, + and the result placed into request.params or request.body.""" + + methods_with_bodies = ('POST', 'PUT', 'PATCH') + """ + A sequence of HTTP methods for which CherryPy will automatically + attempt to read a body from the rfile. If you are going to change + this property, modify it on the configuration (recommended) + or on the "hook point" `on_start_resource`. + """ + + body = None + """ + If the request Content-Type is 'application/x-www-form-urlencoded' + or multipart, this will be None. Otherwise, this will be an instance + of :class:`RequestBody` (which you + can .read()); this value is set between the 'before_request_body' and + 'before_handler' hooks (assuming that process_request_body is True).""" + + # Dispatch attributes + dispatch = cherrypy.dispatch.Dispatcher() + """ + The object which looks up the 'page handler' callable and collects + config for the current request based on the path_info, other + request attributes, and the application architecture. The core + calls the dispatcher as early as possible, passing it a 'path_info' + argument. + + The default dispatcher discovers the page handler by matching path_info + to a hierarchical arrangement of objects, starting at request.app.root. + See help(cherrypy.dispatch) for more information.""" + + script_name = '' + """ + The 'mount point' of the application which is handling this request. + + This attribute MUST NOT end in a slash. If the script_name refers to + the root of the URI, it MUST be an empty string (not "/"). + """ + + path_info = '/' + """ + The 'relative path' portion of the Request-URI. This is relative + to the script_name ('mount point') of the application which is + handling this request.""" + + login = None + """ + When authentication is used during the request processing this is + set to 'False' if it failed and to the 'username' value if it succeeded. + The default 'None' implies that no authentication happened.""" + + # Note that cherrypy.url uses "if request.app:" to determine whether + # the call is during a real HTTP request or not. So leave this None. + app = None + """The cherrypy.Application object which is handling this request.""" + + handler = None + """ + The function, method, or other callable which CherryPy will call to + produce the response. The discovery of the handler and the arguments + it will receive are determined by the request.dispatch object. + By default, the handler is discovered by walking a tree of objects + starting at request.app.root, and is then passed all HTTP params + (from the query string and POST body) as keyword arguments.""" + + toolmaps = {} + """ + A nested dict of all Toolboxes and Tools in effect for this request, + of the form: {Toolbox.namespace: {Tool.name: config dict}}.""" + + config = None + """ + A flat dict of all configuration entries which apply to the + current request. These entries are collected from global config, + application config (based on request.path_info), and from handler + config (exactly how is governed by the request.dispatch object in + effect for this request; by default, handler config can be attached + anywhere in the tree between request.app.root and the final handler, + and inherits downward).""" + + is_index = None + """ + This will be True if the current request is mapped to an 'index' + resource handler (also, a 'default' handler if path_info ends with + a slash). The value may be used to automatically redirect the + user-agent to a 'more canonical' URL which either adds or removes + the trailing slash. See cherrypy.tools.trailing_slash.""" + + hooks = HookMap(hookpoints) + """ + A HookMap (dict-like object) of the form: {hookpoint: [hook, ...]}. + Each key is a str naming the hook point, and each value is a list + of hooks which will be called at that hook point during this request. + The list of hooks is generally populated as early as possible (mostly + from Tools specified in config), but may be extended at any time. + See also: _cprequest.Hook, _cprequest.HookMap, and cherrypy.tools.""" + + error_response = cherrypy.HTTPError(500).set_response + """ + The no-arg callable which will handle unexpected, untrapped errors + during request processing. This is not used for expected exceptions + (like NotFound, HTTPError, or HTTPRedirect) which are raised in + response to expected conditions (those should be customized either + via request.error_page or by overriding HTTPError.set_response). + By default, error_response uses HTTPError(500) to return a generic + error response to the user-agent.""" + + error_page = {} + """ + A dict of {error code: response filename or callable} pairs. + + The error code must be an int representing a given HTTP error code, + or the string 'default', which will be used if no matching entry + is found for a given numeric code. + + If a filename is provided, the file should contain a Python string- + formatting template, and can expect by default to receive format + values with the mapping keys %(status)s, %(message)s, %(traceback)s, + and %(version)s. The set of format mappings can be extended by + overriding HTTPError.set_response. + + If a callable is provided, it will be called by default with keyword + arguments 'status', 'message', 'traceback', and 'version', as for a + string-formatting template. The callable must return a string or + iterable of strings which will be set to response.body. It may also + override headers or perform any other processing. + + If no entry is given for an error code, and no 'default' entry exists, + a default template will be used. + """ + + show_tracebacks = True + """ + If True, unexpected errors encountered during request processing will + include a traceback in the response body.""" + + show_mismatched_params = True + """ + If True, mismatched parameters encountered during PageHandler invocation + processing will be included in the response body.""" + + throws = (KeyboardInterrupt, SystemExit, cherrypy.InternalRedirect) + """The sequence of exceptions which Request.run does not trap.""" + + throw_errors = False + """ + If True, Request.run will not trap any errors (except HTTPRedirect and + HTTPError, which are more properly called 'exceptions', not errors).""" + + closed = False + """True once the close method has been called, False otherwise.""" + + stage = None + """ + A string containing the stage reached in the request-handling process. + This is useful when debugging a live server with hung requests.""" + + unique_id = None + """A lazy object generating and memorizing UUID4 on ``str()`` render.""" + + namespaces = reprconf.NamespaceSet( + **{'hooks': hooks_namespace, + 'request': request_namespace, + 'response': response_namespace, + 'error_page': error_page_namespace, + 'tools': cherrypy.tools, + }) + + def __init__(self, local_host, remote_host, scheme='http', + server_protocol='HTTP/1.1'): + """Populate a new Request object. + + local_host should be an httputil.Host object with the server info. + remote_host should be an httputil.Host object with the client info. + scheme should be a string, either "http" or "https". + """ + self.local = local_host + self.remote = remote_host + self.scheme = scheme + self.server_protocol = server_protocol + + self.closed = False + + # Put a *copy* of the class error_page into self. + self.error_page = self.error_page.copy() + + # Put a *copy* of the class namespaces into self. + self.namespaces = self.namespaces.copy() + + self.stage = None + + self.unique_id = LazyUUID4() + + def close(self): + """Run cleanup code. (Core)""" + if not self.closed: + self.closed = True + self.stage = 'on_end_request' + self.hooks.run('on_end_request') + self.stage = 'close' + + def run(self, method, path, query_string, req_protocol, headers, rfile): + r"""Process the Request. (Core) + + method, path, query_string, and req_protocol should be pulled directly + from the Request-Line (e.g. "GET /path?key=val HTTP/1.0"). + + path + This should be %XX-unquoted, but query_string should not be. + + When using Python 2, they both MUST be byte strings, + not unicode strings. + + When using Python 3, they both MUST be unicode strings, + not byte strings, and preferably not bytes \x00-\xFF + disguised as unicode. + + headers + A list of (name, value) tuples. + + rfile + A file-like object containing the HTTP request entity. + + When run() is done, the returned object should have 3 attributes: + + * status, e.g. "200 OK" + * header_list, a list of (name, value) tuples + * body, an iterable yielding strings + + Consumer code (HTTP servers) should then access these response + attributes to build the outbound stream. + + """ + response = cherrypy.serving.response + self.stage = 'run' + try: + self.error_response = cherrypy.HTTPError(500).set_response + + self.method = method + path = path or '/' + self.query_string = query_string or '' + self.params = {} + + # Compare request and server HTTP protocol versions, in case our + # server does not support the requested protocol. Limit our output + # to min(req, server). We want the following output: + # request server actual written supported response + # protocol protocol response protocol feature set + # a 1.0 1.0 1.0 1.0 + # b 1.0 1.1 1.1 1.0 + # c 1.1 1.0 1.0 1.0 + # d 1.1 1.1 1.1 1.1 + # Notice that, in (b), the response will be "HTTP/1.1" even though + # the client only understands 1.0. RFC 2616 10.5.6 says we should + # only return 505 if the _major_ version is different. + rp = int(req_protocol[5]), int(req_protocol[7]) + sp = int(self.server_protocol[5]), int(self.server_protocol[7]) + self.protocol = min(rp, sp) + response.headers.protocol = self.protocol + + # Rebuild first line of the request (e.g. "GET /path HTTP/1.0"). + url = path + if query_string: + url += '?' + query_string + self.request_line = '%s %s %s' % (method, url, req_protocol) + + self.header_list = list(headers) + self.headers = httputil.HeaderMap() + + self.rfile = rfile + self.body = None + + self.cookie = SimpleCookie() + self.handler = None + + # path_info should be the path from the + # app root (script_name) to the handler. + self.script_name = self.app.script_name + self.path_info = pi = path[len(self.script_name):] + + self.stage = 'respond' + self.respond(pi) + + except self.throws: + raise + except Exception: + if self.throw_errors: + raise + else: + # Failure in setup, error handler or finalize. Bypass them. + # Can't use handle_error because we may not have hooks yet. + cherrypy.log(traceback=True, severity=40) + if self.show_tracebacks: + body = format_exc() + else: + body = '' + r = bare_error(body) + response.output_status, response.header_list, response.body = r + + if self.method == 'HEAD': + # HEAD requests MUST NOT return a message-body in the response. + response.body = [] + + try: + cherrypy.log.access() + except Exception: + cherrypy.log.error(traceback=True) + + return response + + def respond(self, path_info): + """Generate a response for the resource at self.path_info. (Core)""" + try: + try: + try: + self._do_respond(path_info) + except (cherrypy.HTTPRedirect, cherrypy.HTTPError): + inst = sys.exc_info()[1] + inst.set_response() + self.stage = 'before_finalize (HTTPError)' + self.hooks.run('before_finalize') + cherrypy.serving.response.finalize() + finally: + self.stage = 'on_end_resource' + self.hooks.run('on_end_resource') + except self.throws: + raise + except Exception: + if self.throw_errors: + raise + self.handle_error() + + def _do_respond(self, path_info): + response = cherrypy.serving.response + + if self.app is None: + raise cherrypy.NotFound() + + self.hooks = self.__class__.hooks.copy() + self.toolmaps = {} + + # Get the 'Host' header, so we can HTTPRedirect properly. + self.stage = 'process_headers' + self.process_headers() + + self.stage = 'get_resource' + self.get_resource(path_info) + + self.body = _cpreqbody.RequestBody( + self.rfile, self.headers, request_params=self.params) + + self.namespaces(self.config) + + self.stage = 'on_start_resource' + self.hooks.run('on_start_resource') + + # Parse the querystring + self.stage = 'process_query_string' + self.process_query_string() + + # Process the body + if self.process_request_body: + if self.method not in self.methods_with_bodies: + self.process_request_body = False + self.stage = 'before_request_body' + self.hooks.run('before_request_body') + if self.process_request_body: + self.body.process() + + # Run the handler + self.stage = 'before_handler' + self.hooks.run('before_handler') + if self.handler: + self.stage = 'handler' + response.body = self.handler() + + # Finalize + self.stage = 'before_finalize' + self.hooks.run('before_finalize') + response.finalize() + + def process_query_string(self): + """Parse the query string into Python structures. (Core)""" + try: + p = httputil.parse_query_string( + self.query_string, encoding=self.query_string_encoding) + except UnicodeDecodeError: + raise cherrypy.HTTPError( + 404, 'The given query string could not be processed. Query ' + 'strings for this resource must be encoded with %r.' % + self.query_string_encoding) + + self.params.update(p) + + def process_headers(self): + """Parse HTTP header data into Python structures. (Core)""" + # Process the headers into self.headers + headers = self.headers + for name, value in self.header_list: + # Call title() now (and use dict.__method__(headers)) + # so title doesn't have to be called twice. + name = name.title() + value = value.strip() + + headers[name] = httputil.decode_TEXT_maybe(value) + + # Some clients, notably Konquoror, supply multiple + # cookies on different lines with the same key. To + # handle this case, store all cookies in self.cookie. + if name == 'Cookie': + try: + self.cookie.load(value) + except CookieError as exc: + raise cherrypy.HTTPError(400, str(exc)) + + if not dict.__contains__(headers, 'Host'): + # All Internet-based HTTP/1.1 servers MUST respond with a 400 + # (Bad Request) status code to any HTTP/1.1 request message + # which lacks a Host header field. + if self.protocol >= (1, 1): + msg = "HTTP/1.1 requires a 'Host' request header." + raise cherrypy.HTTPError(400, msg) + host = dict.get(headers, 'Host') + if not host: + host = self.local.name or self.local.ip + self.base = '%s://%s' % (self.scheme, host) + + def get_resource(self, path): + """Call a dispatcher (which sets self.handler and .config). (Core)""" + # First, see if there is a custom dispatch at this URI. Custom + # dispatchers can only be specified in app.config, not in _cp_config + # (since custom dispatchers may not even have an app.root). + dispatch = self.app.find_config( + path, 'request.dispatch', self.dispatch) + + # dispatch() should set self.handler and self.config + dispatch(path) + + def handle_error(self): + """Handle the last unanticipated exception. (Core)""" + try: + self.hooks.run('before_error_response') + if self.error_response: + self.error_response() + self.hooks.run('after_error_response') + cherrypy.serving.response.finalize() + except cherrypy.HTTPRedirect: + inst = sys.exc_info()[1] + inst.set_response() + cherrypy.serving.response.finalize() + + +class ResponseBody(object): + + """The body of the HTTP response (the response entity).""" + + unicode_err = ('Page handlers MUST return bytes. Use tools.encode ' + 'if you wish to return unicode.') + + def __get__(self, obj, objclass=None): + if obj is None: + # When calling on the class instead of an instance... + return self + else: + return obj._body + + def __set__(self, obj, value): + # Convert the given value to an iterable object. + if isinstance(value, str): + raise ValueError(self.unicode_err) + elif isinstance(value, list): + # every item in a list must be bytes... + if any(isinstance(item, str) for item in value): + raise ValueError(self.unicode_err) + + obj._body = encoding.prepare_iter(value) + + +class Response(object): + + """An HTTP Response, including status, headers, and body.""" + + status = '' + """The HTTP Status-Code and Reason-Phrase.""" + + header_list = [] + """ + A list of the HTTP response headers as (name, value) tuples. + In general, you should use response.headers (a dict) instead. This + attribute is generated from response.headers and is not valid until + after the finalize phase.""" + + headers = httputil.HeaderMap() + """ + A dict-like object containing the response headers. Keys are header + names (in Title-Case format); however, you may get and set them in + a case-insensitive manner. That is, headers['Content-Type'] and + headers['content-type'] refer to the same value. Values are header + values (decoded according to :rfc:`2047` if necessary). + + .. seealso:: classes :class:`HeaderMap`, :class:`HeaderElement` + """ + + cookie = SimpleCookie() + """See help(Cookie).""" + + body = ResponseBody() + """The body (entity) of the HTTP response.""" + + time = None + """The value of time.time() when created. Use in HTTP dates.""" + + stream = False + """If False, buffer the response body.""" + + def __init__(self): + self.status = None + self.header_list = None + self._body = [] + self.time = time.time() + + self.headers = httputil.HeaderMap() + # Since we know all our keys are titled strings, we can + # bypass HeaderMap.update and get a big speed boost. + dict.update(self.headers, { + 'Content-Type': 'text/html', + 'Server': 'CherryPy/' + cherrypy.__version__, + 'Date': httputil.HTTPDate(self.time), + }) + self.cookie = SimpleCookie() + + def collapse_body(self): + """Collapse self.body to a single string; replace it and return it.""" + new_body = b''.join(self.body) + self.body = new_body + return new_body + + def _flush_body(self): + """ + Discard self.body but consume any generator such that + any finalization can occur, such as is required by + caching.tee_output(). + """ + consume(iter(self.body)) + + def finalize(self): + """Transform headers (and cookies) into self.header_list. (Core)""" + try: + code, reason, _ = httputil.valid_status(self.status) + except ValueError: + raise cherrypy.HTTPError(500, sys.exc_info()[1].args[0]) + + headers = self.headers + + self.status = '%s %s' % (code, reason) + self.output_status = ntob(str(code), 'ascii') + \ + b' ' + headers.encode(reason) + + if self.stream: + # The upshot: wsgiserver will chunk the response if + # you pop Content-Length (or set it explicitly to None). + # Note that lib.static sets C-L to the file's st_size. + if dict.get(headers, 'Content-Length') is None: + dict.pop(headers, 'Content-Length', None) + elif code < 200 or code in (204, 205, 304): + # "All 1xx (informational), 204 (no content), + # and 304 (not modified) responses MUST NOT + # include a message-body." + dict.pop(headers, 'Content-Length', None) + self._flush_body() + self.body = b'' + else: + # Responses which are not streamed should have a Content-Length, + # but allow user code to set Content-Length if desired. + if dict.get(headers, 'Content-Length') is None: + content = self.collapse_body() + dict.__setitem__(headers, 'Content-Length', len(content)) + + # Transform our header dict into a list of tuples. + self.header_list = h = headers.output() + + cookie = self.cookie.output() + if cookie: + for line in cookie.split('\r\n'): + name, value = line.split(': ', 1) + if isinstance(name, str): + name = name.encode('ISO-8859-1') + if isinstance(value, str): + value = headers.encode(value) + h.append((name, value)) + + +class LazyUUID4(object): + def __str__(self): + """Return UUID4 and keep it for future calls.""" + return str(self.uuid4) + + @property + def uuid4(self): + """Provide unique id on per-request basis using UUID4. + + It's evaluated lazily on render. + """ + try: + self._uuid4 + except AttributeError: + # evaluate on first access + self._uuid4 = uuid.uuid4() + + return self._uuid4 diff --git a/resources/lib/cherrypy/_cpserver.py b/resources/lib/cherrypy/_cpserver.py new file mode 100644 index 0000000..5f8d98f --- /dev/null +++ b/resources/lib/cherrypy/_cpserver.py @@ -0,0 +1,241 @@ +"""Manage HTTP servers with CherryPy.""" + +import cherrypy +from cherrypy.lib.reprconf import attributes +from cherrypy._cpcompat import text_or_bytes +from cherrypy.process.servers import ServerAdapter + + +__all__ = ('Server', ) + + +class Server(ServerAdapter): + """An adapter for an HTTP server. + + You can set attributes (like socket_host and socket_port) + on *this* object (which is probably cherrypy.server), and call + quickstart. For example:: + + cherrypy.server.socket_port = 80 + cherrypy.quickstart() + """ + + socket_port = 8080 + """The TCP port on which to listen for connections.""" + + _socket_host = '127.0.0.1' + + @property + def socket_host(self): # noqa: D401; irrelevant for properties + """The hostname or IP address on which to listen for connections. + + Host values may be any IPv4 or IPv6 address, or any valid hostname. + The string 'localhost' is a synonym for '127.0.0.1' (or '::1', if + your hosts file prefers IPv6). The string '0.0.0.0' is a special + IPv4 entry meaning "any active interface" (INADDR_ANY), and '::' + is the similar IN6ADDR_ANY for IPv6. The empty string or None are + not allowed. + """ + return self._socket_host + + @socket_host.setter + def socket_host(self, value): + if value == '': + raise ValueError("The empty string ('') is not an allowed value. " + "Use '0.0.0.0' instead to listen on all active " + 'interfaces (INADDR_ANY).') + self._socket_host = value + + socket_file = None + """If given, the name of the UNIX socket to use instead of TCP/IP. + + When this option is not None, the `socket_host` and `socket_port` options + are ignored.""" + + socket_queue_size = 5 + """The 'backlog' argument to socket.listen(); specifies the maximum number + of queued connections (default 5).""" + + socket_timeout = 10 + """The timeout in seconds for accepted connections (default 10).""" + + accepted_queue_size = -1 + """The maximum number of requests which will be queued up before + the server refuses to accept it (default -1, meaning no limit).""" + + accepted_queue_timeout = 10 + """The timeout in seconds for attempting to add a request to the + queue when the queue is full (default 10).""" + + shutdown_timeout = 5 + """The time to wait for HTTP worker threads to clean up.""" + + protocol_version = 'HTTP/1.1' + """The version string to write in the Status-Line of all HTTP responses, + for example, "HTTP/1.1" (the default). Depending on the HTTP server used, + this should also limit the supported features used in the response.""" + + thread_pool = 10 + """The number of worker threads to start up in the pool.""" + + thread_pool_max = -1 + """The maximum size of the worker-thread pool. Use -1 to indicate no limit. + """ + + max_request_header_size = 500 * 1024 + """The maximum number of bytes allowable in the request headers. + If exceeded, the HTTP server should return "413 Request Entity Too Large". + """ + + max_request_body_size = 100 * 1024 * 1024 + """The maximum number of bytes allowable in the request body. If exceeded, + the HTTP server should return "413 Request Entity Too Large".""" + + instance = None + """If not None, this should be an HTTP server instance (such as + cheroot.wsgi.Server) which cherrypy.server will control. + Use this when you need + more control over object instantiation than is available in the various + configuration options.""" + + ssl_context = None + """When using PyOpenSSL, an instance of SSL.Context.""" + + ssl_certificate = None + """The filename of the SSL certificate to use.""" + + ssl_certificate_chain = None + """When using PyOpenSSL, the certificate chain to pass to + Context.load_verify_locations.""" + + ssl_private_key = None + """The filename of the private key to use with SSL.""" + + ssl_ciphers = None + """The ciphers list of SSL.""" + + ssl_module = 'builtin' + """The name of a registered SSL adaptation module to use with + the builtin WSGI server. Builtin options are: 'builtin' (to + use the SSL library built into recent versions of Python). + You may also register your own classes in the + cheroot.server.ssl_adapters dict.""" + + statistics = False + """Turns statistics-gathering on or off for aware HTTP servers.""" + + nodelay = True + """If True (the default since 3.1), sets the TCP_NODELAY socket option.""" + + wsgi_version = (1, 0) + """The WSGI version tuple to use with the builtin WSGI server. + The provided options are (1, 0) [which includes support for PEP 3333, + which declares it covers WSGI version 1.0.1 but still mandates the + wsgi.version (1, 0)] and ('u', 0), an experimental unicode version. + You may create and register your own experimental versions of the WSGI + protocol by adding custom classes to the cheroot.server.wsgi_gateways dict. + """ + + peercreds = False + """If True, peer cred lookup for UNIX domain socket will put to WSGI env. + + This information will then be available through WSGI env vars: + * X_REMOTE_PID + * X_REMOTE_UID + * X_REMOTE_GID + """ + + peercreds_resolve = False + """If True, username/group will be looked up in the OS from peercreds. + + This information will then be available through WSGI env vars: + * REMOTE_USER + * X_REMOTE_USER + * X_REMOTE_GROUP + """ + + def __init__(self): + """Initialize Server instance.""" + self.bus = cherrypy.engine + self.httpserver = None + self.interrupt = None + self.running = False + + def httpserver_from_self(self, httpserver=None): + """Return a (httpserver, bind_addr) pair based on self attributes.""" + if httpserver is None: + httpserver = self.instance + if httpserver is None: + from cherrypy import _cpwsgi_server + httpserver = _cpwsgi_server.CPWSGIServer(self) + if isinstance(httpserver, text_or_bytes): + # Is anyone using this? Can I add an arg? + httpserver = attributes(httpserver)(self) + return httpserver, self.bind_addr + + def start(self): + """Start the HTTP server.""" + if not self.httpserver: + self.httpserver, self.bind_addr = self.httpserver_from_self() + super(Server, self).start() + start.priority = 75 + + @property + def bind_addr(self): + """Return bind address. + + A (host, port) tuple for TCP sockets or a str for Unix domain sockts. + """ + if self.socket_file: + return self.socket_file + if self.socket_host is None and self.socket_port is None: + return None + return (self.socket_host, self.socket_port) + + @bind_addr.setter + def bind_addr(self, value): + if value is None: + self.socket_file = None + self.socket_host = None + self.socket_port = None + elif isinstance(value, text_or_bytes): + self.socket_file = value + self.socket_host = None + self.socket_port = None + else: + try: + self.socket_host, self.socket_port = value + self.socket_file = None + except ValueError: + raise ValueError('bind_addr must be a (host, port) tuple ' + '(for TCP sockets) or a string (for Unix ' + 'domain sockets), not %r' % value) + + def base(self): + """Return the base for this server. + + e.i. scheme://host[:port] or sock file + """ + if self.socket_file: + return self.socket_file + + host = self.socket_host + if host in ('0.0.0.0', '::'): + # 0.0.0.0 is INADDR_ANY and :: is IN6ADDR_ANY. + # Look up the host name, which should be the + # safest thing to spit out in a URL. + import socket + host = socket.gethostname() + + port = self.socket_port + + if self.ssl_certificate: + scheme = 'https' + if port != 443: + host += ':%s' % port + else: + scheme = 'http' + if port != 80: + host += ':%s' % port + + return '%s://%s' % (scheme, host) diff --git a/resources/lib/cherrypy/_cptools.py b/resources/lib/cherrypy/_cptools.py new file mode 100644 index 0000000..716f99a --- /dev/null +++ b/resources/lib/cherrypy/_cptools.py @@ -0,0 +1,502 @@ +"""CherryPy tools. A "tool" is any helper, adapted to CP. + +Tools are usually designed to be used in a variety of ways (although some +may only offer one if they choose): + + Library calls + All tools are callables that can be used wherever needed. + The arguments are straightforward and should be detailed within the + docstring. + + Function decorators + All tools, when called, may be used as decorators which configure + individual CherryPy page handlers (methods on the CherryPy tree). + That is, "@tools.anytool()" should "turn on" the tool via the + decorated function's _cp_config attribute. + + CherryPy config + If a tool exposes a "_setup" callable, it will be called + once per Request (if the feature is "turned on" via config). + +Tools may be implemented as any object with a namespace. The builtins +are generally either modules or instances of the tools.Tool class. +""" + +import cherrypy +from cherrypy._helper import expose + +from cherrypy.lib import cptools, encoding, static, jsontools +from cherrypy.lib import sessions as _sessions, xmlrpcutil as _xmlrpc +from cherrypy.lib import caching as _caching +from cherrypy.lib import auth_basic, auth_digest + + +def _getargs(func): + """Return the names of all static arguments to the given function.""" + # Use this instead of importing inspect for less mem overhead. + import types + if isinstance(func, types.MethodType): + func = func.__func__ + co = func.__code__ + return co.co_varnames[:co.co_argcount] + + +_attr_error = ( + 'CherryPy Tools cannot be turned on directly. Instead, turn them ' + 'on via config, or use them as decorators on your page handlers.' +) + + +class Tool(object): + + """A registered function for use with CherryPy request-processing hooks. + + help(tool.callable) should give you more information about this Tool. + """ + + namespace = 'tools' + + def __init__(self, point, callable, name=None, priority=50): + self._point = point + self.callable = callable + self._name = name + self._priority = priority + self.__doc__ = self.callable.__doc__ + self._setargs() + + @property + def on(self): + raise AttributeError(_attr_error) + + @on.setter + def on(self, value): + raise AttributeError(_attr_error) + + def _setargs(self): + """Copy func parameter names to obj attributes.""" + try: + for arg in _getargs(self.callable): + setattr(self, arg, None) + except (TypeError, AttributeError): + if hasattr(self.callable, '__call__'): + for arg in _getargs(self.callable.__call__): + setattr(self, arg, None) + # IronPython 1.0 raises NotImplementedError because + # inspect.getargspec tries to access Python bytecode + # in co_code attribute. + except NotImplementedError: + pass + # IronPython 1B1 may raise IndexError in some cases, + # but if we trap it here it doesn't prevent CP from working. + except IndexError: + pass + + def _merged_args(self, d=None): + """Return a dict of configuration entries for this Tool.""" + if d: + conf = d.copy() + else: + conf = {} + + tm = cherrypy.serving.request.toolmaps[self.namespace] + if self._name in tm: + conf.update(tm[self._name]) + + if 'on' in conf: + del conf['on'] + + return conf + + def __call__(self, *args, **kwargs): + """Compile-time decorator (turn on the tool in config). + + For example:: + + @expose + @tools.proxy() + def whats_my_base(self): + return cherrypy.request.base + """ + if args: + raise TypeError('The %r Tool does not accept positional ' + 'arguments; you must use keyword arguments.' + % self._name) + + def tool_decorator(f): + if not hasattr(f, '_cp_config'): + f._cp_config = {} + subspace = self.namespace + '.' + self._name + '.' + f._cp_config[subspace + 'on'] = True + for k, v in kwargs.items(): + f._cp_config[subspace + k] = v + return f + return tool_decorator + + def _setup(self): + """Hook this tool into cherrypy.request. + + The standard CherryPy request object will automatically call this + method when the tool is "turned on" in config. + """ + conf = self._merged_args() + p = conf.pop('priority', None) + if p is None: + p = getattr(self.callable, 'priority', self._priority) + cherrypy.serving.request.hooks.attach(self._point, self.callable, + priority=p, **conf) + + +class HandlerTool(Tool): + + """Tool which is called 'before main', that may skip normal handlers. + + If the tool successfully handles the request (by setting response.body), + if should return True. This will cause CherryPy to skip any 'normal' page + handler. If the tool did not handle the request, it should return False + to tell CherryPy to continue on and call the normal page handler. If the + tool is declared AS a page handler (see the 'handler' method), returning + False will raise NotFound. + """ + + def __init__(self, callable, name=None): + Tool.__init__(self, 'before_handler', callable, name) + + def handler(self, *args, **kwargs): + """Use this tool as a CherryPy page handler. + + For example:: + + class Root: + nav = tools.staticdir.handler(section="/nav", dir="nav", + root=absDir) + """ + @expose + def handle_func(*a, **kw): + handled = self.callable(*args, **self._merged_args(kwargs)) + if not handled: + raise cherrypy.NotFound() + return cherrypy.serving.response.body + return handle_func + + def _wrapper(self, **kwargs): + if self.callable(**kwargs): + cherrypy.serving.request.handler = None + + def _setup(self): + """Hook this tool into cherrypy.request. + + The standard CherryPy request object will automatically call this + method when the tool is "turned on" in config. + """ + conf = self._merged_args() + p = conf.pop('priority', None) + if p is None: + p = getattr(self.callable, 'priority', self._priority) + cherrypy.serving.request.hooks.attach(self._point, self._wrapper, + priority=p, **conf) + + +class HandlerWrapperTool(Tool): + + """Tool which wraps request.handler in a provided wrapper function. + + The 'newhandler' arg must be a handler wrapper function that takes a + 'next_handler' argument, plus ``*args`` and ``**kwargs``. Like all + page handler + functions, it must return an iterable for use as cherrypy.response.body. + + For example, to allow your 'inner' page handlers to return dicts + which then get interpolated into a template:: + + def interpolator(next_handler, *args, **kwargs): + filename = cherrypy.request.config.get('template') + cherrypy.response.template = env.get_template(filename) + response_dict = next_handler(*args, **kwargs) + return cherrypy.response.template.render(**response_dict) + cherrypy.tools.jinja = HandlerWrapperTool(interpolator) + """ + + def __init__(self, newhandler, point='before_handler', name=None, + priority=50): + self.newhandler = newhandler + self._point = point + self._name = name + self._priority = priority + + def callable(self, *args, **kwargs): + innerfunc = cherrypy.serving.request.handler + + def wrap(*args, **kwargs): + return self.newhandler(innerfunc, *args, **kwargs) + cherrypy.serving.request.handler = wrap + + +class ErrorTool(Tool): + + """Tool which is used to replace the default request.error_response.""" + + def __init__(self, callable, name=None): + Tool.__init__(self, None, callable, name) + + def _wrapper(self): + self.callable(**self._merged_args()) + + def _setup(self): + """Hook this tool into cherrypy.request. + + The standard CherryPy request object will automatically call this + method when the tool is "turned on" in config. + """ + cherrypy.serving.request.error_response = self._wrapper + + +# Builtin tools # + + +class SessionTool(Tool): + + """Session Tool for CherryPy. + + sessions.locking + When 'implicit' (the default), the session will be locked for you, + just before running the page handler. + + When 'early', the session will be locked before reading the request + body. This is off by default for safety reasons; for example, + a large upload would block the session, denying an AJAX + progress meter + (`issue `_). + + When 'explicit' (or any other value), you need to call + cherrypy.session.acquire_lock() yourself before using + session data. + """ + + def __init__(self): + # _sessions.init must be bound after headers are read + Tool.__init__(self, 'before_request_body', _sessions.init) + + def _lock_session(self): + cherrypy.serving.session.acquire_lock() + + def _setup(self): + """Hook this tool into cherrypy.request. + + The standard CherryPy request object will automatically call this + method when the tool is "turned on" in config. + """ + hooks = cherrypy.serving.request.hooks + + conf = self._merged_args() + + p = conf.pop('priority', None) + if p is None: + p = getattr(self.callable, 'priority', self._priority) + + hooks.attach(self._point, self.callable, priority=p, **conf) + + locking = conf.pop('locking', 'implicit') + if locking == 'implicit': + hooks.attach('before_handler', self._lock_session) + elif locking == 'early': + # Lock before the request body (but after _sessions.init runs!) + hooks.attach('before_request_body', self._lock_session, + priority=60) + else: + # Don't lock + pass + + hooks.attach('before_finalize', _sessions.save) + hooks.attach('on_end_request', _sessions.close) + + def regenerate(self): + """Drop the current session and make a new one (with a new id).""" + sess = cherrypy.serving.session + sess.regenerate() + + # Grab cookie-relevant tool args + relevant = 'path', 'path_header', 'name', 'timeout', 'domain', 'secure' + conf = dict( + (k, v) + for k, v in self._merged_args().items() + if k in relevant + ) + _sessions.set_response_cookie(**conf) + + +class XMLRPCController(object): + + """A Controller (page handler collection) for XML-RPC. + + To use it, have your controllers subclass this base class (it will + turn on the tool for you). + + You can also supply the following optional config entries:: + + tools.xmlrpc.encoding: 'utf-8' + tools.xmlrpc.allow_none: 0 + + XML-RPC is a rather discontinuous layer over HTTP; dispatching to the + appropriate handler must first be performed according to the URL, and + then a second dispatch step must take place according to the RPC method + specified in the request body. It also allows a superfluous "/RPC2" + prefix in the URL, supplies its own handler args in the body, and + requires a 200 OK "Fault" response instead of 404 when the desired + method is not found. + + Therefore, XML-RPC cannot be implemented for CherryPy via a Tool alone. + This Controller acts as the dispatch target for the first half (based + on the URL); it then reads the RPC method from the request body and + does its own second dispatch step based on that method. It also reads + body params, and returns a Fault on error. + + The XMLRPCDispatcher strips any /RPC2 prefix; if you aren't using /RPC2 + in your URL's, you can safely skip turning on the XMLRPCDispatcher. + Otherwise, you need to use declare it in config:: + + request.dispatch: cherrypy.dispatch.XMLRPCDispatcher() + """ + + # Note we're hard-coding this into the 'tools' namespace. We could do + # a huge amount of work to make it relocatable, but the only reason why + # would be if someone actually disabled the default_toolbox. Meh. + _cp_config = {'tools.xmlrpc.on': True} + + @expose + def default(self, *vpath, **params): + rpcparams, rpcmethod = _xmlrpc.process_body() + + subhandler = self + for attr in str(rpcmethod).split('.'): + subhandler = getattr(subhandler, attr, None) + + if subhandler and getattr(subhandler, 'exposed', False): + body = subhandler(*(vpath + rpcparams), **params) + + else: + # https://github.com/cherrypy/cherrypy/issues/533 + # if a method is not found, an xmlrpclib.Fault should be returned + # raising an exception here will do that; see + # cherrypy.lib.xmlrpcutil.on_error + raise Exception('method "%s" is not supported' % attr) + + conf = cherrypy.serving.request.toolmaps['tools'].get('xmlrpc', {}) + _xmlrpc.respond(body, + conf.get('encoding', 'utf-8'), + conf.get('allow_none', 0)) + return cherrypy.serving.response.body + + +class SessionAuthTool(HandlerTool): + pass + + +class CachingTool(Tool): + + """Caching Tool for CherryPy.""" + + def _wrapper(self, **kwargs): + request = cherrypy.serving.request + if _caching.get(**kwargs): + request.handler = None + else: + if request.cacheable: + # Note the devious technique here of adding hooks on the fly + request.hooks.attach('before_finalize', _caching.tee_output, + priority=100) + _wrapper.priority = 90 + + def _setup(self): + """Hook caching into cherrypy.request.""" + conf = self._merged_args() + + p = conf.pop('priority', None) + cherrypy.serving.request.hooks.attach('before_handler', self._wrapper, + priority=p, **conf) + + +class Toolbox(object): + + """A collection of Tools. + + This object also functions as a config namespace handler for itself. + Custom toolboxes should be added to each Application's toolboxes dict. + """ + + def __init__(self, namespace): + self.namespace = namespace + + def __setattr__(self, name, value): + # If the Tool._name is None, supply it from the attribute name. + if isinstance(value, Tool): + if value._name is None: + value._name = name + value.namespace = self.namespace + object.__setattr__(self, name, value) + + def __enter__(self): + """Populate request.toolmaps from tools specified in config.""" + cherrypy.serving.request.toolmaps[self.namespace] = map = {} + + def populate(k, v): + toolname, arg = k.split('.', 1) + bucket = map.setdefault(toolname, {}) + bucket[arg] = v + return populate + + def __exit__(self, exc_type, exc_val, exc_tb): + """Run tool._setup() for each tool in our toolmap.""" + map = cherrypy.serving.request.toolmaps.get(self.namespace) + if map: + for name, settings in map.items(): + if settings.get('on', False): + tool = getattr(self, name) + tool._setup() + + def register(self, point, **kwargs): + """ + Return a decorator which registers the function + at the given hook point. + """ + def decorator(func): + attr_name = kwargs.get('name', func.__name__) + tool = Tool(point, func, **kwargs) + setattr(self, attr_name, tool) + return func + return decorator + + +default_toolbox = _d = Toolbox('tools') +_d.session_auth = SessionAuthTool(cptools.session_auth) +_d.allow = Tool('on_start_resource', cptools.allow) +_d.proxy = Tool('before_request_body', cptools.proxy, priority=30) +_d.response_headers = Tool('on_start_resource', cptools.response_headers) +_d.log_tracebacks = Tool('before_error_response', cptools.log_traceback) +_d.log_headers = Tool('before_error_response', cptools.log_request_headers) +_d.log_hooks = Tool('on_end_request', cptools.log_hooks, priority=100) +_d.err_redirect = ErrorTool(cptools.redirect) +_d.etags = Tool('before_finalize', cptools.validate_etags, priority=75) +_d.decode = Tool('before_request_body', encoding.decode) +# the order of encoding, gzip, caching is important +_d.encode = Tool('before_handler', encoding.ResponseEncoder, priority=70) +_d.gzip = Tool('before_finalize', encoding.gzip, priority=80) +_d.staticdir = HandlerTool(static.staticdir) +_d.staticfile = HandlerTool(static.staticfile) +_d.sessions = SessionTool() +_d.xmlrpc = ErrorTool(_xmlrpc.on_error) +_d.caching = CachingTool('before_handler', _caching.get, 'caching') +_d.expires = Tool('before_finalize', _caching.expires) +_d.ignore_headers = Tool('before_request_body', cptools.ignore_headers) +_d.referer = Tool('before_request_body', cptools.referer) +_d.trailing_slash = Tool('before_handler', cptools.trailing_slash, priority=60) +_d.flatten = Tool('before_finalize', cptools.flatten) +_d.accept = Tool('on_start_resource', cptools.accept) +_d.redirect = Tool('on_start_resource', cptools.redirect) +_d.autovary = Tool('on_start_resource', cptools.autovary, priority=0) +_d.json_in = Tool('before_request_body', jsontools.json_in, priority=30) +_d.json_out = Tool('before_handler', jsontools.json_out, priority=30) +_d.auth_basic = Tool('before_handler', auth_basic.basic_auth, priority=1) +_d.auth_digest = Tool('before_handler', auth_digest.digest_auth, priority=1) +_d.params = Tool('before_handler', cptools.convert_params, priority=15) + +del _d, cptools, encoding, static diff --git a/resources/lib/cherrypy/_cptree.py b/resources/lib/cherrypy/_cptree.py new file mode 100644 index 0000000..917c5b1 --- /dev/null +++ b/resources/lib/cherrypy/_cptree.py @@ -0,0 +1,302 @@ +"""CherryPy Application and Tree objects.""" + +import os + +import cherrypy +from cherrypy import _cpconfig, _cplogging, _cprequest, _cpwsgi, tools +from cherrypy.lib import httputil, reprconf + + +class Application(object): + """A CherryPy Application. + + Servers and gateways should not instantiate Request objects directly. + Instead, they should ask an Application object for a request object. + + An instance of this class may also be used as a WSGI callable + (WSGI application object) for itself. + """ + + root = None + """The top-most container of page handlers for this app. Handlers should + be arranged in a hierarchy of attributes, matching the expected URI + hierarchy; the default dispatcher then searches this hierarchy for a + matching handler. When using a dispatcher other than the default, + this value may be None.""" + + config = {} + """A dict of {path: pathconf} pairs, where 'pathconf' is itself a dict + of {key: value} pairs.""" + + namespaces = reprconf.NamespaceSet() + toolboxes = {'tools': cherrypy.tools} + + log = None + """A LogManager instance. See _cplogging.""" + + wsgiapp = None + """A CPWSGIApp instance. See _cpwsgi.""" + + request_class = _cprequest.Request + response_class = _cprequest.Response + + relative_urls = False + + def __init__(self, root, script_name='', config=None): + """Initialize Application with given root.""" + self.log = _cplogging.LogManager(id(self), cherrypy.log.logger_root) + self.root = root + self.script_name = script_name + self.wsgiapp = _cpwsgi.CPWSGIApp(self) + + self.namespaces = self.namespaces.copy() + self.namespaces['log'] = lambda k, v: setattr(self.log, k, v) + self.namespaces['wsgi'] = self.wsgiapp.namespace_handler + + self.config = self.__class__.config.copy() + if config: + self.merge(config) + + def __repr__(self): + """Generate a representation of the Application instance.""" + return '%s.%s(%r, %r)' % (self.__module__, self.__class__.__name__, + self.root, self.script_name) + + script_name_doc = """The URI "mount point" for this app. A mount point + is that portion of the URI which is constant for all URIs that are + serviced by this application; it does not include scheme, host, or proxy + ("virtual host") portions of the URI. + + For example, if script_name is "/my/cool/app", then the URL + "http://www.example.com/my/cool/app/page1" might be handled by a + "page1" method on the root object. + + The value of script_name MUST NOT end in a slash. If the script_name + refers to the root of the URI, it MUST be an empty string (not "/"). + + If script_name is explicitly set to None, then the script_name will be + provided for each call from request.wsgi_environ['SCRIPT_NAME']. + """ + + @property + def script_name(self): # noqa: D401; irrelevant for properties + """The URI "mount point" for this app. + + A mount point is that portion of the URI which is constant for all URIs + that are serviced by this application; it does not include scheme, + host, or proxy ("virtual host") portions of the URI. + + For example, if script_name is "/my/cool/app", then the URL + "http://www.example.com/my/cool/app/page1" might be handled by a + "page1" method on the root object. + + The value of script_name MUST NOT end in a slash. If the script_name + refers to the root of the URI, it MUST be an empty string (not "/"). + + If script_name is explicitly set to None, then the script_name will be + provided for each call from request.wsgi_environ['SCRIPT_NAME']. + """ + if self._script_name is not None: + return self._script_name + + # A `_script_name` with a value of None signals that the script name + # should be pulled from WSGI environ. + return cherrypy.serving.request.wsgi_environ['SCRIPT_NAME'].rstrip('/') + + @script_name.setter + def script_name(self, value): + if value: + value = value.rstrip('/') + self._script_name = value + + def merge(self, config): + """Merge the given config into self.config.""" + _cpconfig.merge(self.config, config) + + # Handle namespaces specified in config. + self.namespaces(self.config.get('/', {})) + + def find_config(self, path, key, default=None): + """Return the most-specific value for key along path, or default.""" + trail = path or '/' + while trail: + nodeconf = self.config.get(trail, {}) + + if key in nodeconf: + return nodeconf[key] + + lastslash = trail.rfind('/') + if lastslash == -1: + break + elif lastslash == 0 and trail != '/': + trail = '/' + else: + trail = trail[:lastslash] + + return default + + def get_serving(self, local, remote, scheme, sproto): + """Create and return a Request and Response object.""" + req = self.request_class(local, remote, scheme, sproto) + req.app = self + + for name, toolbox in self.toolboxes.items(): + req.namespaces[name] = toolbox + + resp = self.response_class() + cherrypy.serving.load(req, resp) + cherrypy.engine.publish('acquire_thread') + cherrypy.engine.publish('before_request') + + return req, resp + + def release_serving(self): + """Release the current serving (request and response).""" + req = cherrypy.serving.request + + cherrypy.engine.publish('after_request') + + try: + req.close() + except Exception: + cherrypy.log(traceback=True, severity=40) + + cherrypy.serving.clear() + + def __call__(self, environ, start_response): + """Call a WSGI-callable.""" + return self.wsgiapp(environ, start_response) + + +class Tree(object): + """A registry of CherryPy applications, mounted at diverse points. + + An instance of this class may also be used as a WSGI callable + (WSGI application object), in which case it dispatches to all + mounted apps. + """ + + apps = {} + """ + A dict of the form {script name: application}, where "script name" + is a string declaring the URI mount point (no trailing slash), and + "application" is an instance of cherrypy.Application (or an arbitrary + WSGI callable if you happen to be using a WSGI server).""" + + def __init__(self): + """Initialize registry Tree.""" + self.apps = {} + + def mount(self, root, script_name='', config=None): + """Mount a new app from a root object, script_name, and config. + + root + An instance of a "controller class" (a collection of page + handler methods) which represents the root of the application. + This may also be an Application instance, or None if using + a dispatcher other than the default. + + script_name + A string containing the "mount point" of the application. + This should start with a slash, and be the path portion of the + URL at which to mount the given root. For example, if root.index() + will handle requests to "http://www.example.com:8080/dept/app1/", + then the script_name argument would be "/dept/app1". + + It MUST NOT end in a slash. If the script_name refers to the + root of the URI, it MUST be an empty string (not "/"). + + config + A file or dict containing application config. + """ + if script_name is None: + raise TypeError( + "The 'script_name' argument may not be None. Application " + 'objects may, however, possess a script_name of None (in ' + 'order to inpect the WSGI environ for SCRIPT_NAME upon each ' + 'request). You cannot mount such Applications on this Tree; ' + 'you must pass them to a WSGI server interface directly.') + + # Next line both 1) strips trailing slash and 2) maps "/" -> "". + script_name = script_name.rstrip('/') + + if isinstance(root, Application): + app = root + if script_name != '' and script_name != app.script_name: + raise ValueError( + 'Cannot specify a different script name and pass an ' + 'Application instance to cherrypy.mount') + script_name = app.script_name + else: + app = Application(root, script_name) + + # If mounted at "", add favicon.ico + needs_favicon = ( + script_name == '' + and root is not None + and not hasattr(root, 'favicon_ico') + ) + if needs_favicon: + favicon = os.path.join( + os.getcwd(), + os.path.dirname(__file__), + 'favicon.ico', + ) + root.favicon_ico = tools.staticfile.handler(favicon) + + if config: + app.merge(config) + + self.apps[script_name] = app + + return app + + def graft(self, wsgi_callable, script_name=''): + """Mount a wsgi callable at the given script_name.""" + # Next line both 1) strips trailing slash and 2) maps "/" -> "". + script_name = script_name.rstrip('/') + self.apps[script_name] = wsgi_callable + + def script_name(self, path=None): + """Return the script_name of the app at the given path, or None. + + If path is None, cherrypy.request is used. + """ + if path is None: + try: + request = cherrypy.serving.request + path = httputil.urljoin(request.script_name, + request.path_info) + except AttributeError: + return None + + while True: + if path in self.apps: + return path + + if path == '': + return None + + # Move one node up the tree and try again. + path = path[:path.rfind('/')] + + def __call__(self, environ, start_response): + """Pre-initialize WSGI env and call WSGI-callable.""" + # If you're calling this, then you're probably setting SCRIPT_NAME + # to '' (some WSGI servers always set SCRIPT_NAME to ''). + # Try to look up the app using the full path. + env1x = environ + path = httputil.urljoin(env1x.get('SCRIPT_NAME', ''), + env1x.get('PATH_INFO', '')) + sn = self.script_name(path or '/') + if sn is None: + start_response('404 Not Found', []) + return [] + + app = self.apps[sn] + + # Correct the SCRIPT_NAME and PATH_INFO environ entries. + environ = environ.copy() + environ['SCRIPT_NAME'] = sn + environ['PATH_INFO'] = path[len(sn.rstrip('/')):] + return app(environ, start_response) diff --git a/resources/lib/cherrypy/_cpwsgi.py b/resources/lib/cherrypy/_cpwsgi.py new file mode 100644 index 0000000..b4f55fd --- /dev/null +++ b/resources/lib/cherrypy/_cpwsgi.py @@ -0,0 +1,451 @@ +"""WSGI interface (see PEP 333 and 3333). + +Note that WSGI environ keys and values are 'native strings'; that is, +whatever the type of "" is. For Python 2, that's a byte string; for Python 3, +it's a unicode string. But PEP 3333 says: "even if Python's str type is +actually Unicode "under the hood", the content of native strings must +still be translatable to bytes via the Latin-1 encoding!" +""" + +import sys as _sys +import io + +import cherrypy as _cherrypy +from cherrypy._cpcompat import ntou +from cherrypy import _cperror +from cherrypy.lib import httputil +from cherrypy.lib import is_closable_iterator + + +def downgrade_wsgi_ux_to_1x(environ): + """Return a new environ dict for WSGI 1.x from the given WSGI u.x environ. + """ + env1x = {} + + url_encoding = environ[ntou('wsgi.url_encoding')] + for k, v in environ.copy().items(): + if k in [ntou('PATH_INFO'), ntou('SCRIPT_NAME'), ntou('QUERY_STRING')]: + v = v.encode(url_encoding) + elif isinstance(v, str): + v = v.encode('ISO-8859-1') + env1x[k.encode('ISO-8859-1')] = v + + return env1x + + +class VirtualHost(object): + + """Select a different WSGI application based on the Host header. + + This can be useful when running multiple sites within one CP server. + It allows several domains to point to different applications. For example:: + + root = Root() + RootApp = cherrypy.Application(root) + Domain2App = cherrypy.Application(root) + SecureApp = cherrypy.Application(Secure()) + + vhost = cherrypy._cpwsgi.VirtualHost( + RootApp, + domains={ + 'www.domain2.example': Domain2App, + 'www.domain2.example:443': SecureApp, + }, + ) + + cherrypy.tree.graft(vhost) + """ + default = None + """Required. The default WSGI application.""" + + use_x_forwarded_host = True + """If True (the default), any "X-Forwarded-Host" + request header will be used instead of the "Host" header. This + is commonly added by HTTP servers (such as Apache) when proxying.""" + + domains = {} + """A dict of {host header value: application} pairs. + The incoming "Host" request header is looked up in this dict, + and, if a match is found, the corresponding WSGI application + will be called instead of the default. Note that you often need + separate entries for "example.com" and "www.example.com". + In addition, "Host" headers may contain the port number. + """ + + def __init__(self, default, domains=None, use_x_forwarded_host=True): + self.default = default + self.domains = domains or {} + self.use_x_forwarded_host = use_x_forwarded_host + + def __call__(self, environ, start_response): + domain = environ.get('HTTP_HOST', '') + if self.use_x_forwarded_host: + domain = environ.get('HTTP_X_FORWARDED_HOST', domain) + + nextapp = self.domains.get(domain) + if nextapp is None: + nextapp = self.default + return nextapp(environ, start_response) + + +class InternalRedirector(object): + + """WSGI middleware that handles raised cherrypy.InternalRedirect.""" + + def __init__(self, nextapp, recursive=False): + self.nextapp = nextapp + self.recursive = recursive + + def __call__(self, environ, start_response): + redirections = [] + while True: + environ = environ.copy() + try: + return self.nextapp(environ, start_response) + except _cherrypy.InternalRedirect: + ir = _sys.exc_info()[1] + sn = environ.get('SCRIPT_NAME', '') + path = environ.get('PATH_INFO', '') + qs = environ.get('QUERY_STRING', '') + + # Add the *previous* path_info + qs to redirections. + old_uri = sn + path + if qs: + old_uri += '?' + qs + redirections.append(old_uri) + + if not self.recursive: + # Check to see if the new URI has been redirected to + # already + new_uri = sn + ir.path + if ir.query_string: + new_uri += '?' + ir.query_string + if new_uri in redirections: + ir.request.close() + tmpl = ( + 'InternalRedirector visited the same URL twice: %r' + ) + raise RuntimeError(tmpl % new_uri) + + # Munge the environment and try again. + environ['REQUEST_METHOD'] = 'GET' + environ['PATH_INFO'] = ir.path + environ['QUERY_STRING'] = ir.query_string + environ['wsgi.input'] = io.BytesIO() + environ['CONTENT_LENGTH'] = '0' + environ['cherrypy.previous_request'] = ir.request + + +class ExceptionTrapper(object): + + """WSGI middleware that traps exceptions.""" + + def __init__(self, nextapp, throws=(KeyboardInterrupt, SystemExit)): + self.nextapp = nextapp + self.throws = throws + + def __call__(self, environ, start_response): + return _TrappedResponse( + self.nextapp, + environ, + start_response, + self.throws + ) + + +class _TrappedResponse(object): + + response = iter([]) + + def __init__(self, nextapp, environ, start_response, throws): + self.nextapp = nextapp + self.environ = environ + self.start_response = start_response + self.throws = throws + self.started_response = False + self.response = self.trap( + self.nextapp, self.environ, self.start_response, + ) + self.iter_response = iter(self.response) + + def __iter__(self): + self.started_response = True + return self + + def __next__(self): + return self.trap(next, self.iter_response) + + def close(self): + if hasattr(self.response, 'close'): + self.response.close() + + def trap(self, func, *args, **kwargs): + try: + return func(*args, **kwargs) + except self.throws: + raise + except StopIteration: + raise + except Exception: + tb = _cperror.format_exc() + _cherrypy.log(tb, severity=40) + if not _cherrypy.request.show_tracebacks: + tb = '' + s, h, b = _cperror.bare_error(tb) + if True: + # What fun. + s = s.decode('ISO-8859-1') + h = [ + (k.decode('ISO-8859-1'), v.decode('ISO-8859-1')) + for k, v in h + ] + if self.started_response: + # Empty our iterable (so future calls raise StopIteration) + self.iter_response = iter([]) + else: + self.iter_response = iter(b) + + try: + self.start_response(s, h, _sys.exc_info()) + except Exception: + # "The application must not trap any exceptions raised by + # start_response, if it called start_response with exc_info. + # Instead, it should allow such exceptions to propagate + # back to the server or gateway." + # But we still log and call close() to clean up ourselves. + _cherrypy.log(traceback=True, severity=40) + raise + + if self.started_response: + return b''.join(b) + else: + return b + + +# WSGI-to-CP Adapter # + + +class AppResponse(object): + + """WSGI response iterable for CherryPy applications.""" + + def __init__(self, environ, start_response, cpapp): + self.cpapp = cpapp + try: + self.environ = environ + self.run() + + r = _cherrypy.serving.response + + outstatus = r.output_status + if not isinstance(outstatus, bytes): + raise TypeError('response.output_status is not a byte string.') + + outheaders = [] + for k, v in r.header_list: + if not isinstance(k, bytes): + tmpl = 'response.header_list key %r is not a byte string.' + raise TypeError(tmpl % k) + if not isinstance(v, bytes): + tmpl = ( + 'response.header_list value %r is not a byte string.' + ) + raise TypeError(tmpl % v) + outheaders.append((k, v)) + + if True: + # According to PEP 3333, when using Python 3, the response + # status and headers must be bytes masquerading as unicode; + # that is, they must be of type "str" but are restricted to + # code points in the "latin-1" set. + outstatus = outstatus.decode('ISO-8859-1') + outheaders = [ + (k.decode('ISO-8859-1'), v.decode('ISO-8859-1')) + for k, v in outheaders + ] + + self.iter_response = iter(r.body) + self.write = start_response(outstatus, outheaders) + except BaseException: + self.close() + raise + + def __iter__(self): + return self + + def __next__(self): + return next(self.iter_response) + + def close(self): + """Close and de-reference the current request and response. (Core)""" + streaming = _cherrypy.serving.response.stream + self.cpapp.release_serving() + + # We avoid the expense of examining the iterator to see if it's + # closable unless we are streaming the response, as that's the + # only situation where we are going to have an iterator which + # may not have been exhausted yet. + if streaming and is_closable_iterator(self.iter_response): + iter_close = self.iter_response.close + try: + iter_close() + except Exception: + _cherrypy.log(traceback=True, severity=40) + + def run(self): + """Create a Request object using environ.""" + env = self.environ.get + + local = httputil.Host( + '', + int(env('SERVER_PORT', 80) or -1), + env('SERVER_NAME', ''), + ) + remote = httputil.Host( + env('REMOTE_ADDR', ''), + int(env('REMOTE_PORT', -1) or -1), + env('REMOTE_HOST', ''), + ) + scheme = env('wsgi.url_scheme') + sproto = env('ACTUAL_SERVER_PROTOCOL', 'HTTP/1.1') + request, resp = self.cpapp.get_serving(local, remote, scheme, sproto) + + # LOGON_USER is served by IIS, and is the name of the + # user after having been mapped to a local account. + # Both IIS and Apache set REMOTE_USER, when possible. + request.login = env('LOGON_USER') or env('REMOTE_USER') or None + request.multithread = self.environ['wsgi.multithread'] + request.multiprocess = self.environ['wsgi.multiprocess'] + request.wsgi_environ = self.environ + request.prev = env('cherrypy.previous_request', None) + + meth = self.environ['REQUEST_METHOD'] + + path = httputil.urljoin( + self.environ.get('SCRIPT_NAME', ''), + self.environ.get('PATH_INFO', ''), + ) + qs = self.environ.get('QUERY_STRING', '') + + path, qs = self.recode_path_qs(path, qs) or (path, qs) + + rproto = self.environ.get('SERVER_PROTOCOL') + headers = self.translate_headers(self.environ) + rfile = self.environ['wsgi.input'] + request.run(meth, path, qs, rproto, headers, rfile) + + headerNames = { + 'HTTP_CGI_AUTHORIZATION': 'Authorization', + 'CONTENT_LENGTH': 'Content-Length', + 'CONTENT_TYPE': 'Content-Type', + 'REMOTE_HOST': 'Remote-Host', + 'REMOTE_ADDR': 'Remote-Addr', + } + + def recode_path_qs(self, path, qs): + # This isn't perfect; if the given PATH_INFO is in the + # wrong encoding, it may fail to match the appropriate config + # section URI. But meh. + old_enc = self.environ.get('wsgi.url_encoding', 'ISO-8859-1') + new_enc = self.cpapp.find_config( + self.environ.get('PATH_INFO', ''), + 'request.uri_encoding', 'utf-8', + ) + if new_enc.lower() == old_enc.lower(): + return + + # Even though the path and qs are unicode, the WSGI server + # is required by PEP 3333 to coerce them to ISO-8859-1 + # masquerading as unicode. So we have to encode back to + # bytes and then decode again using the "correct" encoding. + try: + return ( + path.encode(old_enc).decode(new_enc), + qs.encode(old_enc).decode(new_enc), + ) + except (UnicodeEncodeError, UnicodeDecodeError): + # Just pass them through without transcoding and hope. + pass + + def translate_headers(self, environ): + """Translate CGI-environ header names to HTTP header names.""" + for cgiName in environ: + # We assume all incoming header keys are uppercase already. + if cgiName in self.headerNames: + yield self.headerNames[cgiName], environ[cgiName] + elif cgiName[:5] == 'HTTP_': + # Hackish attempt at recovering original header names. + translatedHeader = cgiName[5:].replace('_', '-') + yield translatedHeader, environ[cgiName] + + +class CPWSGIApp(object): + + """A WSGI application object for a CherryPy Application.""" + + pipeline = [ + ('ExceptionTrapper', ExceptionTrapper), + ('InternalRedirector', InternalRedirector), + ] + """A list of (name, wsgiapp) pairs. Each 'wsgiapp' MUST be a + constructor that takes an initial, positional 'nextapp' argument, + plus optional keyword arguments, and returns a WSGI application + (that takes environ and start_response arguments). The 'name' can + be any you choose, and will correspond to keys in self.config.""" + + head = None + """Rather than nest all apps in the pipeline on each call, it's only + done the first time, and the result is memoized into self.head. Set + this to None again if you change self.pipeline after calling self.""" + + config = {} + """A dict whose keys match names listed in the pipeline. Each + value is a further dict which will be passed to the corresponding + named WSGI callable (from the pipeline) as keyword arguments.""" + + response_class = AppResponse + """The class to instantiate and return as the next app in the WSGI chain. + """ + + def __init__(self, cpapp, pipeline=None): + self.cpapp = cpapp + self.pipeline = self.pipeline[:] + if pipeline: + self.pipeline.extend(pipeline) + self.config = self.config.copy() + + def tail(self, environ, start_response): + """WSGI application callable for the actual CherryPy application. + + You probably shouldn't call this; call self.__call__ instead, + so that any WSGI middleware in self.pipeline can run first. + """ + return self.response_class(environ, start_response, self.cpapp) + + def __call__(self, environ, start_response): + head = self.head + if head is None: + # Create and nest the WSGI apps in our pipeline (in reverse order). + # Then memoize the result in self.head. + head = self.tail + for name, callable in self.pipeline[::-1]: + conf = self.config.get(name, {}) + head = callable(head, **conf) + self.head = head + return head(environ, start_response) + + def namespace_handler(self, k, v): + """Config handler for the 'wsgi' namespace.""" + if k == 'pipeline': + # Note this allows multiple 'wsgi.pipeline' config entries + # (but each entry will be processed in a 'random' order). + # It should also allow developers to set default middleware + # in code (passed to self.__init__) that deployers can add to + # (but not remove) via config. + self.pipeline.extend(v) + elif k == 'response_class': + self.response_class = v + else: + name, arg = k.split('.', 1) + bucket = self.config.setdefault(name, {}) + bucket[arg] = v diff --git a/resources/lib/cherrypy/_cpwsgi_server.py b/resources/lib/cherrypy/_cpwsgi_server.py new file mode 100644 index 0000000..11dd846 --- /dev/null +++ b/resources/lib/cherrypy/_cpwsgi_server.py @@ -0,0 +1,110 @@ +""" +WSGI server interface (see PEP 333). + +This adds some CP-specific bits to the framework-agnostic cheroot package. +""" +import sys + +import cheroot.wsgi +import cheroot.server + +import cherrypy + + +class CPWSGIHTTPRequest(cheroot.server.HTTPRequest): + """Wrapper for cheroot.server.HTTPRequest. + + This is a layer, which preserves URI parsing mode like it which was + before Cheroot v5.8.0. + """ + + def __init__(self, server, conn): + """Initialize HTTP request container instance. + + Args: + server (cheroot.server.HTTPServer): + web server object receiving this request + conn (cheroot.server.HTTPConnection): + HTTP connection object for this request + """ + super(CPWSGIHTTPRequest, self).__init__( + server, conn, proxy_mode=True + ) + + +class CPWSGIServer(cheroot.wsgi.Server): + """Wrapper for cheroot.wsgi.Server. + + cheroot has been designed to not reference CherryPy in any way, + so that it can be used in other frameworks and applications. Therefore, + we wrap it here, so we can set our own mount points from cherrypy.tree + and apply some attributes from config -> cherrypy.server -> wsgi.Server. + """ + + fmt = 'CherryPy/{cherrypy.__version__} {cheroot.wsgi.Server.version}' + version = fmt.format(**globals()) + + def __init__(self, server_adapter=cherrypy.server): + """Initialize CPWSGIServer instance. + + Args: + server_adapter (cherrypy._cpserver.Server): ... + """ + self.server_adapter = server_adapter + self.max_request_header_size = ( + self.server_adapter.max_request_header_size or 0 + ) + self.max_request_body_size = ( + self.server_adapter.max_request_body_size or 0 + ) + + server_name = (self.server_adapter.socket_host or + self.server_adapter.socket_file or + None) + + self.wsgi_version = self.server_adapter.wsgi_version + + super(CPWSGIServer, self).__init__( + server_adapter.bind_addr, cherrypy.tree, + self.server_adapter.thread_pool, + server_name, + max=self.server_adapter.thread_pool_max, + request_queue_size=self.server_adapter.socket_queue_size, + timeout=self.server_adapter.socket_timeout, + shutdown_timeout=self.server_adapter.shutdown_timeout, + accepted_queue_size=self.server_adapter.accepted_queue_size, + accepted_queue_timeout=self.server_adapter.accepted_queue_timeout, + peercreds_enabled=self.server_adapter.peercreds, + peercreds_resolve_enabled=self.server_adapter.peercreds_resolve, + ) + self.ConnectionClass.RequestHandlerClass = CPWSGIHTTPRequest + + self.protocol = self.server_adapter.protocol_version + self.nodelay = self.server_adapter.nodelay + + if sys.version_info >= (3, 0): + ssl_module = self.server_adapter.ssl_module or 'builtin' + else: + ssl_module = self.server_adapter.ssl_module or 'pyopenssl' + if self.server_adapter.ssl_context: + adapter_class = cheroot.server.get_ssl_adapter_class(ssl_module) + self.ssl_adapter = adapter_class( + self.server_adapter.ssl_certificate, + self.server_adapter.ssl_private_key, + self.server_adapter.ssl_certificate_chain, + self.server_adapter.ssl_ciphers) + self.ssl_adapter.context = self.server_adapter.ssl_context + elif self.server_adapter.ssl_certificate: + adapter_class = cheroot.server.get_ssl_adapter_class(ssl_module) + self.ssl_adapter = adapter_class( + self.server_adapter.ssl_certificate, + self.server_adapter.ssl_private_key, + self.server_adapter.ssl_certificate_chain, + self.server_adapter.ssl_ciphers) + + self.stats['Enabled'] = getattr( + self.server_adapter, 'statistics', False) + + def error_log(self, msg='', level=20, traceback=False): + """Write given message to the error log.""" + cherrypy.engine.log(msg, level, traceback) diff --git a/resources/lib/cherrypy/_helper.py b/resources/lib/cherrypy/_helper.py new file mode 100644 index 0000000..d57cd1f --- /dev/null +++ b/resources/lib/cherrypy/_helper.py @@ -0,0 +1,348 @@ +"""Helper functions for CP apps.""" + +import urllib.parse + +from cherrypy._cpcompat import text_or_bytes + +import cherrypy + + +def expose(func=None, alias=None): + """Expose the function or class. + + Optionally provide an alias or set of aliases. + """ + def expose_(func): + func.exposed = True + if alias is not None: + if isinstance(alias, text_or_bytes): + parents[alias.replace('.', '_')] = func + else: + for a in alias: + parents[a.replace('.', '_')] = func + return func + + import sys + import types + decoratable_types = types.FunctionType, types.MethodType, type, + if isinstance(func, decoratable_types): + if alias is None: + # @expose + func.exposed = True + return func + else: + # func = expose(func, alias) + parents = sys._getframe(1).f_locals + return expose_(func) + elif func is None: + if alias is None: + # @expose() + parents = sys._getframe(1).f_locals + return expose_ + else: + # @expose(alias="alias") or + # @expose(alias=["alias1", "alias2"]) + parents = sys._getframe(1).f_locals + return expose_ + else: + # @expose("alias") or + # @expose(["alias1", "alias2"]) + parents = sys._getframe(1).f_locals + alias = func + return expose_ + + +def popargs(*args, **kwargs): + """Decorate _cp_dispatch. + + (cherrypy.dispatch.Dispatcher.dispatch_method_name) + + Optional keyword argument: handler=(Object or Function) + + Provides a _cp_dispatch function that pops off path segments into + cherrypy.request.params under the names specified. The dispatch + is then forwarded on to the next vpath element. + + Note that any existing (and exposed) member function of the class that + popargs is applied to will override that value of the argument. For + instance, if you have a method named "list" on the class decorated with + popargs, then accessing "/list" will call that function instead of popping + it off as the requested parameter. This restriction applies to all + _cp_dispatch functions. The only way around this restriction is to create + a "blank class" whose only function is to provide _cp_dispatch. + + If there are path elements after the arguments, or more arguments + are requested than are available in the vpath, then the 'handler' + keyword argument specifies the next object to handle the parameterized + request. If handler is not specified or is None, then self is used. + If handler is a function rather than an instance, then that function + will be called with the args specified and the return value from that + function used as the next object INSTEAD of adding the parameters to + cherrypy.request.args. + + This decorator may be used in one of two ways: + + As a class decorator: + + .. code-block:: python + + @cherrypy.popargs('year', 'month', 'day') + class Blog: + def index(self, year=None, month=None, day=None): + #Process the parameters here; any url like + #/, /2009, /2009/12, or /2009/12/31 + #will fill in the appropriate parameters. + + def create(self): + #This link will still be available at /create. + #Defined functions take precedence over arguments. + + Or as a member of a class: + + .. code-block:: python + + class Blog: + _cp_dispatch = cherrypy.popargs('year', 'month', 'day') + #... + + The handler argument may be used to mix arguments with built in functions. + For instance, the following setup allows different activities at the + day, month, and year level: + + .. code-block:: python + + class DayHandler: + def index(self, year, month, day): + #Do something with this day; probably list entries + + def delete(self, year, month, day): + #Delete all entries for this day + + @cherrypy.popargs('day', handler=DayHandler()) + class MonthHandler: + def index(self, year, month): + #Do something with this month; probably list entries + + def delete(self, year, month): + #Delete all entries for this month + + @cherrypy.popargs('month', handler=MonthHandler()) + class YearHandler: + def index(self, year): + #Do something with this year + + #... + + @cherrypy.popargs('year', handler=YearHandler()) + class Root: + def index(self): + #... + + """ + # Since keyword arg comes after *args, we have to process it ourselves + # for lower versions of python. + + handler = None + handler_call = False + for k, v in kwargs.items(): + if k == 'handler': + handler = v + else: + tm = "cherrypy.popargs() got an unexpected keyword argument '{0}'" + raise TypeError(tm.format(k)) + + import inspect + + if handler is not None \ + and (hasattr(handler, '__call__') or inspect.isclass(handler)): + handler_call = True + + def decorated(cls_or_self=None, vpath=None): + if inspect.isclass(cls_or_self): + # cherrypy.popargs is a class decorator + cls = cls_or_self + name = cherrypy.dispatch.Dispatcher.dispatch_method_name + setattr(cls, name, decorated) + return cls + + # We're in the actual function + self = cls_or_self + parms = {} + for arg in args: + if not vpath: + break + parms[arg] = vpath.pop(0) + + if handler is not None: + if handler_call: + return handler(**parms) + else: + cherrypy.request.params.update(parms) + return handler + + cherrypy.request.params.update(parms) + + # If we are the ultimate handler, then to prevent our _cp_dispatch + # from being called again, we will resolve remaining elements through + # getattr() directly. + if vpath: + return getattr(self, vpath.pop(0), None) + else: + return self + + return decorated + + +def url(path='', qs='', script_name=None, base=None, relative=None): + """Create an absolute URL for the given path. + + If 'path' starts with a slash ('/'), this will return + (base + script_name + path + qs). + If it does not start with a slash, this returns + (base + script_name [+ request.path_info] + path + qs). + + If script_name is None, cherrypy.request will be used + to find a script_name, if available. + + If base is None, cherrypy.request.base will be used (if available). + Note that you can use cherrypy.tools.proxy to change this. + + Finally, note that this function can be used to obtain an absolute URL + for the current request path (minus the querystring) by passing no args. + If you call url(qs=cherrypy.request.query_string), you should get the + original browser URL (assuming no internal redirections). + + If relative is None or not provided, request.app.relative_urls will + be used (if available, else False). If False, the output will be an + absolute URL (including the scheme, host, vhost, and script_name). + If True, the output will instead be a URL that is relative to the + current request path, perhaps including '..' atoms. If relative is + the string 'server', the output will instead be a URL that is + relative to the server root; i.e., it will start with a slash. + """ + if isinstance(qs, (tuple, list, dict)): + qs = urllib.parse.urlencode(qs) + if qs: + qs = '?' + qs + + if cherrypy.request.app: + if not path.startswith('/'): + # Append/remove trailing slash from path_info as needed + # (this is to support mistyped URL's without redirecting; + # if you want to redirect, use tools.trailing_slash). + pi = cherrypy.request.path_info + if cherrypy.request.is_index is True: + if not pi.endswith('/'): + pi = pi + '/' + elif cherrypy.request.is_index is False: + if pi.endswith('/') and pi != '/': + pi = pi[:-1] + + if path == '': + path = pi + else: + path = urllib.parse.urljoin(pi, path) + + if script_name is None: + script_name = cherrypy.request.script_name + if base is None: + base = cherrypy.request.base + + newurl = base + script_name + normalize_path(path) + qs + else: + # No request.app (we're being called outside a request). + # We'll have to guess the base from server.* attributes. + # This will produce very different results from the above + # if you're using vhosts or tools.proxy. + if base is None: + base = cherrypy.server.base() + + path = (script_name or '') + path + newurl = base + normalize_path(path) + qs + + # At this point, we should have a fully-qualified absolute URL. + + if relative is None: + relative = getattr(cherrypy.request.app, 'relative_urls', False) + + # See http://www.ietf.org/rfc/rfc2396.txt + if relative == 'server': + # "A relative reference beginning with a single slash character is + # termed an absolute-path reference, as defined by ..." + # This is also sometimes called "server-relative". + newurl = '/' + '/'.join(newurl.split('/', 3)[3:]) + elif relative: + # "A relative reference that does not begin with a scheme name + # or a slash character is termed a relative-path reference." + old = url(relative=False).split('/')[:-1] + new = newurl.split('/') + while old and new: + a, b = old[0], new[0] + if a != b: + break + old.pop(0) + new.pop(0) + new = (['..'] * len(old)) + new + newurl = '/'.join(new) + + return newurl + + +def normalize_path(path): + """Resolve given path from relative into absolute form.""" + if './' not in path: + return path + + # Normalize the URL by removing ./ and ../ + atoms = [] + for atom in path.split('/'): + if atom == '.': + pass + elif atom == '..': + # Don't pop from empty list + # (i.e. ignore redundant '..') + if atoms: + atoms.pop() + elif atom: + atoms.append(atom) + + newpath = '/'.join(atoms) + # Preserve leading '/' + if path.startswith('/'): + newpath = '/' + newpath + + return newpath + + +#### +# Inlined from jaraco.classes 1.4.3 +# Ref #1673 +class _ClassPropertyDescriptor(object): + """Descript for read-only class-based property. + + Turns a classmethod-decorated func into a read-only property of that class + type (means the value cannot be set). + """ + + def __init__(self, fget, fset=None): + """Initialize a class property descriptor. + + Instantiated by ``_helper.classproperty``. + """ + self.fget = fget + self.fset = fset + + def __get__(self, obj, klass=None): + """Return property value.""" + if klass is None: + klass = type(obj) + return self.fget.__get__(obj, klass)() + + +def classproperty(func): # noqa: D401; irrelevant for properties + """Decorator like classmethod to implement a static class property.""" + if not isinstance(func, (classmethod, staticmethod)): + func = classmethod(func) + + return _ClassPropertyDescriptor(func) +#### diff --git a/resources/lib/cherrypy/_json.py b/resources/lib/cherrypy/_json.py new file mode 100644 index 0000000..0c2a0f0 --- /dev/null +++ b/resources/lib/cherrypy/_json.py @@ -0,0 +1,25 @@ +""" +JSON support. + +Expose preferred json module as json and provide encode/decode +convenience functions. +""" + +try: + # Prefer simplejson + import simplejson as json +except ImportError: + import json + + +__all__ = ['json', 'encode', 'decode'] + + +decode = json.JSONDecoder().decode +_encode = json.JSONEncoder().iterencode + + +def encode(value): + """Encode to bytes.""" + for chunk in _encode(value): + yield chunk.encode('utf-8') diff --git a/resources/lib/cherrypy/daemon.py b/resources/lib/cherrypy/daemon.py new file mode 100644 index 0000000..74488c0 --- /dev/null +++ b/resources/lib/cherrypy/daemon.py @@ -0,0 +1,107 @@ +"""The CherryPy daemon.""" + +import sys + +import cherrypy +from cherrypy.process import plugins, servers +from cherrypy import Application + + +def start(configfiles=None, daemonize=False, environment=None, + fastcgi=False, scgi=False, pidfile=None, imports=None, + cgi=False): + """Subscribe all engine plugins and start the engine.""" + sys.path = [''] + sys.path + for i in imports or []: + exec('import %s' % i) + + for c in configfiles or []: + cherrypy.config.update(c) + # If there's only one app mounted, merge config into it. + if len(cherrypy.tree.apps) == 1: + for app in cherrypy.tree.apps.values(): + if isinstance(app, Application): + app.merge(c) + + engine = cherrypy.engine + + if environment is not None: + cherrypy.config.update({'environment': environment}) + + # Only daemonize if asked to. + if daemonize: + # Don't print anything to stdout/sterr. + cherrypy.config.update({'log.screen': False}) + plugins.Daemonizer(engine).subscribe() + + if pidfile: + plugins.PIDFile(engine, pidfile).subscribe() + + if hasattr(engine, 'signal_handler'): + engine.signal_handler.subscribe() + if hasattr(engine, 'console_control_handler'): + engine.console_control_handler.subscribe() + + if (fastcgi and (scgi or cgi)) or (scgi and cgi): + cherrypy.log.error('You may only specify one of the cgi, fastcgi, and ' + 'scgi options.', 'ENGINE') + sys.exit(1) + elif fastcgi or scgi or cgi: + # Turn off autoreload when using *cgi. + cherrypy.config.update({'engine.autoreload.on': False}) + # Turn off the default HTTP server (which is subscribed by default). + cherrypy.server.unsubscribe() + + addr = cherrypy.server.bind_addr + cls = ( + servers.FlupFCGIServer if fastcgi else + servers.FlupSCGIServer if scgi else + servers.FlupCGIServer + ) + f = cls(application=cherrypy.tree, bindAddress=addr) + s = servers.ServerAdapter(engine, httpserver=f, bind_addr=addr) + s.subscribe() + + # Always start the engine; this will start all other services + try: + engine.start() + except Exception: + # Assume the error has been logged already via bus.log. + sys.exit(1) + else: + engine.block() + + +def run(): + """Run cherryd CLI.""" + from optparse import OptionParser + + p = OptionParser() + p.add_option('-c', '--config', action='append', dest='config', + help='specify config file(s)') + p.add_option('-d', action='store_true', dest='daemonize', + help='run the server as a daemon') + p.add_option('-e', '--environment', dest='environment', default=None, + help='apply the given config environment') + p.add_option('-f', action='store_true', dest='fastcgi', + help='start a fastcgi server instead of the default HTTP ' + 'server') + p.add_option('-s', action='store_true', dest='scgi', + help='start a scgi server instead of the default HTTP server') + p.add_option('-x', action='store_true', dest='cgi', + help='start a cgi server instead of the default HTTP server') + p.add_option('-i', '--import', action='append', dest='imports', + help='specify modules to import') + p.add_option('-p', '--pidfile', dest='pidfile', default=None, + help='store the process id in the given file') + p.add_option('-P', '--Path', action='append', dest='Path', + help='add the given paths to sys.path') + options, args = p.parse_args() + + if options.Path: + for p in options.Path: + sys.path.insert(0, p) + + start(options.config, options.daemonize, + options.environment, options.fastcgi, options.scgi, + options.pidfile, options.imports, options.cgi) diff --git a/resources/lib/cherrypy/favicon.ico b/resources/lib/cherrypy/favicon.ico new file mode 100644 index 0000000..f0d7e61 Binary files /dev/null and b/resources/lib/cherrypy/favicon.ico differ diff --git a/resources/lib/cherrypy/lib/__init__.py b/resources/lib/cherrypy/lib/__init__.py new file mode 100644 index 0000000..f815f76 --- /dev/null +++ b/resources/lib/cherrypy/lib/__init__.py @@ -0,0 +1,96 @@ +"""CherryPy Library.""" + + +def is_iterator(obj): + """Detect if the object provided implements the iterator protocol. + + (i.e. like a generator). + + This will return False for objects which are iterable, + but not iterators themselves. + """ + from types import GeneratorType + if isinstance(obj, GeneratorType): + return True + elif not hasattr(obj, '__iter__'): + return False + else: + # Types which implement the protocol must return themselves when + # invoking 'iter' upon them. + return iter(obj) is obj + + +def is_closable_iterator(obj): + """Detect if the given object is both closable and iterator.""" + # Not an iterator. + if not is_iterator(obj): + return False + + # A generator - the easiest thing to deal with. + import inspect + if inspect.isgenerator(obj): + return True + + # A custom iterator. Look for a close method... + if not (hasattr(obj, 'close') and callable(obj.close)): + return False + + # ... which doesn't require any arguments. + try: + inspect.getcallargs(obj.close) + except TypeError: + return False + else: + return True + + +class file_generator(object): + """Yield the given input (a file object) in chunks (default 64k). + + (Core) + """ + + def __init__(self, input, chunkSize=65536): + """Initialize file_generator with file ``input`` for chunked access.""" + self.input = input + self.chunkSize = chunkSize + + def __iter__(self): + """Return iterator.""" + return self + + def __next__(self): + """Return next chunk of file.""" + chunk = self.input.read(self.chunkSize) + if chunk: + return chunk + else: + if hasattr(self.input, 'close'): + self.input.close() + raise StopIteration() + next = __next__ + + +def file_generator_limited(fileobj, count, chunk_size=65536): + """Yield the given file object in chunks. + + Stopps after `count` bytes has been emitted. + Default chunk size is 64kB. (Core) + """ + remaining = count + while remaining > 0: + chunk = fileobj.read(min(chunk_size, remaining)) + chunklen = len(chunk) + if chunklen == 0: + return + remaining -= chunklen + yield chunk + + +def set_vary_header(response, header_name): + """Add a Vary header to a response.""" + varies = response.headers.get('Vary', '') + varies = [x.strip() for x in varies.split(',') if x.strip()] + if header_name not in varies: + varies.append(header_name) + response.headers['Vary'] = ', '.join(varies) diff --git a/resources/lib/cherrypy/lib/auth_basic.py b/resources/lib/cherrypy/lib/auth_basic.py new file mode 100644 index 0000000..ad379a2 --- /dev/null +++ b/resources/lib/cherrypy/lib/auth_basic.py @@ -0,0 +1,120 @@ +# This file is part of CherryPy +# -*- coding: utf-8 -*- +# vim:ts=4:sw=4:expandtab:fileencoding=utf-8 +"""HTTP Basic Authentication tool. + +This module provides a CherryPy 3.x tool which implements +the server-side of HTTP Basic Access Authentication, as described in +:rfc:`2617`. + +Example usage, using the built-in checkpassword_dict function which uses a dict +as the credentials store:: + + userpassdict = {'bird' : 'bebop', 'ornette' : 'wayout'} + checkpassword = cherrypy.lib.auth_basic.checkpassword_dict(userpassdict) + basic_auth = {'tools.auth_basic.on': True, + 'tools.auth_basic.realm': 'earth', + 'tools.auth_basic.checkpassword': checkpassword, + 'tools.auth_basic.accept_charset': 'UTF-8', + } + app_config = { '/' : basic_auth } + +""" + +import binascii +import unicodedata +import base64 + +import cherrypy +from cherrypy._cpcompat import ntou, tonative + + +__author__ = 'visteya' +__date__ = 'April 2009' + + +def checkpassword_dict(user_password_dict): + """Returns a checkpassword function which checks credentials + against a dictionary of the form: {username : password}. + + If you want a simple dictionary-based authentication scheme, use + checkpassword_dict(my_credentials_dict) as the value for the + checkpassword argument to basic_auth(). + """ + def checkpassword(realm, user, password): + p = user_password_dict.get(user) + return p and p == password or False + + return checkpassword + + +def basic_auth(realm, checkpassword, debug=False, accept_charset='utf-8'): + """A CherryPy tool which hooks at before_handler to perform + HTTP Basic Access Authentication, as specified in :rfc:`2617` + and :rfc:`7617`. + + If the request has an 'authorization' header with a 'Basic' scheme, this + tool attempts to authenticate the credentials supplied in that header. If + the request has no 'authorization' header, or if it does but the scheme is + not 'Basic', or if authentication fails, the tool sends a 401 response with + a 'WWW-Authenticate' Basic header. + + realm + A string containing the authentication realm. + + checkpassword + A callable which checks the authentication credentials. + Its signature is checkpassword(realm, username, password). where + username and password are the values obtained from the request's + 'authorization' header. If authentication succeeds, checkpassword + returns True, else it returns False. + + """ + + fallback_charset = 'ISO-8859-1' + + if '"' in realm: + raise ValueError('Realm cannot contain the " (quote) character.') + request = cherrypy.serving.request + + auth_header = request.headers.get('authorization') + if auth_header is not None: + # split() error, base64.decodestring() error + msg = 'Bad Request' + with cherrypy.HTTPError.handle((ValueError, binascii.Error), 400, msg): + scheme, params = auth_header.split(' ', 1) + if scheme.lower() == 'basic': + charsets = accept_charset, fallback_charset + decoded_params = base64.b64decode(params.encode('ascii')) + decoded_params = _try_decode(decoded_params, charsets) + decoded_params = ntou(decoded_params) + decoded_params = unicodedata.normalize('NFC', decoded_params) + decoded_params = tonative(decoded_params) + username, password = decoded_params.split(':', 1) + if checkpassword(realm, username, password): + if debug: + cherrypy.log('Auth succeeded', 'TOOLS.AUTH_BASIC') + request.login = username + return # successful authentication + + charset = accept_charset.upper() + charset_declaration = ( + (', charset="%s"' % charset) + if charset != fallback_charset + else '' + ) + # Respond with 401 status and a WWW-Authenticate header + cherrypy.serving.response.headers['www-authenticate'] = ( + 'Basic realm="%s"%s' % (realm, charset_declaration) + ) + raise cherrypy.HTTPError( + 401, 'You are not authorized to access that resource') + + +def _try_decode(subject, charsets): + for charset in charsets[:-1]: + try: + return tonative(subject, charset) + except ValueError: + pass + return tonative(subject, charsets[-1]) diff --git a/resources/lib/cherrypy/lib/auth_digest.py b/resources/lib/cherrypy/lib/auth_digest.py new file mode 100644 index 0000000..fbb5df6 --- /dev/null +++ b/resources/lib/cherrypy/lib/auth_digest.py @@ -0,0 +1,463 @@ +# This file is part of CherryPy +# -*- coding: utf-8 -*- +# vim:ts=4:sw=4:expandtab:fileencoding=utf-8 +"""HTTP Digest Authentication tool. + +An implementation of the server-side of HTTP Digest Access +Authentication, which is described in :rfc:`2617`. + +Example usage, using the built-in get_ha1_dict_plain function which uses a dict +of plaintext passwords as the credentials store:: + + userpassdict = {'alice' : '4x5istwelve'} + get_ha1 = cherrypy.lib.auth_digest.get_ha1_dict_plain(userpassdict) + digest_auth = {'tools.auth_digest.on': True, + 'tools.auth_digest.realm': 'wonderland', + 'tools.auth_digest.get_ha1': get_ha1, + 'tools.auth_digest.key': 'a565c27146791cfb', + 'tools.auth_digest.accept_charset': 'UTF-8', + } + app_config = { '/' : digest_auth } +""" + +import time +import functools +from hashlib import md5 +from urllib.request import parse_http_list, parse_keqv_list + +import cherrypy +from cherrypy._cpcompat import ntob, tonative + + +__author__ = 'visteya' +__date__ = 'April 2009' + + +def md5_hex(s): + return md5(ntob(s, 'utf-8')).hexdigest() + + +qop_auth = 'auth' +qop_auth_int = 'auth-int' +valid_qops = (qop_auth, qop_auth_int) + +valid_algorithms = ('MD5', 'MD5-sess') + +FALLBACK_CHARSET = 'ISO-8859-1' +DEFAULT_CHARSET = 'UTF-8' + + +def TRACE(msg): + cherrypy.log(msg, context='TOOLS.AUTH_DIGEST') + +# Three helper functions for users of the tool, providing three variants +# of get_ha1() functions for three different kinds of credential stores. + + +def get_ha1_dict_plain(user_password_dict): + """Returns a get_ha1 function which obtains a plaintext password from a + dictionary of the form: {username : password}. + + If you want a simple dictionary-based authentication scheme, with plaintext + passwords, use get_ha1_dict_plain(my_userpass_dict) as the value for the + get_ha1 argument to digest_auth(). + """ + def get_ha1(realm, username): + password = user_password_dict.get(username) + if password: + return md5_hex('%s:%s:%s' % (username, realm, password)) + return None + + return get_ha1 + + +def get_ha1_dict(user_ha1_dict): + """Returns a get_ha1 function which obtains a HA1 password hash from a + dictionary of the form: {username : HA1}. + + If you want a dictionary-based authentication scheme, but with + pre-computed HA1 hashes instead of plain-text passwords, use + get_ha1_dict(my_userha1_dict) as the value for the get_ha1 + argument to digest_auth(). + """ + def get_ha1(realm, username): + return user_ha1_dict.get(username) + + return get_ha1 + + +def get_ha1_file_htdigest(filename): + """Returns a get_ha1 function which obtains a HA1 password hash from a + flat file with lines of the same format as that produced by the Apache + htdigest utility. For example, for realm 'wonderland', username 'alice', + and password '4x5istwelve', the htdigest line would be:: + + alice:wonderland:3238cdfe91a8b2ed8e39646921a02d4c + + If you want to use an Apache htdigest file as the credentials store, + then use get_ha1_file_htdigest(my_htdigest_file) as the value for the + get_ha1 argument to digest_auth(). It is recommended that the filename + argument be an absolute path, to avoid problems. + """ + def get_ha1(realm, username): + result = None + f = open(filename, 'r') + for line in f: + u, r, ha1 = line.rstrip().split(':') + if u == username and r == realm: + result = ha1 + break + f.close() + return result + + return get_ha1 + + +def synthesize_nonce(s, key, timestamp=None): + """Synthesize a nonce value which resists spoofing and can be checked + for staleness. Returns a string suitable as the value for 'nonce' in + the www-authenticate header. + + s + A string related to the resource, such as the hostname of the server. + + key + A secret string known only to the server. + + timestamp + An integer seconds-since-the-epoch timestamp + + """ + if timestamp is None: + timestamp = int(time.time()) + h = md5_hex('%s:%s:%s' % (timestamp, s, key)) + nonce = '%s:%s' % (timestamp, h) + return nonce + + +def H(s): + """The hash function H""" + return md5_hex(s) + + +def _try_decode_header(header, charset): + global FALLBACK_CHARSET + + for enc in (charset, FALLBACK_CHARSET): + try: + return tonative(ntob(tonative(header, 'latin1'), 'latin1'), enc) + except ValueError as ve: + last_err = ve + else: + raise last_err + + +class HttpDigestAuthorization(object): + """ + Parses a Digest Authorization header and performs + re-calculation of the digest. + """ + + scheme = 'digest' + + def errmsg(self, s): + return 'Digest Authorization header: %s' % s + + @classmethod + def matches(cls, header): + scheme, _, _ = header.partition(' ') + return scheme.lower() == cls.scheme + + def __init__( + self, auth_header, http_method, + debug=False, accept_charset=DEFAULT_CHARSET[:], + ): + self.http_method = http_method + self.debug = debug + + if not self.matches(auth_header): + raise ValueError('Authorization scheme is not "Digest"') + + self.auth_header = _try_decode_header(auth_header, accept_charset) + + scheme, params = self.auth_header.split(' ', 1) + + # make a dict of the params + items = parse_http_list(params) + paramsd = parse_keqv_list(items) + + self.realm = paramsd.get('realm') + self.username = paramsd.get('username') + self.nonce = paramsd.get('nonce') + self.uri = paramsd.get('uri') + self.method = paramsd.get('method') + self.response = paramsd.get('response') # the response digest + self.algorithm = paramsd.get('algorithm', 'MD5').upper() + self.cnonce = paramsd.get('cnonce') + self.opaque = paramsd.get('opaque') + self.qop = paramsd.get('qop') # qop + self.nc = paramsd.get('nc') # nonce count + + # perform some correctness checks + if self.algorithm not in valid_algorithms: + raise ValueError( + self.errmsg("Unsupported value for algorithm: '%s'" % + self.algorithm)) + + has_reqd = ( + self.username and + self.realm and + self.nonce and + self.uri and + self.response + ) + if not has_reqd: + raise ValueError( + self.errmsg('Not all required parameters are present.')) + + if self.qop: + if self.qop not in valid_qops: + raise ValueError( + self.errmsg("Unsupported value for qop: '%s'" % self.qop)) + if not (self.cnonce and self.nc): + raise ValueError( + self.errmsg('If qop is sent then ' + 'cnonce and nc MUST be present')) + else: + if self.cnonce or self.nc: + raise ValueError( + self.errmsg('If qop is not sent, ' + 'neither cnonce nor nc can be present')) + + def __str__(self): + return 'authorization : %s' % self.auth_header + + def validate_nonce(self, s, key): + """Validate the nonce. + Returns True if nonce was generated by synthesize_nonce() and the + timestamp is not spoofed, else returns False. + + s + A string related to the resource, such as the hostname of + the server. + + key + A secret string known only to the server. + + Both s and key must be the same values which were used to synthesize + the nonce we are trying to validate. + """ + try: + timestamp, hashpart = self.nonce.split(':', 1) + s_timestamp, s_hashpart = synthesize_nonce( + s, key, timestamp).split(':', 1) + is_valid = s_hashpart == hashpart + if self.debug: + TRACE('validate_nonce: %s' % is_valid) + return is_valid + except ValueError: # split() error + pass + return False + + def is_nonce_stale(self, max_age_seconds=600): + """Returns True if a validated nonce is stale. The nonce contains a + timestamp in plaintext and also a secure hash of the timestamp. + You should first validate the nonce to ensure the plaintext + timestamp is not spoofed. + """ + try: + timestamp, hashpart = self.nonce.split(':', 1) + if int(timestamp) + max_age_seconds > int(time.time()): + return False + except ValueError: # int() error + pass + if self.debug: + TRACE('nonce is stale') + return True + + def HA2(self, entity_body=''): + """Returns the H(A2) string. See :rfc:`2617` section 3.2.2.3.""" + # RFC 2617 3.2.2.3 + # If the "qop" directive's value is "auth" or is unspecified, + # then A2 is: + # A2 = method ":" digest-uri-value + # + # If the "qop" value is "auth-int", then A2 is: + # A2 = method ":" digest-uri-value ":" H(entity-body) + if self.qop is None or self.qop == 'auth': + a2 = '%s:%s' % (self.http_method, self.uri) + elif self.qop == 'auth-int': + a2 = '%s:%s:%s' % (self.http_method, self.uri, H(entity_body)) + else: + # in theory, this should never happen, since I validate qop in + # __init__() + raise ValueError(self.errmsg('Unrecognized value for qop!')) + return H(a2) + + def request_digest(self, ha1, entity_body=''): + """Calculates the Request-Digest. See :rfc:`2617` section 3.2.2.1. + + ha1 + The HA1 string obtained from the credentials store. + + entity_body + If 'qop' is set to 'auth-int', then A2 includes a hash + of the "entity body". The entity body is the part of the + message which follows the HTTP headers. See :rfc:`2617` section + 4.3. This refers to the entity the user agent sent in the + request which has the Authorization header. Typically GET + requests don't have an entity, and POST requests do. + + """ + ha2 = self.HA2(entity_body) + # Request-Digest -- RFC 2617 3.2.2.1 + if self.qop: + req = '%s:%s:%s:%s:%s' % ( + self.nonce, self.nc, self.cnonce, self.qop, ha2) + else: + req = '%s:%s' % (self.nonce, ha2) + + # RFC 2617 3.2.2.2 + # + # If the "algorithm" directive's value is "MD5" or is unspecified, + # then A1 is: + # A1 = unq(username-value) ":" unq(realm-value) ":" passwd + # + # If the "algorithm" directive's value is "MD5-sess", then A1 is + # calculated only once - on the first request by the client following + # receipt of a WWW-Authenticate challenge from the server. + # A1 = H( unq(username-value) ":" unq(realm-value) ":" passwd ) + # ":" unq(nonce-value) ":" unq(cnonce-value) + if self.algorithm == 'MD5-sess': + ha1 = H('%s:%s:%s' % (ha1, self.nonce, self.cnonce)) + + digest = H('%s:%s' % (ha1, req)) + return digest + + +def _get_charset_declaration(charset): + global FALLBACK_CHARSET + charset = charset.upper() + return ( + (', charset="%s"' % charset) + if charset != FALLBACK_CHARSET + else '' + ) + + +def www_authenticate( + realm, key, algorithm='MD5', nonce=None, qop=qop_auth, + stale=False, accept_charset=DEFAULT_CHARSET[:], +): + """Constructs a WWW-Authenticate header for Digest authentication.""" + if qop not in valid_qops: + raise ValueError("Unsupported value for qop: '%s'" % qop) + if algorithm not in valid_algorithms: + raise ValueError("Unsupported value for algorithm: '%s'" % algorithm) + + HEADER_PATTERN = ( + 'Digest realm="%s", nonce="%s", algorithm="%s", qop="%s"%s%s' + ) + + if nonce is None: + nonce = synthesize_nonce(realm, key) + + stale_param = ', stale="true"' if stale else '' + + charset_declaration = _get_charset_declaration(accept_charset) + + return HEADER_PATTERN % ( + realm, nonce, algorithm, qop, stale_param, charset_declaration, + ) + + +def digest_auth(realm, get_ha1, key, debug=False, accept_charset='utf-8'): + """A CherryPy tool that hooks at before_handler to perform + HTTP Digest Access Authentication, as specified in :rfc:`2617`. + + If the request has an 'authorization' header with a 'Digest' scheme, + this tool authenticates the credentials supplied in that header. + If the request has no 'authorization' header, or if it does but the + scheme is not "Digest", or if authentication fails, the tool sends + a 401 response with a 'WWW-Authenticate' Digest header. + + realm + A string containing the authentication realm. + + get_ha1 + A callable that looks up a username in a credentials store + and returns the HA1 string, which is defined in the RFC to be + MD5(username : realm : password). The function's signature is: + ``get_ha1(realm, username)`` + where username is obtained from the request's 'authorization' header. + If username is not found in the credentials store, get_ha1() returns + None. + + key + A secret string known only to the server, used in the synthesis + of nonces. + + """ + request = cherrypy.serving.request + + auth_header = request.headers.get('authorization') + + respond_401 = functools.partial( + _respond_401, realm, key, accept_charset, debug) + + if not HttpDigestAuthorization.matches(auth_header or ''): + respond_401() + + msg = 'The Authorization header could not be parsed.' + with cherrypy.HTTPError.handle(ValueError, 400, msg): + auth = HttpDigestAuthorization( + auth_header, request.method, + debug=debug, accept_charset=accept_charset, + ) + + if debug: + TRACE(str(auth)) + + if not auth.validate_nonce(realm, key): + respond_401() + + ha1 = get_ha1(realm, auth.username) + + if ha1 is None: + respond_401() + + # note that for request.body to be available we need to + # hook in at before_handler, not on_start_resource like + # 3.1.x digest_auth does. + digest = auth.request_digest(ha1, entity_body=request.body) + if digest != auth.response: + respond_401() + + # authenticated + if debug: + TRACE('digest matches auth.response') + # Now check if nonce is stale. + # The choice of ten minutes' lifetime for nonce is somewhat + # arbitrary + if auth.is_nonce_stale(max_age_seconds=600): + respond_401(stale=True) + + request.login = auth.username + if debug: + TRACE('authentication of %s successful' % auth.username) + + +def _respond_401(realm, key, accept_charset, debug, **kwargs): + """ + Respond with 401 status and a WWW-Authenticate header + """ + header = www_authenticate( + realm, key, + accept_charset=accept_charset, + **kwargs + ) + if debug: + TRACE(header) + cherrypy.serving.response.headers['WWW-Authenticate'] = header + raise cherrypy.HTTPError( + 401, 'You are not authorized to access that resource') diff --git a/resources/lib/cherrypy/lib/caching.py b/resources/lib/cherrypy/lib/caching.py new file mode 100644 index 0000000..08d2d8e --- /dev/null +++ b/resources/lib/cherrypy/lib/caching.py @@ -0,0 +1,478 @@ +""" +CherryPy implements a simple caching system as a pluggable Tool. This tool +tries to be an (in-process) HTTP/1.1-compliant cache. It's not quite there +yet, but it's probably good enough for most sites. + +In general, GET responses are cached (along with selecting headers) and, if +another request arrives for the same resource, the caching Tool will return 304 +Not Modified if possible, or serve the cached response otherwise. It also sets +request.cached to True if serving a cached representation, and sets +request.cacheable to False (so it doesn't get cached again). + +If POST, PUT, or DELETE requests are made for a cached resource, they +invalidate (delete) any cached response. + +Usage +===== + +Configuration file example:: + + [/] + tools.caching.on = True + tools.caching.delay = 3600 + +You may use a class other than the default +:class:`MemoryCache` by supplying the config +entry ``cache_class``; supply the full dotted name of the replacement class +as the config value. It must implement the basic methods ``get``, ``put``, +``delete``, and ``clear``. + +You may set any attribute, including overriding methods, on the cache +instance by providing them in config. The above sets the +:attr:`delay` attribute, for example. +""" + +import datetime +import sys +import threading +import time + +import cherrypy +from cherrypy.lib import cptools, httputil + + +class Cache(object): + + """Base class for Cache implementations.""" + + def get(self): + """Return the current variant if in the cache, else None.""" + raise NotImplementedError + + def put(self, obj, size): + """Store the current variant in the cache.""" + raise NotImplementedError + + def delete(self): + """Remove ALL cached variants of the current resource.""" + raise NotImplementedError + + def clear(self): + """Reset the cache to its initial, empty state.""" + raise NotImplementedError + + +# ------------------------------ Memory Cache ------------------------------- # +class AntiStampedeCache(dict): + + """A storage system for cached items which reduces stampede collisions.""" + + def wait(self, key, timeout=5, debug=False): + """Return the cached value for the given key, or None. + + If timeout is not None, and the value is already + being calculated by another thread, wait until the given timeout has + elapsed. If the value is available before the timeout expires, it is + returned. If not, None is returned, and a sentinel placed in the cache + to signal other threads to wait. + + If timeout is None, no waiting is performed nor sentinels used. + """ + value = self.get(key) + if isinstance(value, threading.Event): + if timeout is None: + # Ignore the other thread and recalc it ourselves. + if debug: + cherrypy.log('No timeout', 'TOOLS.CACHING') + return None + + # Wait until it's done or times out. + if debug: + cherrypy.log('Waiting up to %s seconds' % + timeout, 'TOOLS.CACHING') + value.wait(timeout) + if value.result is not None: + # The other thread finished its calculation. Use it. + if debug: + cherrypy.log('Result!', 'TOOLS.CACHING') + return value.result + # Timed out. Stick an Event in the slot so other threads wait + # on this one to finish calculating the value. + if debug: + cherrypy.log('Timed out', 'TOOLS.CACHING') + e = threading.Event() + e.result = None + dict.__setitem__(self, key, e) + + return None + elif value is None: + # Stick an Event in the slot so other threads wait + # on this one to finish calculating the value. + if debug: + cherrypy.log('Timed out', 'TOOLS.CACHING') + e = threading.Event() + e.result = None + dict.__setitem__(self, key, e) + return value + + def __setitem__(self, key, value): + """Set the cached value for the given key.""" + existing = self.get(key) + dict.__setitem__(self, key, value) + if isinstance(existing, threading.Event): + # Set Event.result so other threads waiting on it have + # immediate access without needing to poll the cache again. + existing.result = value + existing.set() + + +class MemoryCache(Cache): + + """An in-memory cache for varying response content. + + Each key in self.store is a URI, and each value is an AntiStampedeCache. + The response for any given URI may vary based on the values of + "selecting request headers"; that is, those named in the Vary + response header. We assume the list of header names to be constant + for each URI throughout the lifetime of the application, and store + that list in ``self.store[uri].selecting_headers``. + + The items contained in ``self.store[uri]`` have keys which are tuples of + request header values (in the same order as the names in its + selecting_headers), and values which are the actual responses. + """ + + maxobjects = 1000 + """The maximum number of cached objects; defaults to 1000.""" + + maxobj_size = 100000 + """The maximum size of each cached object in bytes; defaults to 100 KB.""" + + maxsize = 10000000 + """The maximum size of the entire cache in bytes; defaults to 10 MB.""" + + delay = 600 + """Seconds until the cached content expires; defaults to 600 (10 minutes). + """ + + antistampede_timeout = 5 + """Seconds to wait for other threads to release a cache lock.""" + + expire_freq = 0.1 + """Seconds to sleep between cache expiration sweeps.""" + + debug = False + + def __init__(self): + self.clear() + + # Run self.expire_cache in a separate daemon thread. + t = threading.Thread(target=self.expire_cache, name='expire_cache') + self.expiration_thread = t + t.daemon = True + t.start() + + def clear(self): + """Reset the cache to its initial, empty state.""" + self.store = {} + self.expirations = {} + self.tot_puts = 0 + self.tot_gets = 0 + self.tot_hist = 0 + self.tot_expires = 0 + self.tot_non_modified = 0 + self.cursize = 0 + + def expire_cache(self): + """Continuously examine cached objects, expiring stale ones. + + This function is designed to be run in its own daemon thread, + referenced at ``self.expiration_thread``. + """ + # It's possible that "time" will be set to None + # arbitrarily, so we check "while time" to avoid exceptions. + # See tickets #99 and #180 for more information. + while time: + now = time.time() + # Must make a copy of expirations so it doesn't change size + # during iteration + for expiration_time, objects in self.expirations.copy().items(): + if expiration_time <= now: + for obj_size, uri, sel_header_values in objects: + try: + del self.store[uri][tuple(sel_header_values)] + self.tot_expires += 1 + self.cursize -= obj_size + except KeyError: + # the key may have been deleted elsewhere + pass + del self.expirations[expiration_time] + time.sleep(self.expire_freq) + + def get(self): + """Return the current variant if in the cache, else None.""" + request = cherrypy.serving.request + self.tot_gets += 1 + + uri = cherrypy.url(qs=request.query_string) + uricache = self.store.get(uri) + if uricache is None: + return None + + header_values = [request.headers.get(h, '') + for h in uricache.selecting_headers] + variant = uricache.wait(key=tuple(sorted(header_values)), + timeout=self.antistampede_timeout, + debug=self.debug) + if variant is not None: + self.tot_hist += 1 + return variant + + def put(self, variant, size): + """Store the current variant in the cache.""" + request = cherrypy.serving.request + response = cherrypy.serving.response + + uri = cherrypy.url(qs=request.query_string) + uricache = self.store.get(uri) + if uricache is None: + uricache = AntiStampedeCache() + uricache.selecting_headers = [ + e.value for e in response.headers.elements('Vary')] + self.store[uri] = uricache + + if len(self.store) < self.maxobjects: + total_size = self.cursize + size + + # checks if there's space for the object + if (size < self.maxobj_size and total_size < self.maxsize): + # add to the expirations list + expiration_time = response.time + self.delay + bucket = self.expirations.setdefault(expiration_time, []) + bucket.append((size, uri, uricache.selecting_headers)) + + # add to the cache + header_values = [request.headers.get(h, '') + for h in uricache.selecting_headers] + uricache[tuple(sorted(header_values))] = variant + self.tot_puts += 1 + self.cursize = total_size + + def delete(self): + """Remove ALL cached variants of the current resource.""" + uri = cherrypy.url(qs=cherrypy.serving.request.query_string) + self.store.pop(uri, None) + + +def get(invalid_methods=('POST', 'PUT', 'DELETE'), debug=False, **kwargs): + """Try to obtain cached output. If fresh enough, raise HTTPError(304). + + If POST, PUT, or DELETE: + * invalidates (deletes) any cached response for this resource + * sets request.cached = False + * sets request.cacheable = False + + else if a cached copy exists: + * sets request.cached = True + * sets request.cacheable = False + * sets response.headers to the cached values + * checks the cached Last-Modified response header against the + current If-(Un)Modified-Since request headers; raises 304 + if necessary. + * sets response.status and response.body to the cached values + * returns True + + otherwise: + * sets request.cached = False + * sets request.cacheable = True + * returns False + """ + request = cherrypy.serving.request + response = cherrypy.serving.response + + if not hasattr(cherrypy, '_cache'): + # Make a process-wide Cache object. + cherrypy._cache = kwargs.pop('cache_class', MemoryCache)() + + # Take all remaining kwargs and set them on the Cache object. + for k, v in kwargs.items(): + setattr(cherrypy._cache, k, v) + cherrypy._cache.debug = debug + + # POST, PUT, DELETE should invalidate (delete) the cached copy. + # See http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html#sec13.10. + if request.method in invalid_methods: + if debug: + cherrypy.log('request.method %r in invalid_methods %r' % + (request.method, invalid_methods), 'TOOLS.CACHING') + cherrypy._cache.delete() + request.cached = False + request.cacheable = False + return False + + if 'no-cache' in [e.value for e in request.headers.elements('Pragma')]: + request.cached = False + request.cacheable = True + return False + + cache_data = cherrypy._cache.get() + request.cached = bool(cache_data) + request.cacheable = not request.cached + if request.cached: + # Serve the cached copy. + max_age = cherrypy._cache.delay + for v in [e.value for e in request.headers.elements('Cache-Control')]: + atoms = v.split('=', 1) + directive = atoms.pop(0) + if directive == 'max-age': + if len(atoms) != 1 or not atoms[0].isdigit(): + raise cherrypy.HTTPError( + 400, 'Invalid Cache-Control header') + max_age = int(atoms[0]) + break + elif directive == 'no-cache': + if debug: + cherrypy.log( + 'Ignoring cache due to Cache-Control: no-cache', + 'TOOLS.CACHING') + request.cached = False + request.cacheable = True + return False + + if debug: + cherrypy.log('Reading response from cache', 'TOOLS.CACHING') + s, h, b, create_time = cache_data + age = int(response.time - create_time) + if (age > max_age): + if debug: + cherrypy.log('Ignoring cache due to age > %d' % max_age, + 'TOOLS.CACHING') + request.cached = False + request.cacheable = True + return False + + # Copy the response headers. See + # https://github.com/cherrypy/cherrypy/issues/721. + response.headers = rh = httputil.HeaderMap() + for k in h: + dict.__setitem__(rh, k, dict.__getitem__(h, k)) + + # Add the required Age header + response.headers['Age'] = str(age) + + try: + # Note that validate_since depends on a Last-Modified header; + # this was put into the cached copy, and should have been + # resurrected just above (response.headers = cache_data[1]). + cptools.validate_since() + except cherrypy.HTTPRedirect: + x = sys.exc_info()[1] + if x.status == 304: + cherrypy._cache.tot_non_modified += 1 + raise + + # serve it & get out from the request + response.status = s + response.body = b + else: + if debug: + cherrypy.log('request is not cached', 'TOOLS.CACHING') + return request.cached + + +def tee_output(): + """Tee response output to cache storage. Internal.""" + # Used by CachingTool by attaching to request.hooks + + request = cherrypy.serving.request + if 'no-store' in request.headers.values('Cache-Control'): + return + + def tee(body): + """Tee response.body into a list.""" + if ('no-cache' in response.headers.values('Pragma') or + 'no-store' in response.headers.values('Cache-Control')): + for chunk in body: + yield chunk + return + + output = [] + for chunk in body: + output.append(chunk) + yield chunk + + # Save the cache data, but only if the body isn't empty. + # e.g. a 304 Not Modified on a static file response will + # have an empty body. + # If the body is empty, delete the cache because it + # contains a stale Threading._Event object that will + # stall all consecutive requests until the _Event times + # out + body = b''.join(output) + if not body: + cherrypy._cache.delete() + else: + cherrypy._cache.put((response.status, response.headers or {}, + body, response.time), len(body)) + + response = cherrypy.serving.response + response.body = tee(response.body) + + +def expires(secs=0, force=False, debug=False): + """Tool for influencing cache mechanisms using the 'Expires' header. + + secs + Must be either an int or a datetime.timedelta, and indicates the + number of seconds between response.time and when the response should + expire. The 'Expires' header will be set to response.time + secs. + If secs is zero, the 'Expires' header is set one year in the past, and + the following "cache prevention" headers are also set: + + * Pragma: no-cache + * Cache-Control': no-cache, must-revalidate + + force + If False, the following headers are checked: + + * Etag + * Last-Modified + * Age + * Expires + + If any are already present, none of the above response headers are set. + + """ + + response = cherrypy.serving.response + headers = response.headers + + cacheable = False + if not force: + # some header names that indicate that the response can be cached + for indicator in ('Etag', 'Last-Modified', 'Age', 'Expires'): + if indicator in headers: + cacheable = True + break + + if not cacheable and not force: + if debug: + cherrypy.log('request is not cacheable', 'TOOLS.EXPIRES') + else: + if debug: + cherrypy.log('request is cacheable', 'TOOLS.EXPIRES') + if isinstance(secs, datetime.timedelta): + secs = (86400 * secs.days) + secs.seconds + + if secs == 0: + if force or ('Pragma' not in headers): + headers['Pragma'] = 'no-cache' + if cherrypy.serving.request.protocol >= (1, 1): + if force or 'Cache-Control' not in headers: + headers['Cache-Control'] = 'no-cache, must-revalidate' + # Set an explicit Expires date in the past. + expiry = httputil.HTTPDate(1169942400.0) + else: + expiry = httputil.HTTPDate(response.time + secs) + if force or 'Expires' not in headers: + headers['Expires'] = expiry diff --git a/resources/lib/cherrypy/lib/covercp.py b/resources/lib/cherrypy/lib/covercp.py new file mode 100644 index 0000000..3e21971 --- /dev/null +++ b/resources/lib/cherrypy/lib/covercp.py @@ -0,0 +1,390 @@ +"""Code-coverage tools for CherryPy. + +To use this module, or the coverage tools in the test suite, +you need to download 'coverage.py', either Gareth Rees' `original +implementation `_ +or Ned Batchelder's `enhanced version: +`_ + +To turn on coverage tracing, use the following code:: + + cherrypy.engine.subscribe('start', covercp.start) + +DO NOT subscribe anything on the 'start_thread' channel, as previously +recommended. Calling start once in the main thread should be sufficient +to start coverage on all threads. Calling start again in each thread +effectively clears any coverage data gathered up to that point. + +Run your code, then use the ``covercp.serve()`` function to browse the +results in a web browser. If you run this module from the command line, +it will call ``serve()`` for you. +""" + +import re +import sys +import cgi +import os +import os.path +import urllib.parse + +import cherrypy + + +localFile = os.path.join(os.path.dirname(__file__), 'coverage.cache') + +the_coverage = None +try: + from coverage import coverage + the_coverage = coverage(data_file=localFile) + + def start(): + the_coverage.start() +except ImportError: + # Setting the_coverage to None will raise errors + # that need to be trapped downstream. + the_coverage = None + + import warnings + warnings.warn( + 'No code coverage will be performed; ' + 'coverage.py could not be imported.') + + def start(): + pass +start.priority = 20 + +TEMPLATE_MENU = """ + + CherryPy Coverage Menu + + + +

CherryPy Coverage

""" + +TEMPLATE_FORM = """ +
+
+ + Show percentages +
+ Hide files over + %%
+ Exclude files matching
+ +
+ + +
+
""" + +TEMPLATE_FRAMESET = """ +CherryPy coverage data + + + + + +""" + +TEMPLATE_COVERAGE = """ + + Coverage for %(name)s + + + +

%(name)s

+

%(fullpath)s

+

Coverage: %(pc)s%%

""" + +TEMPLATE_LOC_COVERED = """ + %s  + %s +\n""" +TEMPLATE_LOC_NOT_COVERED = """ + %s  + %s +\n""" +TEMPLATE_LOC_EXCLUDED = """ + %s  + %s +\n""" + +TEMPLATE_ITEM = ( + "%s%s%s\n" +) + + +def _percent(statements, missing): + s = len(statements) + e = s - len(missing) + if s > 0: + return int(round(100.0 * e / s)) + return 0 + + +def _show_branch(root, base, path, pct=0, showpct=False, exclude='', + coverage=the_coverage): + + # Show the directory name and any of our children + dirs = [k for k, v in root.items() if v] + dirs.sort() + for name in dirs: + newpath = os.path.join(path, name) + + if newpath.lower().startswith(base): + relpath = newpath[len(base):] + yield '| ' * relpath.count(os.sep) + yield ( + "%s\n" % + (newpath, urllib.parse.quote_plus(exclude), name) + ) + + for chunk in _show_branch( + root[name], base, newpath, pct, showpct, + exclude, coverage=coverage + ): + yield chunk + + # Now list the files + if path.lower().startswith(base): + relpath = path[len(base):] + files = [k for k, v in root.items() if not v] + files.sort() + for name in files: + newpath = os.path.join(path, name) + + pc_str = '' + if showpct: + try: + _, statements, _, missing, _ = coverage.analysis2(newpath) + except Exception: + # Yes, we really want to pass on all errors. + pass + else: + pc = _percent(statements, missing) + pc_str = ('%3d%% ' % pc).replace(' ', ' ') + if pc < float(pct) or pc == -1: + pc_str = "%s" % pc_str + else: + pc_str = "%s" % pc_str + + yield TEMPLATE_ITEM % ('| ' * (relpath.count(os.sep) + 1), + pc_str, newpath, name) + + +def _skip_file(path, exclude): + if exclude: + return bool(re.search(exclude, path)) + + +def _graft(path, tree): + d = tree + + p = path + atoms = [] + while True: + p, tail = os.path.split(p) + if not tail: + break + atoms.append(tail) + atoms.append(p) + if p != '/': + atoms.append('/') + + atoms.reverse() + for node in atoms: + if node: + d = d.setdefault(node, {}) + + +def get_tree(base, exclude, coverage=the_coverage): + """Return covered module names as a nested dict.""" + tree = {} + runs = coverage.data.executed_files() + for path in runs: + if not _skip_file(path, exclude) and not os.path.isdir(path): + _graft(path, tree) + return tree + + +class CoverStats(object): + + def __init__(self, coverage, root=None): + self.coverage = coverage + if root is None: + # Guess initial depth. Files outside this path will not be + # reachable from the web interface. + root = os.path.dirname(cherrypy.__file__) + self.root = root + + @cherrypy.expose + def index(self): + return TEMPLATE_FRAMESET % self.root.lower() + + @cherrypy.expose + def menu(self, base='/', pct='50', showpct='', + exclude=r'python\d\.\d|test|tut\d|tutorial'): + + # The coverage module uses all-lower-case names. + base = base.lower().rstrip(os.sep) + + yield TEMPLATE_MENU + yield TEMPLATE_FORM % locals() + + # Start by showing links for parent paths + yield "
" + path = '' + atoms = base.split(os.sep) + atoms.pop() + for atom in atoms: + path += atom + os.sep + yield ("%s %s" + % (path, urllib.parse.quote_plus(exclude), atom, os.sep)) + yield '
' + + yield "
" + + # Then display the tree + tree = get_tree(base, exclude, self.coverage) + if not tree: + yield '

No modules covered.

' + else: + for chunk in _show_branch(tree, base, '/', pct, + showpct == 'checked', exclude, + coverage=self.coverage): + yield chunk + + yield '
' + yield '' + + def annotated_file(self, filename, statements, excluded, missing): + source = open(filename, 'r') + buffer = [] + for lineno, line in enumerate(source.readlines()): + lineno += 1 + line = line.strip('\n\r') + empty_the_buffer = True + if lineno in excluded: + template = TEMPLATE_LOC_EXCLUDED + elif lineno in missing: + template = TEMPLATE_LOC_NOT_COVERED + elif lineno in statements: + template = TEMPLATE_LOC_COVERED + else: + empty_the_buffer = False + buffer.append((lineno, line)) + if empty_the_buffer: + for lno, pastline in buffer: + yield template % (lno, cgi.escape(pastline)) + buffer = [] + yield template % (lineno, cgi.escape(line)) + + @cherrypy.expose + def report(self, name): + filename, statements, excluded, missing, _ = self.coverage.analysis2( + name) + pc = _percent(statements, missing) + yield TEMPLATE_COVERAGE % dict(name=os.path.basename(name), + fullpath=name, + pc=pc) + yield '\n' + for line in self.annotated_file(filename, statements, excluded, + missing): + yield line + yield '
' + yield '' + yield '' + + +def serve(path=localFile, port=8080, root=None): + if coverage is None: + raise ImportError('The coverage module could not be imported.') + from coverage import coverage + cov = coverage(data_file=path) + cov.load() + + cherrypy.config.update({'server.socket_port': int(port), + 'server.thread_pool': 10, + 'environment': 'production', + }) + cherrypy.quickstart(CoverStats(cov, root)) + + +if __name__ == '__main__': + serve(*tuple(sys.argv[1:])) diff --git a/resources/lib/cherrypy/lib/cpstats.py b/resources/lib/cherrypy/lib/cpstats.py new file mode 100644 index 0000000..111af06 --- /dev/null +++ b/resources/lib/cherrypy/lib/cpstats.py @@ -0,0 +1,694 @@ +"""CPStats, a package for collecting and reporting on program statistics. + +Overview +======== + +Statistics about program operation are an invaluable monitoring and debugging +tool. Unfortunately, the gathering and reporting of these critical values is +usually ad-hoc. This package aims to add a centralized place for gathering +statistical performance data, a structure for recording that data which +provides for extrapolation of that data into more useful information, +and a method of serving that data to both human investigators and +monitoring software. Let's examine each of those in more detail. + +Data Gathering +-------------- + +Just as Python's `logging` module provides a common importable for gathering +and sending messages, performance statistics would benefit from a similar +common mechanism, and one that does *not* require each package which wishes +to collect stats to import a third-party module. Therefore, we choose to +re-use the `logging` module by adding a `statistics` object to it. + +That `logging.statistics` object is a nested dict. It is not a custom class, +because that would: + + 1. require libraries and applications to import a third-party module in + order to participate + 2. inhibit innovation in extrapolation approaches and in reporting tools, and + 3. be slow. + +There are, however, some specifications regarding the structure of the dict.:: + + { + +----"SQLAlchemy": { + | "Inserts": 4389745, + | "Inserts per Second": + | lambda s: s["Inserts"] / (time() - s["Start"]), + | C +---"Table Statistics": { + | o | "widgets": {-----------+ + N | l | "Rows": 1.3M, | Record + a | l | "Inserts": 400, | + m | e | },---------------------+ + e | c | "froobles": { + s | t | "Rows": 7845, + p | i | "Inserts": 0, + a | o | }, + c | n +---}, + e | "Slow Queries": + | [{"Query": "SELECT * FROM widgets;", + | "Processing Time": 47.840923343, + | }, + | ], + +----}, + } + +The `logging.statistics` dict has four levels. The topmost level is nothing +more than a set of names to introduce modularity, usually along the lines of +package names. If the SQLAlchemy project wanted to participate, for example, +it might populate the item `logging.statistics['SQLAlchemy']`, whose value +would be a second-layer dict we call a "namespace". Namespaces help multiple +packages to avoid collisions over key names, and make reports easier to read, +to boot. The maintainers of SQLAlchemy should feel free to use more than one +namespace if needed (such as 'SQLAlchemy ORM'). Note that there are no case +or other syntax constraints on the namespace names; they should be chosen +to be maximally readable by humans (neither too short nor too long). + +Each namespace, then, is a dict of named statistical values, such as +'Requests/sec' or 'Uptime'. You should choose names which will look +good on a report: spaces and capitalization are just fine. + +In addition to scalars, values in a namespace MAY be a (third-layer) +dict, or a list, called a "collection". For example, the CherryPy +:class:`StatsTool` keeps track of what each request is doing (or has most +recently done) in a 'Requests' collection, where each key is a thread ID; each +value in the subdict MUST be a fourth dict (whew!) of statistical data about +each thread. We call each subdict in the collection a "record". Similarly, +the :class:`StatsTool` also keeps a list of slow queries, where each record +contains data about each slow query, in order. + +Values in a namespace or record may also be functions, which brings us to: + +Extrapolation +------------- + +The collection of statistical data needs to be fast, as close to unnoticeable +as possible to the host program. That requires us to minimize I/O, for example, +but in Python it also means we need to minimize function calls. So when you +are designing your namespace and record values, try to insert the most basic +scalar values you already have on hand. + +When it comes time to report on the gathered data, however, we usually have +much more freedom in what we can calculate. Therefore, whenever reporting +tools (like the provided :class:`StatsPage` CherryPy class) fetch the contents +of `logging.statistics` for reporting, they first call +`extrapolate_statistics` (passing the whole `statistics` dict as the only +argument). This makes a deep copy of the statistics dict so that the +reporting tool can both iterate over it and even change it without harming +the original. But it also expands any functions in the dict by calling them. +For example, you might have a 'Current Time' entry in the namespace with the +value "lambda scope: time.time()". The "scope" parameter is the current +namespace dict (or record, if we're currently expanding one of those +instead), allowing you access to existing static entries. If you're truly +evil, you can even modify more than one entry at a time. + +However, don't try to calculate an entry and then use its value in further +extrapolations; the order in which the functions are called is not guaranteed. +This can lead to a certain amount of duplicated work (or a redesign of your +schema), but that's better than complicating the spec. + +After the whole thing has been extrapolated, it's time for: + +Reporting +--------- + +The :class:`StatsPage` class grabs the `logging.statistics` dict, extrapolates +it all, and then transforms it to HTML for easy viewing. Each namespace gets +its own header and attribute table, plus an extra table for each collection. +This is NOT part of the statistics specification; other tools can format how +they like. + +You can control which columns are output and how they are formatted by updating +StatsPage.formatting, which is a dict that mirrors the keys and nesting of +`logging.statistics`. The difference is that, instead of data values, it has +formatting values. Use None for a given key to indicate to the StatsPage that a +given column should not be output. Use a string with formatting +(such as '%.3f') to interpolate the value(s), or use a callable (such as +lambda v: v.isoformat()) for more advanced formatting. Any entry which is not +mentioned in the formatting dict is output unchanged. + +Monitoring +---------- + +Although the HTML output takes pains to assign unique id's to each with +statistical data, you're probably better off fetching /cpstats/data, which +outputs the whole (extrapolated) `logging.statistics` dict in JSON format. +That is probably easier to parse, and doesn't have any formatting controls, +so you get the "original" data in a consistently-serialized format. +Note: there's no treatment yet for datetime objects. Try time.time() instead +for now if you can. Nagios will probably thank you. + +Turning Collection Off +---------------------- + +It is recommended each namespace have an "Enabled" item which, if False, +stops collection (but not reporting) of statistical data. Applications +SHOULD provide controls to pause and resume collection by setting these +entries to False or True, if present. + + +Usage +===== + +To collect statistics on CherryPy applications:: + + from cherrypy.lib import cpstats + appconfig['/']['tools.cpstats.on'] = True + +To collect statistics on your own code:: + + import logging + # Initialize the repository + if not hasattr(logging, 'statistics'): logging.statistics = {} + # Initialize my namespace + mystats = logging.statistics.setdefault('My Stuff', {}) + # Initialize my namespace's scalars and collections + mystats.update({ + 'Enabled': True, + 'Start Time': time.time(), + 'Important Events': 0, + 'Events/Second': lambda s: ( + (s['Important Events'] / (time.time() - s['Start Time']))), + }) + ... + for event in events: + ... + # Collect stats + if mystats.get('Enabled', False): + mystats['Important Events'] += 1 + +To report statistics:: + + root.cpstats = cpstats.StatsPage() + +To format statistics reports:: + + See 'Reporting', above. + +""" + +import logging +import os +import sys +import threading +import time + +import cherrypy +from cherrypy._json import json + +# ------------------------------- Statistics -------------------------------- # + +if not hasattr(logging, 'statistics'): + logging.statistics = {} + + +def extrapolate_statistics(scope): + """Return an extrapolated copy of the given scope.""" + c = {} + for k, v in scope.copy().items(): + if isinstance(v, dict): + v = extrapolate_statistics(v) + elif isinstance(v, (list, tuple)): + v = [extrapolate_statistics(record) for record in v] + elif hasattr(v, '__call__'): + v = v(scope) + c[k] = v + return c + + +# -------------------- CherryPy Applications Statistics --------------------- # + +appstats = logging.statistics.setdefault('CherryPy Applications', {}) +appstats.update({ + 'Enabled': True, + 'Bytes Read/Request': lambda s: ( + s['Total Requests'] and + (s['Total Bytes Read'] / float(s['Total Requests'])) or + 0.0 + ), + 'Bytes Read/Second': lambda s: s['Total Bytes Read'] / s['Uptime'](s), + 'Bytes Written/Request': lambda s: ( + s['Total Requests'] and + (s['Total Bytes Written'] / float(s['Total Requests'])) or + 0.0 + ), + 'Bytes Written/Second': lambda s: ( + s['Total Bytes Written'] / s['Uptime'](s) + ), + 'Current Time': lambda s: time.time(), + 'Current Requests': 0, + 'Requests/Second': lambda s: float(s['Total Requests']) / s['Uptime'](s), + 'Server Version': cherrypy.__version__, + 'Start Time': time.time(), + 'Total Bytes Read': 0, + 'Total Bytes Written': 0, + 'Total Requests': 0, + 'Total Time': 0, + 'Uptime': lambda s: time.time() - s['Start Time'], + 'Requests': {}, +}) + + +def proc_time(s): + return time.time() - s['Start Time'] + + +class ByteCountWrapper(object): + + """Wraps a file-like object, counting the number of bytes read.""" + + def __init__(self, rfile): + self.rfile = rfile + self.bytes_read = 0 + + def read(self, size=-1): + data = self.rfile.read(size) + self.bytes_read += len(data) + return data + + def readline(self, size=-1): + data = self.rfile.readline(size) + self.bytes_read += len(data) + return data + + def readlines(self, sizehint=0): + # Shamelessly stolen from StringIO + total = 0 + lines = [] + line = self.readline() + while line: + lines.append(line) + total += len(line) + if 0 < sizehint <= total: + break + line = self.readline() + return lines + + def close(self): + self.rfile.close() + + def __iter__(self): + return self + + def next(self): + data = self.rfile.next() + self.bytes_read += len(data) + return data + + +def average_uriset_time(s): + return s['Count'] and (s['Sum'] / s['Count']) or 0 + + +def _get_threading_ident(): + if sys.version_info >= (3, 3): + return threading.get_ident() + return threading._get_ident() + + +class StatsTool(cherrypy.Tool): + + """Record various information about the current request.""" + + def __init__(self): + cherrypy.Tool.__init__(self, 'on_end_request', self.record_stop) + + def _setup(self): + """Hook this tool into cherrypy.request. + + The standard CherryPy request object will automatically call this + method when the tool is "turned on" in config. + """ + if appstats.get('Enabled', False): + cherrypy.Tool._setup(self) + self.record_start() + + def record_start(self): + """Record the beginning of a request.""" + request = cherrypy.serving.request + if not hasattr(request.rfile, 'bytes_read'): + request.rfile = ByteCountWrapper(request.rfile) + request.body.fp = request.rfile + + r = request.remote + + appstats['Current Requests'] += 1 + appstats['Total Requests'] += 1 + appstats['Requests'][_get_threading_ident()] = { + 'Bytes Read': None, + 'Bytes Written': None, + # Use a lambda so the ip gets updated by tools.proxy later + 'Client': lambda s: '%s:%s' % (r.ip, r.port), + 'End Time': None, + 'Processing Time': proc_time, + 'Request-Line': request.request_line, + 'Response Status': None, + 'Start Time': time.time(), + } + + def record_stop( + self, uriset=None, slow_queries=1.0, slow_queries_count=100, + debug=False, **kwargs): + """Record the end of a request.""" + resp = cherrypy.serving.response + w = appstats['Requests'][_get_threading_ident()] + + r = cherrypy.request.rfile.bytes_read + w['Bytes Read'] = r + appstats['Total Bytes Read'] += r + + if resp.stream: + w['Bytes Written'] = 'chunked' + else: + cl = int(resp.headers.get('Content-Length', 0)) + w['Bytes Written'] = cl + appstats['Total Bytes Written'] += cl + + w['Response Status'] = \ + getattr(resp, 'output_status', resp.status).decode() + + w['End Time'] = time.time() + p = w['End Time'] - w['Start Time'] + w['Processing Time'] = p + appstats['Total Time'] += p + + appstats['Current Requests'] -= 1 + + if debug: + cherrypy.log('Stats recorded: %s' % repr(w), 'TOOLS.CPSTATS') + + if uriset: + rs = appstats.setdefault('URI Set Tracking', {}) + r = rs.setdefault(uriset, { + 'Min': None, 'Max': None, 'Count': 0, 'Sum': 0, + 'Avg': average_uriset_time}) + if r['Min'] is None or p < r['Min']: + r['Min'] = p + if r['Max'] is None or p > r['Max']: + r['Max'] = p + r['Count'] += 1 + r['Sum'] += p + + if slow_queries and p > slow_queries: + sq = appstats.setdefault('Slow Queries', []) + sq.append(w.copy()) + if len(sq) > slow_queries_count: + sq.pop(0) + + +cherrypy.tools.cpstats = StatsTool() + + +# ---------------------- CherryPy Statistics Reporting ---------------------- # + +thisdir = os.path.abspath(os.path.dirname(__file__)) + +missing = object() + + +def locale_date(v): + return time.strftime('%c', time.gmtime(v)) + + +def iso_format(v): + return time.strftime('%Y-%m-%d %H:%M:%S', time.gmtime(v)) + + +def pause_resume(ns): + def _pause_resume(enabled): + pause_disabled = '' + resume_disabled = '' + if enabled: + resume_disabled = 'disabled="disabled" ' + else: + pause_disabled = 'disabled="disabled" ' + return """ +
+ + +
+
+ + +
+ """ % (ns, pause_disabled, ns, resume_disabled) + return _pause_resume + + +class StatsPage(object): + + formatting = { + 'CherryPy Applications': { + 'Enabled': pause_resume('CherryPy Applications'), + 'Bytes Read/Request': '%.3f', + 'Bytes Read/Second': '%.3f', + 'Bytes Written/Request': '%.3f', + 'Bytes Written/Second': '%.3f', + 'Current Time': iso_format, + 'Requests/Second': '%.3f', + 'Start Time': iso_format, + 'Total Time': '%.3f', + 'Uptime': '%.3f', + 'Slow Queries': { + 'End Time': None, + 'Processing Time': '%.3f', + 'Start Time': iso_format, + }, + 'URI Set Tracking': { + 'Avg': '%.3f', + 'Max': '%.3f', + 'Min': '%.3f', + 'Sum': '%.3f', + }, + 'Requests': { + 'Bytes Read': '%s', + 'Bytes Written': '%s', + 'End Time': None, + 'Processing Time': '%.3f', + 'Start Time': None, + }, + }, + 'CherryPy WSGIServer': { + 'Enabled': pause_resume('CherryPy WSGIServer'), + 'Connections/second': '%.3f', + 'Start time': iso_format, + }, + } + + @cherrypy.expose + def index(self): + # Transform the raw data into pretty output for HTML + yield """ + + + Statistics + + + +""" + for title, scalars, collections in self.get_namespaces(): + yield """ +

%s

+ + + +""" % title + for i, (key, value) in enumerate(scalars): + colnum = i % 3 + if colnum == 0: + yield """ + """ + yield ( + """ + """ % + vars() + ) + if colnum == 2: + yield """ + """ + + if colnum == 0: + yield """ + + + """ + elif colnum == 1: + yield """ + + """ + yield """ + +
%(key)s%(value)s
""" + + for subtitle, headers, subrows in collections: + yield """ +

%s

+ + + """ % subtitle + for key in headers: + yield """ + """ % key + yield """ + + + """ + for subrow in subrows: + yield """ + """ + for value in subrow: + yield """ + """ % value + yield """ + """ + yield """ + +
%s
%s
""" + yield """ + + +""" + + def get_namespaces(self): + """Yield (title, scalars, collections) for each namespace.""" + s = extrapolate_statistics(logging.statistics) + for title, ns in sorted(s.items()): + scalars = [] + collections = [] + ns_fmt = self.formatting.get(title, {}) + for k, v in sorted(ns.items()): + fmt = ns_fmt.get(k, {}) + if isinstance(v, dict): + headers, subrows = self.get_dict_collection(v, fmt) + collections.append((k, ['ID'] + headers, subrows)) + elif isinstance(v, (list, tuple)): + headers, subrows = self.get_list_collection(v, fmt) + collections.append((k, headers, subrows)) + else: + format = ns_fmt.get(k, missing) + if format is None: + # Don't output this column. + continue + if hasattr(format, '__call__'): + v = format(v) + elif format is not missing: + v = format % v + scalars.append((k, v)) + yield title, scalars, collections + + def get_dict_collection(self, v, formatting): + """Return ([headers], [rows]) for the given collection.""" + # E.g., the 'Requests' dict. + headers = [] + vals = v.values() + for record in vals: + for k3 in record: + format = formatting.get(k3, missing) + if format is None: + # Don't output this column. + continue + if k3 not in headers: + headers.append(k3) + headers.sort() + + subrows = [] + for k2, record in sorted(v.items()): + subrow = [k2] + for k3 in headers: + v3 = record.get(k3, '') + format = formatting.get(k3, missing) + if format is None: + # Don't output this column. + continue + if hasattr(format, '__call__'): + v3 = format(v3) + elif format is not missing: + v3 = format % v3 + subrow.append(v3) + subrows.append(subrow) + + return headers, subrows + + def get_list_collection(self, v, formatting): + """Return ([headers], [subrows]) for the given collection.""" + # E.g., the 'Slow Queries' list. + headers = [] + for record in v: + for k3 in record: + format = formatting.get(k3, missing) + if format is None: + # Don't output this column. + continue + if k3 not in headers: + headers.append(k3) + headers.sort() + + subrows = [] + for record in v: + subrow = [] + for k3 in headers: + v3 = record.get(k3, '') + format = formatting.get(k3, missing) + if format is None: + # Don't output this column. + continue + if hasattr(format, '__call__'): + v3 = format(v3) + elif format is not missing: + v3 = format % v3 + subrow.append(v3) + subrows.append(subrow) + + return headers, subrows + + if json is not None: + @cherrypy.expose + def data(self): + s = extrapolate_statistics(logging.statistics) + cherrypy.response.headers['Content-Type'] = 'application/json' + return json.dumps(s, sort_keys=True, indent=4).encode('utf-8') + + @cherrypy.expose + def pause(self, namespace): + logging.statistics.get(namespace, {})['Enabled'] = False + raise cherrypy.HTTPRedirect('./') + pause.cp_config = {'tools.allow.on': True, + 'tools.allow.methods': ['POST']} + + @cherrypy.expose + def resume(self, namespace): + logging.statistics.get(namespace, {})['Enabled'] = True + raise cherrypy.HTTPRedirect('./') + resume.cp_config = {'tools.allow.on': True, + 'tools.allow.methods': ['POST']} diff --git a/resources/lib/cherrypy/lib/cptools.py b/resources/lib/cherrypy/lib/cptools.py new file mode 100644 index 0000000..613a899 --- /dev/null +++ b/resources/lib/cherrypy/lib/cptools.py @@ -0,0 +1,637 @@ +"""Functions for builtin CherryPy tools.""" + +import logging +import re +from hashlib import md5 +import urllib.parse + +import cherrypy +from cherrypy._cpcompat import text_or_bytes +from cherrypy.lib import httputil as _httputil +from cherrypy.lib import is_iterator + + +# Conditional HTTP request support # + +def validate_etags(autotags=False, debug=False): + """Validate the current ETag against If-Match, If-None-Match headers. + + If autotags is True, an ETag response-header value will be provided + from an MD5 hash of the response body (unless some other code has + already provided an ETag header). If False (the default), the ETag + will not be automatic. + + WARNING: the autotags feature is not designed for URL's which allow + methods other than GET. For example, if a POST to the same URL returns + no content, the automatic ETag will be incorrect, breaking a fundamental + use for entity tags in a possibly destructive fashion. Likewise, if you + raise 304 Not Modified, the response body will be empty, the ETag hash + will be incorrect, and your application will break. + See :rfc:`2616` Section 14.24. + """ + response = cherrypy.serving.response + + # Guard against being run twice. + if hasattr(response, 'ETag'): + return + + status, reason, msg = _httputil.valid_status(response.status) + + etag = response.headers.get('ETag') + + # Automatic ETag generation. See warning in docstring. + if etag: + if debug: + cherrypy.log('ETag already set: %s' % etag, 'TOOLS.ETAGS') + elif not autotags: + if debug: + cherrypy.log('Autotags off', 'TOOLS.ETAGS') + elif status != 200: + if debug: + cherrypy.log('Status not 200', 'TOOLS.ETAGS') + else: + etag = response.collapse_body() + etag = '"%s"' % md5(etag).hexdigest() + if debug: + cherrypy.log('Setting ETag: %s' % etag, 'TOOLS.ETAGS') + response.headers['ETag'] = etag + + response.ETag = etag + + # "If the request would, without the If-Match header field, result in + # anything other than a 2xx or 412 status, then the If-Match header + # MUST be ignored." + if debug: + cherrypy.log('Status: %s' % status, 'TOOLS.ETAGS') + if status >= 200 and status <= 299: + request = cherrypy.serving.request + + conditions = request.headers.elements('If-Match') or [] + conditions = [str(x) for x in conditions] + if debug: + cherrypy.log('If-Match conditions: %s' % repr(conditions), + 'TOOLS.ETAGS') + if conditions and not (conditions == ['*'] or etag in conditions): + raise cherrypy.HTTPError(412, 'If-Match failed: ETag %r did ' + 'not match %r' % (etag, conditions)) + + conditions = request.headers.elements('If-None-Match') or [] + conditions = [str(x) for x in conditions] + if debug: + cherrypy.log('If-None-Match conditions: %s' % repr(conditions), + 'TOOLS.ETAGS') + if conditions == ['*'] or etag in conditions: + if debug: + cherrypy.log('request.method: %s' % + request.method, 'TOOLS.ETAGS') + if request.method in ('GET', 'HEAD'): + raise cherrypy.HTTPRedirect([], 304) + else: + raise cherrypy.HTTPError(412, 'If-None-Match failed: ETag %r ' + 'matched %r' % (etag, conditions)) + + +def validate_since(): + """Validate the current Last-Modified against If-Modified-Since headers. + + If no code has set the Last-Modified response header, then no validation + will be performed. + """ + response = cherrypy.serving.response + lastmod = response.headers.get('Last-Modified') + if lastmod: + status, reason, msg = _httputil.valid_status(response.status) + + request = cherrypy.serving.request + + since = request.headers.get('If-Unmodified-Since') + if since and since != lastmod: + if (status >= 200 and status <= 299) or status == 412: + raise cherrypy.HTTPError(412) + + since = request.headers.get('If-Modified-Since') + if since and since == lastmod: + if (status >= 200 and status <= 299) or status == 304: + if request.method in ('GET', 'HEAD'): + raise cherrypy.HTTPRedirect([], 304) + else: + raise cherrypy.HTTPError(412) + + +# Tool code # + +def allow(methods=None, debug=False): + """Raise 405 if request.method not in methods (default ['GET', 'HEAD']). + + The given methods are case-insensitive, and may be in any order. + If only one method is allowed, you may supply a single string; + if more than one, supply a list of strings. + + Regardless of whether the current method is allowed or not, this + also emits an 'Allow' response header, containing the given methods. + """ + if not isinstance(methods, (tuple, list)): + methods = [methods] + methods = [m.upper() for m in methods if m] + if not methods: + methods = ['GET', 'HEAD'] + elif 'GET' in methods and 'HEAD' not in methods: + methods.append('HEAD') + + cherrypy.response.headers['Allow'] = ', '.join(methods) + if cherrypy.request.method not in methods: + if debug: + cherrypy.log('request.method %r not in methods %r' % + (cherrypy.request.method, methods), 'TOOLS.ALLOW') + raise cherrypy.HTTPError(405) + else: + if debug: + cherrypy.log('request.method %r in methods %r' % + (cherrypy.request.method, methods), 'TOOLS.ALLOW') + + +def proxy(base=None, local='X-Forwarded-Host', remote='X-Forwarded-For', + scheme='X-Forwarded-Proto', debug=False): + """Change the base URL (scheme://host[:port][/path]). + + For running a CP server behind Apache, lighttpd, or other HTTP server. + + For Apache and lighttpd, you should leave the 'local' argument at the + default value of 'X-Forwarded-Host'. For Squid, you probably want to set + tools.proxy.local = 'Origin'. + + If you want the new request.base to include path info (not just the host), + you must explicitly set base to the full base path, and ALSO set 'local' + to '', so that the X-Forwarded-Host request header (which never includes + path info) does not override it. Regardless, the value for 'base' MUST + NOT end in a slash. + + cherrypy.request.remote.ip (the IP address of the client) will be + rewritten if the header specified by the 'remote' arg is valid. + By default, 'remote' is set to 'X-Forwarded-For'. If you do not + want to rewrite remote.ip, set the 'remote' arg to an empty string. + """ + + request = cherrypy.serving.request + + if scheme: + s = request.headers.get(scheme, None) + if debug: + cherrypy.log('Testing scheme %r:%r' % (scheme, s), 'TOOLS.PROXY') + if s == 'on' and 'ssl' in scheme.lower(): + # This handles e.g. webfaction's 'X-Forwarded-Ssl: on' header + scheme = 'https' + else: + # This is for lighttpd/pound/Mongrel's 'X-Forwarded-Proto: https' + scheme = s + if not scheme: + scheme = request.base[:request.base.find('://')] + + if local: + lbase = request.headers.get(local, None) + if debug: + cherrypy.log('Testing local %r:%r' % (local, lbase), 'TOOLS.PROXY') + if lbase is not None: + base = lbase.split(',')[0] + if not base: + default = urllib.parse.urlparse(request.base).netloc + base = request.headers.get('Host', default) + + if base.find('://') == -1: + # add http:// or https:// if needed + base = scheme + '://' + base + + request.base = base + + if remote: + xff = request.headers.get(remote) + if debug: + cherrypy.log('Testing remote %r:%r' % (remote, xff), 'TOOLS.PROXY') + if xff: + if remote == 'X-Forwarded-For': + # Grab the first IP in a comma-separated list. Ref #1268. + xff = next(ip.strip() for ip in xff.split(',')) + request.remote.ip = xff + + +def ignore_headers(headers=('Range',), debug=False): + """Delete request headers whose field names are included in 'headers'. + + This is a useful tool for working behind certain HTTP servers; + for example, Apache duplicates the work that CP does for 'Range' + headers, and will doubly-truncate the response. + """ + request = cherrypy.serving.request + for name in headers: + if name in request.headers: + if debug: + cherrypy.log('Ignoring request header %r' % name, + 'TOOLS.IGNORE_HEADERS') + del request.headers[name] + + +def response_headers(headers=None, debug=False): + """Set headers on the response.""" + if debug: + cherrypy.log('Setting response headers: %s' % repr(headers), + 'TOOLS.RESPONSE_HEADERS') + for name, value in (headers or []): + cherrypy.serving.response.headers[name] = value + + +response_headers.failsafe = True + + +def referer(pattern, accept=True, accept_missing=False, error=403, + message='Forbidden Referer header.', debug=False): + """Raise HTTPError if Referer header does/does not match the given pattern. + + pattern + A regular expression pattern to test against the Referer. + + accept + If True, the Referer must match the pattern; if False, + the Referer must NOT match the pattern. + + accept_missing + If True, permit requests with no Referer header. + + error + The HTTP error code to return to the client on failure. + + message + A string to include in the response body on failure. + + """ + try: + ref = cherrypy.serving.request.headers['Referer'] + match = bool(re.match(pattern, ref)) + if debug: + cherrypy.log('Referer %r matches %r' % (ref, pattern), + 'TOOLS.REFERER') + if accept == match: + return + except KeyError: + if debug: + cherrypy.log('No Referer header', 'TOOLS.REFERER') + if accept_missing: + return + + raise cherrypy.HTTPError(error, message) + + +class SessionAuth(object): + + """Assert that the user is logged in.""" + + session_key = 'username' + debug = False + + def check_username_and_password(self, username, password): + pass + + def anonymous(self): + """Provide a temporary user name for anonymous users.""" + pass + + def on_login(self, username): + pass + + def on_logout(self, username): + pass + + def on_check(self, username): + pass + + def login_screen(self, from_page='..', username='', error_msg='', + **kwargs): + return (str(""" +Message: %(error_msg)s +
+ Login: +
+ Password: +
+ +
+ +
+""") % vars()).encode('utf-8') + + def do_login(self, username, password, from_page='..', **kwargs): + """Login. May raise redirect, or return True if request handled.""" + response = cherrypy.serving.response + error_msg = self.check_username_and_password(username, password) + if error_msg: + body = self.login_screen(from_page, username, error_msg) + response.body = body + if 'Content-Length' in response.headers: + # Delete Content-Length header so finalize() recalcs it. + del response.headers['Content-Length'] + return True + else: + cherrypy.serving.request.login = username + cherrypy.session[self.session_key] = username + self.on_login(username) + raise cherrypy.HTTPRedirect(from_page or '/') + + def do_logout(self, from_page='..', **kwargs): + """Logout. May raise redirect, or return True if request handled.""" + sess = cherrypy.session + username = sess.get(self.session_key) + sess[self.session_key] = None + if username: + cherrypy.serving.request.login = None + self.on_logout(username) + raise cherrypy.HTTPRedirect(from_page) + + def do_check(self): + """Assert username. Raise redirect, or return True if request handled. + """ + sess = cherrypy.session + request = cherrypy.serving.request + response = cherrypy.serving.response + + username = sess.get(self.session_key) + if not username: + sess[self.session_key] = username = self.anonymous() + self._debug_message('No session[username], trying anonymous') + if not username: + url = cherrypy.url(qs=request.query_string) + self._debug_message( + 'No username, routing to login_screen with from_page %(url)r', + locals(), + ) + response.body = self.login_screen(url) + if 'Content-Length' in response.headers: + # Delete Content-Length header so finalize() recalcs it. + del response.headers['Content-Length'] + return True + self._debug_message('Setting request.login to %(username)r', locals()) + request.login = username + self.on_check(username) + + def _debug_message(self, template, context={}): + if not self.debug: + return + cherrypy.log(template % context, 'TOOLS.SESSAUTH') + + def run(self): + request = cherrypy.serving.request + response = cherrypy.serving.response + + path = request.path_info + if path.endswith('login_screen'): + self._debug_message('routing %(path)r to login_screen', locals()) + response.body = self.login_screen() + return True + elif path.endswith('do_login'): + if request.method != 'POST': + response.headers['Allow'] = 'POST' + self._debug_message('do_login requires POST') + raise cherrypy.HTTPError(405) + self._debug_message('routing %(path)r to do_login', locals()) + return self.do_login(**request.params) + elif path.endswith('do_logout'): + if request.method != 'POST': + response.headers['Allow'] = 'POST' + raise cherrypy.HTTPError(405) + self._debug_message('routing %(path)r to do_logout', locals()) + return self.do_logout(**request.params) + else: + self._debug_message('No special path, running do_check') + return self.do_check() + + +def session_auth(**kwargs): + """Session authentication hook. + + Any attribute of the SessionAuth class may be overridden + via a keyword arg to this function: + + """ + '\n '.join( + '{!s}: {!s}'.format(k, type(getattr(SessionAuth, k)).__name__) + for k in dir(SessionAuth) + if not k.startswith('__') + ) + sa = SessionAuth() + for k, v in kwargs.items(): + setattr(sa, k, v) + return sa.run() + + +def log_traceback(severity=logging.ERROR, debug=False): + """Write the last error's traceback to the cherrypy error log.""" + cherrypy.log('', 'HTTP', severity=severity, traceback=True) + + +def log_request_headers(debug=False): + """Write request headers to the cherrypy error log.""" + h = [' %s: %s' % (k, v) for k, v in cherrypy.serving.request.header_list] + cherrypy.log('\nRequest Headers:\n' + '\n'.join(h), 'HTTP') + + +def log_hooks(debug=False): + """Write request.hooks to the cherrypy error log.""" + request = cherrypy.serving.request + + msg = [] + # Sort by the standard points if possible. + from cherrypy import _cprequest + points = _cprequest.hookpoints + for k in request.hooks.keys(): + if k not in points: + points.append(k) + + for k in points: + msg.append(' %s:' % k) + v = request.hooks.get(k, []) + v.sort() + for h in v: + msg.append(' %r' % h) + cherrypy.log('\nRequest Hooks for ' + cherrypy.url() + + ':\n' + '\n'.join(msg), 'HTTP') + + +def redirect(url='', internal=True, debug=False): + """Raise InternalRedirect or HTTPRedirect to the given url.""" + if debug: + cherrypy.log('Redirecting %sto: %s' % + ({True: 'internal ', False: ''}[internal], url), + 'TOOLS.REDIRECT') + if internal: + raise cherrypy.InternalRedirect(url) + else: + raise cherrypy.HTTPRedirect(url) + + +def trailing_slash(missing=True, extra=False, status=None, debug=False): + """Redirect if path_info has (missing|extra) trailing slash.""" + request = cherrypy.serving.request + pi = request.path_info + + if debug: + cherrypy.log('is_index: %r, missing: %r, extra: %r, path_info: %r' % + (request.is_index, missing, extra, pi), + 'TOOLS.TRAILING_SLASH') + if request.is_index is True: + if missing: + if not pi.endswith('/'): + new_url = cherrypy.url(pi + '/', request.query_string) + raise cherrypy.HTTPRedirect(new_url, status=status or 301) + elif request.is_index is False: + if extra: + # If pi == '/', don't redirect to ''! + if pi.endswith('/') and pi != '/': + new_url = cherrypy.url(pi[:-1], request.query_string) + raise cherrypy.HTTPRedirect(new_url, status=status or 301) + + +def flatten(debug=False): + """Wrap response.body in a generator that recursively iterates over body. + + This allows cherrypy.response.body to consist of 'nested generators'; + that is, a set of generators that yield generators. + """ + def flattener(input): + numchunks = 0 + for x in input: + if not is_iterator(x): + numchunks += 1 + yield x + else: + for y in flattener(x): + numchunks += 1 + yield y + if debug: + cherrypy.log('Flattened %d chunks' % numchunks, 'TOOLS.FLATTEN') + response = cherrypy.serving.response + response.body = flattener(response.body) + + +def accept(media=None, debug=False): + """Return the client's preferred media-type (from the given Content-Types). + + If 'media' is None (the default), no test will be performed. + + If 'media' is provided, it should be the Content-Type value (as a string) + or values (as a list or tuple of strings) which the current resource + can emit. The client's acceptable media ranges (as declared in the + Accept request header) will be matched in order to these Content-Type + values; the first such string is returned. That is, the return value + will always be one of the strings provided in the 'media' arg (or None + if 'media' is None). + + If no match is found, then HTTPError 406 (Not Acceptable) is raised. + Note that most web browsers send */* as a (low-quality) acceptable + media range, which should match any Content-Type. In addition, "...if + no Accept header field is present, then it is assumed that the client + accepts all media types." + + Matching types are checked in order of client preference first, + and then in the order of the given 'media' values. + + Note that this function does not honor accept-params (other than "q"). + """ + if not media: + return + if isinstance(media, text_or_bytes): + media = [media] + request = cherrypy.serving.request + + # Parse the Accept request header, and try to match one + # of the requested media-ranges (in order of preference). + ranges = request.headers.elements('Accept') + if not ranges: + # Any media type is acceptable. + if debug: + cherrypy.log('No Accept header elements', 'TOOLS.ACCEPT') + return media[0] + else: + # Note that 'ranges' is sorted in order of preference + for element in ranges: + if element.qvalue > 0: + if element.value == '*/*': + # Matches any type or subtype + if debug: + cherrypy.log('Match due to */*', 'TOOLS.ACCEPT') + return media[0] + elif element.value.endswith('/*'): + # Matches any subtype + mtype = element.value[:-1] # Keep the slash + for m in media: + if m.startswith(mtype): + if debug: + cherrypy.log('Match due to %s' % element.value, + 'TOOLS.ACCEPT') + return m + else: + # Matches exact value + if element.value in media: + if debug: + cherrypy.log('Match due to %s' % element.value, + 'TOOLS.ACCEPT') + return element.value + + # No suitable media-range found. + ah = request.headers.get('Accept') + if ah is None: + msg = 'Your client did not send an Accept header.' + else: + msg = 'Your client sent this Accept header: %s.' % ah + msg += (' But this resource only emits these media types: %s.' % + ', '.join(media)) + raise cherrypy.HTTPError(406, msg) + + +class MonitoredHeaderMap(_httputil.HeaderMap): + + def transform_key(self, key): + self.accessed_headers.add(key) + return super(MonitoredHeaderMap, self).transform_key(key) + + def __init__(self): + self.accessed_headers = set() + super(MonitoredHeaderMap, self).__init__() + + +def autovary(ignore=None, debug=False): + """Auto-populate the Vary response header based on request.header access. + """ + request = cherrypy.serving.request + + req_h = request.headers + request.headers = MonitoredHeaderMap() + request.headers.update(req_h) + if ignore is None: + ignore = set(['Content-Disposition', 'Content-Length', 'Content-Type']) + + def set_response_header(): + resp_h = cherrypy.serving.response.headers + v = set([e.value for e in resp_h.elements('Vary')]) + if debug: + cherrypy.log( + 'Accessed headers: %s' % request.headers.accessed_headers, + 'TOOLS.AUTOVARY') + v = v.union(request.headers.accessed_headers) + v = v.difference(ignore) + v = list(v) + v.sort() + resp_h['Vary'] = ', '.join(v) + request.hooks.attach('before_finalize', set_response_header, 95) + + +def convert_params(exception=ValueError, error=400): + """Convert request params based on function annotations, with error handling. + + exception + Exception class to catch. + + status + The HTTP error code to return to the client on failure. + """ + request = cherrypy.serving.request + types = request.handler.callable.__annotations__ + with cherrypy.HTTPError.handle(exception, error): + for key in set(types).intersection(request.params): + request.params[key] = types[key](request.params[key]) diff --git a/resources/lib/cherrypy/lib/encoding.py b/resources/lib/cherrypy/lib/encoding.py new file mode 100644 index 0000000..54a7a8a --- /dev/null +++ b/resources/lib/cherrypy/lib/encoding.py @@ -0,0 +1,434 @@ +import struct +import time +import io + +import cherrypy +from cherrypy._cpcompat import text_or_bytes +from cherrypy.lib import file_generator +from cherrypy.lib import is_closable_iterator +from cherrypy.lib import set_vary_header + + +def decode(encoding=None, default_encoding='utf-8'): + """Replace or extend the list of charsets used to decode a request entity. + + Either argument may be a single string or a list of strings. + + encoding + If not None, restricts the set of charsets attempted while decoding + a request entity to the given set (even if a different charset is + given in the Content-Type request header). + + default_encoding + Only in effect if the 'encoding' argument is not given. + If given, the set of charsets attempted while decoding a request + entity is *extended* with the given value(s). + + """ + body = cherrypy.request.body + if encoding is not None: + if not isinstance(encoding, list): + encoding = [encoding] + body.attempt_charsets = encoding + elif default_encoding: + if not isinstance(default_encoding, list): + default_encoding = [default_encoding] + body.attempt_charsets = body.attempt_charsets + default_encoding + + +class UTF8StreamEncoder: + def __init__(self, iterator): + self._iterator = iterator + + def __iter__(self): + return self + + def next(self): + return self.__next__() + + def __next__(self): + res = next(self._iterator) + if isinstance(res, str): + res = res.encode('utf-8') + return res + + def close(self): + if is_closable_iterator(self._iterator): + self._iterator.close() + + def __getattr__(self, attr): + if attr.startswith('__'): + raise AttributeError(self, attr) + return getattr(self._iterator, attr) + + +class ResponseEncoder: + + default_encoding = 'utf-8' + failmsg = 'Response body could not be encoded with %r.' + encoding = None + errors = 'strict' + text_only = True + add_charset = True + debug = False + + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + + self.attempted_charsets = set() + request = cherrypy.serving.request + if request.handler is not None: + # Replace request.handler with self + if self.debug: + cherrypy.log('Replacing request.handler', 'TOOLS.ENCODE') + self.oldhandler = request.handler + request.handler = self + + def encode_stream(self, encoding): + """Encode a streaming response body. + + Use a generator wrapper, and just pray it works as the stream is + being written out. + """ + if encoding in self.attempted_charsets: + return False + self.attempted_charsets.add(encoding) + + def encoder(body): + for chunk in body: + if isinstance(chunk, str): + chunk = chunk.encode(encoding, self.errors) + yield chunk + self.body = encoder(self.body) + return True + + def encode_string(self, encoding): + """Encode a buffered response body.""" + if encoding in self.attempted_charsets: + return False + self.attempted_charsets.add(encoding) + body = [] + for chunk in self.body: + if isinstance(chunk, str): + try: + chunk = chunk.encode(encoding, self.errors) + except (LookupError, UnicodeError): + return False + body.append(chunk) + self.body = body + return True + + def find_acceptable_charset(self): + request = cherrypy.serving.request + response = cherrypy.serving.response + + if self.debug: + cherrypy.log('response.stream %r' % + response.stream, 'TOOLS.ENCODE') + if response.stream: + encoder = self.encode_stream + else: + encoder = self.encode_string + if 'Content-Length' in response.headers: + # Delete Content-Length header so finalize() recalcs it. + # Encoded strings may be of different lengths from their + # unicode equivalents, and even from each other. For example: + # >>> t = u"\u7007\u3040" + # >>> len(t) + # 2 + # >>> len(t.encode("UTF-8")) + # 6 + # >>> len(t.encode("utf7")) + # 8 + del response.headers['Content-Length'] + + # Parse the Accept-Charset request header, and try to provide one + # of the requested charsets (in order of user preference). + encs = request.headers.elements('Accept-Charset') + charsets = [enc.value.lower() for enc in encs] + if self.debug: + cherrypy.log('charsets %s' % repr(charsets), 'TOOLS.ENCODE') + + if self.encoding is not None: + # If specified, force this encoding to be used, or fail. + encoding = self.encoding.lower() + if self.debug: + cherrypy.log('Specified encoding %r' % + encoding, 'TOOLS.ENCODE') + if (not charsets) or '*' in charsets or encoding in charsets: + if self.debug: + cherrypy.log('Attempting encoding %r' % + encoding, 'TOOLS.ENCODE') + if encoder(encoding): + return encoding + else: + if not encs: + if self.debug: + cherrypy.log('Attempting default encoding %r' % + self.default_encoding, 'TOOLS.ENCODE') + # Any character-set is acceptable. + if encoder(self.default_encoding): + return self.default_encoding + else: + raise cherrypy.HTTPError(500, self.failmsg % + self.default_encoding) + else: + for element in encs: + if element.qvalue > 0: + if element.value == '*': + # Matches any charset. Try our default. + if self.debug: + cherrypy.log('Attempting default encoding due ' + 'to %r' % element, 'TOOLS.ENCODE') + if encoder(self.default_encoding): + return self.default_encoding + else: + encoding = element.value + if self.debug: + cherrypy.log('Attempting encoding %s (qvalue >' + '0)' % element, 'TOOLS.ENCODE') + if encoder(encoding): + return encoding + + if '*' not in charsets: + # If no "*" is present in an Accept-Charset field, then all + # character sets not explicitly mentioned get a quality + # value of 0, except for ISO-8859-1, which gets a quality + # value of 1 if not explicitly mentioned. + iso = 'iso-8859-1' + if iso not in charsets: + if self.debug: + cherrypy.log('Attempting ISO-8859-1 encoding', + 'TOOLS.ENCODE') + if encoder(iso): + return iso + + # No suitable encoding found. + ac = request.headers.get('Accept-Charset') + if ac is None: + msg = 'Your client did not send an Accept-Charset header.' + else: + msg = 'Your client sent this Accept-Charset header: %s.' % ac + _charsets = ', '.join(sorted(self.attempted_charsets)) + msg += ' We tried these charsets: %s.' % (_charsets,) + raise cherrypy.HTTPError(406, msg) + + def __call__(self, *args, **kwargs): + response = cherrypy.serving.response + self.body = self.oldhandler(*args, **kwargs) + + self.body = prepare_iter(self.body) + + ct = response.headers.elements('Content-Type') + if self.debug: + cherrypy.log('Content-Type: %r' % [str(h) + for h in ct], 'TOOLS.ENCODE') + if ct and self.add_charset: + ct = ct[0] + if self.text_only: + if ct.value.lower().startswith('text/'): + if self.debug: + cherrypy.log( + 'Content-Type %s starts with "text/"' % ct, + 'TOOLS.ENCODE') + do_find = True + else: + if self.debug: + cherrypy.log('Not finding because Content-Type %s ' + 'does not start with "text/"' % ct, + 'TOOLS.ENCODE') + do_find = False + else: + if self.debug: + cherrypy.log('Finding because not text_only', + 'TOOLS.ENCODE') + do_find = True + + if do_find: + # Set "charset=..." param on response Content-Type header + ct.params['charset'] = self.find_acceptable_charset() + if self.debug: + cherrypy.log('Setting Content-Type %s' % ct, + 'TOOLS.ENCODE') + response.headers['Content-Type'] = str(ct) + + return self.body + + +def prepare_iter(value): + """ + Ensure response body is iterable and resolves to False when empty. + """ + if isinstance(value, text_or_bytes): + # strings get wrapped in a list because iterating over a single + # item list is much faster than iterating over every character + # in a long string. + if value: + value = [value] + else: + # [''] doesn't evaluate to False, so replace it with []. + value = [] + # Don't use isinstance here; io.IOBase which has an ABC takes + # 1000 times as long as, say, isinstance(value, str) + elif hasattr(value, 'read'): + value = file_generator(value) + elif value is None: + value = [] + return value + + +# GZIP + + +def compress(body, compress_level): + """Compress 'body' at the given compress_level.""" + import zlib + + # See http://www.gzip.org/zlib/rfc-gzip.html + yield b'\x1f\x8b' # ID1 and ID2: gzip marker + yield b'\x08' # CM: compression method + yield b'\x00' # FLG: none set + # MTIME: 4 bytes + yield struct.pack(' 0 is present + * The 'identity' value is given with a qvalue > 0. + + """ + request = cherrypy.serving.request + response = cherrypy.serving.response + + set_vary_header(response, 'Accept-Encoding') + + if not response.body: + # Response body is empty (might be a 304 for instance) + if debug: + cherrypy.log('No response body', context='TOOLS.GZIP') + return + + # If returning cached content (which should already have been gzipped), + # don't re-zip. + if getattr(request, 'cached', False): + if debug: + cherrypy.log('Not gzipping cached response', context='TOOLS.GZIP') + return + + acceptable = request.headers.elements('Accept-Encoding') + if not acceptable: + # If no Accept-Encoding field is present in a request, + # the server MAY assume that the client will accept any + # content coding. In this case, if "identity" is one of + # the available content-codings, then the server SHOULD use + # the "identity" content-coding, unless it has additional + # information that a different content-coding is meaningful + # to the client. + if debug: + cherrypy.log('No Accept-Encoding', context='TOOLS.GZIP') + return + + ct = response.headers.get('Content-Type', '').split(';')[0] + for coding in acceptable: + if coding.value == 'identity' and coding.qvalue != 0: + if debug: + cherrypy.log('Non-zero identity qvalue: %s' % coding, + context='TOOLS.GZIP') + return + if coding.value in ('gzip', 'x-gzip'): + if coding.qvalue == 0: + if debug: + cherrypy.log('Zero gzip qvalue: %s' % coding, + context='TOOLS.GZIP') + return + + if ct not in mime_types: + # If the list of provided mime-types contains tokens + # such as 'text/*' or 'application/*+xml', + # we go through them and find the most appropriate one + # based on the given content-type. + # The pattern matching is only caring about the most + # common cases, as stated above, and doesn't support + # for extra parameters. + found = False + if '/' in ct: + ct_media_type, ct_sub_type = ct.split('/') + for mime_type in mime_types: + if '/' in mime_type: + media_type, sub_type = mime_type.split('/') + if ct_media_type == media_type: + if sub_type == '*': + found = True + break + elif '+' in sub_type and '+' in ct_sub_type: + ct_left, ct_right = ct_sub_type.split('+') + left, right = sub_type.split('+') + if left == '*' and ct_right == right: + found = True + break + + if not found: + if debug: + cherrypy.log('Content-Type %s not in mime_types %r' % + (ct, mime_types), context='TOOLS.GZIP') + return + + if debug: + cherrypy.log('Gzipping', context='TOOLS.GZIP') + # Return a generator that compresses the page + response.headers['Content-Encoding'] = 'gzip' + response.body = compress(response.body, compress_level) + if 'Content-Length' in response.headers: + # Delete Content-Length header so finalize() recalcs it. + del response.headers['Content-Length'] + + return + + if debug: + cherrypy.log('No acceptable encoding found.', context='GZIP') + cherrypy.HTTPError(406, 'identity, gzip').set_response() diff --git a/resources/lib/cherrypy/lib/gctools.py b/resources/lib/cherrypy/lib/gctools.py new file mode 100644 index 0000000..26746d7 --- /dev/null +++ b/resources/lib/cherrypy/lib/gctools.py @@ -0,0 +1,218 @@ +import gc +import inspect +import sys +import time + +try: + import objgraph +except ImportError: + objgraph = None + +import cherrypy +from cherrypy import _cprequest, _cpwsgi +from cherrypy.process.plugins import SimplePlugin + + +class ReferrerTree(object): + + """An object which gathers all referrers of an object to a given depth.""" + + peek_length = 40 + + def __init__(self, ignore=None, maxdepth=2, maxparents=10): + self.ignore = ignore or [] + self.ignore.append(inspect.currentframe().f_back) + self.maxdepth = maxdepth + self.maxparents = maxparents + + def ascend(self, obj, depth=1): + """Return a nested list containing referrers of the given object.""" + depth += 1 + parents = [] + + # Gather all referrers in one step to minimize + # cascading references due to repr() logic. + refs = gc.get_referrers(obj) + self.ignore.append(refs) + if len(refs) > self.maxparents: + return [('[%s referrers]' % len(refs), [])] + + try: + ascendcode = self.ascend.__code__ + except AttributeError: + ascendcode = self.ascend.im_func.func_code + for parent in refs: + if inspect.isframe(parent) and parent.f_code is ascendcode: + continue + if parent in self.ignore: + continue + if depth <= self.maxdepth: + parents.append((parent, self.ascend(parent, depth))) + else: + parents.append((parent, [])) + + return parents + + def peek(self, s): + """Return s, restricted to a sane length.""" + if len(s) > (self.peek_length + 3): + half = self.peek_length // 2 + return s[:half] + '...' + s[-half:] + else: + return s + + def _format(self, obj, descend=True): + """Return a string representation of a single object.""" + if inspect.isframe(obj): + filename, lineno, func, context, index = inspect.getframeinfo(obj) + return "" % func + + if not descend: + return self.peek(repr(obj)) + + if isinstance(obj, dict): + return '{' + ', '.join(['%s: %s' % (self._format(k, descend=False), + self._format(v, descend=False)) + for k, v in obj.items()]) + '}' + elif isinstance(obj, list): + return '[' + ', '.join([self._format(item, descend=False) + for item in obj]) + ']' + elif isinstance(obj, tuple): + return '(' + ', '.join([self._format(item, descend=False) + for item in obj]) + ')' + + r = self.peek(repr(obj)) + if isinstance(obj, (str, int, float)): + return r + return '%s: %s' % (type(obj), r) + + def format(self, tree): + """Return a list of string reprs from a nested list of referrers.""" + output = [] + + def ascend(branch, depth=1): + for parent, grandparents in branch: + output.append((' ' * depth) + self._format(parent)) + if grandparents: + ascend(grandparents, depth + 1) + ascend(tree) + return output + + +def get_instances(cls): + return [x for x in gc.get_objects() if isinstance(x, cls)] + + +class RequestCounter(SimplePlugin): + + def start(self): + self.count = 0 + + def before_request(self): + self.count += 1 + + def after_request(self): + self.count -= 1 + + +request_counter = RequestCounter(cherrypy.engine) +request_counter.subscribe() + + +def get_context(obj): + if isinstance(obj, _cprequest.Request): + return 'path=%s;stage=%s' % (obj.path_info, obj.stage) + elif isinstance(obj, _cprequest.Response): + return 'status=%s' % obj.status + elif isinstance(obj, _cpwsgi.AppResponse): + return 'PATH_INFO=%s' % obj.environ.get('PATH_INFO', '') + elif hasattr(obj, 'tb_lineno'): + return 'tb_lineno=%s' % obj.tb_lineno + return '' + + +class GCRoot(object): + + """A CherryPy page handler for testing reference leaks.""" + + classes = [ + (_cprequest.Request, 2, 2, + 'Should be 1 in this request thread and 1 in the main thread.'), + (_cprequest.Response, 2, 2, + 'Should be 1 in this request thread and 1 in the main thread.'), + (_cpwsgi.AppResponse, 1, 1, + 'Should be 1 in this request thread only.'), + ] + + @cherrypy.expose + def index(self): + return 'Hello, world!' + + @cherrypy.expose + def stats(self): + output = ['Statistics:'] + + for trial in range(10): + if request_counter.count > 0: + break + time.sleep(0.5) + else: + output.append('\nNot all requests closed properly.') + + # gc_collect isn't perfectly synchronous, because it may + # break reference cycles that then take time to fully + # finalize. Call it thrice and hope for the best. + gc.collect() + gc.collect() + unreachable = gc.collect() + if unreachable: + if objgraph is not None: + final = objgraph.by_type('Nondestructible') + if final: + objgraph.show_backrefs(final, filename='finalizers.png') + + trash = {} + for x in gc.garbage: + trash[type(x)] = trash.get(type(x), 0) + 1 + if trash: + output.insert(0, '\n%s unreachable objects:' % unreachable) + trash = [(v, k) for k, v in trash.items()] + trash.sort() + for pair in trash: + output.append(' ' + repr(pair)) + + # Check declared classes to verify uncollected instances. + # These don't have to be part of a cycle; they can be + # any objects that have unanticipated referrers that keep + # them from being collected. + allobjs = {} + for cls, minobj, maxobj, msg in self.classes: + allobjs[cls] = get_instances(cls) + + for cls, minobj, maxobj, msg in self.classes: + objs = allobjs[cls] + lenobj = len(objs) + if lenobj < minobj or lenobj > maxobj: + if minobj == maxobj: + output.append( + '\nExpected %s %r references, got %s.' % + (minobj, cls, lenobj)) + else: + output.append( + '\nExpected %s to %s %r references, got %s.' % + (minobj, maxobj, cls, lenobj)) + + for obj in objs: + if objgraph is not None: + ig = [id(objs), id(inspect.currentframe())] + fname = 'graph_%s_%s.png' % (cls.__name__, id(obj)) + objgraph.show_backrefs( + obj, extra_ignore=ig, max_depth=4, too_many=20, + filename=fname, extra_info=get_context) + output.append('\nReferrers for %s (refcount=%s):' % + (repr(obj), sys.getrefcount(obj))) + t = ReferrerTree(ignore=[objs], maxdepth=3) + tree = t.ascend(obj) + output.extend(t.format(tree)) + + return '\n'.join(output) diff --git a/resources/lib/cherrypy/lib/httputil.py b/resources/lib/cherrypy/lib/httputil.py new file mode 100644 index 0000000..eedf8d8 --- /dev/null +++ b/resources/lib/cherrypy/lib/httputil.py @@ -0,0 +1,518 @@ +"""HTTP library functions. + +This module contains functions for building an HTTP application +framework: any one, not just one whose name starts with "Ch". ;) If you +reference any modules from some popular framework inside *this* module, +FuManChu will personally hang you up by your thumbs and submit you +to a public caning. +""" + +import functools +import email.utils +import re +import builtins +from binascii import b2a_base64 +from cgi import parse_header +from email.header import decode_header +from http.server import BaseHTTPRequestHandler +from urllib.parse import unquote_plus + +import jaraco.collections + +import cherrypy +from cherrypy._cpcompat import ntob, ntou + +response_codes = BaseHTTPRequestHandler.responses.copy() + +# From https://github.com/cherrypy/cherrypy/issues/361 +response_codes[500] = ('Internal Server Error', + 'The server encountered an unexpected condition ' + 'which prevented it from fulfilling the request.') +response_codes[503] = ('Service Unavailable', + 'The server is currently unable to handle the ' + 'request due to a temporary overloading or ' + 'maintenance of the server.') + + +HTTPDate = functools.partial(email.utils.formatdate, usegmt=True) + + +def urljoin(*atoms): + r"""Return the given path \*atoms, joined into a single URL. + + This will correctly join a SCRIPT_NAME and PATH_INFO into the + original URL, even if either atom is blank. + """ + url = '/'.join([x for x in atoms if x]) + while '//' in url: + url = url.replace('//', '/') + # Special-case the final url of "", and return "/" instead. + return url or '/' + + +def urljoin_bytes(*atoms): + """Return the given path `*atoms`, joined into a single URL. + + This will correctly join a SCRIPT_NAME and PATH_INFO into the + original URL, even if either atom is blank. + """ + url = b'/'.join([x for x in atoms if x]) + while b'//' in url: + url = url.replace(b'//', b'/') + # Special-case the final url of "", and return "/" instead. + return url or b'/' + + +def protocol_from_http(protocol_str): + """Return a protocol tuple from the given 'HTTP/x.y' string.""" + return int(protocol_str[5]), int(protocol_str[7]) + + +def get_ranges(headervalue, content_length): + """Return a list of (start, stop) indices from a Range header, or None. + + Each (start, stop) tuple will be composed of two ints, which are suitable + for use in a slicing operation. That is, the header "Range: bytes=3-6", + if applied against a Python string, is requesting resource[3:7]. This + function will return the list [(3, 7)]. + + If this function returns an empty list, you should return HTTP 416. + """ + + if not headervalue: + return None + + result = [] + bytesunit, byteranges = headervalue.split('=', 1) + for brange in byteranges.split(','): + start, stop = [x.strip() for x in brange.split('-', 1)] + if start: + if not stop: + stop = content_length - 1 + start, stop = int(start), int(stop) + if start >= content_length: + # From rfc 2616 sec 14.16: + # "If the server receives a request (other than one + # including an If-Range request-header field) with an + # unsatisfiable Range request-header field (that is, + # all of whose byte-range-spec values have a first-byte-pos + # value greater than the current length of the selected + # resource), it SHOULD return a response code of 416 + # (Requested range not satisfiable)." + continue + if stop < start: + # From rfc 2616 sec 14.16: + # "If the server ignores a byte-range-spec because it + # is syntactically invalid, the server SHOULD treat + # the request as if the invalid Range header field + # did not exist. (Normally, this means return a 200 + # response containing the full entity)." + return None + result.append((start, stop + 1)) + else: + if not stop: + # See rfc quote above. + return None + # Negative subscript (last N bytes) + # + # RFC 2616 Section 14.35.1: + # If the entity is shorter than the specified suffix-length, + # the entire entity-body is used. + if int(stop) > content_length: + result.append((0, content_length)) + else: + result.append((content_length - int(stop), content_length)) + + return result + + +class HeaderElement(object): + + """An element (with parameters) from an HTTP header's element list.""" + + def __init__(self, value, params=None): + self.value = value + if params is None: + params = {} + self.params = params + + def __cmp__(self, other): + return builtins.cmp(self.value, other.value) + + def __lt__(self, other): + return self.value < other.value + + def __str__(self): + p = [';%s=%s' % (k, v) for k, v in self.params.items()] + return str('%s%s' % (self.value, ''.join(p))) + + def __bytes__(self): + return ntob(self.__str__()) + + def __unicode__(self): + return ntou(self.__str__()) + + @staticmethod + def parse(elementstr): + """Transform 'token;key=val' to ('token', {'key': 'val'}).""" + initial_value, params = parse_header(elementstr) + return initial_value, params + + @classmethod + def from_str(cls, elementstr): + """Construct an instance from a string of the form 'token;key=val'.""" + ival, params = cls.parse(elementstr) + return cls(ival, params) + + +q_separator = re.compile(r'; *q *=') + + +class AcceptElement(HeaderElement): + + """An element (with parameters) from an Accept* header's element list. + + AcceptElement objects are comparable; the more-preferred object will be + "less than" the less-preferred object. They are also therefore sortable; + if you sort a list of AcceptElement objects, they will be listed in + priority order; the most preferred value will be first. Yes, it should + have been the other way around, but it's too late to fix now. + """ + + @classmethod + def from_str(cls, elementstr): + qvalue = None + # The first "q" parameter (if any) separates the initial + # media-range parameter(s) (if any) from the accept-params. + atoms = q_separator.split(elementstr, 1) + media_range = atoms.pop(0).strip() + if atoms: + # The qvalue for an Accept header can have extensions. The other + # headers cannot, but it's easier to parse them as if they did. + qvalue = HeaderElement.from_str(atoms[0].strip()) + + media_type, params = cls.parse(media_range) + if qvalue is not None: + params['q'] = qvalue + return cls(media_type, params) + + @property + def qvalue(self): + 'The qvalue, or priority, of this value.' + val = self.params.get('q', '1') + if isinstance(val, HeaderElement): + val = val.value + try: + return float(val) + except ValueError as val_err: + """Fail client requests with invalid quality value. + + Ref: https://github.com/cherrypy/cherrypy/issues/1370 + """ + raise cherrypy.HTTPError( + 400, + 'Malformed HTTP header: `{}`'. + format(str(self)), + ) from val_err + + def __cmp__(self, other): + diff = builtins.cmp(self.qvalue, other.qvalue) + if diff == 0: + diff = builtins.cmp(str(self), str(other)) + return diff + + def __lt__(self, other): + if self.qvalue == other.qvalue: + return str(self) < str(other) + else: + return self.qvalue < other.qvalue + + +RE_HEADER_SPLIT = re.compile(',(?=(?:[^"]*"[^"]*")*[^"]*$)') + + +def header_elements(fieldname, fieldvalue): + """Return a sorted HeaderElement list from a comma-separated header string. + """ + if not fieldvalue: + return [] + + result = [] + for element in RE_HEADER_SPLIT.split(fieldvalue): + if fieldname.startswith('Accept') or fieldname == 'TE': + hv = AcceptElement.from_str(element) + else: + hv = HeaderElement.from_str(element) + result.append(hv) + + return list(reversed(sorted(result))) + + +def decode_TEXT(value): + r""" + Decode :rfc:`2047` TEXT + + >>> decode_TEXT("=?utf-8?q?f=C3=BCr?=") == b'f\xfcr'.decode('latin-1') + True + """ + atoms = decode_header(value) + decodedvalue = '' + for atom, charset in atoms: + if charset is not None: + atom = atom.decode(charset) + decodedvalue += atom + return decodedvalue + + +def decode_TEXT_maybe(value): + """ + Decode the text but only if '=?' appears in it. + """ + return decode_TEXT(value) if '=?' in value else value + + +def valid_status(status): + """Return legal HTTP status Code, Reason-phrase and Message. + + The status arg must be an int, a str that begins with an int + or the constant from ``http.client`` stdlib module. + + If status has no reason-phrase is supplied, a default reason- + phrase will be provided. + + >>> import http.client + >>> from http.server import BaseHTTPRequestHandler + >>> valid_status(http.client.ACCEPTED) == ( + ... int(http.client.ACCEPTED), + ... ) + BaseHTTPRequestHandler.responses[http.client.ACCEPTED] + True + """ + + if not status: + status = 200 + + code, reason = status, None + if isinstance(status, str): + code, _, reason = status.partition(' ') + reason = reason.strip() or None + + try: + code = int(code) + except (TypeError, ValueError): + raise ValueError('Illegal response status from server ' + '(%s is non-numeric).' % repr(code)) + + if code < 100 or code > 599: + raise ValueError('Illegal response status from server ' + '(%s is out of range).' % repr(code)) + + if code not in response_codes: + # code is unknown but not illegal + default_reason, message = '', '' + else: + default_reason, message = response_codes[code] + + if reason is None: + reason = default_reason + + return code, reason, message + + +# NOTE: the parse_qs functions that follow are modified version of those +# in the python3.0 source - we need to pass through an encoding to the unquote +# method, but the default parse_qs function doesn't allow us to. These do. + +def _parse_qs(qs, keep_blank_values=0, strict_parsing=0, encoding='utf-8'): + """Parse a query given as a string argument. + + Arguments: + + qs: URL-encoded query string to be parsed + + keep_blank_values: flag indicating whether blank values in + URL encoded queries should be treated as blank strings. A + true value indicates that blanks should be retained as blank + strings. The default false value indicates that blank values + are to be ignored and treated as if they were not included. + + strict_parsing: flag indicating what to do with parsing errors. If + false (the default), errors are silently ignored. If true, + errors raise a ValueError exception. + + Returns a dict, as G-d intended. + """ + pairs = [s2 for s1 in qs.split('&') for s2 in s1.split(';')] + d = {} + for name_value in pairs: + if not name_value and not strict_parsing: + continue + nv = name_value.split('=', 1) + if len(nv) != 2: + if strict_parsing: + raise ValueError('bad query field: %r' % (name_value,)) + # Handle case of a control-name with no equal sign + if keep_blank_values: + nv.append('') + else: + continue + if len(nv[1]) or keep_blank_values: + name = unquote_plus(nv[0], encoding, errors='strict') + value = unquote_plus(nv[1], encoding, errors='strict') + if name in d: + if not isinstance(d[name], list): + d[name] = [d[name]] + d[name].append(value) + else: + d[name] = value + return d + + +image_map_pattern = re.compile(r'[0-9]+,[0-9]+') + + +def parse_query_string(query_string, keep_blank_values=True, encoding='utf-8'): + """Build a params dictionary from a query_string. + + Duplicate key/value pairs in the provided query_string will be + returned as {'key': [val1, val2, ...]}. Single key/values will + be returned as strings: {'key': 'value'}. + """ + if image_map_pattern.match(query_string): + # Server-side image map. Map the coords to 'x' and 'y' + # (like CGI::Request does). + pm = query_string.split(',') + pm = {'x': int(pm[0]), 'y': int(pm[1])} + else: + pm = _parse_qs(query_string, keep_blank_values, encoding=encoding) + return pm + + +class CaseInsensitiveDict(jaraco.collections.KeyTransformingDict): + + """A case-insensitive dict subclass. + + Each key is changed on entry to title case. + """ + + @staticmethod + def transform_key(key): + if key is None: + # TODO(#1830): why? + return 'None' + return key.title() + + +# TEXT = +# +# A CRLF is allowed in the definition of TEXT only as part of a header +# field continuation. It is expected that the folding LWS will be +# replaced with a single SP before interpretation of the TEXT value." +if str == bytes: + header_translate_table = ''.join([chr(i) for i in range(256)]) + header_translate_deletechars = ''.join( + [chr(i) for i in range(32)]) + chr(127) +else: + header_translate_table = None + header_translate_deletechars = bytes(range(32)) + bytes([127]) + + +class HeaderMap(CaseInsensitiveDict): + + """A dict subclass for HTTP request and response headers. + + Each key is changed on entry to str(key).title(). This allows headers + to be case-insensitive and avoid duplicates. + + Values are header values (decoded according to :rfc:`2047` if necessary). + """ + + protocol = (1, 1) + encodings = ['ISO-8859-1'] + + # Someday, when http-bis is done, this will probably get dropped + # since few servers, clients, or intermediaries do it. But until then, + # we're going to obey the spec as is. + # "Words of *TEXT MAY contain characters from character sets other than + # ISO-8859-1 only when encoded according to the rules of RFC 2047." + use_rfc_2047 = True + + def elements(self, key): + """Return a sorted list of HeaderElements for the given header.""" + return header_elements(self.transform_key(key), self.get(key)) + + def values(self, key): + """Return a sorted list of HeaderElement.value for the given header.""" + return [e.value for e in self.elements(key)] + + def output(self): + """Transform self into a list of (name, value) tuples.""" + return list(self.encode_header_items(self.items())) + + @classmethod + def encode_header_items(cls, header_items): + """ + Prepare the sequence of name, value tuples into a form suitable for + transmitting on the wire for HTTP. + """ + for k, v in header_items: + if not isinstance(v, str) and not isinstance(v, bytes): + v = str(v) + + yield tuple(map(cls.encode_header_item, (k, v))) + + @classmethod + def encode_header_item(cls, item): + if isinstance(item, str): + item = cls.encode(item) + + # See header_translate_* constants above. + # Replace only if you really know what you're doing. + return item.translate( + header_translate_table, header_translate_deletechars) + + @classmethod + def encode(cls, v): + """Return the given header name or value, encoded for HTTP output.""" + for enc in cls.encodings: + try: + return v.encode(enc) + except UnicodeEncodeError: + continue + + if cls.protocol == (1, 1) and cls.use_rfc_2047: + # Encode RFC-2047 TEXT + # (e.g. u"\u8200" -> "=?utf-8?b?6IiA?="). + # We do our own here instead of using the email module + # because we never want to fold lines--folding has + # been deprecated by the HTTP working group. + v = b2a_base64(v.encode('utf-8')) + return (b'=?utf-8?b?' + v.strip(b'\n') + b'?=') + + raise ValueError('Could not encode header part %r using ' + 'any of the encodings %r.' % + (v, cls.encodings)) + + +class Host(object): + + """An internet address. + + name + Should be the client's host name. If not available (because no DNS + lookup is performed), the IP address should be used instead. + + """ + + ip = '0.0.0.0' + port = 80 + name = 'unknown.tld' + + def __init__(self, ip, port, name=None): + self.ip = ip + self.port = port + if name is None: + name = ip + self.name = name + + def __repr__(self): + return 'httputil.Host(%r, %r, %r)' % (self.ip, self.port, self.name) diff --git a/resources/lib/cherrypy/lib/jsontools.py b/resources/lib/cherrypy/lib/jsontools.py new file mode 100644 index 0000000..9ca75a8 --- /dev/null +++ b/resources/lib/cherrypy/lib/jsontools.py @@ -0,0 +1,89 @@ +import cherrypy +from cherrypy import _json as json +from cherrypy._cpcompat import text_or_bytes, ntou + + +def json_processor(entity): + """Read application/json data into request.json.""" + if not entity.headers.get(ntou('Content-Length'), ntou('')): + raise cherrypy.HTTPError(411) + + body = entity.fp.read() + with cherrypy.HTTPError.handle(ValueError, 400, 'Invalid JSON document'): + cherrypy.serving.request.json = json.decode(body.decode('utf-8')) + + +def json_in(content_type=[ntou('application/json'), ntou('text/javascript')], + force=True, debug=False, processor=json_processor): + """Add a processor to parse JSON request entities: + The default processor places the parsed data into request.json. + + Incoming request entities which match the given content_type(s) will + be deserialized from JSON to the Python equivalent, and the result + stored at cherrypy.request.json. The 'content_type' argument may + be a Content-Type string or a list of allowable Content-Type strings. + + If the 'force' argument is True (the default), then entities of other + content types will not be allowed; "415 Unsupported Media Type" is + raised instead. + + Supply your own processor to use a custom decoder, or to handle the parsed + data differently. The processor can be configured via + tools.json_in.processor or via the decorator method. + + Note that the deserializer requires the client send a Content-Length + request header, or it will raise "411 Length Required". If for any + other reason the request entity cannot be deserialized from JSON, + it will raise "400 Bad Request: Invalid JSON document". + """ + request = cherrypy.serving.request + if isinstance(content_type, text_or_bytes): + content_type = [content_type] + + if force: + if debug: + cherrypy.log('Removing body processors %s' % + repr(request.body.processors.keys()), 'TOOLS.JSON_IN') + request.body.processors.clear() + request.body.default_proc = cherrypy.HTTPError( + 415, 'Expected an entity of content type %s' % + ', '.join(content_type)) + + for ct in content_type: + if debug: + cherrypy.log('Adding body processor for %s' % ct, 'TOOLS.JSON_IN') + request.body.processors[ct] = processor + + +def json_handler(*args, **kwargs): + value = cherrypy.serving.request._json_inner_handler(*args, **kwargs) + return json.encode(value) + + +def json_out(content_type='application/json', debug=False, + handler=json_handler): + """Wrap request.handler to serialize its output to JSON. Sets Content-Type. + + If the given content_type is None, the Content-Type response header + is not set. + + Provide your own handler to use a custom encoder. For example + cherrypy.config['tools.json_out.handler'] = , or + @json_out(handler=function). + """ + request = cherrypy.serving.request + # request.handler may be set to None by e.g. the caching tool + # to signal to all components that a response body has already + # been attached, in which case we don't need to wrap anything. + if request.handler is None: + return + if debug: + cherrypy.log('Replacing %s with JSON handler' % request.handler, + 'TOOLS.JSON_OUT') + request._json_inner_handler = request.handler + request.handler = handler + if content_type is not None: + if debug: + cherrypy.log('Setting Content-Type to %s' % + content_type, 'TOOLS.JSON_OUT') + cherrypy.serving.response.headers['Content-Type'] = content_type diff --git a/resources/lib/cherrypy/lib/locking.py b/resources/lib/cherrypy/lib/locking.py new file mode 100644 index 0000000..317fb58 --- /dev/null +++ b/resources/lib/cherrypy/lib/locking.py @@ -0,0 +1,47 @@ +import datetime + + +class NeverExpires(object): + def expired(self): + return False + + +class Timer(object): + """ + A simple timer that will indicate when an expiration time has passed. + """ + def __init__(self, expiration): + 'Create a timer that expires at `expiration` (UTC datetime)' + self.expiration = expiration + + @classmethod + def after(cls, elapsed): + """ + Return a timer that will expire after `elapsed` passes. + """ + return cls(datetime.datetime.utcnow() + elapsed) + + def expired(self): + return datetime.datetime.utcnow() >= self.expiration + + +class LockTimeout(Exception): + 'An exception when a lock could not be acquired before a timeout period' + + +class LockChecker(object): + """ + Keep track of the time and detect if a timeout has expired + """ + def __init__(self, session_id, timeout): + self.session_id = session_id + if timeout: + self.timer = Timer.after(timeout) + else: + self.timer = NeverExpires() + + def expired(self): + if self.timer.expired(): + raise LockTimeout( + 'Timeout acquiring lock for %(session_id)s' % vars(self)) + return False diff --git a/resources/lib/cherrypy/lib/profiler.py b/resources/lib/cherrypy/lib/profiler.py new file mode 100644 index 0000000..fccf2eb --- /dev/null +++ b/resources/lib/cherrypy/lib/profiler.py @@ -0,0 +1,221 @@ +"""Profiler tools for CherryPy. + +CherryPy users +============== + +You can profile any of your pages as follows:: + + from cherrypy.lib import profiler + + class Root: + p = profiler.Profiler("/path/to/profile/dir") + + @cherrypy.expose + def index(self): + self.p.run(self._index) + + def _index(self): + return "Hello, world!" + + cherrypy.tree.mount(Root()) + +You can also turn on profiling for all requests +using the ``make_app`` function as WSGI middleware. + +CherryPy developers +=================== + +This module can be used whenever you make changes to CherryPy, +to get a quick sanity-check on overall CP performance. Use the +``--profile`` flag when running the test suite. Then, use the ``serve()`` +function to browse the results in a web browser. If you run this +module from the command line, it will call ``serve()`` for you. + +""" + +import io +import os +import os.path +import sys +import warnings + +import cherrypy + + +try: + import profile + import pstats + + def new_func_strip_path(func_name): + """Make profiler output more readable by adding `__init__` modules' parents + """ + filename, line, name = func_name + if filename.endswith('__init__.py'): + return ( + os.path.basename(filename[:-12]) + filename[-12:], + line, + name, + ) + return os.path.basename(filename), line, name + + pstats.func_strip_path = new_func_strip_path +except ImportError: + profile = None + pstats = None + + +_count = 0 + + +class Profiler(object): + + def __init__(self, path=None): + if not path: + path = os.path.join(os.path.dirname(__file__), 'profile') + self.path = path + if not os.path.exists(path): + os.makedirs(path) + + def run(self, func, *args, **params): + """Dump profile data into self.path.""" + global _count + c = _count = _count + 1 + path = os.path.join(self.path, 'cp_%04d.prof' % c) + prof = profile.Profile() + result = prof.runcall(func, *args, **params) + prof.dump_stats(path) + return result + + def statfiles(self): + """:rtype: list of available profiles. + """ + return [f for f in os.listdir(self.path) + if f.startswith('cp_') and f.endswith('.prof')] + + def stats(self, filename, sortby='cumulative'): + """:rtype stats(index): output of print_stats() for the given profile. + """ + sio = io.StringIO() + if sys.version_info >= (2, 5): + s = pstats.Stats(os.path.join(self.path, filename), stream=sio) + s.strip_dirs() + s.sort_stats(sortby) + s.print_stats() + else: + # pstats.Stats before Python 2.5 didn't take a 'stream' arg, + # but just printed to stdout. So re-route stdout. + s = pstats.Stats(os.path.join(self.path, filename)) + s.strip_dirs() + s.sort_stats(sortby) + oldout = sys.stdout + try: + sys.stdout = sio + s.print_stats() + finally: + sys.stdout = oldout + response = sio.getvalue() + sio.close() + return response + + @cherrypy.expose + def index(self): + return """ + CherryPy profile data + + + + + + """ + + @cherrypy.expose + def menu(self): + yield '

Profiling runs

' + yield '

Click on one of the runs below to see profiling data.

' + runs = self.statfiles() + runs.sort() + for i in runs: + yield "%s
" % ( + i, i) + + @cherrypy.expose + def report(self, filename): + cherrypy.response.headers['Content-Type'] = 'text/plain' + return self.stats(filename) + + +class ProfileAggregator(Profiler): + + def __init__(self, path=None): + Profiler.__init__(self, path) + global _count + self.count = _count = _count + 1 + self.profiler = profile.Profile() + + def run(self, func, *args, **params): + path = os.path.join(self.path, 'cp_%04d.prof' % self.count) + result = self.profiler.runcall(func, *args, **params) + self.profiler.dump_stats(path) + return result + + +class make_app: + + def __init__(self, nextapp, path=None, aggregate=False): + """Make a WSGI middleware app which wraps 'nextapp' with profiling. + + nextapp + the WSGI application to wrap, usually an instance of + cherrypy.Application. + + path + where to dump the profiling output. + + aggregate + if True, profile data for all HTTP requests will go in + a single file. If False (the default), each HTTP request will + dump its profile data into a separate file. + + """ + if profile is None or pstats is None: + msg = ('Your installation of Python does not have a profile ' + "module. If you're on Debian, try " + '`sudo apt-get install python-profiler`. ' + 'See http://www.cherrypy.org/wiki/ProfilingOnDebian ' + 'for details.') + warnings.warn(msg) + + self.nextapp = nextapp + self.aggregate = aggregate + if aggregate: + self.profiler = ProfileAggregator(path) + else: + self.profiler = Profiler(path) + + def __call__(self, environ, start_response): + def gather(): + result = [] + for line in self.nextapp(environ, start_response): + result.append(line) + return result + return self.profiler.run(gather) + + +def serve(path=None, port=8080): + if profile is None or pstats is None: + msg = ('Your installation of Python does not have a profile module. ' + "If you're on Debian, try " + '`sudo apt-get install python-profiler`. ' + 'See http://www.cherrypy.org/wiki/ProfilingOnDebian ' + 'for details.') + warnings.warn(msg) + + cherrypy.config.update({'server.socket_port': int(port), + 'server.thread_pool': 10, + 'environment': 'production', + }) + cherrypy.quickstart(Profiler(path)) + + +if __name__ == '__main__': + serve(*tuple(sys.argv[1:])) diff --git a/resources/lib/cherrypy/lib/reprconf.py b/resources/lib/cherrypy/lib/reprconf.py new file mode 100644 index 0000000..3976652 --- /dev/null +++ b/resources/lib/cherrypy/lib/reprconf.py @@ -0,0 +1,397 @@ +"""Generic configuration system using unrepr. + +Configuration data may be supplied as a Python dictionary, as a filename, +or as an open file object. When you supply a filename or file, Python's +builtin ConfigParser is used (with some extensions). + +Namespaces +---------- + +Configuration keys are separated into namespaces by the first "." in the key. + +The only key that cannot exist in a namespace is the "environment" entry. +This special entry 'imports' other config entries from a template stored in +the Config.environments dict. + +You can define your own namespaces to be called when new config is merged +by adding a named handler to Config.namespaces. The name can be any string, +and the handler must be either a callable or a context manager. +""" + +import builtins +import configparser +import operator +import sys + +from cherrypy._cpcompat import text_or_bytes + + +class NamespaceSet(dict): + + """A dict of config namespace names and handlers. + + Each config entry should begin with a namespace name; the corresponding + namespace handler will be called once for each config entry in that + namespace, and will be passed two arguments: the config key (with the + namespace removed) and the config value. + + Namespace handlers may be any Python callable; they may also be + context managers, in which case their __enter__ + method should return a callable to be used as the handler. + See cherrypy.tools (the Toolbox class) for an example. + """ + + def __call__(self, config): + """Iterate through config and pass it to each namespace handler. + + config + A flat dict, where keys use dots to separate + namespaces, and values are arbitrary. + + The first name in each config key is used to look up the corresponding + namespace handler. For example, a config entry of {'tools.gzip.on': v} + will call the 'tools' namespace handler with the args: ('gzip.on', v) + """ + # Separate the given config into namespaces + ns_confs = {} + for k in config: + if '.' in k: + ns, name = k.split('.', 1) + bucket = ns_confs.setdefault(ns, {}) + bucket[name] = config[k] + + # I chose __enter__ and __exit__ so someday this could be + # rewritten using 'with' statement: + # for ns, handler in self.items(): + # with handler as callable: + # for k, v in ns_confs.get(ns, {}).items(): + # callable(k, v) + for ns, handler in self.items(): + exit = getattr(handler, '__exit__', None) + if exit: + callable = handler.__enter__() + no_exc = True + try: + try: + for k, v in ns_confs.get(ns, {}).items(): + callable(k, v) + except Exception: + # The exceptional case is handled here + no_exc = False + if exit is None: + raise + if not exit(*sys.exc_info()): + raise + # The exception is swallowed if exit() returns true + finally: + # The normal and non-local-goto cases are handled here + if no_exc and exit: + exit(None, None, None) + else: + for k, v in ns_confs.get(ns, {}).items(): + handler(k, v) + + def __repr__(self): + return '%s.%s(%s)' % (self.__module__, self.__class__.__name__, + dict.__repr__(self)) + + def __copy__(self): + newobj = self.__class__() + newobj.update(self) + return newobj + copy = __copy__ + + +class Config(dict): + + """A dict-like set of configuration data, with defaults and namespaces. + + May take a file, filename, or dict. + """ + + defaults = {} + environments = {} + namespaces = NamespaceSet() + + def __init__(self, file=None, **kwargs): + self.reset() + if file is not None: + self.update(file) + if kwargs: + self.update(kwargs) + + def reset(self): + """Reset self to default values.""" + self.clear() + dict.update(self, self.defaults) + + def update(self, config): + """Update self from a dict, file, or filename.""" + self._apply(Parser.load(config)) + + def _apply(self, config): + """Update self from a dict.""" + which_env = config.get('environment') + if which_env: + env = self.environments[which_env] + for k in env: + if k not in config: + config[k] = env[k] + + dict.update(self, config) + self.namespaces(config) + + def __setitem__(self, k, v): + dict.__setitem__(self, k, v) + self.namespaces({k: v}) + + +class Parser(configparser.ConfigParser): + + """Sub-class of ConfigParser that keeps the case of options and that + raises an exception if the file cannot be read. + """ + + def optionxform(self, optionstr): + return optionstr + + def read(self, filenames): + if isinstance(filenames, text_or_bytes): + filenames = [filenames] + for filename in filenames: + # try: + # fp = open(filename) + # except IOError: + # continue + fp = open(filename) + try: + self._read(fp, filename) + finally: + fp.close() + + def as_dict(self, raw=False, vars=None): + """Convert an INI file to a dictionary""" + # Load INI file into a dict + result = {} + for section in self.sections(): + if section not in result: + result[section] = {} + for option in self.options(section): + value = self.get(section, option, raw=raw, vars=vars) + try: + value = unrepr(value) + except Exception: + x = sys.exc_info()[1] + msg = ('Config error in section: %r, option: %r, ' + 'value: %r. Config values must be valid Python.' % + (section, option, value)) + raise ValueError(msg, x.__class__.__name__, x.args) + result[section][option] = value + return result + + def dict_from_file(self, file): + if hasattr(file, 'read'): + self.readfp(file) + else: + self.read(file) + return self.as_dict() + + @classmethod + def load(self, input): + """Resolve 'input' to dict from a dict, file, or filename.""" + is_file = ( + # Filename + isinstance(input, text_or_bytes) + # Open file object + or hasattr(input, 'read') + ) + return Parser().dict_from_file(input) if is_file else input.copy() + + +# public domain "unrepr" implementation, found on the web and then improved. + + +class _Builder: + + def build(self, o): + m = getattr(self, 'build_' + o.__class__.__name__, None) + if m is None: + raise TypeError('unrepr does not recognize %s' % + repr(o.__class__.__name__)) + return m(o) + + def astnode(self, s): + """Return a Python3 ast Node compiled from a string.""" + try: + import ast + except ImportError: + # Fallback to eval when ast package is not available, + # e.g. IronPython 1.0. + return eval(s) + + p = ast.parse('__tempvalue__ = ' + s) + return p.body[0].value + + def build_Subscript(self, o): + return self.build(o.value)[self.build(o.slice)] + + def build_Index(self, o): + return self.build(o.value) + + def _build_call35(self, o): + """ + Workaround for python 3.5 _ast.Call signature, docs found here + https://greentreesnakes.readthedocs.org/en/latest/nodes.html + """ + import ast + callee = self.build(o.func) + args = [] + if o.args is not None: + for a in o.args: + if isinstance(a, ast.Starred): + args.append(self.build(a.value)) + else: + args.append(self.build(a)) + kwargs = {} + for kw in o.keywords: + if kw.arg is None: # double asterix `**` + rst = self.build(kw.value) + if not isinstance(rst, dict): + raise TypeError('Invalid argument for call.' + 'Must be a mapping object.') + # give preference to the keys set directly from arg=value + for k, v in rst.items(): + if k not in kwargs: + kwargs[k] = v + else: # defined on the call as: arg=value + kwargs[kw.arg] = self.build(kw.value) + return callee(*args, **kwargs) + + def build_Call(self, o): + if sys.version_info >= (3, 5): + return self._build_call35(o) + + callee = self.build(o.func) + + if o.args is None: + args = () + else: + args = tuple([self.build(a) for a in o.args]) + + if o.starargs is None: + starargs = () + else: + starargs = tuple(self.build(o.starargs)) + + if o.kwargs is None: + kwargs = {} + else: + kwargs = self.build(o.kwargs) + if o.keywords is not None: # direct a=b keywords + for kw in o.keywords: + # preference because is a direct keyword against **kwargs + kwargs[kw.arg] = self.build(kw.value) + return callee(*(args + starargs), **kwargs) + + def build_List(self, o): + return list(map(self.build, o.elts)) + + def build_Str(self, o): + return o.s + + def build_Num(self, o): + return o.n + + def build_Dict(self, o): + return dict([(self.build(k), self.build(v)) + for k, v in zip(o.keys, o.values)]) + + def build_Tuple(self, o): + return tuple(self.build_List(o)) + + def build_Name(self, o): + name = o.id + if name == 'None': + return None + if name == 'True': + return True + if name == 'False': + return False + + # See if the Name is a package or module. If it is, import it. + try: + return modules(name) + except ImportError: + pass + + # See if the Name is in builtins. + try: + return getattr(builtins, name) + except AttributeError: + pass + + raise TypeError('unrepr could not resolve the name %s' % repr(name)) + + def build_NameConstant(self, o): + return o.value + + build_Constant = build_NameConstant # Python 3.8 change + + def build_UnaryOp(self, o): + op, operand = map(self.build, [o.op, o.operand]) + return op(operand) + + def build_BinOp(self, o): + left, op, right = map(self.build, [o.left, o.op, o.right]) + return op(left, right) + + def build_Add(self, o): + return operator.add + + def build_Mult(self, o): + return operator.mul + + def build_USub(self, o): + return operator.neg + + def build_Attribute(self, o): + parent = self.build(o.value) + return getattr(parent, o.attr) + + def build_NoneType(self, o): + return None + + +def unrepr(s): + """Return a Python object compiled from a string.""" + if not s: + return s + b = _Builder() + obj = b.astnode(s) + return b.build(obj) + + +def modules(modulePath): + """Load a module and retrieve a reference to that module.""" + __import__(modulePath) + return sys.modules[modulePath] + + +def attributes(full_attribute_name): + """Load a module and retrieve an attribute of that module.""" + + # Parse out the path, module, and attribute + last_dot = full_attribute_name.rfind('.') + attr_name = full_attribute_name[last_dot + 1:] + mod_path = full_attribute_name[:last_dot] + + mod = modules(mod_path) + # Let an AttributeError propagate outward. + try: + attr = getattr(mod, attr_name) + except AttributeError: + raise AttributeError("'%s' object has no attribute '%s'" + % (mod_path, attr_name)) + + # Return a reference to the attribute. + return attr diff --git a/resources/lib/cherrypy/lib/sessions.py b/resources/lib/cherrypy/lib/sessions.py new file mode 100644 index 0000000..5b3328f --- /dev/null +++ b/resources/lib/cherrypy/lib/sessions.py @@ -0,0 +1,910 @@ +"""Session implementation for CherryPy. + +You need to edit your config file to use sessions. Here's an example:: + + [/] + tools.sessions.on = True + tools.sessions.storage_class = cherrypy.lib.sessions.FileSession + tools.sessions.storage_path = "/home/site/sessions" + tools.sessions.timeout = 60 + +This sets the session to be stored in files in the directory +/home/site/sessions, and the session timeout to 60 minutes. If you omit +``storage_class``, the sessions will be saved in RAM. +``tools.sessions.on`` is the only required line for working sessions, +the rest are optional. + +By default, the session ID is passed in a cookie, so the client's browser must +have cookies enabled for your site. + +To set data for the current session, use +``cherrypy.session['fieldname'] = 'fieldvalue'``; +to get data use ``cherrypy.session.get('fieldname')``. + +================ +Locking sessions +================ + +By default, the ``'locking'`` mode of sessions is ``'implicit'``, which means +the session is locked early and unlocked late. Be mindful of this default mode +for any requests that take a long time to process (streaming responses, +expensive calculations, database lookups, API calls, etc), as other concurrent +requests that also utilize sessions will hang until the session is unlocked. + +If you want to control when the session data is locked and unlocked, +set ``tools.sessions.locking = 'explicit'``. Then call +``cherrypy.session.acquire_lock()`` and ``cherrypy.session.release_lock()``. +Regardless of which mode you use, the session is guaranteed to be unlocked when +the request is complete. + +================= +Expiring Sessions +================= + +You can force a session to expire with :func:`cherrypy.lib.sessions.expire`. +Simply call that function at the point you want the session to expire, and it +will cause the session cookie to expire client-side. + +=========================== +Session Fixation Protection +=========================== + +If CherryPy receives, via a request cookie, a session id that it does not +recognize, it will reject that id and create a new one to return in the +response cookie. This `helps prevent session fixation attacks +`_. +However, CherryPy "recognizes" a session id by looking up the saved session +data for that id. Therefore, if you never save any session data, +**you will get a new session id for every request**. + +A side effect of CherryPy overwriting unrecognised session ids is that if you +have multiple, separate CherryPy applications running on a single domain (e.g. +on different ports), each app will overwrite the other's session id because by +default they use the same cookie name (``"session_id"``) but do not recognise +each others sessions. It is therefore a good idea to use a different name for +each, for example:: + + [/] + ... + tools.sessions.name = "my_app_session_id" + +================ +Sharing Sessions +================ + +If you run multiple instances of CherryPy (for example via mod_python behind +Apache prefork), you most likely cannot use the RAM session backend, since each +instance of CherryPy will have its own memory space. Use a different backend +instead, and verify that all instances are pointing at the same file or db +location. Alternately, you might try a load balancer which makes sessions +"sticky". Google is your friend, there. + +================ +Expiration Dates +================ + +The response cookie will possess an expiration date to inform the client at +which point to stop sending the cookie back in requests. If the server time +and client time differ, expect sessions to be unreliable. **Make sure the +system time of your server is accurate**. + +CherryPy defaults to a 60-minute session timeout, which also applies to the +cookie which is sent to the client. Unfortunately, some versions of Safari +("4 public beta" on Windows XP at least) appear to have a bug in their parsing +of the GMT expiration date--they appear to interpret the date as one hour in +the past. Sixty minutes minus one hour is pretty close to zero, so you may +experience this bug as a new session id for every request, unless the requests +are less than one second apart. To fix, try increasing the session.timeout. + +On the other extreme, some users report Firefox sending cookies after their +expiration date, although this was on a system with an inaccurate system time. +Maybe FF doesn't trust system time. +""" +import sys +import datetime +import os +import time +import threading +import binascii +import pickle + +import zc.lockfile + +import cherrypy +from cherrypy.lib import httputil +from cherrypy.lib import locking +from cherrypy.lib import is_iterator + + +missing = object() + + +class Session(object): + + """A CherryPy dict-like Session object (one per request).""" + + _id = None + + id_observers = None + "A list of callbacks to which to pass new id's." + + @property + def id(self): + """Return the current session id.""" + return self._id + + @id.setter + def id(self, value): + self._id = value + for o in self.id_observers: + o(value) + + timeout = 60 + 'Number of minutes after which to delete session data.' + + locked = False + """ + If True, this session instance has exclusive read/write access + to session data.""" + + loaded = False + """ + If True, data has been retrieved from storage. This should happen + automatically on the first attempt to access session data.""" + + clean_thread = None + 'Class-level Monitor which calls self.clean_up.' + + clean_freq = 5 + 'The poll rate for expired session cleanup in minutes.' + + originalid = None + 'The session id passed by the client. May be missing or unsafe.' + + missing = False + 'True if the session requested by the client did not exist.' + + regenerated = False + """ + True if the application called session.regenerate(). This is not set by + internal calls to regenerate the session id.""" + + debug = False + 'If True, log debug information.' + + # --------------------- Session management methods --------------------- # + + def __init__(self, id=None, **kwargs): + self.id_observers = [] + self._data = {} + + for k, v in kwargs.items(): + setattr(self, k, v) + + self.originalid = id + self.missing = False + if id is None: + if self.debug: + cherrypy.log('No id given; making a new one', 'TOOLS.SESSIONS') + self._regenerate() + else: + self.id = id + if self._exists(): + if self.debug: + cherrypy.log('Set id to %s.' % id, 'TOOLS.SESSIONS') + else: + if self.debug: + cherrypy.log('Expired or malicious session %r; ' + 'making a new one' % id, 'TOOLS.SESSIONS') + # Expired or malicious session. Make a new one. + # See https://github.com/cherrypy/cherrypy/issues/709. + self.id = None + self.missing = True + self._regenerate() + + def now(self): + """Generate the session specific concept of 'now'. + + Other session providers can override this to use alternative, + possibly timezone aware, versions of 'now'. + """ + return datetime.datetime.now() + + def regenerate(self): + """Replace the current session (with a new id).""" + self.regenerated = True + self._regenerate() + + def _regenerate(self): + if self.id is not None: + if self.debug: + cherrypy.log( + 'Deleting the existing session %r before ' + 'regeneration.' % self.id, + 'TOOLS.SESSIONS') + self.delete() + + old_session_was_locked = self.locked + if old_session_was_locked: + self.release_lock() + if self.debug: + cherrypy.log('Old lock released.', 'TOOLS.SESSIONS') + + self.id = None + while self.id is None: + self.id = self.generate_id() + # Assert that the generated id is not already stored. + if self._exists(): + self.id = None + if self.debug: + cherrypy.log('Set id to generated %s.' % self.id, + 'TOOLS.SESSIONS') + + if old_session_was_locked: + self.acquire_lock() + if self.debug: + cherrypy.log('Regenerated lock acquired.', 'TOOLS.SESSIONS') + + def clean_up(self): + """Clean up expired sessions.""" + pass + + def generate_id(self): + """Return a new session id.""" + return binascii.hexlify(os.urandom(20)).decode('ascii') + + def save(self): + """Save session data.""" + try: + # If session data has never been loaded then it's never been + # accessed: no need to save it + if self.loaded: + t = datetime.timedelta(seconds=self.timeout * 60) + expiration_time = self.now() + t + if self.debug: + cherrypy.log('Saving session %r with expiry %s' % + (self.id, expiration_time), + 'TOOLS.SESSIONS') + self._save(expiration_time) + else: + if self.debug: + cherrypy.log( + 'Skipping save of session %r (no session loaded).' % + self.id, 'TOOLS.SESSIONS') + finally: + if self.locked: + # Always release the lock if the user didn't release it + self.release_lock() + if self.debug: + cherrypy.log('Lock released after save.', 'TOOLS.SESSIONS') + + def load(self): + """Copy stored session data into this session instance.""" + data = self._load() + # data is either None or a tuple (session_data, expiration_time) + if data is None or data[1] < self.now(): + if self.debug: + cherrypy.log('Expired session %r, flushing data.' % self.id, + 'TOOLS.SESSIONS') + self._data = {} + else: + if self.debug: + cherrypy.log('Data loaded for session %r.' % self.id, + 'TOOLS.SESSIONS') + self._data = data[0] + self.loaded = True + + # Stick the clean_thread in the class, not the instance. + # The instances are created and destroyed per-request. + cls = self.__class__ + if self.clean_freq and not cls.clean_thread: + # clean_up is an instancemethod and not a classmethod, + # so that tool config can be accessed inside the method. + t = cherrypy.process.plugins.Monitor( + cherrypy.engine, self.clean_up, self.clean_freq * 60, + name='Session cleanup') + t.subscribe() + cls.clean_thread = t + t.start() + if self.debug: + cherrypy.log('Started cleanup thread.', 'TOOLS.SESSIONS') + + def delete(self): + """Delete stored session data.""" + self._delete() + if self.debug: + cherrypy.log('Deleted session %s.' % self.id, + 'TOOLS.SESSIONS') + + # -------------------- Application accessor methods -------------------- # + + def __getitem__(self, key): + if not self.loaded: + self.load() + return self._data[key] + + def __setitem__(self, key, value): + if not self.loaded: + self.load() + self._data[key] = value + + def __delitem__(self, key): + if not self.loaded: + self.load() + del self._data[key] + + def pop(self, key, default=missing): + """Remove the specified key and return the corresponding value. + If key is not found, default is returned if given, + otherwise KeyError is raised. + """ + if not self.loaded: + self.load() + if default is missing: + return self._data.pop(key) + else: + return self._data.pop(key, default) + + def __contains__(self, key): + if not self.loaded: + self.load() + return key in self._data + + def get(self, key, default=None): + """D.get(k[,d]) -> D[k] if k in D, else d. d defaults to None.""" + if not self.loaded: + self.load() + return self._data.get(key, default) + + def update(self, d): + """D.update(E) -> None. Update D from E: for k in E: D[k] = E[k].""" + if not self.loaded: + self.load() + self._data.update(d) + + def setdefault(self, key, default=None): + """D.setdefault(k[,d]) -> D.get(k,d), also set D[k]=d if k not in D.""" + if not self.loaded: + self.load() + return self._data.setdefault(key, default) + + def clear(self): + """D.clear() -> None. Remove all items from D.""" + if not self.loaded: + self.load() + self._data.clear() + + def keys(self): + """D.keys() -> list of D's keys.""" + if not self.loaded: + self.load() + return self._data.keys() + + def items(self): + """D.items() -> list of D's (key, value) pairs, as 2-tuples.""" + if not self.loaded: + self.load() + return self._data.items() + + def values(self): + """D.values() -> list of D's values.""" + if not self.loaded: + self.load() + return self._data.values() + + +class RamSession(Session): + + # Class-level objects. Don't rebind these! + cache = {} + locks = {} + + def clean_up(self): + """Clean up expired sessions.""" + + now = self.now() + for _id, (data, expiration_time) in self.cache.copy().items(): + if expiration_time <= now: + try: + del self.cache[_id] + except KeyError: + pass + try: + if self.locks[_id].acquire(blocking=False): + lock = self.locks.pop(_id) + lock.release() + except KeyError: + pass + + # added to remove obsolete lock objects + for _id in list(self.locks): + locked = ( + _id not in self.cache + and self.locks[_id].acquire(blocking=False) + ) + if locked: + lock = self.locks.pop(_id) + lock.release() + + def _exists(self): + return self.id in self.cache + + def _load(self): + return self.cache.get(self.id) + + def _save(self, expiration_time): + self.cache[self.id] = (self._data, expiration_time) + + def _delete(self): + self.cache.pop(self.id, None) + + def acquire_lock(self): + """Acquire an exclusive lock on the currently-loaded session data.""" + self.locked = True + self.locks.setdefault(self.id, threading.RLock()).acquire() + + def release_lock(self): + """Release the lock on the currently-loaded session data.""" + self.locks[self.id].release() + self.locked = False + + def __len__(self): + """Return the number of active sessions.""" + return len(self.cache) + + +class FileSession(Session): + + """Implementation of the File backend for sessions + + storage_path + The folder where session data will be saved. Each session + will be saved as pickle.dump(data, expiration_time) in its own file; + the filename will be self.SESSION_PREFIX + self.id. + + lock_timeout + A timedelta or numeric seconds indicating how long + to block acquiring a lock. If None (default), acquiring a lock + will block indefinitely. + """ + + SESSION_PREFIX = 'session-' + LOCK_SUFFIX = '.lock' + pickle_protocol = pickle.HIGHEST_PROTOCOL + + def __init__(self, id=None, **kwargs): + # The 'storage_path' arg is required for file-based sessions. + kwargs['storage_path'] = os.path.abspath(kwargs['storage_path']) + kwargs.setdefault('lock_timeout', None) + + Session.__init__(self, id=id, **kwargs) + + # validate self.lock_timeout + if isinstance(self.lock_timeout, (int, float)): + self.lock_timeout = datetime.timedelta(seconds=self.lock_timeout) + if not isinstance(self.lock_timeout, (datetime.timedelta, type(None))): + raise ValueError( + 'Lock timeout must be numeric seconds or a timedelta instance.' + ) + + @classmethod + def setup(cls, **kwargs): + """Set up the storage system for file-based sessions. + + This should only be called once per process; this will be done + automatically when using sessions.init (as the built-in Tool does). + """ + # The 'storage_path' arg is required for file-based sessions. + kwargs['storage_path'] = os.path.abspath(kwargs['storage_path']) + + for k, v in kwargs.items(): + setattr(cls, k, v) + + def _get_file_path(self): + f = os.path.join(self.storage_path, self.SESSION_PREFIX + self.id) + if not os.path.abspath(f).startswith(self.storage_path): + raise cherrypy.HTTPError(400, 'Invalid session id in cookie.') + return f + + def _exists(self): + path = self._get_file_path() + return os.path.exists(path) + + def _load(self, path=None): + assert self.locked, ('The session load without being locked. ' + "Check your tools' priority levels.") + if path is None: + path = self._get_file_path() + try: + f = open(path, 'rb') + try: + return pickle.load(f) + finally: + f.close() + except (IOError, EOFError): + e = sys.exc_info()[1] + if self.debug: + cherrypy.log('Error loading the session pickle: %s' % + e, 'TOOLS.SESSIONS') + return None + + def _save(self, expiration_time): + assert self.locked, ('The session was saved without being locked. ' + "Check your tools' priority levels.") + f = open(self._get_file_path(), 'wb') + try: + pickle.dump((self._data, expiration_time), f, self.pickle_protocol) + finally: + f.close() + + def _delete(self): + assert self.locked, ('The session deletion without being locked. ' + "Check your tools' priority levels.") + try: + os.unlink(self._get_file_path()) + except OSError: + pass + + def acquire_lock(self, path=None): + """Acquire an exclusive lock on the currently-loaded session data.""" + if path is None: + path = self._get_file_path() + path += self.LOCK_SUFFIX + checker = locking.LockChecker(self.id, self.lock_timeout) + while not checker.expired(): + try: + self.lock = zc.lockfile.LockFile(path) + except zc.lockfile.LockError: + time.sleep(0.1) + else: + break + self.locked = True + if self.debug: + cherrypy.log('Lock acquired.', 'TOOLS.SESSIONS') + + def release_lock(self, path=None): + """Release the lock on the currently-loaded session data.""" + self.lock.close() + self.locked = False + + def clean_up(self): + """Clean up expired sessions.""" + now = self.now() + # Iterate over all session files in self.storage_path + for fname in os.listdir(self.storage_path): + have_session = ( + fname.startswith(self.SESSION_PREFIX) + and not fname.endswith(self.LOCK_SUFFIX) + ) + if have_session: + # We have a session file: lock and load it and check + # if it's expired. If it fails, nevermind. + path = os.path.join(self.storage_path, fname) + self.acquire_lock(path) + if self.debug: + # This is a bit of a hack, since we're calling clean_up + # on the first instance rather than the entire class, + # so depending on whether you have "debug" set on the + # path of the first session called, this may not run. + cherrypy.log('Cleanup lock acquired.', 'TOOLS.SESSIONS') + + try: + contents = self._load(path) + # _load returns None on IOError + if contents is not None: + data, expiration_time = contents + if expiration_time < now: + # Session expired: deleting it + os.unlink(path) + finally: + self.release_lock(path) + + def __len__(self): + """Return the number of active sessions.""" + return len([fname for fname in os.listdir(self.storage_path) + if (fname.startswith(self.SESSION_PREFIX) and + not fname.endswith(self.LOCK_SUFFIX))]) + + +class MemcachedSession(Session): + + # The most popular memcached client for Python isn't thread-safe. + # Wrap all .get and .set operations in a single lock. + mc_lock = threading.RLock() + + # This is a separate set of locks per session id. + locks = {} + + servers = ['localhost:11211'] + + @classmethod + def setup(cls, **kwargs): + """Set up the storage system for memcached-based sessions. + + This should only be called once per process; this will be done + automatically when using sessions.init (as the built-in Tool does). + """ + for k, v in kwargs.items(): + setattr(cls, k, v) + + import memcache + cls.cache = memcache.Client(cls.servers) + + def _exists(self): + self.mc_lock.acquire() + try: + return bool(self.cache.get(self.id)) + finally: + self.mc_lock.release() + + def _load(self): + self.mc_lock.acquire() + try: + return self.cache.get(self.id) + finally: + self.mc_lock.release() + + def _save(self, expiration_time): + # Send the expiration time as "Unix time" (seconds since 1/1/1970) + td = int(time.mktime(expiration_time.timetuple())) + self.mc_lock.acquire() + try: + if not self.cache.set(self.id, (self._data, expiration_time), td): + raise AssertionError( + 'Session data for id %r not set.' % self.id) + finally: + self.mc_lock.release() + + def _delete(self): + self.cache.delete(self.id) + + def acquire_lock(self): + """Acquire an exclusive lock on the currently-loaded session data.""" + self.locked = True + self.locks.setdefault(self.id, threading.RLock()).acquire() + if self.debug: + cherrypy.log('Lock acquired.', 'TOOLS.SESSIONS') + + def release_lock(self): + """Release the lock on the currently-loaded session data.""" + self.locks[self.id].release() + self.locked = False + + def __len__(self): + """Return the number of active sessions.""" + raise NotImplementedError + + +# Hook functions (for CherryPy tools) + +def save(): + """Save any changed session data.""" + + if not hasattr(cherrypy.serving, 'session'): + return + request = cherrypy.serving.request + response = cherrypy.serving.response + + # Guard against running twice + if hasattr(request, '_sessionsaved'): + return + request._sessionsaved = True + + if response.stream: + # If the body is being streamed, we have to save the data + # *after* the response has been written out + request.hooks.attach('on_end_request', cherrypy.session.save) + else: + # If the body is not being streamed, we save the data now + # (so we can release the lock). + if is_iterator(response.body): + response.collapse_body() + cherrypy.session.save() + + +save.failsafe = True + + +def close(): + """Close the session object for this request.""" + sess = getattr(cherrypy.serving, 'session', None) + if getattr(sess, 'locked', False): + # If the session is still locked we release the lock + sess.release_lock() + if sess.debug: + cherrypy.log('Lock released on close.', 'TOOLS.SESSIONS') + + +close.failsafe = True +close.priority = 90 + + +def init(storage_type=None, path=None, path_header=None, name='session_id', + timeout=60, domain=None, secure=False, clean_freq=5, + persistent=True, httponly=False, debug=False, + # Py27 compat + # *, storage_class=RamSession, + **kwargs): + """Initialize session object (using cookies). + + storage_class + The Session subclass to use. Defaults to RamSession. + + storage_type + (deprecated) + One of 'ram', 'file', memcached'. This will be + used to look up the corresponding class in cherrypy.lib.sessions + globals. For example, 'file' will use the FileSession class. + + path + The 'path' value to stick in the response cookie metadata. + + path_header + If 'path' is None (the default), then the response + cookie 'path' will be pulled from request.headers[path_header]. + + name + The name of the cookie. + + timeout + The expiration timeout (in minutes) for the stored session data. + If 'persistent' is True (the default), this is also the timeout + for the cookie. + + domain + The cookie domain. + + secure + If False (the default) the cookie 'secure' value will not + be set. If True, the cookie 'secure' value will be set (to 1). + + clean_freq (minutes) + The poll rate for expired session cleanup. + + persistent + If True (the default), the 'timeout' argument will be used + to expire the cookie. If False, the cookie will not have an expiry, + and the cookie will be a "session cookie" which expires when the + browser is closed. + + httponly + If False (the default) the cookie 'httponly' value will not be set. + If True, the cookie 'httponly' value will be set (to 1). + + Any additional kwargs will be bound to the new Session instance, + and may be specific to the storage type. See the subclass of Session + you're using for more information. + """ + + # Py27 compat + storage_class = kwargs.pop('storage_class', RamSession) + + request = cherrypy.serving.request + + # Guard against running twice + if hasattr(request, '_session_init_flag'): + return + request._session_init_flag = True + + # Check if request came with a session ID + id = None + if name in request.cookie: + id = request.cookie[name].value + if debug: + cherrypy.log('ID obtained from request.cookie: %r' % id, + 'TOOLS.SESSIONS') + + first_time = not hasattr(cherrypy, 'session') + + if storage_type: + if first_time: + msg = 'storage_type is deprecated. Supply storage_class instead' + cherrypy.log(msg) + storage_class = storage_type.title() + 'Session' + storage_class = globals()[storage_class] + + # call setup first time only + if first_time: + if hasattr(storage_class, 'setup'): + storage_class.setup(**kwargs) + + # Create and attach a new Session instance to cherrypy.serving. + # It will possess a reference to (and lock, and lazily load) + # the requested session data. + kwargs['timeout'] = timeout + kwargs['clean_freq'] = clean_freq + cherrypy.serving.session = sess = storage_class(id, **kwargs) + sess.debug = debug + + def update_cookie(id): + """Update the cookie every time the session id changes.""" + cherrypy.serving.response.cookie[name] = id + sess.id_observers.append(update_cookie) + + # Create cherrypy.session which will proxy to cherrypy.serving.session + if not hasattr(cherrypy, 'session'): + cherrypy.session = cherrypy._ThreadLocalProxy('session') + + if persistent: + cookie_timeout = timeout + else: + # See http://support.microsoft.com/kb/223799/EN-US/ + # and http://support.mozilla.com/en-US/kb/Cookies + cookie_timeout = None + set_response_cookie(path=path, path_header=path_header, name=name, + timeout=cookie_timeout, domain=domain, secure=secure, + httponly=httponly) + + +def set_response_cookie(path=None, path_header=None, name='session_id', + timeout=60, domain=None, secure=False, httponly=False): + """Set a response cookie for the client. + + path + the 'path' value to stick in the response cookie metadata. + + path_header + if 'path' is None (the default), then the response + cookie 'path' will be pulled from request.headers[path_header]. + + name + the name of the cookie. + + timeout + the expiration timeout for the cookie. If 0 or other boolean + False, no 'expires' param will be set, and the cookie will be a + "session cookie" which expires when the browser is closed. + + domain + the cookie domain. + + secure + if False (the default) the cookie 'secure' value will not + be set. If True, the cookie 'secure' value will be set (to 1). + + httponly + If False (the default) the cookie 'httponly' value will not be set. + If True, the cookie 'httponly' value will be set (to 1). + + """ + # Set response cookie + cookie = cherrypy.serving.response.cookie + cookie[name] = cherrypy.serving.session.id + cookie[name]['path'] = ( + path or + cherrypy.serving.request.headers.get(path_header) or + '/' + ) + + if timeout: + cookie[name]['max-age'] = timeout * 60 + _add_MSIE_max_age_workaround(cookie[name], timeout) + if domain is not None: + cookie[name]['domain'] = domain + if secure: + cookie[name]['secure'] = 1 + if httponly: + if not cookie[name].isReservedKey('httponly'): + raise ValueError('The httponly cookie token is not supported.') + cookie[name]['httponly'] = 1 + + +def _add_MSIE_max_age_workaround(cookie, timeout): + """ + We'd like to use the "max-age" param as indicated in + http://www.faqs.org/rfcs/rfc2109.html but IE doesn't + save it to disk and the session is lost if people close + the browser. So we have to use the old "expires" ... sigh ... + """ + expires = time.time() + timeout * 60 + cookie['expires'] = httputil.HTTPDate(expires) + + +def expire(): + """Expire the current session cookie.""" + name = cherrypy.serving.request.config.get( + 'tools.sessions.name', 'session_id') + one_year = 60 * 60 * 24 * 365 + e = time.time() - one_year + cherrypy.serving.response.cookie[name]['expires'] = httputil.HTTPDate(e) + cherrypy.serving.response.cookie[name].pop('max-age', None) diff --git a/resources/lib/cherrypy/lib/static.py b/resources/lib/cherrypy/lib/static.py new file mode 100644 index 0000000..66a5a94 --- /dev/null +++ b/resources/lib/cherrypy/lib/static.py @@ -0,0 +1,416 @@ +"""Module with helpers for serving static files.""" + +import os +import platform +import re +import stat +import mimetypes +import urllib.parse +import unicodedata + +from email.generator import _make_boundary as make_boundary +from io import UnsupportedOperation + +import cherrypy +from cherrypy._cpcompat import ntob +from cherrypy.lib import cptools, httputil, file_generator_limited + + +def _setup_mimetypes(): + """Pre-initialize global mimetype map.""" + if not mimetypes.inited: + mimetypes.init() + mimetypes.types_map['.dwg'] = 'image/x-dwg' + mimetypes.types_map['.ico'] = 'image/x-icon' + mimetypes.types_map['.bz2'] = 'application/x-bzip2' + mimetypes.types_map['.gz'] = 'application/x-gzip' + + +_setup_mimetypes() + + +def _make_content_disposition(disposition, file_name): + """Create HTTP header for downloading a file with a UTF-8 filename. + + This function implements the recommendations of :rfc:`6266#appendix-D`. + See this and related answers: https://stackoverflow.com/a/8996249/2173868. + """ + # As normalization algorithm for `unicodedata` is used composed form (NFC + # and NFKC) with compatibility equivalence criteria (NFK), so "NFKC" is the + # one. It first applies the compatibility decomposition, followed by the + # canonical composition. Should be displayed in the same manner, should be + # treated in the same way by applications such as alphabetizing names or + # searching, and may be substituted for each other. + # See: https://en.wikipedia.org/wiki/Unicode_equivalence. + ascii_name = ( + unicodedata.normalize('NFKC', file_name). + encode('ascii', errors='ignore').decode() + ) + header = '{}; filename="{}"'.format(disposition, ascii_name) + if ascii_name != file_name: + quoted_name = urllib.parse.quote(file_name) + header += '; filename*=UTF-8\'\'{}'.format(quoted_name) + return header + + +def serve_file(path, content_type=None, disposition=None, name=None, + debug=False): + """Set status, headers, and body in order to serve the given path. + + The Content-Type header will be set to the content_type arg, if provided. + If not provided, the Content-Type will be guessed by the file extension + of the 'path' argument. + + If disposition is not None, the Content-Disposition header will be set + to "; filename=; filename*=utf-8''" + as described in :rfc:`6266#appendix-D`. + If name is None, it will be set to the basename of path. + If disposition is None, no Content-Disposition header will be written. + """ + response = cherrypy.serving.response + + # If path is relative, users should fix it by making path absolute. + # That is, CherryPy should not guess where the application root is. + # It certainly should *not* use cwd (since CP may be invoked from a + # variety of paths). If using tools.staticdir, you can make your relative + # paths become absolute by supplying a value for "tools.staticdir.root". + if not os.path.isabs(path): + msg = "'%s' is not an absolute path." % path + if debug: + cherrypy.log(msg, 'TOOLS.STATICFILE') + raise ValueError(msg) + + try: + st = os.stat(path) + except (OSError, TypeError, ValueError): + # OSError when file fails to stat + # TypeError on Python 2 when there's a null byte + # ValueError on Python 3 when there's a null byte + if debug: + cherrypy.log('os.stat(%r) failed' % path, 'TOOLS.STATIC') + raise cherrypy.NotFound() + + # Check if path is a directory. + if stat.S_ISDIR(st.st_mode): + # Let the caller deal with it as they like. + if debug: + cherrypy.log('%r is a directory' % path, 'TOOLS.STATIC') + raise cherrypy.NotFound() + + # Set the Last-Modified response header, so that + # modified-since validation code can work. + response.headers['Last-Modified'] = httputil.HTTPDate(st.st_mtime) + cptools.validate_since() + + if content_type is None: + # Set content-type based on filename extension + ext = '' + i = path.rfind('.') + if i != -1: + ext = path[i:].lower() + content_type = mimetypes.types_map.get(ext, None) + if content_type is not None: + response.headers['Content-Type'] = content_type + if debug: + cherrypy.log('Content-Type: %r' % content_type, 'TOOLS.STATIC') + + cd = None + if disposition is not None: + if name is None: + name = os.path.basename(path) + cd = _make_content_disposition(disposition, name) + response.headers['Content-Disposition'] = cd + if debug: + cherrypy.log('Content-Disposition: %r' % cd, 'TOOLS.STATIC') + + # Set Content-Length and use an iterable (file object) + # this way CP won't load the whole file in memory + content_length = st.st_size + fileobj = open(path, 'rb') + return _serve_fileobj(fileobj, content_type, content_length, debug=debug) + + +def serve_fileobj(fileobj, content_type=None, disposition=None, name=None, + debug=False): + """Set status, headers, and body in order to serve the given file object. + + The Content-Type header will be set to the content_type arg, if provided. + + If disposition is not None, the Content-Disposition header will be set + to "; filename=; filename*=utf-8''" + as described in :rfc:`6266#appendix-D`. + If name is None, 'filename' will not be set. + If disposition is None, no Content-Disposition header will be written. + + CAUTION: If the request contains a 'Range' header, one or more seek()s will + be performed on the file object. This may cause undesired behavior if + the file object is not seekable. It could also produce undesired results + if the caller set the read position of the file object prior to calling + serve_fileobj(), expecting that the data would be served starting from that + position. + """ + response = cherrypy.serving.response + + try: + st = os.fstat(fileobj.fileno()) + except AttributeError: + if debug: + cherrypy.log('os has no fstat attribute', 'TOOLS.STATIC') + content_length = None + except UnsupportedOperation: + content_length = None + else: + # Set the Last-Modified response header, so that + # modified-since validation code can work. + response.headers['Last-Modified'] = httputil.HTTPDate(st.st_mtime) + cptools.validate_since() + content_length = st.st_size + + if content_type is not None: + response.headers['Content-Type'] = content_type + if debug: + cherrypy.log('Content-Type: %r' % content_type, 'TOOLS.STATIC') + + cd = None + if disposition is not None: + if name is None: + cd = disposition + else: + cd = _make_content_disposition(disposition, name) + response.headers['Content-Disposition'] = cd + if debug: + cherrypy.log('Content-Disposition: %r' % cd, 'TOOLS.STATIC') + + return _serve_fileobj(fileobj, content_type, content_length, debug=debug) + + +def _serve_fileobj(fileobj, content_type, content_length, debug=False): + """Internal. Set response.body to the given file object, perhaps ranged.""" + response = cherrypy.serving.response + + # HTTP/1.0 didn't have Range/Accept-Ranges headers, or the 206 code + request = cherrypy.serving.request + if request.protocol >= (1, 1): + response.headers['Accept-Ranges'] = 'bytes' + r = httputil.get_ranges(request.headers.get('Range'), content_length) + if r == []: + response.headers['Content-Range'] = 'bytes */%s' % content_length + message = ('Invalid Range (first-byte-pos greater than ' + 'Content-Length)') + if debug: + cherrypy.log(message, 'TOOLS.STATIC') + raise cherrypy.HTTPError(416, message) + + if r: + if len(r) == 1: + # Return a single-part response. + start, stop = r[0] + if stop > content_length: + stop = content_length + r_len = stop - start + if debug: + cherrypy.log( + 'Single part; start: %r, stop: %r' % (start, stop), + 'TOOLS.STATIC') + response.status = '206 Partial Content' + response.headers['Content-Range'] = ( + 'bytes %s-%s/%s' % (start, stop - 1, content_length)) + response.headers['Content-Length'] = r_len + fileobj.seek(start) + response.body = file_generator_limited(fileobj, r_len) + else: + # Return a multipart/byteranges response. + response.status = '206 Partial Content' + boundary = make_boundary() + ct = 'multipart/byteranges; boundary=%s' % boundary + response.headers['Content-Type'] = ct + if 'Content-Length' in response.headers: + # Delete Content-Length header so finalize() recalcs it. + del response.headers['Content-Length'] + + def file_ranges(): + # Apache compatibility: + yield b'\r\n' + + for start, stop in r: + if debug: + cherrypy.log( + 'Multipart; start: %r, stop: %r' % ( + start, stop), + 'TOOLS.STATIC') + yield ntob('--' + boundary, 'ascii') + yield ntob('\r\nContent-type: %s' % content_type, + 'ascii') + yield ntob( + '\r\nContent-range: bytes %s-%s/%s\r\n\r\n' % ( + start, stop - 1, content_length), + 'ascii') + fileobj.seek(start) + gen = file_generator_limited(fileobj, stop - start) + for chunk in gen: + yield chunk + yield b'\r\n' + # Final boundary + yield ntob('--' + boundary + '--', 'ascii') + + # Apache compatibility: + yield b'\r\n' + response.body = file_ranges() + return response.body + else: + if debug: + cherrypy.log('No byteranges requested', 'TOOLS.STATIC') + + # Set Content-Length and use an iterable (file object) + # this way CP won't load the whole file in memory + response.headers['Content-Length'] = content_length + response.body = fileobj + return response.body + + +def serve_download(path, name=None): + """Serve 'path' as an application/x-download attachment.""" + # This is such a common idiom I felt it deserved its own wrapper. + return serve_file(path, 'application/x-download', 'attachment', name) + + +def _attempt(filename, content_types, debug=False): + if debug: + cherrypy.log('Attempting %r (content_types %r)' % + (filename, content_types), 'TOOLS.STATICDIR') + try: + # you can set the content types for a + # complete directory per extension + content_type = None + if content_types: + r, ext = os.path.splitext(filename) + content_type = content_types.get(ext[1:], None) + serve_file(filename, content_type=content_type, debug=debug) + return True + except cherrypy.NotFound: + # If we didn't find the static file, continue handling the + # request. We might find a dynamic handler instead. + if debug: + cherrypy.log('NotFound', 'TOOLS.STATICFILE') + return False + + +def staticdir(section, dir, root='', match='', content_types=None, index='', + debug=False): + """Serve a static resource from the given (root +) dir. + + match + If given, request.path_info will be searched for the given + regular expression before attempting to serve static content. + + content_types + If given, it should be a Python dictionary of + {file-extension: content-type} pairs, where 'file-extension' is + a string (e.g. "gif") and 'content-type' is the value to write + out in the Content-Type response header (e.g. "image/gif"). + + index + If provided, it should be the (relative) name of a file to + serve for directory requests. For example, if the dir argument is + '/home/me', the Request-URI is 'myapp', and the index arg is + 'index.html', the file '/home/me/myapp/index.html' will be sought. + """ + request = cherrypy.serving.request + if request.method not in ('GET', 'HEAD'): + if debug: + cherrypy.log('request.method not GET or HEAD', 'TOOLS.STATICDIR') + return False + + if match and not re.search(match, request.path_info): + if debug: + cherrypy.log('request.path_info %r does not match pattern %r' % + (request.path_info, match), 'TOOLS.STATICDIR') + return False + + # Allow the use of '~' to refer to a user's home directory. + dir = os.path.expanduser(dir) + + # If dir is relative, make absolute using "root". + if not os.path.isabs(dir): + if not root: + msg = 'Static dir requires an absolute dir (or root).' + if debug: + cherrypy.log(msg, 'TOOLS.STATICDIR') + raise ValueError(msg) + dir = os.path.join(root, dir) + + # Determine where we are in the object tree relative to 'section' + # (where the static tool was defined). + if section == 'global': + section = '/' + section = section.rstrip(r'\/') + branch = request.path_info[len(section) + 1:] + branch = urllib.parse.unquote(branch.lstrip(r'\/')) + + # Requesting a file in sub-dir of the staticdir results + # in mixing of delimiter styles, e.g. C:\static\js/script.js. + # Windows accepts this form except not when the path is + # supplied in extended-path notation, e.g. \\?\C:\static\js/script.js. + # http://bit.ly/1vdioCX + if platform.system() == 'Windows': + branch = branch.replace('/', '\\') + + # If branch is "", filename will end in a slash + filename = os.path.join(dir, branch) + if debug: + cherrypy.log('Checking file %r to fulfill %r' % + (filename, request.path_info), 'TOOLS.STATICDIR') + + # There's a chance that the branch pulled from the URL might + # have ".." or similar uplevel attacks in it. Check that the final + # filename is a child of dir. + if not os.path.normpath(filename).startswith(os.path.normpath(dir)): + raise cherrypy.HTTPError(403) # Forbidden + + handled = _attempt(filename, content_types) + if not handled: + # Check for an index file if a folder was requested. + if index: + handled = _attempt(os.path.join(filename, index), content_types) + if handled: + request.is_index = filename[-1] in (r'\/') + return handled + + +def staticfile(filename, root=None, match='', content_types=None, debug=False): + """Serve a static resource from the given (root +) filename. + + match + If given, request.path_info will be searched for the given + regular expression before attempting to serve static content. + + content_types + If given, it should be a Python dictionary of + {file-extension: content-type} pairs, where 'file-extension' is + a string (e.g. "gif") and 'content-type' is the value to write + out in the Content-Type response header (e.g. "image/gif"). + + """ + request = cherrypy.serving.request + if request.method not in ('GET', 'HEAD'): + if debug: + cherrypy.log('request.method not GET or HEAD', 'TOOLS.STATICFILE') + return False + + if match and not re.search(match, request.path_info): + if debug: + cherrypy.log('request.path_info %r does not match pattern %r' % + (request.path_info, match), 'TOOLS.STATICFILE') + return False + + # If filename is relative, make absolute using "root". + if not os.path.isabs(filename): + if not root: + msg = "Static tool requires an absolute filename (got '%s')." % ( + filename,) + if debug: + cherrypy.log(msg, 'TOOLS.STATICFILE') + raise ValueError(msg) + filename = os.path.join(root, filename) + + return _attempt(filename, content_types, debug=debug) diff --git a/resources/lib/cherrypy/lib/xmlrpcutil.py b/resources/lib/cherrypy/lib/xmlrpcutil.py new file mode 100644 index 0000000..29d9c4a --- /dev/null +++ b/resources/lib/cherrypy/lib/xmlrpcutil.py @@ -0,0 +1,60 @@ +"""XML-RPC tool helpers.""" +import sys +from xmlrpc.client import ( + loads as xmlrpc_loads, dumps as xmlrpc_dumps, + Fault as XMLRPCFault +) + +import cherrypy +from cherrypy._cpcompat import ntob + + +def process_body(): + """Return (params, method) from request body.""" + try: + return xmlrpc_loads(cherrypy.request.body.read()) + except Exception: + return ('ERROR PARAMS', ), 'ERRORMETHOD' + + +def patched_path(path): + """Return 'path', doctored for RPC.""" + if not path.endswith('/'): + path += '/' + if path.startswith('/RPC2/'): + # strip the first /rpc2 + path = path[5:] + return path + + +def _set_response(body): + """Set up HTTP status, headers and body within CherryPy.""" + # The XML-RPC spec (http://www.xmlrpc.com/spec) says: + # "Unless there's a lower-level error, always return 200 OK." + # Since Python's xmlrpc_client interprets a non-200 response + # as a "Protocol Error", we'll just return 200 every time. + response = cherrypy.response + response.status = '200 OK' + response.body = ntob(body, 'utf-8') + response.headers['Content-Type'] = 'text/xml' + response.headers['Content-Length'] = len(body) + + +def respond(body, encoding='utf-8', allow_none=0): + """Construct HTTP response body.""" + if not isinstance(body, XMLRPCFault): + body = (body,) + + _set_response( + xmlrpc_dumps( + body, methodresponse=1, + encoding=encoding, + allow_none=allow_none + ) + ) + + +def on_error(*args, **kwargs): + """Construct HTTP response body for an error response.""" + body = str(sys.exc_info()[1]) + _set_response(xmlrpc_dumps(XMLRPCFault(1, body))) diff --git a/resources/lib/cherrypy/process/__init__.py b/resources/lib/cherrypy/process/__init__.py new file mode 100644 index 0000000..f242d22 --- /dev/null +++ b/resources/lib/cherrypy/process/__init__.py @@ -0,0 +1,17 @@ +"""Site container for an HTTP server. + +A Web Site Process Bus object is used to connect applications, servers, +and frameworks with site-wide services such as daemonization, process +reload, signal handling, drop privileges, PID file management, logging +for all of these, and many more. + +The 'plugins' module defines a few abstract and concrete services for +use with the bus. Some use tool-specific channels; see the documentation +for each class. +""" + +from .wspbus import bus +from . import plugins, servers + + +__all__ = ('bus', 'plugins', 'servers') diff --git a/resources/lib/cherrypy/process/plugins.py b/resources/lib/cherrypy/process/plugins.py new file mode 100644 index 0000000..d2f87a4 --- /dev/null +++ b/resources/lib/cherrypy/process/plugins.py @@ -0,0 +1,754 @@ +"""Site services for use with a Web Site Process Bus.""" + +import os +import re +import signal as _signal +import sys +import time +import threading +import _thread + +from cherrypy._cpcompat import text_or_bytes +from cherrypy._cpcompat import ntob + +# _module__file__base is used by Autoreload to make +# absolute any filenames retrieved from sys.modules which are not +# already absolute paths. This is to work around Python's quirk +# of importing the startup script and using a relative filename +# for it in sys.modules. +# +# Autoreload examines sys.modules afresh every time it runs. If an application +# changes the current directory by executing os.chdir(), then the next time +# Autoreload runs, it will not be able to find any filenames which are +# not absolute paths, because the current directory is not the same as when the +# module was first imported. Autoreload will then wrongly conclude the file +# has "changed", and initiate the shutdown/re-exec sequence. +# See ticket #917. +# For this workaround to have a decent probability of success, this module +# needs to be imported as early as possible, before the app has much chance +# to change the working directory. +_module__file__base = os.getcwd() + + +class SimplePlugin(object): + + """Plugin base class which auto-subscribes methods for known channels.""" + + bus = None + """A :class:`Bus `, usually cherrypy.engine. + """ + + def __init__(self, bus): + self.bus = bus + + def subscribe(self): + """Register this object as a (multi-channel) listener on the bus.""" + for channel in self.bus.listeners: + # Subscribe self.start, self.exit, etc. if present. + method = getattr(self, channel, None) + if method is not None: + self.bus.subscribe(channel, method) + + def unsubscribe(self): + """Unregister this object as a listener on the bus.""" + for channel in self.bus.listeners: + # Unsubscribe self.start, self.exit, etc. if present. + method = getattr(self, channel, None) + if method is not None: + self.bus.unsubscribe(channel, method) + + +class SignalHandler(object): + + """Register bus channels (and listeners) for system signals. + + You can modify what signals your application listens for, and what it does + when it receives signals, by modifying :attr:`SignalHandler.handlers`, + a dict of {signal name: callback} pairs. The default set is:: + + handlers = {'SIGTERM': self.bus.exit, + 'SIGHUP': self.handle_SIGHUP, + 'SIGUSR1': self.bus.graceful, + } + + The :func:`SignalHandler.handle_SIGHUP`` method calls + :func:`bus.restart()` + if the process is daemonized, but + :func:`bus.exit()` + if the process is attached to a TTY. This is because Unix window + managers tend to send SIGHUP to terminal windows when the user closes them. + + Feel free to add signals which are not available on every platform. + The :class:`SignalHandler` will ignore errors raised from attempting + to register handlers for unknown signals. + """ + + handlers = {} + """A map from signal names (e.g. 'SIGTERM') to handlers (e.g. bus.exit).""" + + signals = {} + """A map from signal numbers to names.""" + + for k, v in vars(_signal).items(): + if k.startswith('SIG') and not k.startswith('SIG_'): + signals[v] = k + del k, v + + def __init__(self, bus): + self.bus = bus + # Set default handlers + self.handlers = {'SIGTERM': self.bus.exit, + 'SIGHUP': self.handle_SIGHUP, + 'SIGUSR1': self.bus.graceful, + } + + if sys.platform[:4] == 'java': + del self.handlers['SIGUSR1'] + self.handlers['SIGUSR2'] = self.bus.graceful + self.bus.log('SIGUSR1 cannot be set on the JVM platform. ' + 'Using SIGUSR2 instead.') + self.handlers['SIGINT'] = self._jython_SIGINT_handler + + self._previous_handlers = {} + # used to determine is the process is a daemon in `self._is_daemonized` + self._original_pid = os.getpid() + + def _jython_SIGINT_handler(self, signum=None, frame=None): + # See http://bugs.jython.org/issue1313 + self.bus.log('Keyboard Interrupt: shutting down bus') + self.bus.exit() + + def _is_daemonized(self): + """Return boolean indicating if the current process is + running as a daemon. + + The criteria to determine the `daemon` condition is to verify + if the current pid is not the same as the one that got used on + the initial construction of the plugin *and* the stdin is not + connected to a terminal. + + The sole validation of the tty is not enough when the plugin + is executing inside other process like in a CI tool + (Buildbot, Jenkins). + """ + return ( + self._original_pid != os.getpid() and + not os.isatty(sys.stdin.fileno()) + ) + + def subscribe(self): + """Subscribe self.handlers to signals.""" + for sig, func in self.handlers.items(): + try: + self.set_handler(sig, func) + except ValueError: + pass + + def unsubscribe(self): + """Unsubscribe self.handlers from signals.""" + for signum, handler in self._previous_handlers.items(): + signame = self.signals[signum] + + if handler is None: + self.bus.log('Restoring %s handler to SIG_DFL.' % signame) + handler = _signal.SIG_DFL + else: + self.bus.log('Restoring %s handler %r.' % (signame, handler)) + + try: + our_handler = _signal.signal(signum, handler) + if our_handler is None: + self.bus.log('Restored old %s handler %r, but our ' + 'handler was not registered.' % + (signame, handler), level=30) + except ValueError: + self.bus.log('Unable to restore %s handler %r.' % + (signame, handler), level=40, traceback=True) + + def set_handler(self, signal, listener=None): + """Subscribe a handler for the given signal (number or name). + + If the optional 'listener' argument is provided, it will be + subscribed as a listener for the given signal's channel. + + If the given signal name or number is not available on the current + platform, ValueError is raised. + """ + if isinstance(signal, text_or_bytes): + signum = getattr(_signal, signal, None) + if signum is None: + raise ValueError('No such signal: %r' % signal) + signame = signal + else: + try: + signame = self.signals[signal] + except KeyError: + raise ValueError('No such signal: %r' % signal) + signum = signal + + prev = _signal.signal(signum, self._handle_signal) + self._previous_handlers[signum] = prev + + if listener is not None: + self.bus.log('Listening for %s.' % signame) + self.bus.subscribe(signame, listener) + + def _handle_signal(self, signum=None, frame=None): + """Python signal handler (self.set_handler subscribes it for you).""" + signame = self.signals[signum] + self.bus.log('Caught signal %s.' % signame) + self.bus.publish(signame) + + def handle_SIGHUP(self): + """Restart if daemonized, else exit.""" + if self._is_daemonized(): + self.bus.log('SIGHUP caught while daemonized. Restarting.') + self.bus.restart() + else: + # not daemonized (may be foreground or background) + self.bus.log('SIGHUP caught but not daemonized. Exiting.') + self.bus.exit() + + +try: + import pwd + import grp +except ImportError: + pwd, grp = None, None + + +class DropPrivileges(SimplePlugin): + + """Drop privileges. uid/gid arguments not available on Windows. + + Special thanks to `Gavin Baker + `_ + """ + + def __init__(self, bus, umask=None, uid=None, gid=None): + SimplePlugin.__init__(self, bus) + self.finalized = False + self.uid = uid + self.gid = gid + self.umask = umask + + @property + def uid(self): + """The uid under which to run. Availability: Unix.""" + return self._uid + + @uid.setter + def uid(self, val): + if val is not None: + if pwd is None: + self.bus.log('pwd module not available; ignoring uid.', + level=30) + val = None + elif isinstance(val, text_or_bytes): + val = pwd.getpwnam(val)[2] + self._uid = val + + @property + def gid(self): + """The gid under which to run. Availability: Unix.""" + return self._gid + + @gid.setter + def gid(self, val): + if val is not None: + if grp is None: + self.bus.log('grp module not available; ignoring gid.', + level=30) + val = None + elif isinstance(val, text_or_bytes): + val = grp.getgrnam(val)[2] + self._gid = val + + @property + def umask(self): + """The default permission mode for newly created files and directories. + + Usually expressed in octal format, for example, ``0644``. + Availability: Unix, Windows. + """ + return self._umask + + @umask.setter + def umask(self, val): + if val is not None: + try: + os.umask + except AttributeError: + self.bus.log('umask function not available; ignoring umask.', + level=30) + val = None + self._umask = val + + def start(self): + # uid/gid + def current_ids(): + """Return the current (uid, gid) if available.""" + name, group = None, None + if pwd: + name = pwd.getpwuid(os.getuid())[0] + if grp: + group = grp.getgrgid(os.getgid())[0] + return name, group + + if self.finalized: + if not (self.uid is None and self.gid is None): + self.bus.log('Already running as uid: %r gid: %r' % + current_ids()) + else: + if self.uid is None and self.gid is None: + if pwd or grp: + self.bus.log('uid/gid not set', level=30) + else: + self.bus.log('Started as uid: %r gid: %r' % current_ids()) + if self.gid is not None: + os.setgid(self.gid) + os.setgroups([]) + if self.uid is not None: + os.setuid(self.uid) + self.bus.log('Running as uid: %r gid: %r' % current_ids()) + + # umask + if self.finalized: + if self.umask is not None: + self.bus.log('umask already set to: %03o' % self.umask) + else: + if self.umask is None: + self.bus.log('umask not set', level=30) + else: + old_umask = os.umask(self.umask) + self.bus.log('umask old: %03o, new: %03o' % + (old_umask, self.umask)) + + self.finalized = True + # This is slightly higher than the priority for server.start + # in order to facilitate the most common use: starting on a low + # port (which requires root) and then dropping to another user. + start.priority = 77 + + +class Daemonizer(SimplePlugin): + + """Daemonize the running script. + + Use this with a Web Site Process Bus via:: + + Daemonizer(bus).subscribe() + + When this component finishes, the process is completely decoupled from + the parent environment. Please note that when this component is used, + the return code from the parent process will still be 0 if a startup + error occurs in the forked children. Errors in the initial daemonizing + process still return proper exit codes. Therefore, if you use this + plugin to daemonize, don't use the return code as an accurate indicator + of whether the process fully started. In fact, that return code only + indicates if the process successfully finished the first fork. + """ + + def __init__(self, bus, stdin='/dev/null', stdout='/dev/null', + stderr='/dev/null'): + SimplePlugin.__init__(self, bus) + self.stdin = stdin + self.stdout = stdout + self.stderr = stderr + self.finalized = False + + def start(self): + if self.finalized: + self.bus.log('Already deamonized.') + + # forking has issues with threads: + # http://www.opengroup.org/onlinepubs/000095399/functions/fork.html + # "The general problem with making fork() work in a multi-threaded + # world is what to do with all of the threads..." + # So we check for active threads: + if threading.activeCount() != 1: + self.bus.log('There are %r active threads. ' + 'Daemonizing now may cause strange failures.' % + threading.enumerate(), level=30) + + self.daemonize(self.stdin, self.stdout, self.stderr, self.bus.log) + + self.finalized = True + start.priority = 65 + + @staticmethod + def daemonize( + stdin='/dev/null', stdout='/dev/null', stderr='/dev/null', + logger=lambda msg: None): + # See http://www.erlenstar.demon.co.uk/unix/faq_2.html#SEC16 + # (or http://www.faqs.org/faqs/unix-faq/programmer/faq/ section 1.7) + # and http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/66012 + + # Finish up with the current stdout/stderr + sys.stdout.flush() + sys.stderr.flush() + + error_tmpl = ( + '{sys.argv[0]}: fork #{n} failed: ({exc.errno}) {exc.strerror}\n' + ) + + for fork in range(2): + msg = ['Forking once.', 'Forking twice.'][fork] + try: + pid = os.fork() + if pid > 0: + # This is the parent; exit. + logger(msg) + os._exit(0) + except OSError as exc: + # Python raises OSError rather than returning negative numbers. + sys.exit(error_tmpl.format(sys=sys, exc=exc, n=fork + 1)) + if fork == 0: + os.setsid() + + os.umask(0) + + si = open(stdin, 'r') + so = open(stdout, 'a+') + se = open(stderr, 'a+') + + # os.dup2(fd, fd2) will close fd2 if necessary, + # so we don't explicitly close stdin/out/err. + # See http://docs.python.org/lib/os-fd-ops.html + os.dup2(si.fileno(), sys.stdin.fileno()) + os.dup2(so.fileno(), sys.stdout.fileno()) + os.dup2(se.fileno(), sys.stderr.fileno()) + + logger('Daemonized to PID: %s' % os.getpid()) + + +class PIDFile(SimplePlugin): + + """Maintain a PID file via a WSPBus.""" + + def __init__(self, bus, pidfile): + SimplePlugin.__init__(self, bus) + self.pidfile = pidfile + self.finalized = False + + def start(self): + pid = os.getpid() + if self.finalized: + self.bus.log('PID %r already written to %r.' % (pid, self.pidfile)) + else: + open(self.pidfile, 'wb').write(ntob('%s\n' % pid, 'utf8')) + self.bus.log('PID %r written to %r.' % (pid, self.pidfile)) + self.finalized = True + start.priority = 70 + + def exit(self): + try: + os.remove(self.pidfile) + self.bus.log('PID file removed: %r.' % self.pidfile) + except (KeyboardInterrupt, SystemExit): + raise + except Exception: + pass + + +class PerpetualTimer(threading.Timer): + + """A responsive subclass of threading.Timer whose run() method repeats. + + Use this timer only when you really need a very interruptible timer; + this checks its 'finished' condition up to 20 times a second, which can + results in pretty high CPU usage + """ + + def __init__(self, *args, **kwargs): + "Override parent constructor to allow 'bus' to be provided." + self.bus = kwargs.pop('bus', None) + super(PerpetualTimer, self).__init__(*args, **kwargs) + + def run(self): + while True: + self.finished.wait(self.interval) + if self.finished.isSet(): + return + try: + self.function(*self.args, **self.kwargs) + except Exception: + if self.bus: + self.bus.log( + 'Error in perpetual timer thread function %r.' % + self.function, level=40, traceback=True) + # Quit on first error to avoid massive logs. + raise + + +class BackgroundTask(threading.Thread): + + """A subclass of threading.Thread whose run() method repeats. + + Use this class for most repeating tasks. It uses time.sleep() to wait + for each interval, which isn't very responsive; that is, even if you call + self.cancel(), you'll have to wait until the sleep() call finishes before + the thread stops. To compensate, it defaults to being daemonic, which means + it won't delay stopping the whole process. + """ + + def __init__(self, interval, function, args=[], kwargs={}, bus=None): + super(BackgroundTask, self).__init__() + self.interval = interval + self.function = function + self.args = args + self.kwargs = kwargs + self.running = False + self.bus = bus + + # default to daemonic + self.daemon = True + + def cancel(self): + self.running = False + + def run(self): + self.running = True + while self.running: + time.sleep(self.interval) + if not self.running: + return + try: + self.function(*self.args, **self.kwargs) + except Exception: + if self.bus: + self.bus.log('Error in background task thread function %r.' + % self.function, level=40, traceback=True) + # Quit on first error to avoid massive logs. + raise + + +class Monitor(SimplePlugin): + + """WSPBus listener to periodically run a callback in its own thread.""" + + callback = None + """The function to call at intervals.""" + + frequency = 60 + """The time in seconds between callback runs.""" + + thread = None + """A :class:`BackgroundTask` + thread. + """ + + def __init__(self, bus, callback, frequency=60, name=None): + SimplePlugin.__init__(self, bus) + self.callback = callback + self.frequency = frequency + self.thread = None + self.name = name + + def start(self): + """Start our callback in its own background thread.""" + if self.frequency > 0: + threadname = self.name or self.__class__.__name__ + if self.thread is None: + self.thread = BackgroundTask(self.frequency, self.callback, + bus=self.bus) + self.thread.setName(threadname) + self.thread.start() + self.bus.log('Started monitor thread %r.' % threadname) + else: + self.bus.log('Monitor thread %r already started.' % threadname) + start.priority = 70 + + def stop(self): + """Stop our callback's background task thread.""" + if self.thread is None: + self.bus.log('No thread running for %s.' % + self.name or self.__class__.__name__) + else: + if self.thread is not threading.currentThread(): + name = self.thread.getName() + self.thread.cancel() + if not self.thread.daemon: + self.bus.log('Joining %r' % name) + self.thread.join() + self.bus.log('Stopped thread %r.' % name) + self.thread = None + + def graceful(self): + """Stop the callback's background task thread and restart it.""" + self.stop() + self.start() + + +class Autoreloader(Monitor): + + """Monitor which re-executes the process when files change. + + This :ref:`plugin` restarts the process (via :func:`os.execv`) + if any of the files it monitors change (or is deleted). By default, the + autoreloader monitors all imported modules; you can add to the + set by adding to ``autoreload.files``:: + + cherrypy.engine.autoreload.files.add(myFile) + + If there are imported files you do *not* wish to monitor, you can + adjust the ``match`` attribute, a regular expression. For example, + to stop monitoring cherrypy itself:: + + cherrypy.engine.autoreload.match = r'^(?!cherrypy).+' + + Like all :class:`Monitor` plugins, + the autoreload plugin takes a ``frequency`` argument. The default is + 1 second; that is, the autoreloader will examine files once each second. + """ + + files = None + """The set of files to poll for modifications.""" + + frequency = 1 + """The interval in seconds at which to poll for modified files.""" + + match = '.*' + """A regular expression by which to match filenames.""" + + def __init__(self, bus, frequency=1, match='.*'): + self.mtimes = {} + self.files = set() + self.match = match + Monitor.__init__(self, bus, self.run, frequency) + + def start(self): + """Start our own background task thread for self.run.""" + if self.thread is None: + self.mtimes = {} + Monitor.start(self) + start.priority = 70 + + def sysfiles(self): + """Return a Set of sys.modules filenames to monitor.""" + search_mod_names = filter( + re.compile(self.match).match, + list(sys.modules.keys()), + ) + mods = map(sys.modules.get, search_mod_names) + return set(filter(None, map(self._file_for_module, mods))) + + @classmethod + def _file_for_module(cls, module): + """Return the relevant file for the module.""" + return ( + cls._archive_for_zip_module(module) + or cls._file_for_file_module(module) + ) + + @staticmethod + def _archive_for_zip_module(module): + """Return the archive filename for the module if relevant.""" + try: + return module.__loader__.archive + except AttributeError: + pass + + @classmethod + def _file_for_file_module(cls, module): + """Return the file for the module.""" + try: + return module.__file__ and cls._make_absolute(module.__file__) + except AttributeError: + pass + + @staticmethod + def _make_absolute(filename): + """Ensure filename is absolute to avoid effect of os.chdir.""" + return filename if os.path.isabs(filename) else ( + os.path.normpath(os.path.join(_module__file__base, filename)) + ) + + def run(self): + """Reload the process if registered files have been modified.""" + for filename in self.sysfiles() | self.files: + if filename: + if filename.endswith('.pyc'): + filename = filename[:-1] + + oldtime = self.mtimes.get(filename, 0) + if oldtime is None: + # Module with no .py file. Skip it. + continue + + try: + mtime = os.stat(filename).st_mtime + except OSError: + # Either a module with no .py file, or it's been deleted. + mtime = None + + if filename not in self.mtimes: + # If a module has no .py file, this will be None. + self.mtimes[filename] = mtime + else: + if mtime is None or mtime > oldtime: + # The file has been deleted or modified. + self.bus.log('Restarting because %s changed.' % + filename) + self.thread.cancel() + self.bus.log('Stopped thread %r.' % + self.thread.getName()) + self.bus.restart() + return + + +class ThreadManager(SimplePlugin): + + """Manager for HTTP request threads. + + If you have control over thread creation and destruction, publish to + the 'acquire_thread' and 'release_thread' channels (for each thread). + This will register/unregister the current thread and publish to + 'start_thread' and 'stop_thread' listeners in the bus as needed. + + If threads are created and destroyed by code you do not control + (e.g., Apache), then, at the beginning of every HTTP request, + publish to 'acquire_thread' only. You should not publish to + 'release_thread' in this case, since you do not know whether + the thread will be re-used or not. The bus will call + 'stop_thread' listeners for you when it stops. + """ + + threads = None + """A map of {thread ident: index number} pairs.""" + + def __init__(self, bus): + self.threads = {} + SimplePlugin.__init__(self, bus) + self.bus.listeners.setdefault('acquire_thread', set()) + self.bus.listeners.setdefault('start_thread', set()) + self.bus.listeners.setdefault('release_thread', set()) + self.bus.listeners.setdefault('stop_thread', set()) + + def acquire_thread(self): + """Run 'start_thread' listeners for the current thread. + + If the current thread has already been seen, any 'start_thread' + listeners will not be run again. + """ + thread_ident = _thread.get_ident() + if thread_ident not in self.threads: + # We can't just use get_ident as the thread ID + # because some platforms reuse thread ID's. + i = len(self.threads) + 1 + self.threads[thread_ident] = i + self.bus.publish('start_thread', i) + + def release_thread(self): + """Release the current thread and run 'stop_thread' listeners.""" + thread_ident = _thread.get_ident() + i = self.threads.pop(thread_ident, None) + if i is not None: + self.bus.publish('stop_thread', i) + + def stop(self): + """Release all threads and run all 'stop_thread' listeners.""" + for thread_ident, i in self.threads.items(): + self.bus.publish('stop_thread', i) + self.threads.clear() + graceful = stop diff --git a/resources/lib/cherrypy/process/servers.py b/resources/lib/cherrypy/process/servers.py new file mode 100644 index 0000000..dcb34de --- /dev/null +++ b/resources/lib/cherrypy/process/servers.py @@ -0,0 +1,416 @@ +r""" +Starting in CherryPy 3.1, cherrypy.server is implemented as an +:ref:`Engine Plugin`. It's an instance of +:class:`cherrypy._cpserver.Server`, which is a subclass of +:class:`cherrypy.process.servers.ServerAdapter`. The ``ServerAdapter`` class +is designed to control other servers, as well. + +Multiple servers/ports +====================== + +If you need to start more than one HTTP server (to serve on multiple ports, or +protocols, etc.), you can manually register each one and then start them all +with engine.start:: + + s1 = ServerAdapter( + cherrypy.engine, + MyWSGIServer(host='0.0.0.0', port=80) + ) + s2 = ServerAdapter( + cherrypy.engine, + another.HTTPServer(host='127.0.0.1', SSL=True) + ) + s1.subscribe() + s2.subscribe() + cherrypy.engine.start() + +.. index:: SCGI + +FastCGI/SCGI +============ + +There are also Flup\ **F**\ CGIServer and Flup\ **S**\ CGIServer classes in +:mod:`cherrypy.process.servers`. To start an fcgi server, for example, +wrap an instance of it in a ServerAdapter:: + + addr = ('0.0.0.0', 4000) + f = servers.FlupFCGIServer(application=cherrypy.tree, bindAddress=addr) + s = servers.ServerAdapter(cherrypy.engine, httpserver=f, bind_addr=addr) + s.subscribe() + +The :doc:`cherryd` startup script will do the above for +you via its `-f` flag. +Note that you need to download and install `flup `_ +yourself, whether you use ``cherryd`` or not. + +.. _fastcgi: +.. index:: FastCGI + +FastCGI +------- + +A very simple setup lets your cherry run with FastCGI. +You just need the flup library, +plus a running Apache server (with ``mod_fastcgi``) or lighttpd server. + +CherryPy code +^^^^^^^^^^^^^ + +hello.py:: + + #!/usr/bin/python + import cherrypy + + class HelloWorld: + '''Sample request handler class.''' + @cherrypy.expose + def index(self): + return "Hello world!" + + cherrypy.tree.mount(HelloWorld()) + # CherryPy autoreload must be disabled for the flup server to work + cherrypy.config.update({'engine.autoreload.on':False}) + +Then run :doc:`/deployguide/cherryd` with the '-f' arg:: + + cherryd -c -d -f -i hello.py + +Apache +^^^^^^ + +At the top level in httpd.conf:: + + FastCgiIpcDir /tmp + FastCgiServer /path/to/cherry.fcgi -idle-timeout 120 -processes 4 + +And inside the relevant VirtualHost section:: + + # FastCGI config + AddHandler fastcgi-script .fcgi + ScriptAliasMatch (.*$) /path/to/cherry.fcgi$1 + +Lighttpd +^^^^^^^^ + +For `Lighttpd `_ you can follow these +instructions. Within ``lighttpd.conf`` make sure ``mod_fastcgi`` is +active within ``server.modules``. Then, within your ``$HTTP["host"]`` +directive, configure your fastcgi script like the following:: + + $HTTP["url"] =~ "" { + fastcgi.server = ( + "/" => ( + "script.fcgi" => ( + "bin-path" => "/path/to/your/script.fcgi", + "socket" => "/tmp/script.sock", + "check-local" => "disable", + "disable-time" => 1, + "min-procs" => 1, + "max-procs" => 1, # adjust as needed + ), + ), + ) + } # end of $HTTP["url"] =~ "^/" + +Please see `Lighttpd FastCGI Docs +`_ for +an explanation of the possible configuration options. +""" + +import os +import sys +import time +import warnings +import contextlib + +import portend + + +class Timeouts: + occupied = 5 + free = 1 + + +class ServerAdapter(object): + + """Adapter for an HTTP server. + + If you need to start more than one HTTP server (to serve on multiple + ports, or protocols, etc.), you can manually register each one and then + start them all with bus.start:: + + s1 = ServerAdapter(bus, MyWSGIServer(host='0.0.0.0', port=80)) + s2 = ServerAdapter(bus, another.HTTPServer(host='127.0.0.1', SSL=True)) + s1.subscribe() + s2.subscribe() + bus.start() + """ + + def __init__(self, bus, httpserver=None, bind_addr=None): + self.bus = bus + self.httpserver = httpserver + self.bind_addr = bind_addr + self.interrupt = None + self.running = False + + def subscribe(self): + self.bus.subscribe('start', self.start) + self.bus.subscribe('stop', self.stop) + + def unsubscribe(self): + self.bus.unsubscribe('start', self.start) + self.bus.unsubscribe('stop', self.stop) + + def start(self): + """Start the HTTP server.""" + if self.running: + self.bus.log('Already serving on %s' % self.description) + return + + self.interrupt = None + if not self.httpserver: + raise ValueError('No HTTP server has been created.') + + if not os.environ.get('LISTEN_PID', None): + # Start the httpserver in a new thread. + if isinstance(self.bind_addr, tuple): + portend.free(*self.bind_addr, timeout=Timeouts.free) + + import threading + t = threading.Thread(target=self._start_http_thread) + t.setName('HTTPServer ' + t.getName()) + t.start() + + self.wait() + self.running = True + self.bus.log('Serving on %s' % self.description) + start.priority = 75 + + @property + def description(self): + """ + A description about where this server is bound. + """ + if self.bind_addr is None: + on_what = 'unknown interface (dynamic?)' + elif isinstance(self.bind_addr, tuple): + on_what = self._get_base() + else: + on_what = 'socket file: %s' % self.bind_addr + return on_what + + def _get_base(self): + if not self.httpserver: + return '' + host, port = self.bound_addr + if getattr(self.httpserver, 'ssl_adapter', None): + scheme = 'https' + if port != 443: + host += ':%s' % port + else: + scheme = 'http' + if port != 80: + host += ':%s' % port + + return '%s://%s' % (scheme, host) + + def _start_http_thread(self): + """HTTP servers MUST be running in new threads, so that the + main thread persists to receive KeyboardInterrupt's. If an + exception is raised in the httpserver's thread then it's + trapped here, and the bus (and therefore our httpserver) + are shut down. + """ + try: + self.httpserver.start() + except KeyboardInterrupt: + self.bus.log(' hit: shutting down HTTP server') + self.interrupt = sys.exc_info()[1] + self.bus.exit() + except SystemExit: + self.bus.log('SystemExit raised: shutting down HTTP server') + self.interrupt = sys.exc_info()[1] + self.bus.exit() + raise + except Exception: + self.interrupt = sys.exc_info()[1] + self.bus.log('Error in HTTP server: shutting down', + traceback=True, level=40) + self.bus.exit() + raise + + def wait(self): + """Wait until the HTTP server is ready to receive requests.""" + while not getattr(self.httpserver, 'ready', False): + if self.interrupt: + raise self.interrupt + time.sleep(.1) + + # bypass check when LISTEN_PID is set + if os.environ.get('LISTEN_PID', None): + return + + # bypass check when running via socket-activation + # (for socket-activation the port will be managed by systemd) + if not isinstance(self.bind_addr, tuple): + return + + # wait for port to be occupied + with _safe_wait(*self.bound_addr): + portend.occupied(*self.bound_addr, timeout=Timeouts.occupied) + + @property + def bound_addr(self): + """ + The bind address, or if it's an ephemeral port and the + socket has been bound, return the actual port bound. + """ + host, port = self.bind_addr + if port == 0 and self.httpserver.socket: + # Bound to ephemeral port. Get the actual port allocated. + port = self.httpserver.socket.getsockname()[1] + return host, port + + def stop(self): + """Stop the HTTP server.""" + if self.running: + # stop() MUST block until the server is *truly* stopped. + self.httpserver.stop() + # Wait for the socket to be truly freed. + if isinstance(self.bind_addr, tuple): + portend.free(*self.bound_addr, timeout=Timeouts.free) + self.running = False + self.bus.log('HTTP Server %s shut down' % self.httpserver) + else: + self.bus.log('HTTP Server %s already shut down' % self.httpserver) + stop.priority = 25 + + def restart(self): + """Restart the HTTP server.""" + self.stop() + self.start() + + +class FlupCGIServer(object): + + """Adapter for a flup.server.cgi.WSGIServer.""" + + def __init__(self, *args, **kwargs): + self.args = args + self.kwargs = kwargs + self.ready = False + + def start(self): + """Start the CGI server.""" + # We have to instantiate the server class here because its __init__ + # starts a threadpool. If we do it too early, daemonize won't work. + from flup.server.cgi import WSGIServer + + self.cgiserver = WSGIServer(*self.args, **self.kwargs) + self.ready = True + self.cgiserver.run() + + def stop(self): + """Stop the HTTP server.""" + self.ready = False + + +class FlupFCGIServer(object): + + """Adapter for a flup.server.fcgi.WSGIServer.""" + + def __init__(self, *args, **kwargs): + if kwargs.get('bindAddress', None) is None: + import socket + if not hasattr(socket, 'fromfd'): + raise ValueError( + 'Dynamic FCGI server not available on this platform. ' + 'You must use a static or external one by providing a ' + 'legal bindAddress.') + self.args = args + self.kwargs = kwargs + self.ready = False + + def start(self): + """Start the FCGI server.""" + # We have to instantiate the server class here because its __init__ + # starts a threadpool. If we do it too early, daemonize won't work. + from flup.server.fcgi import WSGIServer + self.fcgiserver = WSGIServer(*self.args, **self.kwargs) + # TODO: report this bug upstream to flup. + # If we don't set _oldSIGs on Windows, we get: + # File "C:\Python24\Lib\site-packages\flup\server\threadedserver.py", + # line 108, in run + # self._restoreSignalHandlers() + # File "C:\Python24\Lib\site-packages\flup\server\threadedserver.py", + # line 156, in _restoreSignalHandlers + # for signum,handler in self._oldSIGs: + # AttributeError: 'WSGIServer' object has no attribute '_oldSIGs' + self.fcgiserver._installSignalHandlers = lambda: None + self.fcgiserver._oldSIGs = [] + self.ready = True + self.fcgiserver.run() + + def stop(self): + """Stop the HTTP server.""" + # Forcibly stop the fcgi server main event loop. + self.fcgiserver._keepGoing = False + # Force all worker threads to die off. + self.fcgiserver._threadPool.maxSpare = ( + self.fcgiserver._threadPool._idleCount) + self.ready = False + + +class FlupSCGIServer(object): + + """Adapter for a flup.server.scgi.WSGIServer.""" + + def __init__(self, *args, **kwargs): + self.args = args + self.kwargs = kwargs + self.ready = False + + def start(self): + """Start the SCGI server.""" + # We have to instantiate the server class here because its __init__ + # starts a threadpool. If we do it too early, daemonize won't work. + from flup.server.scgi import WSGIServer + self.scgiserver = WSGIServer(*self.args, **self.kwargs) + # TODO: report this bug upstream to flup. + # If we don't set _oldSIGs on Windows, we get: + # File "C:\Python24\Lib\site-packages\flup\server\threadedserver.py", + # line 108, in run + # self._restoreSignalHandlers() + # File "C:\Python24\Lib\site-packages\flup\server\threadedserver.py", + # line 156, in _restoreSignalHandlers + # for signum,handler in self._oldSIGs: + # AttributeError: 'WSGIServer' object has no attribute '_oldSIGs' + self.scgiserver._installSignalHandlers = lambda: None + self.scgiserver._oldSIGs = [] + self.ready = True + self.scgiserver.run() + + def stop(self): + """Stop the HTTP server.""" + self.ready = False + # Forcibly stop the scgi server main event loop. + self.scgiserver._keepGoing = False + # Force all worker threads to die off. + self.scgiserver._threadPool.maxSpare = 0 + + +@contextlib.contextmanager +def _safe_wait(host, port): + """ + On systems where a loopback interface is not available and the + server is bound to all interfaces, it's difficult to determine + whether the server is in fact occupying the port. In this case, + just issue a warning and move on. See issue #1100. + """ + try: + yield + except portend.Timeout: + if host == portend.client_host(host): + raise + msg = 'Unable to verify that the server is bound on %r' % port + warnings.warn(msg) diff --git a/resources/lib/cherrypy/process/win32.py b/resources/lib/cherrypy/process/win32.py new file mode 100644 index 0000000..b7a79b1 --- /dev/null +++ b/resources/lib/cherrypy/process/win32.py @@ -0,0 +1,183 @@ +"""Windows service. Requires pywin32.""" + +import os +import win32api +import win32con +import win32event +import win32service +import win32serviceutil + +from cherrypy.process import wspbus, plugins + + +class ConsoleCtrlHandler(plugins.SimplePlugin): + + """A WSPBus plugin for handling Win32 console events (like Ctrl-C).""" + + def __init__(self, bus): + self.is_set = False + plugins.SimplePlugin.__init__(self, bus) + + def start(self): + if self.is_set: + self.bus.log('Handler for console events already set.', level=20) + return + + result = win32api.SetConsoleCtrlHandler(self.handle, 1) + if result == 0: + self.bus.log('Could not SetConsoleCtrlHandler (error %r)' % + win32api.GetLastError(), level=40) + else: + self.bus.log('Set handler for console events.', level=20) + self.is_set = True + + def stop(self): + if not self.is_set: + self.bus.log('Handler for console events already off.', level=20) + return + + try: + result = win32api.SetConsoleCtrlHandler(self.handle, 0) + except ValueError: + # "ValueError: The object has not been registered" + result = 1 + + if result == 0: + self.bus.log('Could not remove SetConsoleCtrlHandler (error %r)' % + win32api.GetLastError(), level=40) + else: + self.bus.log('Removed handler for console events.', level=20) + self.is_set = False + + def handle(self, event): + """Handle console control events (like Ctrl-C).""" + if event in (win32con.CTRL_C_EVENT, win32con.CTRL_LOGOFF_EVENT, + win32con.CTRL_BREAK_EVENT, win32con.CTRL_SHUTDOWN_EVENT, + win32con.CTRL_CLOSE_EVENT): + self.bus.log('Console event %s: shutting down bus' % event) + + # Remove self immediately so repeated Ctrl-C doesn't re-call it. + try: + self.stop() + except ValueError: + pass + + self.bus.exit() + # 'First to return True stops the calls' + return 1 + return 0 + + +class Win32Bus(wspbus.Bus): + + """A Web Site Process Bus implementation for Win32. + + Instead of time.sleep, this bus blocks using native win32event objects. + """ + + def __init__(self): + self.events = {} + wspbus.Bus.__init__(self) + + def _get_state_event(self, state): + """Return a win32event for the given state (creating it if needed).""" + try: + return self.events[state] + except KeyError: + event = win32event.CreateEvent(None, 0, 0, + 'WSPBus %s Event (pid=%r)' % + (state.name, os.getpid())) + self.events[state] = event + return event + + @property + def state(self): + return self._state + + @state.setter + def state(self, value): + self._state = value + event = self._get_state_event(value) + win32event.PulseEvent(event) + + def wait(self, state, interval=0.1, channel=None): + """Wait for the given state(s), KeyboardInterrupt or SystemExit. + + Since this class uses native win32event objects, the interval + argument is ignored. + """ + if isinstance(state, (tuple, list)): + # Don't wait for an event that beat us to the punch ;) + if self.state not in state: + events = tuple([self._get_state_event(s) for s in state]) + win32event.WaitForMultipleObjects( + events, 0, win32event.INFINITE) + else: + # Don't wait for an event that beat us to the punch ;) + if self.state != state: + event = self._get_state_event(state) + win32event.WaitForSingleObject(event, win32event.INFINITE) + + +class _ControlCodes(dict): + + """Control codes used to "signal" a service via ControlService. + + User-defined control codes are in the range 128-255. We generally use + the standard Python value for the Linux signal and add 128. Example: + + >>> signal.SIGUSR1 + 10 + control_codes['graceful'] = 128 + 10 + """ + + def key_for(self, obj): + """For the given value, return its corresponding key.""" + for key, val in self.items(): + if val is obj: + return key + raise ValueError('The given object could not be found: %r' % obj) + + +control_codes = _ControlCodes({'graceful': 138}) + + +def signal_child(service, command): + if command == 'stop': + win32serviceutil.StopService(service) + elif command == 'restart': + win32serviceutil.RestartService(service) + else: + win32serviceutil.ControlService(service, control_codes[command]) + + +class PyWebService(win32serviceutil.ServiceFramework): + + """Python Web Service.""" + + _svc_name_ = 'Python Web Service' + _svc_display_name_ = 'Python Web Service' + _svc_deps_ = None # sequence of service names on which this depends + _exe_name_ = 'pywebsvc' + _exe_args_ = None # Default to no arguments + + # Only exists on Windows 2000 or later, ignored on windows NT + _svc_description_ = 'Python Web Service' + + def SvcDoRun(self): + from cherrypy import process + process.bus.start() + process.bus.block() + + def SvcStop(self): + from cherrypy import process + self.ReportServiceStatus(win32service.SERVICE_STOP_PENDING) + process.bus.exit() + + def SvcOther(self, control): + from cherrypy import process + process.bus.publish(control_codes.key_for(control)) + + +if __name__ == '__main__': + win32serviceutil.HandleCommandLine(PyWebService) diff --git a/resources/lib/cherrypy/process/wspbus.py b/resources/lib/cherrypy/process/wspbus.py new file mode 100644 index 0000000..ead90a4 --- /dev/null +++ b/resources/lib/cherrypy/process/wspbus.py @@ -0,0 +1,587 @@ +r"""An implementation of the Web Site Process Bus. + +This module is completely standalone, depending only on the stdlib. + +Web Site Process Bus +-------------------- + +A Bus object is used to contain and manage site-wide behavior: +daemonization, HTTP server start/stop, process reload, signal handling, +drop privileges, PID file management, logging for all of these, +and many more. + +In addition, a Bus object provides a place for each web framework +to register code that runs in response to site-wide events (like +process start and stop), or which controls or otherwise interacts with +the site-wide components mentioned above. For example, a framework which +uses file-based templates would add known template filenames to an +autoreload component. + +Ideally, a Bus object will be flexible enough to be useful in a variety +of invocation scenarios: + + 1. The deployer starts a site from the command line via a + framework-neutral deployment script; applications from multiple frameworks + are mixed in a single site. Command-line arguments and configuration + files are used to define site-wide components such as the HTTP server, + WSGI component graph, autoreload behavior, signal handling, etc. + 2. The deployer starts a site via some other process, such as Apache; + applications from multiple frameworks are mixed in a single site. + Autoreload and signal handling (from Python at least) are disabled. + 3. The deployer starts a site via a framework-specific mechanism; + for example, when running tests, exploring tutorials, or deploying + single applications from a single framework. The framework controls + which site-wide components are enabled as it sees fit. + +The Bus object in this package uses topic-based publish-subscribe +messaging to accomplish all this. A few topic channels are built in +('start', 'stop', 'exit', 'graceful', 'log', and 'main'). Frameworks and +site containers are free to define their own. If a message is sent to a +channel that has not been defined or has no listeners, there is no effect. + +In general, there should only ever be a single Bus object per process. +Frameworks and site containers share a single Bus object by publishing +messages and subscribing listeners. + +The Bus object works as a finite state machine which models the current +state of the process. Bus methods move it from one state to another; +those methods then publish to subscribed listeners on the channel for +the new state.:: + + O + | + V + STOPPING --> STOPPED --> EXITING -> X + A A | + | \___ | + | \ | + | V V + STARTED <-- STARTING + +""" + +import atexit + +try: + import ctypes +except ImportError: + """Google AppEngine is shipped without ctypes + + :seealso: http://stackoverflow.com/a/6523777/70170 + """ + ctypes = None + +import operator +import os +import sys +import threading +import time +import traceback as _traceback +import warnings +import subprocess +import functools + +from more_itertools import always_iterable + + +# Here I save the value of os.getcwd(), which, if I am imported early enough, +# will be the directory from which the startup script was run. This is needed +# by _do_execv(), to change back to the original directory before execv()ing a +# new process. This is a defense against the application having changed the +# current working directory (which could make sys.executable "not found" if +# sys.executable is a relative-path, and/or cause other problems). +_startup_cwd = os.getcwd() + + +class ChannelFailures(Exception): + """Exception raised during errors on Bus.publish().""" + + delimiter = '\n' + + def __init__(self, *args, **kwargs): + """Initialize ChannelFailures errors wrapper.""" + super(ChannelFailures, self).__init__(*args, **kwargs) + self._exceptions = list() + + def handle_exception(self): + """Append the current exception to self.""" + self._exceptions.append(sys.exc_info()[1]) + + def get_instances(self): + """Return a list of seen exception instances.""" + return self._exceptions[:] + + def __str__(self): + """Render the list of errors, which happened in channel.""" + exception_strings = map(repr, self.get_instances()) + return self.delimiter.join(exception_strings) + + __repr__ = __str__ + + def __bool__(self): + """Determine whether any error happened in channel.""" + return bool(self._exceptions) + __nonzero__ = __bool__ + +# Use a flag to indicate the state of the bus. + + +class _StateEnum(object): + + class State(object): + name = None + + def __repr__(self): + return 'states.%s' % self.name + + def __setattr__(self, key, value): + if isinstance(value, self.State): + value.name = key + object.__setattr__(self, key, value) + + +states = _StateEnum() +states.STOPPED = states.State() +states.STARTING = states.State() +states.STARTED = states.State() +states.STOPPING = states.State() +states.EXITING = states.State() + + +try: + import fcntl +except ImportError: + max_files = 0 +else: + try: + max_files = os.sysconf('SC_OPEN_MAX') + except AttributeError: + max_files = 1024 + + +class Bus(object): + """Process state-machine and messenger for HTTP site deployment. + + All listeners for a given channel are guaranteed to be called even + if others at the same channel fail. Each failure is logged, but + execution proceeds on to the next listener. The only way to stop all + processing from inside a listener is to raise SystemExit and stop the + whole server. + """ + + states = states + state = states.STOPPED + execv = False + max_cloexec_files = max_files + + def __init__(self): + """Initialize pub/sub bus.""" + self.execv = False + self.state = states.STOPPED + channels = 'start', 'stop', 'exit', 'graceful', 'log', 'main' + self.listeners = dict( + (channel, set()) + for channel in channels + ) + self._priorities = {} + + def subscribe(self, channel, callback=None, priority=None): + """Add the given callback at the given channel (if not present). + + If callback is None, return a partial suitable for decorating + the callback. + """ + if callback is None: + return functools.partial( + self.subscribe, + channel, + priority=priority, + ) + + ch_listeners = self.listeners.setdefault(channel, set()) + ch_listeners.add(callback) + + if priority is None: + priority = getattr(callback, 'priority', 50) + self._priorities[(channel, callback)] = priority + + def unsubscribe(self, channel, callback): + """Discard the given callback (if present).""" + listeners = self.listeners.get(channel) + if listeners and callback in listeners: + listeners.discard(callback) + del self._priorities[(channel, callback)] + + def publish(self, channel, *args, **kwargs): + """Return output of all subscribers for the given channel.""" + if channel not in self.listeners: + return [] + + exc = ChannelFailures() + output = [] + + raw_items = ( + (self._priorities[(channel, listener)], listener) + for listener in self.listeners[channel] + ) + items = sorted(raw_items, key=operator.itemgetter(0)) + for priority, listener in items: + try: + output.append(listener(*args, **kwargs)) + except KeyboardInterrupt: + raise + except SystemExit: + e = sys.exc_info()[1] + # If we have previous errors ensure the exit code is non-zero + if exc and e.code == 0: + e.code = 1 + raise + except Exception: + exc.handle_exception() + if channel == 'log': + # Assume any further messages to 'log' will fail. + pass + else: + self.log('Error in %r listener %r' % (channel, listener), + level=40, traceback=True) + if exc: + raise exc + return output + + def _clean_exit(self): + """Assert that the Bus is not running in atexit handler callback.""" + if self.state != states.EXITING: + warnings.warn( + 'The main thread is exiting, but the Bus is in the %r state; ' + 'shutting it down automatically now. You must either call ' + 'bus.block() after start(), or call bus.exit() before the ' + 'main thread exits.' % self.state, RuntimeWarning) + self.exit() + + def start(self): + """Start all services.""" + atexit.register(self._clean_exit) + + self.state = states.STARTING + self.log('Bus STARTING') + try: + self.publish('start') + self.state = states.STARTED + self.log('Bus STARTED') + except (KeyboardInterrupt, SystemExit): + raise + except Exception: + self.log('Shutting down due to error in start listener:', + level=40, traceback=True) + e_info = sys.exc_info()[1] + try: + self.exit() + except Exception: + # Any stop/exit errors will be logged inside publish(). + pass + # Re-raise the original error + raise e_info + + def exit(self): + """Stop all services and prepare to exit the process.""" + exitstate = self.state + EX_SOFTWARE = 70 + try: + self.stop() + + self.state = states.EXITING + self.log('Bus EXITING') + self.publish('exit') + # This isn't strictly necessary, but it's better than seeing + # "Waiting for child threads to terminate..." and then nothing. + self.log('Bus EXITED') + except Exception: + # This method is often called asynchronously (whether thread, + # signal handler, console handler, or atexit handler), so we + # can't just let exceptions propagate out unhandled. + # Assume it's been logged and just die. + os._exit(EX_SOFTWARE) + + if exitstate == states.STARTING: + # exit() was called before start() finished, possibly due to + # Ctrl-C because a start listener got stuck. In this case, + # we could get stuck in a loop where Ctrl-C never exits the + # process, so we just call os.exit here. + os._exit(EX_SOFTWARE) + + def restart(self): + """Restart the process (may close connections). + + This method does not restart the process from the calling thread; + instead, it stops the bus and asks the main thread to call execv. + """ + self.execv = True + self.exit() + + def graceful(self): + """Advise all services to reload.""" + self.log('Bus graceful') + self.publish('graceful') + + def block(self, interval=0.1): + """Wait for the EXITING state, KeyboardInterrupt or SystemExit. + + This function is intended to be called only by the main thread. + After waiting for the EXITING state, it also waits for all threads + to terminate, and then calls os.execv if self.execv is True. This + design allows another thread to call bus.restart, yet have the main + thread perform the actual execv call (required on some platforms). + """ + try: + self.wait(states.EXITING, interval=interval, channel='main') + except (KeyboardInterrupt, IOError): + # The time.sleep call might raise + # "IOError: [Errno 4] Interrupted function call" on KBInt. + self.log('Keyboard Interrupt: shutting down bus') + self.exit() + except SystemExit: + self.log('SystemExit raised: shutting down bus') + self.exit() + raise + + # Waiting for ALL child threads to finish is necessary on OS X. + # See https://github.com/cherrypy/cherrypy/issues/581. + # It's also good to let them all shut down before allowing + # the main thread to call atexit handlers. + # See https://github.com/cherrypy/cherrypy/issues/751. + self.log('Waiting for child threads to terminate...') + for t in threading.enumerate(): + # Validate the we're not trying to join the MainThread + # that will cause a deadlock and the case exist when + # implemented as a windows service and in any other case + # that another thread executes cherrypy.engine.exit() + if ( + t != threading.currentThread() and + not isinstance(t, threading._MainThread) and + # Note that any dummy (external) threads are + # always daemonic. + not t.daemon + ): + self.log('Waiting for thread %s.' % t.getName()) + t.join() + + if self.execv: + self._do_execv() + + def wait(self, state, interval=0.1, channel=None): + """Poll for the given state(s) at intervals; publish to channel.""" + states = set(always_iterable(state)) + + while self.state not in states: + time.sleep(interval) + self.publish(channel) + + def _do_execv(self): + """Re-execute the current process. + + This must be called from the main thread, because certain platforms + (OS X) don't allow execv to be called in a child thread very well. + """ + try: + args = self._get_true_argv() + except NotImplementedError: + """It's probably win32 or GAE""" + args = [sys.executable] + self._get_interpreter_argv() + sys.argv + + self.log('Re-spawning %s' % ' '.join(args)) + + self._extend_pythonpath(os.environ) + + if sys.platform[:4] == 'java': + from _systemrestart import SystemRestart + raise SystemRestart + else: + if sys.platform == 'win32': + args = ['"%s"' % arg for arg in args] + + os.chdir(_startup_cwd) + if self.max_cloexec_files: + self._set_cloexec() + os.execv(sys.executable, args) + + @staticmethod + def _get_interpreter_argv(): + """Retrieve current Python interpreter's arguments. + + Returns empty tuple in case of frozen mode, uses built-in arguments + reproduction function otherwise. + + Frozen mode is possible for the app has been packaged into a binary + executable using py2exe. In this case the interpreter's arguments are + already built-in into that executable. + + :seealso: https://github.com/cherrypy/cherrypy/issues/1526 + Ref: https://pythonhosted.org/PyInstaller/runtime-information.html + """ + return ([] + if getattr(sys, 'frozen', False) + else subprocess._args_from_interpreter_flags()) + + @staticmethod + def _get_true_argv(): + """Retrieve all real arguments of the python interpreter. + + ...even those not listed in ``sys.argv`` + + :seealso: http://stackoverflow.com/a/28338254/595220 + :seealso: http://stackoverflow.com/a/6683222/595220 + :seealso: http://stackoverflow.com/a/28414807/595220 + """ + try: + char_p = ctypes.c_wchar_p + + argv = ctypes.POINTER(char_p)() + argc = ctypes.c_int() + + ctypes.pythonapi.Py_GetArgcArgv( + ctypes.byref(argc), + ctypes.byref(argv), + ) + + _argv = argv[:argc.value] + + # The code below is trying to correctly handle special cases. + # `-c`'s argument interpreted by Python itself becomes `-c` as + # well. Same applies to `-m`. This snippet is trying to survive + # at least the case with `-m` + # Ref: https://github.com/cherrypy/cherrypy/issues/1545 + # Ref: python/cpython@418baf9 + argv_len, is_command, is_module = len(_argv), False, False + + try: + m_ind = _argv.index('-m') + if m_ind < argv_len - 1 and _argv[m_ind + 1] in ('-c', '-m'): + """ + In some older Python versions `-m`'s argument may be + substituted with `-c`, not `-m` + """ + is_module = True + except (IndexError, ValueError): + m_ind = None + + try: + c_ind = _argv.index('-c') + if c_ind < argv_len - 1 and _argv[c_ind + 1] == '-c': + is_command = True + except (IndexError, ValueError): + c_ind = None + + if is_module: + """It's containing `-m -m` sequence of arguments""" + if is_command and c_ind < m_ind: + """There's `-c -c` before `-m`""" + raise RuntimeError( + "Cannot reconstruct command from '-c'. Ref: " + 'https://github.com/cherrypy/cherrypy/issues/1545') + # Survive module argument here + original_module = sys.argv[0] + if not os.access(original_module, os.R_OK): + """There's no such module exist""" + raise AttributeError( + "{} doesn't seem to be a module " + 'accessible by current user'.format(original_module)) + del _argv[m_ind:m_ind + 2] # remove `-m -m` + # ... and substitute it with the original module path: + _argv.insert(m_ind, original_module) + elif is_command: + """It's containing just `-c -c` sequence of arguments""" + raise RuntimeError( + "Cannot reconstruct command from '-c'. " + 'Ref: https://github.com/cherrypy/cherrypy/issues/1545') + except AttributeError: + """It looks Py_GetArgcArgv is completely absent in some environments + + It is known, that there's no Py_GetArgcArgv in MS Windows and + ``ctypes`` module is completely absent in Google AppEngine + + :seealso: https://github.com/cherrypy/cherrypy/issues/1506 + :seealso: https://github.com/cherrypy/cherrypy/issues/1512 + :ref: http://bit.ly/2gK6bXK + """ + raise NotImplementedError + else: + return _argv + + @staticmethod + def _extend_pythonpath(env): + """Prepend current working dir to PATH environment variable if needed. + + If sys.path[0] is an empty string, the interpreter was likely + invoked with -m and the effective path is about to change on + re-exec. Add the current directory to $PYTHONPATH to ensure + that the new process sees the same path. + + This issue cannot be addressed in the general case because + Python cannot reliably reconstruct the + original command line (http://bugs.python.org/issue14208). + + (This idea filched from tornado.autoreload) + """ + path_prefix = '.' + os.pathsep + existing_path = env.get('PYTHONPATH', '') + needs_patch = ( + sys.path[0] == '' and + not existing_path.startswith(path_prefix) + ) + + if needs_patch: + env['PYTHONPATH'] = path_prefix + existing_path + + def _set_cloexec(self): + """Set the CLOEXEC flag on all open files (except stdin/out/err). + + If self.max_cloexec_files is an integer (the default), then on + platforms which support it, it represents the max open files setting + for the operating system. This function will be called just before + the process is restarted via os.execv() to prevent open files + from persisting into the new process. + + Set self.max_cloexec_files to 0 to disable this behavior. + """ + for fd in range(3, self.max_cloexec_files): # skip stdin/out/err + try: + flags = fcntl.fcntl(fd, fcntl.F_GETFD) + except IOError: + continue + fcntl.fcntl(fd, fcntl.F_SETFD, flags | fcntl.FD_CLOEXEC) + + def stop(self): + """Stop all services.""" + self.state = states.STOPPING + self.log('Bus STOPPING') + self.publish('stop') + self.state = states.STOPPED + self.log('Bus STOPPED') + + def start_with_callback(self, func, args=None, kwargs=None): + """Start 'func' in a new thread T, then start self (and return T).""" + if args is None: + args = () + if kwargs is None: + kwargs = {} + args = (func,) + args + + def _callback(func, *a, **kw): + self.wait(states.STARTED) + func(*a, **kw) + t = threading.Thread(target=_callback, args=args, kwargs=kwargs) + t.setName('Bus Callback ' + t.getName()) + t.start() + + self.start() + + return t + + def log(self, msg='', level=20, traceback=False): + """Log the given message. Append the last traceback if requested.""" + if traceback: + msg += '\n' + ''.join(_traceback.format_exception(*sys.exc_info())) + self.publish('log', msg, level) + + +bus = Bus() diff --git a/resources/lib/cherrypy/scaffold/__init__.py b/resources/lib/cherrypy/scaffold/__init__.py new file mode 100644 index 0000000..bcddba2 --- /dev/null +++ b/resources/lib/cherrypy/scaffold/__init__.py @@ -0,0 +1,63 @@ +""", a CherryPy application. + +Use this as a base for creating new CherryPy applications. When you want +to make a new app, copy and paste this folder to some other location +(maybe site-packages) and rename it to the name of your project, +then tweak as desired. + +Even before any tweaking, this should serve a few demonstration pages. +Change to this directory and run: + + cherryd -c site.conf + +""" + +import cherrypy +from cherrypy import tools, url + +import os +local_dir = os.path.join(os.getcwd(), os.path.dirname(__file__)) + + +@cherrypy.config(**{'tools.log_tracebacks.on': True}) +class Root: + """Declaration of the CherryPy app URI structure.""" + + @cherrypy.expose + def index(self): + """Render HTML-template at the root path of the web-app.""" + return """ +Try some other path, +or a default path.
+Or, just look at the pretty picture:
+ +""" % (url('other'), url('else'), + url('files/made_with_cherrypy_small.png')) + + @cherrypy.expose + def default(self, *args, **kwargs): + """Render catch-all args and kwargs.""" + return 'args: %s kwargs: %s' % (args, kwargs) + + @cherrypy.expose + def other(self, a=2, b='bananas', c=None): + """Render number of fruits based on third argument.""" + cherrypy.response.headers['Content-Type'] = 'text/plain' + if c is None: + return 'Have %d %s.' % (int(a), b) + else: + return 'Have %d %s, %s.' % (int(a), b, c) + + files = tools.staticdir.handler( + section='/files', + dir=os.path.join(local_dir, 'static'), + # Ignore .php files, etc. + match=r'\.(css|gif|html?|ico|jpe?g|js|png|swf|xml)$', + ) + + +root = Root() + +# Uncomment the following to use your own favicon instead of CP's default. +# favicon_path = os.path.join(local_dir, "favicon.ico") +# root.favicon_ico = tools.staticfile.handler(filename=favicon_path) diff --git a/resources/lib/cherrypy/scaffold/apache-fcgi.conf b/resources/lib/cherrypy/scaffold/apache-fcgi.conf new file mode 100644 index 0000000..6e4f144 --- /dev/null +++ b/resources/lib/cherrypy/scaffold/apache-fcgi.conf @@ -0,0 +1,22 @@ +# Apache2 server conf file for using CherryPy with mod_fcgid. + +# This doesn't have to be "C:/", but it has to be a directory somewhere, and +# MUST match the directory used in the FastCgiExternalServer directive, below. +DocumentRoot "C:/" + +ServerName 127.0.0.1 +Listen 80 +LoadModule fastcgi_module modules/mod_fastcgi.dll +LoadModule rewrite_module modules/mod_rewrite.so + +Options ExecCGI +SetHandler fastcgi-script +RewriteEngine On +# Send requests for any URI to our fastcgi handler. +RewriteRule ^(.*)$ /fastcgi.pyc [L] + +# The FastCgiExternalServer directive defines filename as an external FastCGI application. +# If filename does not begin with a slash (/) then it is assumed to be relative to the ServerRoot. +# The filename does not have to exist in the local filesystem. URIs that Apache resolves to this +# filename will be handled by this external FastCGI application. +FastCgiExternalServer "C:/fastcgi.pyc" -host 127.0.0.1:8088 diff --git a/resources/lib/cherrypy/scaffold/example.conf b/resources/lib/cherrypy/scaffold/example.conf new file mode 100644 index 0000000..63250fe --- /dev/null +++ b/resources/lib/cherrypy/scaffold/example.conf @@ -0,0 +1,3 @@ +[/] +log.error_file: "error.log" +log.access_file: "access.log" diff --git a/resources/lib/cherrypy/scaffold/site.conf b/resources/lib/cherrypy/scaffold/site.conf new file mode 100644 index 0000000..6ed3898 --- /dev/null +++ b/resources/lib/cherrypy/scaffold/site.conf @@ -0,0 +1,14 @@ +[global] +# Uncomment this when you're done developing +#environment: "production" + +server.socket_host: "0.0.0.0" +server.socket_port: 8088 + +# Uncomment the following lines to run on HTTPS at the same time +#server.2.socket_host: "0.0.0.0" +#server.2.socket_port: 8433 +#server.2.ssl_certificate: '../test/test.pem' +#server.2.ssl_private_key: '../test/test.pem' + +tree.myapp: cherrypy.Application(scaffold.root, "/", "example.conf") diff --git a/resources/lib/cherrypy/scaffold/static/made_with_cherrypy_small.png b/resources/lib/cherrypy/scaffold/static/made_with_cherrypy_small.png new file mode 100644 index 0000000..724f9d7 Binary files /dev/null and b/resources/lib/cherrypy/scaffold/static/made_with_cherrypy_small.png differ diff --git a/resources/lib/connect_daemon.py b/resources/lib/connect_daemon.py index 73aa701..ac7aa71 100644 --- a/resources/lib/connect_daemon.py +++ b/resources/lib/connect_daemon.py @@ -5,7 +5,7 @@ from utils import log_msg, log_exception import xbmc import threading -import thread +import _thread import xbmcvfs @@ -39,26 +39,26 @@ def run(self): xbmc.executebuiltin("SetProperty(spotify-discovery,disabled,Home)") try: try: - log_msg("trying AP Port 443", xbmc.LOGNOTICE) + log_msg("trying AP Port 443", xbmc.LOGINFO) self.__spotty_proc = self.__spotty.run_spotty(arguments=spotty_args, disable_discovery=disable_discovery, ap_port="443") except: try: - log_msg("trying AP Port 80", xbmc.LOGNOTICE) + log_msg("trying AP Port 80", xbmc.LOGINFO) self.__spotty_proc = self.__spotty.run_spotty(arguments=spotty_args, disable_discovery=disable_discovery, ap_port="80") except: - log_msg("trying AP Port 4070", xbmc.LOGNOTICE) + log_msg("trying AP Port 4070", xbmc.LOGINFO) self.__spotty_proc = self.__spotty.run_spotty(arguments=spotty_args, disable_discovery=disable_discovery, ap_port="4070") while not self.__exit: line = self.__spotty_proc.stdout.readline() if self.__spotty_proc.returncode and self.__spotty_proc.returncode > 0 and not self.__exit: # daemon crashed ? restart ? - log_msg("spotty stopped!", xbmc.LOGNOTICE) + log_msg("spotty stopped!", xbmc.LOGINFO) break xbmc.sleep(100) self.daemon_active = False log_msg("Stopped Spotify Connect Daemon") except: self.daemon_active = False - log_msg("Cannot run SPOTTY, No APs available", xbmc.LOGNOTICE) + log_msg("Cannot run SPOTTY, No APs available", xbmc.LOGINFO) diff --git a/resources/lib/contextlib2.py b/resources/lib/contextlib2.py new file mode 100644 index 0000000..f08df14 --- /dev/null +++ b/resources/lib/contextlib2.py @@ -0,0 +1,436 @@ +"""contextlib2 - backports and enhancements to the contextlib module""" + +import sys +import warnings +from collections import deque +from functools import wraps + +__all__ = ["contextmanager", "closing", "ContextDecorator", "ExitStack", + "redirect_stdout", "redirect_stderr", "suppress"] + +# Backwards compatibility +__all__ += ["ContextStack"] + +class ContextDecorator(object): + "A base class or mixin that enables context managers to work as decorators." + + def refresh_cm(self): + """Returns the context manager used to actually wrap the call to the + decorated function. + + The default implementation just returns *self*. + + Overriding this method allows otherwise one-shot context managers + like _GeneratorContextManager to support use as decorators via + implicit recreation. + + DEPRECATED: refresh_cm was never added to the standard library's + ContextDecorator API + """ + warnings.warn("refresh_cm was never added to the standard library", + DeprecationWarning) + return self._recreate_cm() + + def _recreate_cm(self): + """Return a recreated instance of self. + + Allows an otherwise one-shot context manager like + _GeneratorContextManager to support use as + a decorator via implicit recreation. + + This is a private interface just for _GeneratorContextManager. + See issue #11647 for details. + """ + return self + + def __call__(self, func): + @wraps(func) + def inner(*args, **kwds): + with self._recreate_cm(): + return func(*args, **kwds) + return inner + + +class _GeneratorContextManager(ContextDecorator): + """Helper for @contextmanager decorator.""" + + def __init__(self, func, args, kwds): + self.gen = func(*args, **kwds) + self.func, self.args, self.kwds = func, args, kwds + # Issue 19330: ensure context manager instances have good docstrings + doc = getattr(func, "__doc__", None) + if doc is None: + doc = type(self).__doc__ + self.__doc__ = doc + # Unfortunately, this still doesn't provide good help output when + # inspecting the created context manager instances, since pydoc + # currently bypasses the instance docstring and shows the docstring + # for the class instead. + # See http://bugs.python.org/issue19404 for more details. + + def _recreate_cm(self): + # _GCM instances are one-shot context managers, so the + # CM must be recreated each time a decorated function is + # called + return self.__class__(self.func, self.args, self.kwds) + + def __enter__(self): + try: + return next(self.gen) + except StopIteration: + raise RuntimeError("generator didn't yield") + + def __exit__(self, type, value, traceback): + if type is None: + try: + next(self.gen) + except StopIteration: + return + else: + raise RuntimeError("generator didn't stop") + else: + if value is None: + # Need to force instantiation so we can reliably + # tell if we get the same exception back + value = type() + try: + self.gen.throw(type, value, traceback) + raise RuntimeError("generator didn't stop after throw()") + except StopIteration as exc: + # Suppress StopIteration *unless* it's the same exception that + # was passed to throw(). This prevents a StopIteration + # raised inside the "with" statement from being suppressed. + return exc is not value + except RuntimeError as exc: + # Don't re-raise the passed in exception + if exc is value: + return False + # Likewise, avoid suppressing if a StopIteration exception + # was passed to throw() and later wrapped into a RuntimeError + # (see PEP 479). + if _HAVE_EXCEPTION_CHAINING and exc.__cause__ is value: + return False + raise + except: + # only re-raise if it's *not* the exception that was + # passed to throw(), because __exit__() must not raise + # an exception unless __exit__() itself failed. But throw() + # has to raise the exception to signal propagation, so this + # fixes the impedance mismatch between the throw() protocol + # and the __exit__() protocol. + # + if sys.exc_info()[1] is not value: + raise + + +def contextmanager(func): + """@contextmanager decorator. + + Typical usage: + + @contextmanager + def some_generator(): + + try: + yield + finally: + + + This makes this: + + with some_generator() as : + + + equivalent to this: + + + try: + = + + finally: + + + """ + @wraps(func) + def helper(*args, **kwds): + return _GeneratorContextManager(func, args, kwds) + return helper + + +class closing(object): + """Context to automatically close something at the end of a block. + + Code like this: + + with closing(.open()) as f: + + + is equivalent to this: + + f = .open() + try: + + finally: + f.close() + + """ + def __init__(self, thing): + self.thing = thing + def __enter__(self): + return self.thing + def __exit__(self, *exc_info): + self.thing.close() + + +class _RedirectStream(object): + + _stream = None + + def __init__(self, new_target): + self._new_target = new_target + # We use a list of old targets to make this CM re-entrant + self._old_targets = [] + + def __enter__(self): + self._old_targets.append(getattr(sys, self._stream)) + setattr(sys, self._stream, self._new_target) + return self._new_target + + def __exit__(self, exctype, excinst, exctb): + setattr(sys, self._stream, self._old_targets.pop()) + + +class redirect_stdout(_RedirectStream): + """Context manager for temporarily redirecting stdout to another file. + + # How to send help() to stderr + with redirect_stdout(sys.stderr): + help(dir) + + # How to write help() to a file + with open('help.txt', 'w') as f: + with redirect_stdout(f): + help(pow) + """ + + _stream = "stdout" + + +class redirect_stderr(_RedirectStream): + """Context manager for temporarily redirecting stderr to another file.""" + + _stream = "stderr" + + +class suppress(object): + """Context manager to suppress specified exceptions + + After the exception is suppressed, execution proceeds with the next + statement following the with statement. + + with suppress(FileNotFoundError): + os.remove(somefile) + # Execution still resumes here if the file was already removed + """ + + def __init__(self, *exceptions): + self._exceptions = exceptions + + def __enter__(self): + pass + + def __exit__(self, exctype, excinst, exctb): + # Unlike isinstance and issubclass, CPython exception handling + # currently only looks at the concrete type hierarchy (ignoring + # the instance and subclass checking hooks). While Guido considers + # that a bug rather than a feature, it's a fairly hard one to fix + # due to various internal implementation details. suppress provides + # the simpler issubclass based semantics, rather than trying to + # exactly reproduce the limitations of the CPython interpreter. + # + # See http://bugs.python.org/issue12029 for more details + return exctype is not None and issubclass(exctype, self._exceptions) + + +# Context manipulation is Python 3 only +_HAVE_EXCEPTION_CHAINING = sys.version_info[0] >= 3 +if _HAVE_EXCEPTION_CHAINING: + def _make_context_fixer(frame_exc): + def _fix_exception_context(new_exc, old_exc): + # Context may not be correct, so find the end of the chain + while 1: + exc_context = new_exc.__context__ + if exc_context is old_exc: + # Context is already set correctly (see issue 20317) + return + if exc_context is None or exc_context is frame_exc: + break + new_exc = exc_context + # Change the end of the chain to point to the exception + # we expect it to reference + new_exc.__context__ = old_exc + return _fix_exception_context + + def _reraise_with_existing_context(exc_details): + try: + # bare "raise exc_details[1]" replaces our carefully + # set-up context + fixed_ctx = exc_details[1].__context__ + raise exc_details[1] + except BaseException: + exc_details[1].__context__ = fixed_ctx + raise +else: + # No exception context in Python 2 + def _make_context_fixer(frame_exc): + return lambda new_exc, old_exc: None + + # Use 3 argument raise in Python 2, + # but use exec to avoid SyntaxError in Python 3 + def _reraise_with_existing_context(exc_details): + exc_type, exc_value, exc_tb = exc_details + exec ("raise exc_type, exc_value, exc_tb") + +# Handle old-style classes if they exist +try: + from types import InstanceType +except ImportError: + # Python 3 doesn't have old-style classes + _get_type = type +else: + # Need to handle old-style context managers on Python 2 + def _get_type(obj): + obj_type = type(obj) + if obj_type is InstanceType: + return obj.__class__ # Old-style class + return obj_type # New-style class + +# Inspired by discussions on http://bugs.python.org/issue13585 +class ExitStack(object): + """Context manager for dynamic management of a stack of exit callbacks + + For example: + + with ExitStack() as stack: + files = [stack.enter_context(open(fname)) for fname in filenames] + # All opened files will automatically be closed at the end of + # the with statement, even if attempts to open files later + # in the list raise an exception + + """ + def __init__(self): + self._exit_callbacks = deque() + + def pop_all(self): + """Preserve the context stack by transferring it to a new instance""" + new_stack = type(self)() + new_stack._exit_callbacks = self._exit_callbacks + self._exit_callbacks = deque() + return new_stack + + def _push_cm_exit(self, cm, cm_exit): + """Helper to correctly register callbacks to __exit__ methods""" + def _exit_wrapper(*exc_details): + return cm_exit(cm, *exc_details) + _exit_wrapper.__self__ = cm + self.push(_exit_wrapper) + + def push(self, exit): + """Registers a callback with the standard __exit__ method signature + + Can suppress exceptions the same way __exit__ methods can. + + Also accepts any object with an __exit__ method (registering a call + to the method instead of the object itself) + """ + # We use an unbound method rather than a bound method to follow + # the standard lookup behaviour for special methods + _cb_type = _get_type(exit) + try: + exit_method = _cb_type.__exit__ + except AttributeError: + # Not a context manager, so assume its a callable + self._exit_callbacks.append(exit) + else: + self._push_cm_exit(exit, exit_method) + return exit # Allow use as a decorator + + def callback(self, callback, *args, **kwds): + """Registers an arbitrary callback and arguments. + + Cannot suppress exceptions. + """ + def _exit_wrapper(exc_type, exc, tb): + callback(*args, **kwds) + # We changed the signature, so using @wraps is not appropriate, but + # setting __wrapped__ may still help with introspection + _exit_wrapper.__wrapped__ = callback + self.push(_exit_wrapper) + return callback # Allow use as a decorator + + def enter_context(self, cm): + """Enters the supplied context manager + + If successful, also pushes its __exit__ method as a callback and + returns the result of the __enter__ method. + """ + # We look up the special methods on the type to match the with statement + _cm_type = _get_type(cm) + _exit = _cm_type.__exit__ + result = _cm_type.__enter__(cm) + self._push_cm_exit(cm, _exit) + return result + + def close(self): + """Immediately unwind the context stack""" + self.__exit__(None, None, None) + + def __enter__(self): + return self + + def __exit__(self, *exc_details): + received_exc = exc_details[0] is not None + + # We manipulate the exception state so it behaves as though + # we were actually nesting multiple with statements + frame_exc = sys.exc_info()[1] + _fix_exception_context = _make_context_fixer(frame_exc) + + # Callbacks are invoked in LIFO order to match the behaviour of + # nested context managers + suppressed_exc = False + pending_raise = False + while self._exit_callbacks: + cb = self._exit_callbacks.pop() + try: + if cb(*exc_details): + suppressed_exc = True + pending_raise = False + exc_details = (None, None, None) + except: + new_exc_details = sys.exc_info() + # simulate the stack of exceptions by setting the context + _fix_exception_context(new_exc_details[1], exc_details[1]) + pending_raise = True + exc_details = new_exc_details + if pending_raise: + _reraise_with_existing_context(exc_details) + return received_exc and suppressed_exc + +# Preserve backwards compatibility +class ContextStack(ExitStack): + """Backwards compatibility alias for ExitStack""" + + def __init__(self): + warnings.warn("ContextStack has been renamed to ExitStack", + DeprecationWarning) + super(ContextStack, self).__init__() + + def register_exit(self, callback): + return self.push(callback) + + def register(self, callback, *args, **kwds): + return self.callback(callback, *args, **kwds) + + def preserve(self): + return self.pop_all() diff --git a/resources/lib/httpproxy.py b/resources/lib/httpproxy.py index d8d8070..09662a4 100644 --- a/resources/lib/httpproxy.py +++ b/resources/lib/httpproxy.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- import threading -import thread +import _thread import time import re import struct @@ -11,14 +11,13 @@ import sys import platform import logging -import os +from io import BytesIO from utils import log_msg, log_exception, create_wave_header, PROXY_PORT, StringIO import xbmc import math class Root: spotty = None - spotty_bin = None spotty_trackid = None spotty_range_l = None @@ -35,16 +34,20 @@ def _check_request(self): # Error if the requester is not allowed # for now this is a simple check just checking if the useragent matches Kodi user_agent = headers['User-Agent'].lower() - if not ("kodi" in user_agent or "osmc" in user_agent): - raise cherrypy.HTTPError(403) + # if not ("Kodi" in user_agent or "osmc" in user_agent): + # raise cherrypy.HTTPError(403) return method @cherrypy.expose + def index(self): + return "Server started" + + # @cherrypy.expose @cherrypy.tools.json_out() - @cherrypy.tools.json_in() + @cherrypy.tools.json_in() def lms(self, filename, **kwargs): - ''' fake lms hook to retrieve events form spotty daemon''' + ''' fake lms hook to retrieve events from spotty daemon''' method = cherrypy.request.method.upper() if method != "POST" or filename != "jsonrpc.js": raise cherrypy.HTTPError(405) @@ -54,7 +57,7 @@ def lms(self, filename, **kwargs): log_msg("lms event hook called. Event: %s" % event) # check username, it might have changed spotty_user = self.__spotty.get_username() - cur_user = xbmc.getInfoLabel("Window(Home).Property(spotify-username)").decode("utf-8") + cur_user = xbmc.getInfoLabel("Window(Home).Property(spotify-username)") if spotty_user != cur_user: log_msg("user change detected") xbmc.executebuiltin("SetProperty(spotify-cmd,__LOGOUT__,Home)") @@ -81,7 +84,7 @@ def track(self, track_id, duration, **kwargs): self._check_request() # Calculate file size, and obtain the header - duration = int(duration) + duration = int(float(duration)) wave_header, filesize = create_wave_header(duration) request_range = cherrypy.request.headers.get('Range', '') # response timeout must be at least the duration of the track: read/write loop @@ -105,7 +108,7 @@ def track(self, track_id, duration, **kwargs): range_r = filesize cherrypy.response.headers['Accept-Ranges'] = 'bytes' - cherrypy.response.headers['Content-Length'] = filesize + cherrypy.response.headers['Content-Length'] = range_r - range_l cherrypy.response.headers['Content-Range'] = "bytes %s-%s/%s" % (range_l, range_r, filesize) log_msg("partial request range: %s, length: %s" % (cherrypy.response.headers['Content-Range'], cherrypy.response.headers['Content-Length']), xbmc.LOGDEBUG) else: @@ -113,36 +116,37 @@ def track(self, track_id, duration, **kwargs): cherrypy.response.headers['Content-Type'] = 'audio/x-wav' cherrypy.response.headers['Accept-Ranges'] = 'bytes' cherrypy.response.headers['Content-Length'] = filesize + log_msg("!! Full File. Size : %s " % (filesize), xbmc.LOGDEBUG) # If method was GET, write the file content if cherrypy.request.method.upper() == 'GET': - return self.send_audio_stream(track_id, filesize, wave_header, range_l) + + if self.spotty_bin != None: + # If spotty binary still attached for a different request, try to terminate it. + log_msg("WHOOPS!!! Running spotty detected - killing it to continue.", \ + xbmc.LOGERROR) + self.kill_spotty() + + while self.spotty_bin: + time.sleep(0.1) + + return self.send_audio_stream(track_id, range_r - range_l, wave_header, range_l) + track._cp_config = {'response.stream': True} def kill_spotty(self): self.spotty_bin.terminate() + self.spotty_bin.communicate() self.spotty_bin = None self.spotty_trackid = None self.spotty_range_l = None - def send_audio_stream(self, track_id, filesize, wave_header, range_l): + def send_audio_stream(self, track_id, length, wave_header, range_l): '''chunked transfer of audio data from spotty binary''' - if self.spotty_bin != None and \ - self.spotty_trackid == track_id and \ - self.spotty_range_l == range_l: - # leave the existing spotty running and don't start a new one. - log_msg("WHOOPS!!! Running spotty still handling same request - leave it alone.", \ - xbmc.LOGERROR) - return - elif self.spotty_bin != None: - # If spotty binary still attached for a different request, try to terminate it. - log_msg("WHOOPS!!! Running spotty detected - killing it to continue.", \ - xbmc.LOGERROR) - self.kill_spotty() - - log_msg("start transfer for track %s - range: %s" % (track_id, range_l), \ - xbmc.LOGDEBUG) try: + log_msg("start transfer for track %s - range: %s" % (track_id, range_l), \ + xbmc.LOGDEBUG) + # Initialize some loop vars max_buffer_size = 524288 bytes_written = 0 @@ -152,53 +156,54 @@ def send_audio_stream(self, track_id, filesize, wave_header, range_l): # bytes_written = len(wave_header) if not range_l: yield wave_header + bytes_written = len(wave_header) - # get pcm data from spotty stdout and append to our buffer + # get OGG data from spotty stdout and append to our buffer args = ["-n", "temp", "--single-track", track_id] - self.spotty_bin = self.__spotty.run_spotty(args, use_creds=True) + if self.spotty_bin == None: + self.spotty_bin = self.__spotty.run_spotty(args, use_creds=True) self.spotty_trackid = track_id self.spotty_range_l = range_l - - # ignore the first x bytes to match the range request + log_msg("Infos : Track : %s" % track_id) + + + # ignore the first x bytes to match the range request if range_l: self.spotty_bin.stdout.read(range_l) # Loop as long as there's something to output - frame = self.spotty_bin.stdout.read(max_buffer_size) - while frame: - if cherrypy.response.timed_out: - # A timeout occured on the cherrypy session and has been flagged - so exit - # The session timer was set to be longer than the track being played so this - # would probably require network problems or something bad elsewhere. - log_msg("SPOTTY cherrypy response timeout: %r - %s" % \ - (repr(cherrypy.response.timed_out), cherrypy.response.status), xbmc.LOGERROR) + while bytes_written < length: + frame = self.spotty_bin.stdout.read(max_buffer_size) + if not frame: break bytes_written += len(frame) yield frame - frame = self.spotty_bin.stdout.read(max_buffer_size) + + log_msg("FINISH transfer for track %s - range %s - written %s" % (track_id, range_l, bytes_written), \ + xbmc.LOGDEBUG) except Exception as exc: log_exception(__name__, exc) + log_msg("EXCEPTION FINISH transfer for track %s - range %s - written %s" % (track_id, range_l, bytes_written), \ + xbmc.LOGDEBUG) finally: # make sure spotty always gets terminated if self.spotty_bin != None: self.kill_spotty() - log_msg("FINISH transfer for track %s - range %s" % (track_id, range_l), \ - xbmc.LOGDEBUG) @cherrypy.expose def silence(self, duration, **kwargs): '''stream silence audio for the given duration, used by spotify connect player''' - duration = int(duration) + duration = float(duration) wave_header, filesize = create_wave_header(duration) - output_buffer = StringIO() + output_buffer = BytesIO() output_buffer.write(wave_header) - output_buffer.write('\0' * (filesize - output_buffer.tell())) - return cherrypy.lib.static.serve_fileobj(output_buffer, content_type="audio/wav", - name="%s.wav" % duration, filesize=output_buffer.tell()) + output_buffer.write(bytes('\0' * (filesize - output_buffer.tell()), 'utf-8')) + return cherrypy.lib.static.serve_fileobj(output_buffer.read(), content_type="audio/wav", name="%s.wav" % duration, debug=True) @cherrypy.expose def nexttrack(self, **kwargs): '''play silence while spotify connect player is waiting for the next track''' + log_msg('play silence while spotify connect player is waiting for the next track', xbmc.LOGDEBUG) return self.silence(20) @cherrypy.expose @@ -234,14 +239,10 @@ class ProxyRunner(threading.Thread): def __init__(self, spotty): self.__root = Root(spotty) log = cherrypy.log - log.access_file = '' - log.error_file = '' - log.screen = False + log.screen = True cherrypy.config.update({ - 'server.socket_host': '0.0.0.0', - 'server.socket_port': PROXY_PORT, - 'engine.timeout_monitor.frequency': 5, - 'server.shutdown_timeout': 1 + 'server.socket_host': '127.0.0.1', + 'server.socket_port': PROXY_PORT }) self.__server = cherrypy.server.httpserver = CPHTTPServer(cherrypy.server) threading.Thread.__init__(self) diff --git a/resources/lib/jaraco/__init__.py b/resources/lib/jaraco/__init__.py new file mode 100644 index 0000000..69e3be5 --- /dev/null +++ b/resources/lib/jaraco/__init__.py @@ -0,0 +1 @@ +__path__ = __import__('pkgutil').extend_path(__path__, __name__) diff --git a/resources/lib/jaraco/classes/__init__.py b/resources/lib/jaraco/classes/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/resources/lib/jaraco/classes/ancestry.py b/resources/lib/jaraco/classes/ancestry.py new file mode 100644 index 0000000..a097843 --- /dev/null +++ b/resources/lib/jaraco/classes/ancestry.py @@ -0,0 +1,68 @@ +""" +Routines for obtaining the class names +of an object and its parent classes. +""" + +from more_itertools import unique_everseen + + +def all_bases(c): + """ + return a tuple of all base classes the class c has as a parent. + >>> object in all_bases(list) + True + """ + return c.mro()[1:] + + +def all_classes(c): + """ + return a tuple of all classes to which c belongs + >>> list in all_classes(list) + True + """ + return c.mro() + + +# borrowed from +# http://code.activestate.com/recipes/576949-find-all-subclasses-of-a-given-class/ + + +def iter_subclasses(cls): + """ + Generator over all subclasses of a given class, in depth-first order. + + >>> bool in list(iter_subclasses(int)) + True + >>> class A(object): pass + >>> class B(A): pass + >>> class C(A): pass + >>> class D(B,C): pass + >>> class E(D): pass + >>> + >>> for cls in iter_subclasses(A): + ... print(cls.__name__) + B + D + E + C + >>> # get ALL classes currently defined + >>> res = [cls.__name__ for cls in iter_subclasses(object)] + >>> 'type' in res + True + >>> 'tuple' in res + True + >>> len(res) > 100 + True + """ + return unique_everseen(_iter_all_subclasses(cls)) + + +def _iter_all_subclasses(cls): + try: + subs = cls.__subclasses__() + except TypeError: # fails only when cls is type + subs = cls.__subclasses__(cls) + for sub in subs: + yield sub + yield from iter_subclasses(sub) diff --git a/resources/lib/jaraco/classes/meta.py b/resources/lib/jaraco/classes/meta.py new file mode 100644 index 0000000..351cdb4 --- /dev/null +++ b/resources/lib/jaraco/classes/meta.py @@ -0,0 +1,64 @@ +""" +meta.py + +Some useful metaclasses. +""" + + +class LeafClassesMeta(type): + """ + A metaclass for classes that keeps track of all of them that + aren't base classes. + + >>> Parent = LeafClassesMeta('MyParentClass', (), {}) + >>> Parent in Parent._leaf_classes + True + >>> Child = LeafClassesMeta('MyChildClass', (Parent,), {}) + >>> Child in Parent._leaf_classes + True + >>> Parent in Parent._leaf_classes + False + + >>> Other = LeafClassesMeta('OtherClass', (), {}) + >>> Parent in Other._leaf_classes + False + >>> len(Other._leaf_classes) + 1 + """ + + def __init__(cls, name, bases, attrs): + if not hasattr(cls, '_leaf_classes'): + cls._leaf_classes = set() + leaf_classes = getattr(cls, '_leaf_classes') + leaf_classes.add(cls) + # remove any base classes + leaf_classes -= set(bases) + + +class TagRegistered(type): + """ + As classes of this metaclass are created, they keep a registry in the + base class of all classes by a class attribute, indicated by attr_name. + + >>> FooObject = TagRegistered('FooObject', (), dict(tag='foo')) + >>> FooObject._registry['foo'] is FooObject + True + >>> BarObject = TagRegistered('Barobject', (FooObject,), dict(tag='bar')) + >>> FooObject._registry is BarObject._registry + True + >>> len(FooObject._registry) + 2 + >>> FooObject._registry['bar'] + + """ + + attr_name = 'tag' + + def __init__(cls, name, bases, namespace): + super(TagRegistered, cls).__init__(name, bases, namespace) + if not hasattr(cls, '_registry'): + cls._registry = {} + meta = cls.__class__ + attr = getattr(cls, meta.attr_name, None) + if attr: + cls._registry[attr] = cls diff --git a/resources/lib/jaraco/classes/properties.py b/resources/lib/jaraco/classes/properties.py new file mode 100644 index 0000000..f6d1685 --- /dev/null +++ b/resources/lib/jaraco/classes/properties.py @@ -0,0 +1,169 @@ +class NonDataProperty: + """Much like the property builtin, but only implements __get__, + making it a non-data property, and can be subsequently reset. + + See http://users.rcn.com/python/download/Descriptor.htm for more + information. + + >>> class X(object): + ... @NonDataProperty + ... def foo(self): + ... return 3 + >>> x = X() + >>> x.foo + 3 + >>> x.foo = 4 + >>> x.foo + 4 + >>> X.foo + + """ + + def __init__(self, fget): + assert fget is not None, "fget cannot be none" + assert callable(fget), "fget must be callable" + self.fget = fget + + def __get__(self, obj, objtype=None): + if obj is None: + return self + return self.fget(obj) + + +class classproperty: + """ + Like @property but applies at the class level. + + + >>> class X(metaclass=classproperty.Meta): + ... val = None + ... @classproperty + ... def foo(cls): + ... return cls.val + ... @foo.setter + ... def foo(cls, val): + ... cls.val = val + >>> X.foo + >>> X.foo = 3 + >>> X.foo + 3 + >>> x = X() + >>> x.foo + 3 + >>> X.foo = 4 + >>> x.foo + 4 + + Setting the property on an instance affects the class. + + >>> x.foo = 5 + >>> x.foo + 5 + >>> X.foo + 5 + >>> vars(x) + {} + >>> X().foo + 5 + + Attempting to set an attribute where no setter was defined + results in an AttributeError: + + >>> class GetOnly(metaclass=classproperty.Meta): + ... @classproperty + ... def foo(cls): + ... return 'bar' + >>> GetOnly.foo = 3 + Traceback (most recent call last): + ... + AttributeError: can't set attribute + + It is also possible to wrap a classmethod or staticmethod in + a classproperty. + + >>> class Static(metaclass=classproperty.Meta): + ... @classproperty + ... @classmethod + ... def foo(cls): + ... return 'foo' + ... @classproperty + ... @staticmethod + ... def bar(): + ... return 'bar' + >>> Static.foo + 'foo' + >>> Static.bar + 'bar' + + *Legacy* + + For compatibility, if the metaclass isn't specified, the + legacy behavior will be invoked. + + >>> class X: + ... val = None + ... @classproperty + ... def foo(cls): + ... return cls.val + ... @foo.setter + ... def foo(cls, val): + ... cls.val = val + >>> X.foo + >>> X.foo = 3 + >>> X.foo + 3 + >>> x = X() + >>> x.foo + 3 + >>> X.foo = 4 + >>> x.foo + 4 + + Note, because the metaclass was not specified, setting + a value on an instance does not have the intended effect. + + >>> x.foo = 5 + >>> x.foo + 5 + >>> X.foo # should be 5 + 4 + >>> vars(x) # should be empty + {'foo': 5} + >>> X().foo # should be 5 + 4 + """ + + class Meta(type): + def __setattr__(self, key, value): + obj = self.__dict__.get(key, None) + if type(obj) is classproperty: + return obj.__set__(self, value) + return super().__setattr__(key, value) + + def __init__(self, fget, fset=None): + self.fget = self._fix_function(fget) + self.fset = fset + fset and self.setter(fset) + + def __get__(self, instance, owner=None): + return self.fget.__get__(None, owner)() + + def __set__(self, owner, value): + if not self.fset: + raise AttributeError("can't set attribute") + if type(owner) is not classproperty.Meta: + owner = type(owner) + return self.fset.__get__(None, owner)(value) + + def setter(self, fset): + self.fset = self._fix_function(fset) + return self + + @classmethod + def _fix_function(cls, fn): + """ + Ensure fn is a classmethod or staticmethod. + """ + if not isinstance(fn, (classmethod, staticmethod)): + return classmethod(fn) + return fn diff --git a/resources/lib/jaraco/collections.py b/resources/lib/jaraco/collections.py new file mode 100644 index 0000000..3ab9dc9 --- /dev/null +++ b/resources/lib/jaraco/collections.py @@ -0,0 +1,963 @@ +import re +import operator +import collections.abc +import itertools +import copy +import functools + +from jaraco.classes.properties import NonDataProperty +import jaraco.text + + +class Projection(collections.abc.Mapping): + """ + Project a set of keys over a mapping + + >>> sample = {'a': 1, 'b': 2, 'c': 3} + >>> prj = Projection(['a', 'c', 'd'], sample) + >>> prj == {'a': 1, 'c': 3} + True + + Keys should only appear if they were specified and exist in the space. + + >>> sorted(list(prj.keys())) + ['a', 'c'] + + Attempting to access a key not in the projection + results in a KeyError. + + >>> prj['b'] + Traceback (most recent call last): + ... + KeyError: 'b' + + Use the projection to update another dict. + + >>> target = {'a': 2, 'b': 2} + >>> target.update(prj) + >>> target == {'a': 1, 'b': 2, 'c': 3} + True + + Also note that Projection keeps a reference to the original dict, so + if you modify the original dict, that could modify the Projection. + + >>> del sample['a'] + >>> dict(prj) + {'c': 3} + """ + + def __init__(self, keys, space): + self._keys = tuple(keys) + self._space = space + + def __getitem__(self, key): + if key not in self._keys: + raise KeyError(key) + return self._space[key] + + def __iter__(self): + return iter(set(self._keys).intersection(self._space)) + + def __len__(self): + return len(tuple(iter(self))) + + +class DictFilter(object): + """ + Takes a dict, and simulates a sub-dict based on the keys. + + >>> sample = {'a': 1, 'b': 2, 'c': 3} + >>> filtered = DictFilter(sample, ['a', 'c']) + >>> filtered == {'a': 1, 'c': 3} + True + >>> set(filtered.values()) == {1, 3} + True + >>> set(filtered.items()) == {('a', 1), ('c', 3)} + True + + One can also filter by a regular expression pattern + + >>> sample['d'] = 4 + >>> sample['ef'] = 5 + + Here we filter for only single-character keys + + >>> filtered = DictFilter(sample, include_pattern='.$') + >>> filtered == {'a': 1, 'b': 2, 'c': 3, 'd': 4} + True + + >>> filtered['e'] + Traceback (most recent call last): + ... + KeyError: 'e' + + Also note that DictFilter keeps a reference to the original dict, so + if you modify the original dict, that could modify the filtered dict. + + >>> del sample['d'] + >>> del sample['a'] + >>> filtered == {'b': 2, 'c': 3} + True + >>> filtered != {'b': 2, 'c': 3} + False + """ + + def __init__(self, dict, include_keys=[], include_pattern=None): + self.dict = dict + self.specified_keys = set(include_keys) + if include_pattern is not None: + self.include_pattern = re.compile(include_pattern) + else: + # for performance, replace the pattern_keys property + self.pattern_keys = set() + + def get_pattern_keys(self): + keys = filter(self.include_pattern.match, self.dict.keys()) + return set(keys) + + pattern_keys = NonDataProperty(get_pattern_keys) + + @property + def include_keys(self): + return self.specified_keys.union(self.pattern_keys) + + def keys(self): + return self.include_keys.intersection(self.dict.keys()) + + def values(self): + return map(self.dict.get, self.keys()) + + def __getitem__(self, i): + if i not in self.include_keys: + raise KeyError(i) + return self.dict[i] + + def items(self): + keys = self.keys() + values = map(self.dict.get, keys) + return zip(keys, values) + + def __eq__(self, other): + return dict(self) == other + + def __ne__(self, other): + return dict(self) != other + + +def dict_map(function, dictionary): + """ + dict_map is much like the built-in function map. It takes a dictionary + and applys a function to the values of that dictionary, returning a + new dictionary with the mapped values in the original keys. + + >>> d = dict_map(lambda x:x+1, dict(a=1, b=2)) + >>> d == dict(a=2,b=3) + True + """ + return dict((key, function(value)) for key, value in dictionary.items()) + + +class RangeMap(dict): + """ + A dictionary-like object that uses the keys as bounds for a range. + Inclusion of the value for that range is determined by the + key_match_comparator, which defaults to less-than-or-equal. + A value is returned for a key if it is the first key that matches in + the sorted list of keys. + + One may supply keyword parameters to be passed to the sort function used + to sort keys (i.e. cmp [python 2 only], keys, reverse) as sort_params. + + Let's create a map that maps 1-3 -> 'a', 4-6 -> 'b' + + >>> r = RangeMap({3: 'a', 6: 'b'}) # boy, that was easy + >>> r[1], r[2], r[3], r[4], r[5], r[6] + ('a', 'a', 'a', 'b', 'b', 'b') + + Even float values should work so long as the comparison operator + supports it. + + >>> r[4.5] + 'b' + + But you'll notice that the way rangemap is defined, it must be open-ended + on one side. + + >>> r[0] + 'a' + >>> r[-1] + 'a' + + One can close the open-end of the RangeMap by using undefined_value + + >>> r = RangeMap({0: RangeMap.undefined_value, 3: 'a', 6: 'b'}) + >>> r[0] + Traceback (most recent call last): + ... + KeyError: 0 + + One can get the first or last elements in the range by using RangeMap.Item + + >>> last_item = RangeMap.Item(-1) + >>> r[last_item] + 'b' + + .last_item is a shortcut for Item(-1) + + >>> r[RangeMap.last_item] + 'b' + + Sometimes it's useful to find the bounds for a RangeMap + + >>> r.bounds() + (0, 6) + + RangeMap supports .get(key, default) + + >>> r.get(0, 'not found') + 'not found' + + >>> r.get(7, 'not found') + 'not found' + """ + + def __init__(self, source, sort_params={}, key_match_comparator=operator.le): + dict.__init__(self, source) + self.sort_params = sort_params + self.match = key_match_comparator + + def __getitem__(self, item): + sorted_keys = sorted(self.keys(), **self.sort_params) + if isinstance(item, RangeMap.Item): + result = self.__getitem__(sorted_keys[item]) + else: + key = self._find_first_match_(sorted_keys, item) + result = dict.__getitem__(self, key) + if result is RangeMap.undefined_value: + raise KeyError(key) + return result + + def get(self, key, default=None): + """ + Return the value for key if key is in the dictionary, else default. + If default is not given, it defaults to None, so that this method + never raises a KeyError. + """ + try: + return self[key] + except KeyError: + return default + + def _find_first_match_(self, keys, item): + is_match = functools.partial(self.match, item) + matches = list(filter(is_match, keys)) + if matches: + return matches[0] + raise KeyError(item) + + def bounds(self): + sorted_keys = sorted(self.keys(), **self.sort_params) + return (sorted_keys[RangeMap.first_item], sorted_keys[RangeMap.last_item]) + + # some special values for the RangeMap + undefined_value = type(str('RangeValueUndefined'), (object,), {})() + + class Item(int): + "RangeMap Item" + + first_item = Item(0) + last_item = Item(-1) + + +def __identity(x): + return x + + +def sorted_items(d, key=__identity, reverse=False): + """ + Return the items of the dictionary sorted by the keys + + >>> sample = dict(foo=20, bar=42, baz=10) + >>> tuple(sorted_items(sample)) + (('bar', 42), ('baz', 10), ('foo', 20)) + + >>> reverse_string = lambda s: ''.join(reversed(s)) + >>> tuple(sorted_items(sample, key=reverse_string)) + (('foo', 20), ('bar', 42), ('baz', 10)) + + >>> tuple(sorted_items(sample, reverse=True)) + (('foo', 20), ('baz', 10), ('bar', 42)) + """ + # wrap the key func so it operates on the first element of each item + def pairkey_key(item): + return key(item[0]) + + return sorted(d.items(), key=pairkey_key, reverse=reverse) + + +class KeyTransformingDict(dict): + """ + A dict subclass that transforms the keys before they're used. + Subclasses may override the default transform_key to customize behavior. + """ + + @staticmethod + def transform_key(key): # pragma: nocover + return key + + def __init__(self, *args, **kargs): + super(KeyTransformingDict, self).__init__() + # build a dictionary using the default constructs + d = dict(*args, **kargs) + # build this dictionary using transformed keys. + for item in d.items(): + self.__setitem__(*item) + + def __setitem__(self, key, val): + key = self.transform_key(key) + super(KeyTransformingDict, self).__setitem__(key, val) + + def __getitem__(self, key): + key = self.transform_key(key) + return super(KeyTransformingDict, self).__getitem__(key) + + def __contains__(self, key): + key = self.transform_key(key) + return super(KeyTransformingDict, self).__contains__(key) + + def __delitem__(self, key): + key = self.transform_key(key) + return super(KeyTransformingDict, self).__delitem__(key) + + def get(self, key, *args, **kwargs): + key = self.transform_key(key) + return super(KeyTransformingDict, self).get(key, *args, **kwargs) + + def setdefault(self, key, *args, **kwargs): + key = self.transform_key(key) + return super(KeyTransformingDict, self).setdefault(key, *args, **kwargs) + + def pop(self, key, *args, **kwargs): + key = self.transform_key(key) + return super(KeyTransformingDict, self).pop(key, *args, **kwargs) + + def matching_key_for(self, key): + """ + Given a key, return the actual key stored in self that matches. + Raise KeyError if the key isn't found. + """ + try: + return next(e_key for e_key in self.keys() if e_key == key) + except StopIteration: + raise KeyError(key) + + +class FoldedCaseKeyedDict(KeyTransformingDict): + """ + A case-insensitive dictionary (keys are compared as insensitive + if they are strings). + + >>> d = FoldedCaseKeyedDict() + >>> d['heLlo'] = 'world' + >>> list(d.keys()) == ['heLlo'] + True + >>> list(d.values()) == ['world'] + True + >>> d['hello'] == 'world' + True + >>> 'hello' in d + True + >>> 'HELLO' in d + True + >>> print(repr(FoldedCaseKeyedDict({'heLlo': 'world'})).replace("u'", "'")) + {'heLlo': 'world'} + >>> d = FoldedCaseKeyedDict({'heLlo': 'world'}) + >>> print(d['hello']) + world + >>> print(d['Hello']) + world + >>> list(d.keys()) + ['heLlo'] + >>> d = FoldedCaseKeyedDict({'heLlo': 'world', 'Hello': 'world'}) + >>> list(d.values()) + ['world'] + >>> key, = d.keys() + >>> key in ['heLlo', 'Hello'] + True + >>> del d['HELLO'] + >>> d + {} + + get should work + + >>> d['Sumthin'] = 'else' + >>> d.get('SUMTHIN') + 'else' + >>> d.get('OTHER', 'thing') + 'thing' + >>> del d['sumthin'] + + setdefault should also work + + >>> d['This'] = 'that' + >>> print(d.setdefault('this', 'other')) + that + >>> len(d) + 1 + >>> print(d['this']) + that + >>> print(d.setdefault('That', 'other')) + other + >>> print(d['THAT']) + other + + Make it pop! + + >>> print(d.pop('THAT')) + other + + To retrieve the key in its originally-supplied form, use matching_key_for + + >>> print(d.matching_key_for('this')) + This + + >>> d.matching_key_for('missing') + Traceback (most recent call last): + ... + KeyError: 'missing' + """ + + @staticmethod + def transform_key(key): + return jaraco.text.FoldedCase(key) + + +class DictAdapter(object): + """ + Provide a getitem interface for attributes of an object. + + Let's say you want to get at the string.lowercase property in a formatted + string. It's easy with DictAdapter. + + >>> import string + >>> print("lowercase is %(ascii_lowercase)s" % DictAdapter(string)) + lowercase is abcdefghijklmnopqrstuvwxyz + """ + + def __init__(self, wrapped_ob): + self.object = wrapped_ob + + def __getitem__(self, name): + return getattr(self.object, name) + + +class ItemsAsAttributes(object): + """ + Mix-in class to enable a mapping object to provide items as + attributes. + + >>> C = type(str('C'), (dict, ItemsAsAttributes), dict()) + >>> i = C() + >>> i['foo'] = 'bar' + >>> i.foo + 'bar' + + Natural attribute access takes precedence + + >>> i.foo = 'henry' + >>> i.foo + 'henry' + + But as you might expect, the mapping functionality is preserved. + + >>> i['foo'] + 'bar' + + A normal attribute error should be raised if an attribute is + requested that doesn't exist. + + >>> i.missing + Traceback (most recent call last): + ... + AttributeError: 'C' object has no attribute 'missing' + + It also works on dicts that customize __getitem__ + + >>> missing_func = lambda self, key: 'missing item' + >>> C = type( + ... str('C'), + ... (dict, ItemsAsAttributes), + ... dict(__missing__ = missing_func), + ... ) + >>> i = C() + >>> i.missing + 'missing item' + >>> i.foo + 'missing item' + """ + + def __getattr__(self, key): + try: + return getattr(super(ItemsAsAttributes, self), key) + except AttributeError as e: + # attempt to get the value from the mapping (return self[key]) + # but be careful not to lose the original exception context. + noval = object() + + def _safe_getitem(cont, key, missing_result): + try: + return cont[key] + except KeyError: + return missing_result + + result = _safe_getitem(self, key, noval) + if result is not noval: + return result + # raise the original exception, but use the original class + # name, not 'super'. + message, = e.args + message = message.replace('super', self.__class__.__name__, 1) + e.args = (message,) + raise + + +def invert_map(map): + """ + Given a dictionary, return another dictionary with keys and values + switched. If any of the values resolve to the same key, raises + a ValueError. + + >>> numbers = dict(a=1, b=2, c=3) + >>> letters = invert_map(numbers) + >>> letters[1] + 'a' + >>> numbers['d'] = 3 + >>> invert_map(numbers) + Traceback (most recent call last): + ... + ValueError: Key conflict in inverted mapping + """ + res = dict((v, k) for k, v in map.items()) + if not len(res) == len(map): + raise ValueError('Key conflict in inverted mapping') + return res + + +class IdentityOverrideMap(dict): + """ + A dictionary that by default maps each key to itself, but otherwise + acts like a normal dictionary. + + >>> d = IdentityOverrideMap() + >>> d[42] + 42 + >>> d['speed'] = 'speedo' + >>> print(d['speed']) + speedo + """ + + def __missing__(self, key): + return key + + +class DictStack(list, collections.abc.Mapping): + """ + A stack of dictionaries that behaves as a view on those dictionaries, + giving preference to the last. + + >>> stack = DictStack([dict(a=1, c=2), dict(b=2, a=2)]) + >>> stack['a'] + 2 + >>> stack['b'] + 2 + >>> stack['c'] + 2 + >>> stack.push(dict(a=3)) + >>> stack['a'] + 3 + >>> set(stack.keys()) == set(['a', 'b', 'c']) + True + >>> dict(**stack) == dict(a=3, c=2, b=2) + True + >>> d = stack.pop() + >>> stack['a'] + 2 + >>> d = stack.pop() + >>> stack['a'] + 1 + >>> stack.get('b', None) + """ + + def keys(self): + return list(set(itertools.chain.from_iterable(c.keys() for c in self))) + + def __getitem__(self, key): + for scope in reversed(self): + if key in scope: + return scope[key] + raise KeyError(key) + + push = list.append + + +class BijectiveMap(dict): + """ + A Bijective Map (two-way mapping). + + Implemented as a simple dictionary of 2x the size, mapping values back + to keys. + + Note, this implementation may be incomplete. If there's not a test for + your use case below, it's likely to fail, so please test and send pull + requests or patches for additional functionality needed. + + + >>> m = BijectiveMap() + >>> m['a'] = 'b' + >>> m == {'a': 'b', 'b': 'a'} + True + >>> print(m['b']) + a + + >>> m['c'] = 'd' + >>> len(m) + 2 + + Some weird things happen if you map an item to itself or overwrite a + single key of a pair, so it's disallowed. + + >>> m['e'] = 'e' + Traceback (most recent call last): + ValueError: Key cannot map to itself + + >>> m['d'] = 'e' + Traceback (most recent call last): + ValueError: Key/Value pairs may not overlap + + >>> m['e'] = 'd' + Traceback (most recent call last): + ValueError: Key/Value pairs may not overlap + + >>> print(m.pop('d')) + c + + >>> 'c' in m + False + + >>> m = BijectiveMap(dict(a='b')) + >>> len(m) + 1 + >>> print(m['b']) + a + + >>> m = BijectiveMap() + >>> m.update(a='b') + >>> m['b'] + 'a' + + >>> del m['b'] + >>> len(m) + 0 + >>> 'a' in m + False + """ + + def __init__(self, *args, **kwargs): + super(BijectiveMap, self).__init__() + self.update(*args, **kwargs) + + def __setitem__(self, item, value): + if item == value: + raise ValueError("Key cannot map to itself") + overlap = ( + item in self + and self[item] != value + or value in self + and self[value] != item + ) + if overlap: + raise ValueError("Key/Value pairs may not overlap") + super(BijectiveMap, self).__setitem__(item, value) + super(BijectiveMap, self).__setitem__(value, item) + + def __delitem__(self, item): + self.pop(item) + + def __len__(self): + return super(BijectiveMap, self).__len__() // 2 + + def pop(self, key, *args, **kwargs): + mirror = self[key] + super(BijectiveMap, self).__delitem__(mirror) + return super(BijectiveMap, self).pop(key, *args, **kwargs) + + def update(self, *args, **kwargs): + # build a dictionary using the default constructs + d = dict(*args, **kwargs) + # build this dictionary using transformed keys. + for item in d.items(): + self.__setitem__(*item) + + +class FrozenDict(collections.abc.Mapping, collections.abc.Hashable): + """ + An immutable mapping. + + >>> a = FrozenDict(a=1, b=2) + >>> b = FrozenDict(a=1, b=2) + >>> a == b + True + + >>> a == dict(a=1, b=2) + True + >>> dict(a=1, b=2) == a + True + >>> 'a' in a + True + >>> type(hash(a)) is type(0) + True + >>> set(iter(a)) == {'a', 'b'} + True + >>> len(a) + 2 + >>> a['a'] == a.get('a') == 1 + True + + >>> a['c'] = 3 + Traceback (most recent call last): + ... + TypeError: 'FrozenDict' object does not support item assignment + + >>> a.update(y=3) + Traceback (most recent call last): + ... + AttributeError: 'FrozenDict' object has no attribute 'update' + + Copies should compare equal + + >>> copy.copy(a) == a + True + + Copies should be the same type + + >>> isinstance(copy.copy(a), FrozenDict) + True + + FrozenDict supplies .copy(), even though + collections.abc.Mapping doesn't demand it. + + >>> a.copy() == a + True + >>> a.copy() is not a + True + """ + + __slots__ = ['__data'] + + def __new__(cls, *args, **kwargs): + self = super(FrozenDict, cls).__new__(cls) + self.__data = dict(*args, **kwargs) + return self + + # Container + def __contains__(self, key): + return key in self.__data + + # Hashable + def __hash__(self): + return hash(tuple(sorted(self.__data.items()))) + + # Mapping + def __iter__(self): + return iter(self.__data) + + def __len__(self): + return len(self.__data) + + def __getitem__(self, key): + return self.__data[key] + + # override get for efficiency provided by dict + def get(self, *args, **kwargs): + return self.__data.get(*args, **kwargs) + + # override eq to recognize underlying implementation + def __eq__(self, other): + if isinstance(other, FrozenDict): + other = other.__data + return self.__data.__eq__(other) + + def copy(self): + "Return a shallow copy of self" + return copy.copy(self) + + +class Enumeration(ItemsAsAttributes, BijectiveMap): + """ + A convenient way to provide enumerated values + + >>> e = Enumeration('a b c') + >>> e['a'] + 0 + + >>> e.a + 0 + + >>> e[1] + 'b' + + >>> set(e.names) == set('abc') + True + + >>> set(e.codes) == set(range(3)) + True + + >>> e.get('d') is None + True + + Codes need not start with 0 + + >>> e = Enumeration('a b c', range(1, 4)) + >>> e['a'] + 1 + + >>> e[3] + 'c' + """ + + def __init__(self, names, codes=None): + if isinstance(names, str): + names = names.split() + if codes is None: + codes = itertools.count() + super(Enumeration, self).__init__(zip(names, codes)) + + @property + def names(self): + return (key for key in self if isinstance(key, str)) + + @property + def codes(self): + return (self[name] for name in self.names) + + +class Everything(object): + """ + A collection "containing" every possible thing. + + >>> 'foo' in Everything() + True + + >>> import random + >>> random.randint(1, 999) in Everything() + True + + >>> random.choice([None, 'foo', 42, ('a', 'b', 'c')]) in Everything() + True + """ + + def __contains__(self, other): + return True + + +class InstrumentedDict(collections.UserDict): + """ + Instrument an existing dictionary with additional + functionality, but always reference and mutate + the original dictionary. + + >>> orig = {'a': 1, 'b': 2} + >>> inst = InstrumentedDict(orig) + >>> inst['a'] + 1 + >>> inst['c'] = 3 + >>> orig['c'] + 3 + >>> inst.keys() == orig.keys() + True + """ + + def __init__(self, data): + super().__init__() + self.data = data + + +class Least(object): + """ + A value that is always lesser than any other + + >>> least = Least() + >>> 3 < least + False + >>> 3 > least + True + >>> least < 3 + True + >>> least <= 3 + True + >>> least > 3 + False + >>> 'x' > least + True + >>> None > least + True + """ + + def __le__(self, other): + return True + + __lt__ = __le__ + + def __ge__(self, other): + return False + + __gt__ = __ge__ + + +class Greatest(object): + """ + A value that is always greater than any other + + >>> greatest = Greatest() + >>> 3 < greatest + True + >>> 3 > greatest + False + >>> greatest < 3 + False + >>> greatest > 3 + True + >>> greatest >= 3 + True + >>> 'x' > greatest + False + >>> None > greatest + False + """ + + def __ge__(self, other): + return True + + __gt__ = __ge__ + + def __le__(self, other): + return False + + __lt__ = __le__ + + +def pop_all(items): + """ + Clear items in place and return a copy of items. + + >>> items = [1, 2, 3] + >>> popped = pop_all(items) + >>> popped is items + False + >>> popped + [1, 2, 3] + >>> items + [] + """ + result, items[:] = items[:], [] + return result diff --git a/resources/lib/jaraco/functools.py b/resources/lib/jaraco/functools.py new file mode 100644 index 0000000..eb550db --- /dev/null +++ b/resources/lib/jaraco/functools.py @@ -0,0 +1,458 @@ +import functools +import time +import inspect +import collections +import types +import itertools + +import more_itertools + + +def compose(*funcs): + """ + Compose any number of unary functions into a single unary function. + + >>> import textwrap + >>> stripped = str.strip(textwrap.dedent(compose.__doc__)) + >>> compose(str.strip, textwrap.dedent)(compose.__doc__) == stripped + True + + Compose also allows the innermost function to take arbitrary arguments. + + >>> round_three = lambda x: round(x, ndigits=3) + >>> f = compose(round_three, int.__truediv__) + >>> [f(3*x, x+1) for x in range(1,10)] + [1.5, 2.0, 2.25, 2.4, 2.5, 2.571, 2.625, 2.667, 2.7] + """ + + def compose_two(f1, f2): + return lambda *args, **kwargs: f1(f2(*args, **kwargs)) + + return functools.reduce(compose_two, funcs) + + +def method_caller(method_name, *args, **kwargs): + """ + Return a function that will call a named method on the + target object with optional positional and keyword + arguments. + + >>> lower = method_caller('lower') + >>> lower('MyString') + 'mystring' + """ + + def call_method(target): + func = getattr(target, method_name) + return func(*args, **kwargs) + + return call_method + + +def once(func): + """ + Decorate func so it's only ever called the first time. + + This decorator can ensure that an expensive or non-idempotent function + will not be expensive on subsequent calls and is idempotent. + + >>> add_three = once(lambda a: a+3) + >>> add_three(3) + 6 + >>> add_three(9) + 6 + >>> add_three('12') + 6 + + To reset the stored value, simply clear the property ``saved_result``. + + >>> del add_three.saved_result + >>> add_three(9) + 12 + >>> add_three(8) + 12 + + Or invoke 'reset()' on it. + + >>> add_three.reset() + >>> add_three(-3) + 0 + >>> add_three(0) + 0 + """ + + @functools.wraps(func) + def wrapper(*args, **kwargs): + if not hasattr(wrapper, 'saved_result'): + wrapper.saved_result = func(*args, **kwargs) + return wrapper.saved_result + + wrapper.reset = lambda: vars(wrapper).__delitem__('saved_result') + return wrapper + + +def method_cache(method, cache_wrapper=None): + """ + Wrap lru_cache to support storing the cache data in the object instances. + + Abstracts the common paradigm where the method explicitly saves an + underscore-prefixed protected property on first call and returns that + subsequently. + + >>> class MyClass: + ... calls = 0 + ... + ... @method_cache + ... def method(self, value): + ... self.calls += 1 + ... return value + + >>> a = MyClass() + >>> a.method(3) + 3 + >>> for x in range(75): + ... res = a.method(x) + >>> a.calls + 75 + + Note that the apparent behavior will be exactly like that of lru_cache + except that the cache is stored on each instance, so values in one + instance will not flush values from another, and when an instance is + deleted, so are the cached values for that instance. + + >>> b = MyClass() + >>> for x in range(35): + ... res = b.method(x) + >>> b.calls + 35 + >>> a.method(0) + 0 + >>> a.calls + 75 + + Note that if method had been decorated with ``functools.lru_cache()``, + a.calls would have been 76 (due to the cached value of 0 having been + flushed by the 'b' instance). + + Clear the cache with ``.cache_clear()`` + + >>> a.method.cache_clear() + + Another cache wrapper may be supplied: + + >>> cache = functools.lru_cache(maxsize=2) + >>> MyClass.method2 = method_cache(lambda self: 3, cache_wrapper=cache) + >>> a = MyClass() + >>> a.method2() + 3 + + Caution - do not subsequently wrap the method with another decorator, such + as ``@property``, which changes the semantics of the function. + + See also + http://code.activestate.com/recipes/577452-a-memoize-decorator-for-instance-methods/ + for another implementation and additional justification. + """ + cache_wrapper = cache_wrapper or functools.lru_cache() + + def wrapper(self, *args, **kwargs): + # it's the first call, replace the method with a cached, bound method + bound_method = types.MethodType(method, self) + cached_method = cache_wrapper(bound_method) + setattr(self, method.__name__, cached_method) + return cached_method(*args, **kwargs) + + return _special_method_cache(method, cache_wrapper) or wrapper + + +def _special_method_cache(method, cache_wrapper): + """ + Because Python treats special methods differently, it's not + possible to use instance attributes to implement the cached + methods. + + Instead, install the wrapper method under a different name + and return a simple proxy to that wrapper. + + https://github.com/jaraco/jaraco.functools/issues/5 + """ + name = method.__name__ + special_names = '__getattr__', '__getitem__' + if name not in special_names: + return + + wrapper_name = '__cached' + name + + def proxy(self, *args, **kwargs): + if wrapper_name not in vars(self): + bound = types.MethodType(method, self) + cache = cache_wrapper(bound) + setattr(self, wrapper_name, cache) + else: + cache = getattr(self, wrapper_name) + return cache(*args, **kwargs) + + return proxy + + +def apply(transform): + """ + Decorate a function with a transform function that is + invoked on results returned from the decorated function. + + >>> @apply(reversed) + ... def get_numbers(start): + ... return range(start, start+3) + >>> list(get_numbers(4)) + [6, 5, 4] + """ + + def wrap(func): + return compose(transform, func) + + return wrap + + +def result_invoke(action): + r""" + Decorate a function with an action function that is + invoked on the results returned from the decorated + function (for its side-effect), then return the original + result. + + >>> @result_invoke(print) + ... def add_two(a, b): + ... return a + b + >>> x = add_two(2, 3) + 5 + """ + + def wrap(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + result = func(*args, **kwargs) + action(result) + return result + + return wrapper + + return wrap + + +def call_aside(f, *args, **kwargs): + """ + Call a function for its side effect after initialization. + + >>> @call_aside + ... def func(): print("called") + called + >>> func() + called + + Use functools.partial to pass parameters to the initial call + + >>> @functools.partial(call_aside, name='bingo') + ... def func(name): print("called with", name) + called with bingo + """ + f(*args, **kwargs) + return f + + +class Throttler: + """ + Rate-limit a function (or other callable) + """ + + def __init__(self, func, max_rate=float('Inf')): + if isinstance(func, Throttler): + func = func.func + self.func = func + self.max_rate = max_rate + self.reset() + + def reset(self): + self.last_called = 0 + + def __call__(self, *args, **kwargs): + self._wait() + return self.func(*args, **kwargs) + + def _wait(self): + "ensure at least 1/max_rate seconds from last call" + elapsed = time.time() - self.last_called + must_wait = 1 / self.max_rate - elapsed + time.sleep(max(0, must_wait)) + self.last_called = time.time() + + def __get__(self, obj, type=None): + return first_invoke(self._wait, functools.partial(self.func, obj)) + + +def first_invoke(func1, func2): + """ + Return a function that when invoked will invoke func1 without + any parameters (for its side-effect) and then invoke func2 + with whatever parameters were passed, returning its result. + """ + + def wrapper(*args, **kwargs): + func1() + return func2(*args, **kwargs) + + return wrapper + + +def retry_call(func, cleanup=lambda: None, retries=0, trap=()): + """ + Given a callable func, trap the indicated exceptions + for up to 'retries' times, invoking cleanup on the + exception. On the final attempt, allow any exceptions + to propagate. + """ + attempts = itertools.count() if retries == float('inf') else range(retries) + for attempt in attempts: + try: + return func() + except trap: + cleanup() + + return func() + + +def retry(*r_args, **r_kwargs): + """ + Decorator wrapper for retry_call. Accepts arguments to retry_call + except func and then returns a decorator for the decorated function. + + Ex: + + >>> @retry(retries=3) + ... def my_func(a, b): + ... "this is my funk" + ... print(a, b) + >>> my_func.__doc__ + 'this is my funk' + """ + + def decorate(func): + @functools.wraps(func) + def wrapper(*f_args, **f_kwargs): + bound = functools.partial(func, *f_args, **f_kwargs) + return retry_call(bound, *r_args, **r_kwargs) + + return wrapper + + return decorate + + +def print_yielded(func): + """ + Convert a generator into a function that prints all yielded elements + + >>> @print_yielded + ... def x(): + ... yield 3; yield None + >>> x() + 3 + None + """ + print_all = functools.partial(map, print) + print_results = compose(more_itertools.consume, print_all, func) + return functools.wraps(func)(print_results) + + +def pass_none(func): + """ + Wrap func so it's not called if its first param is None + + >>> print_text = pass_none(print) + >>> print_text('text') + text + >>> print_text(None) + """ + + @functools.wraps(func) + def wrapper(param, *args, **kwargs): + if param is not None: + return func(param, *args, **kwargs) + + return wrapper + + +def assign_params(func, namespace): + """ + Assign parameters from namespace where func solicits. + + >>> def func(x, y=3): + ... print(x, y) + >>> assigned = assign_params(func, dict(x=2, z=4)) + >>> assigned() + 2 3 + + The usual errors are raised if a function doesn't receive + its required parameters: + + >>> assigned = assign_params(func, dict(y=3, z=4)) + >>> assigned() + Traceback (most recent call last): + TypeError: func() ...argument... + + It even works on methods: + + >>> class Handler: + ... def meth(self, arg): + ... print(arg) + >>> assign_params(Handler().meth, dict(arg='crystal', foo='clear'))() + crystal + """ + sig = inspect.signature(func) + params = sig.parameters.keys() + call_ns = {k: namespace[k] for k in params if k in namespace} + return functools.partial(func, **call_ns) + + +def save_method_args(method): + """ + Wrap a method such that when it is called, the args and kwargs are + saved on the method. + + >>> class MyClass: + ... @save_method_args + ... def method(self, a, b): + ... print(a, b) + >>> my_ob = MyClass() + >>> my_ob.method(1, 2) + 1 2 + >>> my_ob._saved_method.args + (1, 2) + >>> my_ob._saved_method.kwargs + {} + >>> my_ob.method(a=3, b='foo') + 3 foo + >>> my_ob._saved_method.args + () + >>> my_ob._saved_method.kwargs == dict(a=3, b='foo') + True + + The arguments are stored on the instance, allowing for + different instance to save different args. + + >>> your_ob = MyClass() + >>> your_ob.method({str('x'): 3}, b=[4]) + {'x': 3} [4] + >>> your_ob._saved_method.args + ({'x': 3},) + >>> my_ob._saved_method.args + () + """ + args_and_kwargs = collections.namedtuple('args_and_kwargs', 'args kwargs') + + @functools.wraps(method) + def wrapper(self, *args, **kwargs): + attr_name = '_saved_' + method.__name__ + attr = args_and_kwargs(args, kwargs) + setattr(self, attr_name, attr) + return method(self, *args, **kwargs) + + return wrapper diff --git a/resources/lib/jaraco/text/Lorem ipsum.txt b/resources/lib/jaraco/text/Lorem ipsum.txt new file mode 100644 index 0000000..986f944 --- /dev/null +++ b/resources/lib/jaraco/text/Lorem ipsum.txt @@ -0,0 +1,2 @@ +Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum. +Curabitur pretium tincidunt lacus. Nulla gravida orci a odio. Nullam varius, turpis et commodo pharetra, est eros bibendum elit, nec luctus magna felis sollicitudin mauris. Integer in mauris eu nibh euismod gravida. Duis ac tellus et risus vulputate vehicula. Donec lobortis risus a elit. Etiam tempor. Ut ullamcorper, ligula eu tempor congue, eros est euismod turpis, id tincidunt sapien risus a quam. Maecenas fermentum consequat mi. Donec fermentum. Pellentesque malesuada nulla a mi. Duis sapien sem, aliquet nec, commodo eget, consequat quis, neque. Aliquam faucibus, elit ut dictum aliquet, felis nisl adipiscing sapien, sed malesuada diam lacus eget erat. Cras mollis scelerisque nunc. Nullam arcu. Aliquam consequat. Curabitur augue lorem, dapibus quis, laoreet et, pretium ac, nisi. Aenean magna nisl, mollis quis, molestie eu, feugiat in, orci. In hac habitasse platea dictumst. diff --git a/resources/lib/jaraco/text/__init__.py b/resources/lib/jaraco/text/__init__.py new file mode 100644 index 0000000..7540296 --- /dev/null +++ b/resources/lib/jaraco/text/__init__.py @@ -0,0 +1,500 @@ +from __future__ import absolute_import, unicode_literals, print_function + +import re +import itertools +import textwrap +import functools + +import six + +try: + from importlib import resources +except ImportError: # pragma: nocover + import importlib_resources as resources + +from jaraco.functools import compose, method_cache + + +def substitution(old, new): + """ + Return a function that will perform a substitution on a string + """ + return lambda s: s.replace(old, new) + + +def multi_substitution(*substitutions): + """ + Take a sequence of pairs specifying substitutions, and create + a function that performs those substitutions. + + >>> multi_substitution(('foo', 'bar'), ('bar', 'baz'))('foo') + 'baz' + """ + substitutions = itertools.starmap(substitution, substitutions) + # compose function applies last function first, so reverse the + # substitutions to get the expected order. + substitutions = reversed(tuple(substitutions)) + return compose(*substitutions) + + +class FoldedCase(six.text_type): + """ + A case insensitive string class; behaves just like str + except compares equal when the only variation is case. + + >>> s = FoldedCase('hello world') + + >>> s == 'Hello World' + True + + >>> 'Hello World' == s + True + + >>> s != 'Hello World' + False + + >>> s.index('O') + 4 + + >>> s.split('O') + ['hell', ' w', 'rld'] + + >>> sorted(map(FoldedCase, ['GAMMA', 'alpha', 'Beta'])) + ['alpha', 'Beta', 'GAMMA'] + + Sequence membership is straightforward. + + >>> "Hello World" in [s] + True + >>> s in ["Hello World"] + True + + You may test for set inclusion, but candidate and elements + must both be folded. + + >>> FoldedCase("Hello World") in {s} + True + >>> s in {FoldedCase("Hello World")} + True + + String inclusion works as long as the FoldedCase object + is on the right. + + >>> "hello" in FoldedCase("Hello World") + True + + But not if the FoldedCase object is on the left: + + >>> FoldedCase('hello') in 'Hello World' + False + + In that case, use in_: + + >>> FoldedCase('hello').in_('Hello World') + True + + >>> FoldedCase('hello') > FoldedCase('Hello') + False + """ + + def __lt__(self, other): + return self.lower() < other.lower() + + def __gt__(self, other): + return self.lower() > other.lower() + + def __eq__(self, other): + return self.lower() == other.lower() + + def __ne__(self, other): + return self.lower() != other.lower() + + def __hash__(self): + return hash(self.lower()) + + def __contains__(self, other): + return super(FoldedCase, self).lower().__contains__(other.lower()) + + def in_(self, other): + "Does self appear in other?" + return self in FoldedCase(other) + + # cache lower since it's likely to be called frequently. + @method_cache + def lower(self): + return super(FoldedCase, self).lower() + + def index(self, sub): + return self.lower().index(sub.lower()) + + def split(self, splitter=' ', maxsplit=0): + pattern = re.compile(re.escape(splitter), re.I) + return pattern.split(self, maxsplit) + + +def is_decodable(value): + r""" + Return True if the supplied value is decodable (using the default + encoding). + + >>> is_decodable(b'\xff') + False + >>> is_decodable(b'\x32') + True + """ + # TODO: This code could be expressed more consisely and directly + # with a jaraco.context.ExceptionTrap, but that adds an unfortunate + # long dependency tree, so for now, use boolean literals. + try: + value.decode() + except UnicodeDecodeError: + return False + return True + + +def is_binary(value): + r""" + Return True if the value appears to be binary (that is, it's a byte + string and isn't decodable). + + >>> is_binary(b'\xff') + True + >>> is_binary('\xff') + False + """ + return isinstance(value, bytes) and not is_decodable(value) + + +def trim(s): + r""" + Trim something like a docstring to remove the whitespace that + is common due to indentation and formatting. + + >>> trim("\n\tfoo = bar\n\t\tbar = baz\n") + 'foo = bar\n\tbar = baz' + """ + return textwrap.dedent(s).strip() + + +def wrap(s): + """ + Wrap lines of text, retaining existing newlines as + paragraph markers. + + >>> print(wrap(lorem_ipsum)) + Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do + eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad + minim veniam, quis nostrud exercitation ullamco laboris nisi ut + aliquip ex ea commodo consequat. Duis aute irure dolor in + reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla + pariatur. Excepteur sint occaecat cupidatat non proident, sunt in + culpa qui officia deserunt mollit anim id est laborum. + + Curabitur pretium tincidunt lacus. Nulla gravida orci a odio. Nullam + varius, turpis et commodo pharetra, est eros bibendum elit, nec luctus + magna felis sollicitudin mauris. Integer in mauris eu nibh euismod + gravida. Duis ac tellus et risus vulputate vehicula. Donec lobortis + risus a elit. Etiam tempor. Ut ullamcorper, ligula eu tempor congue, + eros est euismod turpis, id tincidunt sapien risus a quam. Maecenas + fermentum consequat mi. Donec fermentum. Pellentesque malesuada nulla + a mi. Duis sapien sem, aliquet nec, commodo eget, consequat quis, + neque. Aliquam faucibus, elit ut dictum aliquet, felis nisl adipiscing + sapien, sed malesuada diam lacus eget erat. Cras mollis scelerisque + nunc. Nullam arcu. Aliquam consequat. Curabitur augue lorem, dapibus + quis, laoreet et, pretium ac, nisi. Aenean magna nisl, mollis quis, + molestie eu, feugiat in, orci. In hac habitasse platea dictumst. + """ + paragraphs = s.splitlines() + wrapped = ('\n'.join(textwrap.wrap(para)) for para in paragraphs) + return '\n\n'.join(wrapped) + + +def unwrap(s): + r""" + Given a multi-line string, return an unwrapped version. + + >>> wrapped = wrap(lorem_ipsum) + >>> wrapped.count('\n') + 20 + >>> unwrapped = unwrap(wrapped) + >>> unwrapped.count('\n') + 1 + >>> print(unwrapped) + Lorem ipsum dolor sit amet, consectetur adipiscing ... + Curabitur pretium tincidunt lacus. Nulla gravida orci ... + + """ + paragraphs = re.split(r'\n\n+', s) + cleaned = (para.replace('\n', ' ') for para in paragraphs) + return '\n'.join(cleaned) + + +lorem_ipsum = resources.read_text(__name__, 'Lorem ipsum.txt') + + +class Splitter(object): + """object that will split a string with the given arguments for each call + + >>> s = Splitter(',') + >>> s('hello, world, this is your, master calling') + ['hello', ' world', ' this is your', ' master calling'] + """ + + def __init__(self, *args): + self.args = args + + def __call__(self, s): + return s.split(*self.args) + + +def indent(string, prefix=' ' * 4): + """ + >>> indent('foo') + ' foo' + """ + return prefix + string + + +class WordSet(tuple): + """ + Given a Python identifier, return the words that identifier represents, + whether in camel case, underscore-separated, etc. + + >>> WordSet.parse("camelCase") + ('camel', 'Case') + + >>> WordSet.parse("under_sep") + ('under', 'sep') + + Acronyms should be retained + + >>> WordSet.parse("firstSNL") + ('first', 'SNL') + + >>> WordSet.parse("you_and_I") + ('you', 'and', 'I') + + >>> WordSet.parse("A simple test") + ('A', 'simple', 'test') + + Multiple caps should not interfere with the first cap of another word. + + >>> WordSet.parse("myABCClass") + ('my', 'ABC', 'Class') + + The result is a WordSet, so you can get the form you need. + + >>> WordSet.parse("myABCClass").underscore_separated() + 'my_ABC_Class' + + >>> WordSet.parse('a-command').camel_case() + 'ACommand' + + >>> WordSet.parse('someIdentifier').lowered().space_separated() + 'some identifier' + + Slices of the result should return another WordSet. + + >>> WordSet.parse('taken-out-of-context')[1:].underscore_separated() + 'out_of_context' + + >>> WordSet.from_class_name(WordSet()).lowered().space_separated() + 'word set' + + >>> example = WordSet.parse('figured it out') + >>> example.headless_camel_case() + 'figuredItOut' + >>> example.dash_separated() + 'figured-it-out' + + """ + + _pattern = re.compile('([A-Z]?[a-z]+)|([A-Z]+(?![a-z]))') + + def capitalized(self): + return WordSet(word.capitalize() for word in self) + + def lowered(self): + return WordSet(word.lower() for word in self) + + def camel_case(self): + return ''.join(self.capitalized()) + + def headless_camel_case(self): + words = iter(self) + first = next(words).lower() + new_words = itertools.chain((first,), WordSet(words).camel_case()) + return ''.join(new_words) + + def underscore_separated(self): + return '_'.join(self) + + def dash_separated(self): + return '-'.join(self) + + def space_separated(self): + return ' '.join(self) + + def __getitem__(self, item): + result = super(WordSet, self).__getitem__(item) + if isinstance(item, slice): + result = WordSet(result) + return result + + # for compatibility with Python 2 + def __getslice__(self, i, j): # pragma: nocover + return self.__getitem__(slice(i, j)) + + @classmethod + def parse(cls, identifier): + matches = cls._pattern.finditer(identifier) + return WordSet(match.group(0) for match in matches) + + @classmethod + def from_class_name(cls, subject): + return cls.parse(subject.__class__.__name__) + + +# for backward compatibility +words = WordSet.parse + + +def simple_html_strip(s): + r""" + Remove HTML from the string `s`. + + >>> str(simple_html_strip('')) + '' + + >>> print(simple_html_strip('A stormy day in paradise')) + A stormy day in paradise + + >>> print(simple_html_strip('Somebody tell the truth.')) + Somebody tell the truth. + + >>> print(simple_html_strip('What about
\nmultiple lines?')) + What about + multiple lines? + """ + html_stripper = re.compile('()|(<[^>]*>)|([^<]+)', re.DOTALL) + texts = (match.group(3) or '' for match in html_stripper.finditer(s)) + return ''.join(texts) + + +class SeparatedValues(six.text_type): + """ + A string separated by a separator. Overrides __iter__ for getting + the values. + + >>> list(SeparatedValues('a,b,c')) + ['a', 'b', 'c'] + + Whitespace is stripped and empty values are discarded. + + >>> list(SeparatedValues(' a, b , c, ')) + ['a', 'b', 'c'] + """ + + separator = ',' + + def __iter__(self): + parts = self.split(self.separator) + return six.moves.filter(None, (part.strip() for part in parts)) + + +class Stripper: + r""" + Given a series of lines, find the common prefix and strip it from them. + + >>> lines = [ + ... 'abcdefg\n', + ... 'abc\n', + ... 'abcde\n', + ... ] + >>> res = Stripper.strip_prefix(lines) + >>> res.prefix + 'abc' + >>> list(res.lines) + ['defg\n', '\n', 'de\n'] + + If no prefix is common, nothing should be stripped. + + >>> lines = [ + ... 'abcd\n', + ... '1234\n', + ... ] + >>> res = Stripper.strip_prefix(lines) + >>> res.prefix = '' + >>> list(res.lines) + ['abcd\n', '1234\n'] + """ + + def __init__(self, prefix, lines): + self.prefix = prefix + self.lines = map(self, lines) + + @classmethod + def strip_prefix(cls, lines): + prefix_lines, lines = itertools.tee(lines) + prefix = functools.reduce(cls.common_prefix, prefix_lines) + return cls(prefix, lines) + + def __call__(self, line): + if not self.prefix: + return line + null, prefix, rest = line.partition(self.prefix) + return rest + + @staticmethod + def common_prefix(s1, s2): + """ + Return the common prefix of two lines. + """ + index = min(len(s1), len(s2)) + while s1[:index] != s2[:index]: + index -= 1 + return s1[:index] + + +def remove_prefix(text, prefix): + """ + Remove the prefix from the text if it exists. + + >>> remove_prefix('underwhelming performance', 'underwhelming ') + 'performance' + + >>> remove_prefix('something special', 'sample') + 'something special' + """ + null, prefix, rest = text.rpartition(prefix) + return rest + + +def remove_suffix(text, suffix): + """ + Remove the suffix from the text if it exists. + + >>> remove_suffix('name.git', '.git') + 'name' + + >>> remove_suffix('something special', 'sample') + 'something special' + """ + rest, suffix, null = text.partition(suffix) + return rest + + +def normalize_newlines(text): + r""" + Replace alternate newlines with the canonical newline. + + >>> normalize_newlines('Lorem Ipsum\u2029') + 'Lorem Ipsum\n' + >>> normalize_newlines('Lorem Ipsum\r\n') + 'Lorem Ipsum\n' + >>> normalize_newlines('Lorem Ipsum\x85') + 'Lorem Ipsum\n' + """ + newlines = ['\r\n', '\r', '\n', '\u0085', '\u2028', '\u2029'] + pattern = '|'.join(newlines) + return re.sub(pattern, '\n', text) diff --git a/resources/lib/logless/__init__.py b/resources/lib/logless/__init__.py new file mode 100644 index 0000000..e0ffc68 --- /dev/null +++ b/resources/lib/logless/__init__.py @@ -0,0 +1,10 @@ +"""The logless module: do more with (log)less.""" + +from .main import logged, logged_block, get_logger, flush_buffers + +__all__ = [ + "logged", + "logged_block", + "get_logger", + "flush_buffers", +] diff --git a/resources/lib/logless/main.py b/resources/lib/logless/main.py new file mode 100644 index 0000000..085a233 --- /dev/null +++ b/resources/lib/logless/main.py @@ -0,0 +1,422 @@ +"""Logless logging platform.""" + +from contextlib import contextmanager +import datetime +import functools +import inspect +import logging +import threading +import queue +import os +import sys +import time + +# pylint: disable = logging-format-interpolation + +try: + import psutil +except Exception as ex: + psutil = None # noqa + logging.warning( + f"Could not load psutils. Some logging function may not be available. {ex}" + ) + + +@functools.lru_cache(20) +def get_logger(app_name: str, debug=os.environ.get("DEBUG", None)): + """ + Get a logger object. + + Arguments + --------- + app_name {str} -- An app string to use for the logger. + + Keyword Arguments + ----------------- + debug {bool} -- Whether to include DEBUG output. + + Returns + ------- + logger -- A logger object. + + """ + log_format = "%(asctime)s - %(levelname)s - %(message)s" + log_level = logging.DEBUG if debug else logging.INFO + logging.basicConfig(format=log_format) + logger = logging.getLogger(app_name) + logger.setLevel(log_level) + return logger + + +DEFAULT_HEARTBEAT_MINUTES = 5 +DEFAULT_LOGGER = get_logger("slalom.dataops.logs") + + +def duration_to_string(seconds): + """Return duration as a concise string. e.g. "32min 3s", "4hr 34min", etc.""" + units = ["hr", "min", "s"] + duration_parts = [ + int(part) for part in str(datetime.timedelta(seconds=int(seconds))).split(":") + ] + if duration_parts[0]: + # if >=1hr, append seconds as tenth of minutes + # duration_parts[1] = duration_parts[1] + round(duration_parts[2] / 60, 1) + duration_parts = duration_parts[:2] + result = " ".join( + [str(part) + units[x] for x, part in enumerate(duration_parts, 0) if part] + ) + return result or "0s" + + +def elapsed_since(start, template="({duration} elapsed)"): + """Return a formatted string, e.g. '(HH:MM:SS elapsed)'.""" + seconds = time.time() - start + duration = duration_to_string(seconds=int(seconds)) # noqa + return fstr(template, locals()) + + +def flush_buffers(): + """Flush the logging buffers, stderr, and stdout.""" + sys.stdout.flush() + sys.stderr.flush() + for loghandler in DEFAULT_LOGGER.handlers: + loghandler.flush() + sys.stdout.flush() + sys.stderr.flush() + + +def _convert_mem_units( + from_val, from_units: str = None, to_units: str = None, sig_digits=None +): + """ + Convert memory units. + + Arguments: + from_val {[type]} -- [description] + + Keyword Arguments: + from_units {str} -- [description] (default: {None}) + to_units {str} -- [description] (default: {None}) + sig_digits {[type]} -- [description] (default: {None}) + + Returns: + [type] -- [description] + """ + from_units = from_units or "B" + _mem_units_map = { + "B": 1, + "K": (1024 ** 1), + "MB": (1024 ** 2), + "GB": (1024 ** 3), + "TB": (1024 ** 4), + } + num_bytes = from_val * _mem_units_map[from_units] + return_tuple = not to_units + if not to_units: + cutover_factor = 800 + if to_units not in _mem_units_map: + if num_bytes < 100: # < 800 K as K + to_units = "B" + if num_bytes < cutover_factor * _mem_units_map["K"]: # < 800 K as K + to_units = "K" + elif num_bytes < cutover_factor * _mem_units_map["MB"]: # < 800 MB as MB + to_units = "MB" + elif num_bytes < cutover_factor * _mem_units_map["GB"]: # < 800 GB as GB + to_units = "GB" + else: # >= 800 TB as TB + to_units = "TB" + result = num_bytes * 1.0 / _mem_units_map[to_units] + if not sig_digits: + sig_digits = 1 if result >= 10 else 2 + if return_tuple: + return round(result, sig_digits), to_units + return round(result, sig_digits) + + +def _bytes_to_string(num_bytes, units=None): + """ + Return a string that efficiently represents the number of bytes. + + e.g. "476.4MB", "0.92TB", etc. + """ + new_value, units = _convert_mem_units(num_bytes, from_units="B", to_units=None) + return f"{new_value}{units}" + + +def _ram_usage_string(process_id=None): + """ + Return a string representing the amount and percentage of memory used by this process. + """ + if not psutil: + return "(unknown mem usage - missing psutil library)" + process = psutil.Process(process_id or os.getpid()) + amount = _bytes_to_string(process.memory_info().rss) + percent = process.memory_percent() + return f"(mem usage: {amount}, {percent:,.1f}%)" + + +def _cpu_usage_string(process_id=None): + """ + Return a string representing the amount and percentage of memory used by this process. + """ + if not psutil: + return "(unknown CPU usage - missing psutil library)" + process = psutil.Process(process_id or os.getpid()) + return f"(CPU {process.cpu_percent(interval=0.2)}%)" + + +def _get_printable_context(context: dict = None, as_str=True): + """Return a string or dict, obfuscating names that look like keys.""" + printable_dict = { + k: ( + v + if not any( + [ + "secret" in k.lower(), + "pwd" in k.lower(), + "pass" in k.lower(), + "access.key" in k.lower(), + ] + ) + else "****" + ) + for k, v in context.items() + if k != "__builtins__" + } + if as_str: + return "\n".join([f"\t{k}:\t{v}" for k, v in printable_dict.items()]) + return printable_dict + + +def _caller_and_lineno(): + caller = inspect.getframeinfo(inspect.stack()[1][0]) + return f"{caller.filename}:{caller.lineno}" + + +def fstr(fstring_text, locals, globals=None): + """Dynamically evaluate the provided `fstring_text`. + + Sample usage: + format_str = "{i}*{i}={i*i}" + i = 2 + fstr(format_str, locals()) # "2*2=4" + i = 4 + fstr(format_str, locals()) # "4*4=16" + fstr(format_str, {"i": 12}) # "10*10=100" + """ + locals = locals or {} + globals = globals or {} + result = eval(f'f"{fstring_text}"', locals, globals) + return result + + +def heartbeat_printer( + desc_text, + msg_queue: queue.Queue, + interval, + show_memory=True, + show_cpu=None, + start_time=None, +): + start_time = start_time or time.time() + show_memory = show_memory if show_memory is not None else True + show_cpu = show_cpu if show_cpu is not None else show_memory + time.sleep(interval) + while msg_queue.empty(): + elapsed_str = elapsed_since(start_time, template="({duration} and counting...)") + msg = f"Still {desc_text} {elapsed_str}" + if show_cpu: + msg += _cpu_usage_string(process_id=None) + if show_memory: + msg += _ram_usage_string(process_id=None) + DEFAULT_LOGGER.info(msg) + time.sleep(interval) + + +@contextmanager +def logged_block( + desc_text, + start_msg="Beginning {desc_text}...", + success_msg="Completed {desc_text} {success_detail} {elapsed}", + success_detail="", # noqa + show_memory=None, + heartbeat_interval=DEFAULT_HEARTBEAT_MINUTES * 60, + **kwargs, +): + """ + Time and log the execution inside a with block. + + Sample usage: + + with logged_block("running '{job.name}' job", job=job_obj): + do_job(job) + """ + start = time.time() + context_dict = locals().copy() + context_dict.update(kwargs) + if start_msg: + if show_memory: + start_msg = start_msg + (" " * 15) + _ram_usage_string() + DEFAULT_LOGGER.info(fstr(start_msg, locals=context_dict)) + msg_queue = None + if heartbeat_interval: + msg_queue = queue.Queue() + heartbeat = threading.Thread( + target=heartbeat_printer, + args=[], + kwargs={ + "desc_text": desc_text, + "msg_queue": msg_queue, + "interval": heartbeat_interval, + "show_memory": show_memory, + }, + ) + heartbeat.daemon = True + heartbeat.start() + yield + if heartbeat: + try: + msg_queue.put("cancel") + except Exception as ex: + DEFAULT_LOGGER.exception("Failed to kill heartbeat log. {ex}") + context_dict["elapsed"] = elapsed_since(start) + if success_msg: + if show_memory: + success_msg = success_msg + _ram_usage_string() + DEFAULT_LOGGER.info(fstr(success_msg, locals=context_dict)) + + +class logged(): + """ + Decorator class for logging function start, completion, and elapsed time. + + Sample usage: + @logged() + def my_func_a(): + pass + + @logged(log_fn=logging.debug) + def my_func_b(): + pass + + @logged("doing a thing") + def my_func_c(): + pass + + @logged("doing a thing with {foo_obj.name}") + def my_func_d(foo_obj): + pass + + @logged("doing a thing with '{custom_kwarg}'", custom_kwarg="foo") + def my_func_d(foo_obj): + pass + """ + + def __init__( + self, + desc_text="{fn.__name__}() for '{desc_detail}'", + desc_detail="", + start_msg="Beginning {desc_text}...", + success_msg="Completed {desc_text} {elapsed} ({success_detail})", + success_detail="", + buffer_lines=0, + log_fn=None, + **addl_kwargs, + ): + """All arguments optional.""" + log_fn = log_fn or DEFAULT_LOGGER.info + self.default_context = addl_kwargs.copy() # start with addl. args + self.default_context.update(locals()) # merge all constructor args + self.buffer_lines = buffer_lines + + def print_buffer(self): + """Clear print buffer.""" + if self.buffer_lines: + nl = "\n" + flush_buffers() + sys.stdout.write(f"\n\n{('-' * 80 + nl) * self.buffer_lines}\n\n") + flush_buffers() + + def __call__(self, fn): + """Call the decorated function.""" + + def wrapped_fn(*args, **kwargs): + """ + The decorated function definition. + + Note that the log needs access to + all passed arguments to the decorator, as well as all of the function's + native args in a dictionary, even if args are not provided by keyword. + If start_msg is None or success_msg is None, those log entries are skipped. + """ + + def re_eval(context_dict, context_key: str): + """Evaluate any f-strings in context_dict[context_key], save the result.""" + try: + context_dict[context_key] = fstr( + context_dict[context_key], locals=context_dict + ) + except Exception as ex: + DEFAULT_LOGGER.warning( + f"Error evaluating '{context_key}' " + f"({context_dict.get(context_key, '(missing)')})" + f": '{ex}' with context: '{_get_printable_context(context_dict)}'" + ) + + start = time.time() + fn_context = self.default_context.copy() + fn_context["fn"] = fn + fn_context["elapsed"] = None + argspec = inspect.getfullargspec(fn) + # DEFAULT_LOGGER.info(f"argspec: {argspec}") + if argspec.defaults: + # DEFAULT_LOGGER.info( + # f"attempting to set defaults: {list(enumerate(argspec.defaults, 1))}" + # ) + for i, v in enumerate(reversed(argspec.defaults), 1): + fn_context[argspec.args[-1 * i]] = v + if argspec.kwonlydefaults: + fn_context.update(dict(argspec.kwonlydefaults)) + fn_arg_names = argspec.args.copy() + # DEFAULT_LOGGER.info(f"args: {fn_arg_names}") + if argspec.varargs is not None: + # unnamed ordered args + fn_context[argspec.varargs] = args + else: + for x, arg_value in enumerate(args, 0): + # DEFAULT_LOGGER.info( + # f"Attempting to set: fn_arg_names[{x}] = {arg_value}" + # ) + fn_context[fn_arg_names[x]] = arg_value + fn_context.update(kwargs) + desc_detail_fn = None + log_fn = fn_context["log_fn"] + # If desc_detail is callable, evaluate dynamically (both before and after) + if callable(fn_context["desc_detail"]): + desc_detail_fn = fn_context["desc_detail"] + fn_context["desc_detail"] = desc_detail_fn() + # Re-evaluate any decorator args which are fstrings + re_eval(fn_context, "desc_detail") + re_eval(fn_context, "desc_text") + # Remove 'desc_detail' if blank or unused + fn_context["desc_text"] = fn_context["desc_text"].replace("'' ", "") + re_eval(fn_context, "start_msg") + if fn_context["start_msg"]: + self.print_buffer() + log_fn(fn_context["start_msg"]) # log start of execution + result = fn(*args, **kwargs) + if fn_context["success_msg"]: # log the end of execution + if callable(fn_context["success_msg"]): + fn_context["success_msg"] = fn_context["success_msg"]() + fn_context["result"] = result + if desc_detail_fn: # If desc_detail callable, then reevaluate + fn_context["desc_detail"] = desc_detail_fn() + fn_context["elapsed"] = elapsed_since(start) + re_eval(fn_context, "success_detail") + re_eval(fn_context, "success_msg") + log_fn(fn_context["success_msg"].replace(" ()", "")) + self.print_buffer() + return result + + wrapped_fn.__doc__ = fn.__doc__ # Use docstring from inner function. + return wrapped_fn diff --git a/resources/lib/main_service.py b/resources/lib/main_service.py index 51bbc25..8c26f5a 100644 --- a/resources/lib/main_service.py +++ b/resources/lib/main_service.py @@ -8,7 +8,7 @@ Background service which launches the spotty binary and monitors the player ''' -from utils import log_msg, ADDON_ID, log_exception, get_token, Spotty, PROXY_PORT, kill_spotty, parse_spotify_track +from utils import log_msg, ADDON_ID, log_exception, get_token, Spotty, PROXY_PORT, parse_spotify_track from player_monitor import ConnectPlayer from connect_daemon import ConnectDaemon from httpproxy import ProxyRunner @@ -24,8 +24,8 @@ import spotipy import time import threading -import thread -import StringIO +import _thread +import io class MainService: @@ -67,7 +67,7 @@ def main_loop(self): loop_timer = 5 while not self.kodimonitor.waitForAbort(loop_timer): # monitor logged in user - cmd = self.win.getProperty("spotify-cmd").decode("utf-8") + cmd = self.win.getProperty("spotify-cmd") if cmd == "__LOGOUT__": log_msg("logout cmd received") self.stop_connect_daemon() @@ -99,8 +99,8 @@ def main_loop(self): def close(self): '''shutdown, perform cleanup''' - log_msg('Shutdown requested !', xbmc.LOGNOTICE) - kill_spotty() + log_msg('Shutdown requested !', xbmc.LOGINFO) + self.spotty.kill_spotty() self.proxy_runner.stop() self.connect_player.close() self.stop_connect_daemon() @@ -108,7 +108,7 @@ def close(self): del self.addon del self.kodimonitor del self.win - log_msg('stopped', xbmc.LOGNOTICE) + log_msg('stopped', xbmc.LOGINFO) def switch_user(self, restart_daemon=False): @@ -121,10 +121,10 @@ def get_username(self): ''' get the current configured/setup username''' username = self.spotty.get_username() if not username: - username = self.addon.getSetting("username").decode("utf-8") + username = self.addon.getSetting("username") if not username and self.addon.getSetting("multi_account") == "true": - username1 = self.addon.getSetting("username1").decode("utf-8") - password1 = self.addon.getSetting("password1").decode("utf-8") + username1 = self.addon.getSetting("username1") + password1 = self.addon.getSetting("password1") if username1 and password1: self.addon.setSetting("username", username1) self.addon.setSetting("password", password1) @@ -166,7 +166,7 @@ def renew_token(self): self.sp._auth = auth_token["access_token"] me = self.sp.me() self.current_user = me["id"] - log_msg("Logged in to Spotify - Username: %s" % self.current_user, xbmc.LOGNOTICE) + log_msg("Logged in to Spotify - Username: %s" % self.current_user, xbmc.LOGINFO) # store authtoken and username as window prop for easy access by plugin entry self.win.setProperty("spotify-token", auth_token["access_token"]) self.win.setProperty("spotify-username", self.current_user) diff --git a/resources/lib/metadatautils/helpers/__init__.py b/resources/lib/metadatautils/helpers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/resources/lib/metadatautils/helpers/animatedart.py b/resources/lib/metadatautils/helpers/animatedart.py new file mode 100644 index 0000000..a1aff28 --- /dev/null +++ b/resources/lib/metadatautils/helpers/animatedart.py @@ -0,0 +1,192 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +"""Retrieve animated artwork for kodi movies""" + +import os, sys +if sys.version_info.major == 3: + from .utils import get_json, DialogSelect, log_msg, ADDON_ID +else: + from utils import get_json, DialogSelect, log_msg, ADDON_ID +import xbmc +import xbmcvfs +import xbmcgui +import xbmcaddon +from simplecache import use_cache +from datetime import timedelta + + +class AnimatedArt(object): + """get animated artwork""" + ignore_cache = False + + def __init__(self, simplecache=None, kodidb=None): + """Initialize - optionaly provide SimpleCache and KodiDb object""" + + if not kodidb: + if sys.version_info.major == 3: + from .kodidb import KodiDb + else: + from kodidb import KodiDb + self.kodidb = KodiDb() + else: + self.kodidb = kodidb + + if not simplecache: + from simplecache import SimpleCache + self.cache = SimpleCache() + else: + self.cache = simplecache + + @use_cache(14) + def get_animated_artwork(self, imdb_id, manual_select=False, ignore_cache=False): + """returns all available animated art for the given imdbid/tmdbid""" + # prefer local result + kodi_movie = self.kodidb.movie_by_imdbid(imdb_id) + if not manual_select and kodi_movie and kodi_movie["art"].get("animatedposter"): + result = { + "animatedposter": kodi_movie["art"].get("animatedposter"), + "animatedfanart": kodi_movie["art"].get("animatedfanart") + } + else: + result = { + "animatedposter": self.poster(imdb_id, manual_select), + "animatedfanart": self.fanart(imdb_id, manual_select), + "imdb_id": imdb_id + } + self.write_kodidb(result) + log_msg("get_animated_artwork for imdbid: %s - result: %s" % (imdb_id, result)) + return result + + def poster(self, imdb_id, manual_select=False): + """return preferred animated poster, optionally show selectdialog for manual selection""" + img = self.select_art(self.posters(imdb_id), manual_select, "poster") + return self.process_image(img, "poster", imdb_id) + + def fanart(self, imdb_id, manual_select=False): + """return preferred animated fanart, optionally show selectdialog for manual selection""" + img = self.select_art(self.fanarts(imdb_id), manual_select, "fanart") + return self.process_image(img, "fanart", imdb_id) + + def posters(self, imdb_id): + """return all animated posters for the given imdb_id (imdbid can also be tmdbid)""" + return self.get_art(imdb_id, "posters") + + def fanarts(self, imdb_id): + """return animated fanarts for the given imdb_id (imdbid can also be tmdbid)""" + return self.get_art(imdb_id, "fanarts") + + def get_art(self, imdb_id, art_type): + """get the artwork""" + art_db = self.get_animatedart_db() + if art_db.get(imdb_id): + return art_db[imdb_id][art_type] + return [] + + def get_animatedart_db(self): + """get the full animated art database as dict with imdbid and tmdbid as key + uses 7 day cache to prevent overloading the server""" + # get all animated posters from the online json file + cache = self.cache.get("animatedartdb") + if cache: + return cache + art_db = {} + data = get_json('http://www.consiliumb.com/animatedgifs/movies.json', None) + base_url = data.get("baseURL", "") + if data and data.get('movies'): + for item in data['movies']: + for db_id in ["imdbid", "tmdbid"]: + key = item[db_id] + art_db[key] = {"posters": [], "fanarts": []} + for entry in item['entries']: + entry_new = { + "contributedby": entry["contributedBy"], + "dateadded": entry["dateAdded"], + "language": entry["language"], + "source": entry["source"], + "image": "%s/%s" % (base_url, entry["image"].replace(".gif", "_original.gif")), + "thumb": "%s/%s" % (base_url, entry["image"])} + if entry['type'] == 'poster': + art_db[key]["posters"].append(entry_new) + elif entry['type'] == 'background': + art_db[key]["fanarts"].append(entry_new) + self.cache.set("animatedartdb", art_db, expiration=timedelta(days=7)) + return art_db + + @staticmethod + def select_art(items, manual_select=False, art_type=""): + """select the preferred image from the list""" + image = None + if manual_select: + # show selectdialog to manually select the item + results_list = [] + # add none and browse entries + listitem = xbmcgui.ListItem(label=xbmc.getLocalizedString(231)) + listitem.setArt({'icon': "DefaultAddonNone.png"}) + results_list.append(listitem) + listitem = xbmcgui.ListItem(label=xbmc.getLocalizedString(1030)) + listitem.setArt({'icon': "DefaultFolder.png"}) + results_list.append(listitem) + for item in items: + labels = [item["contributedby"], item["dateadded"], item["language"], item["source"]] + label = " / ".join(labels) + listitem = xbmcgui.ListItem(label=label) + listitem.setArt({'icon': item["thumb"]}) + results_list.append(listitem) + if manual_select and results_list: + dialog = DialogSelect("DialogSelect.xml", "", listing=results_list, window_title=art_type) + dialog.doModal() + selected_item = dialog.result + del dialog + if selected_item == 0: + image = "" + if selected_item == 1: + # browse for image + dialog = xbmcgui.Dialog() + if sys.version_info.major == 3: + image = dialog.browse(2, xbmc.getLocalizedString(1030), 'files', mask='.gif') + else: + image = dialog.browse(2, xbmc.getLocalizedString(1030), 'files', mask='.gif').decode("utf-8") + del dialog + elif selected_item > 1: + # user has selected an image from online results + image = items[selected_item - 2]["image"] + elif items: + # just grab the first item as best match + image = items[0]["image"] + return image + + @staticmethod + def process_image(image_url, art_type, imdb_id): + """animated gifs need to be stored locally, otherwise they won't work""" + # make sure that our local path for the gif images exists + addon = xbmcaddon.Addon(ADDON_ID) + gifs_path = "%sanimatedgifs/" % addon.getAddonInfo('profile') + del addon + if not xbmcvfs.exists(gifs_path): + xbmcvfs.mkdirs(gifs_path) + # only process existing images + if not image_url or not xbmcvfs.exists(image_url): + return None + # copy the image to our local path and return the new path as value + local_filename = "%s%s_%s.gif" % (gifs_path, imdb_id, art_type) + if xbmcvfs.exists(local_filename): + xbmcvfs.delete(local_filename) + # we don't use xbmcvfs.copy because we want to wait for the action to complete + img = xbmcvfs.File(image_url) + img_data = img.readBytes() + img.close() + img = xbmcvfs.File(local_filename, 'w') + img.write(img_data) + img.close() + return local_filename + + def write_kodidb(self, artwork): + """store the animated artwork in kodi database to access it with ListItem.Art(animatedartX)""" + kodi_movie = self.kodidb.movie_by_imdbid(artwork["imdb_id"]) + if kodi_movie: + params = { + "movieid": kodi_movie["movieid"], + "art": {"animatedfanart": artwork["animatedfanart"], "animatedposter": artwork["animatedposter"]} + } + self.kodidb.set_json('VideoLibrary.SetMovieDetails', params) diff --git a/resources/lib/metadatautils/helpers/channellogos.py b/resources/lib/metadatautils/helpers/channellogos.py new file mode 100644 index 0000000..15cbfcf --- /dev/null +++ b/resources/lib/metadatautils/helpers/channellogos.py @@ -0,0 +1,59 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +""" + script.module.metadatautils + channellogos.py + Get channellogos from kodidb or logosdb +""" + +import os, sys +if sys.version_info.major == 3: + from .utils import get_json, get_clean_image +else: + from utils import get_json, get_clean_image +import xbmc +import xbmcvfs + + +class ChannelLogos(object): + """get channellogo""" + + def __init__(self, kodidb=None): + """Initialize - optionaly provide KodiDb object""" + if not kodidb: + if sys.version_info.major == 3: + from .kodidb import KodiDb + else: + from kodidb import KodiDb + self.kodidb = KodiDb() + else: + self.kodidb = kodidb + + def get_channellogo(self, channelname): + """get channellogo for the supplied channelname""" + result = {} + for searchmethod in [self.search_kodi]: + if result: + break + result = searchmethod(channelname) + return result + + def search_kodi(self, searchphrase): + """search kodi json api for channel logo""" + result = "" + if xbmc.getCondVisibility("PVR.HasTVChannels"): + results = self.kodidb.get_json( + 'PVR.GetChannels', + fields=["thumbnail"], + returntype="tvchannels", + optparam=( + "channelgroupid", + "alltv")) + for item in results: + if item["label"] == searchphrase: + channelicon = get_clean_image(item['thumbnail']) + if channelicon and xbmcvfs.exists(channelicon): + result = channelicon + break + return result \ No newline at end of file diff --git a/resources/lib/metadatautils/helpers/extrafanart.py b/resources/lib/metadatautils/helpers/extrafanart.py new file mode 100644 index 0000000..4a69a1d --- /dev/null +++ b/resources/lib/metadatautils/helpers/extrafanart.py @@ -0,0 +1,44 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +""" + script.module.metadatautils + extrafanart.py + Get extrafanart location for kodi media +""" + +import os, sys +import xbmcvfs + + +def get_extrafanart(file_path): + """get extrafanart path on disk based on media path""" + result = {} + efa_path = "" + if "plugin.video.emby" in file_path: + # workaround for emby addon + efa_path = u"plugin://plugin.video.emby/extrafanart?path=" + file_path + elif "plugin://" in file_path: + efa_path = "" + elif "videodb://" in file_path: + efa_path = "" + else: + count = 0 + while not count == 3: + # lookup extrafanart folder by navigating up the tree + file_path = os.path.dirname(file_path) + try_path = file_path + u"/extrafanart/" + if xbmcvfs.exists(try_path): + efa_path = try_path + break + count += 1 + + if efa_path: + result["art"] = {"extrafanart": efa_path} + for count, file in enumerate(xbmcvfs.listdir(efa_path)[1]): + if file.lower().endswith(".jpg"): + if sys.version_info.major == 3: + result["art"]["ExtraFanArt.%s" % count] = efa_path + file + else: + result["art"]["ExtraFanArt.%s" % count] = efa_path + file.decode("utf-8") + return result diff --git a/resources/lib/metadatautils/helpers/extraposter.py b/resources/lib/metadatautils/helpers/extraposter.py new file mode 100644 index 0000000..59acc3b --- /dev/null +++ b/resources/lib/metadatautils/helpers/extraposter.py @@ -0,0 +1,44 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +""" + script.module.metadatautils + extraposter.py + Get extraposter location for kodi media +""" + +import os +import xbmcvfs + + +def get_extraposter(file_path): + """get extraposter path on disk based on media path""" + result = {} + efa_path = "" + if "plugin.video.emby" in file_path: + # workaround for emby addon + efa_path = u"plugin://plugin.video.emby/extraposter?path=" + file_path + elif "plugin://" in file_path: + efa_path = "" + elif "videodb://" in file_path: + efa_path = "" + else: + count = 0 + while not count == 3: + # lookup extraposter folder by navigating up the tree + file_path = os.path.dirname(file_path) + try_path = file_path + u"/extraposter/" + if xbmcvfs.exists(try_path): + efa_path = try_path + break + count += 1 + + if efa_path: + result["art"] = {"extraposter": efa_path} + for count, file in enumerate(xbmcvfs.listdir(efa_path)[1]): + if file.lower().endswith(".jpg"): + if sys.version_info.major == 3: + result["art"]["ExtraPoster.%s" % count] = efa_path + file + else: + result["art"]["ExtraPoster.%s" % count] = efa_path + file.decode("utf-8") + return result diff --git a/resources/lib/metadatautils/helpers/fanarttv.py b/resources/lib/metadatautils/helpers/fanarttv.py new file mode 100644 index 0000000..2e3d996 --- /dev/null +++ b/resources/lib/metadatautils/helpers/fanarttv.py @@ -0,0 +1,156 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +"""Get artwork for media from fanart.tv""" + +import os, sys +if sys.version_info.major == 3: + from .utils import get_json, KODI_LANGUAGE, process_method_on_list, try_parse_int, ADDON_ID +else: + from utils import get_json, KODI_LANGUAGE, process_method_on_list, try_parse_int, ADDON_ID +from operator import itemgetter +import xbmcaddon +import datetime + + +class FanartTv(object): + """get artwork from fanart.tv""" + base_url = 'http://webservice.fanart.tv/v3/' + api_key = '' + client_key = '' + ignore_cache = False + + def __init__(self, simplecache=None): + """Initialize - optionaly provide simplecache object""" + if not simplecache: + from simplecache import SimpleCache + self.cache = SimpleCache() + else: + self.cache = simplecache + addon = xbmcaddon.Addon(ADDON_ID) + self.client_key = addon.getSetting("fanarttv_apikey").strip() + del addon + + def artist(self, artist_id): + """get artist artwork""" + data = self.get_data("music/%s" % artist_id) + mapping_table = [("artistbackground", "fanart"), ("artistthumb", "thumb"), + ("hdmusiclogo", "clearlogo"), ("musiclogo", "clearlogo"), ("musicbanner", "banner")] + return self.map_artwork(data, mapping_table) + + def album(self, album_id): + """get album artwork""" + artwork = {} + data = self.get_data("music/albums/%s" % album_id) + if data: + mapping_table = [("cdart", "discart"), ("albumcover", "thumb")] + if sys.version_info.major == 3: + for item in data["albums"].values(): + artwork.update(self.map_artwork(item, mapping_table)) + else: + for item in data["albums"].itervalues(): + artwork.update(self.map_artwork(item, mapping_table)) + return artwork + + def musiclabel(self, label_id): + """get musiclabel logo""" + artwork = {} + data = self.get_data("music/labels/%s" % label_id) + if data and data.get("musiclabel"): + for item in data["musiclabel"]: + # we just grab the first logo (as the result is sorted by likes) + if item["colour"] == "colour" and "logo_color" not in artwork: + artwork["logo_color"] = item["url"] + elif item["colour"] == "white" and "logo_white" not in artwork: + artwork["logo_white"] = item["url"] + return artwork + + def movie(self, movie_id): + """get movie artwork""" + data = self.get_data("movies/%s" % movie_id) + mapping_table = [("hdmovielogo", "clearlogo"), ("moviedisc", "discart"), ("movielogo", "clearlogo"), + ("movieposter", "poster"), ("hdmovieclearart", "clearart"), ("movieart", "clearart"), + ("moviebackground", "fanart"), ("moviebanner", "banner"), ("moviethumb", "landscape")] + return self.map_artwork(data, mapping_table) + + def tvshow(self, tvshow_id): + """get tvshow artwork""" + data = self.get_data("tv/%s" % tvshow_id) + mapping_table = [("hdtvlogo", "clearlogo"), ("clearlogo", "clearlogo"), ("hdclearart", "clearart"), + ("clearart", "clearart"), ("showbackground", "fanart"), ("tvthumb", "landscape"), + ("tvbanner", "banner"), ("characterart", "characterart"), ("tvposter", "poster")] + return self.map_artwork(data, mapping_table) + + def tvseason(self, tvshow_id, season): + """get season artwork - banner+landscape only as the seasonposters lacks a season in the json response""" + data = self.get_data("tv/%s" % tvshow_id) + artwork = {} + mapping_table = [("seasonthumb", "landscape"), ("seasonbanner", "banner")] + for artwork_mapping in mapping_table: + fanarttv_type = artwork_mapping[0] + kodi_type = artwork_mapping[1] + if fanarttv_type in data: + images = [item for item in data[fanarttv_type] if item["season"] == str(season)] + images = process_method_on_list(self.score_image, data[fanarttv_type]) + if images: + images = sorted(images, key=itemgetter("score"), reverse=True) + images = [item["url"] for item in images] + artwork[kodi_type + "s"] = images + artwork[kodi_type] = images[0] + return artwork + + def get_data(self, query): + """helper method to get data from fanart.tv json API""" + api_key = self.api_key + if not api_key: + api_key = '639191cb0774661597f28a47e7e2bad5' # rate limited default api key + url = '%s%s?api_key=%s' % (self.base_url, query, api_key) + if self.client_key or self.api_key: + if self.client_key: + url += '&client_key=%s' % self.client_key + rate_limit = None + expiration = datetime.timedelta(days=7) + else: + # without personal or app provided api key = rate limiting and older info from cache + rate_limit = ("fanart.tv", 2) + expiration = datetime.timedelta(days=60) + cache = self.cache.get(url) + if cache: + result = cache + else: + result = get_json(url, ratelimit=rate_limit) + self.cache.set(url, result, expiration=expiration) + return result + + def map_artwork(self, data, mapping_table): + """helper method to map the artwork received from fanart.tv to kodi known formats""" + artwork = {} + if data: + for artwork_mapping in mapping_table: + fanarttv_type = artwork_mapping[0] + kodi_type = artwork_mapping[1] + images = [] + if fanarttv_type in data and kodi_type not in artwork: + # artworktype is found in the data, now do some magic to select the best one + images = process_method_on_list(self.score_image, data[fanarttv_type]) + # set all images in list and select the item with highest score + if images: + images = sorted(images, key=itemgetter("score"), reverse=True) + images = [item["url"] for item in images] + artwork[kodi_type + "s"] = images + artwork[kodi_type] = images[0] + return artwork + + @staticmethod + def score_image(item): + """score item based on number of likes and the language""" + score = 0 + item["url"] = item["url"].replace(" ", "%20") + score += try_parse_int(item["likes"]) + if "lang" in item: + if item["lang"] == KODI_LANGUAGE: + score += 1000 + elif item["lang"] == "en": + score += 500 + item["score"] = score + return item diff --git a/resources/lib/metadatautils/helpers/google.py b/resources/lib/metadatautils/helpers/google.py new file mode 100644 index 0000000..2be5111 --- /dev/null +++ b/resources/lib/metadatautils/helpers/google.py @@ -0,0 +1,88 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +"""get images from google images""" + +import os, sys +if sys.version_info.major == 3: + from .utils import DialogSelect, requests, log_exception +else: + from utils import DialogSelect, requests, log_exception +import bs4 as BeautifulSoup +import xbmc +import xbmcvfs +import xbmcgui +from simplecache import use_cache + + +class GoogleImages(object): + """get images from google images""" + + def __init__(self, simplecache=None): + """Initialize - optionaly provide simplecache object""" + if not simplecache: + from simplecache import SimpleCache + self.cache = SimpleCache() + else: + self.cache = simplecache + + def search_images(self, search_query): + """search google images with the given query, returns list of all images found""" + return self.get_data(search_query) + + def search_image(self, search_query, manual_select=False): + """ + search google images with the given query, returns first/best match + optional parameter: manual_select (bool), will show selectdialog to allow manual select by user + """ + image = "" + images_list = [] + for img in self.get_data(search_query): + img = img.replace(" ", "%20") # fix for spaces in url + if xbmcvfs.exists(img): + if not manual_select: + # just return the first image found (assuming that will be the best match) + return img + else: + # manual lookup, list results and let user pick one + listitem = xbmcgui.ListItem(label=img) + listitem.setArt({'icon': img}) + images_list.append(listitem) + if manual_select and images_list: + dialog = DialogSelect("DialogSelect.xml", "", listing=images_list, window_title="%s - Google" + % xbmc.getLocalizedString(283)) + dialog.doModal() + selected_item = dialog.result + del dialog + if selected_item != -1: + selected_item = images_list[selected_item] + if sys.version_info.major == 3: + image = selected_item.getLabel() + else: + image = selected_item.getLabel().decode("utf-8") + return image + + @use_cache(30) + def get_data(self, search_query): + """helper method to get data from google images by scraping and parsing""" + params = {"site": "imghp", "tbm": "isch", "tbs": "isz:l", "q": search_query} + headers = {'User-agent': 'Mozilla/4.0 (compatible; MSIE 7.0; Windows Phone OS 7.0; Trident/3.1; \ + IEMobile/7.0; LG; GW910)'} + html = '' + try: + html = requests.get('https://www.google.com/search', headers=headers, params=params, timeout=5).text + except Exception as exc: + log_exception(__name__, exc) + soup = BeautifulSoup.BeautifulSoup(html, features="html.parser") + results = [] + for div in soup.findAll('div'): + if not div.get("id") == "images": + for a_link in div.findAll("a"): + page = a_link.get("href") + try: + img = page.split("imgurl=")[-1] + img = img.split("&imgrefurl=")[0] + results.append(img) + except Exception: + pass + return results diff --git a/resources/lib/metadatautils/helpers/imdb.py b/resources/lib/metadatautils/helpers/imdb.py new file mode 100644 index 0000000..0ca60b7 --- /dev/null +++ b/resources/lib/metadatautils/helpers/imdb.py @@ -0,0 +1,78 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +""" + script.module.metadatautils + imdb.py + Get metadata from imdb +""" + +import os, sys +if sys.version_info.major == 3: + from .utils import requests, try_parse_int +else: + from utils import requests, try_parse_int +import bs4 as BeautifulSoup +from simplecache import use_cache + + +class Imdb(object): + """Info from IMDB (currently only top250)""" + + def __init__(self, simplecache=None, kodidb=None): + """Initialize - optionaly provide simplecache object""" + if not simplecache: + from simplecache import SimpleCache + self.cache = SimpleCache() + else: + self.cache = simplecache + if not kodidb: + if sys.version_info.major == 3: + from .kodidb import KodiDb + else: + from kodidb import KodiDb + self.kodidb = KodiDb() + else: + self.kodidb = kodidb + + @use_cache(2) + def get_top250_rating(self, imdb_id): + """get the top250 rating for the given imdbid""" + return {"IMDB.Top250": self.get_top250_db().get(imdb_id, 0)} + + @use_cache(7) + def get_top250_db(self): + """ + get the top250 listing for both movies and tvshows as dict with imdbid as key + uses 7 day cache to prevent overloading the server + """ + results = {} + for listing in [("top", "chttp_tt_"), ("toptv", "chttvtp_tt_")]: + html = requests.get( + "http://www.imdb.com/chart/%s" % + listing[0], headers={ + 'User-agent': 'Mozilla/5.0'}, timeout=20) + soup = BeautifulSoup.BeautifulSoup(html.text, features="html.parser") + for table in soup.findAll('table'): + if not table.get("class") == "chart full-width": + for td_def in table.findAll('td'): + if not td_def.get("class") == "titleColumn": + a_link = td_def.find("a") + if a_link: + url = a_link["href"] + imdb_id = url.split("/")[2] + imdb_rank = url.split(listing[1])[1] + results[imdb_id] = try_parse_int(imdb_rank) + self.write_kodidb(results) + return results + + def write_kodidb(self, results): + """store the top250 position in kodi database to access it with ListItem.Top250""" + for imdb_id in results: + kodi_movie = self.kodidb.movie_by_imdbid(imdb_id) + if kodi_movie: + params = { + "movieid": kodi_movie["movieid"], + "top250": results[imdb_id] + } + self.kodidb.set_json('VideoLibrary.SetMovieDetails', params) diff --git a/resources/lib/metadatautils/helpers/kodi_constants.py b/resources/lib/metadatautils/helpers/kodi_constants.py new file mode 100644 index 0000000..38ed423 --- /dev/null +++ b/resources/lib/metadatautils/helpers/kodi_constants.py @@ -0,0 +1,55 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +""" + script.module.metadatautils + kodi_constants.py + Several common constants for use with Kodi json api +""" +import os, sys +if sys.version_info.major == 3: + from .utils import KODI_VERSION +else: + from utils import KODI_VERSION + +FIELDS_BASE = ["dateadded", "file", "lastplayed", "plot", "title", "art", "playcount"] +FIELDS_FILE = FIELDS_BASE + ["streamdetails", "director", "resume", "runtime"] +FIELDS_MOVIES = FIELDS_FILE + ["plotoutline", "sorttitle", "cast", "votes", "showlink", "top250", "trailer", "year", + "country", "studio", "set", "genre", "mpaa", "setid", "rating", "tag", "tagline", + "writer", "originaltitle", + "imdbnumber"] +if KODI_VERSION > 16: + FIELDS_MOVIES.append("uniqueid") +FIELDS_TVSHOWS = FIELDS_BASE + ["sorttitle", "mpaa", "premiered", "year", "episode", "watchedepisodes", "votes", + "rating", "studio", "season", "genre", "cast", "episodeguide", "tag", "originaltitle", + "imdbnumber"] +FIELDS_EPISODES = FIELDS_FILE + ["cast", "productioncode", "rating", "votes", "episode", "showtitle", "tvshowid", + "season", "firstaired", "writer", "originaltitle"] +FIELDS_MUSICVIDEOS = FIELDS_FILE + ["genre", "artist", "tag", "album", "track", "studio", "year"] +FIELDS_FILES = FIELDS_FILE + ["plotoutline", "sorttitle", "cast", "votes", "trailer", "year", "country", "studio", + "genre", "mpaa", "rating", "tagline", "writer", "originaltitle", "imdbnumber", + "premiered", "episode", "showtitle", + "firstaired", "watchedepisodes", "duration", "season"] +FIELDS_SONGS = ["artist", "displayartist", "title", "rating", "fanart", "thumbnail", "duration", "disc", + "playcount", "comment", "file", "album", "lastplayed", "genre", "musicbrainzartistid", "track", + "dateadded"] +FIELDS_ALBUMS = ["title", "fanart", "thumbnail", "genre", "displayartist", "artist", + "musicbrainzalbumartistid", "year", "rating", "artistid", "musicbrainzalbumid", "theme", "description", + "type", "style", "playcount", "albumlabel", "mood", "dateadded"] +FIELDS_ARTISTS = ["born", "formed", "died", "style", "yearsactive", "mood", "fanart", "thumbnail", + "musicbrainzartistid", "disbanded", "description", "instrument"] +FIELDS_RECORDINGS = ["art", "channel", "directory", "endtime", "file", "genre", "icon", "playcount", "plot", + "plotoutline", "resume", "runtime", "starttime", "streamurl", "title"] +FIELDS_CHANNELS = ["broadcastnow", "channeltype", "hidden", "locked", "lastplayed", "thumbnail", "channel"] + +FILTER_UNWATCHED = {"operator": "lessthan", "field": "playcount", "value": "1"} +FILTER_WATCHED = {"operator": "isnot", "field": "playcount", "value": "0"} +FILTER_RATING = {"operator": "greaterthan", "field": "rating", "value": "7"} +FILTER_RATING_MUSIC = {"operator": "greaterthan", "field": "rating", "value": "3"} +FILTER_INPROGRESS = {"operator": "true", "field": "inprogress", "value": ""} +SORT_RATING = {"method": "rating", "order": "descending"} +SORT_RANDOM = {"method": "random", "order": "descending"} +SORT_TITLE = {"method": "title", "order": "ascending"} +SORT_DATEADDED = {"method": "dateadded", "order": "descending"} +SORT_LASTPLAYED = {"method": "lastplayed", "order": "descending"} +SORT_EPISODE = {"method": "episode"} diff --git a/resources/lib/metadatautils/helpers/kodidb.py b/resources/lib/metadatautils/helpers/kodidb.py new file mode 100644 index 0000000..f844a01 --- /dev/null +++ b/resources/lib/metadatautils/helpers/kodidb.py @@ -0,0 +1,753 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +"""get metadata from the kodi DB""" + +import os, sys +import xbmc +import xbmcgui +import xbmcvfs +if sys.version_info.major == 3: + from .utils import json, try_encode, log_msg, log_exception, get_clean_image, KODI_VERSION + from .utils import try_parse_int, localdate_from_utc_string, localized_date_time + from .kodi_constants import * +else: + from utils import json, try_encode, log_msg, log_exception, get_clean_image, KODI_VERSION + from utils import try_parse_int, localdate_from_utc_string, localized_date_time + from kodi_constants import * +from operator import itemgetter +import arrow + + +class KodiDb(object): + """various methods and helpers to get data from kodi json api""" + + def movie(self, db_id): + """get moviedetails from kodi db""" + return self.get_json("VideoLibrary.GetMovieDetails", returntype="moviedetails", + fields=FIELDS_MOVIES, optparam=("movieid", try_parse_int(db_id))) + + def movies(self, sort=None, filters=None, limits=None, filtertype=None): + """get moviedetails from kodi db""" + return self.get_json("VideoLibrary.GetMovies", sort=sort, filters=filters, + fields=FIELDS_MOVIES, limits=limits, returntype="movies", filtertype=filtertype) + + def movie_by_imdbid(self, imdb_id): + """gets a movie from kodidb by imdbid.""" + # apparently you can't filter on imdb so we have to do this the complicated way + if KODI_VERSION > 16: + # from Kodi 17 we have a uniqueid field instead of imdbnumber + all_items = self.get_json('VideoLibrary.GetMovies', fields=["uniqueid"], returntype="movies") + for item in all_items: + if 'uniqueid' in item: + for item2 in item["uniqueid"].values(): + if item2 == imdb_id: + return self.movie(item["movieid"]) + else: + all_items = self.get_json('VideoLibrary.GetMovies', fields=["imdbnumber"], returntype="movies") + for item in all_items: + if item["imdbnumber"] == imdb_id: + return self.movie(item["movieid"]) + return {} + + def tvshow(self, db_id): + """get tvshow from kodi db""" + tvshow = self.get_json("VideoLibrary.GetTvShowDetails", returntype="tvshowdetails", + fields=FIELDS_TVSHOWS, optparam=("tvshowid", try_parse_int(db_id))) + return self.tvshow_watchedcounts(tvshow) + + def tvshows(self, sort=None, filters=None, limits=None, filtertype=None): + """get tvshows from kodi db""" + tvshows = self.get_json("VideoLibrary.GetTvShows", sort=sort, filters=filters, + fields=FIELDS_TVSHOWS, limits=limits, returntype="tvshows", filtertype=filtertype) + # append watched counters + for tvshow in tvshows: + self.tvshow_watchedcounts(tvshow) + return tvshows + + def tvshow_by_imdbid(self, imdb_id): + """gets a tvshow from kodidb by imdbid (or tvdbid).""" + # apparently you can't filter on imdb so we have to do this the complicated way + if KODI_VERSION > 16: + # from Kodi 17 we have a uniqueid field instead of imdbnumber + all_items = self.get_json('VideoLibrary.GetTvShows', fields=["uniqueid"], returntype="tvshows") + for item in all_items: + if 'uniqueid' in item: + for item2 in item["uniqueid"].values(): + if item2 == imdb_id: + return self.tvshow(item["tvshowid"]) + else: + # pre-kodi 17 approach + all_items = self.get_json('VideoLibrary.GetTvShows', fields=["imdbnumber"], returntype="tvshows") + for item in all_items: + if item["imdbnumber"] == imdb_id: + return self.tvshow(item["tvshowid"]) + return {} + + def episode(self, db_id): + """get episode from kodi db""" + return self.get_json("VideoLibrary.GetEpisodeDetails", returntype="episodedetails", + fields=FIELDS_EPISODES, optparam=("episodeid", try_parse_int(db_id))) + + def episodes(self, sort=None, filters=None, limits=None, filtertype=None, tvshowid=None, fields=FIELDS_EPISODES): + """get episodes from kodi db""" + if tvshowid: + params = ("tvshowid", try_parse_int(tvshowid)) + else: + params = None + return self.get_json("VideoLibrary.GetEpisodes", sort=sort, filters=filters, fields=fields, + limits=limits, returntype="episodes", filtertype=filtertype, optparam=params) + + def musicvideo(self, db_id): + """get musicvideo from kodi db""" + return self.get_json("VideoLibrary.GetMusicVideoDetails", returntype="musicvideodetails", + fields=FIELDS_MUSICVIDEOS, optparam=("musicvideoid", try_parse_int(db_id))) + + def musicvideos(self, sort=None, filters=None, limits=None, filtertype=None): + """get musicvideos from kodi db""" + return self.get_json("VideoLibrary.GetMusicVideos", sort=sort, filters=filters, + fields=FIELDS_MUSICVIDEOS, limits=limits, returntype="musicvideos", filtertype=filtertype) + + def movieset(self, db_id, include_set_movies_fields=""): + """get movieset from kodi db""" + if include_set_movies_fields: + optparams = [("setid", try_parse_int(db_id)), ("movies", {"properties": include_set_movies_fields})] + else: + optparams = ("setid", try_parse_int(db_id)) + return self.get_json("VideoLibrary.GetMovieSetDetails", returntype="", + fields=["title", "art", "playcount"], optparam=optparams) + + def moviesets(self, sort=None, limits=None, include_set_movies=False): + """get moviesetdetails from kodi db""" + if include_set_movies: + optparam = ("movies", {"properties": FIELDS_MOVIES}) + else: + optparam = None + return self.get_json("VideoLibrary.GetMovieSets", sort=sort, + fields=["title", "art", "playcount"], + limits=limits, returntype="", optparam=optparam) + + def files(self, vfspath, sort=None, limits=None): + """gets all items in a kodi vfs path""" + return self.get_json("Files.GetDirectory", returntype="", optparam=("directory", vfspath), + fields=FIELDS_FILES, sort=sort, limits=limits) + + def genres(self, media_type): + """return all genres for the given media type (movie/tvshow/musicvideo)""" + return self.get_json("VideoLibrary.GetGenres", fields=["thumbnail", "title"], + returntype="genres", optparam=("type", media_type)) + + def song(self, db_id): + """get songdetails from kodi db""" + return self.get_json("AudioLibrary.GetSongDetails", returntype="songdetails", + fields=FIELDS_SONGS, optparam=("songid", try_parse_int(db_id))) + + def songs(self, sort=None, filters=None, limits=None, filtertype=None): + """get songs from kodi db""" + return self.get_json("AudioLibrary.GetSongs", sort=sort, filters=filters, + fields=FIELDS_SONGS, limits=limits, returntype="songs", filtertype=filtertype) + + def album(self, db_id): + """get albumdetails from kodi db""" + album = self.get_json("AudioLibrary.GetAlbumDetails", returntype="albumdetails", + fields=FIELDS_ALBUMS, optparam=("albumid", try_parse_int(db_id))) + # override type as the kodi json api is returning the album type instead of mediatype + album["type"] = "album" + return album + + def albums(self, sort=None, filters=None, limits=None, filtertype=None): + """get albums from kodi db""" + albums = self.get_json("AudioLibrary.GetAlbums", sort=sort, filters=filters, + fields=FIELDS_ALBUMS, limits=limits, returntype="albums", filtertype=filtertype) + # override type as the kodi json api is returning the album type instead of mediatype + for album in albums: + album["type"] = "album" + return albums + + def artist(self, db_id): + """get artistdetails from kodi db""" + return self.get_json("AudioLibrary.GetArtistDetails", returntype="artistdetails", + fields=FIELDS_ARTISTS, optparam=("artistid", try_parse_int(db_id))) + + def artists(self, sort=None, filters=None, limits=None, filtertype=None): + """get artists from kodi db""" + return self.get_json("AudioLibrary.GetArtists", sort=sort, filters=filters, + fields=FIELDS_ARTISTS, limits=limits, returntype="artists", filtertype=filtertype) + + def recording(self, db_id): + """get pvr recording from kodi db""" + return self.get_json("PVR.GetRecordingDetails", returntype="recordingdetails", + fields=FIELDS_RECORDINGS, optparam=("recordingid", try_parse_int(db_id))) + + def recordings(self, limits=None): + """get pvr recordings from kodi db""" + return self.get_json("PVR.GetRecordings", fields=FIELDS_RECORDINGS, limits=limits, returntype="recordings") + + def channel(self, db_id): + """get pvr channel from kodi db""" + return self.get_json("PVR.GetChannelDetails", returntype="channeldetails", + fields=FIELDS_CHANNELS, optparam=("channelid", try_parse_int(db_id))) + + def channels(self, limits=None, channelgroupid="alltv"): + """get pvr channels from kodi db""" + return self.get_json("PVR.GetChannels", fields=FIELDS_CHANNELS, limits=limits, + returntype="channels", optparam=("channelgroupid", channelgroupid)) + + def channelgroups(self, limits=None, channeltype="tv"): + """get pvr channelgroups from kodi db""" + return self.get_json("PVR.GetChannelGroups", fields=[], limits=limits, + returntype="channelgroups", optparam=("channeltype", channeltype)) + + def timers(self, limits=None): + """get pvr recordings from kodi db""" + fields = ["title", "endtime", "starttime", "channelid", "summary", "file"] + return self.get_json("PVR.GetTimers", fields=fields, limits=limits, returntype="timers") + + def favourites(self): + """get kodi favourites""" + items = self.get_favourites_from_file() + if not items: + fields = ["path", "thumbnail", "window", "windowparameter"] + optparams = ("type", None) + items = self.get_json("Favourites.GetFavourites", fields=fields, optparam=optparams) + return items + + def castmedia(self, actorname): + """helper to display all media (movies/shows) for a specific actor""" + # use db counts as simple checksum + filters = [{"operator": "contains", "field": "actor", "value": actorname}] + all_items = self.movies(filters=filters) + for item in self.tvshows(filters=filters): + item["file"] = "videodb://tvshows/titles/%s" % item["tvshowid"] + item["isFolder"] = True + all_items.append(item) + return all_items + + def actors(self): + """return all actors""" + all_items = [] + all_actors = [] + result = self.files("videodb://movies/actors") + result += self.files("videodb://tvshows/actors") + for item in result: + if not item["label"] in all_actors: + all_actors.append(item["label"]) + item["type"] = "actor" + item["isFolder"] = True + if not item["art"].get("thumb"): + item["art"]["thumb"] = "DefaultActor.png" + all_items.append(item) + return sorted(all_items, key=itemgetter("label")) + + @staticmethod + def set_json(jsonmethod, params): + """method to set info in the kodi json api""" + kodi_json = {} + kodi_json["jsonrpc"] = "2.0" + kodi_json["method"] = jsonmethod + kodi_json["params"] = params + kodi_json["id"] = 1 + json_response = xbmc.executeJSONRPC(try_encode(json.dumps(kodi_json))) + if sys.version_info.major == 3: + return json.loads(json_response) + else: + return json.loads(json_response.decode('utf-8', 'replace')) + + @staticmethod + def get_json(jsonmethod, sort=None, filters=None, fields=None, limits=None, + returntype=None, optparam=None, filtertype=None): + """method to get details from the kodi json api""" + kodi_json = {} + kodi_json["jsonrpc"] = "2.0" + kodi_json["method"] = jsonmethod + kodi_json["params"] = {} + if optparam: + if isinstance(optparam, list): + for param in optparam: + kodi_json["params"][param[0]] = param[1] + else: + kodi_json["params"][optparam[0]] = optparam[1] + kodi_json["id"] = 1 + if sort: + kodi_json["params"]["sort"] = sort + if filters: + if not filtertype: + filtertype = "and" + if len(filters) > 1: + kodi_json["params"]["filter"] = {filtertype: filters} + else: + kodi_json["params"]["filter"] = filters[0] + if fields: + kodi_json["params"]["properties"] = fields + if limits: + kodi_json["params"]["limits"] = {"start": limits[0], "end": limits[1]} + json_response = xbmc.executeJSONRPC(try_encode(json.dumps(kodi_json))) + if sys.version_info.major == 3: + json_object = json.loads(json_response) + else: + json_object = json.loads(json_response.decode('utf-8', 'replace')) + # set the default returntype to prevent errors + if "details" in jsonmethod.lower(): + result = {} + else: + result = [] + if 'result' in json_object: + if returntype and returntype in json_object['result']: + # returntype specified, return immediately + result = json_object['result'][returntype] + else: + # no returntype specified, we'll have to look for it + if sys.version_info.major == 3: + for key, value in json_object['result'].items(): + if not key == "limits" and (isinstance(value, list) or isinstance(value, dict)): + result = value + else: + for key, value in json_object['result'].iteritems(): + if not key == "limits" and (isinstance(value, list) or isinstance(value, dict)): + result = value + else: + log_msg(json_response) + log_msg(kodi_json) + return result + + @staticmethod + def get_favourites_from_file(): + """json method for favourites doesn't return all items (such as android apps) so retrieve them from file""" + allfavourites = [] + try: + from xml.dom.minidom import parse + if sys.version_info.major == 3: + favourites_path = xbmcvfs.translatePath('special://profile/favourites.xml') + else: + favourites_path = xbmc.translatePath('special://profile/favourites.xml').decode("utf-8") + if xbmcvfs.exists(favourites_path): + doc = parse(favourites_path) + result = doc.documentElement.getElementsByTagName('favourite') + for fav in result: + action = fav.childNodes[0].nodeValue + action = action.replace('"', '') + label = fav.attributes['name'].nodeValue + try: + thumb = fav.attributes['thumb'].nodeValue + except Exception: + thumb = "" + window = "" + windowparameter = "" + action_type = "unknown" + if action.startswith("StartAndroidActivity"): + action_type = "androidapp" + elif action.startswith("ActivateWindow"): + action_type = "window" + actionparts = action.replace("ActivateWindow(", "").replace(",return)", "").split(",") + window = actionparts[0] + if len(actionparts) > 1: + windowparameter = actionparts[1] + elif action.startswith("PlayMedia"): + action_type = "media" + action = action.replace("PlayMedia(", "")[:-1] + allfavourites.append({"label": label, "path": action, "thumbnail": thumb, "window": window, + "windowparameter": windowparameter, "type": action_type}) + except Exception as exc: + log_exception(__name__, exc) + return allfavourites + + @staticmethod + def create_listitem(item, as_tuple=True, offscreen=True): + """helper to create a kodi listitem from kodi compatible dict with mediainfo""" + try: + if KODI_VERSION > 17: + liz = xbmcgui.ListItem( + label=item.get("label", ""), + label2=item.get("label2", ""), + path=item['file'], + offscreen=offscreen) + else: + liz = xbmcgui.ListItem( + label=item.get("label", ""), + label2=item.get("label2", ""), + path=item['file']) + + # only set isPlayable prop if really needed + if item.get("isFolder", False): + liz.setProperty('IsPlayable', 'false') + elif "plugin://script.skin.helper" not in item['file']: + liz.setProperty('IsPlayable', 'true') + + nodetype = "Video" + if item["type"] in ["song", "album", "artist"]: + nodetype = "Music" + + # extra properties + if sys.version_info.major == 3: + for key, value in item["extraproperties"].items(): + liz.setProperty(key, value) + else: + for key, value in item["extraproperties"].iteritems(): + liz.setProperty(key, value) + + # video infolabels + if nodetype == "Video": + infolabels = { + "title": item.get("title"), + "size": item.get("size"), + "genre": item.get("genre"), + "year": item.get("year"), + "top250": item.get("top250"), + "tracknumber": item.get("tracknumber"), + "rating": item.get("rating"), + "playcount": item.get("playcount"), + "overlay": item.get("overlay"), + "cast": item.get("cast"), + "castandrole": item.get("castandrole"), + "director": item.get("director"), + "mpaa": item.get("mpaa"), + "plot": item.get("plot"), + "plotoutline": item.get("plotoutline"), + "originaltitle": item.get("originaltitle"), + "sorttitle": item.get("sorttitle"), + "duration": item.get("duration"), + "studio": item.get("studio"), + "tagline": item.get("tagline"), + "writer": item.get("writer"), + "tvshowtitle": item.get("tvshowtitle"), + "premiered": item.get("premiered"), + "status": item.get("status"), + "code": item.get("imdbnumber"), + "imdbnumber": item.get("imdbnumber"), + "aired": item.get("aired"), + "credits": item.get("credits"), + "album": item.get("album"), + "artist": item.get("artist"), + "votes": item.get("votes"), + "trailer": item.get("trailer") + } + #ERROR: NEWADDON Unknown Video Info Key "progress" in Kodi 19 ?! + if KODI_VERSION < 18: + infolabels["progress"] = item.get('progresspercentage') + if item["type"] == "episode": + infolabels["season"] = item["season"] + infolabels["episode"] = item["episode"] + + # streamdetails + if item.get("streamdetails"): + liz.addStreamInfo("video", item["streamdetails"].get("video", {})) + liz.addStreamInfo("audio", item["streamdetails"].get("audio", {})) + liz.addStreamInfo("subtitle", item["streamdetails"].get("subtitle", {})) + + if "dateadded" in item: + infolabels["dateadded"] = item["dateadded"] + if "date" in item: + infolabels["date"] = item["date"] + + # music infolabels + else: + infolabels = { + "title": item.get("title"), + "size": item.get("size"), + "genre": item.get("genre"), + "year": item.get("year"), + "tracknumber": item.get("track"), + "album": item.get("album"), + "artist": " / ".join(item.get('artist')), + "rating": str(item.get("rating", 0)), + "lyrics": item.get("lyrics"), + "playcount": item.get("playcount") + } + if "date" in item: + infolabels["date"] = item["date"] + if "duration" in item: + infolabels["duration"] = item["duration"] + if "lastplayed" in item: + infolabels["lastplayed"] = item["lastplayed"] + + # setting the dbtype and dbid is supported from kodi krypton and up + if KODI_VERSION > 16 and item["type"] not in ["recording", "channel", "favourite", "genre", "categorie"]: + infolabels["mediatype"] = item["type"] + # setting the dbid on music items is not supported ? + if nodetype == "Video" and "DBID" in item["extraproperties"]: + infolabels["dbid"] = item["extraproperties"]["DBID"] + + if "lastplayed" in item: + infolabels["lastplayed"] = item["lastplayed"] + + # assign the infolabels + liz.setInfo(type=nodetype, infoLabels=infolabels) + + # artwork + liz.setArt(item.get("art", {})) + if KODI_VERSION > 17: + if "icon" in item: + liz.setArt({"icon":item['icon']}) + if "thumbnail" in item: + liz.setArt({"thumb":item['thumbnail']}) + else: + if "icon" in item: + liz.setIconImage(item['icon']) + if "thumbnail" in item: + liz.setThumbnailImage(item['thumbnail']) + + # contextmenu + if item["type"] in ["episode", "season"] and "season" in item and "tvshowid" in item: + # add series and season level to widgets + if "contextmenu" not in item: + item["contextmenu"] = [] + item["contextmenu"] += [ + (xbmc.getLocalizedString(20364), "ActivateWindow(Video,videodb://tvshows/titles/%s/,return)" + % (item["tvshowid"])), + (xbmc.getLocalizedString(20373), "ActivateWindow(Video,videodb://tvshows/titles/%s/%s/,return)" + % (item["tvshowid"], item["season"]))] + if "contextmenu" in item: + liz.addContextMenuItems(item["contextmenu"]) + + if as_tuple: + return item["file"], liz, item.get("isFolder", False) + else: + return liz + except Exception as exc: + log_exception(__name__, exc) + log_msg(item) + return None + + @staticmethod + def prepare_listitem(item): + """helper to convert kodi output from json api to compatible format for listitems""" + try: + # fix values returned from json to be used as listitem values + properties = item.get("extraproperties", {}) + + # set type + for idvar in [ + ('episode', 'DefaultTVShows.png'), + ('tvshow', 'DefaultTVShows.png'), + ('movie', 'DefaultMovies.png'), + ('song', 'DefaultAudio.png'), + ('album', 'DefaultAudio.png'), + ('artist', 'DefaultArtist.png'), + ('musicvideo', 'DefaultMusicVideos.png'), + ('recording', 'DefaultTVShows.png'), + ('channel', 'DefaultAddonPVRClient.png')]: + dbid = item.get(idvar[0] + "id") + if dbid: + properties["DBID"] = str(dbid) + if not item.get("type"): + item["type"] = idvar[0] + if not item.get("icon"): + item["icon"] = idvar[1] + break + + # general properties + if "genre" in item and isinstance(item['genre'], list): + item["genre"] = " / ".join(item['genre']) + if "studio" in item and isinstance(item['studio'], list): + item["studio"] = " / ".join(item['studio']) + if "writer" in item and isinstance(item['writer'], list): + item["writer"] = " / ".join(item['writer']) + if 'director' in item and isinstance(item['director'], list): + item["director"] = " / ".join(item['director']) + if 'artist' in item and not isinstance(item['artist'], list): + item["artist"] = [item['artist']] + if 'artist' not in item: + item["artist"] = [] + if item['type'] == "album" and 'album' not in item and 'label' in item: + item['album'] = item['label'] + if "duration" not in item and "runtime" in item: + if (item["runtime"] / 60) > 300: + item["duration"] = item["runtime"] / 60 + else: + item["duration"] = item["runtime"] + if "plot" not in item and "comment" in item: + item["plot"] = item["comment"] + if "tvshowtitle" not in item and "showtitle" in item: + item["tvshowtitle"] = item["showtitle"] + if "premiered" not in item and "firstaired" in item: + item["premiered"] = item["firstaired"] + if "firstaired" in item and "aired" not in item: + item["aired"] = item["firstaired"] + if "imdbnumber" not in properties and "imdbnumber" in item: + properties["imdbnumber"] = item["imdbnumber"] + if "imdbnumber" not in properties and "uniqueid" in item: + for value in item["uniqueid"].values(): + if value.startswith("tt"): + properties["imdbnumber"] = value + + properties["dbtype"] = item["type"] + properties["DBTYPE"] = item["type"] + properties["type"] = item["type"] + properties["path"] = item.get("file") + + # cast + list_cast = [] + list_castandrole = [] + item["cast_org"] = item.get("cast", []) + if "cast" in item and isinstance(item["cast"], list): + for castmember in item["cast"]: + if isinstance(castmember, dict): + list_cast.append(castmember.get("name", "")) + list_castandrole.append((castmember["name"], castmember["role"])) + else: + list_cast.append(castmember) + list_castandrole.append((castmember, "")) + + item["cast"] = list_cast + item["castandrole"] = list_castandrole + + if "season" in item and "episode" in item: + properties["episodeno"] = "s%se%s" % (item.get("season"), item.get("episode")) + if "resume" in item: + properties["resumetime"] = str(item['resume']['position']) + properties["totaltime"] = str(item['resume']['total']) + properties['StartOffset'] = str(item['resume']['position']) + + # streamdetails + if "streamdetails" in item: + streamdetails = item["streamdetails"] + audiostreams = streamdetails.get('audio', []) + videostreams = streamdetails.get('video', []) + subtitles = streamdetails.get('subtitle', []) + if len(videostreams) > 0: + stream = videostreams[0] + height = stream.get("height", "") + width = stream.get("width", "") + if height and width: + resolution = "" + if width <= 720 and height <= 480: + resolution = "480" + elif width <= 768 and height <= 576: + resolution = "576" + elif width <= 960 and height <= 544: + resolution = "540" + elif width <= 1280 and height <= 720: + resolution = "720" + elif width <= 1920 and height <= 1080: + resolution = "1080" + elif width * height >= 6000000: + resolution = "4K" + properties["VideoResolution"] = resolution + if stream.get("codec", ""): + properties["VideoCodec"] = str(stream["codec"]) + if stream.get("aspect", ""): + properties["VideoAspect"] = str(round(stream["aspect"], 2)) + item["streamdetails"]["video"] = stream + + # grab details of first audio stream + if len(audiostreams) > 0: + stream = audiostreams[0] + properties["AudioCodec"] = stream.get('codec', '') + properties["AudioChannels"] = str(stream.get('channels', '')) + properties["AudioLanguage"] = stream.get('language', '') + item["streamdetails"]["audio"] = stream + + # grab details of first subtitle + if len(subtitles) > 0: + properties["SubtitleLanguage"] = subtitles[0].get('language', '') + item["streamdetails"]["subtitle"] = subtitles[0] + else: + item["streamdetails"] = {} + item["streamdetails"]["video"] = {'duration': item.get('duration', 0)} + + # additional music properties + if 'album_description' in item: + properties["Album_Description"] = item.get('album_description') + + # pvr properties + if "starttime" in item: + # convert utc time to local time + item["starttime"] = localdate_from_utc_string(item["starttime"]) + item["endtime"] = localdate_from_utc_string(item["endtime"]) + # set some localized versions of the time and date as additional properties + startdate, starttime = localized_date_time(item['starttime']) + enddate, endtime = localized_date_time(item['endtime']) + properties["StartTime"] = starttime + properties["StartDate"] = startdate + properties["EndTime"] = endtime + properties["EndDate"] = enddate + properties["Date"] = "%s %s-%s" % (startdate, starttime, endtime) + properties["StartDateTime"] = "%s %s" % (startdate, starttime) + properties["EndDateTime"] = "%s %s" % (enddate, endtime) + # set date to startdate + item["date"] = arrow.get(item["starttime"]).format("DD.MM.YYYY") + if "channellogo" in item: + properties["channellogo"] = item["channellogo"] + properties["channelicon"] = item["channellogo"] + if "episodename" in item: + properties["episodename"] = item["episodename"] + if "channel" in item: + properties["channel"] = item["channel"] + properties["channelname"] = item["channel"] + item["label2"] = item["title"] + + # artwork + art = item.get("art", {}) + if item["type"] in ["episode", "season"]: + if not art.get("fanart") and art.get("season.fanart"): + art["fanart"] = art["season.fanart"] + if not art.get("poster") and art.get("season.poster"): + art["poster"] = art["season.poster"] + if not art.get("landscape") and art.get("season.landscape"): + art["poster"] = art["season.landscape"] + if not art.get("fanart") and art.get("tvshow.fanart"): + art["fanart"] = art.get("tvshow.fanart") + if not art.get("poster") and art.get("tvshow.poster"): + art["poster"] = art.get("tvshow.poster") + if not art.get("clearlogo") and art.get("tvshow.clearlogo"): + art["clearlogo"] = art.get("tvshow.clearlogo") + if not art.get("banner") and art.get("tvshow.banner"): + art["banner"] = art.get("tvshow.banner") + if not art.get("landscape") and art.get("tvshow.landscape"): + art["landscape"] = art.get("tvshow.landscape") + if not art.get("fanart") and item.get('fanart'): + art["fanart"] = item.get('fanart') + if not art.get("thumb") and item.get('thumbnail'): + art["thumb"] = get_clean_image(item.get('thumbnail')) + if not art.get("thumb") and art.get('poster'): + art["thumb"] = get_clean_image(art.get('poster')) + if not art.get("thumb") and item.get('icon'): + art["thumb"] = get_clean_image(item.get('icon')) + if not item.get("thumbnail") and art.get('thumb'): + item["thumbnail"] = art["thumb"] + + # clean art + if sys.version_info.major == 3: + for key, value in art.items(): + if not isinstance(value, str): + art[key] = "" + elif value: + art[key] = get_clean_image(value) + else: + if sys.version_info.major == 3: + for key, value in art.items(): + if not isinstance(value, str): + art[key] = "" + elif value: + art[key] = get_clean_image(value) + else: + for key, value in art.iteritems(): + if not isinstance(value, (str, unicode)): + art[key] = "" + elif value: + art[key] = get_clean_image(value) + item["art"] = art + + item["extraproperties"] = properties + + if "file" not in item: + log_msg("Item is missing file path ! --> %s" % item["label"], xbmc.LOGWARNING) + item["file"] = "" + + # return the result + return item + + except Exception as exc: + log_exception(__name__, exc) + log_msg(item) + return None + + @staticmethod + def tvshow_watchedcounts(tvshow): + """append watched counts to tvshow details""" + tvshow["extraproperties"] = {"totalseasons": str(tvshow["season"]), + "totalepisodes": str(tvshow["episode"]), + "watchedepisodes": str(tvshow["watchedepisodes"]), + "unwatchedepisodes": str(tvshow["episode"] - tvshow["watchedepisodes"]) + } + return tvshow diff --git a/resources/lib/metadatautils/helpers/lastfm.py b/resources/lib/metadatautils/helpers/lastfm.py new file mode 100644 index 0000000..0e6d772 --- /dev/null +++ b/resources/lib/metadatautils/helpers/lastfm.py @@ -0,0 +1,134 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +"""get metadata from the lastfm""" + +import os, sys +if sys.version_info.major == 3: + from .utils import get_json, strip_newlines, get_compare_string +else: + from utils import get_json, strip_newlines, get_compare_string +from simplecache import use_cache +import xbmcvfs + + +class LastFM(object): + """get metadata from the lastfm""" + api_key = "1869cecbff11c2715934b45b721e6fb0" + ignore_cache = False + + def __init__(self, simplecache=None): + """Initialize - optionaly provide simplecache object""" + if not simplecache: + from simplecache import SimpleCache + self.cache = SimpleCache() + else: + self.cache = simplecache + + def search(self, artist, album, track): + """get musicbrainz id's by query of artist, album and/or track""" + artistid = "" + albumid = "" + artist = artist.lower() + if artist and album: + params = {'method': 'album.getInfo', 'artist': artist, 'album': album} + data = self.get_data(params) + if data and data.get("album"): + lfmdetails = data["album"] + if lfmdetails.get("mbid"): + albumid = lfmdetails.get("mbid") + if lfmdetails.get("tracks") and lfmdetails["tracks"].get("track"): + for track in lfmdetails.get("tracks")["track"]: + found_artist = get_compare_string(track["artist"]["name"]) + if found_artist == get_compare_string(artist) and track["artist"].get("mbid"): + artistid = track["artist"]["mbid"] + break + if not (artistid or albumid) and artist and track: + params = {'method': 'track.getInfo', 'artist': artist, 'track': track} + data = self.get_data(params) + if data and data.get("track"): + lfmdetails = data["track"] + if lfmdetails.get('album position="1"'): + albumid = lfmdetails['album position="1"'].get("mbid") + if lfmdetails.get("artist") and lfmdetails["artist"].get("name"): + found_artist = get_compare_string(lfmdetails["artist"]["name"]) + if found_artist == get_compare_string(artist) and lfmdetails["artist"].get("mbid"): + artistid = lfmdetails["artist"]["mbid"] + return artistid, albumid + + def get_artist_id(self, artist, album, track): + """get musicbrainz id by query of artist, album and/or track""" + return self.search(artist, album, track)[0] + + def get_album_id(self, artist, album, track): + """get musicbrainz id by query of artist, album and/or track""" + return self.search(artist, album, track)[1] + + def artist_info(self, artist_id): + """get artist metadata by musicbrainz id""" + details = {"art": {}} + params = {'method': 'artist.getInfo', 'mbid': artist_id} + data = self.get_data(params) + if data and data.get("artist"): + lfmdetails = data["artist"] + #if lfmdetails.get("image"): + # for image in lfmdetails["image"]: + # if image["size"] in ["mega", "extralarge"] and xbmcvfs.exists(image["#text"]): + # details["art"]["thumbs"] = [image["#text"]] + # details["art"]["thumb"] = image["#text"] + if lfmdetails.get("bio") and lfmdetails["bio"].get("content"): + details["plot"] = strip_newlines(lfmdetails["bio"]["content"].split(' 1: + if " / " in tag["name"]: + taglst = tag["name"].split(" / ") + elif "/" in tag["name"]: + taglst = tag["name"].split("/") + elif " - " in tag["name"]: + taglst = tag["name"].split(" - ") + elif "-" in tag["name"]: + taglst = tag["name"].split("-") + else: + taglst = [tag["name"]] + for item in taglst: + if item not in result["tags"]: + result["tags"].append(item) + if item not in result["genre"] and int(tag["count"]) > 4: + result["genre"].append(item) + except Exception as exc: + log_msg("Error in musicbrainz - get album details: %s" % str(exc), xbmc.LOGWARNING) + return result + + @staticmethod + def get_albumthumb(albumid): + """get album thumb""" + thumb = "" + url = "http://coverartarchive.org/release-group/%s/front" % albumid + if xbmcvfs.exists(url): + thumb = url + return thumb + + @use_cache(14) + def search_release_group_match(self, artist, album): + """try to get a match on releasegroup for given artist/album combi""" + artistid = "" + albumid = "" + mb_albums = self.mbrainz.search_release_groups(query=album, + limit=20, offset=None, strict=False, artist=artist) + + if mb_albums and mb_albums.get("release-group-list"): + for albumtype in ["Album", "Single", ""]: + if artistid and albumid: + break + for mb_album in mb_albums["release-group-list"]: + if artistid and albumid: + break + if mb_album and isinstance(mb_album, dict): + if albumtype and albumtype != mb_album.get("primary-type", ""): + continue + if mb_album.get("artist-credit"): + artistid = self.match_artistcredit(mb_album["artist-credit"], artist) + if artistid: + albumid = mb_album.get("id", "") + break + return artistid, albumid + + @staticmethod + def match_artistcredit(artist_credit, artist): + """find match for artist in artist-credits""" + artistid = "" + for mb_artist in artist_credit: + if artistid: + break + if isinstance(mb_artist, dict) and mb_artist.get("artist", ""): + # safety check - only allow exact artist match + foundartist = mb_artist["artist"].get("name") + if sys.version_info.major < 3: + foundartist = foundartist.encode("utf-8").decode("utf-8") + if foundartist and get_compare_string(foundartist) == get_compare_string(artist): + artistid = mb_artist.get("artist").get("id") + break + if not artistid and mb_artist["artist"].get("alias-list"): + alias_list = [get_compare_string(item["alias"]) + for item in mb_artist["artist"]["alias-list"]] + if get_compare_string(artist) in alias_list: + artistid = mb_artist.get("artist").get("id") + break + for item in artist.split("&"): + item = get_compare_string(item) + if item in alias_list or item in get_compare_string(foundartist): + artistid = mb_artist.get("artist").get("id") + break + return artistid + + @use_cache(14) + def search_recording_match(self, artist, track): + """ + try to get the releasegroup (album) for the given artist/track combi + various-artists compilations are ignored + """ + artistid = "" + albumid = "" + mb_albums = self.mbrainz.search_recordings(query=track, + limit=20, offset=None, strict=False, artist=artist) + if mb_albums and mb_albums.get("recording-list"): + for mb_recording in mb_albums["recording-list"]: + if albumid and artistid: + break + if mb_recording and isinstance(mb_recording, dict): + # look for match on artist + if mb_recording.get("artist-credit"): + artistid = self.match_artistcredit(mb_recording["artist-credit"], artist) + # if we have a match on artist, look for match in release list + if artistid: + if mb_recording.get("release-list"): + for mb_release in mb_recording["release-list"]: + if mb_release.get("artist-credit"): + if mb_release["artist-credit"][0].get("id", "") == artistid: + albumid = mb_release["release-group"]["id"] + break + else: + continue + if mb_release.get("artist-credit-phrase", "") == 'Various Artists': + continue + # grab release group details to make sure we're + # not looking at some various artists compilation + mb_album = self.mbrainz.get_release_group_by_id( + mb_release["release-group"]["id"], includes=["artist-credits"]) + mb_album = mb_album["release-group"] + if mb_album.get("artist-credit"): + artistid = self.match_artistcredit(mb_album["artist-credit"], artist) + if artistid: + albumid = mb_release["release-group"]["id"] + break + return artistid, albumid diff --git a/resources/lib/metadatautils/helpers/moviesetdetails.py b/resources/lib/metadatautils/helpers/moviesetdetails.py new file mode 100644 index 0000000..3abd4bd --- /dev/null +++ b/resources/lib/metadatautils/helpers/moviesetdetails.py @@ -0,0 +1,185 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +""" + Returns complete (nicely formatted) information about the movieset and it's movies +""" + +import os, sys +if sys.version_info.major == 3: + from .kodi_constants import FIELDS_MOVIES + from .utils import get_duration, get_clean_image, extend_dict + from urllib.parse import quote_plus +else: + from kodi_constants import FIELDS_MOVIES + from utils import get_duration, get_clean_image, extend_dict + from urllib import quote_plus +from operator import itemgetter +import xbmc + + +def get_moviesetdetails(metadatautils, title, set_id): + """Returns complete (nicely formatted) information about the movieset and it's movies""" + details = {} + # try to get from cache first + # use checksum compare based on playcounts because moviesets do not get refreshed automatically + movieset = metadatautils.kodidb.movieset(set_id, ["playcount"]) + cache_str = "MovieSetDetails.%s" % set_id + cache_checksum = "%s.%s" % (set_id, metadatautils.studiologos_path) + if movieset and len(movieset["movies"]) < 50: + for movie in movieset["movies"]: + cache_checksum += "%s" % movie["playcount"] + cache = metadatautils.cache.get(cache_str, checksum=cache_checksum) + if cache: + return cache + # grab all details online and from kodi dbid + details = get_online_setdata(metadatautils, title) + details = extend_dict(details, get_kodidb_setdata(metadatautils, set_id)) + if not details.get("plot"): + details["plot"] = details["plots"] + details["extendedplot"] = details["titles"] + u"[CR]" + details["plot"] + all_fanarts = details["art"]["fanarts"] + efa_path = "plugin://script.skin.helper.service/?action=extrafanart&fanarts=%s" % quote_plus(repr(all_fanarts)) + details["art"]["extrafanart"] = efa_path + for count, fanart in enumerate(all_fanarts): + details["art"]["ExtraFanArt.%s" % count] = fanart + metadatautils.cache.set(cache_str, details, checksum=cache_checksum) + return details + + +def get_online_setdata(metadatautils, title): + """get moviesetdetails from TMDB and fanart.tv""" + details = metadatautils.tmdb.search_movieset(title) + if details: + # append images from fanart.tv + details["art"] = extend_dict( + details["art"], + metadatautils.fanarttv.movie(details["tmdb_id"]), + ["poster", "fanart", "clearlogo", "clearart"]) + return details + +# pylint: disable-msg=too-many-local-variables + + +def get_kodidb_setdata(metadatautils, set_id): + """get moviesetdetails from Kodi DB""" + details = {} + movieset = metadatautils.kodidb.movieset(set_id, FIELDS_MOVIES) + count = 0 + runtime = 0 + unwatchedcount = 0 + watchedcount = 0 + runtime = 0 + writer = [] + director = [] + genre = [] + countries = [] + studio = [] + years = [] + plot = "" + title_list = "" + total_movies = len(movieset['movies']) + title_header = "[B]%s %s[/B][CR]" % (total_movies, xbmc.getLocalizedString(20342)) + all_fanarts = [] + details["art"] = movieset["art"] + movieset_movies = sorted(movieset['movies'], key=itemgetter("year")) + for count, item in enumerate(movieset_movies): + if item["playcount"] == 0: + unwatchedcount += 1 + else: + watchedcount += 1 + + # generic labels + for label in ["label", "plot", "year", "rating"]: + details['%s.%s' % (count, label)] = item[label] + details["%s.DBID" % count] = item["movieid"] + details["%s.duration" % count] = item['runtime'] / 60 + + # art labels + art = item['art'] + for label in ["poster", "fanart", "landscape", "clearlogo", "clearart", "banner", "discart"]: + if art.get(label): + details['%s.art.%s' % (count, label)] = get_clean_image(art[label]) + if not movieset["art"].get(label): + movieset["art"][label] = get_clean_image(art[label]) + all_fanarts.append(get_clean_image(art.get("fanart"))) + + # streamdetails + if item.get('streamdetails', ''): + streamdetails = item["streamdetails"] + audiostreams = streamdetails.get('audio', []) + videostreams = streamdetails.get('video', []) + subtitles = streamdetails.get('subtitle', []) + if len(videostreams) > 0: + stream = videostreams[0] + height = stream.get("height", "") + width = stream.get("width", "") + if height and width: + resolution = "" + if width <= 720 and height <= 480: + resolution = "480" + elif width <= 768 and height <= 576: + resolution = "576" + elif width <= 960 and height <= 544: + resolution = "540" + elif width <= 1280 and height <= 720: + resolution = "720" + elif width <= 1920 and height <= 1080: + resolution = "1080" + elif width * height >= 6000000: + resolution = "4K" + details["%s.resolution" % count] = resolution + details["%s.Codec" % count] = stream.get("codec", "") + if stream.get("aspect", ""): + details["%s.aspectratio" % count] = round(stream["aspect"], 2) + if len(audiostreams) > 0: + # grab details of first audio stream + stream = audiostreams[0] + details["%s.audiocodec" % count] = stream.get('codec', '') + details["%s.audiochannels" % count] = stream.get('channels', '') + details["%s.audiolanguage" % count] = stream.get('language', '') + if len(subtitles) > 0: + # grab details of first subtitle + details["%s.SubTitle" % count] = subtitles[0].get('language', '') + + title_list += "%s (%s)[CR]" % (item['label'], item['year']) + if item['plotoutline']: + plot += "[B]%s (%s)[/B][CR]%s[CR][CR]" % (item['label'], item['year'], item['plotoutline']) + else: + plot += "[B]%s (%s)[/B][CR]%s[CR][CR]" % (item['label'], item['year'], item['plot']) + runtime += item['runtime'] + if item.get("writer"): + writer += [w for w in item["writer"] if w and w not in writer] + if item.get("director"): + director += [d for d in item["director"] if d and d not in director] + if item.get("genre"): + genre += [g for g in item["genre"] if g and g not in genre] + if item.get("country"): + countries += [c for c in item["country"] if c and c not in countries] + if item.get("studio"): + studio += [s for s in item["studio"] if s and s not in studio] + years.append(str(item['year'])) + details["plots"] = plot + if total_movies > 1: + details["extendedplots"] = title_header + title_list + "[CR]" + plot + else: + details["extendedplots"] = plot + details["titles"] = title_list + details["runtime"] = runtime / 60 + details.update(get_duration(runtime / 60)) + details["writer"] = writer + details["director"] = director + details["genre"] = genre + details["studio"] = studio + details["years"] = years + if len(years) > 1: + details["year"] = "%s - %s" % (years[0], years[-1]) + else: + details["year"] = years[0] if years else "" + details["country"] = countries + details["watchedcount"] = str(watchedcount) + details["unwatchedcount"] = str(unwatchedcount) + details.update(metadatautils.studiologos.get_studio_logo(studio, metadatautils.studiologos_path)) + details["count"] = total_movies + details["art"]["fanarts"] = all_fanarts + return details diff --git a/resources/lib/metadatautils/helpers/musicartwork.py b/resources/lib/metadatautils/helpers/musicartwork.py new file mode 100644 index 0000000..9f9ecef --- /dev/null +++ b/resources/lib/metadatautils/helpers/musicartwork.py @@ -0,0 +1,703 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +""" + script.module.metadatautils + musicartwork.py + Get metadata for music +""" + +import os, sys +if sys.version_info.major == 3: + from .utils import log_msg, extend_dict, ADDON_ID, strip_newlines, download_artwork, try_decode, manual_set_artwork + from .mbrainz import MusicBrainz + from urllib.parse import quote_plus +else: + from utils import log_msg, extend_dict, ADDON_ID, strip_newlines, download_artwork, try_decode, manual_set_artwork + from mbrainz import MusicBrainz + from urllib import quote_plus +import xbmc +import xbmcvfs +import xbmcgui +from difflib import SequenceMatcher as SM +from simplecache import use_cache + + +class MusicArtwork(object): + """get metadata and artwork for music""" + + def __init__(self, metadatautils): + """Initialize - optionaly provide our base MetadataUtils class""" + self._mutils = metadatautils + self.cache = self._mutils.cache + self.lastfm = self._mutils.lastfm + self.mbrainz = MusicBrainz() + self.audiodb = self._mutils.audiodb + + def get_music_artwork(self, artist, album, track, disc, ignore_cache=False, flush_cache=False, manual=False): + """ + get music metadata by providing artist and/or album/track + returns combined result of artist and album metadata + """ + if artist == track or album == track: + track = "" + artists = self.get_all_artists(artist, track) + album = self.get_clean_title(album) + track = self.get_clean_title(track) + + # retrieve artist and album details + artist_details = self.get_artists_metadata(artists, album, track, + ignore_cache=ignore_cache, flush_cache=flush_cache, manual=manual) + album_artist = artist_details.get("albumartist", artists[0]) + if album or track: + album_details = self.get_album_metadata(album_artist, album, track, disc, + ignore_cache=ignore_cache, flush_cache=flush_cache, manual=manual) + else: + album_details = {"art": {}} + + # combine artist details and album details + details = extend_dict(album_details, artist_details) + + # combine artist plot and album plot as extended plot + if artist_details.get("plot") and album_details.get("plot"): + details["extendedplot"] = "%s -- %s" % try_decode((album_details["plot"], artist_details["plot"])) + else: + details["extendedplot"] = details.get("plot", "") + + # append track title to results + if track: + details["title"] = track + + # return the endresult + return details + + def music_artwork_options(self, artist, album, track, disc): + """show options for music artwork""" + options = [] + options.append(self._mutils.addon.getLocalizedString(32028)) # Refresh item (auto lookup) + options.append(self._mutils.addon.getLocalizedString(32036)) # Choose art + options.append(self._mutils.addon.getLocalizedString(32034)) # Open addon settings + header = self._mutils.addon.getLocalizedString(32015) + dialog = xbmcgui.Dialog() + ret = dialog.select(header, options) + del dialog + if ret == 0: + # Refresh item (auto lookup) + self.get_music_artwork(artist, album, track, disc, ignore_cache=True) + elif ret == 1: + # Choose art + self.get_music_artwork(artist, album, track, disc, ignore_cache=True, manual=True) + elif ret == 2: + # Open addon settings + xbmc.executebuiltin("Addon.OpenSettings(%s)" % ADDON_ID) + + def get_artists_metadata(self, artists, album, track, ignore_cache=False, flush_cache=False, manual=False): + """collect artist metadata for all artists""" + artist_details = {"art": {}} + # for multi artist songs/albums we grab details from all artists + if len(artists) == 1: + # single artist + artist_details = self.get_artist_metadata( + artists[0], album, track, ignore_cache=ignore_cache, flush_cache=flush_cache, manual=manual) + artist_details["albumartist"] = artists[0] + else: + # multi-artist track + # The first artist with details is considered the main artist + # all others are assumed as featuring artists + artist_details = {"art": {}} + feat_artist_details = [] + for artist in artists: + if not (artist_details.get("plot") or artist_details.get("art")): + # get main artist details + artist_details["albumartist"] = artist + artist_details = self.get_artist_metadata( + artist, album, track, ignore_cache=ignore_cache, manual=manual) + else: + # assume featuring artist + feat_artist_details.append(self.get_artist_metadata( + artist, album, track, ignore_cache=ignore_cache, manual=manual)) + + # combined images to use as multiimage (for all artists) + # append featuring artist details + for arttype in ["banners", "fanarts", "clearlogos", "thumbs"]: + combined_art = [] + for artist_item in [artist_details] + feat_artist_details: + art = artist_item["art"].get(arttype, []) + if isinstance(art, list): + for item in art: + if item not in combined_art: + combined_art.append(item) + else: + for item in self._mutils.kodidb.files(art): + if item["file"] not in combined_art: + combined_art.append(item["file"]) + if combined_art: + # use the extrafanart plugin entry to display multi images + artist_details["art"][arttype] = "plugin://script.skin.helper.service/"\ + "?action=extrafanart&fanarts=%s" % quote_plus(repr(combined_art)) + # also set extrafanart path + if arttype == "fanarts": + artist_details["art"]["extrafanart"] = artist_details["art"][arttype] + # return the result + return artist_details + + # pylint: disable-msg=too-many-local-variables + def get_artist_metadata(self, artist, album, track, ignore_cache=False, flush_cache=False, manual=False): + """collect artist metadata for given artist""" + details = {"art": {}} + cache_str = "music_artwork.artist.%s" % artist.lower() + # retrieve details from cache + cache = self._mutils.cache.get(cache_str) + if not cache and flush_cache: + # nothing to do - just return empty results + return details + elif cache and flush_cache: + # only update kodi metadata for updated counts etc + details = extend_dict(self.get_artist_kodi_metadata(artist), cache) + elif cache and not ignore_cache: + # we have a valid cache - return that + details = cache + elif cache and manual: + # user wants to manually override the artwork in the cache + details = self.manual_set_music_artwork(cache, "artist") + else: + # nothing in cache - start metadata retrieval + log_msg("get_artist_metadata --> artist: %s - album: %s - track: %s" % (artist, album, track)) + details["cachestr"] = cache_str + local_path = "" + local_path_custom = "" + # get metadata from kodi db + details = extend_dict(details, self.get_artist_kodi_metadata(artist)) + # get artwork from songlevel path + if details.get("diskpath") and self._mutils.addon.getSetting("music_art_musicfolders") == "true": + details["art"] = extend_dict(details["art"], self.lookup_artistart_in_folder(details["diskpath"])) + local_path = details["diskpath"] + # get artwork from custom folder + custom_path = None + if self._mutils.addon.getSetting("music_art_custom") == "true": + if sys.version_info.major == 3: + custom_path = self._mutils.addon.getSetting("music_art_custom_path") + else: + custom_path = self._mutils.addon.getSetting("music_art_custom_path").decode("utf-8") + local_path_custom = self.get_customfolder_path(custom_path, artist) + #log_msg("custom path on disk for artist: %s --> %s" % (artist, local_path_custom)) + details["art"] = extend_dict(details["art"], self.lookup_artistart_in_folder(local_path_custom)) + details["customartpath"] = local_path_custom + # lookup online metadata + if self._mutils.addon.getSetting("music_art_scraper") == "true": + if not album and not track: + album = details.get("ref_album") + track = details.get("ref_track") + # prefer the musicbrainzid that is already in the kodi database - only perform lookup if missing + mb_artistid = details.get("musicbrainzartistid", self.get_mb_artist_id(artist, album, track)) + details["musicbrainzartistid"] = mb_artistid + if mb_artistid: + # get artwork from fanarttv + if self._mutils.addon.getSetting("music_art_scraper_fatv") == "true": + details["art"] = extend_dict(details["art"], self._mutils.fanarttv.artist(mb_artistid)) + # get metadata from theaudiodb + if self._mutils.addon.getSetting("music_art_scraper_adb") == "true": + details = extend_dict(details, self.audiodb.artist_info(artist)) + # get metadata from lastfm + if self._mutils.addon.getSetting("music_art_scraper_lfm") == "true": + details = extend_dict(details, self.lastfm.artist_info(mb_artistid)) + # download artwork to music folder + if local_path and self._mutils.addon.getSetting("music_art_download") == "true": + details["art"] = download_artwork(local_path, details["art"]) + # download artwork to custom folder + if local_path_custom and self._mutils.addon.getSetting("music_art_download_custom") == "true": + details["art"] = download_artwork(local_path_custom, details["art"]) + # fix extrafanart + if details["art"].get("fanarts"): + for count, item in enumerate(details["art"]["fanarts"]): + details["art"]["fanart.%s" % count] = item + if not details["art"].get("extrafanart") and len(details["art"]["fanarts"]) > 1: + details["art"]["extrafanart"] = "plugin://script.skin.helper.service/"\ + "?action=extrafanart&fanarts=%s" % quote_plus(repr(details["art"]["fanarts"])) + # multi-image path for all images for each arttype + for arttype in ["banners", "clearlogos", "thumbs"]: + art = details["art"].get(arttype, []) + if len(art) > 1: + # use the extrafanart plugin entry to display multi images + details["art"][arttype] = "plugin://script.skin.helper.service/"\ + "?action=extrafanart&fanarts=%s" % quote_plus(repr(art)) + # set default details + if not details.get("artist"): + details["artist"] = artist + if details["art"].get("thumb"): + details["art"]["artistthumb"] = details["art"]["thumb"] + + # always store results in cache and return results + self._mutils.cache.set(cache_str, details) + return details + + def get_album_metadata(self, artist, album, track, disc, ignore_cache=False, flush_cache=False, manual=False): + """collect all album metadata""" + cache_str = "music_artwork.album.%s.%s.%s" % (artist.lower(), album.lower(), disc.lower()) + if not album: + cache_str = "music_artwork.album.%s.%s" % (artist.lower(), track.lower()) + details = {"art": {}, "cachestr": cache_str} + + log_msg("get_album_metadata --> artist: %s - album: %s - track: %s" % (artist, album, track)) + # retrieve details from cache + cache = self._mutils.cache.get(cache_str) + if not cache and flush_cache: + # nothing to do - just return empty results + return details + elif cache and flush_cache: + # only update kodi metadata for updated counts etc + details = extend_dict(self.get_album_kodi_metadata(artist, album, track, disc), cache) + elif cache and not ignore_cache: + # we have a valid cache - return that + details = cache + elif cache and manual: + # user wants to manually override the artwork in the cache + details = self.manual_set_music_artwork(cache, "album") + else: + # nothing in cache - start metadata retrieval + local_path = "" + local_path_custom = "" + # get metadata from kodi db + details = extend_dict(details, self.get_album_kodi_metadata(artist, album, track, disc)) + if not album and details.get("title"): + album = details["title"] + # get artwork from songlevel path + if details.get("local_path_custom") and self._mutils.addon.getSetting("music_art_musicfolders") == "true": + details["art"] = extend_dict(details["art"], self.lookup_albumart_in_folder(details["local_path_custom"])) + local_path = details["local_path_custom"] + # get artwork from custom folder + custom_path = None + if self._mutils.addon.getSetting("music_art_custom") == "true": + if sys.version_info.major == 3: + custom_path = self._mutils.addon.getSetting("music_art_custom_path") + else: + custom_path = self._mutils.addon.getSetting("music_art_custom_path").decode("utf-8") + local_path_custom = self.get_custom_album_path(custom_path, artist, album, disc) + details["art"] = extend_dict(details["art"], self.lookup_albumart_in_folder(local_path_custom)) + details["customartpath"] = local_path_custom + # lookup online metadata + if self._mutils.addon.getSetting("music_art_scraper") == "true": + # prefer the musicbrainzid that is already in the kodi database - only perform lookup if missing + mb_albumid = details.get("musicbrainzalbumid") + if not mb_albumid: + mb_albumid = self.get_mb_album_id(artist, album, track) + adb_album = self.audiodb.get_album_id(artist, album, track) + if mb_albumid: + # get artwork from fanarttv + if self._mutils.addon.getSetting("music_art_scraper_fatv") == "true": + details["art"] = extend_dict(details["art"], self._mutils.fanarttv.album(mb_albumid)) + # get metadata from theaudiodb + if self._mutils.addon.getSetting("music_art_scraper_adb") == "true": + details = extend_dict(details, self.audiodb.album_info(artist, adb_album)) + # get metadata from lastfm + if self._mutils.addon.getSetting("music_art_scraper_lfm") == "true": + details = extend_dict(details, self.lastfm.album_info(mb_albumid)) + # metadata from musicbrainz + if not details.get("year") or not details.get("genre"): + details = extend_dict(details, self.mbrainz.get_albuminfo(mb_albumid)) + # musicbrainz thumb as last resort + if not details["art"].get("thumb"): + details["art"]["thumb"] = self.mbrainz.get_albumthumb(mb_albumid) + # download artwork to music folder + # get artwork from custom folder + # (yes again, but this time we might have an album where we didnt have that before) + if custom_path and not album and details.get("title"): + album = details["title"] + diskpath = self.get_custom_album_path(custom_path, artist, album, disc) + if diskpath: + details["art"] = extend_dict(details["art"], self.lookup_albumart_in_folder(diskpath)) + local_path_custom = diskpath + details["customartpath"] = diskpath + # download artwork to custom folder + if custom_path and self._mutils.addon.getSetting("music_art_download_custom") == "true": + if local_path_custom: + # allow folder creation if we enabled downloads and the folder does not exist + artist_path = self.get_customfolder_path(custom_path, artist) + if artist_path: + local_path_custom = os.path.join(artist_path, album) + else: + local_path_custom = os.path.join(custom_path, artist, album) + details["customartpath"] = local_path_custom + details["art"] = download_artwork(local_path_custom, details["art"]) + # set default details + if not details.get("album") and details.get("title"): + details["album"] = details["title"] + if details["art"].get("thumb"): + details["art"]["albumthumb"] = details["art"]["thumb"] + + # store results in cache and return results + self._mutils.cache.set(cache_str, details) +# self.write_kodidb(details) + return details + + # pylint: enable-msg=too-many-local-variables + + def get_artist_kodi_metadata(self, artist): + """get artist details from the kodi database""" + details = {} + filters = [{"operator": "is", "field": "artist", "value": artist}] + result = self._mutils.kodidb.artists(filters=filters, limits=(0, 1)) + if result: + details = result[0] + details["title"] = details["artist"] + details["plot"] = strip_newlines(details["description"]) + if details["musicbrainzartistid"] and isinstance(details["musicbrainzartistid"], list): + details["musicbrainzartistid"] = details["musicbrainzartistid"][0] + filters = [{"artistid": details["artistid"]}] + artist_albums = self._mutils.kodidb.albums(filters=filters) + details["albums"] = [] + details["albumsartist"] = [] + details["albumscompilations"] = [] + details["tracks"] = [] + if sys.version_info.major == 3: + bullet = "•" + else: + bullet = "•".decode("utf-8") + details["albums.formatted"] = u"" + details["tracks.formatted"] = u"" + details["tracks.formatted2"] = u"" + details["albumsartist.formatted"] = u"" + details["albumscompilations.formatted"] = u"" + # enumerate albums for this artist + for artist_album in artist_albums: + details["albums"].append(artist_album["label"]) + details["albums.formatted"] += u"%s %s [CR]" % (bullet, artist_album["label"]) + if artist in artist_album["displayartist"]: + details["albumsartist"].append(artist_album["label"]) + details["albumsartist.formatted"] += u"%s %s [CR]" % (bullet, artist_album["label"]) + else: + details["albumscompilations"].append(artist_album["label"]) + details["albumscompilations.formatted"] += u"%s %s [CR]" % (bullet, artist_album["label"]) + # enumerate songs for this album + filters = [{"albumid": artist_album["albumid"]}] + album_tracks = self._mutils.kodidb.songs(filters=filters) + if album_tracks: + # retrieve path on disk by selecting one song for this artist + if not details.get("ref_track") and not len(artist_album["artistid"]) > 1: + song_path = album_tracks[0]["file"] + details["diskpath"] = self.get_artistpath_by_songpath(song_path, artist) + details["ref_album"] = artist_album["title"] + details["ref_track"] = album_tracks[0]["title"] + for album_track in album_tracks: + details["tracks"].append(album_track["title"]) + tr_title = album_track["title"] + if album_track["track"]: + tr_title = "%s. %s" % (album_track["track"], album_track["title"]) + details["tracks.formatted"] += u"%s %s [CR]" % (bullet, tr_title) + duration = album_track["duration"] + total_seconds = int(duration) + minutes = total_seconds // 60 % 60 + seconds = total_seconds - (minutes * 60) + duration = "%s:%s" % (minutes, str(seconds).zfill(2)) + details["tracks.formatted2"] += u"%s %s (%s)[CR]" % (bullet, tr_title, duration) + details["albumcount"] = len(details["albums"]) + details["albumsartistcount"] = len(details["albumsartist"]) + details["albumscompilationscount"] = len(details["albumscompilations"]) + # do not retrieve artwork from item as there's no way to write it back + # and it will already be retrieved if user enables to get the artwork from the song path + return details + + def get_album_kodi_metadata(self, artist, album, track, disc): + """get album details from the kodi database""" + details = {} + filters = [{"operator": "contains", "field": "artist", "value": artist}] + if artist and track and not album: + # get album by track + filters.append({"operator": "contains", "field": "title", "value": track}) + result = self._mutils.kodidb.songs(filters=filters) + for item in result: + album = item["album"] + break + if artist and album: + filters.append({"operator": "contains", "field": "album", "value": album}) + result = self._mutils.kodidb.albums(filters=filters) + if result: + details = result[0] + details["plot"] = strip_newlines(details["description"]) + filters = [{"albumid": details["albumid"]}] + album_tracks = self._mutils.kodidb.songs(filters=filters) + details["artistid"] = details["artistid"][0] + details["tracks"] = [] + if sys.version_info.major == 3: + bullet = "•" + else: + bullet = "•".decode("utf-8") + details["tracks.formatted"] = u"" + details["tracks.formatted2"] = "" + details["runtime"] = 0 + for item in album_tracks: + details["tracks"].append(item["title"]) + details["tracks.formatted"] += u"%s %s [CR]" % (bullet, item["title"]) + duration = item["duration"] + total_seconds = int(duration) + minutes = total_seconds // 60 % 60 + seconds = total_seconds - (minutes * 60) + duration = "%s:%s" % (minutes, str(seconds).zfill(2)) + details["runtime"] += item["duration"] + details["tracks.formatted2"] += u"%s %s (%s)[CR]" % (bullet, item["title"], duration) + if not details.get("diskpath"): + if not disc or item["disc"] == int(disc): + details["diskpath"] = self.get_albumpath_by_songpath(item["file"]) + details["art"] = {} + details["songcount"] = len(album_tracks) + # get album total duration pretty printed as mm:ss + total_seconds = int(details["runtime"]) + minutes = total_seconds // 60 % 60 + seconds = total_seconds - (minutes * 60) + details["duration"] = "%s:%s" % (minutes, str(seconds).zfill(2)) + # do not retrieve artwork from item as there's no way to write it back + # and it will already be retrieved if user enables to get the artwork from the song path + return details + + def get_mb_artist_id(self, artist, album, track): + """lookup musicbrainz artist id with query of artist and album/track""" + artistid = self.mbrainz.get_artist_id(artist, album, track) + if not artistid and self._mutils.addon.getSetting("music_art_scraper_lfm") == "true": + artistid = self.lastfm.get_artist_id(artist, album, track) + if not artistid and self._mutils.addon.getSetting("music_art_scraper_adb") == "true": + artistid = self.audiodb.get_artist_id(artist, album, track) + return artistid + + def get_mb_album_id(self, artist, album, track): + """lookup musicbrainz album id with query of artist and album/track""" + albumid = self.mbrainz.get_album_id(artist, album, track) + if not albumid and self._mutils.addon.getSetting("music_art_scraper_lfm") == "true": + albumid = self.lastfm.get_album_id(artist, album, track) + if not albumid and self._mutils.addon.getSetting("music_art_scraper_adb") == "true": + albumid = self.audiodb.get_album_id(artist, album, track) + return albumid + + def manual_set_music_artwork(self, details, mediatype): + """manual override artwork options""" + if mediatype == "artist" and "artist" in details: + header = "%s: %s" % (xbmc.getLocalizedString(13511), details["artist"]) + else: + header = "%s: %s" % (xbmc.getLocalizedString(13511), xbmc.getLocalizedString(558)) + changemade, artwork = manual_set_artwork(details["art"], mediatype, header) + # save results if any changes made + if changemade: + details["art"] = artwork + refresh_needed = False + download_art = self._mutils.addon.getSetting("music_art_download") == "true" + download_art_custom = self._mutils.addon.getSetting("music_art_download_custom") == "true" + # download artwork to music folder if needed + if details.get("custom_path") and download_art: + details["art"] = download_artwork(details["custom_path"], details["art"]) + refresh_needed = True + # download artwork to custom folder if needed + if details.get("customartpath") and download_art_custom: + details["art"] = download_artwork(details["customartpath"], details["art"]) + refresh_needed = True + # reload skin to make sure new artwork is visible + if refresh_needed: + xbmc.sleep(500) + xbmc.executebuiltin("ReloadSkin()") + # return endresult + return details + + @staticmethod + def get_artistpath_by_songpath(songpath, artist): + """get the artist path on disk by listing the song's path""" + result = "" + if "\\" in songpath: + delim = "\\" + else: + delim = "/" + # just move up the directory tree (max 3 levels) untill we find the directory + for trypath in [songpath.rsplit(delim, 2)[0] + delim, + songpath.rsplit(delim, 3)[0] + delim, songpath.rsplit(delim, 1)[0] + delim]: + if trypath.split(delim)[-2].lower() == artist.lower(): + result = trypath + break + return result + + @staticmethod + def get_albumpath_by_songpath(songpath): + """get the album path on disk by listing the song's path""" + if "\\" in songpath: + delim = "\\" + else: + delim = "/" + return songpath.rsplit(delim, 1)[0] + delim + + @staticmethod + def lookup_artistart_in_folder(folderpath): + """lookup artwork in given folder""" + artwork = {} + if not folderpath or not xbmcvfs.exists(folderpath): + return artwork + files = xbmcvfs.listdir(folderpath)[1] + for item in files: + if sys.version_info.major < 3: + item = item.decode("utf-8") + if item in ["banner.jpg", "clearart.png", "poster.png", "fanart.jpg", "landscape.jpg"]: + key = item.split(".")[0] + artwork[key] = folderpath + item + elif item == "logo.png": + artwork["clearlogo"] = folderpath + item + elif item == "folder.jpg": + artwork["thumb"] = folderpath + item + # extrafanarts + efa_path = folderpath + "extrafanart/" + if xbmcvfs.exists(efa_path): + files = xbmcvfs.listdir(efa_path)[1] + artwork["fanarts"] = [] + if files: + artwork["extrafanart"] = efa_path + for item in files: + if sys.version_info.major == 3: + item = efa_path + item + else: + item = efa_path + item.decode("utf-8") + artwork["fanarts"].append(item) + return artwork + + @staticmethod + def lookup_albumart_in_folder(folderpath): + """lookup artwork in given folder""" + artwork = {} + if not folderpath or not xbmcvfs.exists(folderpath): + return artwork + files = xbmcvfs.listdir(folderpath)[1] + for item in files: + if sys.version_info.major < 3: + item = item.decode("utf-8") + if item in ["cdart.png", "disc.png"]: + artwork["discart"] = folderpath + item + if item in ["cdart2.png", "disc2.png"]: + artwork["discart2"] = folderpath + item + if item == "thumbback.jpg": + artwork["thumbback"] = folderpath + item + if item == "spine.jpg": + artwork["spine"] = folderpath + item + if item == "album3Dthumb.png": + artwork["album3Dthumb"] = folderpath + item + if item == "album3Dflat.png": + artwork["album3Dflat"] = folderpath + item + if item == "album3Dcase.png": + artwork["album3Dcase"] = folderpath + item + if item == "album3Dface.png": + artwork["album3Dface"] = folderpath + item + elif item == "folder.jpg": + artwork["thumb"] = folderpath + item + return artwork + + def get_custom_album_path(self, custom_path, artist, album, disc): + """try to locate the custom path for the album""" + artist_path = self.get_customfolder_path(custom_path, artist) + album_path = "" + if artist_path: + album_path = self.get_customfolder_path(artist_path, album) + if album_path and disc: + if "\\" in album_path: + delim = "\\" + else: + delim = "/" + dirs = xbmcvfs.listdir(album_path)[0] + for directory in dirs: + if sys.version_info.major < 3: + directory = directory.decode("utf-8") + if disc in directory: + return os.path.join(album_path, directory) + delim + return album_path + + def get_customfolder_path(self, customfolder, foldername, sublevel=False): + """search recursively (max 2 levels) for a specific folder""" + if sys.version_info.major == 3: + artistcustom_path = self._mutils.addon.getSetting("music_art_custom_path") + else: + artistcustom_path = self._mutils.addon.getSetting("music_art_custom_path").decode("utf-8") + cachestr = "customfolder_path.%s%s" % (customfolder, foldername) + folder_path = self.cache.get(cachestr) + if not folder_path: + if "\\" in customfolder: + delim = "\\" + else: + delim = "/" + dirs = xbmcvfs.listdir(customfolder)[0] + for strictness in [1, 0.95, 0.9, 0.85]: + for directory in dirs: + if sys.version_info.major < 3: + directory = directory.decode("utf-8") + curpath = os.path.join(customfolder, directory) + delim + match = SM(None, foldername.lower(), directory.lower()).ratio() + if match >= strictness: + folder_path = curpath + elif not sublevel: + # check if our requested path is in a sublevel of the current path + # restrict the number of sublevels to just one for now for performance reasons + folder_path = self.get_customfolder_path(curpath, foldername, True) + if folder_path: + break + if folder_path: + break + if not sublevel: + if not folder_path and self._mutils.addon.getSetting("music_art_download_custom") == "true": + # allow creation of folder if downloading is enabled + folder_path = os.path.join(customfolder, foldername) + delim + self.cache.set(cachestr, folder_path) + if not folder_path and self._mutils.addon.getSetting("music_art_download_custom") == "true": + folder_path = os.path.join(artistcustom_path, foldername) + delim + return folder_path + + @staticmethod + def get_clean_title(title): + """strip all unwanted characters from track name""" + title = title.split("f/")[0] + title = title.split("F/")[0] + title = title.split(" and ")[0] + title = title.split("(")[0] + title = title.split("[")[0] + title = title.split("ft.")[0] + title = title.split("Ft.")[0] + title = title.split("Feat.")[0] + title = title.split("feat")[0] + title = title.split("Featuring")[0] + title = title.split("featuring")[0] + title = title.split(" f/")[0] + title = title.split(" F/")[0] + title = title.split("/")[0] + title = title.split("Now On Air: ")[0] + title = title.split(" x ")[0] + title = title.split("vs.")[0] + title = title.split(" Ft ")[0] + title = title.split(" ft ")[0] + title = title.split(" & ")[0] + title = title.split(",")[0] + return title.strip() + + @staticmethod + def get_all_artists(artist, track): + """extract multiple artists from both artist and track string""" + artists = [] + feat_artists = [] + + # fix for band names which actually contain the kodi splitter (slash) in their name... + specials = ["AC/DC"] # to be completed with more artists + for special in specials: + if special in artist: + artist = artist.replace(special, special.replace("/", "")) + + for splitter in ["ft.", " ft ", "feat.", "Now On Air: ", " and ", "feat", "featuring", "Ft.", "Feat.", "F.", "F/", "f/", " Ft ", "Featuring", " x ", " & ", "vs.", ","]: + # replace splitter by kodi default splitter for easier split all later + artist = artist.replace(splitter, u"/") + + # extract any featuring artists from trackname + if splitter in track: + track_parts = track.split(splitter) + if len(track_parts) > 1: + feat_artist = track_parts[1].replace(")", "").replace("(", "").strip() + feat_artists.append(feat_artist) + + # break all artists string into list + all_artists = artist.split("/") + feat_artists + for item in all_artists: + item = item.strip() + if item not in artists: + artists.append(item) + # & can be a both a splitter or part of artist name + for item2 in item.split("&"): + item2 = item2.strip() + if item2 not in artists: + artists.append(item2) + return artists \ No newline at end of file diff --git a/resources/lib/metadatautils/helpers/omdb.py b/resources/lib/metadatautils/helpers/omdb.py new file mode 100644 index 0000000..d8e093b --- /dev/null +++ b/resources/lib/metadatautils/helpers/omdb.py @@ -0,0 +1,265 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +"""get metadata from omdb""" +import os, sys +if sys.version_info.major == 3: + from .utils import get_json, get_xml, formatted_number, int_with_commas, try_parse_int, KODI_LANGUAGE, ADDON_ID +else: + from utils import get_json, get_xml, formatted_number, int_with_commas, try_parse_int, KODI_LANGUAGE, ADDON_ID +from simplecache import use_cache +import arrow +import xbmc +import xbmcaddon + + +class Omdb(object): + """get metadata from omdb""" + api_key = None # public var to be set by the calling addon + + def __init__(self, simplecache=None): + """Initialize - optionaly provide simplecache object""" + if not simplecache: + from simplecache import SimpleCache + self.cache = SimpleCache() + else: + self.cache = simplecache + addon = xbmcaddon.Addon(id=ADDON_ID) + api_key = addon.getSetting("omdbapi_apikey") + if api_key: + self.api_key = api_key + del addon + + @use_cache(14) + def get_details_by_imdbid(self, imdb_id): + """get omdb details by providing an imdb id""" + params = {"i": imdb_id} + data = self.get_data(params) + return self.map_details(data) if data else None + + @use_cache(14) + def get_details_by_title(self, title, year="", media_type=""): + """ get omdb details by title + title --> The title of the media to look for (required) + year (str/int)--> The year of the media (optional, better results when provided) + media_type --> The type of the media: movie/tvshow (optional, better results of provided) + """ + if "movie" in media_type: + media_type = "movie" + elif media_type in ["tvshows", "tvshow"]: + media_type = "series" + params = {"t": title, "y": year, "type": media_type} + data = self.get_data(params) + return self.map_details(data) if data else None + + @use_cache(14) + def get_data(self, params): + """helper method to get data from omdb json API""" + base_url = 'http://www.omdbapi.com/' + params["plot"] = "full" + if self.api_key: + params["apikey"] = self.api_key + rate_limit = None + else: + # rate limited api key ! + params["apikey"] = "d4d53b9a" + rate_limit = ("omdbapi.com", 2) + params["r"] = "xml" + params["tomatoes"] = "true" + return get_xml(base_url, params, ratelimit=rate_limit) + + @staticmethod + def map_details(data): + """helper method to map the details received from omdb to kodi compatible format""" + result = {} + if sys.version_info.major == 3: + for key, value in data.items(): + # filter the N/A values + if value in ["N/A", "NA"] or not value: + continue + if key == "title": + result["title"] = value + elif key == "year": + try: + result["year"] = try_parse_int(value.split("-")[0]) + except Exception: + result["year"] = value + elif key == "rated": + result["mpaa"] = value.replace("Rated", "") + elif key == "released": + date_time = arrow.get(value, "DD MMM YYYY") + result["premiered"] = date_time.strftime(xbmc.getRegion("dateshort")) + try: + result["premiered.formatted"] = date_time.format('DD MMM YYYY', locale=KODI_LANGUAGE) + except Exception: + result["premiered.formatted"] = value + elif key == "runtime": + result["runtime"] = try_parse_int(value.replace(" min", "")) * 60 + elif key == "genre": + result["genre"] = value.split(", ") + elif key == "director": + result["director"] = value.split(", ") + elif key == "writer": + result["writer"] = value.split(", ") + elif key == "country": + result["country"] = value.split(", ") + elif key == "awards": + result["awards"] = value + elif key == "poster": + result["thumbnail"] = value + result["art"] = {} + result["art"]["thumb"] = value + elif key == "imdbVotes": + result["votes.imdb"] = value + result["votes"] = try_parse_int(value.replace(",", "")) + elif key == "imdbRating": + result["rating.imdb"] = value + result["rating"] = float(value) + result["rating.percent.imdb"] = "%s" % (try_parse_int(float(value) * 10)) + elif key == "metascore": + result["metacritic.rating"] = value + result["metacritic.rating.percent"] = "%s" % value + elif key == "imdbID": + result["imdbnumber"] = value + elif key == "BoxOffice": + result["boxoffice"] = value + elif key == "DVD": + date_time = arrow.get(value, "DD MMM YYYY") + result["dvdrelease"] = date_time.format('YYYY-MM-DD') + result["dvdrelease.formatted"] = date_time.format('DD MMM YYYY', locale=KODI_LANGUAGE) + elif key == "Production": + result["studio"] = value.split(", ") + elif key == "Website": + result["homepage"] = value + elif key == "plot": + result["plot"] = value + result["imdb.plot"] = value + elif key == "type": + if value == "series": + result["type"] = "tvshow" + else: + result["type"] = value + result["media_type"] = result["type"] + # rotten tomatoes + elif key == "tomatoMeter": + result["rottentomatoes.meter"] = value + result["rottentomatoesmeter"] = value + elif key == "tomatoImage": + result["rottentomatoes.image"] = value + elif key == "tomatoRating": + result["rottentomatoes.rating"] = value + result["rottentomatoes.rating.percent"] = "%s" % (try_parse_int(float(value) * 10)) + result["rating.rt"] = value + elif key == "tomatoReviews": + result["rottentomatoes.reviews"] = formatted_number(value) + elif key == "tomatoFresh": + result["rottentomatoes.fresh"] = value + elif key == "tomatoRotten": + result["rottentomatoes.rotten"] = value + elif key == "tomatoConsensus": + result["rottentomatoes.consensus"] = value + elif key == "tomatoUserMeter": + result["rottentomatoes.usermeter"] = value + elif key == "tomatoUserRating": + result["rottentomatoes.userrating"] = value + result["rottentomatoes.userrating.percent"] = "%s" % (try_parse_int(float(value) * 10)) + elif key == "tomatoUserReviews": + result["rottentomatoes.userreviews"] = int_with_commas(value) + elif key == "tomatoeURL": + result["rottentomatoes.url"] = value + else: + for key, value in data.iteritems(): + # filter the N/A values + if value in ["N/A", "NA"] or not value: + continue + if key == "title": + result["title"] = value + elif key == "Year": + try: + result["year"] = try_parse_int(value.split("-")[0]) + except Exception: + result["year"] = value + elif key == "rated": + result["mpaa"] = value.replace("Rated", "") + elif key == "released": + date_time = arrow.get(value, "DD MMM YYYY") + result["premiered"] = date_time.strftime(xbmc.getRegion("dateshort")) + try: + result["premiered.formatted"] = date_time.format('DD MMM YYYY', locale=KODI_LANGUAGE) + except Exception: + result["premiered.formatted"] = value + elif key == "runtime": + result["runtime"] = try_parse_int(value.replace(" min", "")) * 60 + elif key == "genre": + result["genre"] = value.split(", ") + elif key == "director": + result["director"] = value.split(", ") + elif key == "writer": + result["writer"] = value.split(", ") + elif key == "country": + result["country"] = value.split(", ") + elif key == "awards": + result["awards"] = value + elif key == "poster": + result["thumbnail"] = value + result["art"] = {} + result["art"]["thumb"] = value + elif key == "imdbVotes": + result["votes.imdb"] = value + result["votes"] = try_parse_int(value.replace(",", "")) + elif key == "imdbRating": + result["rating.imdb"] = value + result["rating"] = float(value) + result["rating.percent.imdb"] = "%s" % (try_parse_int(float(value) * 10)) + elif key == "metascore": + result["metacritic.rating"] = value + result["metacritic.rating.percent"] = "%s" % value + elif key == "imdbID": + result["imdbnumber"] = value + elif key == "BoxOffice": + result["boxoffice"] = value + elif key == "DVD": + date_time = arrow.get(value, "DD MMM YYYY") + result["dvdrelease"] = date_time.format('YYYY-MM-DD') + result["dvdrelease.formatted"] = date_time.format('DD MMM YYYY', locale=KODI_LANGUAGE) + elif key == "Production": + result["studio"] = value.split(", ") + elif key == "Website": + result["homepage"] = value + elif key == "plot": + result["plot"] = value + result["imdb.plot"] = value + elif key == "type": + if value == "series": + result["type"] = "tvshow" + else: + result["type"] = value + result["media_type"] = result["type"] + # rotten tomatoes + elif key == "tomatoMeter": + result["rottentomatoes.meter"] = value + result["rottentomatoesmeter"] = value + elif key == "tomatoImage": + result["rottentomatoes.image"] = value + elif key == "tomatoRating": + result["rottentomatoes.rating"] = value + result["rottentomatoes.rating.percent"] = "%s" % (try_parse_int(float(value) * 10)) + result["rating.rt"] = value + elif key == "tomatoReviews": + result["rottentomatoes.reviews"] = formatted_number(value) + elif key == "tomatoFresh": + result["rottentomatoes.fresh"] = value + elif key == "tomatoRotten": + result["rottentomatoes.rotten"] = value + elif key == "tomatoConsensus": + result["rottentomatoes.consensus"] = value + elif key == "tomatoUserMeter": + result["rottentomatoes.usermeter"] = value + elif key == "tomatoUserRating": + result["rottentomatoes.userrating"] = value + result["rottentomatoes.userrating.percent"] = "%s" % (try_parse_int(float(value) * 10)) + elif key == "tomatoUserReviews": + result["rottentomatoes.userreviews"] = int_with_commas(value) + elif key == "tomatoeURL": + result["rottentomatoes.url"] = value + return result diff --git a/resources/lib/metadatautils/helpers/pvrartwork.py b/resources/lib/metadatautils/helpers/pvrartwork.py new file mode 100644 index 0000000..43eb609 --- /dev/null +++ b/resources/lib/metadatautils/helpers/pvrartwork.py @@ -0,0 +1,544 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +""" + script.module.metadatautils + pvrartwork.py + Get metadata for Kodi PVR programs +""" + +import os, sys +if sys.version_info.major == 3: + from .utils import get_clean_image, DialogSelect, log_msg, extend_dict, ADDON_ID, download_artwork, normalize_string + from urllib.parse import quote_plus +else: + from utils import get_clean_image, DialogSelect, log_msg, extend_dict, ADDON_ID, download_artwork, normalize_string + from urllib import quote_plus +import xbmc +import xbmcgui +import xbmcvfs +from difflib import SequenceMatcher as SM +from operator import itemgetter +import re +from datetime import timedelta + + +class PvrArtwork(object): + """get artwork for kodi pvr""" + + def __init__(self, metadatautils): + """Initialize - optionaly provide our base MetadataUtils class""" + self._mutils = metadatautils + self.cache = self._mutils.cache + + def get_pvr_artwork(self, title, channel, genre="", manual_select=False, ignore_cache=False): + """ + collect full metadata and artwork for pvr entries + parameters: title (required) + channel: channel name (required) + year: year or date (optional) + genre: (optional) + the more optional parameters are supplied, the better the search results + """ + details = {"art": {}} + # try cache first + + # use searchtitle when searching cache + cache_title = title.lower() + cache_channel = channel.lower() + searchtitle = self.get_searchtitle(cache_title, cache_channel) + # original cache_str assignment cache_str = "pvr_artwork.%s.%s" % (title.lower(), channel.lower()) + cache_str = "pvr_artwork.%s.%s" % (searchtitle, channel.lower()) + cache = self._mutils.cache.get(cache_str) + if cache and not manual_select and not ignore_cache: + log_msg("get_pvr_artwork - return data from cache - %s" % cache_str) + details = cache + else: + # no cache - start our lookup adventure + log_msg("get_pvr_artwork - no data in cache - start lookup - %s" % cache_str) + + # workaround for recordings + recordingdetails = self.lookup_local_recording(title, channel) + if recordingdetails and not (channel and genre): + genre = recordingdetails["genre"] + channel = recordingdetails["channel"] + + details["pvrtitle"] = title + details["pvrchannel"] = channel + details["pvrgenre"] = genre + details["cachestr"] = cache_str + details["media_type"] = "" + details["art"] = {} + + # filter genre unknown/other + if not genre or genre.split(" / ")[0] in xbmc.getLocalizedString(19499).split(" / "): + details["genre"] = [] + genre = "" + log_msg("genre is unknown so ignore....") + else: + details["genre"] = genre.split(" / ") + details["media_type"] = self.get_mediatype_from_genre(genre) + searchtitle = self.get_searchtitle(title, channel) + + # only continue if we pass our basic checks + filterstr = self.pvr_proceed_lookup(title, channel, genre, recordingdetails) + proceed_lookup = False if filterstr else True + if not proceed_lookup and manual_select: + # warn user about active skip filter + proceed_lookup = xbmcgui.Dialog().yesno( + message=self._mutils.addon.getLocalizedString(32027), line2=filterstr, + heading=xbmc.getLocalizedString(750)) + + if proceed_lookup: + + # if manual lookup get the title from the user + if manual_select: + if sys.version_info.major == 3: + searchtitle = xbmcgui.Dialog().input(xbmc.getLocalizedString(16017), searchtitle, + type=xbmcgui.INPUT_ALPHANUM) + else: + searchtitle = xbmcgui.Dialog().input(xbmc.getLocalizedString(16017), searchtitle, + type=xbmcgui.INPUT_ALPHANUM).decode("utf-8") + if not searchtitle: + return + + # if manual lookup and no mediatype, ask the user + if manual_select and not details["media_type"]: + yesbtn = self._mutils.addon.getLocalizedString(32042) + nobtn = self._mutils.addon.getLocalizedString(32043) + header = self._mutils.addon.getLocalizedString(32041) + if xbmcgui.Dialog().yesno(header, header, yeslabel=yesbtn, nolabel=nobtn): + details["media_type"] = "movie" + else: + details["media_type"] = "tvshow" + + # append thumb from recordingdetails + if recordingdetails and recordingdetails.get("thumbnail"): + details["art"]["thumb"] = recordingdetails["thumbnail"] + # lookup custom path + details = extend_dict(details, self.lookup_custom_path(searchtitle, title)) + # lookup movie/tv library + details = extend_dict(details, self.lookup_local_library(searchtitle, details["media_type"])) + + # do internet scraping if enabled + if self._mutils.addon.getSetting("pvr_art_scraper") == "true": + + log_msg( + "pvrart start scraping metadata for title: %s - media_type: %s" % + (searchtitle, details["media_type"])) + + # prefer tmdb scraper + tmdb_result = self._mutils.get_tmdb_details( + "", "", searchtitle, "", "", details["media_type"], + manual_select=manual_select, ignore_cache=manual_select) + log_msg("pvrart lookup for title: %s - TMDB result: %s" % (searchtitle, tmdb_result)) + if tmdb_result: + details["media_type"] = tmdb_result["media_type"] + details = extend_dict(details, tmdb_result) + + # fallback to tvdb scraper + # following 3 lines added as part of "auto refresh" fix. ensure manual_select=true for TVDB lookup. No idea why this works + tempmanualselect = manual_select + manual_select="true" + log_msg("DEBUG INFO: TVDB lookup: searchtitle: %s channel: %s manual_select: %s" %(searchtitle, channel, manual_select)) + if (not tmdb_result or (tmdb_result and not tmdb_result.get("art")) or + details["media_type"] == "tvshow"): + # original code: tvdb_match = self.lookup_tvdb(searchtitle, channel, manual_select=manual_select). part of "auto refresh" fix. + tvdb_match = self.lookup_tvdb(searchtitle, channel, manual_select=manual_select, tempmanualselect=tempmanualselect) + log_msg("pvrart lookup for title: %s - TVDB result: %s" % (searchtitle, tvdb_match)) + if tvdb_match: + # get full tvdb results and extend with tmdb + if not details["media_type"]: + details["media_type"] = "tvshow" + details = extend_dict(details, self._mutils.thetvdb.get_series(tvdb_match)) + details = extend_dict(details, self._mutils.tmdb.get_videodetails_by_externalid( + tvdb_match, "tvdb_id"), ["poster", "fanart"]) + # part of "auto refresh" fix - revert manual_select to original value + manual_select = tempmanualselect + # fanart.tv scraping - append result to existing art + if details.get("imdbnumber") and details["media_type"] == "movie": + details["art"] = extend_dict( + details["art"], self._mutils.fanarttv.movie( + details["imdbnumber"]), [ + "poster", "fanart", "landscape"]) + elif details.get("tvdb_id") and details["media_type"] == "tvshow": + details["art"] = extend_dict( + details["art"], self._mutils.fanarttv.tvshow( + details["tvdb_id"]), [ + "poster", "fanart", "landscape"]) + + # append omdb details + if details.get("imdbnumber"): + details = extend_dict( + details, self._mutils.omdb.get_details_by_imdbid( + details["imdbnumber"]), [ + "rating", "votes"]) + + # set thumbnail - prefer scrapers + thumb = "" + if details.get("thumbnail"): + thumb = details["thumbnail"] + elif details["art"].get("landscape"): + thumb = details["art"]["landscape"] + elif details["art"].get("fanart"): + thumb = details["art"]["fanart"] + elif details["art"].get("poster"): + thumb = details["art"]["poster"] + # use google images as last-resort fallback for thumbs - if enabled + elif self._mutils.addon.getSetting("pvr_art_google") == "true": + if manual_select: + google_title = searchtitle + else: + google_title = '%s %s' % (searchtitle, "imdb") + thumb = self._mutils.google.search_image(google_title, manual_select) + if thumb: + details["thumbnail"] = thumb + details["art"]["thumb"] = thumb + # extrafanart + if details["art"].get("fanarts"): + for count, item in enumerate(details["art"]["fanarts"]): + details["art"]["fanart.%s" % count] = item + if not details["art"].get("extrafanart") and len(details["art"]["fanarts"]) > 1: + details["art"]["extrafanart"] = "plugin://script.skin.helper.service/"\ + "?action=extrafanart&fanarts=%s" % quote_plus(repr(details["art"]["fanarts"])) + + # download artwork to custom folder + if self._mutils.addon.getSetting("pvr_art_download") == "true": + details["art"] = download_artwork(self.get_custom_path(searchtitle, title), details["art"]) + + log_msg("pvrart lookup for title: %s - final result: %s" % (searchtitle, details)) + + # always store result in cache + # manual lookups should not expire too often + if manual_select: + self._mutils.cache.set(cache_str, details, expiration=timedelta(days=365)) + else: + self._mutils.cache.set(cache_str, details, expiration=timedelta(days=365)) + return details + + def manual_set_pvr_artwork(self, title, channel, genre): + """manual override artwork options""" + + details = self.get_pvr_artwork(title, channel, genre) + cache_str = details["cachestr"] + + # show dialogselect with all artwork option + from .utils import manual_set_artwork + changemade, artwork = manual_set_artwork(details["art"], "pvr") + if changemade: + details["art"] = artwork + # save results in cache + self._mutils.cache.set(cache_str, details, expiration=timedelta(days=365)) + + def pvr_artwork_options(self, title, channel, genre): + """show options for pvr artwork""" + if not channel and genre: + channel, genre = self.get_pvr_channel_and_genre(title) + ignorechannels = self._mutils.addon.getSetting("pvr_art_ignore_channels").split("|") + ignoretitles = self._mutils.addon.getSetting("pvr_art_ignore_titles").split("|") + options = [] + options.append(self._mutils.addon.getLocalizedString(32028)) # Refresh item (auto lookup) + options.append(self._mutils.addon.getLocalizedString(32029)) # Refresh item (manual lookup) + options.append(self._mutils.addon.getLocalizedString(32036)) # Choose art + if channel in ignorechannels: + options.append(self._mutils.addon.getLocalizedString(32030)) # Remove channel from ignore list + else: + options.append(self._mutils.addon.getLocalizedString(32031)) # Add channel to ignore list + if title in ignoretitles: + options.append(self._mutils.addon.getLocalizedString(32032)) # Remove title from ignore list + else: + options.append(self._mutils.addon.getLocalizedString(32033)) # Add title to ignore list + options.append(self._mutils.addon.getLocalizedString(32034)) # Open addon settings + header = self._mutils.addon.getLocalizedString(32035) + dialog = xbmcgui.Dialog() + ret = dialog.select(header, options) + del dialog + if ret == 0: + # Refresh item (auto lookup) + self.get_pvr_artwork(title=title, channel=channel, genre=genre, ignore_cache=True, manual_select=False) + elif ret == 1: + # Refresh item (manual lookup) + self.get_pvr_artwork(title=title, channel=channel, genre=genre, ignore_cache=True, manual_select=True) + elif ret == 2: + # Choose art + self.manual_set_pvr_artwork(title, channel, genre) + elif ret == 3: + # Add/remove channel to ignore list + if channel in ignorechannels: + ignorechannels.remove(channel) + else: + ignorechannels.append(channel) + ignorechannels_str = "|".join(ignorechannels) + self._mutils.addon.setSetting("pvr_art_ignore_channels", ignorechannels_str) + self.get_pvr_artwork(title=title, channel=channel, genre=genre, ignore_cache=True, manual_select=False) + elif ret == 4: + # Add/remove title to ignore list + if title in ignoretitles: + ignoretitles.remove(title) + else: + ignoretitles.append(title) + ignoretitles_str = "|".join(ignoretitles) + self._mutils.addon.setSetting("pvr_art_ignore_titles", ignoretitles_str) + self.get_pvr_artwork(title=title, channel=channel, genre=genre, ignore_cache=True, manual_select=False) + elif ret == 5: + # Open addon settings + xbmc.executebuiltin("Addon.OpenSettings(%s)" % ADDON_ID) + + def pvr_proceed_lookup(self, title, channel, genre, recordingdetails): + """perform some checks if we can proceed with the lookup""" + filters = [] + if not title: + filters.append("Title is empty") + for item in self._mutils.addon.getSetting("pvr_art_ignore_titles").split("|"): + if item and item.lower() == title.lower(): + filters.append("Title is in list of titles to ignore") + for item in self._mutils.addon.getSetting("pvr_art_ignore_channels").split("|"): + if item and item.lower() == channel.lower(): + filters.append("Channel is in list of channels to ignore") + for item in self._mutils.addon.getSetting("pvr_art_ignore_genres").split("|"): + if genre and item and item.lower() in genre.lower(): + filters.append("Genre is in list of genres to ignore") + if self._mutils.addon.getSetting("pvr_art_ignore_commongenre") == "true": + # skip common genres like sports, weather, news etc. + genre = genre.lower() + kodi_strings = [19516, 19517, 19518, 19520, 19548, 19549, 19551, + 19552, 19553, 19554, 19555, 19556, 19557, 19558, 19559] + for kodi_string in kodi_strings: + kodi_string = xbmc.getLocalizedString(kodi_string).lower() + if (genre and (genre in kodi_string or kodi_string in genre)) or kodi_string in title: + filters.append("Common genres like weather/sports are set to be ignored") + if self._mutils.addon.getSetting("pvr_art_recordings_only") == "true" and not recordingdetails: + filters.append("PVR Artwork is enabled for recordings only") + if filters: + filterstr = " - ".join(filters) + log_msg("PVR artwork - filter active for title: %s - channel %s --> %s" % (title, channel, filterstr)) + return filterstr + else: + return "" + + @staticmethod + def get_mediatype_from_genre(genre): + """guess media type from genre for better matching""" + media_type = "" + if "movie" in genre.lower() or "film" in genre.lower(): + media_type = "movie" + if "show" in genre.lower(): + media_type = "tvshow" + if not media_type: + # Kodi defined movie genres + kodi_genres = [19500, 19507, 19508, 19602, 19603] + for kodi_genre in kodi_genres: + if xbmc.getLocalizedString(kodi_genre) in genre: + media_type = "movie" + break + if not media_type: + # Kodi defined tvshow genres + kodi_genres = [19505, 19516, 19517, 19518, 19520, 19532, 19533, 19534, 19535, 19548, 19549, + 19550, 19551, 19552, 19553, 19554, 19555, 19556, 19557, 19558, 19559] + for kodi_genre in kodi_genres: + if xbmc.getLocalizedString(kodi_genre) in genre: + media_type = "tvshow" + break + return media_type + + def get_searchtitle(self, title, channel): + """common logic to get a proper searchtitle from crappy titles provided by pvr""" + if sys.version_info.major < 3: + if not isinstance(title, unicode): + title = title.decode("utf-8") + title = title.lower() + # split characters - split on common splitters + if sys.version_info.major == 3: + splitters = self._mutils.addon.getSetting("pvr_art_splittitlechar").split("|") + else: + splitters = self._mutils.addon.getSetting("pvr_art_splittitlechar").decode("utf-8").split("|") + if channel: + splitters.append(" %s" % channel.lower()) + for splitchar in splitters: + title = title.split(splitchar)[0] + # replace common chars and words + if sys.version_info.major == 3: + title = re.sub(self._mutils.addon.getSetting("pvr_art_replace_by_space"), ' ', title) + # following line removed as always seems to return blanks. also addon settings changed to replace ": " with " " + # title = re.sub(self._mutils.addon.getSetting("pvr_art_stripchars"), '', title) + else: + title = re.sub(self._mutils.addon.getSetting("pvr_art_replace_by_space").decode("utf-8"), ' ', title) + title = re.sub(self._mutils.addon.getSetting("pvr_art_stripchars").decode("utf-8"), '', title) + title = title.strip() + return title + + def lookup_local_recording(self, title, channel): + """lookup actual recordings to get details for grouped recordings + also grab a thumb provided by the pvr + """ + cache = self._mutils.cache.get("recordingdetails.%s%s" % (title, channel)) + if cache: + return cache + details = {} + recordings = self._mutils.kodidb.recordings() + for item in recordings: + if (title == item["title"] or title in item["file"]) and (channel == item["channel"] or not channel): + # grab thumb from pvr + if item.get("art"): + details["thumbnail"] = get_clean_image(item["art"].get("thumb")) + # ignore tvheadend thumb as it returns the channellogo + elif item.get("icon") and "imagecache" not in item["icon"]: + details["thumbnail"] = get_clean_image(item["icon"]) + details["channel"] = item["channel"] + details["genre"] = " / ".join(item["genre"]) + break + self._mutils.cache.set("recordingdetails.%s%s" % (title, channel), details) + return details + + # original code: def lookup_tvdb(self, searchtitle, channel, manual_select=False):. part of "auto refesh fix". + def lookup_tvdb(self, searchtitle, channel, manual_select=False, tempmanualselect=False): + """helper to select a match on tvdb""" + tvdb_match = None + searchtitle = searchtitle.lower() + tvdb_result = self._mutils.thetvdb.search_series(searchtitle, True) + searchchannel = channel.lower().split("hd")[0].replace(" ", "") + if " FHD" in channel: + searchchannel = channel.lower().split("fhd")[0].replace(" ", "") + if " HD" in channel: + searchchannel = channel.lower().split("hd")[0].replace(" ", "") + if " SD" in channel: + searchchannel = channel.lower().split("sd")[0].replace(" ", "") + match_results = [] + if tvdb_result: + for item in tvdb_result: + item["score"] = 0 + if not item["seriesName"]: + continue # seriesname can be None in some conditions + itemtitle = item["seriesName"].lower() + if not item["network"]: + continue # network can be None in some conditions + network = item["network"].lower().replace(" ", "") + # high score if channel name matches + if network in searchchannel or searchchannel in network: + item["score"] += 800 + # exact match on title - very high score + if searchtitle == itemtitle: + item["score"] += 1000 + # match title by replacing some characters + if re.sub('\*|,|.\"|\'| |:|;', '', searchtitle) == re.sub('\*|,|.\"|\'| |:|;', '', itemtitle): + item["score"] += 750 + # add SequenceMatcher score to the results + stringmatchscore = SM(None, searchtitle, itemtitle).ratio() + if stringmatchscore > 0.7: + item["score"] += stringmatchscore * 500 + # prefer items with artwork + if item["banner"]: + item["score"] += 1 + if item["score"] > 500 or manual_select: + match_results.append(item) + # sort our new list by score + match_results = sorted(match_results, key=itemgetter("score"), reverse=True) + # original code: if match_results and manual_select:. part of "auto refresh" fix. + if match_results and manual_select and tempmanualselect: + # show selectdialog to manually select the item + listitems = [] + for item in match_results: + thumb = "http://thetvdb.com%s" % item["poster"] if item["poster"] else "" + listitem = xbmcgui.ListItem(label=item["seriesName"]) + listitem.setArt({'icon': thumb}) + listitems.append(listitem) + dialog = DialogSelect( + "DialogSelect.xml", + "", + listing=listitems, + window_title="%s - TVDB" % + xbmc.getLocalizedString(283)) + dialog.doModal() + selected_item = dialog.result + del dialog + if selected_item != -1: + tvdb_match = match_results[selected_item]["id"] + else: + match_results = [] + if not tvdb_match and match_results: + # just grab the first item as best match + tvdb_match = match_results[0]["id"] + return tvdb_match + + def get_custom_path(self, searchtitle, title): + """locate custom folder on disk as pvrart location""" + title_path = "" + custom_path = self._mutils.addon.getSetting("pvr_art_custom_path") + if custom_path and self._mutils.addon.getSetting("pvr_art_custom") == "true": + delim = "\\" if "\\" in custom_path else "/" + dirs = xbmcvfs.listdir(custom_path)[0] + for strictness in [1, 0.95, 0.9, 0.8]: + if title_path: + break + for directory in dirs: + if title_path: + break + if sys.version_info.major < 3: + directory = directory.decode("utf-8") + curpath = os.path.join(custom_path, directory) + delim + for item in [title, searchtitle]: + match = SM(None, item, directory).ratio() + if match >= strictness: + title_path = curpath + break + if not title_path and self._mutils.addon.getSetting("pvr_art_download") == "true": + title_path = os.path.join(custom_path, normalize_string(title)) + delim + return title_path + + def lookup_custom_path(self, searchtitle, title): + """looks up a custom directory if it contains a subdir for our title""" + details = {} + details["art"] = {} + title_path = self.get_custom_path(searchtitle, title) + if title_path and xbmcvfs.exists(title_path): + # we have found a folder for the title, look for artwork + files = xbmcvfs.listdir(title_path)[1] + for item in files: + if sys.version_info.major < 3: + item = item.decode("utf-8") + if item in ["banner.jpg", "clearart.png", "poster.jpg", "disc.png", "characterart.png", + "fanart.jpg", "landscape.jpg"]: + key = item.split(".")[0] + details["art"][key] = title_path + item + elif item == "logo.png": + details["art"]["clearlogo"] = title_path + item + elif item == "thumb.jpg": + details["art"]["thumb"] = title_path + item + # extrafanarts + efa_path = title_path + "extrafanart/" + if xbmcvfs.exists(title_path + "extrafanart"): + files = xbmcvfs.listdir(efa_path)[1] + details["art"]["fanarts"] = [] + if files: + details["art"]["extrafanart"] = efa_path + for item in files: + if sys.version_info.major == 3: + item = efa_path + item + else: + item = efa_path + item.decode("utf-8") + details["art"]["fanarts"].append(item) + return details + + def lookup_local_library(self, title, media_type): + """lookup the title in the local video db""" + details = {} + filters = [{"operator": "is", "field": "title", "value": title}] + if not media_type or media_type == "tvshow": + kodi_items = self._mutils.kodidb.tvshows(filters=filters, limits=(0, 1)) + if kodi_items: + details = kodi_items[0] + details["media_type"] = "tvshow" + if not details and (not media_type or media_type == "movie"): + kodi_items = self._mutils.kodidb.movies(filters=filters, limits=(0, 1)) + if kodi_items: + details = kodi_items[0] + details["media_type"] = "movie" + if details: + if sys.version_info.major == 3: + for artkey, artvalue in details["art"].items(): + details["art"][artkey] = get_clean_image(artvalue) + else: + for artkey, artvalue in details["art"].iteritems(): + details["art"][artkey] = get_clean_image(artvalue) + # todo: check extrafanart ? + return details diff --git a/resources/lib/metadatautils/helpers/streamdetails.py b/resources/lib/metadatautils/helpers/streamdetails.py new file mode 100644 index 0000000..60d39b5 --- /dev/null +++ b/resources/lib/metadatautils/helpers/streamdetails.py @@ -0,0 +1,103 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +""" + script.module.metadatautils + streamdetails.py + Get all streamdetails for a kodi media item in database +""" + +import os, sys + +def get_streamdetails(kodidb, db_id, media_type): + """helper to get all streamdetails from a video item in kodi db""" + streamdetails = {} + # get data from json + if "movie" in media_type and "movieset" not in media_type: + json_result = kodidb.movie(db_id) + elif "episode" in media_type: + json_result = kodidb.episode(db_id) + elif "musicvideo" in media_type: + json_result = kodidb.musicvideo(db_id) + else: + json_result = {} + + if json_result and json_result["streamdetails"]: + audio = json_result["streamdetails"]['audio'] + subtitles = json_result["streamdetails"]['subtitle'] + video = json_result["streamdetails"]['video'] + all_audio_str = [] + all_subs = [] + all_lang = [] + for count, item in enumerate(audio): + # audio codec + codec = item['codec'] + if "ac3" in codec: + codec = u"Dolby D" + elif "dca" in codec: + codec = u"DTS" + elif "dts-hd" in codec or "dtshd" in codec: + codec = u"DTS HD" + # audio channels + channels = item['channels'] + if channels == 1: + channels = u"1.0" + elif channels == 2: + channels = u"2.0" + elif channels == 3: + channels = u"2.1" + elif channels == 4: + channels = u"4.0" + elif channels == 5: + channels = u"5.0" + elif channels == 6: + channels = u"5.1" + elif channels == 7: + channels = u"6.1" + elif channels == 8: + channels = u"7.1" + elif channels == 9: + channels = u"8.1" + elif channels == 10: + channels = u"9.1" + else: + channels = str(channels) + # audio language + language = item.get('language', '') + if language and language not in all_lang: + all_lang.append(language) + if language: + streamdetails['AudioStreams.%d.Language' % count] = item['language'] + if item['codec']: + streamdetails['AudioStreams.%d.AudioCodec' % count] = item['codec'] + if item['channels']: + streamdetails['AudioStreams.%d.AudioChannels' % count] = str(item['channels']) + if sys.version_info.major == 3: + joinchar = " • " + else: + joinchar = " • ".decode("utf-8") + audio_str = joinchar.join([language, codec, channels]) + if audio_str: + streamdetails['AudioStreams.%d' % count] = audio_str + all_audio_str.append(audio_str) + subs_count = 0 + subs_count_unique = 0 + for item in subtitles: + subs_count += 1 + if item['language'] not in all_subs: + all_subs.append(item['language']) + streamdetails['Subtitles.%d' % subs_count_unique] = item['language'] + subs_count_unique += 1 + streamdetails['subtitles'] = all_subs + streamdetails['subtitles.count'] = str(subs_count) + streamdetails['allaudiostreams'] = all_audio_str + streamdetails['audioStreams.count'] = str(len(all_audio_str)) + streamdetails['languages'] = all_lang + streamdetails['languages.count'] = len(all_lang) + if len(video) > 0: + stream = video[0] + streamdetails['videoheight'] = stream.get("height", 0) + streamdetails['videowidth'] = stream.get("width", 0) + if json_result.get("tag"): + streamdetails["tags"] = json_result["tag"] + return streamdetails diff --git a/resources/lib/metadatautils/helpers/studiologos.py b/resources/lib/metadatautils/helpers/studiologos.py new file mode 100644 index 0000000..154a5a2 --- /dev/null +++ b/resources/lib/metadatautils/helpers/studiologos.py @@ -0,0 +1,112 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +"""Helper for studio logo images""" + +import xbmcvfs +import os, sys +from datetime import timedelta +from simplecache import use_cache +if sys.version_info.major == 3: + from .utils import try_decode +else: + from utils import try_decode + + +class StudioLogos(): + """Helper class for studio logo images""" + + def __init__(self, simplecache=None): + """Initialize - optionaly provide simplecache object""" + if not simplecache: + from simplecache import SimpleCache + self.cache = SimpleCache() + else: + self.cache = simplecache + + @use_cache(14) + def get_studio_logo(self, studios, lookup_path): + """get the studio logo for the given studio string(s)""" + if not studios: + return {} + result = {} + if not isinstance(studios, list): + studios = studios.split(" / ") + result["Studio"] = studios[0] + result['Studios'] = "[CR]".join(studios) + result['StudioLogo'] = self.match_studio_logo(studios, self.get_studio_logos(lookup_path)) + return result + + def get_studio_logos(self, lookup_path): + """get all studio logos""" + cache_str = u"SkinHelper.StudioLogos" + cache = self.cache.get(cache_str, checksum=lookup_path) + if cache: + return cache + # no cache - start lookup + all_logos = {} + if lookup_path.startswith("resource://"): + all_logos = self.get_resource_addon_files(lookup_path) + else: + if not (lookup_path.endswith("/") or lookup_path.endswith("\\")): + lookup_path = lookup_path + os.sep + all_logos = self.list_files_in_path(lookup_path) + # save in cache and return + self.cache.set(cache_str, all_logos, expiration=timedelta(days=14), checksum=lookup_path) + return all_logos + + @staticmethod + def match_studio_logo(studios, studiologos): + """try to find a matching studio logo""" + studiologo = "" + for studio in studios: + if studiologo: + break + studio = studio.lower() + # find logo normal + if studio in studiologos: + studiologo = studiologos[studio] + if not studiologo: + # find logo by substituting characters + if " (" in studio: + studio = studio.split(" (")[0] + if studio in studiologos: + studiologo = studiologos[studio] + if not studiologo: + # find logo by substituting characters for pvr channels + if " HD" in studio: + studio = studio.replace(" HD", "") + elif " " in studio: + studio = studio.replace(" ", "") + if studio in studiologos: + studiologo = studiologos[studio] + return studiologo + + @use_cache(90) + def get_resource_addon_files(self, resourcepath): + """get listing of all files (eg studio logos) inside a resource image addonName + read data from our permanent cache file to prevent that we have to query the resource addon""" + return self.list_files_in_path(resourcepath) + + @staticmethod + def list_files_in_path(filespath): + """used for easy matching of studio logos""" + all_files = {} + dirs, files = xbmcvfs.listdir(filespath) + if "/" in filespath: + sep = "/" + else: + sep = "\\" + for file in files: + file = try_decode(file) + name = file.split(".png")[0].lower() + all_files[name] = filespath + file + for directory in dirs: + directory = try_decode(directory) + files = xbmcvfs.listdir(os.path.join(filespath, directory) + sep)[1] + for file in files: + file = try_decode(file) + name = directory + "/" + file.split(".png")[0].lower() + all_files[name] = filespath + directory + sep + file + # return the list + return all_files diff --git a/resources/lib/metadatautils/helpers/theaudiodb.py b/resources/lib/metadatautils/helpers/theaudiodb.py new file mode 100644 index 0000000..4ca395c --- /dev/null +++ b/resources/lib/metadatautils/helpers/theaudiodb.py @@ -0,0 +1,200 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +""" + script.module.metadatautils + theaudiodb.py + Get metadata from theaudiodb +""" + +import os, sys +if sys.version_info.major == 3: + from .utils import get_json, strip_newlines, KODI_LANGUAGE, get_compare_string, ADDON_ID +else: + from utils import get_json, strip_newlines, KODI_LANGUAGE, get_compare_string, ADDON_ID +from simplecache import use_cache +import xbmcvfs +import xbmcaddon + +class TheAudioDb(object): + """get metadata from the audiodb""" + api_key = None # public var to be set by the calling addon + ignore_cache = False + + def __init__(self, simplecache=None): + """Initialize - optionaly provide simplecache object""" + if not simplecache: + from simplecache import SimpleCache + self.cache = SimpleCache() + else: + self.cache = simplecache + addon = xbmcaddon.Addon(id=ADDON_ID) + api_key = addon.getSetting("adb_apikey") + if api_key: + self.api_key = api_key + del addon + + def search(self, artist, album, track): + """get musicbrainz id by query of artist, album and/or track""" + artist = artist.lower() + params = {'s': artist, 'a': album} + data = self.get_data("searchalbum.php", params) + if not album and track: + params = {'t': track, 's': artist} + data = self.get_data("searchtrack.php", params) + if data and data.get("track") and len(data.get("track")) > 0: + adbdetails = data["track"][0] + # safety check - only allow exact artist match + foundartist = adbdetails.get("strArtist", "").lower() + if foundartist and get_compare_string(foundartist) == get_compare_string(artist): + album = adbdetails.get("strAlbum", "") + artist = adbdetails.get("strArtist", "") + if data and data.get("album") and len(data.get("album")) > 0: + adbdetails = data["album"][0] + # safety check - only allow exact artist match + foundartist = adbdetails.get("strArtist", "").lower() + if foundartist and get_compare_string(foundartist) == get_compare_string(artist): + album = adbdetails.get("strAlbum", "") + artist = adbdetails.get("strArtist", "") + return artist, album + + def get_artist_id(self, artist, album, track): + """get musicbrainz id by query of artist, album and/or track""" + return self.search(artist, album, track)[0] + + def get_album_id(self, artist, album, track): + """get musicbrainz id by query of artist, album and/or track""" + return self.search(artist, album, track)[1] + + def artist_info(self, artist): + """get artist metadata by artist""" + details = {"art": {}} + data = self.get_data("/search.php", {'s': artist}) + if data and data.get("artists"): + adbdetails = data["artists"][0] + if adbdetails.get("strArtistBanner") and xbmcvfs.exists(adbdetails.get("strArtistBanner")): + details["art"]["banner"] = adbdetails.get("strArtistBanner") + details["art"]["banners"] = [adbdetails.get("strArtistBanner")] + details["art"]["fanarts"] = [] + if adbdetails.get("strArtistFanart") and xbmcvfs.exists(adbdetails.get("strArtistFanart")): + details["art"]["fanart"] = adbdetails.get("strArtistFanart") + details["art"]["fanarts"].append(adbdetails.get("strArtistFanart")) + if adbdetails.get("strArtistFanart2") and xbmcvfs.exists(adbdetails.get("strArtistFanart2")): + details["art"]["fanarts"].append(adbdetails.get("strArtistFanart2")) + if adbdetails.get("strArtistFanart3") and xbmcvfs.exists(adbdetails.get("strArtistFanart3")): + details["art"]["fanarts"].append(adbdetails.get("strArtistFanart3")) + if adbdetails.get("strArtistWideThumb") and xbmcvfs.exists(adbdetails.get("strArtistWideThumb")): + details["art"]["landscape"] = adbdetails.get("strArtistWideThumb") + if adbdetails.get("strArtistLogo") and xbmcvfs.exists(adbdetails.get("strArtistLogo")): + details["art"]["clearlogo"] = adbdetails.get("strArtistLogo") + details["art"]["clearlogos"] = [adbdetails.get("strArtistLogo")] + if adbdetails.get("strArtistClearart") and xbmcvfs.exists(adbdetails.get("strArtistClearart")): + details["art"]["clearart"] = adbdetails.get("strArtistClearart") + details["art"]["cleararts"] = [adbdetails.get("strArtistClearart")] + if adbdetails.get("strArtistThumb") and xbmcvfs.exists(adbdetails.get("strArtistThumb")): + details["art"]["thumb"] = adbdetails["strArtistThumb"] + details["art"]["thumbs"] = [adbdetails["strArtistThumb"]] + if adbdetails.get("strBiography" + KODI_LANGUAGE.upper()): + details["plot"] = adbdetails["strBiography" + KODI_LANGUAGE.upper()] + if adbdetails.get("strBiographyEN") and not details.get("plot"): + details["plot"] = adbdetails.get("strBiographyEN") + if details.get("plot"): + details["plot"] = strip_newlines(details["plot"]) + if adbdetails.get("strArtistAlternate"): + details["alternamename"] = adbdetails["strArtistAlternate"] + if adbdetails.get("intFormedYear"): + details["formed"] = adbdetails["intFormedYear"] + if adbdetails.get("intBornYear"): + details["born"] = adbdetails["intBornYear"] + if adbdetails.get("intDiedYear"): + details["died"] = adbdetails["intDiedYear"] + if adbdetails.get("strDisbanded"): + details["disbanded"] = adbdetails["strDisbanded"] + if adbdetails.get("strStyle"): + details["style"] = adbdetails["strStyle"].split("/") + if adbdetails.get("strGenre"): + details["genre"] = adbdetails["strGenre"].split("/") + if adbdetails.get("strMood"): + details["mood"] = adbdetails["strMood"].split("/") + if adbdetails.get("strWebsite"): + details["homepage"] = adbdetails["strWebsite"] + if adbdetails.get("strFacebook"): + details["facebook"] = adbdetails["strFacebook"] + if adbdetails.get("strTwitter"): + details["twitter"] = adbdetails["strTwitter"] + if adbdetails.get("strGender"): + details["gender"] = adbdetails["strGender"] + if adbdetails.get("intMembers"): + details["members"] = adbdetails["intMembers"] + if adbdetails.get("strCountry"): + details["country"] = adbdetails["strCountry"].split(", ") + return details + + def album_info(self, artist, album): + """get album metadata by name""" + details = {"art": {}} + data = self.get_data("/searchalbum.php", {'s': artist, 'a': album}) + if data and data.get("album"): + adbdetails = data["album"][0] + if adbdetails.get("strAlbumThumb") and xbmcvfs.exists(adbdetails.get("strAlbumThumb")): + details["art"]["thumb"] = adbdetails.get("strAlbumThumb") + details["art"]["thumbs"] = [adbdetails.get("strAlbumThumb")] + if adbdetails.get("strAlbumCDart") and xbmcvfs.exists(adbdetails.get("strAlbumCDart")): + details["art"]["discart"] = adbdetails.get("strAlbumCDart") + details["art"]["discarts"] = [adbdetails.get("strAlbumCDart")] + if adbdetails.get("strAlbumSpine") and xbmcvfs.exists(adbdetails.get("strAlbumSpine")): + details["art"]["spine"] = adbdetails.get("strAlbumSpine") + if adbdetails.get("strAlbumThumbBack") and xbmcvfs.exists(adbdetails.get("strAlbumThumbBack")): + details["art"]["thumbback"] = adbdetails.get("strAlbumThumbBack") + if adbdetails.get("strAlbum3DCase") and xbmcvfs.exists(adbdetails.get("strAlbum3DCase")): + details["art"]["album3Dcase"] = adbdetails.get("strAlbum3DCase") + if adbdetails.get("strAlbum3DFlat") and xbmcvfs.exists(adbdetails.get("strAlbum3DFlat")): + details["art"]["album3Dflat"] = adbdetails.get("strAlbum3DFlat") + if adbdetails.get("strAlbum3DFace") and xbmcvfs.exists(adbdetails.get("strAlbum3DFace")): + details["art"]["album3Dface"] = adbdetails.get("strAlbum3DFace") + if adbdetails.get("strAlbum3DThumb") and xbmcvfs.exists(adbdetails.get("strAlbum3DThumb")): + details["art"]["album3Dthumb"] = adbdetails.get("strAlbum3DThumb") + if adbdetails.get("strDescription%s" % KODI_LANGUAGE.upper()): + details["plot"] = adbdetails.get("strDescription%s" % KODI_LANGUAGE.upper()) + if not details.get("plot") and adbdetails.get("strDescriptionEN"): + details["plot"] = adbdetails.get("strDescriptionEN") + if details.get("plot"): + details["plot"] = strip_newlines(details["plot"]) + if adbdetails.get("strGenre"): + details["genre"] = adbdetails["strGenre"].split("/") + if adbdetails.get("strStyle"): + details["style"] = adbdetails["strStyle"].split("/") + if adbdetails.get("strMood"): + details["mood"] = adbdetails["strMood"].split("/") + if adbdetails.get("intYearReleased"): + details["year"] = adbdetails["intYearReleased"] + if adbdetails.get("intScore"): + details["rating"] = adbdetails["intScore"] + if adbdetails.get("strAlbum"): + details["title"] = adbdetails["strAlbum"] + if adbdetails.get("strLabel"): + details["albumlabel"] = adbdetails["strLabel"] + if adbdetails.get("idAlbum"): + details["idalbum"] = adbdetails["idAlbum"] + if adbdetails.get("idAlbum"): + idalbum = adbdetails.get("idAlbum", "") + data = self.get_data("/track.php", {'m': idalbum}) + adbtrackdetails = data["track"] + if data.get("track"): + tracks = [] + for count, item in enumerate(adbtrackdetails): + tracks.append(item["strTrack"]) + details["tracks.formatted.%s" % count] = item["intTrackNumber"] + "." + item["strTrack"] + details["tracks.clean.formatted.%s" % count] = item["strTrack"] + details["tracks.formatted"] = "[CR]".join(tracks) + return details + + @use_cache(60) + def get_data(self, endpoint, params): + """helper method to get data from theaudiodb json API""" + endpoint = 'https://www.theaudiodb.com/api/v1/json/%s/%s' % (self.api_key, endpoint) + data = get_json(endpoint, params) + if data: + return data + else: + return {} diff --git a/resources/lib/metadatautils/helpers/tmdb.py b/resources/lib/metadatautils/helpers/tmdb.py new file mode 100644 index 0000000..f6aa01e --- /dev/null +++ b/resources/lib/metadatautils/helpers/tmdb.py @@ -0,0 +1,468 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +""" + script.module.metadatautils + tmdb.py + Get metadata from The Movie Database +""" + +import os, sys +if sys.version_info.major == 3: + from .utils import get_json, KODI_LANGUAGE, try_parse_int, DialogSelect, get_compare_string, int_with_commas, ADDON_ID +else: + from utils import get_json, KODI_LANGUAGE, try_parse_int, DialogSelect, get_compare_string, int_with_commas, ADDON_ID +from difflib import SequenceMatcher as SM +from simplecache import use_cache +from operator import itemgetter +import xbmc +import xbmcgui +import xbmcaddon +import datetime + + +class Tmdb(object): + """get metadata from tmdb""" + api_key = None # public var to be set by the calling addon + + def __init__(self, simplecache=None, api_key=None): + """Initialize - optionaly provide simplecache object""" + if not simplecache: + from simplecache import SimpleCache + self.cache = SimpleCache() + else: + self.cache = simplecache + addon = xbmcaddon.Addon(id=ADDON_ID) + # personal api key (preferred over provided api key) + api_key = addon.getSetting("tmdb_apikey") + if api_key: + self.api_key = api_key + del addon + + def search_movie(self, title, year="", manual_select=False, ignore_cache=False): + """ + Search tmdb for a specific movie, returns full details of best match + parameters: + title: (required) the title of the movie to search for + year: (optional) the year of the movie to search for (enhances search result if supplied) + manual_select: (optional) if True will show select dialog with all results + """ + details = self.select_best_match(self.search_movies(title, year), manual_select=manual_select) + if details: + details = self.get_movie_details(details["id"]) + return details + + @use_cache(30) + def search_movieset(self, title): + """search for movieset details providing the title of the set""" + details = {} + params = {"query": title, "language": KODI_LANGUAGE} + result = self.get_data("search/collection", params) + if result: + set_id = result[0]["id"] + details = self.get_movieset_details(set_id) + return details + + @use_cache(4) + def search_tvshow(self, title, year="", manual_select=False, ignore_cache=False): + """ + Search tmdb for a specific movie, returns full details of best match + parameters: + title: (required) the title of the movie to search for + year: (optional) the year of the movie to search for (enhances search result if supplied) + manual_select: (optional) if True will show select dialog with all results + """ + details = self.select_best_match(self.search_tvshows(title, year), manual_select=manual_select) + if details: + details = self.get_tvshow_details(details["id"]) + return details + + @use_cache(4) + def search_video(self, title, prefyear="", preftype="", manual_select=False, ignore_cache=False): + """ + Search tmdb for a specific entry (can be movie or tvshow), returns full details of best match + parameters: + title: (required) the title of the movie/tvshow to search for + prefyear: (optional) prefer result if year matches + preftype: (optional) prefer result if type matches + manual_select: (optional) if True will show select dialog with all results + """ + results = self.search_videos(title) + details = self.select_best_match(results, prefyear=prefyear, preftype=preftype, + preftitle=title, manual_select=manual_select) + if details and details["media_type"] == "movie": + details = self.get_movie_details(details["id"]) + elif details and "tv" in details["media_type"]: + details = self.get_tvshow_details(details["id"]) + return details + + @use_cache(4) + def search_videos(self, title): + """ + Search tmdb for a specific entry (can be movie or tvshow), parameters: + title: (required) the title of the movie/tvshow to search for + """ + results = [] + page = 1 + maxpages = 5 + while page < maxpages: + params = {"query": title, "language": KODI_LANGUAGE, "page": page} + subresults = self.get_data("search/multi", params) + page += 1 + if subresults: + for item in subresults: + if item["media_type"] in ["movie", "tv"]: + results.append(item) + else: + break + return results + + @use_cache(4) + def search_movies(self, title, year=""): + """ + Search tmdb for a specific movie, returns a list of all closest matches + parameters: + title: (required) the title of the movie to search for + year: (optional) the year of the movie to search for (enhances search result if supplied) + """ + params = {"query": title, "language": KODI_LANGUAGE} + if year: + params["year"] = try_parse_int(year) + return self.get_data("search/movie", params) + + @use_cache(4) + def search_tvshows(self, title, year=""): + """ + Search tmdb for a specific tvshow, returns a list of all closest matches + parameters: + title: (required) the title of the tvshow to search for + year: (optional) the first air date year of the tvshow to search for (enhances search result if supplied) + """ + params = {"query": title, "language": KODI_LANGUAGE} + if year: + params["first_air_date_year"] = try_parse_int(year) + return self.get_data("search/tv", params) + + def get_actor(self, name): + """ + Search tmdb for a specific actor/person, returns the best match as kodi compatible dict + required parameter: name --> the name of the person + """ + params = {"query": name, "language": KODI_LANGUAGE} + result = self.get_data("search/person", params) + if result: + result = result[0] + cast_thumb = "https://image.tmdb.org/t/p/original%s" % result[ + "profile_path"] if result["profile_path"] else "" + item = {"name": result["name"], + "thumb": cast_thumb, + "roles": [item["title"] if item.get("title") else item["name"] for item in result["known_for"]]} + return item + else: + return {} + + def get_movie_details(self, movie_id): + """get all moviedetails""" + params = { + "append_to_response": "keywords,videos,credits,images", + "include_image_language": "%s,en" % KODI_LANGUAGE, + "language": KODI_LANGUAGE + } + return self.map_details(self.get_data("movie/%s" % movie_id, params), "movie") + + def get_movieset_details(self, movieset_id): + """get all moviesetdetails""" + details = {"art": {}} + params = {"language": KODI_LANGUAGE} + result = self.get_data("collection/%s" % movieset_id, params) + if result: + details["title"] = result["name"] + details["plot"] = result["overview"] + details["tmdb_id"] = result["id"] + details["art"]["poster"] = "https://image.tmdb.org/t/p/original%s" % result["poster_path"] + details["art"]["fanart"] = "https://image.tmdb.org/t/p/original%s" % result["backdrop_path"] + details["totalmovies"] = len(result["parts"]) + return details + + def get_tvshow_details(self, tvshow_id): + """get all tvshowdetails""" + params = { + "append_to_response": "keywords,videos,external_ids,credits,images", + "include_image_language": "%s,en" % KODI_LANGUAGE, + "language": KODI_LANGUAGE + } + return self.map_details(self.get_data("tv/%s" % tvshow_id, params), "tvshow") + + def get_videodetails_by_externalid(self, extid, extid_type): + """get metadata by external ID (like imdbid)""" + params = {"external_source": extid_type, "language": KODI_LANGUAGE} + results = self.get_data("find/%s" % extid, params) + if results and results["movie_results"]: + return self.get_movie_details(results["movie_results"][0]["id"]) + elif results and results["tv_results"]: + return self.get_tvshow_details(results["tv_results"][0]["id"]) + return {} + + def get_data(self, endpoint, params): + """helper method to get data from tmdb json API""" + if self.api_key: + # addon provided or personal api key + params["api_key"] = self.api_key + rate_limit = None + expiration = datetime.timedelta(days=7) + else: + # fallback api key (rate limited !) + params["api_key"] = "80246691939720672db3fc71c74e0ef2" + # without personal (or addon specific) api key = rate limiting and older info from cache + rate_limit = ("themoviedb.org", 5) + expiration = datetime.timedelta(days=60) + if sys.version_info.major == 3: + cachestr = "tmdb.%s" % params.values() + else: + cachestr = "tmdb.%s" % params.itervalues() + cache = self.cache.get(cachestr) + if cache: + # data obtained from cache + result = cache + else: + # no cache, grab data from API + url = u'https://api.themoviedb.org/3/%s' % endpoint + result = get_json(url, params, ratelimit=rate_limit) + # make sure that we have a plot value (if localized value fails, fallback to english) + if result and "language" in params and "overview" in result: + if not result["overview"] and params["language"] != "en": + params["language"] = "en" + result2 = get_json(url, params) + if result2 and result2.get("overview"): + result = result2 + self.cache.set(url, result, expiration=expiration) + return result + + def map_details(self, data, media_type): + """helper method to map the details received from tmdb to kodi compatible formatting""" + if not data: + return {} + details = {} + details["tmdb_id"] = data["id"] + details["rating"] = data["vote_average"] + details["votes"] = data["vote_count"] + details["rating.tmdb"] = data["vote_average"] + details["votes.tmdb"] = data["vote_count"] + details["popularity"] = data["popularity"] + details["popularity.tmdb"] = data["popularity"] + details["plot"] = data["overview"] + details["genre"] = [item["name"] for item in data["genres"]] + details["homepage"] = data["homepage"] + details["status"] = data["status"] + details["cast"] = [] + details["castandrole"] = [] + details["writer"] = [] + details["director"] = [] + details["media_type"] = media_type + # cast + if "credits" in data: + if "cast" in data["credits"]: + for cast_member in data["credits"]["cast"]: + cast_thumb = "" + if cast_member["profile_path"]: + cast_thumb = "https://image.tmdb.org/t/p/original%s" % cast_member["profile_path"] + details["cast"].append({"name": cast_member["name"], "role": cast_member["character"], + "thumbnail": cast_thumb}) + details["castandrole"].append((cast_member["name"], cast_member["character"])) + # crew (including writers and directors) + if "crew" in data["credits"]: + for crew_member in data["credits"]["crew"]: + cast_thumb = "" + if crew_member["profile_path"]: + cast_thumb = "https://image.tmdb.org/t/p/original%s" % crew_member["profile_path"] + if crew_member["job"] in ["Author", "Writer"]: + details["writer"].append(crew_member["name"]) + if crew_member["job"] in ["Producer", "Executive Producer"]: + details["director"].append(crew_member["name"]) + if crew_member["job"] in ["Producer", "Executive Producer", "Author", "Writer"]: + details["cast"].append({"name": crew_member["name"], "role": crew_member["job"], + "thumbnail": cast_thumb}) + # artwork + details["art"] = {} + if data.get("images"): + if data["images"].get("backdrops"): + fanarts = self.get_best_images(data["images"]["backdrops"]) + details["art"]["fanarts"] = fanarts + details["art"]["fanart"] = fanarts[0] if fanarts else "" + if data["images"].get("posters"): + posters = self.get_best_images(data["images"]["posters"]) + details["art"]["posters"] = posters + details["art"]["poster"] = posters[0] if posters else "" + if not details["art"].get("poster") and data.get("poster_path"): + details["art"]["poster"] = "https://image.tmdb.org/t/p/original%s" % data["poster_path"] + if not details["art"].get("fanart") and data.get("backdrop_path"): + details["art"]["fanart"] = "https://image.tmdb.org/t/p/original%s" % data["backdrop_path"] + # movies only + if media_type == "movie": + details["title"] = data["title"] + details["originaltitle"] = data["original_title"] + if data["belongs_to_collection"]: + details["set"] = data["belongs_to_collection"].get("name", "") + if data.get("release_date"): + details["premiered"] = data["release_date"] + details["year"] = try_parse_int(data["release_date"].split("-")[0]) + details["tagline"] = data["tagline"] + if data["runtime"]: + details["runtime"] = data["runtime"] * 60 + details["imdbnumber"] = data["imdb_id"] + details["budget"] = data["budget"] + details["budget.formatted"] = int_with_commas(data["budget"]) + details["revenue"] = data["revenue"] + details["revenue.formatted"] = int_with_commas(data["revenue"]) + if data.get("production_companies"): + details["studio"] = [item["name"] for item in data["production_companies"]] + if data.get("production_countries"): + details["country"] = [item["name"] for item in data["production_countries"]] + if data.get("keywords"): + details["tag"] = [item["name"] for item in data["keywords"]["keywords"]] + # tvshows only + if media_type == "tvshow": + details["title"] = data["name"] + details["originaltitle"] = data["original_name"] + if data.get("created_by"): + details["director"] += [item["name"] for item in data["created_by"]] + if data.get("episode_run_time"): + details["runtime"] = data["episode_run_time"][0] * 60 + if data.get("first_air_date"): + details["premiered"] = data["first_air_date"] + details["year"] = try_parse_int(data["first_air_date"].split("-")[0]) + if "last_air_date" in data: + details["lastaired"] = data["last_air_date"] + if data.get("networks"): + details["studio"] = [item["name"] for item in data["networks"]] + if "origin_country" in data: + details["country"] = data["origin_country"] + if "number_of_seasons" in data: + details["Seasons"] = data["number_of_seasons"] + if "number_of_episodes" in data: + details["Episodes"] = data["number_of_episodes"] + if data.get("seasons"): + tmdboverviewdetails = data["seasons"] + seasons = [] + for count, item in enumerate(tmdboverviewdetails): + seasons.append(item["overview"]) + details["seasons.formatted.%s" % count] = "%s %s[CR]%s[CR]" % (item["name"], item["air_date"], item["overview"]) + details["seasons.formatted"] = "[CR]".join(seasons) + if data.get("external_ids"): + details["imdbnumber"] = data["external_ids"].get("imdb_id", "") + details["tvdb_id"] = data["external_ids"].get("tvdb_id", "") + if "results" in data["keywords"]: + details["tag"] = [item["name"] for item in data["keywords"]["results"]] + # trailer + for video in data["videos"]["results"]: + if video["site"] == "YouTube" and video["type"] == "Trailer": + details["trailer"] = 'plugin://plugin.video.youtube/?action=play_video&videoid=%s' % video["key"] + break + return details + + @staticmethod + def get_best_images(images): + """get the best 5 images based on number of likes and the language""" + for image in images: + score = 0 + score += image["vote_count"] + score += image["vote_average"] * 10 + score += image["height"] + if "iso_639_1" in image: + if image["iso_639_1"] == KODI_LANGUAGE: + score += 1000 + image["score"] = score + if not image["file_path"].startswith("https"): + image["file_path"] = "https://image.tmdb.org/t/p/original%s" % image["file_path"] + images = sorted(images, key=itemgetter("score"), reverse=True) + return [image["file_path"] for image in images] + + @staticmethod + def select_best_match(results, prefyear="", preftype="", preftitle="", manual_select=False): + """helper to select best match or let the user manually select the best result from the search""" + details = {} + # score results if one or more preferences are given + if results and (prefyear or preftype or preftitle): + newdata = [] + preftitle = preftitle.lower() + for item in results: + item["score"] = 0 + itemtitle = item["title"] if item.get("title") else item["name"] + itemtitle = itemtitle.lower() + itemorgtitle = item["original_title"] if item.get("original_title") else item["original_name"] + itemorgtitle = itemorgtitle.lower() + + # high score if year matches + if prefyear: + if item.get("first_air_date") and prefyear in item["first_air_date"]: + item["score"] += 800 # matches preferred year + if item.get("release_date") and prefyear in item["release_date"]: + item["score"] += 800 # matches preferred year + + # find exact match on title + if preftitle and preftitle == itemtitle: + item["score"] += 1000 # exact match! + if preftitle and preftitle == itemorgtitle: + item["score"] += 1000 # exact match! + + # match title by replacing some characters + if preftitle and get_compare_string(preftitle) == get_compare_string(itemtitle): + item["score"] += 750 + if preftitle and get_compare_string(preftitle) == get_compare_string(itemorgtitle): + item["score"] += 750 + + # add SequenceMatcher score to the results + if preftitle: + stringmatchscore = SM(None, preftitle, itemtitle).ratio( + ) + SM(None, preftitle, itemorgtitle).ratio() + if stringmatchscore > 1.6: + item["score"] += stringmatchscore * 250 + + # higher score if result ALSO matches our preferred type or native language + # (only when we already have a score) + if item["score"]: + if preftype and (item["media_type"] in preftype) or (preftype in item["media_type"]): + item["score"] += 250 # matches preferred type + if item["original_language"] == KODI_LANGUAGE: + item["score"] += 500 # native language! + if KODI_LANGUAGE.upper() in item.get("origin_country", []): + item["score"] += 500 # native language! + if KODI_LANGUAGE in item.get("languages", []): + item["score"] += 500 # native language! + + if item["score"] > 500 or manual_select: + newdata.append(item) + results = sorted(newdata, key=itemgetter("score"), reverse=True) + + if results and manual_select: + # show selectdialog to manually select the item + results_list = [] + for item in results: + title = item["name"] if "name" in item else item["title"] + if item.get("premiered"): + year = item["premiered"].split("-")[0] + else: + year = item.get("first_air_date", "").split("-")[0] + if item["poster_path"]: + thumb = "https://image.tmdb.org/t/p/original%s" % item["poster_path"] + else: + thumb = "" + label = "%s (%s) - %s" % (title, year, item["media_type"]) + listitem = xbmcgui.ListItem(label=label, label2=item["overview"]) + listitem.setArt({'icon': thumb}) + results_list.append(listitem) + if manual_select and results_list: + dialog = DialogSelect("DialogSelect.xml", "", listing=results_list, window_title="%s - TMDB" + % xbmc.getLocalizedString(283)) + dialog.doModal() + selected_item = dialog.result + del dialog + if selected_item != -1: + details = results[selected_item] + else: + results = [] + + if not details and results: + # just grab the first item as best match + details = results[0] + return details diff --git a/resources/lib/metadatautils/helpers/utils.py b/resources/lib/metadatautils/helpers/utils.py new file mode 100644 index 0000000..4cf5bc4 --- /dev/null +++ b/resources/lib/metadatautils/helpers/utils.py @@ -0,0 +1,885 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +"""Various generic helper methods""" + +import os, sys +import xbmcgui +import xbmc +import xbmcvfs +import xbmcaddon +import sys +import requests +import arrow +from requests.packages.urllib3.util.retry import Retry +from requests.adapters import HTTPAdapter +if sys.version_info.major == 3: + import traceback + from urllib.parse import unquote +else: + from traceback import format_exc + from urllib import unquote +import unicodedata +import datetime +import time +import xml.etree.ElementTree as ET + +try: + import simplejson as json +except Exception: + import json + +try: + from multiprocessing.pool import ThreadPool + SUPPORTS_POOL = True +except Exception: + SUPPORTS_POOL = False + + +ADDON_ID = "script.module.metadatautils" +KODI_LANGUAGE = xbmc.getLanguage(xbmc.ISO_639_1) +if not KODI_LANGUAGE: + KODI_LANGUAGE = "en" +KODI_VERSION = int(xbmc.getInfoLabel("System.BuildVersion").split(".")[0]) + +# setup requests with some additional options +requests.packages.urllib3.disable_warnings() +SESSION = requests.Session() +RETRIES = Retry(total=5, backoff_factor=5, status_forcelist=[500, 502, 503, 504]) +SESSION.mount('http://', HTTPAdapter(max_retries=RETRIES)) +SESSION.mount('https://', HTTPAdapter(max_retries=RETRIES)) + +FORCE_DEBUG_LOG = False +LIMIT_EXTRAFANART = 0 +try: + ADDON = xbmcaddon.Addon(ADDON_ID) + FORCE_DEBUG_LOG = ADDON.getSetting('debug_log') == 'true' + LIMIT_EXTRAFANART = int(ADDON.getSetting('max_extrafanarts')) + del ADDON +except Exception: + pass + + +def log_msg(msg, loglevel=xbmc.LOGDEBUG): + """log message to kodi logfile""" + if sys.version_info.major < 3: + if isinstance(msg, unicode): + msg = msg.encode('utf-8') + if loglevel == xbmc.LOGDEBUG and FORCE_DEBUG_LOG: + loglevel = xbmc.LOGINFO + xbmc.log("%s --> %s" % (ADDON_ID, msg), level=loglevel) + + +def log_exception(modulename, exceptiondetails): + '''helper to properly log an exception''' + if sys.version_info.major == 3: + exc_type, exc_value, exc_traceback = sys.exc_info() + lines = traceback.format_exception(exc_type, exc_value, exc_traceback) + log_msg("Exception details: Type: %s Value: %s Traceback: %s" % (exc_type.__name__, exc_value, ''.join(line for line in lines)), xbmc.LOGWARNING) + else: + log_msg(format_exc(sys.exc_info()), xbmc.LOGWARNING) + log_msg("Exception in %s ! --> %s" % (modulename, exceptiondetails), xbmc.LOGERROR) + + +def rate_limiter(rl_params): + """ A very basic rate limiter which limits to 1 request per X seconds to the api""" + # Please respect the parties providing these free api's to us and do not modify this code. + # If I suspect any abuse I will revoke all api keys and require all users + # to have a personal api key for all services. + # Thank you + if not rl_params: + return + monitor = xbmc.Monitor() + win = xbmcgui.Window(10000) + rl_name = rl_params[0] + rl_delay = rl_params[1] + cur_timestamp = int(time.mktime(datetime.datetime.now().timetuple())) + prev_timestamp = try_parse_int(win.getProperty("ratelimiter.%s" % rl_name)) + if (prev_timestamp + rl_delay) > cur_timestamp: + sec_to_wait = (prev_timestamp + rl_delay) - cur_timestamp + log_msg( + "Rate limiter active for %s - delaying request with %s seconds - " + "Configure a personal API key in the settings to get rid of this message and the delay." % + (rl_name, sec_to_wait), xbmc.LOGINFO) + while sec_to_wait and not monitor.abortRequested(): + monitor.waitForAbort(1) + # keep setting the timestamp to create some sort of queue + cur_timestamp = int(time.mktime(datetime.datetime.now().timetuple())) + win.setProperty("ratelimiter.%s" % rl_name, "%s" % cur_timestamp) + sec_to_wait -= 1 + # always set the timestamp + cur_timestamp = int(time.mktime(datetime.datetime.now().timetuple())) + win.setProperty("ratelimiter.%s" % rl_name, "%s" % cur_timestamp) + del monitor + del win + + +def get_json(url, params=None, retries=0, ratelimit=None): + """get info from a rest api""" + result = {} + if not params: + params = {} + # apply rate limiting if needed + rate_limiter(ratelimit) + try: + response = requests.get(url, params=params, timeout=20) + if response and response.content and response.status_code == 200: + if sys.version_info.major == 3: + result = json.loads(response.content) + else: + result = json.loads(response.content.decode('utf-8', 'replace')) + if "results" in result: + result = result["results"] + elif "result" in result: + result = result["result"] + elif response.status_code in (429, 503, 504): + raise Exception('Read timed out') + except Exception as exc: + result = None + if "Read timed out" in str(exc) and retries < 5 and not ratelimit: + # retry on connection error or http server limiting + monitor = xbmc.Monitor() + if not monitor.waitForAbort(2): + result = get_json(url, params, retries + 1) + del monitor + else: + log_exception(__name__, exc) + # return result + return result + + +def get_xml(url, params=None, retries=0, ratelimit=None): + '''get info from a rest api''' + result = {} + if not params: + params = {} + # apply rate limiting if needed + rate_limiter(ratelimit) + try: + response = requests.get(url, params=params, timeout=20) + if response and response.content and response.status_code == 200: + tree = ET.fromstring(response.content) + #child = tree.find('movie') + if(len(tree)): + child = tree[0] + #log_exception(__name__, child) + for attrName, attrValue in child.items(): + result.update({attrName : attrValue}) + elif response.status_code in (429, 503, 504): + raise Exception('Read timed out') + except Exception as exc: + result = None + if "Read timed out" in str(exc) and retries < 5 and not ratelimit: + # retry on connection error or http server limiting + monitor = xbmc.Monitor() + if not monitor.waitForAbort(2): + result = get_xml(url, params, retries + 1) + del monitor + else: + log_exception(__name__, exc) + return result + + +def try_encode(text, encoding="utf-8"): + """helper to encode a string to utf-8""" + try: + if sys.version_info.major == 3: + return text + else: + return text.encode(encoding, "ignore") + except Exception: + return text + + +def try_decode(text, encoding="utf-8"): + """helper to decode a string to unicode""" + try: + if sys.version_info.major == 3: + return text + else: + return text.decode(encoding, "ignore") + except Exception: + return text + + +def urlencode(text): + """helper to properly urlencode a string""" + blah = urllib.urlencode({'blahblahblah': try_encode(text)}) + blah = blah[13:] + return blah + + +def formatted_number(number): + """try to format a number to formatted string with thousands""" + try: + number = int(number) + if number < 0: + return '-' + formatted_number(-number) + result = '' + while number >= 1000: + number, number2 = divmod(number, 1000) + result = ",%03d%s" % (number2, result) + return "%d%s" % (number, result) + except Exception: + return "" + + +def process_method_on_list(method_to_run, items): + """helper method that processes a method on each listitem with pooling if the system supports it""" + all_items = [] + if items is not None: + if SUPPORTS_POOL: + pool = ThreadPool() + try: + all_items = pool.map(method_to_run, items) + except Exception: + # catch exception to prevent threadpool running forever + log_msg(format_exc(sys.exc_info())) + log_msg("Error in %s" % method_to_run) + pool.close() + pool.join() + else: + try: + all_items = [method_to_run(item) for item in list(items)] + except Exception: + log_msg(format_exc(sys.exc_info())) + log_msg("Error while executing %s with %s" % (method_to_run, items)) + if sys.version_info.major == 3: + all_items = list(filter(None, all_items)) + else: + all_items = filter(None, all_items) + return all_items + + +def get_clean_image(image): + """helper to strip all kodi tags/formatting of an image path/url""" + if not image: + return "" + if "music@" in image: + # fix for embedded images + thumbcache = xbmc.getCacheThumbName(image).replace(".tbn", ".jpg") + thumbcache = "special://thumbnails/%s/%s" % (thumbcache[0], thumbcache) + if not xbmcvfs.exists(thumbcache): + xbmcvfs.copy(image, thumbcache) + image = thumbcache + if image and "image://" in image: + image = image.replace("image://", "") + if sys.version_info.major == 3: + image = unquote(image) + else: + image = unquote(image.encode("utf-8")) + if image.endswith("/"): + image = image[:-1] + if sys.version_info.major < 3: + if not isinstance(image, unicode): + image = image.decode("utf8") + return image + + +def get_duration(duration): + """transform duration time in minutes to hours:minutes""" + if not duration: + return {} + if sys.version_info.major == 3: + if isinstance(duration, str): + duration.replace("min", "").replace("", "").replace(".", "") + else: + if isinstance(duration, (unicode, str)): + duration.replace("min", "").replace("", "").replace(".", "") + try: + total_minutes = int(duration) + if total_minutes < 60: + hours = 0 + else: + hours = total_minutes // 60 % 60 + minutes = total_minutes - (hours * 60) + formatted_time = "%s:%s" % (hours, str(minutes).zfill(2)) + except Exception as exc: + log_exception(__name__, exc) + return {} + return { + "Duration": formatted_time, + "Duration.Hours": hours, + "Duration.Minutes": minutes, + "Runtime": total_minutes, + "RuntimeExtended": "%s %s" % (total_minutes, xbmc.getLocalizedString(12391)), + "DurationAndRuntime": "%s (%s min.)" % (formatted_time, total_minutes), + "DurationAndRuntimeExtended": "%s (%s %s)" % (formatted_time, total_minutes, xbmc.getLocalizedString(12391)) + } + + +def int_with_commas(number): + """helper to pretty format a number""" + try: + number = int(number) + if number < 0: + return '-' + int_with_commas(-number) + result = '' + while number >= 1000: + number, number2 = divmod(number, 1000) + result = ",%03d%s" % (number2, result) + return "%d%s" % (number, result) + except Exception: + return "" + + +def try_parse_int(string): + """helper to parse int from string without erroring on empty or misformed string""" + try: + return int(string) + except Exception: + return 0 + + +def extend_dict(org_dict, new_dict, allow_overwrite=None): + """Create a new dictionary with a's properties extended by b, + without overwriting existing values.""" + if not new_dict: + return org_dict + if not org_dict: + return new_dict + if sys.version_info.major == 3: + for key, value in new_dict.items(): + if value: + if not org_dict.get(key): + # orginal dict doesn't has this key (or no value), just overwrite + org_dict[key] = value + else: + # original dict already has this key, append results + if isinstance(value, list): + # make sure that our original value also is a list + if isinstance(org_dict[key], list): + for item in value: + if item not in org_dict[key]: + org_dict[key].append(item) + # previous value was str, combine both in list + elif isinstance(org_dict[key], str): + org_dict[key] = org_dict[key].split(" / ") + for item in value: + if item not in org_dict[key]: + org_dict[key].append(item) + elif isinstance(value, dict): + org_dict[key] = extend_dict(org_dict[key], value, allow_overwrite) + elif allow_overwrite and key in allow_overwrite: + # value may be overwritten + org_dict[key] = value + else: + # conflict, leave alone + pass + else: + if sys.version_info.major == 3: + for key, value in new_dict.items(): + if value: + if not org_dict.get(key): + # orginal dict doesn't has this key (or no value), just overwrite + org_dict[key] = value + else: + # original dict already has this key, append results + if isinstance(value, list): + # make sure that our original value also is a list + if isinstance(org_dict[key], list): + for item in value: + if item not in org_dict[key]: + org_dict[key].append(item) + # previous value was str, combine both in list + elif isinstance(org_dict[key], str): + org_dict[key] = org_dict[key].split(" / ") + for item in value: + if item not in org_dict[key]: + org_dict[key].append(item) + elif isinstance(value, dict): + org_dict[key] = extend_dict(org_dict[key], value, allow_overwrite) + elif allow_overwrite and key in allow_overwrite: + # value may be overwritten + org_dict[key] = value + else: + # conflict, leave alone + pass + else: + for key, value in new_dict.iteritems(): + if value: + if not org_dict.get(key): + # orginal dict doesn't has this key (or no value), just overwrite + org_dict[key] = value + else: + # original dict already has this key, append results + if isinstance(value, list): + # make sure that our original value also is a list + if isinstance(org_dict[key], list): + for item in value: + if item not in org_dict[key]: + org_dict[key].append(item) + # previous value was str, combine both in list + elif isinstance(org_dict[key], (str, unicode)): + org_dict[key] = org_dict[key].split(" / ") + for item in value: + if item not in org_dict[key]: + org_dict[key].append(item) + elif isinstance(value, dict): + org_dict[key] = extend_dict(org_dict[key], value, allow_overwrite) + elif allow_overwrite and key in allow_overwrite: + # value may be overwritten + org_dict[key] = value + else: + # conflict, leave alone + pass + + return org_dict + + +def localdate_from_utc_string(timestring): + """helper to convert internal utc time (used in pvr) to local timezone""" + utc_datetime = arrow.get(timestring) + local_datetime = utc_datetime.to('local') + return local_datetime.format("YYYY-MM-DD HH:mm:ss") + + +def localized_date_time(timestring): + """returns localized version of the timestring (used in pvr)""" + date_time = arrow.get(timestring) + local_date = date_time.strftime(xbmc.getRegion("dateshort")) + local_time = date_time.strftime(xbmc.getRegion("time").replace(":%S", "")) + return local_date, local_time + + +def normalize_string(text): + """normalize string, strip all special chars""" + text = text.replace(":", "") + text = text.replace("/", "-") + text = text.replace("\\", "-") + text = text.replace("<", "") + text = text.replace(">", "") + text = text.replace("*", "") + text = text.replace("?", "") + text = text.replace('|', "") + text = text.replace('(', "") + text = text.replace(')', "") + text = text.replace("\"", "") + text = text.strip() + text = text.rstrip('.') + text = unicodedata.normalize('NFKD', try_decode(text)) + return text + + +def get_compare_string(text): + """strip all special chars in a string for better comparing of searchresults""" + if sys.version_info.major < 3: + if not isinstance(text, unicode): + text.decode("utf-8") + text = text.lower() + text = ''.join(e for e in text if e.isalnum()) + return text + + +def strip_newlines(text): + """strip any newlines from a string""" + return text.replace('\n', ' ').replace('\r', '').rstrip() + + +def detect_plugin_content(plugin_path): + """based on the properties of a vfspath we try to detect the content type""" + content_type = "" + if not plugin_path: + return "" + # detect content based on the path + if "listing" in plugin_path: + content_type = "folder" + elif "movie" in plugin_path.lower(): + content_type = "movies" + elif "album" in plugin_path.lower(): + content_type = "albums" + elif "show" in plugin_path.lower(): + content_type = "tvshows" + elif "episode" in plugin_path.lower(): + content_type = "episodes" + elif "song" in plugin_path.lower(): + content_type = "songs" + elif "musicvideo" in plugin_path.lower(): + content_type = "musicvideos" + elif "pvr" in plugin_path.lower(): + content_type = "pvr" + elif "type=dynamic" in plugin_path.lower(): + content_type = "movies" + elif "videos" in plugin_path.lower(): + content_type = "movies" + elif "type=both" in plugin_path.lower(): + content_type = "movies" + elif "media" in plugin_path.lower(): + content_type = "movies" + elif "favourites" in plugin_path.lower(): + content_type = "movies" + elif ("box" in plugin_path.lower() or "dvd" in plugin_path.lower() or + "rentals" in plugin_path.lower() or "incinemas" in plugin_path.lower() or + "comingsoon" in plugin_path.lower() or "upcoming" in plugin_path.lower() or + "opening" in plugin_path.lower() or "intheaters" in plugin_path.lower()): + content_type = "movies" + # if we didn't get the content based on the path, we need to probe the addon... + if not content_type and not xbmc.getCondVisibility("Window.IsMedia"): # safety check + from kodidb import KodiDb + media_array = KodiDb().files(plugin_path, limits=(0, 1)) + for item in media_array: + if item.get("filetype", "") == "directory": + content_type = "folder" + break + elif item.get("type") and item["type"] != "unknown": + content_type = item["type"] + "s" + break + elif "showtitle" not in item and "artist" not in item: + # these properties are only returned in the json response if we're looking at actual file content... + # if it's missing it means this is a main directory listing and no need to + # scan the underlying listitems. + content_type = "files" + break + if "showtitle" not in item and "artist" in item: + # AUDIO ITEMS + if item["type"] == "artist": + content_type = "artists" + break + elif (isinstance(item["artist"], list) and len(item["artist"]) > 0 and + item["artist"][0] == item["title"]): + content_type = "artists" + break + elif item["type"] == "album" or item["album"] == item["title"]: + content_type = "albums" + break + elif ((item["type"] == "song" and "play_album" not in item["file"]) or + (item["artist"] and item["album"])): + content_type = "songs" + break + else: + # VIDEO ITEMS + if item["showtitle"] and not item.get("artist"): + # this is a tvshow, episode or season... + if item["type"] == "season" or (item["season"] > -1 and item["episode"] == -1): + content_type = "seasons" + break + elif item["type"] == "episode" or item["season"] > -1 and item["episode"] > -1: + content_type = "episodes" + break + else: + content_type = "tvshows" + break + elif item.get("artist"): + # this is a musicvideo! + content_type = "musicvideos" + break + elif (item["type"] == "movie" or item.get("imdbnumber") or item.get("mpaa") or + item.get("trailer") or item.get("studio")): + content_type = "movies" + break + log_msg("detect_plugin_path_content for: %s - result: %s" % (plugin_path, content_type)) + return content_type + + +def download_artwork(folderpath, artwork): + """download artwork to local folder""" + efa_path = "" + new_dict = {} + if not xbmcvfs.exists(folderpath): + xbmcvfs.mkdir(folderpath) + if sys.version_info.major == 3: + for key, value in artwork.items(): + if key == "fanart": + new_dict[key] = download_image(os.path.join(folderpath, "fanart.jpg"), value) + elif key == "thumb": + new_dict[key] = download_image(os.path.join(folderpath, "folder.jpg"), value) + elif key == "discart": + new_dict[key] = download_image(os.path.join(folderpath, "disc.png"), value) + elif key == "banner": + new_dict[key] = download_image(os.path.join(folderpath, "banner.jpg"), value) + elif key == "clearlogo": + new_dict[key] = download_image(os.path.join(folderpath, "logo.png"), value) + elif key == "clearart": + new_dict[key] = download_image(os.path.join(folderpath, "clearart.png"), value) + elif key == "characterart": + new_dict[key] = download_image(os.path.join(folderpath, "characterart.png"), value) + elif key == "poster": + new_dict[key] = download_image(os.path.join(folderpath, "poster.jpg"), value) + elif key == "landscape": + new_dict[key] = download_image(os.path.join(folderpath, "landscape.jpg"), value) + elif key == "thumbback": + new_dict[key] = download_image(os.path.join(folderpath, "thumbback.jpg"), value) + elif key == "spine": + new_dict[key] = download_image(os.path.join(folderpath, "spine.jpg"), value) + elif key == "album3Dthumb": + new_dict[key] = download_image(os.path.join(folderpath, "album3Dthumb.png"), value) + elif key == "album3Dflat": + new_dict[key] = download_image(os.path.join(folderpath, "album3Dflat.png"), value) + elif key == "album3Dcase": + new_dict[key] = download_image(os.path.join(folderpath, "album3Dcase.png"), value) + elif key == "album3Dface": + new_dict[key] = download_image(os.path.join(folderpath, "album3Dface.png"), value) + elif key == "fanarts" and value: + # copy extrafanarts only if the directory doesn't exist at all + delim = "\\" if "\\" in folderpath else "/" + efa_path = "%sextrafanart" % folderpath + delim + if not xbmcvfs.exists(efa_path): + xbmcvfs.mkdir(efa_path) + images = [] + for count, image in enumerate(value): + image = download_image(os.path.join(efa_path, "fanart%s.jpg" % count), image) + images.append(image) + if LIMIT_EXTRAFANART and count == LIMIT_EXTRAFANART: + break + new_dict[key] = images + elif key == "posters" and value: + # copy extraposters only if the directory doesn't exist at all + delim = "\\" if "\\" in folderpath else "/" + efa_path = "%sextraposter" % folderpath + delim + if not xbmcvfs.exists(efa_path): + xbmcvfs.mkdir(efa_path) + images = [] + for count, image in enumerate(value): + image = download_image(os.path.join(efa_path, "poster%s.jpg" % count), image) + images.append(image) + if LIMIT_EXTRAFANART and count == LIMIT_EXTRAFANART: + break + new_dict[key] = images + else: + new_dict[key] = value + else: + for key, value in artwork.iteritems(): + if key == "fanart": + new_dict[key] = download_image(os.path.join(folderpath, "fanart.jpg"), value) + elif key == "thumb": + new_dict[key] = download_image(os.path.join(folderpath, "folder.jpg"), value) + elif key == "discart": + new_dict[key] = download_image(os.path.join(folderpath, "disc.png"), value) + elif key == "banner": + new_dict[key] = download_image(os.path.join(folderpath, "banner.jpg"), value) + elif key == "clearlogo": + new_dict[key] = download_image(os.path.join(folderpath, "logo.png"), value) + elif key == "clearart": + new_dict[key] = download_image(os.path.join(folderpath, "clearart.png"), value) + elif key == "characterart": + new_dict[key] = download_image(os.path.join(folderpath, "characterart.png"), value) + elif key == "poster": + new_dict[key] = download_image(os.path.join(folderpath, "poster.jpg"), value) + elif key == "landscape": + new_dict[key] = download_image(os.path.join(folderpath, "landscape.jpg"), value) + elif key == "thumbback": + new_dict[key] = download_image(os.path.join(folderpath, "thumbback.jpg"), value) + elif key == "spine": + new_dict[key] = download_image(os.path.join(folderpath, "spine.jpg"), value) + elif key == "fanarts" and value: + # copy extrafanarts only if the directory doesn't exist at all + delim = "\\" if "\\" in folderpath else "/" + efa_path = "%sextrafanart" % folderpath + delim + if not xbmcvfs.exists(efa_path): + xbmcvfs.mkdir(efa_path) + images = [] + for count, image in enumerate(value): + image = download_image(os.path.join(efa_path, "fanart%s.jpg" % count), image) + images.append(image) + if LIMIT_EXTRAFANART and count == LIMIT_EXTRAFANART: + break + new_dict[key] = images + elif key == "posters" and value: + # copy extraposters only if the directory doesn't exist at all + delim = "\\" if "\\" in folderpath else "/" + efa_path = "%sextraposter" % folderpath + delim + if not xbmcvfs.exists(efa_path): + xbmcvfs.mkdir(efa_path) + images = [] + for count, image in enumerate(value): + image = download_image(os.path.join(efa_path, "poster%s.jpg" % count), image) + images.append(image) + if LIMIT_EXTRAFANART and count == LIMIT_EXTRAFANART: + break + new_dict[key] = images + else: + new_dict[key] = value + if efa_path: + new_dict["extrafanart"] = efa_path + return new_dict + + +def download_image(filename, url): + """download specific image to local folder""" + if not url: + return url + refresh_needed = False + if xbmcvfs.exists(filename) and filename == url: + # only overwrite if new image is different + return filename + else: + if xbmcvfs.exists(filename): + xbmcvfs.delete(filename) + refresh_needed = True + if xbmcvfs.copy(url, filename): + if refresh_needed: + refresh_image(filename) + return filename + + return url + + +def refresh_image(imagepath): + """tell kodi texture cache to refresh a particular image""" + import sqlite3 + if sys.version_info.major == 3: + dbpath = xbmcvfs.translatePath("special://database/Textures13.db") + else: + dbpath = xbmc.translatePath("special://database/Textures13.db").decode('utf-8') + connection = sqlite3.connect(dbpath, timeout=30, isolation_level=None) + try: + cache_image = connection.execute('SELECT cachedurl FROM texture WHERE url = ?', (imagepath,)).fetchone() + if sys.version_info.major == 3: + if cache_image and isinstance(cache_image, str): + if xbmcvfs.exists(cache_image): + xbmcvfs.delete("special://profile/Thumbnails/%s" % cache_image) + connection.execute('DELETE FROM texture WHERE url = ?', (imagepath,)) + else: + if cache_image and isinstance(cache_image, (unicode, str)): + if xbmcvfs.exists(cache_image): + xbmcvfs.delete("special://profile/Thumbnails/%s" % cache_image) + connection.execute('DELETE FROM texture WHERE url = ?', (imagepath,)) + connection.close() + except Exception as exc: + log_exception(__name__, exc) + finally: + del connection + +# pylint: disable-msg=too-many-local-variables + + +def manual_set_artwork(artwork, mediatype, header=None): + """Allow user to manually select the artwork with a select dialog""" + changemade = False + if mediatype == "artist": + art_types = ["thumb", "poster", "fanart", "banner", "clearart", "clearlogo", "landscape"] + elif mediatype == "album": + art_types = ["thumb", "discart", "thumbback", "spine", "album3Dthumb", "album3Dflat", "album3Dcase", "album3Dface"] + else: + art_types = ["thumb", "poster", "fanart", "banner", "clearart", + "clearlogo", "discart", "landscape", "characterart"] + + if not header: + header = xbmc.getLocalizedString(13511) + + # show dialogselect with all artwork options + abort = False + while not abort: + listitems = [] + for arttype in art_types: + img = artwork.get(arttype, "") + listitem = xbmcgui.ListItem(label=arttype, label2=img) + listitem.setArt({'icon': img}) + listitem.setProperty("icon", img) + listitems.append(listitem) + dialog = DialogSelect("DialogSelect.xml", "", listing=listitems, + window_title=header, multiselect=False) + dialog.doModal() + selected_item = dialog.result + del dialog + if selected_item == -1: + abort = True + else: + # show results for selected art type + artoptions = [] + selected_item = listitems[selected_item] + if sys.version_info.major == 3: + image = selected_item.getProperty("icon") + label = selected_item.getLabel() + else: + image = selected_item.getProperty("icon").decode("utf-8") + label = selected_item.getLabel().decode("utf-8") + subheader = "%s: %s" % (header, label) + if image: + # current image + listitem = xbmcgui.ListItem(label=xbmc.getLocalizedString(13512), label2=image) + listitem.setArt({'icon': image}) + listitem.setProperty("icon", image) + artoptions.append(listitem) + # none option + listitem = xbmcgui.ListItem(label=xbmc.getLocalizedString(231)) + listitem.setArt({'icon': "DefaultAddonNone.png"}) + listitem.setProperty("icon", "DefaultAddonNone.png") + artoptions.append(listitem) + # browse option + listitem = xbmcgui.ListItem(label=xbmc.getLocalizedString(1024)) + listitem.setArt({'icon': "DefaultFolder.png"}) + listitem.setProperty("icon", "DefaultFolder.png") + artoptions.append(listitem) + + # add remaining images as option + allarts = artwork.get(label + "s", []) + for item in allarts: + listitem = xbmcgui.ListItem(label=item) + listitem.setArt({'icon': item}) + listitem.setProperty("icon", item) + artoptions.append(listitem) + + dialog = DialogSelect("DialogSelect.xml", "", listing=artoptions, window_title=subheader) + dialog.doModal() + selected_item = dialog.result + del dialog + if image and selected_item == 1: + # set image to None + artwork[label] = "" + changemade = True + elif (image and selected_item > 2) or (not image and selected_item > 0): + # one of the optional images is selected as new default + artwork[label] = artoptions[selected_item].getProperty("icon") + changemade = True + elif (image and selected_item == 2) or (not image and selected_item == 0): + # manual browse... + dialog = xbmcgui.Dialog() + if sys.version_info.major == 3: + image = dialog.browse(2, xbmc.getLocalizedString(1030), + 'files', mask='.gif|.png|.jpg') + else: + image = dialog.browse(2, xbmc.getLocalizedString(1030), + 'files', mask='.gif|.png|.jpg').decode("utf-8") + del dialog + if image: + artwork[label] = image + changemade = True + + # return endresult + return changemade, artwork + +# pylint: enable-msg=too-many-local-variables + + +class DialogSelect(xbmcgui.WindowXMLDialog): + """wrapper around Kodi dialogselect to present a list of items""" + + list_control = None + + def __init__(self, *args, **kwargs): + xbmcgui.WindowXMLDialog.__init__(self) + self.listing = kwargs.get("listing") + self.window_title = kwargs.get("window_title", "") + self.result = -1 + + def onInit(self): + """called when the dialog is drawn""" + self.list_control = self.getControl(6) + self.getControl(1).setLabel(self.window_title) + self.getControl(3).setVisible(False) + try: + self.getControl(7).setLabel(xbmc.getLocalizedString(222)) + except Exception: + pass + + self.getControl(5).setVisible(False) + + # add our items to the listing and focus the control + self.list_control.addItems(self.listing) + self.setFocus(self.list_control) + + def onAction(self, action): + """On kodi action""" + if action.getId() in (9, 10, 92, 216, 247, 257, 275, 61467, 61448, ): + self.result = -1 + self.close() + + def onClick(self, control_id): + """Triggers if our dialog is clicked""" + if control_id in (6, 3,): + num = self.list_control.getSelectedPosition() + self.result = num + else: + self.result = -1 + self.close() diff --git a/resources/lib/metadatautils/metadatautils.py b/resources/lib/metadatautils/metadatautils.py new file mode 100644 index 0000000..0eb98e3 --- /dev/null +++ b/resources/lib/metadatautils/metadatautils.py @@ -0,0 +1,460 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +''' + script.module.metadatautils + Provides all kind of mediainfo for kodi media, returned as dict with details +''' + +import os, sys +import helpers.kodi_constants as kodi_constants +from helpers.utils import log_msg, ADDON_ID +from simplecache import use_cache, SimpleCache +import xbmcvfs + +if sys.version_info.major == 3: + from urllib.parse import quote_plus +else: + from urllib import quote_plus + + +class MetadataUtils(object): + ''' + Provides all kind of mediainfo for kodi media, returned as dict with details + ''' + _audiodb, _addon, _close_called, _omdb, _kodidb, _tmdb, _fanarttv, _channellogos = [None] * 8 + _imdb, _google, _studiologos, _animatedart, _thetvdb, _musicart, _pvrart, _lastfm = [None] * 8 + _studiologos_path, _process_method_on_list, _detect_plugin_content, _get_streamdetails = [None] * 4 + _extend_dict, _get_clean_image, _get_duration, _get_extrafanart, _get_extraposter, _get_moviesetdetails = [None] * 6 + cache = None + + + def __init__(self): + '''Initialize and load all our helpers''' + self.cache = SimpleCache() + log_msg("Initialized") + + @use_cache(14) + def get_extrafanart(self, file_path): + '''helper to retrieve the extrafanart path for a kodi media item''' + log_msg("metadatautils get_extrafanart called for %s" % file_path) + if not self._get_extrafanart: + from helpers.extrafanart import get_extrafanart + self._get_extrafanart = get_extrafanart + return self._get_extrafanart(file_path) + + @use_cache(14) + def get_extraposter(self, file_path): + '''helper to retrieve the extraposter path for a kodi media item''' + if not self._get_extraposter: + from helpers.extraposter import get_extraposter + self._get_extraposter = get_extraposter + return self._get_extraposter(file_path) + + def get_music_artwork(self, artist, album="", track="", disc="", ignore_cache=False, flush_cache=False): + '''method to get music artwork for the goven artist/album/song''' + return self.musicart.get_music_artwork( + artist, album, track, disc, ignore_cache=ignore_cache, flush_cache=flush_cache) + + def music_artwork_options(self, artist, album="", track="", disc=""): + '''options for music metadata for specific item''' + return self.musicart.music_artwork_options(artist, album, track, disc) + + @use_cache(7) + def get_extended_artwork(self, imdb_id="", tvdb_id="", tmdb_id="", media_type=""): + '''get extended artwork for the given imdbid or tvdbid''' + result = None + if "movie" in media_type and tmdb_id: + result = self.fanarttv.movie(tmdb_id) + elif "movie" in media_type and imdb_id: + # prefer local artwork + local_details = self.kodidb.movie_by_imdbid(imdb_id) + if local_details: + result = local_details["art"] + result = self.extend_dict(result, self.fanarttv.movie(imdb_id)) + elif media_type in ["tvshow", "tvshows", "seasons", "episodes"]: + if not tvdb_id: + if imdb_id and not imdb_id.startswith("tt"): + tvdb_id = imdb_id + elif imdb_id: + tvdb_id = self.thetvdb.get_series_by_imdb_id(imdb_id).get("tvdb_id") + if tvdb_id: + # prefer local artwork + local_details = self.kodidb.tvshow_by_imdbid(tvdb_id) + if local_details: + result = local_details["art"] + elif imdb_id and imdb_id != tvdb_id: + local_details = self.kodidb.tvshow_by_imdbid(imdb_id) + if local_details: + result = local_details["art"] + result = self.extend_dict(result, self.fanarttv.tvshow(tvdb_id)) + # add additional art with special path + if result: + result = {"art": result} + for arttype in ["fanarts", "posters", "clearlogos", "banners", "discarts", "cleararts", "characterarts"]: + if result["art"].get(arttype): + result["art"][arttype] = "plugin://script.skin.helper.service/"\ + "?action=extrafanart&fanarts=%s" % quote_plus(repr(result["art"][arttype])) + return result + + @use_cache(90) + def get_tmdb_details(self, imdb_id="", tvdb_id="", title="", year="", media_type="", + preftype="", manual_select=False, ignore_cache=False): + '''returns details from tmdb''' + result = {} + if imdb_id: + result = self.tmdb.get_videodetails_by_externalid( + imdb_id, "imdb_id") + elif tvdb_id: + result = self.tmdb.get_videodetails_by_externalid( + tvdb_id, "tvdb_id") + elif title and media_type in ["movies", "setmovies", "movie"]: + result = self.tmdb.search_movie( + title, year, manual_select=manual_select, ignore_cache=ignore_cache) + elif title and media_type in ["tvshows", "tvshow"]: + result = self.tmdb.search_tvshow( + title, year, manual_select=manual_select, ignore_cache=ignore_cache) + elif title: + result = self.tmdb.search_video( + title, year, preftype=preftype, manual_select=manual_select, ignore_cache=ignore_cache) + if result and result.get("status"): + result["status"] = self.translate_string(result["status"]) + if result and result.get("runtime"): + result["runtime"] = result["runtime"] / 60 + result.update(self.get_duration(result["runtime"])) + return result + + @use_cache(90) + def get_moviesetdetails(self, title, set_id): + '''get a nicely formatted dict of the movieset details which we can for example set as window props''' + # get details from tmdb + if not self._get_moviesetdetails: + from helpers.moviesetdetails import get_moviesetdetails + self._get_moviesetdetails = get_moviesetdetails + return self._get_moviesetdetails(self, title, set_id) + + @use_cache(14) + def get_streamdetails(self, db_id, media_type, ignore_cache=False): + '''get a nicely formatted dict of the streamdetails''' + if not self._get_streamdetails: + from helpers.streamdetails import get_streamdetails + self._get_streamdetails = get_streamdetails + return self._get_streamdetails(self.kodidb, db_id, media_type) + + def get_pvr_artwork(self, title, channel="", genre="", manual_select=False, ignore_cache=False): + '''get artwork and mediadetails for PVR entries''' + return self.pvrart.get_pvr_artwork( + title, channel, genre, manual_select=manual_select, ignore_cache=ignore_cache) + + def pvr_artwork_options(self, title, channel="", genre=""): + '''options for pvr metadata for specific item''' + return self.pvrart.pvr_artwork_options(title, channel, genre) + + def get_channellogo(self, channelname): + '''get channellogo for the given channel name''' + return self.channellogos.get_channellogo(channelname) + + def get_studio_logo(self, studio): + '''get studio logo for the given studio''' + # dont use cache at this level because of changing logospath + return self.studiologos.get_studio_logo(studio, self.studiologos_path) + + @property + def studiologos_path(self): + '''path to use to lookup studio logos, must be set by the calling addon''' + return self._studiologos_path + + @studiologos_path.setter + def studiologos_path(self, value): + '''path to use to lookup studio logos, must be set by the calling addon''' + self._studiologos_path = value + + def get_animated_artwork(self, imdb_id, manual_select=False, ignore_cache=False): + '''get animated artwork, perform extra check if local version still exists''' + artwork = self.animatedart.get_animated_artwork( + imdb_id, manual_select=manual_select, ignore_cache=ignore_cache) + if not (manual_select or ignore_cache): + refresh_needed = False + if artwork.get("animatedposter") and not xbmcvfs.exists( + artwork["animatedposter"]): + refresh_needed = True + if artwork.get("animatedfanart") and not xbmcvfs.exists( + artwork["animatedfanart"]): + refresh_needed = True + + return {"art": artwork} + + @use_cache(90) + def get_omdb_info(self, imdb_id="", title="", year="", content_type=""): + '''Get (kodi compatible formatted) metadata from OMDB, including Rotten tomatoes details''' + title = title.split(" (")[0] # strip year appended to title + result = {} + if imdb_id: + result = self.omdb.get_details_by_imdbid(imdb_id) + elif title and content_type in ["seasons", "season", "episodes", "episode", "tvshows", "tvshow"]: + result = self.omdb.get_details_by_title(title, "", "tvshows") + elif title and year: + result = self.omdb.get_details_by_title(title, year, content_type) + if result and result.get("status"): + result["status"] = self.translate_string(result["status"]) + if result and result.get("runtime"): + result["runtime"] = result["runtime"] / 60 + result.update(self.get_duration(result["runtime"])) + return result + + def get_top250_rating(self, imdb_id): + '''get the position in the IMDB top250 for the given IMDB ID''' + return self.imdb.get_top250_rating(imdb_id) + + @use_cache(14) + def get_duration(self, duration): + '''helper to get a formatted duration''' + if not self._get_duration: + from helpers.utils import get_duration + self._get_duration = get_duration + if sys.version_info.major == 3: + if isinstance(duration, str) and ":" in duration: + dur_lst = duration.split(":") + return { + "Duration": "%s:%s" % (dur_lst[0], dur_lst[1]), + "Duration.Hours": dur_lst[0], + "Duration.Minutes": dur_lst[1], + "Runtime": int(dur_lst[0]) * 60 + int(dur_lst[1]), + } + else: + return self._get_duration(duration) + else: + if isinstance(duration, (str, unicode)) and ":" in duration: + dur_lst = duration.split(":") + return { + "Duration": "%s:%s" % (dur_lst[0], dur_lst[1]), + "Duration.Hours": dur_lst[0], + "Duration.Minutes": dur_lst[1], + "Runtime": str((int(dur_lst[0]) * 60) + int(dur_lst[1])), + } + else: + return self._get_duration(duration) + + @use_cache(2) + def get_tvdb_details(self, imdbid="", tvdbid=""): + '''get metadata from tvdb by providing a tvdbid or tmdbid''' + result = {} + self.thetvdb.days_ahead = 365 + if not tvdbid and imdbid and not imdbid.startswith("tt"): + # assume imdbid is actually a tvdbid... + tvdbid = imdbid + if tvdbid: + result = self.thetvdb.get_series(tvdbid) + elif imdbid: + result = self.thetvdb.get_series_by_imdb_id(imdbid) + if result: + if result["status"] == "Continuing": + # include next episode info + result["nextepisode"] = self.thetvdb.get_nextaired_episode(result["tvdb_id"]) + # include last episode info + result["lastepisode"] = self.thetvdb.get_last_episode_for_series(result["tvdb_id"]) + result["status"] = self.translate_string(result["status"]) + if result.get("runtime"): + result["runtime"] = result["runtime"] / 60 + result.update(self.get_duration(result["runtime"])) + return result + + @use_cache(90) + def get_imdbtvdb_id(self, title, content_type, year="", imdbid="", tvshowtitle=""): + '''try to figure out the imdbnumber and/or tvdbid''' + tvdbid = "" + if content_type in ["seasons", "episodes"] or tvshowtitle: + title = tvshowtitle + content_type = "tvshows" + if imdbid and not imdbid.startswith("tt"): + if content_type in ["tvshows", "seasons", "episodes"]: + tvdbid = imdbid + imdbid = "" + if not imdbid and year: + omdb_info = self.get_omdb_info("", title, year, content_type) + if omdb_info: + imdbid = omdb_info.get("imdbnumber", "") + if not imdbid: + # repeat without year + omdb_info = self.get_omdb_info("", title, "", content_type) + if omdb_info: + imdbid = omdb_info.get("imdbnumber", "") + # return results + return (imdbid, tvdbid) + + def translate_string(self, _str): + '''translate the received english string from the various sources like tvdb, tmbd etc''' + translation = _str + _str = _str.lower() + if "continuing" in _str: + translation = self.addon.getLocalizedString(32037) + elif "ended" in _str: + translation = self.addon.getLocalizedString(32038) + elif "released" in _str: + translation = self.addon.getLocalizedString(32040) + return translation + + def process_method_on_list(self, *args, **kwargs): + '''expose our process_method_on_list method to public''' + if not self._process_method_on_list: + from helpers.utils import process_method_on_list + self._process_method_on_list = process_method_on_list + return self._process_method_on_list(*args, **kwargs) + + def detect_plugin_content(self, *args, **kwargs): + '''expose our detect_plugin_content method to public''' + if not self._detect_plugin_content: + from helpers.utils import detect_plugin_content + self._detect_plugin_content = detect_plugin_content + return self._detect_plugin_content(*args, **kwargs) + + def extend_dict(self, *args, **kwargs): + '''expose our extend_dict method to public''' + if not self._extend_dict: + from helpers.utils import extend_dict + self._extend_dict = extend_dict + return self._extend_dict(*args, **kwargs) + + def get_clean_image(self, *args, **kwargs): + '''expose our get_clean_image method to public''' + if not self._get_clean_image: + from helpers.utils import get_clean_image + self._get_clean_image = get_clean_image + return self._get_clean_image(*args, **kwargs) + + @property + def omdb(self): + '''public omdb object - for lazy loading''' + if not self._omdb: + from helpers.omdb import Omdb + self._omdb = Omdb(self.cache) + return self._omdb + + @property + def kodidb(self): + '''public kodidb object - for lazy loading''' + if not self._kodidb: + from helpers.kodidb import KodiDb + self._kodidb = KodiDb() + return self._kodidb + + @property + def tmdb(self): + '''public Tmdb object - for lazy loading''' + if not self._tmdb: + from helpers.tmdb import Tmdb + self._tmdb = Tmdb(self.cache) + return self._tmdb + + @property + def fanarttv(self): + '''public FanartTv object - for lazy loading''' + if not self._fanarttv: + from helpers.fanarttv import FanartTv + self._fanarttv = FanartTv(self.cache) + return self._fanarttv + + @property + def channellogos(self): + '''public ChannelLogos object - for lazy loading''' + if not self._channellogos: + from helpers.channellogos import ChannelLogos + self._channellogos = ChannelLogos(self.kodidb) + return self._channellogos + + @property + def imdb(self): + '''public Imdb object - for lazy loading''' + if not self._imdb: + from helpers.imdb import Imdb + self._imdb = Imdb(self.cache) + return self._imdb + + @property + def google(self): + '''public GoogleImages object - for lazy loading''' + if not self._google: + from helpers.google import GoogleImages + self._google = GoogleImages(self.cache) + return self._google + + @property + def studiologos(self): + '''public StudioLogos object - for lazy loading''' + if not self._studiologos: + from helpers.studiologos import StudioLogos + self._studiologos = StudioLogos(self.cache) + return self._studiologos + + @property + def animatedart(self): + '''public AnimatedArt object - for lazy loading''' + if not self._animatedart: + from helpers.animatedart import AnimatedArt + self._animatedart = AnimatedArt(self.cache, self.kodidb) + return self._animatedart + + @property + def thetvdb(self): + '''public TheTvDb object - for lazy loading''' + if not self._thetvdb: + from thetvdb import TheTvDb + self._thetvdb = TheTvDb() + return self._thetvdb + + @property + def musicart(self): + '''public MusicArtwork object - for lazy loading''' + if not self._musicart: + from helpers.musicartwork import MusicArtwork + self._musicart = MusicArtwork(self) + return self._musicart + + @property + def pvrart(self): + '''public PvrArtwork object - for lazy loading''' + if not self._pvrart: + from helpers.pvrartwork import PvrArtwork + self._pvrart = PvrArtwork(self) + return self._pvrart + + @property + def addon(self): + '''public Addon object - for lazy loading''' + if not self._addon: + import xbmcaddon + self._addon = xbmcaddon.Addon(ADDON_ID) + return self._addon + + @property + def lastfm(self): + '''public LastFM object - for lazy loading''' + if not self._lastfm: + from helpers.lastfm import LastFM + self._lastfm = LastFM() + return self._lastfm + + @property + def audiodb(self): + '''public TheAudioDb object - for lazy loading''' + if not self._audiodb: + from helpers.theaudiodb import TheAudioDb + self._audiodb = TheAudioDb() + return self._audiodb + + def close(self): + '''Cleanup instances''' + self._close_called = True + if self.cache: + self.cache.close() + del self.cache + if self._addon: + del self._addon + if self._thetvdb: + del self._thetvdb + log_msg("Exited") + + def __del__(self): + '''make sure close is called''' + if not self._close_called: + self.close() diff --git a/resources/lib/more_itertools/__init__.py b/resources/lib/more_itertools/__init__.py new file mode 100644 index 0000000..7921b1c --- /dev/null +++ b/resources/lib/more_itertools/__init__.py @@ -0,0 +1,4 @@ +from .more import * # noqa +from .recipes import * # noqa + +__version__ = '8.3.0' diff --git a/resources/lib/more_itertools/more.py b/resources/lib/more_itertools/more.py new file mode 100644 index 0000000..c338ecc --- /dev/null +++ b/resources/lib/more_itertools/more.py @@ -0,0 +1,3143 @@ +import warnings +from collections import Counter, defaultdict, deque, abc +from collections.abc import Sequence +from functools import partial, wraps +from heapq import merge, heapify, heapreplace, heappop +from itertools import ( + chain, + compress, + count, + cycle, + dropwhile, + groupby, + islice, + repeat, + starmap, + takewhile, + tee, + zip_longest, +) +from math import exp, floor, log +from random import random, randrange, uniform +from operator import itemgetter, sub, gt, lt +from sys import maxsize +from time import monotonic + +from .recipes import consume, flatten, powerset, take, unique_everseen + +__all__ = [ + 'adjacent', + 'always_iterable', + 'always_reversible', + 'bucket', + 'chunked', + 'circular_shifts', + 'collapse', + 'collate', + 'consecutive_groups', + 'consumer', + 'count_cycle', + 'difference', + 'distinct_combinations', + 'distinct_permutations', + 'distribute', + 'divide', + 'exactly_n', + 'filter_except', + 'first', + 'groupby_transform', + 'ilen', + 'interleave_longest', + 'interleave', + 'intersperse', + 'islice_extended', + 'iterate', + 'ichunked', + 'last', + 'locate', + 'lstrip', + 'make_decorator', + 'map_except', + 'map_reduce', + 'nth_or_last', + 'numeric_range', + 'one', + 'only', + 'padded', + 'partitions', + 'set_partitions', + 'peekable', + 'repeat_last', + 'replace', + 'rlocate', + 'rstrip', + 'run_length', + 'sample', + 'seekable', + 'SequenceView', + 'side_effect', + 'sliced', + 'sort_together', + 'split_at', + 'split_after', + 'split_before', + 'split_when', + 'split_into', + 'spy', + 'stagger', + 'strip', + 'substrings', + 'substrings_indexes', + 'time_limited', + 'unique_to_each', + 'unzip', + 'windowed', + 'with_iter', + 'UnequalIterablesError', + 'zip_equal', + 'zip_offset', +] + +_marker = object() + + +def chunked(iterable, n): + """Break *iterable* into lists of length *n*: + + >>> list(chunked([1, 2, 3, 4, 5, 6], 3)) + [[1, 2, 3], [4, 5, 6]] + + If the length of *iterable* is not evenly divisible by *n*, the last + returned list will be shorter: + + >>> list(chunked([1, 2, 3, 4, 5, 6, 7, 8], 3)) + [[1, 2, 3], [4, 5, 6], [7, 8]] + + To use a fill-in value instead, see the :func:`grouper` recipe. + + :func:`chunked` is useful for splitting up a computation on a large number + of keys into batches, to be pickled and sent off to worker processes. One + example is operations on rows in MySQL, which does not implement + server-side cursors properly and would otherwise load the entire dataset + into RAM on the client. + + """ + return iter(partial(take, n, iter(iterable)), []) + + +def first(iterable, default=_marker): + """Return the first item of *iterable*, or *default* if *iterable* is + empty. + + >>> first([0, 1, 2, 3]) + 0 + >>> first([], 'some default') + 'some default' + + If *default* is not provided and there are no items in the iterable, + raise ``ValueError``. + + :func:`first` is useful when you have a generator of expensive-to-retrieve + values and want any arbitrary one. It is marginally shorter than + ``next(iter(iterable), default)``. + + """ + try: + return next(iter(iterable)) + except StopIteration: + # I'm on the edge about raising ValueError instead of StopIteration. At + # the moment, ValueError wins, because the caller could conceivably + # want to do something different with flow control when I raise the + # exception, and it's weird to explicitly catch StopIteration. + if default is _marker: + raise ValueError( + 'first() was called on an empty iterable, and no ' + 'default value was provided.' + ) + return default + + +def last(iterable, default=_marker): + """Return the last item of *iterable*, or *default* if *iterable* is + empty. + + >>> last([0, 1, 2, 3]) + 3 + >>> last([], 'some default') + 'some default' + + If *default* is not provided and there are no items in the iterable, + raise ``ValueError``. + """ + try: + try: + # Try to access the last item directly + return iterable[-1] + except (TypeError, AttributeError, KeyError): + # If not slice-able, iterate entirely using length-1 deque + return deque(iterable, maxlen=1)[0] + except IndexError: # If the iterable was empty + if default is _marker: + raise ValueError( + 'last() was called on an empty iterable, and no ' + 'default value was provided.' + ) + return default + + +def nth_or_last(iterable, n, default=_marker): + """Return the nth or the last item of *iterable*, + or *default* if *iterable* is empty. + + >>> nth_or_last([0, 1, 2, 3], 2) + 2 + >>> nth_or_last([0, 1], 2) + 1 + >>> nth_or_last([], 0, 'some default') + 'some default' + + If *default* is not provided and there are no items in the iterable, + raise ``ValueError``. + """ + return last(islice(iterable, n + 1), default=default) + + +class peekable: + """Wrap an iterator to allow lookahead and prepending elements. + + Call :meth:`peek` on the result to get the value that will be returned + by :func:`next`. This won't advance the iterator: + + >>> p = peekable(['a', 'b']) + >>> p.peek() + 'a' + >>> next(p) + 'a' + + Pass :meth:`peek` a default value to return that instead of raising + ``StopIteration`` when the iterator is exhausted. + + >>> p = peekable([]) + >>> p.peek('hi') + 'hi' + + peekables also offer a :meth:`prepend` method, which "inserts" items + at the head of the iterable: + + >>> p = peekable([1, 2, 3]) + >>> p.prepend(10, 11, 12) + >>> next(p) + 10 + >>> p.peek() + 11 + >>> list(p) + [11, 12, 1, 2, 3] + + peekables can be indexed. Index 0 is the item that will be returned by + :func:`next`, index 1 is the item after that, and so on: + The values up to the given index will be cached. + + >>> p = peekable(['a', 'b', 'c', 'd']) + >>> p[0] + 'a' + >>> p[1] + 'b' + >>> next(p) + 'a' + + Negative indexes are supported, but be aware that they will cache the + remaining items in the source iterator, which may require significant + storage. + + To check whether a peekable is exhausted, check its truth value: + + >>> p = peekable(['a', 'b']) + >>> if p: # peekable has items + ... list(p) + ['a', 'b'] + >>> if not p: # peekable is exhaused + ... list(p) + [] + + """ + + def __init__(self, iterable): + self._it = iter(iterable) + self._cache = deque() + + def __iter__(self): + return self + + def __bool__(self): + try: + self.peek() + except StopIteration: + return False + return True + + def peek(self, default=_marker): + """Return the item that will be next returned from ``next()``. + + Return ``default`` if there are no items left. If ``default`` is not + provided, raise ``StopIteration``. + + """ + if not self._cache: + try: + self._cache.append(next(self._it)) + except StopIteration: + if default is _marker: + raise + return default + return self._cache[0] + + def prepend(self, *items): + """Stack up items to be the next ones returned from ``next()`` or + ``self.peek()``. The items will be returned in + first in, first out order:: + + >>> p = peekable([1, 2, 3]) + >>> p.prepend(10, 11, 12) + >>> next(p) + 10 + >>> list(p) + [11, 12, 1, 2, 3] + + It is possible, by prepending items, to "resurrect" a peekable that + previously raised ``StopIteration``. + + >>> p = peekable([]) + >>> next(p) + Traceback (most recent call last): + ... + StopIteration + >>> p.prepend(1) + >>> next(p) + 1 + >>> next(p) + Traceback (most recent call last): + ... + StopIteration + + """ + self._cache.extendleft(reversed(items)) + + def __next__(self): + if self._cache: + return self._cache.popleft() + + return next(self._it) + + def _get_slice(self, index): + # Normalize the slice's arguments + step = 1 if (index.step is None) else index.step + if step > 0: + start = 0 if (index.start is None) else index.start + stop = maxsize if (index.stop is None) else index.stop + elif step < 0: + start = -1 if (index.start is None) else index.start + stop = (-maxsize - 1) if (index.stop is None) else index.stop + else: + raise ValueError('slice step cannot be zero') + + # If either the start or stop index is negative, we'll need to cache + # the rest of the iterable in order to slice from the right side. + if (start < 0) or (stop < 0): + self._cache.extend(self._it) + # Otherwise we'll need to find the rightmost index and cache to that + # point. + else: + n = min(max(start, stop) + 1, maxsize) + cache_len = len(self._cache) + if n >= cache_len: + self._cache.extend(islice(self._it, n - cache_len)) + + return list(self._cache)[index] + + def __getitem__(self, index): + if isinstance(index, slice): + return self._get_slice(index) + + cache_len = len(self._cache) + if index < 0: + self._cache.extend(self._it) + elif index >= cache_len: + self._cache.extend(islice(self._it, index + 1 - cache_len)) + + return self._cache[index] + + +def collate(*iterables, **kwargs): + """Return a sorted merge of the items from each of several already-sorted + *iterables*. + + >>> list(collate('ACDZ', 'AZ', 'JKL')) + ['A', 'A', 'C', 'D', 'J', 'K', 'L', 'Z', 'Z'] + + Works lazily, keeping only the next value from each iterable in memory. Use + :func:`collate` to, for example, perform a n-way mergesort of items that + don't fit in memory. + + If a *key* function is specified, the iterables will be sorted according + to its result: + + >>> key = lambda s: int(s) # Sort by numeric value, not by string + >>> list(collate(['1', '10'], ['2', '11'], key=key)) + ['1', '2', '10', '11'] + + + If the *iterables* are sorted in descending order, set *reverse* to + ``True``: + + >>> list(collate([5, 3, 1], [4, 2, 0], reverse=True)) + [5, 4, 3, 2, 1, 0] + + If the elements of the passed-in iterables are out of order, you might get + unexpected results. + + On Python 3.5+, this function is an alias for :func:`heapq.merge`. + + """ + warnings.warn( + "collate is no longer part of more_itertools, use heapq.merge", + DeprecationWarning, + ) + return merge(*iterables, **kwargs) + + +def consumer(func): + """Decorator that automatically advances a PEP-342-style "reverse iterator" + to its first yield point so you don't have to call ``next()`` on it + manually. + + >>> @consumer + ... def tally(): + ... i = 0 + ... while True: + ... print('Thing number %s is %s.' % (i, (yield))) + ... i += 1 + ... + >>> t = tally() + >>> t.send('red') + Thing number 0 is red. + >>> t.send('fish') + Thing number 1 is fish. + + Without the decorator, you would have to call ``next(t)`` before + ``t.send()`` could be used. + + """ + + @wraps(func) + def wrapper(*args, **kwargs): + gen = func(*args, **kwargs) + next(gen) + return gen + + return wrapper + + +def ilen(iterable): + """Return the number of items in *iterable*. + + >>> ilen(x for x in range(1000000) if x % 3 == 0) + 333334 + + This consumes the iterable, so handle with care. + + """ + # This approach was selected because benchmarks showed it's likely the + # fastest of the known implementations at the time of writing. + # See GitHub tracker: #236, #230. + counter = count() + deque(zip(iterable, counter), maxlen=0) + return next(counter) + + +def iterate(func, start): + """Return ``start``, ``func(start)``, ``func(func(start))``, ... + + >>> from itertools import islice + >>> list(islice(iterate(lambda x: 2*x, 1), 10)) + [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] + + """ + while True: + yield start + start = func(start) + + +def with_iter(context_manager): + """Wrap an iterable in a ``with`` statement, so it closes once exhausted. + + For example, this will close the file when the iterator is exhausted:: + + upper_lines = (line.upper() for line in with_iter(open('foo'))) + + Any context manager which returns an iterable is a candidate for + ``with_iter``. + + """ + with context_manager as iterable: + yield from iterable + + +def one(iterable, too_short=None, too_long=None): + """Return the first item from *iterable*, which is expected to contain only + that item. Raise an exception if *iterable* is empty or has more than one + item. + + :func:`one` is useful for ensuring that an iterable contains only one item. + For example, it can be used to retrieve the result of a database query + that is expected to return a single row. + + If *iterable* is empty, ``ValueError`` will be raised. You may specify a + different exception with the *too_short* keyword: + + >>> it = [] + >>> one(it) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ValueError: too many items in iterable (expected 1)' + >>> too_short = IndexError('too few items') + >>> one(it, too_short=too_short) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + IndexError: too few items + + Similarly, if *iterable* contains more than one item, ``ValueError`` will + be raised. You may specify a different exception with the *too_long* + keyword: + + >>> it = ['too', 'many'] + >>> one(it) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ValueError: Expected exactly one item in iterable, but got 'too', + 'many', and perhaps more. + >>> too_long = RuntimeError + >>> one(it, too_long=too_long) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + RuntimeError + + Note that :func:`one` attempts to advance *iterable* twice to ensure there + is only one item. See :func:`spy` or :func:`peekable` to check iterable + contents less destructively. + + """ + it = iter(iterable) + + try: + first_value = next(it) + except StopIteration: + raise too_short or ValueError('too few items in iterable (expected 1)') + + try: + second_value = next(it) + except StopIteration: + pass + else: + msg = ( + 'Expected exactly one item in iterable, but got {!r}, {!r}, ' + 'and perhaps more.'.format(first_value, second_value) + ) + raise too_long or ValueError(msg) + + return first_value + + +def distinct_permutations(iterable, r=None): + """Yield successive distinct permutations of the elements in *iterable*. + + >>> sorted(distinct_permutations([1, 0, 1])) + [(0, 1, 1), (1, 0, 1), (1, 1, 0)] + + Equivalent to ``set(permutations(iterable))``, except duplicates are not + generated and thrown away. For larger input sequences this is much more + efficient. + + Duplicate permutations arise when there are duplicated elements in the + input iterable. The number of items returned is + `n! / (x_1! * x_2! * ... * x_n!)`, where `n` is the total number of + items input, and each `x_i` is the count of a distinct item in the input + sequence. + + If *r* is given, only the *r*-length permutations are yielded. + + >>> sorted(distinct_permutations([1, 0, 1], r=3)) + [(0, 1, 1), (1, 0, 1), (1, 1, 0)] + + """ + # Algorithm: https://w.wiki/Qai + def _full(A): + while True: + # Yield the permutation we have + yield tuple(A) + + # Find the largest index i such that A[i] < A[i + 1] + for i in range(size - 2, -1, -1): + if A[i] < A[i + 1]: + break + # If no such index exists, this permutation is the last one + else: + return + + # Find the largest index j greater than j such that A[i] < A[j] + for j in range(size - 1, i, -1): + if A[i] < A[j]: + break + + # Swap the value of A[i] with that of A[j], then reverse the + # sequence from A[i + 1] to form the new permutation + A[i], A[j] = A[j], A[i] + A[i + 1:] = A[:i - size:-1] # A[i + 1:][::-1] + + # Algorithm: modified from the above + def _partial(A, r): + # Split A into the first r items and the last r items + head, tail = A[:r], A[r:] + right_head_indexes = range(r - 1, -1, -1) + left_tail_indexes = range(len(tail)) + + while True: + # Yield the permutation we have + yield tuple(head) + + # Starting from the right, find the first index of the head with + # value smaller than the maximum value of the tail - call it i. + pivot = tail[-1] + for i in right_head_indexes: + if head[i] < pivot: + break + pivot = head[i] + else: + return + + # Starting from the left, find the first value of the tail + # with a value greater than head[i] and swap. + for j in left_tail_indexes: + if tail[j] > head[i]: + head[i], tail[j] = tail[j], head[i] + break + # If we didn't find one, start from the right and find the first + # index of the head with a value greater than head[i] and swap. + else: + for j in right_head_indexes: + if head[j] > head[i]: + head[i], head[j] = head[j], head[i] + break + + # Reverse head[i + 1:] and swap it with tail[:r - (i + 1)] + tail += head[:i - r:-1] # head[i + 1:][::-1] + i += 1 + head[i:], tail[:] = tail[:r - i], tail[r - i:] + + items = sorted(iterable) + + size = len(items) + if r is None: + r = size + + if 0 < r <= size: + return _full(items) if (r == size) else _partial(items, r) + + return iter(() if r else ((),)) + + +def intersperse(e, iterable, n=1): + """Intersperse filler element *e* among the items in *iterable*, leaving + *n* items between each filler element. + + >>> list(intersperse('!', [1, 2, 3, 4, 5])) + [1, '!', 2, '!', 3, '!', 4, '!', 5] + + >>> list(intersperse(None, [1, 2, 3, 4, 5], n=2)) + [1, 2, None, 3, 4, None, 5] + + """ + if n == 0: + raise ValueError('n must be > 0') + elif n == 1: + # interleave(repeat(e), iterable) -> e, x_0, e, e, x_1, e, x_2... + # islice(..., 1, None) -> x_0, e, e, x_1, e, x_2... + return islice(interleave(repeat(e), iterable), 1, None) + else: + # interleave(filler, chunks) -> [e], [x_0, x_1], [e], [x_2, x_3]... + # islice(..., 1, None) -> [x_0, x_1], [e], [x_2, x_3]... + # flatten(...) -> x_0, x_1, e, x_2, x_3... + filler = repeat([e]) + chunks = chunked(iterable, n) + return flatten(islice(interleave(filler, chunks), 1, None)) + + +def unique_to_each(*iterables): + """Return the elements from each of the input iterables that aren't in the + other input iterables. + + For example, suppose you have a set of packages, each with a set of + dependencies:: + + {'pkg_1': {'A', 'B'}, 'pkg_2': {'B', 'C'}, 'pkg_3': {'B', 'D'}} + + If you remove one package, which dependencies can also be removed? + + If ``pkg_1`` is removed, then ``A`` is no longer necessary - it is not + associated with ``pkg_2`` or ``pkg_3``. Similarly, ``C`` is only needed for + ``pkg_2``, and ``D`` is only needed for ``pkg_3``:: + + >>> unique_to_each({'A', 'B'}, {'B', 'C'}, {'B', 'D'}) + [['A'], ['C'], ['D']] + + If there are duplicates in one input iterable that aren't in the others + they will be duplicated in the output. Input order is preserved:: + + >>> unique_to_each("mississippi", "missouri") + [['p', 'p'], ['o', 'u', 'r']] + + It is assumed that the elements of each iterable are hashable. + + """ + pool = [list(it) for it in iterables] + counts = Counter(chain.from_iterable(map(set, pool))) + uniques = {element for element in counts if counts[element] == 1} + return [list(filter(uniques.__contains__, it)) for it in pool] + + +def windowed(seq, n, fillvalue=None, step=1): + """Return a sliding window of width *n* over the given iterable. + + >>> all_windows = windowed([1, 2, 3, 4, 5], 3) + >>> list(all_windows) + [(1, 2, 3), (2, 3, 4), (3, 4, 5)] + + When the window is larger than the iterable, *fillvalue* is used in place + of missing values: + + >>> list(windowed([1, 2, 3], 4)) + [(1, 2, 3, None)] + + Each window will advance in increments of *step*: + + >>> list(windowed([1, 2, 3, 4, 5, 6], 3, fillvalue='!', step=2)) + [(1, 2, 3), (3, 4, 5), (5, 6, '!')] + + To slide into the iterable's items, use :func:`chain` to add filler items + to the left: + + >>> iterable = [1, 2, 3, 4] + >>> n = 3 + >>> padding = [None] * (n - 1) + >>> list(windowed(chain(padding, iterable), 3)) + [(None, None, 1), (None, 1, 2), (1, 2, 3), (2, 3, 4)] + """ + if n < 0: + raise ValueError('n must be >= 0') + if n == 0: + yield tuple() + return + if step < 1: + raise ValueError('step must be >= 1') + + window = deque(maxlen=n) + i = n + for _ in map(window.append, seq): + i -= 1 + if not i: + i = step + yield tuple(window) + + size = len(window) + if size < n: + yield tuple(chain(window, repeat(fillvalue, n - size))) + elif 0 < i < min(step, n): + window += (fillvalue,) * i + yield tuple(window) + + +def substrings(iterable): + """Yield all of the substrings of *iterable*. + + >>> [''.join(s) for s in substrings('more')] + ['m', 'o', 'r', 'e', 'mo', 'or', 're', 'mor', 'ore', 'more'] + + Note that non-string iterables can also be subdivided. + + >>> list(substrings([0, 1, 2])) + [(0,), (1,), (2,), (0, 1), (1, 2), (0, 1, 2)] + + """ + # The length-1 substrings + seq = [] + for item in iter(iterable): + seq.append(item) + yield (item,) + seq = tuple(seq) + item_count = len(seq) + + # And the rest + for n in range(2, item_count + 1): + for i in range(item_count - n + 1): + yield seq[i : i + n] + + +def substrings_indexes(seq, reverse=False): + """Yield all substrings and their positions in *seq* + + The items yielded will be a tuple of the form ``(substr, i, j)``, where + ``substr == seq[i:j]``. + + This function only works for iterables that support slicing, such as + ``str`` objects. + + >>> for item in substrings_indexes('more'): + ... print(item) + ('m', 0, 1) + ('o', 1, 2) + ('r', 2, 3) + ('e', 3, 4) + ('mo', 0, 2) + ('or', 1, 3) + ('re', 2, 4) + ('mor', 0, 3) + ('ore', 1, 4) + ('more', 0, 4) + + Set *reverse* to ``True`` to yield the same items in the opposite order. + + + """ + r = range(1, len(seq) + 1) + if reverse: + r = reversed(r) + return ( + (seq[i : i + L], i, i + L) for L in r for i in range(len(seq) - L + 1) + ) + + +class bucket: + """Wrap *iterable* and return an object that buckets it iterable into + child iterables based on a *key* function. + + >>> iterable = ['a1', 'b1', 'c1', 'a2', 'b2', 'c2', 'b3'] + >>> s = bucket(iterable, key=lambda x: x[0]) # Bucket by 1st character + >>> sorted(list(s)) # Get the keys + ['a', 'b', 'c'] + >>> a_iterable = s['a'] + >>> next(a_iterable) + 'a1' + >>> next(a_iterable) + 'a2' + >>> list(s['b']) + ['b1', 'b2', 'b3'] + + The original iterable will be advanced and its items will be cached until + they are used by the child iterables. This may require significant storage. + + By default, attempting to select a bucket to which no items belong will + exhaust the iterable and cache all values. + If you specify a *validator* function, selected buckets will instead be + checked against it. + + >>> from itertools import count + >>> it = count(1, 2) # Infinite sequence of odd numbers + >>> key = lambda x: x % 10 # Bucket by last digit + >>> validator = lambda x: x in {1, 3, 5, 7, 9} # Odd digits only + >>> s = bucket(it, key=key, validator=validator) + >>> 2 in s + False + >>> list(s[2]) + [] + + """ + + def __init__(self, iterable, key, validator=None): + self._it = iter(iterable) + self._key = key + self._cache = defaultdict(deque) + self._validator = validator or (lambda x: True) + + def __contains__(self, value): + if not self._validator(value): + return False + + try: + item = next(self[value]) + except StopIteration: + return False + else: + self._cache[value].appendleft(item) + + return True + + def _get_values(self, value): + """ + Helper to yield items from the parent iterator that match *value*. + Items that don't match are stored in the local cache as they + are encountered. + """ + while True: + # If we've cached some items that match the target value, emit + # the first one and evict it from the cache. + if self._cache[value]: + yield self._cache[value].popleft() + # Otherwise we need to advance the parent iterator to search for + # a matching item, caching the rest. + else: + while True: + try: + item = next(self._it) + except StopIteration: + return + item_value = self._key(item) + if item_value == value: + yield item + break + elif self._validator(item_value): + self._cache[item_value].append(item) + + def __iter__(self): + for item in self._it: + item_value = self._key(item) + if self._validator(item_value): + self._cache[item_value].append(item) + + yield from self._cache.keys() + + def __getitem__(self, value): + if not self._validator(value): + return iter(()) + + return self._get_values(value) + + +def spy(iterable, n=1): + """Return a 2-tuple with a list containing the first *n* elements of + *iterable*, and an iterator with the same items as *iterable*. + This allows you to "look ahead" at the items in the iterable without + advancing it. + + There is one item in the list by default: + + >>> iterable = 'abcdefg' + >>> head, iterable = spy(iterable) + >>> head + ['a'] + >>> list(iterable) + ['a', 'b', 'c', 'd', 'e', 'f', 'g'] + + You may use unpacking to retrieve items instead of lists: + + >>> (head,), iterable = spy('abcdefg') + >>> head + 'a' + >>> (first, second), iterable = spy('abcdefg', 2) + >>> first + 'a' + >>> second + 'b' + + The number of items requested can be larger than the number of items in + the iterable: + + >>> iterable = [1, 2, 3, 4, 5] + >>> head, iterable = spy(iterable, 10) + >>> head + [1, 2, 3, 4, 5] + >>> list(iterable) + [1, 2, 3, 4, 5] + + """ + it = iter(iterable) + head = take(n, it) + + return head.copy(), chain(head, it) + + +def interleave(*iterables): + """Return a new iterable yielding from each iterable in turn, + until the shortest is exhausted. + + >>> list(interleave([1, 2, 3], [4, 5], [6, 7, 8])) + [1, 4, 6, 2, 5, 7] + + For a version that doesn't terminate after the shortest iterable is + exhausted, see :func:`interleave_longest`. + + """ + return chain.from_iterable(zip(*iterables)) + + +def interleave_longest(*iterables): + """Return a new iterable yielding from each iterable in turn, + skipping any that are exhausted. + + >>> list(interleave_longest([1, 2, 3], [4, 5], [6, 7, 8])) + [1, 4, 6, 2, 5, 7, 3, 8] + + This function produces the same output as :func:`roundrobin`, but may + perform better for some inputs (in particular when the number of iterables + is large). + + """ + i = chain.from_iterable(zip_longest(*iterables, fillvalue=_marker)) + return (x for x in i if x is not _marker) + + +def collapse(iterable, base_type=None, levels=None): + """Flatten an iterable with multiple levels of nesting (e.g., a list of + lists of tuples) into non-iterable types. + + >>> iterable = [(1, 2), ([3, 4], [[5], [6]])] + >>> list(collapse(iterable)) + [1, 2, 3, 4, 5, 6] + + Binary and text strings are not considered iterable and + will not be collapsed. + + To avoid collapsing other types, specify *base_type*: + + >>> iterable = ['ab', ('cd', 'ef'), ['gh', 'ij']] + >>> list(collapse(iterable, base_type=tuple)) + ['ab', ('cd', 'ef'), 'gh', 'ij'] + + Specify *levels* to stop flattening after a certain level: + + >>> iterable = [('a', ['b']), ('c', ['d'])] + >>> list(collapse(iterable)) # Fully flattened + ['a', 'b', 'c', 'd'] + >>> list(collapse(iterable, levels=1)) # Only one level flattened + ['a', ['b'], 'c', ['d']] + + """ + + def walk(node, level): + if ( + ((levels is not None) and (level > levels)) + or isinstance(node, (str, bytes)) + or ((base_type is not None) and isinstance(node, base_type)) + ): + yield node + return + + try: + tree = iter(node) + except TypeError: + yield node + return + else: + for child in tree: + yield from walk(child, level + 1) + + yield from walk(iterable, 0) + + +def side_effect(func, iterable, chunk_size=None, before=None, after=None): + """Invoke *func* on each item in *iterable* (or on each *chunk_size* group + of items) before yielding the item. + + `func` must be a function that takes a single argument. Its return value + will be discarded. + + *before* and *after* are optional functions that take no arguments. They + will be executed before iteration starts and after it ends, respectively. + + `side_effect` can be used for logging, updating progress bars, or anything + that is not functionally "pure." + + Emitting a status message: + + >>> from more_itertools import consume + >>> func = lambda item: print('Received {}'.format(item)) + >>> consume(side_effect(func, range(2))) + Received 0 + Received 1 + + Operating on chunks of items: + + >>> pair_sums = [] + >>> func = lambda chunk: pair_sums.append(sum(chunk)) + >>> list(side_effect(func, [0, 1, 2, 3, 4, 5], 2)) + [0, 1, 2, 3, 4, 5] + >>> list(pair_sums) + [1, 5, 9] + + Writing to a file-like object: + + >>> from io import StringIO + >>> from more_itertools import consume + >>> f = StringIO() + >>> func = lambda x: print(x, file=f) + >>> before = lambda: print(u'HEADER', file=f) + >>> after = f.close + >>> it = [u'a', u'b', u'c'] + >>> consume(side_effect(func, it, before=before, after=after)) + >>> f.closed + True + + """ + try: + if before is not None: + before() + + if chunk_size is None: + for item in iterable: + func(item) + yield item + else: + for chunk in chunked(iterable, chunk_size): + func(chunk) + yield from chunk + finally: + if after is not None: + after() + + +def sliced(seq, n): + """Yield slices of length *n* from the sequence *seq*. + + >>> list(sliced((1, 2, 3, 4, 5, 6), 3)) + [(1, 2, 3), (4, 5, 6)] + + If the length of the sequence is not divisible by the requested slice + length, the last slice will be shorter. + + >>> list(sliced((1, 2, 3, 4, 5, 6, 7, 8), 3)) + [(1, 2, 3), (4, 5, 6), (7, 8)] + + This function will only work for iterables that support slicing. + For non-sliceable iterables, see :func:`chunked`. + + """ + return takewhile(len, (seq[i : i + n] for i in count(0, n))) + + +def split_at(iterable, pred, maxsplit=-1, keep_separator=False): + """Yield lists of items from *iterable*, where each list is delimited by + an item where callable *pred* returns ``True``. + + >>> list(split_at('abcdcba', lambda x: x == 'b')) + [['a'], ['c', 'd', 'c'], ['a']] + + >>> list(split_at(range(10), lambda n: n % 2 == 1)) + [[0], [2], [4], [6], [8], []] + + At most *maxsplit* splits are done. If *maxsplit* is not specified or -1, + then there is no limit on the number of splits: + + >>> list(split_at(range(10), lambda n: n % 2 == 1, maxsplit=2)) + [[0], [2], [4, 5, 6, 7, 8, 9]] + + By default, the delimiting items are not included in the output. + The include them, set *keep_separator* to ``True``. + + >>> list(split_at('abcdcba', lambda x: x == 'b', keep_separator=True)) + [['a'], ['b'], ['c', 'd', 'c'], ['b'], ['a']] + + """ + if maxsplit == 0: + yield list(iterable) + return + + buf = [] + it = iter(iterable) + for item in it: + if pred(item): + yield buf + if keep_separator: + yield [item] + if maxsplit == 1: + yield list(it) + return + buf = [] + maxsplit -= 1 + else: + buf.append(item) + yield buf + + +def split_before(iterable, pred, maxsplit=-1): + """Yield lists of items from *iterable*, where each list ends just before + an item for which callable *pred* returns ``True``: + + >>> list(split_before('OneTwo', lambda s: s.isupper())) + [['O', 'n', 'e'], ['T', 'w', 'o']] + + >>> list(split_before(range(10), lambda n: n % 3 == 0)) + [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] + + At most *maxsplit* splits are done. If *maxsplit* is not specified or -1, + then there is no limit on the number of splits: + + >>> list(split_before(range(10), lambda n: n % 3 == 0, maxsplit=2)) + [[0, 1, 2], [3, 4, 5], [6, 7, 8, 9]] + """ + if maxsplit == 0: + yield list(iterable) + return + + buf = [] + it = iter(iterable) + for item in it: + if pred(item) and buf: + yield buf + if maxsplit == 1: + yield [item] + list(it) + return + buf = [] + maxsplit -= 1 + buf.append(item) + yield buf + + +def split_after(iterable, pred, maxsplit=-1): + """Yield lists of items from *iterable*, where each list ends with an + item where callable *pred* returns ``True``: + + >>> list(split_after('one1two2', lambda s: s.isdigit())) + [['o', 'n', 'e', '1'], ['t', 'w', 'o', '2']] + + >>> list(split_after(range(10), lambda n: n % 3 == 0)) + [[0], [1, 2, 3], [4, 5, 6], [7, 8, 9]] + + At most *maxsplit* splits are done. If *maxsplit* is not specified or -1, + then there is no limit on the number of splits: + + >>> list(split_after(range(10), lambda n: n % 3 == 0, maxsplit=2)) + [[0], [1, 2, 3], [4, 5, 6, 7, 8, 9]] + + """ + if maxsplit == 0: + yield list(iterable) + return + + buf = [] + it = iter(iterable) + for item in it: + buf.append(item) + if pred(item) and buf: + yield buf + if maxsplit == 1: + yield list(it) + return + buf = [] + maxsplit -= 1 + if buf: + yield buf + + +def split_when(iterable, pred, maxsplit=-1): + """Split *iterable* into pieces based on the output of *pred*. + *pred* should be a function that takes successive pairs of items and + returns ``True`` if the iterable should be split in between them. + + For example, to find runs of increasing numbers, split the iterable when + element ``i`` is larger than element ``i + 1``: + + >>> list(split_when([1, 2, 3, 3, 2, 5, 2, 4, 2], lambda x, y: x > y)) + [[1, 2, 3, 3], [2, 5], [2, 4], [2]] + + At most *maxsplit* splits are done. If *maxsplit* is not specified or -1, + then there is no limit on the number of splits: + + >>> list(split_when([1, 2, 3, 3, 2, 5, 2, 4, 2], + ... lambda x, y: x > y, maxsplit=2)) + [[1, 2, 3, 3], [2, 5], [2, 4, 2]] + + """ + if maxsplit == 0: + yield list(iterable) + return + + it = iter(iterable) + try: + cur_item = next(it) + except StopIteration: + return + + buf = [cur_item] + for next_item in it: + if pred(cur_item, next_item): + yield buf + if maxsplit == 1: + yield [next_item] + list(it) + return + buf = [] + maxsplit -= 1 + + buf.append(next_item) + cur_item = next_item + + yield buf + + +def split_into(iterable, sizes): + """Yield a list of sequential items from *iterable* of length 'n' for each + integer 'n' in *sizes*. + + >>> list(split_into([1,2,3,4,5,6], [1,2,3])) + [[1], [2, 3], [4, 5, 6]] + + If the sum of *sizes* is smaller than the length of *iterable*, then the + remaining items of *iterable* will not be returned. + + >>> list(split_into([1,2,3,4,5,6], [2,3])) + [[1, 2], [3, 4, 5]] + + If the sum of *sizes* is larger than the length of *iterable*, fewer items + will be returned in the iteration that overruns *iterable* and further + lists will be empty: + + >>> list(split_into([1,2,3,4], [1,2,3,4])) + [[1], [2, 3], [4], []] + + When a ``None`` object is encountered in *sizes*, the returned list will + contain items up to the end of *iterable* the same way that itertools.slice + does: + + >>> list(split_into([1,2,3,4,5,6,7,8,9,0], [2,3,None])) + [[1, 2], [3, 4, 5], [6, 7, 8, 9, 0]] + + :func:`split_into` can be useful for grouping a series of items where the + sizes of the groups are not uniform. An example would be where in a row + from a table, multiple columns represent elements of the same feature + (e.g. a point represented by x,y,z) but, the format is not the same for + all columns. + """ + # convert the iterable argument into an iterator so its contents can + # be consumed by islice in case it is a generator + it = iter(iterable) + + for size in sizes: + if size is None: + yield list(it) + return + else: + yield list(islice(it, size)) + + +def padded(iterable, fillvalue=None, n=None, next_multiple=False): + """Yield the elements from *iterable*, followed by *fillvalue*, such that + at least *n* items are emitted. + + >>> list(padded([1, 2, 3], '?', 5)) + [1, 2, 3, '?', '?'] + + If *next_multiple* is ``True``, *fillvalue* will be emitted until the + number of items emitted is a multiple of *n*:: + + >>> list(padded([1, 2, 3, 4], n=3, next_multiple=True)) + [1, 2, 3, 4, None, None] + + If *n* is ``None``, *fillvalue* will be emitted indefinitely. + + """ + it = iter(iterable) + if n is None: + yield from chain(it, repeat(fillvalue)) + elif n < 1: + raise ValueError('n must be at least 1') + else: + item_count = 0 + for item in it: + yield item + item_count += 1 + + remaining = (n - item_count) % n if next_multiple else n - item_count + for _ in range(remaining): + yield fillvalue + + +def repeat_last(iterable, default=None): + """After the *iterable* is exhausted, keep yielding its last element. + + >>> list(islice(repeat_last(range(3)), 5)) + [0, 1, 2, 2, 2] + + If the iterable is empty, yield *default* forever:: + + >>> list(islice(repeat_last(range(0), 42), 5)) + [42, 42, 42, 42, 42] + + """ + item = _marker + for item in iterable: + yield item + final = default if item is _marker else item + yield from repeat(final) + + +def distribute(n, iterable): + """Distribute the items from *iterable* among *n* smaller iterables. + + >>> group_1, group_2 = distribute(2, [1, 2, 3, 4, 5, 6]) + >>> list(group_1) + [1, 3, 5] + >>> list(group_2) + [2, 4, 6] + + If the length of *iterable* is not evenly divisible by *n*, then the + length of the returned iterables will not be identical: + + >>> children = distribute(3, [1, 2, 3, 4, 5, 6, 7]) + >>> [list(c) for c in children] + [[1, 4, 7], [2, 5], [3, 6]] + + If the length of *iterable* is smaller than *n*, then the last returned + iterables will be empty: + + >>> children = distribute(5, [1, 2, 3]) + >>> [list(c) for c in children] + [[1], [2], [3], [], []] + + This function uses :func:`itertools.tee` and may require significant + storage. If you need the order items in the smaller iterables to match the + original iterable, see :func:`divide`. + + """ + if n < 1: + raise ValueError('n must be at least 1') + + children = tee(iterable, n) + return [islice(it, index, None, n) for index, it in enumerate(children)] + + +def stagger(iterable, offsets=(-1, 0, 1), longest=False, fillvalue=None): + """Yield tuples whose elements are offset from *iterable*. + The amount by which the `i`-th item in each tuple is offset is given by + the `i`-th item in *offsets*. + + >>> list(stagger([0, 1, 2, 3])) + [(None, 0, 1), (0, 1, 2), (1, 2, 3)] + >>> list(stagger(range(8), offsets=(0, 2, 4))) + [(0, 2, 4), (1, 3, 5), (2, 4, 6), (3, 5, 7)] + + By default, the sequence will end when the final element of a tuple is the + last item in the iterable. To continue until the first element of a tuple + is the last item in the iterable, set *longest* to ``True``:: + + >>> list(stagger([0, 1, 2, 3], longest=True)) + [(None, 0, 1), (0, 1, 2), (1, 2, 3), (2, 3, None), (3, None, None)] + + By default, ``None`` will be used to replace offsets beyond the end of the + sequence. Specify *fillvalue* to use some other value. + + """ + children = tee(iterable, len(offsets)) + + return zip_offset( + *children, offsets=offsets, longest=longest, fillvalue=fillvalue + ) + + +class UnequalIterablesError(ValueError): + def __init__(self, details=None): + msg = 'Iterables have different lengths' + if details is not None: + msg += ( + ': index 0 has length {}; index {} has length {}' + ).format(*details) + + super().__init__(msg) + + +def zip_equal(*iterables): + """``zip`` the input *iterables* together, but raise + ``UnequalIterablesError`` if they aren't all the same length. + + >>> it_1 = range(3) + >>> it_2 = iter('abc') + >>> list(zip_equal(it_1, it_2)) + [(0, 'a'), (1, 'b'), (2, 'c')] + + >>> it_1 = range(3) + >>> it_2 = iter('abcd') + >>> list(zip_equal(it_1, it_2)) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + more_itertools.more.UnequalIterablesError: Iterables have different + lengths + + """ + # Check whether the iterables are all the same size. + try: + first_size = len(iterables[0]) + for i, it in enumerate(iterables[1:], 1): + size = len(it) + if size != first_size: + break + else: + # If we didn't break out, we can use the built-in zip. + return zip(*iterables) + + # If we did break out, there was a mismatch. + raise UnequalIterablesError(details=(first_size, i, size)) + # If any one of the iterables didn't have a length, start reading + # them until one runs out. + except TypeError: + return _zip_equal_generator(iterables) + + +def _zip_equal_generator(iterables): + for combo in zip_longest(*iterables, fillvalue=_marker): + for val in combo: + if val is _marker: + raise UnequalIterablesError() + yield combo + + +def zip_offset(*iterables, offsets, longest=False, fillvalue=None): + """``zip`` the input *iterables* together, but offset the `i`-th iterable + by the `i`-th item in *offsets*. + + >>> list(zip_offset('0123', 'abcdef', offsets=(0, 1))) + [('0', 'b'), ('1', 'c'), ('2', 'd'), ('3', 'e')] + + This can be used as a lightweight alternative to SciPy or pandas to analyze + data sets in which some series have a lead or lag relationship. + + By default, the sequence will end when the shortest iterable is exhausted. + To continue until the longest iterable is exhausted, set *longest* to + ``True``. + + >>> list(zip_offset('0123', 'abcdef', offsets=(0, 1), longest=True)) + [('0', 'b'), ('1', 'c'), ('2', 'd'), ('3', 'e'), (None, 'f')] + + By default, ``None`` will be used to replace offsets beyond the end of the + sequence. Specify *fillvalue* to use some other value. + + """ + if len(iterables) != len(offsets): + raise ValueError("Number of iterables and offsets didn't match") + + staggered = [] + for it, n in zip(iterables, offsets): + if n < 0: + staggered.append(chain(repeat(fillvalue, -n), it)) + elif n > 0: + staggered.append(islice(it, n, None)) + else: + staggered.append(it) + + if longest: + return zip_longest(*staggered, fillvalue=fillvalue) + + return zip(*staggered) + + +def sort_together(iterables, key_list=(0,), reverse=False): + """Return the input iterables sorted together, with *key_list* as the + priority for sorting. All iterables are trimmed to the length of the + shortest one. + + This can be used like the sorting function in a spreadsheet. If each + iterable represents a column of data, the key list determines which + columns are used for sorting. + + By default, all iterables are sorted using the ``0``-th iterable:: + + >>> iterables = [(4, 3, 2, 1), ('a', 'b', 'c', 'd')] + >>> sort_together(iterables) + [(1, 2, 3, 4), ('d', 'c', 'b', 'a')] + + Set a different key list to sort according to another iterable. + Specifying multiple keys dictates how ties are broken:: + + >>> iterables = [(3, 1, 2), (0, 1, 0), ('c', 'b', 'a')] + >>> sort_together(iterables, key_list=(1, 2)) + [(2, 3, 1), (0, 0, 1), ('a', 'c', 'b')] + + Set *reverse* to ``True`` to sort in descending order. + + >>> sort_together([(1, 2, 3), ('c', 'b', 'a')], reverse=True) + [(3, 2, 1), ('a', 'b', 'c')] + + """ + return list( + zip( + *sorted( + zip(*iterables), key=itemgetter(*key_list), reverse=reverse + ) + ) + ) + + +def unzip(iterable): + """The inverse of :func:`zip`, this function disaggregates the elements + of the zipped *iterable*. + + The ``i``-th iterable contains the ``i``-th element from each element + of the zipped iterable. The first element is used to to determine the + length of the remaining elements. + + >>> iterable = [('a', 1), ('b', 2), ('c', 3), ('d', 4)] + >>> letters, numbers = unzip(iterable) + >>> list(letters) + ['a', 'b', 'c', 'd'] + >>> list(numbers) + [1, 2, 3, 4] + + This is similar to using ``zip(*iterable)``, but it avoids reading + *iterable* into memory. Note, however, that this function uses + :func:`itertools.tee` and thus may require significant storage. + + """ + head, iterable = spy(iter(iterable)) + if not head: + # empty iterable, e.g. zip([], [], []) + return () + # spy returns a one-length iterable as head + head = head[0] + iterables = tee(iterable, len(head)) + + def itemgetter(i): + def getter(obj): + try: + return obj[i] + except IndexError: + # basically if we have an iterable like + # iter([(1, 2, 3), (4, 5), (6,)]) + # the second unzipped iterable would fail at the third tuple + # since it would try to access tup[1] + # same with the third unzipped iterable and the second tuple + # to support these "improperly zipped" iterables, + # we create a custom itemgetter + # which just stops the unzipped iterables + # at first length mismatch + raise StopIteration + + return getter + + return tuple(map(itemgetter(i), it) for i, it in enumerate(iterables)) + + +def divide(n, iterable): + """Divide the elements from *iterable* into *n* parts, maintaining + order. + + >>> group_1, group_2 = divide(2, [1, 2, 3, 4, 5, 6]) + >>> list(group_1) + [1, 2, 3] + >>> list(group_2) + [4, 5, 6] + + If the length of *iterable* is not evenly divisible by *n*, then the + length of the returned iterables will not be identical: + + >>> children = divide(3, [1, 2, 3, 4, 5, 6, 7]) + >>> [list(c) for c in children] + [[1, 2, 3], [4, 5], [6, 7]] + + If the length of the iterable is smaller than n, then the last returned + iterables will be empty: + + >>> children = divide(5, [1, 2, 3]) + >>> [list(c) for c in children] + [[1], [2], [3], [], []] + + This function will exhaust the iterable before returning and may require + significant storage. If order is not important, see :func:`distribute`, + which does not first pull the iterable into memory. + + """ + if n < 1: + raise ValueError('n must be at least 1') + + try: + iterable[:0] + except TypeError: + seq = tuple(iterable) + else: + seq = iterable + + q, r = divmod(len(seq), n) + + ret = [] + stop = 0 + for i in range(1, n + 1): + start = stop + stop += q + 1 if i <= r else q + ret.append(iter(seq[start:stop])) + + return ret + + +def always_iterable(obj, base_type=(str, bytes)): + """If *obj* is iterable, return an iterator over its items:: + + >>> obj = (1, 2, 3) + >>> list(always_iterable(obj)) + [1, 2, 3] + + If *obj* is not iterable, return a one-item iterable containing *obj*:: + + >>> obj = 1 + >>> list(always_iterable(obj)) + [1] + + If *obj* is ``None``, return an empty iterable: + + >>> obj = None + >>> list(always_iterable(None)) + [] + + By default, binary and text strings are not considered iterable:: + + >>> obj = 'foo' + >>> list(always_iterable(obj)) + ['foo'] + + If *base_type* is set, objects for which ``isinstance(obj, base_type)`` + returns ``True`` won't be considered iterable. + + >>> obj = {'a': 1} + >>> list(always_iterable(obj)) # Iterate over the dict's keys + ['a'] + >>> list(always_iterable(obj, base_type=dict)) # Treat dicts as a unit + [{'a': 1}] + + Set *base_type* to ``None`` to avoid any special handling and treat objects + Python considers iterable as iterable: + + >>> obj = 'foo' + >>> list(always_iterable(obj, base_type=None)) + ['f', 'o', 'o'] + """ + if obj is None: + return iter(()) + + if (base_type is not None) and isinstance(obj, base_type): + return iter((obj,)) + + try: + return iter(obj) + except TypeError: + return iter((obj,)) + + +def adjacent(predicate, iterable, distance=1): + """Return an iterable over `(bool, item)` tuples where the `item` is + drawn from *iterable* and the `bool` indicates whether + that item satisfies the *predicate* or is adjacent to an item that does. + + For example, to find whether items are adjacent to a ``3``:: + + >>> list(adjacent(lambda x: x == 3, range(6))) + [(False, 0), (False, 1), (True, 2), (True, 3), (True, 4), (False, 5)] + + Set *distance* to change what counts as adjacent. For example, to find + whether items are two places away from a ``3``: + + >>> list(adjacent(lambda x: x == 3, range(6), distance=2)) + [(False, 0), (True, 1), (True, 2), (True, 3), (True, 4), (True, 5)] + + This is useful for contextualizing the results of a search function. + For example, a code comparison tool might want to identify lines that + have changed, but also surrounding lines to give the viewer of the diff + context. + + The predicate function will only be called once for each item in the + iterable. + + See also :func:`groupby_transform`, which can be used with this function + to group ranges of items with the same `bool` value. + + """ + # Allow distance=0 mainly for testing that it reproduces results with map() + if distance < 0: + raise ValueError('distance must be at least 0') + + i1, i2 = tee(iterable) + padding = [False] * distance + selected = chain(padding, map(predicate, i1), padding) + adjacent_to_selected = map(any, windowed(selected, 2 * distance + 1)) + return zip(adjacent_to_selected, i2) + + +def groupby_transform(iterable, keyfunc=None, valuefunc=None): + """An extension of :func:`itertools.groupby` that transforms the values of + *iterable* after grouping them. + *keyfunc* is a function used to compute a grouping key for each item. + *valuefunc* is a function for transforming the items after grouping. + + >>> iterable = 'AaaABbBCcA' + >>> keyfunc = lambda x: x.upper() + >>> valuefunc = lambda x: x.lower() + >>> grouper = groupby_transform(iterable, keyfunc, valuefunc) + >>> [(k, ''.join(g)) for k, g in grouper] + [('A', 'aaaa'), ('B', 'bbb'), ('C', 'cc'), ('A', 'a')] + + *keyfunc* and *valuefunc* default to identity functions if they are not + specified. + + :func:`groupby_transform` is useful when grouping elements of an iterable + using a separate iterable as the key. To do this, :func:`zip` the iterables + and pass a *keyfunc* that extracts the first element and a *valuefunc* + that extracts the second element:: + + >>> from operator import itemgetter + >>> keys = [0, 0, 1, 1, 1, 2, 2, 2, 3] + >>> values = 'abcdefghi' + >>> iterable = zip(keys, values) + >>> grouper = groupby_transform(iterable, itemgetter(0), itemgetter(1)) + >>> [(k, ''.join(g)) for k, g in grouper] + [(0, 'ab'), (1, 'cde'), (2, 'fgh'), (3, 'i')] + + Note that the order of items in the iterable is significant. + Only adjacent items are grouped together, so if you don't want any + duplicate groups, you should sort the iterable by the key function. + + """ + res = groupby(iterable, keyfunc) + return ((k, map(valuefunc, g)) for k, g in res) if valuefunc else res + + +class numeric_range(abc.Sequence, abc.Hashable): + """An extension of the built-in ``range()`` function whose arguments can + be any orderable numeric type. + + With only *stop* specified, *start* defaults to ``0`` and *step* + defaults to ``1``. The output items will match the type of *stop*: + + >>> list(numeric_range(3.5)) + [0.0, 1.0, 2.0, 3.0] + + With only *start* and *stop* specified, *step* defaults to ``1``. The + output items will match the type of *start*: + + >>> from decimal import Decimal + >>> start = Decimal('2.1') + >>> stop = Decimal('5.1') + >>> list(numeric_range(start, stop)) + [Decimal('2.1'), Decimal('3.1'), Decimal('4.1')] + + With *start*, *stop*, and *step* specified the output items will match + the type of ``start + step``: + + >>> from fractions import Fraction + >>> start = Fraction(1, 2) # Start at 1/2 + >>> stop = Fraction(5, 2) # End at 5/2 + >>> step = Fraction(1, 2) # Count by 1/2 + >>> list(numeric_range(start, stop, step)) + [Fraction(1, 2), Fraction(1, 1), Fraction(3, 2), Fraction(2, 1)] + + If *step* is zero, ``ValueError`` is raised. Negative steps are supported: + + >>> list(numeric_range(3, -1, -1.0)) + [3.0, 2.0, 1.0, 0.0] + + Be aware of the limitations of floating point numbers; the representation + of the yielded numbers may be surprising. + + ``datetime.datetime`` objects can be used for *start* and *stop*, if *step* + is a ``datetime.timedelta`` object: + + >>> import datetime + >>> start = datetime.datetime(2019, 1, 1) + >>> stop = datetime.datetime(2019, 1, 3) + >>> step = datetime.timedelta(days=1) + >>> items = iter(numeric_range(start, stop, step)) + >>> next(items) + datetime.datetime(2019, 1, 1, 0, 0) + >>> next(items) + datetime.datetime(2019, 1, 2, 0, 0) + + """ + _EMPTY_HASH = hash(range(0, 0)) + + def __init__(self, *args): + argc = len(args) + if argc == 1: + self._stop, = args + self._start = type(self._stop)(0) + self._step = type(self._stop - self._start)(1) + elif argc == 2: + self._start, self._stop = args + self._step = type(self._stop - self._start)(1) + elif argc == 3: + self._start, self._stop, self._step = args + elif argc == 0: + raise TypeError('numeric_range expected at least ' + '1 argument, got {}'.format(argc)) + else: + raise TypeError('numeric_range expected at most ' + '3 arguments, got {}'.format(argc)) + + self._zero = type(self._step)(0) + if self._step == self._zero: + raise ValueError('numeric_range() arg 3 must not be zero') + self._growing = self._step > self._zero + self._init_len() + + def __bool__(self): + if self._growing: + return self._start < self._stop + else: + return self._start > self._stop + + def __contains__(self, elem): + if self._growing: + if self._start <= elem < self._stop: + return (elem - self._start) % self._step == self._zero + else: + if self._start >= elem > self._stop: + return (self._start - elem) % (-self._step) == self._zero + + return False + + def __eq__(self, other): + if isinstance(other, numeric_range): + empty_self = not bool(self) + empty_other = not bool(other) + if empty_self or empty_other: + return empty_self and empty_other # True if both empty + else: + return (self._start == other._start + and self._step == other._step + and self._get_by_index(-1) == other._get_by_index(-1)) + else: + return False + + def __getitem__(self, key): + if isinstance(key, int): + return self._get_by_index(key) + elif isinstance(key, slice): + step = self._step if key.step is None else key.step * self._step + + if key.start is None or key.start <= -self._len: + start = self._start + elif key.start >= self._len: + start = self._stop + else: # -self._len < key.start < self._len + start = self._get_by_index(key.start) + + if key.stop is None or key.stop >= self._len: + stop = self._stop + elif key.stop <= -self._len: + stop = self._start + else: # -self._len < key.stop < self._len + stop = self._get_by_index(key.stop) + + return numeric_range(start, stop, step) + else: + raise TypeError( + 'numeric range indices must be ' + 'integers or slices, not {}'.format(type(key).__name__)) + + def __hash__(self): + if self: + return hash((self._start, self._get_by_index(-1), self._step)) + else: + return self._EMPTY_HASH + + def __iter__(self): + values = (self._start + (n * self._step) for n in count()) + if self._growing: + return takewhile(partial(gt, self._stop), values) + else: + return takewhile(partial(lt, self._stop), values) + + def __len__(self): + return self._len + + def _init_len(self): + if self._growing: + start = self._start + stop = self._stop + step = self._step + else: + start = self._stop + stop = self._start + step = -self._step + distance = stop - start + if distance <= self._zero: + self._len = 0 + else: # distance > 0 and step > 0: regular euclidean division + q, r = divmod(distance, step) + self._len = int(q) + int(r != self._zero) + + def __reduce__(self): + return numeric_range, (self._start, self._stop, self._step) + + def __repr__(self): + if self._step == 1: + return "numeric_range({}, {})".format(repr(self._start), + repr(self._stop)) + else: + return "numeric_range({}, {}, {})".format(repr(self._start), + repr(self._stop), + repr(self._step)) + + def __reversed__(self): + return iter(numeric_range(self._get_by_index(-1), + self._start - self._step, -self._step)) + + def count(self, value): + return int(value in self) + + def index(self, value): + if self._growing: + if self._start <= value < self._stop: + q, r = divmod(value - self._start, self._step) + if r == self._zero: + return int(q) + else: + if self._start >= value > self._stop: + q, r = divmod(self._start - value, -self._step) + if r == self._zero: + return int(q) + + raise ValueError("{} is not in numeric range".format(value)) + + def _get_by_index(self, i): + if i < 0: + i += self._len + if i < 0 or i >= self._len: + raise IndexError("numeric range object index out of range") + return self._start + i * self._step + + +def count_cycle(iterable, n=None): + """Cycle through the items from *iterable* up to *n* times, yielding + the number of completed cycles along with each item. If *n* is omitted the + process repeats indefinitely. + + >>> list(count_cycle('AB', 3)) + [(0, 'A'), (0, 'B'), (1, 'A'), (1, 'B'), (2, 'A'), (2, 'B')] + + """ + iterable = tuple(iterable) + if not iterable: + return iter(()) + counter = count() if n is None else range(n) + return ((i, item) for i in counter for item in iterable) + + +def locate(iterable, pred=bool, window_size=None): + """Yield the index of each item in *iterable* for which *pred* returns + ``True``. + + *pred* defaults to :func:`bool`, which will select truthy items: + + >>> list(locate([0, 1, 1, 0, 1, 0, 0])) + [1, 2, 4] + + Set *pred* to a custom function to, e.g., find the indexes for a particular + item. + + >>> list(locate(['a', 'b', 'c', 'b'], lambda x: x == 'b')) + [1, 3] + + If *window_size* is given, then the *pred* function will be called with + that many items. This enables searching for sub-sequences: + + >>> iterable = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3] + >>> pred = lambda *args: args == (1, 2, 3) + >>> list(locate(iterable, pred=pred, window_size=3)) + [1, 5, 9] + + Use with :func:`seekable` to find indexes and then retrieve the associated + items: + + >>> from itertools import count + >>> from more_itertools import seekable + >>> source = (3 * n + 1 if (n % 2) else n // 2 for n in count()) + >>> it = seekable(source) + >>> pred = lambda x: x > 100 + >>> indexes = locate(it, pred=pred) + >>> i = next(indexes) + >>> it.seek(i) + >>> next(it) + 106 + + """ + if window_size is None: + return compress(count(), map(pred, iterable)) + + if window_size < 1: + raise ValueError('window size must be at least 1') + + it = windowed(iterable, window_size, fillvalue=_marker) + return compress(count(), starmap(pred, it)) + + +def lstrip(iterable, pred): + """Yield the items from *iterable*, but strip any from the beginning + for which *pred* returns ``True``. + + For example, to remove a set of items from the start of an iterable: + + >>> iterable = (None, False, None, 1, 2, None, 3, False, None) + >>> pred = lambda x: x in {None, False, ''} + >>> list(lstrip(iterable, pred)) + [1, 2, None, 3, False, None] + + This function is analogous to to :func:`str.lstrip`, and is essentially + an wrapper for :func:`itertools.dropwhile`. + + """ + return dropwhile(pred, iterable) + + +def rstrip(iterable, pred): + """Yield the items from *iterable*, but strip any from the end + for which *pred* returns ``True``. + + For example, to remove a set of items from the end of an iterable: + + >>> iterable = (None, False, None, 1, 2, None, 3, False, None) + >>> pred = lambda x: x in {None, False, ''} + >>> list(rstrip(iterable, pred)) + [None, False, None, 1, 2, None, 3] + + This function is analogous to :func:`str.rstrip`. + + """ + cache = [] + cache_append = cache.append + cache_clear = cache.clear + for x in iterable: + if pred(x): + cache_append(x) + else: + yield from cache + cache_clear() + yield x + + +def strip(iterable, pred): + """Yield the items from *iterable*, but strip any from the + beginning and end for which *pred* returns ``True``. + + For example, to remove a set of items from both ends of an iterable: + + >>> iterable = (None, False, None, 1, 2, None, 3, False, None) + >>> pred = lambda x: x in {None, False, ''} + >>> list(strip(iterable, pred)) + [1, 2, None, 3] + + This function is analogous to :func:`str.strip`. + + """ + return rstrip(lstrip(iterable, pred), pred) + + +def islice_extended(iterable, *args): + """An extension of :func:`itertools.islice` that supports negative values + for *stop*, *start*, and *step*. + + >>> iterable = iter('abcdefgh') + >>> list(islice_extended(iterable, -4, -1)) + ['e', 'f', 'g'] + + Slices with negative values require some caching of *iterable*, but this + function takes care to minimize the amount of memory required. + + For example, you can use a negative step with an infinite iterator: + + >>> from itertools import count + >>> list(islice_extended(count(), 110, 99, -2)) + [110, 108, 106, 104, 102, 100] + + """ + s = slice(*args) + start = s.start + stop = s.stop + if s.step == 0: + raise ValueError('step argument must be a non-zero integer or None.') + step = s.step or 1 + + it = iter(iterable) + + if step > 0: + start = 0 if (start is None) else start + + if start < 0: + # Consume all but the last -start items + cache = deque(enumerate(it, 1), maxlen=-start) + len_iter = cache[-1][0] if cache else 0 + + # Adjust start to be positive + i = max(len_iter + start, 0) + + # Adjust stop to be positive + if stop is None: + j = len_iter + elif stop >= 0: + j = min(stop, len_iter) + else: + j = max(len_iter + stop, 0) + + # Slice the cache + n = j - i + if n <= 0: + return + + for index, item in islice(cache, 0, n, step): + yield item + elif (stop is not None) and (stop < 0): + # Advance to the start position + next(islice(it, start, start), None) + + # When stop is negative, we have to carry -stop items while + # iterating + cache = deque(islice(it, -stop), maxlen=-stop) + + for index, item in enumerate(it): + cached_item = cache.popleft() + if index % step == 0: + yield cached_item + cache.append(item) + else: + # When both start and stop are positive we have the normal case + yield from islice(it, start, stop, step) + else: + start = -1 if (start is None) else start + + if (stop is not None) and (stop < 0): + # Consume all but the last items + n = -stop - 1 + cache = deque(enumerate(it, 1), maxlen=n) + len_iter = cache[-1][0] if cache else 0 + + # If start and stop are both negative they are comparable and + # we can just slice. Otherwise we can adjust start to be negative + # and then slice. + if start < 0: + i, j = start, stop + else: + i, j = min(start - len_iter, -1), None + + for index, item in list(cache)[i:j:step]: + yield item + else: + # Advance to the stop position + if stop is not None: + m = stop + 1 + next(islice(it, m, m), None) + + # stop is positive, so if start is negative they are not comparable + # and we need the rest of the items. + if start < 0: + i = start + n = None + # stop is None and start is positive, so we just need items up to + # the start index. + elif stop is None: + i = None + n = start + 1 + # Both stop and start are positive, so they are comparable. + else: + i = None + n = start - stop + if n <= 0: + return + + cache = list(islice(it, n)) + + yield from cache[i::step] + + +def always_reversible(iterable): + """An extension of :func:`reversed` that supports all iterables, not + just those which implement the ``Reversible`` or ``Sequence`` protocols. + + >>> print(*always_reversible(x for x in range(3))) + 2 1 0 + + If the iterable is already reversible, this function returns the + result of :func:`reversed()`. If the iterable is not reversible, + this function will cache the remaining items in the iterable and + yield them in reverse order, which may require significant storage. + """ + try: + return reversed(iterable) + except TypeError: + return reversed(list(iterable)) + + +def consecutive_groups(iterable, ordering=lambda x: x): + """Yield groups of consecutive items using :func:`itertools.groupby`. + The *ordering* function determines whether two items are adjacent by + returning their position. + + By default, the ordering function is the identity function. This is + suitable for finding runs of numbers: + + >>> iterable = [1, 10, 11, 12, 20, 30, 31, 32, 33, 40] + >>> for group in consecutive_groups(iterable): + ... print(list(group)) + [1] + [10, 11, 12] + [20] + [30, 31, 32, 33] + [40] + + For finding runs of adjacent letters, try using the :meth:`index` method + of a string of letters: + + >>> from string import ascii_lowercase + >>> iterable = 'abcdfgilmnop' + >>> ordering = ascii_lowercase.index + >>> for group in consecutive_groups(iterable, ordering): + ... print(list(group)) + ['a', 'b', 'c', 'd'] + ['f', 'g'] + ['i'] + ['l', 'm', 'n', 'o', 'p'] + + Each group of consecutive items is an iterator that shares it source with + *iterable*. When an an output group is advanced, the previous group is + no longer available unless its elements are copied (e.g., into a ``list``). + + >>> iterable = [1, 2, 11, 12, 21, 22] + >>> saved_groups = [] + >>> for group in consecutive_groups(iterable): + ... saved_groups.append(list(group)) # Copy group elements + >>> saved_groups + [[1, 2], [11, 12], [21, 22]] + + """ + for k, g in groupby( + enumerate(iterable), key=lambda x: x[0] - ordering(x[1]) + ): + yield map(itemgetter(1), g) + + +def difference(iterable, func=sub, *, initial=None): + """By default, compute the first difference of *iterable* using + :func:`operator.sub`. + + >>> iterable = [0, 1, 3, 6, 10] + >>> list(difference(iterable)) + [0, 1, 2, 3, 4] + + This is the opposite of :func:`itertools.accumulate`'s default behavior: + + >>> from itertools import accumulate + >>> iterable = [0, 1, 2, 3, 4] + >>> list(accumulate(iterable)) + [0, 1, 3, 6, 10] + >>> list(difference(accumulate(iterable))) + [0, 1, 2, 3, 4] + + By default *func* is :func:`operator.sub`, but other functions can be + specified. They will be applied as follows:: + + A, B, C, D, ... --> A, func(B, A), func(C, B), func(D, C), ... + + For example, to do progressive division: + + >>> iterable = [1, 2, 6, 24, 120] # Factorial sequence + >>> func = lambda x, y: x // y + >>> list(difference(iterable, func)) + [1, 2, 3, 4, 5] + + Since Python 3.8, :func:`itertools.accumulate` can be supplied with an + *initial* keyword argument. If :func:`difference` is called with *initial* + set to something other than ``None``, it will skip the first element when + computing successive differences. + + >>> iterable = [100, 101, 103, 106] # accumate([1, 2, 3], initial=100) + >>> list(difference(iterable, initial=100)) + [1, 2, 3] + + """ + a, b = tee(iterable) + try: + first = [next(b)] + except StopIteration: + return iter([]) + + if initial is not None: + first = [] + + return chain(first, starmap(func, zip(b, a))) + + +class SequenceView(Sequence): + """Return a read-only view of the sequence object *target*. + + :class:`SequenceView` objects are analogous to Python's built-in + "dictionary view" types. They provide a dynamic view of a sequence's items, + meaning that when the sequence updates, so does the view. + + >>> seq = ['0', '1', '2'] + >>> view = SequenceView(seq) + >>> view + SequenceView(['0', '1', '2']) + >>> seq.append('3') + >>> view + SequenceView(['0', '1', '2', '3']) + + Sequence views support indexing, slicing, and length queries. They act + like the underlying sequence, except they don't allow assignment: + + >>> view[1] + '1' + >>> view[1:-1] + ['1', '2'] + >>> len(view) + 4 + + Sequence views are useful as an alternative to copying, as they don't + require (much) extra storage. + + """ + + def __init__(self, target): + if not isinstance(target, Sequence): + raise TypeError + self._target = target + + def __getitem__(self, index): + return self._target[index] + + def __len__(self): + return len(self._target) + + def __repr__(self): + return '{}({})'.format(self.__class__.__name__, repr(self._target)) + + +class seekable: + """Wrap an iterator to allow for seeking backward and forward. This + progressively caches the items in the source iterable so they can be + re-visited. + + Call :meth:`seek` with an index to seek to that position in the source + iterable. + + To "reset" an iterator, seek to ``0``: + + >>> from itertools import count + >>> it = seekable((str(n) for n in count())) + >>> next(it), next(it), next(it) + ('0', '1', '2') + >>> it.seek(0) + >>> next(it), next(it), next(it) + ('0', '1', '2') + >>> next(it) + '3' + + You can also seek forward: + + >>> it = seekable((str(n) for n in range(20))) + >>> it.seek(10) + >>> next(it) + '10' + >>> it.seek(20) # Seeking past the end of the source isn't a problem + >>> list(it) + [] + >>> it.seek(0) # Resetting works even after hitting the end + >>> next(it), next(it), next(it) + ('0', '1', '2') + + You may view the contents of the cache with the :meth:`elements` method. + That returns a :class:`SequenceView`, a view that updates automatically: + + >>> it = seekable((str(n) for n in range(10))) + >>> next(it), next(it), next(it) + ('0', '1', '2') + >>> elements = it.elements() + >>> elements + SequenceView(['0', '1', '2']) + >>> next(it) + '3' + >>> elements + SequenceView(['0', '1', '2', '3']) + + By default, the cache grows as the source iterable progresses, so beware of + wrapping very large or infinite iterables. Supply *maxlen* to limit the + size of the cache (this of course limits how far back you can seek). + + >>> from itertools import count + >>> it = seekable((str(n) for n in count()), maxlen=2) + >>> next(it), next(it), next(it), next(it) + ('0', '1', '2', '3') + >>> list(it.elements()) + ['2', '3'] + >>> it.seek(0) + >>> next(it), next(it), next(it), next(it) + ('2', '3', '4', '5') + >>> next(it) + '6' + + """ + + def __init__(self, iterable, maxlen=None): + self._source = iter(iterable) + if maxlen is None: + self._cache = [] + else: + self._cache = deque([], maxlen) + self._index = None + + def __iter__(self): + return self + + def __next__(self): + if self._index is not None: + try: + item = self._cache[self._index] + except IndexError: + self._index = None + else: + self._index += 1 + return item + + item = next(self._source) + self._cache.append(item) + return item + + def elements(self): + return SequenceView(self._cache) + + def seek(self, index): + self._index = index + remainder = index - len(self._cache) + if remainder > 0: + consume(self, remainder) + + +class run_length: + """ + :func:`run_length.encode` compresses an iterable with run-length encoding. + It yields groups of repeated items with the count of how many times they + were repeated: + + >>> uncompressed = 'abbcccdddd' + >>> list(run_length.encode(uncompressed)) + [('a', 1), ('b', 2), ('c', 3), ('d', 4)] + + :func:`run_length.decode` decompresses an iterable that was previously + compressed with run-length encoding. It yields the items of the + decompressed iterable: + + >>> compressed = [('a', 1), ('b', 2), ('c', 3), ('d', 4)] + >>> list(run_length.decode(compressed)) + ['a', 'b', 'b', 'c', 'c', 'c', 'd', 'd', 'd', 'd'] + + """ + + @staticmethod + def encode(iterable): + return ((k, ilen(g)) for k, g in groupby(iterable)) + + @staticmethod + def decode(iterable): + return chain.from_iterable(repeat(k, n) for k, n in iterable) + + +def exactly_n(iterable, n, predicate=bool): + """Return ``True`` if exactly ``n`` items in the iterable are ``True`` + according to the *predicate* function. + + >>> exactly_n([True, True, False], 2) + True + >>> exactly_n([True, True, False], 1) + False + >>> exactly_n([0, 1, 2, 3, 4, 5], 3, lambda x: x < 3) + True + + The iterable will be advanced until ``n + 1`` truthy items are encountered, + so avoid calling it on infinite iterables. + + """ + return len(take(n + 1, filter(predicate, iterable))) == n + + +def circular_shifts(iterable): + """Return a list of circular shifts of *iterable*. + + >>> circular_shifts(range(4)) + [(0, 1, 2, 3), (1, 2, 3, 0), (2, 3, 0, 1), (3, 0, 1, 2)] + """ + lst = list(iterable) + return take(len(lst), windowed(cycle(lst), len(lst))) + + +def make_decorator(wrapping_func, result_index=0): + """Return a decorator version of *wrapping_func*, which is a function that + modifies an iterable. *result_index* is the position in that function's + signature where the iterable goes. + + This lets you use itertools on the "production end," i.e. at function + definition. This can augment what the function returns without changing the + function's code. + + For example, to produce a decorator version of :func:`chunked`: + + >>> from more_itertools import chunked + >>> chunker = make_decorator(chunked, result_index=0) + >>> @chunker(3) + ... def iter_range(n): + ... return iter(range(n)) + ... + >>> list(iter_range(9)) + [[0, 1, 2], [3, 4, 5], [6, 7, 8]] + + To only allow truthy items to be returned: + + >>> truth_serum = make_decorator(filter, result_index=1) + >>> @truth_serum(bool) + ... def boolean_test(): + ... return [0, 1, '', ' ', False, True] + ... + >>> list(boolean_test()) + [1, ' ', True] + + The :func:`peekable` and :func:`seekable` wrappers make for practical + decorators: + + >>> from more_itertools import peekable + >>> peekable_function = make_decorator(peekable) + >>> @peekable_function() + ... def str_range(*args): + ... return (str(x) for x in range(*args)) + ... + >>> it = str_range(1, 20, 2) + >>> next(it), next(it), next(it) + ('1', '3', '5') + >>> it.peek() + '7' + >>> next(it) + '7' + + """ + # See https://sites.google.com/site/bbayles/index/decorator_factory for + # notes on how this works. + def decorator(*wrapping_args, **wrapping_kwargs): + def outer_wrapper(f): + def inner_wrapper(*args, **kwargs): + result = f(*args, **kwargs) + wrapping_args_ = list(wrapping_args) + wrapping_args_.insert(result_index, result) + return wrapping_func(*wrapping_args_, **wrapping_kwargs) + + return inner_wrapper + + return outer_wrapper + + return decorator + + +def map_reduce(iterable, keyfunc, valuefunc=None, reducefunc=None): + """Return a dictionary that maps the items in *iterable* to categories + defined by *keyfunc*, transforms them with *valuefunc*, and + then summarizes them by category with *reducefunc*. + + *valuefunc* defaults to the identity function if it is unspecified. + If *reducefunc* is unspecified, no summarization takes place: + + >>> keyfunc = lambda x: x.upper() + >>> result = map_reduce('abbccc', keyfunc) + >>> sorted(result.items()) + [('A', ['a']), ('B', ['b', 'b']), ('C', ['c', 'c', 'c'])] + + Specifying *valuefunc* transforms the categorized items: + + >>> keyfunc = lambda x: x.upper() + >>> valuefunc = lambda x: 1 + >>> result = map_reduce('abbccc', keyfunc, valuefunc) + >>> sorted(result.items()) + [('A', [1]), ('B', [1, 1]), ('C', [1, 1, 1])] + + Specifying *reducefunc* summarizes the categorized items: + + >>> keyfunc = lambda x: x.upper() + >>> valuefunc = lambda x: 1 + >>> reducefunc = sum + >>> result = map_reduce('abbccc', keyfunc, valuefunc, reducefunc) + >>> sorted(result.items()) + [('A', 1), ('B', 2), ('C', 3)] + + You may want to filter the input iterable before applying the map/reduce + procedure: + + >>> all_items = range(30) + >>> items = [x for x in all_items if 10 <= x <= 20] # Filter + >>> keyfunc = lambda x: x % 2 # Evens map to 0; odds to 1 + >>> categories = map_reduce(items, keyfunc=keyfunc) + >>> sorted(categories.items()) + [(0, [10, 12, 14, 16, 18, 20]), (1, [11, 13, 15, 17, 19])] + >>> summaries = map_reduce(items, keyfunc=keyfunc, reducefunc=sum) + >>> sorted(summaries.items()) + [(0, 90), (1, 75)] + + Note that all items in the iterable are gathered into a list before the + summarization step, which may require significant storage. + + The returned object is a :obj:`collections.defaultdict` with the + ``default_factory`` set to ``None``, such that it behaves like a normal + dictionary. + + """ + valuefunc = (lambda x: x) if (valuefunc is None) else valuefunc + + ret = defaultdict(list) + for item in iterable: + key = keyfunc(item) + value = valuefunc(item) + ret[key].append(value) + + if reducefunc is not None: + for key, value_list in ret.items(): + ret[key] = reducefunc(value_list) + + ret.default_factory = None + return ret + + +def rlocate(iterable, pred=bool, window_size=None): + """Yield the index of each item in *iterable* for which *pred* returns + ``True``, starting from the right and moving left. + + *pred* defaults to :func:`bool`, which will select truthy items: + + >>> list(rlocate([0, 1, 1, 0, 1, 0, 0])) # Truthy at 1, 2, and 4 + [4, 2, 1] + + Set *pred* to a custom function to, e.g., find the indexes for a particular + item: + + >>> iterable = iter('abcb') + >>> pred = lambda x: x == 'b' + >>> list(rlocate(iterable, pred)) + [3, 1] + + If *window_size* is given, then the *pred* function will be called with + that many items. This enables searching for sub-sequences: + + >>> iterable = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3] + >>> pred = lambda *args: args == (1, 2, 3) + >>> list(rlocate(iterable, pred=pred, window_size=3)) + [9, 5, 1] + + Beware, this function won't return anything for infinite iterables. + If *iterable* is reversible, ``rlocate`` will reverse it and search from + the right. Otherwise, it will search from the left and return the results + in reverse order. + + See :func:`locate` to for other example applications. + + """ + if window_size is None: + try: + len_iter = len(iterable) + return (len_iter - i - 1 for i in locate(reversed(iterable), pred)) + except TypeError: + pass + + return reversed(list(locate(iterable, pred, window_size))) + + +def replace(iterable, pred, substitutes, count=None, window_size=1): + """Yield the items from *iterable*, replacing the items for which *pred* + returns ``True`` with the items from the iterable *substitutes*. + + >>> iterable = [1, 1, 0, 1, 1, 0, 1, 1] + >>> pred = lambda x: x == 0 + >>> substitutes = (2, 3) + >>> list(replace(iterable, pred, substitutes)) + [1, 1, 2, 3, 1, 1, 2, 3, 1, 1] + + If *count* is given, the number of replacements will be limited: + + >>> iterable = [1, 1, 0, 1, 1, 0, 1, 1, 0] + >>> pred = lambda x: x == 0 + >>> substitutes = [None] + >>> list(replace(iterable, pred, substitutes, count=2)) + [1, 1, None, 1, 1, None, 1, 1, 0] + + Use *window_size* to control the number of items passed as arguments to + *pred*. This allows for locating and replacing subsequences. + + >>> iterable = [0, 1, 2, 5, 0, 1, 2, 5] + >>> window_size = 3 + >>> pred = lambda *args: args == (0, 1, 2) # 3 items passed to pred + >>> substitutes = [3, 4] # Splice in these items + >>> list(replace(iterable, pred, substitutes, window_size=window_size)) + [3, 4, 5, 3, 4, 5] + + """ + if window_size < 1: + raise ValueError('window_size must be at least 1') + + # Save the substitutes iterable, since it's used more than once + substitutes = tuple(substitutes) + + # Add padding such that the number of windows matches the length of the + # iterable + it = chain(iterable, [_marker] * (window_size - 1)) + windows = windowed(it, window_size) + + n = 0 + for w in windows: + # If the current window matches our predicate (and we haven't hit + # our maximum number of replacements), splice in the substitutes + # and then consume the following windows that overlap with this one. + # For example, if the iterable is (0, 1, 2, 3, 4...) + # and the window size is 2, we have (0, 1), (1, 2), (2, 3)... + # If the predicate matches on (0, 1), we need to zap (0, 1) and (1, 2) + if pred(*w): + if (count is None) or (n < count): + n += 1 + yield from substitutes + consume(windows, window_size - 1) + continue + + # If there was no match (or we've reached the replacement limit), + # yield the first item from the window. + if w and (w[0] is not _marker): + yield w[0] + + +def partitions(iterable): + """Yield all possible order-perserving partitions of *iterable*. + + >>> iterable = 'abc' + >>> for part in partitions(iterable): + ... print([''.join(p) for p in part]) + ['abc'] + ['a', 'bc'] + ['ab', 'c'] + ['a', 'b', 'c'] + + This is unrelated to :func:`partition`. + + """ + sequence = list(iterable) + n = len(sequence) + for i in powerset(range(1, n)): + yield [sequence[i:j] for i, j in zip((0,) + i, i + (n,))] + + +def set_partitions(iterable, k=None): + """ + Yield the set partitions of *iterable* into *k* parts. Set partitions are + not order-preserving. + + >>> iterable = 'abc' + >>> for part in set_partitions(iterable, 2): + ... print([''.join(p) for p in part]) + ['a', 'bc'] + ['ab', 'c'] + ['b', 'ac'] + + + If *k* is not given, every set partition is generated. + + >>> iterable = 'abc' + >>> for part in set_partitions(iterable): + ... print([''.join(p) for p in part]) + ['abc'] + ['a', 'bc'] + ['ab', 'c'] + ['b', 'ac'] + ['a', 'b', 'c'] + + """ + L = list(iterable) + n = len(L) + if k is not None: + if k < 1: + raise ValueError( + "Can't partition in a negative or zero number of groups" + ) + elif k > n: + return + + def set_partitions_helper(L, k): + n = len(L) + if k == 1: + yield [L] + elif n == k: + yield [[s] for s in L] + else: + e, *M = L + for p in set_partitions_helper(M, k - 1): + yield [[e], *p] + for p in set_partitions_helper(M, k): + for i in range(len(p)): + yield p[:i] + [[e] + p[i]] + p[i + 1 :] + + if k is None: + for k in range(1, n + 1): + yield from set_partitions_helper(L, k) + else: + yield from set_partitions_helper(L, k) + + +def time_limited(limit_seconds, iterable): + """ + Yield items from *iterable* until *limit_seconds* have passed. + + >>> from time import sleep + >>> def generator(): + ... yield 1 + ... yield 2 + ... sleep(0.2) + ... yield 3 + >>> iterable = generator() + >>> list(time_limited(0.1, iterable)) + [1, 2] + + Note that the time is checked before each item is yielded, and iteration + stops if the time elapsed is greater than *limit_seconds*. If your time + limit is 1 second, but it takes 2 seconds to generate the first item from + the iterable, the function will run for 2 seconds and not yield anything. + + """ + if limit_seconds < 0: + raise ValueError('limit_seconds must be positive') + + start_time = monotonic() + for item in iterable: + if monotonic() - start_time > limit_seconds: + break + yield item + + +def only(iterable, default=None, too_long=None): + """If *iterable* has only one item, return it. + If it has zero items, return *default*. + If it has more than one item, raise the exception given by *too_long*, + which is ``ValueError`` by default. + + >>> only([], default='missing') + 'missing' + >>> only([1]) + 1 + >>> only([1, 2]) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ValueError: Expected exactly one item in iterable, but got 1, 2, + and perhaps more.' + >>> only([1, 2], too_long=TypeError) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + TypeError + + Note that :func:`only` attempts to advance *iterable* twice to ensure there + is only one item. See :func:`spy` or :func:`peekable` to check + iterable contents less destructively. + """ + it = iter(iterable) + first_value = next(it, default) + + try: + second_value = next(it) + except StopIteration: + pass + else: + msg = ( + 'Expected exactly one item in iterable, but got {!r}, {!r}, ' + 'and perhaps more.'.format(first_value, second_value) + ) + raise too_long or ValueError(msg) + + return first_value + + +def ichunked(iterable, n): + """Break *iterable* into sub-iterables with *n* elements each. + :func:`ichunked` is like :func:`chunked`, but it yields iterables + instead of lists. + + If the sub-iterables are read in order, the elements of *iterable* + won't be stored in memory. + If they are read out of order, :func:`itertools.tee` is used to cache + elements as necessary. + + >>> from itertools import count + >>> all_chunks = ichunked(count(), 4) + >>> c_1, c_2, c_3 = next(all_chunks), next(all_chunks), next(all_chunks) + >>> list(c_2) # c_1's elements have been cached; c_3's haven't been + [4, 5, 6, 7] + >>> list(c_1) + [0, 1, 2, 3] + >>> list(c_3) + [8, 9, 10, 11] + + """ + source = iter(iterable) + + while True: + # Check to see whether we're at the end of the source iterable + item = next(source, _marker) + if item is _marker: + return + + # Clone the source and yield an n-length slice + source, it = tee(chain([item], source)) + yield islice(it, n) + + # Advance the source iterable + consume(source, n) + + +def distinct_combinations(iterable, r): + """Yield the distinct combinations of *r* items taken from *iterable*. + + >>> list(distinct_combinations([0, 0, 1], 2)) + [(0, 0), (0, 1)] + + Equivalent to ``set(combinations(iterable))``, except duplicates are not + generated and thrown away. For larger input sequences this is much more + efficient. + + """ + if r < 0: + raise ValueError('r must be non-negative') + elif r == 0: + yield () + else: + pool = tuple(iterable) + for i, prefix in unique_everseen(enumerate(pool), key=itemgetter(1)): + for suffix in distinct_combinations(pool[i + 1 :], r - 1): + yield (prefix,) + suffix + + +def filter_except(validator, iterable, *exceptions): + """Yield the items from *iterable* for which the *validator* function does + not raise one of the specified *exceptions*. + + *validator* is called for each item in *iterable*. + It should be a function that accepts one argument and raises an exception + if that item is not valid. + + >>> iterable = ['1', '2', 'three', '4', None] + >>> list(filter_except(int, iterable, ValueError, TypeError)) + ['1', '2', '4'] + + If an exception other than one given by *exceptions* is raised by + *validator*, it is raised like normal. + """ + exceptions = tuple(exceptions) + for item in iterable: + try: + validator(item) + except exceptions: + pass + else: + yield item + + +def map_except(function, iterable, *exceptions): + """Transform each item from *iterable* with *function* and yield the + result, unless *function* raises one of the specified *exceptions*. + + *function* is called to transform each item in *iterable*. + It should be a accept one argument. + + >>> iterable = ['1', '2', 'three', '4', None] + >>> list(map_except(int, iterable, ValueError, TypeError)) + [1, 2, 4] + + If an exception other than one given by *exceptions* is raised by + *function*, it is raised like normal. + """ + exceptions = tuple(exceptions) + for item in iterable: + try: + yield function(item) + except exceptions: + pass + + +def _sample_unweighted(iterable, k): + # Implementation of "Algorithm L" from the 1994 paper by Kim-Hung Li: + # "Reservoir-Sampling Algorithms of Time Complexity O(n(1+log(N/n)))". + + # Fill up the reservoir (collection of samples) with the first `k` samples + reservoir = take(k, iterable) + + # Generate random number that's the largest in a sample of k U(0,1) numbers + # Largest order statistic: https://en.wikipedia.org/wiki/Order_statistic + W = exp(log(random()) / k) + + # The number of elements to skip before changing the reservoir is a random + # number with a geometric distribution. Sample it using random() and logs. + next_index = k + floor(log(random()) / log(1 - W)) + + for index, element in enumerate(iterable, k): + + if index == next_index: + reservoir[randrange(k)] = element + # The new W is the largest in a sample of k U(0, `old_W`) numbers + W *= exp(log(random()) / k) + next_index += floor(log(random()) / log(1 - W)) + 1 + + return reservoir + + +def _sample_weighted(iterable, k, weights): + # Implementation of "A-ExpJ" from the 2006 paper by Efraimidis et al. : + # "Weighted random sampling with a reservoir". + + # Log-transform for numerical stability for weights that are small/large + weight_keys = (log(random()) / weight for weight in weights) + + # Fill up the reservoir (collection of samples) with the first `k` + # weight-keys and elements, then heapify the list. + reservoir = take(k, zip(weight_keys, iterable)) + heapify(reservoir) + + # The number of jumps before changing the reservoir is a random variable + # with an exponential distribution. Sample it using random() and logs. + smallest_weight_key, _ = reservoir[0] + weights_to_skip = log(random()) / smallest_weight_key + + for weight, element in zip(weights, iterable): + if weight >= weights_to_skip: + # The notation here is consistent with the paper, but we store + # the weight-keys in log-space for better numerical stability. + smallest_weight_key, _ = reservoir[0] + t_w = exp(weight * smallest_weight_key) + r_2 = uniform(t_w, 1) # generate U(t_w, 1) + weight_key = log(r_2) / weight + heapreplace(reservoir, (weight_key, element)) + smallest_weight_key, _ = reservoir[0] + weights_to_skip = log(random()) / smallest_weight_key + else: + weights_to_skip -= weight + + # Equivalent to [element for weight_key, element in sorted(reservoir)] + return [heappop(reservoir)[1] for _ in range(k)] + + +def sample(iterable, k, weights=None): + """Return a *k*-length list of elements chosen (without replacement) + from the *iterable*. Like :func:`random.sample`, but works on iterables + of unknown length. + + >>> iterable = range(100) + >>> sample(iterable, 5) # doctest: +SKIP + [81, 60, 96, 16, 4] + + An iterable with *weights* may also be given: + + >>> iterable = range(100) + >>> weights = (i * i + 1 for i in range(100)) + >>> sampled = sample(iterable, 5, weights=weights) # doctest: +SKIP + [79, 67, 74, 66, 78] + + The algorithm can also be used to generate weighted random permutations. + The relative weight of each item determines the probability that it + appears late in the permutation. + + >>> data = "abcdefgh" + >>> weights = range(1, len(data) + 1) + >>> sample(data, k=len(data), weights=weights) # doctest: +SKIP + ['c', 'a', 'b', 'e', 'g', 'd', 'h', 'f'] + """ + if k == 0: + return [] + + iterable = iter(iterable) + if weights is None: + return _sample_unweighted(iterable, k) + else: + weights = iter(weights) + return _sample_weighted(iterable, k, weights) diff --git a/resources/lib/more_itertools/recipes.py b/resources/lib/more_itertools/recipes.py new file mode 100644 index 0000000..77cc2db --- /dev/null +++ b/resources/lib/more_itertools/recipes.py @@ -0,0 +1,572 @@ +"""Imported from the recipes section of the itertools documentation. + +All functions taken from the recipes section of the itertools library docs +[1]_. +Some backward-compatible usability improvements have been made. + +.. [1] http://docs.python.org/library/itertools.html#recipes + +""" +import warnings +from collections import deque +from itertools import ( + chain, + combinations, + count, + cycle, + groupby, + islice, + repeat, + starmap, + tee, + zip_longest, +) +import operator +from random import randrange, sample, choice + +__all__ = [ + 'all_equal', + 'consume', + 'dotproduct', + 'first_true', + 'flatten', + 'grouper', + 'iter_except', + 'ncycles', + 'nth', + 'nth_combination', + 'padnone', + 'pairwise', + 'partition', + 'powerset', + 'prepend', + 'quantify', + 'random_combination_with_replacement', + 'random_combination', + 'random_permutation', + 'random_product', + 'repeatfunc', + 'roundrobin', + 'tabulate', + 'tail', + 'take', + 'unique_everseen', + 'unique_justseen', +] + + +def take(n, iterable): + """Return first *n* items of the iterable as a list. + + >>> take(3, range(10)) + [0, 1, 2] + + If there are fewer than *n* items in the iterable, all of them are + returned. + + >>> take(10, range(3)) + [0, 1, 2] + + """ + return list(islice(iterable, n)) + + +def tabulate(function, start=0): + """Return an iterator over the results of ``func(start)``, + ``func(start + 1)``, ``func(start + 2)``... + + *func* should be a function that accepts one integer argument. + + If *start* is not specified it defaults to 0. It will be incremented each + time the iterator is advanced. + + >>> square = lambda x: x ** 2 + >>> iterator = tabulate(square, -3) + >>> take(4, iterator) + [9, 4, 1, 0] + + """ + return map(function, count(start)) + + +def tail(n, iterable): + """Return an iterator over the last *n* items of *iterable*. + + >>> t = tail(3, 'ABCDEFG') + >>> list(t) + ['E', 'F', 'G'] + + """ + return iter(deque(iterable, maxlen=n)) + + +def consume(iterator, n=None): + """Advance *iterable* by *n* steps. If *n* is ``None``, consume it + entirely. + + Efficiently exhausts an iterator without returning values. Defaults to + consuming the whole iterator, but an optional second argument may be + provided to limit consumption. + + >>> i = (x for x in range(10)) + >>> next(i) + 0 + >>> consume(i, 3) + >>> next(i) + 4 + >>> consume(i) + >>> next(i) + Traceback (most recent call last): + File "", line 1, in + StopIteration + + If the iterator has fewer items remaining than the provided limit, the + whole iterator will be consumed. + + >>> i = (x for x in range(3)) + >>> consume(i, 5) + >>> next(i) + Traceback (most recent call last): + File "", line 1, in + StopIteration + + """ + # Use functions that consume iterators at C speed. + if n is None: + # feed the entire iterator into a zero-length deque + deque(iterator, maxlen=0) + else: + # advance to the empty slice starting at position n + next(islice(iterator, n, n), None) + + +def nth(iterable, n, default=None): + """Returns the nth item or a default value. + + >>> l = range(10) + >>> nth(l, 3) + 3 + >>> nth(l, 20, "zebra") + 'zebra' + + """ + return next(islice(iterable, n, None), default) + + +def all_equal(iterable): + """ + Returns ``True`` if all the elements are equal to each other. + + >>> all_equal('aaaa') + True + >>> all_equal('aaab') + False + + """ + g = groupby(iterable) + return next(g, True) and not next(g, False) + + +def quantify(iterable, pred=bool): + """Return the how many times the predicate is true. + + >>> quantify([True, False, True]) + 2 + + """ + return sum(map(pred, iterable)) + + +def padnone(iterable): + """Returns the sequence of elements and then returns ``None`` indefinitely. + + >>> take(5, padnone(range(3))) + [0, 1, 2, None, None] + + Useful for emulating the behavior of the built-in :func:`map` function. + + See also :func:`padded`. + + """ + return chain(iterable, repeat(None)) + + +def ncycles(iterable, n): + """Returns the sequence elements *n* times + + >>> list(ncycles(["a", "b"], 3)) + ['a', 'b', 'a', 'b', 'a', 'b'] + + """ + return chain.from_iterable(repeat(tuple(iterable), n)) + + +def dotproduct(vec1, vec2): + """Returns the dot product of the two iterables. + + >>> dotproduct([10, 10], [20, 20]) + 400 + + """ + return sum(map(operator.mul, vec1, vec2)) + + +def flatten(listOfLists): + """Return an iterator flattening one level of nesting in a list of lists. + + >>> list(flatten([[0, 1], [2, 3]])) + [0, 1, 2, 3] + + See also :func:`collapse`, which can flatten multiple levels of nesting. + + """ + return chain.from_iterable(listOfLists) + + +def repeatfunc(func, times=None, *args): + """Call *func* with *args* repeatedly, returning an iterable over the + results. + + If *times* is specified, the iterable will terminate after that many + repetitions: + + >>> from operator import add + >>> times = 4 + >>> args = 3, 5 + >>> list(repeatfunc(add, times, *args)) + [8, 8, 8, 8] + + If *times* is ``None`` the iterable will not terminate: + + >>> from random import randrange + >>> times = None + >>> args = 1, 11 + >>> take(6, repeatfunc(randrange, times, *args)) # doctest:+SKIP + [2, 4, 8, 1, 8, 4] + + """ + if times is None: + return starmap(func, repeat(args)) + return starmap(func, repeat(args, times)) + + +def pairwise(iterable): + """Returns an iterator of paired items, overlapping, from the original + + >>> take(4, pairwise(count())) + [(0, 1), (1, 2), (2, 3), (3, 4)] + + """ + a, b = tee(iterable) + next(b, None) + return zip(a, b) + + +def grouper(iterable, n, fillvalue=None): + """Collect data into fixed-length chunks or blocks. + + >>> list(grouper('ABCDEFG', 3, 'x')) + [('A', 'B', 'C'), ('D', 'E', 'F'), ('G', 'x', 'x')] + + """ + if isinstance(iterable, int): + warnings.warn( + "grouper expects iterable as first parameter", DeprecationWarning + ) + n, iterable = iterable, n + args = [iter(iterable)] * n + return zip_longest(fillvalue=fillvalue, *args) + + +def roundrobin(*iterables): + """Yields an item from each iterable, alternating between them. + + >>> list(roundrobin('ABC', 'D', 'EF')) + ['A', 'D', 'E', 'B', 'F', 'C'] + + This function produces the same output as :func:`interleave_longest`, but + may perform better for some inputs (in particular when the number of + iterables is small). + + """ + # Recipe credited to George Sakkis + pending = len(iterables) + nexts = cycle(iter(it).__next__ for it in iterables) + while pending: + try: + for next in nexts: + yield next() + except StopIteration: + pending -= 1 + nexts = cycle(islice(nexts, pending)) + + +def partition(pred, iterable): + """ + Returns a 2-tuple of iterables derived from the input iterable. + The first yields the items that have ``pred(item) == False``. + The second yields the items that have ``pred(item) == True``. + + >>> is_odd = lambda x: x % 2 != 0 + >>> iterable = range(10) + >>> even_items, odd_items = partition(is_odd, iterable) + >>> list(even_items), list(odd_items) + ([0, 2, 4, 6, 8], [1, 3, 5, 7, 9]) + + If *pred* is None, :func:`bool` is used. + + >>> iterable = [0, 1, False, True, '', ' '] + >>> false_items, true_items = partition(None, iterable) + >>> list(false_items), list(true_items) + ([0, False, ''], [1, True, ' ']) + + """ + if pred is None: + pred = bool + + evaluations = ((pred(x), x) for x in iterable) + t1, t2 = tee(evaluations) + return ( + (x for (cond, x) in t1 if not cond), + (x for (cond, x) in t2 if cond), + ) + + +def powerset(iterable): + """Yields all possible subsets of the iterable. + + >>> list(powerset([1, 2, 3])) + [(), (1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)] + + :func:`powerset` will operate on iterables that aren't :class:`set` + instances, so repeated elements in the input will produce repeated elements + in the output. Use :func:`unique_everseen` on the input to avoid generating + duplicates: + + >>> seq = [1, 1, 0] + >>> list(powerset(seq)) + [(), (1,), (1,), (0,), (1, 1), (1, 0), (1, 0), (1, 1, 0)] + >>> from more_itertools import unique_everseen + >>> list(powerset(unique_everseen(seq))) + [(), (1,), (0,), (1, 0)] + + """ + s = list(iterable) + return chain.from_iterable(combinations(s, r) for r in range(len(s) + 1)) + + +def unique_everseen(iterable, key=None): + """ + Yield unique elements, preserving order. + + >>> list(unique_everseen('AAAABBBCCDAABBB')) + ['A', 'B', 'C', 'D'] + >>> list(unique_everseen('ABBCcAD', str.lower)) + ['A', 'B', 'C', 'D'] + + Sequences with a mix of hashable and unhashable items can be used. + The function will be slower (i.e., `O(n^2)`) for unhashable items. + + Remember that ``list`` objects are unhashable - you can use the *key* + parameter to transform the list to a tuple (which is hashable) to + avoid a slowdown. + + >>> iterable = ([1, 2], [2, 3], [1, 2]) + >>> list(unique_everseen(iterable)) # Slow + [[1, 2], [2, 3]] + >>> list(unique_everseen(iterable, key=tuple)) # Faster + [[1, 2], [2, 3]] + + Similary, you may want to convert unhashable ``set`` objects with + ``key=frozenset``. For ``dict`` objects, + ``key=lambda x: frozenset(x.items())`` can be used. + + """ + seenset = set() + seenset_add = seenset.add + seenlist = [] + seenlist_add = seenlist.append + iterable, keys = tee(iterable) + for element, k in zip(iterable, map(key, keys) if key else keys): + try: + if k not in seenset: + seenset_add(k) + yield element + except TypeError: + if k not in seenlist: + seenlist_add(k) + yield element + + +def unique_justseen(iterable, key=None): + """Yields elements in order, ignoring serial duplicates + + >>> list(unique_justseen('AAAABBBCCDAABBB')) + ['A', 'B', 'C', 'D', 'A', 'B'] + >>> list(unique_justseen('ABBCcAD', str.lower)) + ['A', 'B', 'C', 'A', 'D'] + + """ + return map(next, map(operator.itemgetter(1), groupby(iterable, key))) + + +def iter_except(func, exception, first=None): + """Yields results from a function repeatedly until an exception is raised. + + Converts a call-until-exception interface to an iterator interface. + Like ``iter(func, sentinel)``, but uses an exception instead of a sentinel + to end the loop. + + >>> l = [0, 1, 2] + >>> list(iter_except(l.pop, IndexError)) + [2, 1, 0] + + """ + try: + if first is not None: + yield first() + while 1: + yield func() + except exception: + pass + + +def first_true(iterable, default=None, pred=None): + """ + Returns the first true value in the iterable. + + If no true value is found, returns *default* + + If *pred* is not None, returns the first item for which + ``pred(item) == True`` . + + >>> first_true(range(10)) + 1 + >>> first_true(range(10), pred=lambda x: x > 5) + 6 + >>> first_true(range(10), default='missing', pred=lambda x: x > 9) + 'missing' + + """ + return next(filter(pred, iterable), default) + + +def random_product(*args, repeat=1): + """Draw an item at random from each of the input iterables. + + >>> random_product('abc', range(4), 'XYZ') # doctest:+SKIP + ('c', 3, 'Z') + + If *repeat* is provided as a keyword argument, that many items will be + drawn from each iterable. + + >>> random_product('abcd', range(4), repeat=2) # doctest:+SKIP + ('a', 2, 'd', 3) + + This equivalent to taking a random selection from + ``itertools.product(*args, **kwarg)``. + + """ + pools = [tuple(pool) for pool in args] * repeat + return tuple(choice(pool) for pool in pools) + + +def random_permutation(iterable, r=None): + """Return a random *r* length permutation of the elements in *iterable*. + + If *r* is not specified or is ``None``, then *r* defaults to the length of + *iterable*. + + >>> random_permutation(range(5)) # doctest:+SKIP + (3, 4, 0, 1, 2) + + This equivalent to taking a random selection from + ``itertools.permutations(iterable, r)``. + + """ + pool = tuple(iterable) + r = len(pool) if r is None else r + return tuple(sample(pool, r)) + + +def random_combination(iterable, r): + """Return a random *r* length subsequence of the elements in *iterable*. + + >>> random_combination(range(5), 3) # doctest:+SKIP + (2, 3, 4) + + This equivalent to taking a random selection from + ``itertools.combinations(iterable, r)``. + + """ + pool = tuple(iterable) + n = len(pool) + indices = sorted(sample(range(n), r)) + return tuple(pool[i] for i in indices) + + +def random_combination_with_replacement(iterable, r): + """Return a random *r* length subsequence of elements in *iterable*, + allowing individual elements to be repeated. + + >>> random_combination_with_replacement(range(3), 5) # doctest:+SKIP + (0, 0, 1, 2, 2) + + This equivalent to taking a random selection from + ``itertools.combinations_with_replacement(iterable, r)``. + + """ + pool = tuple(iterable) + n = len(pool) + indices = sorted(randrange(n) for i in range(r)) + return tuple(pool[i] for i in indices) + + +def nth_combination(iterable, r, index): + """Equivalent to ``list(combinations(iterable, r))[index]``. + + The subsequences of *iterable* that are of length *r* can be ordered + lexicographically. :func:`nth_combination` computes the subsequence at + sort position *index* directly, without computing the previous + subsequences. + + """ + pool = tuple(iterable) + n = len(pool) + if (r < 0) or (r > n): + raise ValueError + + c = 1 + k = min(r, n - r) + for i in range(1, k + 1): + c = c * (n - k + i) // i + + if index < 0: + index += c + + if (index < 0) or (index >= c): + raise IndexError + + result = [] + while r: + c, n, r = c * r // n, n - 1, r - 1 + while index >= c: + index -= c + c, n = c * (n - r) // n, n - 1 + result.append(pool[-1 - n]) + + return tuple(result) + + +def prepend(value, iterator): + """Yield *value*, followed by the elements in *iterator*. + + >>> value = '0' + >>> iterator = ['1', '2', '3'] + >>> list(prepend(value, iterator)) + ['0', '1', '2', '3'] + + To prepend multiple values, see :func:`itertools.chain`. + + """ + return chain([value], iterator) diff --git a/resources/lib/more_itertools/tests/__init__.py b/resources/lib/more_itertools/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/resources/lib/more_itertools/tests/test_more.py b/resources/lib/more_itertools/tests/test_more.py new file mode 100644 index 0000000..a1b1e43 --- /dev/null +++ b/resources/lib/more_itertools/tests/test_more.py @@ -0,0 +1,2074 @@ +from __future__ import division, print_function, unicode_literals + +from collections import OrderedDict +from decimal import Decimal +from doctest import DocTestSuite +from fractions import Fraction +from functools import partial, reduce +from heapq import merge +from io import StringIO +from itertools import ( + chain, + count, + groupby, + islice, + permutations, + product, + repeat, +) +from operator import add, mul, itemgetter +from unittest import TestCase + +from six.moves import filter, map, range, zip + +import more_itertools as mi + + +def load_tests(loader, tests, ignore): + # Add the doctests + tests.addTests(DocTestSuite('more_itertools.more')) + return tests + + +class CollateTests(TestCase): + """Unit tests for ``collate()``""" + # Also accidentally tests peekable, though that could use its own tests + + def test_default(self): + """Test with the default `key` function.""" + iterables = [range(4), range(7), range(3, 6)] + self.assertEqual( + sorted(reduce(list.__add__, [list(it) for it in iterables])), + list(mi.collate(*iterables)) + ) + + def test_key(self): + """Test using a custom `key` function.""" + iterables = [range(5, 0, -1), range(4, 0, -1)] + actual = sorted( + reduce(list.__add__, [list(it) for it in iterables]), reverse=True + ) + expected = list(mi.collate(*iterables, key=lambda x: -x)) + self.assertEqual(actual, expected) + + def test_empty(self): + """Be nice if passed an empty list of iterables.""" + self.assertEqual([], list(mi.collate())) + + def test_one(self): + """Work when only 1 iterable is passed.""" + self.assertEqual([0, 1], list(mi.collate(range(2)))) + + def test_reverse(self): + """Test the `reverse` kwarg.""" + iterables = [range(4, 0, -1), range(7, 0, -1), range(3, 6, -1)] + + actual = sorted( + reduce(list.__add__, [list(it) for it in iterables]), reverse=True + ) + expected = list(mi.collate(*iterables, reverse=True)) + self.assertEqual(actual, expected) + + def test_alias(self): + self.assertNotEqual(merge.__doc__, mi.collate.__doc__) + self.assertNotEqual(partial.__doc__, mi.collate.__doc__) + + +class ChunkedTests(TestCase): + """Tests for ``chunked()``""" + + def test_even(self): + """Test when ``n`` divides evenly into the length of the iterable.""" + self.assertEqual( + list(mi.chunked('ABCDEF', 3)), [['A', 'B', 'C'], ['D', 'E', 'F']] + ) + + def test_odd(self): + """Test when ``n`` does not divide evenly into the length of the + iterable. + + """ + self.assertEqual( + list(mi.chunked('ABCDE', 3)), [['A', 'B', 'C'], ['D', 'E']] + ) + + +class FirstTests(TestCase): + """Tests for ``first()``""" + + def test_many(self): + """Test that it works on many-item iterables.""" + # Also try it on a generator expression to make sure it works on + # whatever those return, across Python versions. + self.assertEqual(mi.first(x for x in range(4)), 0) + + def test_one(self): + """Test that it doesn't raise StopIteration prematurely.""" + self.assertEqual(mi.first([3]), 3) + + def test_empty_stop_iteration(self): + """It should raise StopIteration for empty iterables.""" + self.assertRaises(ValueError, lambda: mi.first([])) + + def test_default(self): + """It should return the provided default arg for empty iterables.""" + self.assertEqual(mi.first([], 'boo'), 'boo') + + +class IterOnlyRange: + """User-defined iterable class which only support __iter__. + + It is not specified to inherit ``object``, so indexing on a instance will + raise an ``AttributeError`` rather than ``TypeError`` in Python 2. + + >>> r = IterOnlyRange(5) + >>> r[0] + AttributeError: IterOnlyRange instance has no attribute '__getitem__' + + Note: In Python 3, ``TypeError`` will be raised because ``object`` is + inherited implicitly by default. + + >>> r[0] + TypeError: 'IterOnlyRange' object does not support indexing + """ + def __init__(self, n): + """Set the length of the range.""" + self.n = n + + def __iter__(self): + """Works same as range().""" + return iter(range(self.n)) + + +class LastTests(TestCase): + """Tests for ``last()``""" + + def test_many_nonsliceable(self): + """Test that it works on many-item non-slice-able iterables.""" + # Also try it on a generator expression to make sure it works on + # whatever those return, across Python versions. + self.assertEqual(mi.last(x for x in range(4)), 3) + + def test_one_nonsliceable(self): + """Test that it doesn't raise StopIteration prematurely.""" + self.assertEqual(mi.last(x for x in range(1)), 0) + + def test_empty_stop_iteration_nonsliceable(self): + """It should raise ValueError for empty non-slice-able iterables.""" + self.assertRaises(ValueError, lambda: mi.last(x for x in range(0))) + + def test_default_nonsliceable(self): + """It should return the provided default arg for empty non-slice-able + iterables. + """ + self.assertEqual(mi.last((x for x in range(0)), 'boo'), 'boo') + + def test_many_sliceable(self): + """Test that it works on many-item slice-able iterables.""" + self.assertEqual(mi.last([0, 1, 2, 3]), 3) + + def test_one_sliceable(self): + """Test that it doesn't raise StopIteration prematurely.""" + self.assertEqual(mi.last([3]), 3) + + def test_empty_stop_iteration_sliceable(self): + """It should raise ValueError for empty slice-able iterables.""" + self.assertRaises(ValueError, lambda: mi.last([])) + + def test_default_sliceable(self): + """It should return the provided default arg for empty slice-able + iterables. + """ + self.assertEqual(mi.last([], 'boo'), 'boo') + + def test_dict(self): + """last(dic) and last(dic.keys()) should return same result.""" + dic = {'a': 1, 'b': 2, 'c': 3} + self.assertEqual(mi.last(dic), mi.last(dic.keys())) + + def test_ordereddict(self): + """last(dic) should return the last key.""" + od = OrderedDict() + od['a'] = 1 + od['b'] = 2 + od['c'] = 3 + self.assertEqual(mi.last(od), 'c') + + def test_customrange(self): + """It should work on custom class where [] raises AttributeError.""" + self.assertEqual(mi.last(IterOnlyRange(5)), 4) + + +class PeekableTests(TestCase): + """Tests for ``peekable()`` behavor not incidentally covered by testing + ``collate()`` + + """ + def test_peek_default(self): + """Make sure passing a default into ``peek()`` works.""" + p = mi.peekable([]) + self.assertEqual(p.peek(7), 7) + + def test_truthiness(self): + """Make sure a ``peekable`` tests true iff there are items remaining in + the iterable. + + """ + p = mi.peekable([]) + self.assertFalse(p) + + p = mi.peekable(range(3)) + self.assertTrue(p) + + def test_simple_peeking(self): + """Make sure ``next`` and ``peek`` advance and don't advance the + iterator, respectively. + + """ + p = mi.peekable(range(10)) + self.assertEqual(next(p), 0) + self.assertEqual(p.peek(), 1) + self.assertEqual(next(p), 1) + + def test_indexing(self): + """ + Indexing into the peekable shouldn't advance the iterator. + """ + p = mi.peekable('abcdefghijkl') + + # The 0th index is what ``next()`` will return + self.assertEqual(p[0], 'a') + self.assertEqual(next(p), 'a') + + # Indexing further into the peekable shouldn't advance the itertor + self.assertEqual(p[2], 'd') + self.assertEqual(next(p), 'b') + + # The 0th index moves up with the iterator; the last index follows + self.assertEqual(p[0], 'c') + self.assertEqual(p[9], 'l') + + self.assertEqual(next(p), 'c') + self.assertEqual(p[8], 'l') + + # Negative indexing should work too + self.assertEqual(p[-2], 'k') + self.assertEqual(p[-9], 'd') + self.assertRaises(IndexError, lambda: p[-10]) + + def test_slicing(self): + """Slicing the peekable shouldn't advance the iterator.""" + seq = list('abcdefghijkl') + p = mi.peekable(seq) + + # Slicing the peekable should just be like slicing a re-iterable + self.assertEqual(p[1:4], seq[1:4]) + + # Advancing the iterator moves the slices up also + self.assertEqual(next(p), 'a') + self.assertEqual(p[1:4], seq[1:][1:4]) + + # Implicit starts and stop should work + self.assertEqual(p[:5], seq[1:][:5]) + self.assertEqual(p[:], seq[1:][:]) + + # Indexing past the end should work + self.assertEqual(p[:100], seq[1:][:100]) + + # Steps should work, including negative + self.assertEqual(p[::2], seq[1:][::2]) + self.assertEqual(p[::-1], seq[1:][::-1]) + + def test_slicing_reset(self): + """Test slicing on a fresh iterable each time""" + iterable = ['0', '1', '2', '3', '4', '5'] + indexes = list(range(-4, len(iterable) + 4)) + [None] + steps = [1, 2, 3, 4, -1, -2, -3, 4] + for slice_args in product(indexes, indexes, steps): + it = iter(iterable) + p = mi.peekable(it) + next(p) + index = slice(*slice_args) + actual = p[index] + expected = iterable[1:][index] + self.assertEqual(actual, expected, slice_args) + + def test_slicing_error(self): + iterable = '01234567' + p = mi.peekable(iter(iterable)) + + # Prime the cache + p.peek() + old_cache = list(p._cache) + + # Illegal slice + with self.assertRaises(ValueError): + p[1:-1:0] + + # Neither the cache nor the iteration should be affected + self.assertEqual(old_cache, list(p._cache)) + self.assertEqual(list(p), list(iterable)) + + def test_passthrough(self): + """Iterating a peekable without using ``peek()`` or ``prepend()`` + should just give the underlying iterable's elements (a trivial test but + useful to set a baseline in case something goes wrong)""" + expected = [1, 2, 3, 4, 5] + actual = list(mi.peekable(expected)) + self.assertEqual(actual, expected) + + # prepend() behavior tests + + def test_prepend(self): + """Tests intersperesed ``prepend()`` and ``next()`` calls""" + it = mi.peekable(range(2)) + actual = [] + + # Test prepend() before next() + it.prepend(10) + actual += [next(it), next(it)] + + # Test prepend() between next()s + it.prepend(11) + actual += [next(it), next(it)] + + # Test prepend() after source iterable is consumed + it.prepend(12) + actual += [next(it)] + + expected = [10, 0, 11, 1, 12] + self.assertEqual(actual, expected) + + def test_multi_prepend(self): + """Tests prepending multiple items and getting them in proper order""" + it = mi.peekable(range(5)) + actual = [next(it), next(it)] + it.prepend(10, 11, 12) + it.prepend(20, 21) + actual += list(it) + expected = [0, 1, 20, 21, 10, 11, 12, 2, 3, 4] + self.assertEqual(actual, expected) + + def test_empty(self): + """Tests prepending in front of an empty iterable""" + it = mi.peekable([]) + it.prepend(10) + actual = list(it) + expected = [10] + self.assertEqual(actual, expected) + + def test_prepend_truthiness(self): + """Tests that ``__bool__()`` or ``__nonzero__()`` works properly + with ``prepend()``""" + it = mi.peekable(range(5)) + self.assertTrue(it) + actual = list(it) + self.assertFalse(it) + it.prepend(10) + self.assertTrue(it) + actual += [next(it)] + self.assertFalse(it) + expected = [0, 1, 2, 3, 4, 10] + self.assertEqual(actual, expected) + + def test_multi_prepend_peek(self): + """Tests prepending multiple elements and getting them in reverse order + while peeking""" + it = mi.peekable(range(5)) + actual = [next(it), next(it)] + self.assertEqual(it.peek(), 2) + it.prepend(10, 11, 12) + self.assertEqual(it.peek(), 10) + it.prepend(20, 21) + self.assertEqual(it.peek(), 20) + actual += list(it) + self.assertFalse(it) + expected = [0, 1, 20, 21, 10, 11, 12, 2, 3, 4] + self.assertEqual(actual, expected) + + def test_prepend_after_stop(self): + """Test resuming iteration after a previous exhaustion""" + it = mi.peekable(range(3)) + self.assertEqual(list(it), [0, 1, 2]) + self.assertRaises(StopIteration, lambda: next(it)) + it.prepend(10) + self.assertEqual(next(it), 10) + self.assertRaises(StopIteration, lambda: next(it)) + + def test_prepend_slicing(self): + """Tests interaction between prepending and slicing""" + seq = list(range(20)) + p = mi.peekable(seq) + + p.prepend(30, 40, 50) + pseq = [30, 40, 50] + seq # pseq for prepended_seq + + # adapt the specific tests from test_slicing + self.assertEqual(p[0], 30) + self.assertEqual(p[1:8], pseq[1:8]) + self.assertEqual(p[1:], pseq[1:]) + self.assertEqual(p[:5], pseq[:5]) + self.assertEqual(p[:], pseq[:]) + self.assertEqual(p[:100], pseq[:100]) + self.assertEqual(p[::2], pseq[::2]) + self.assertEqual(p[::-1], pseq[::-1]) + + def test_prepend_indexing(self): + """Tests interaction between prepending and indexing""" + seq = list(range(20)) + p = mi.peekable(seq) + + p.prepend(30, 40, 50) + + self.assertEqual(p[0], 30) + self.assertEqual(next(p), 30) + self.assertEqual(p[2], 0) + self.assertEqual(next(p), 40) + self.assertEqual(p[0], 50) + self.assertEqual(p[9], 8) + self.assertEqual(next(p), 50) + self.assertEqual(p[8], 8) + self.assertEqual(p[-2], 18) + self.assertEqual(p[-9], 11) + self.assertRaises(IndexError, lambda: p[-21]) + + def test_prepend_iterable(self): + """Tests prepending from an iterable""" + it = mi.peekable(range(5)) + # Don't directly use the range() object to avoid any range-specific + # optimizations + it.prepend(*(x for x in range(5))) + actual = list(it) + expected = list(chain(range(5), range(5))) + self.assertEqual(actual, expected) + + def test_prepend_many(self): + """Tests that prepending a huge number of elements works""" + it = mi.peekable(range(5)) + # Don't directly use the range() object to avoid any range-specific + # optimizations + it.prepend(*(x for x in range(20000))) + actual = list(it) + expected = list(chain(range(20000), range(5))) + self.assertEqual(actual, expected) + + def test_prepend_reversed(self): + """Tests prepending from a reversed iterable""" + it = mi.peekable(range(3)) + it.prepend(*reversed((10, 11, 12))) + actual = list(it) + expected = [12, 11, 10, 0, 1, 2] + self.assertEqual(actual, expected) + + +class ConsumerTests(TestCase): + """Tests for ``consumer()``""" + + def test_consumer(self): + @mi.consumer + def eater(): + while True: + x = yield # noqa + + e = eater() + e.send('hi') # without @consumer, would raise TypeError + + +class DistinctPermutationsTests(TestCase): + def test_distinct_permutations(self): + """Make sure the output for ``distinct_permutations()`` is the same as + set(permutations(it)). + + """ + iterable = ['z', 'a', 'a', 'q', 'q', 'q', 'y'] + test_output = sorted(mi.distinct_permutations(iterable)) + ref_output = sorted(set(permutations(iterable))) + self.assertEqual(test_output, ref_output) + + def test_other_iterables(self): + """Make sure ``distinct_permutations()`` accepts a different type of + iterables. + + """ + # a generator + iterable = (c for c in ['z', 'a', 'a', 'q', 'q', 'q', 'y']) + test_output = sorted(mi.distinct_permutations(iterable)) + # "reload" it + iterable = (c for c in ['z', 'a', 'a', 'q', 'q', 'q', 'y']) + ref_output = sorted(set(permutations(iterable))) + self.assertEqual(test_output, ref_output) + + # an iterator + iterable = iter(['z', 'a', 'a', 'q', 'q', 'q', 'y']) + test_output = sorted(mi.distinct_permutations(iterable)) + # "reload" it + iterable = iter(['z', 'a', 'a', 'q', 'q', 'q', 'y']) + ref_output = sorted(set(permutations(iterable))) + self.assertEqual(test_output, ref_output) + + +class IlenTests(TestCase): + def test_ilen(self): + """Sanity-checks for ``ilen()``.""" + # Non-empty + self.assertEqual( + mi.ilen(filter(lambda x: x % 10 == 0, range(101))), 11 + ) + + # Empty + self.assertEqual(mi.ilen((x for x in range(0))), 0) + + # Iterable with __len__ + self.assertEqual(mi.ilen(list(range(6))), 6) + + +class WithIterTests(TestCase): + def test_with_iter(self): + s = StringIO('One fish\nTwo fish') + initial_words = [line.split()[0] for line in mi.with_iter(s)] + + # Iterable's items should be faithfully represented + self.assertEqual(initial_words, ['One', 'Two']) + # The file object should be closed + self.assertEqual(s.closed, True) + + +class OneTests(TestCase): + def test_basic(self): + it = iter(['item']) + self.assertEqual(mi.one(it), 'item') + + def test_too_short(self): + it = iter([]) + self.assertRaises(ValueError, lambda: mi.one(it)) + self.assertRaises(IndexError, lambda: mi.one(it, too_short=IndexError)) + + def test_too_long(self): + it = count() + self.assertRaises(ValueError, lambda: mi.one(it)) # burn 0 and 1 + self.assertEqual(next(it), 2) + self.assertRaises( + OverflowError, lambda: mi.one(it, too_long=OverflowError) + ) + + +class IntersperseTest(TestCase): + """ Tests for intersperse() """ + + def test_even(self): + iterable = (x for x in '01') + self.assertEqual( + list(mi.intersperse(None, iterable)), ['0', None, '1'] + ) + + def test_odd(self): + iterable = (x for x in '012') + self.assertEqual( + list(mi.intersperse(None, iterable)), ['0', None, '1', None, '2'] + ) + + def test_nested(self): + element = ('a', 'b') + iterable = (x for x in '012') + actual = list(mi.intersperse(element, iterable)) + expected = ['0', ('a', 'b'), '1', ('a', 'b'), '2'] + self.assertEqual(actual, expected) + + def test_not_iterable(self): + self.assertRaises(TypeError, lambda: mi.intersperse('x', 1)) + + def test_n(self): + for n, element, expected in [ + (1, '_', ['0', '_', '1', '_', '2', '_', '3', '_', '4', '_', '5']), + (2, '_', ['0', '1', '_', '2', '3', '_', '4', '5']), + (3, '_', ['0', '1', '2', '_', '3', '4', '5']), + (4, '_', ['0', '1', '2', '3', '_', '4', '5']), + (5, '_', ['0', '1', '2', '3', '4', '_', '5']), + (6, '_', ['0', '1', '2', '3', '4', '5']), + (7, '_', ['0', '1', '2', '3', '4', '5']), + (3, ['a', 'b'], ['0', '1', '2', ['a', 'b'], '3', '4', '5']), + ]: + iterable = (x for x in '012345') + actual = list(mi.intersperse(element, iterable, n=n)) + self.assertEqual(actual, expected) + + def test_n_zero(self): + self.assertRaises( + ValueError, lambda: list(mi.intersperse('x', '012', n=0)) + ) + + +class UniqueToEachTests(TestCase): + """Tests for ``unique_to_each()``""" + + def test_all_unique(self): + """When all the input iterables are unique the output should match + the input.""" + iterables = [[1, 2], [3, 4, 5], [6, 7, 8]] + self.assertEqual(mi.unique_to_each(*iterables), iterables) + + def test_duplicates(self): + """When there are duplicates in any of the input iterables that aren't + in the rest, those duplicates should be emitted.""" + iterables = ["mississippi", "missouri"] + self.assertEqual( + mi.unique_to_each(*iterables), [['p', 'p'], ['o', 'u', 'r']] + ) + + def test_mixed(self): + """When the input iterables contain different types the function should + still behave properly""" + iterables = ['x', (i for i in range(3)), [1, 2, 3], tuple()] + self.assertEqual(mi.unique_to_each(*iterables), [['x'], [0], [3], []]) + + +class WindowedTests(TestCase): + """Tests for ``windowed()``""" + + def test_basic(self): + actual = list(mi.windowed([1, 2, 3, 4, 5], 3)) + expected = [(1, 2, 3), (2, 3, 4), (3, 4, 5)] + self.assertEqual(actual, expected) + + def test_large_size(self): + """ + When the window size is larger than the iterable, and no fill value is + given,``None`` should be filled in. + """ + actual = list(mi.windowed([1, 2, 3, 4, 5], 6)) + expected = [(1, 2, 3, 4, 5, None)] + self.assertEqual(actual, expected) + + def test_fillvalue(self): + """ + When sizes don't match evenly, the given fill value should be used. + """ + iterable = [1, 2, 3, 4, 5] + + for n, kwargs, expected in [ + (6, {}, [(1, 2, 3, 4, 5, '!')]), # n > len(iterable) + (3, {'step': 3}, [(1, 2, 3), (4, 5, '!')]), # using ``step`` + ]: + actual = list(mi.windowed(iterable, n, fillvalue='!', **kwargs)) + self.assertEqual(actual, expected) + + def test_zero(self): + """When the window size is zero, an empty tuple should be emitted.""" + actual = list(mi.windowed([1, 2, 3, 4, 5], 0)) + expected = [tuple()] + self.assertEqual(actual, expected) + + def test_negative(self): + """When the window size is negative, ValueError should be raised.""" + with self.assertRaises(ValueError): + list(mi.windowed([1, 2, 3, 4, 5], -1)) + + def test_step(self): + """The window should advance by the number of steps provided""" + iterable = [1, 2, 3, 4, 5, 6, 7] + for n, step, expected in [ + (3, 2, [(1, 2, 3), (3, 4, 5), (5, 6, 7)]), # n > step + (3, 3, [(1, 2, 3), (4, 5, 6), (7, None, None)]), # n == step + (3, 4, [(1, 2, 3), (5, 6, 7)]), # line up nicely + (3, 5, [(1, 2, 3), (6, 7, None)]), # off by one + (3, 6, [(1, 2, 3), (7, None, None)]), # off by two + (3, 7, [(1, 2, 3)]), # step past the end + (7, 8, [(1, 2, 3, 4, 5, 6, 7)]), # step > len(iterable) + ]: + actual = list(mi.windowed(iterable, n, step=step)) + self.assertEqual(actual, expected) + + # Step must be greater than or equal to 1 + with self.assertRaises(ValueError): + list(mi.windowed(iterable, 3, step=0)) + + +class BucketTests(TestCase): + """Tests for ``bucket()``""" + + def test_basic(self): + iterable = [10, 20, 30, 11, 21, 31, 12, 22, 23, 33] + D = mi.bucket(iterable, key=lambda x: 10 * (x // 10)) + + # In-order access + self.assertEqual(list(D[10]), [10, 11, 12]) + + # Out of order access + self.assertEqual(list(D[30]), [30, 31, 33]) + self.assertEqual(list(D[20]), [20, 21, 22, 23]) + + self.assertEqual(list(D[40]), []) # Nothing in here! + + def test_in(self): + iterable = [10, 20, 30, 11, 21, 31, 12, 22, 23, 33] + D = mi.bucket(iterable, key=lambda x: 10 * (x // 10)) + + self.assertTrue(10 in D) + self.assertFalse(40 in D) + self.assertTrue(20 in D) + self.assertFalse(21 in D) + + # Checking in-ness shouldn't advance the iterator + self.assertEqual(next(D[10]), 10) + + def test_validator(self): + iterable = count(0) + key = lambda x: int(str(x)[0]) # First digit of each number + validator = lambda x: 0 < x < 10 # No leading zeros + D = mi.bucket(iterable, key, validator=validator) + self.assertEqual(mi.take(3, D[1]), [1, 10, 11]) + self.assertNotIn(0, D) # Non-valid entries don't return True + self.assertNotIn(0, D._cache) # Don't store non-valid entries + self.assertEqual(list(D[0]), []) + + +class SpyTests(TestCase): + """Tests for ``spy()``""" + + def test_basic(self): + original_iterable = iter('abcdefg') + head, new_iterable = mi.spy(original_iterable) + self.assertEqual(head, ['a']) + self.assertEqual( + list(new_iterable), ['a', 'b', 'c', 'd', 'e', 'f', 'g'] + ) + + def test_unpacking(self): + original_iterable = iter('abcdefg') + (first, second, third), new_iterable = mi.spy(original_iterable, 3) + self.assertEqual(first, 'a') + self.assertEqual(second, 'b') + self.assertEqual(third, 'c') + self.assertEqual( + list(new_iterable), ['a', 'b', 'c', 'd', 'e', 'f', 'g'] + ) + + def test_too_many(self): + original_iterable = iter('abc') + head, new_iterable = mi.spy(original_iterable, 4) + self.assertEqual(head, ['a', 'b', 'c']) + self.assertEqual(list(new_iterable), ['a', 'b', 'c']) + + def test_zero(self): + original_iterable = iter('abc') + head, new_iterable = mi.spy(original_iterable, 0) + self.assertEqual(head, []) + self.assertEqual(list(new_iterable), ['a', 'b', 'c']) + + +class InterleaveTests(TestCase): + def test_even(self): + actual = list(mi.interleave([1, 4, 7], [2, 5, 8], [3, 6, 9])) + expected = [1, 2, 3, 4, 5, 6, 7, 8, 9] + self.assertEqual(actual, expected) + + def test_short(self): + actual = list(mi.interleave([1, 4], [2, 5, 7], [3, 6, 8])) + expected = [1, 2, 3, 4, 5, 6] + self.assertEqual(actual, expected) + + def test_mixed_types(self): + it_list = ['a', 'b', 'c', 'd'] + it_str = '12345' + it_inf = count() + actual = list(mi.interleave(it_list, it_str, it_inf)) + expected = ['a', '1', 0, 'b', '2', 1, 'c', '3', 2, 'd', '4', 3] + self.assertEqual(actual, expected) + + +class InterleaveLongestTests(TestCase): + def test_even(self): + actual = list(mi.interleave_longest([1, 4, 7], [2, 5, 8], [3, 6, 9])) + expected = [1, 2, 3, 4, 5, 6, 7, 8, 9] + self.assertEqual(actual, expected) + + def test_short(self): + actual = list(mi.interleave_longest([1, 4], [2, 5, 7], [3, 6, 8])) + expected = [1, 2, 3, 4, 5, 6, 7, 8] + self.assertEqual(actual, expected) + + def test_mixed_types(self): + it_list = ['a', 'b', 'c', 'd'] + it_str = '12345' + it_gen = (x for x in range(3)) + actual = list(mi.interleave_longest(it_list, it_str, it_gen)) + expected = ['a', '1', 0, 'b', '2', 1, 'c', '3', 2, 'd', '4', '5'] + self.assertEqual(actual, expected) + + +class TestCollapse(TestCase): + """Tests for ``collapse()``""" + + def test_collapse(self): + l = [[1], 2, [[3], 4], [[[5]]]] + self.assertEqual(list(mi.collapse(l)), [1, 2, 3, 4, 5]) + + def test_collapse_to_string(self): + l = [["s1"], "s2", [["s3"], "s4"], [[["s5"]]]] + self.assertEqual(list(mi.collapse(l)), ["s1", "s2", "s3", "s4", "s5"]) + + def test_collapse_flatten(self): + l = [[1], [2], [[3], 4], [[[5]]]] + self.assertEqual(list(mi.collapse(l, levels=1)), list(mi.flatten(l))) + + def test_collapse_to_level(self): + l = [[1], 2, [[3], 4], [[[5]]]] + self.assertEqual(list(mi.collapse(l, levels=2)), [1, 2, 3, 4, [5]]) + self.assertEqual( + list(mi.collapse(mi.collapse(l, levels=1), levels=1)), + list(mi.collapse(l, levels=2)) + ) + + def test_collapse_to_list(self): + l = (1, [2], (3, [4, (5,)], 'ab')) + actual = list(mi.collapse(l, base_type=list)) + expected = [1, [2], 3, [4, (5,)], 'ab'] + self.assertEqual(actual, expected) + + +class SideEffectTests(TestCase): + """Tests for ``side_effect()``""" + + def test_individual(self): + # The function increments the counter for each call + counter = [0] + + def func(arg): + counter[0] += 1 + + result = list(mi.side_effect(func, range(10))) + self.assertEqual(result, list(range(10))) + self.assertEqual(counter[0], 10) + + def test_chunked(self): + # The function increments the counter for each call + counter = [0] + + def func(arg): + counter[0] += 1 + + result = list(mi.side_effect(func, range(10), 2)) + self.assertEqual(result, list(range(10))) + self.assertEqual(counter[0], 5) + + def test_before_after(self): + f = StringIO() + collector = [] + + def func(item): + print(item, file=f) + collector.append(f.getvalue()) + + def it(): + yield u'a' + yield u'b' + raise RuntimeError('kaboom') + + before = lambda: print('HEADER', file=f) + after = f.close + + try: + mi.consume(mi.side_effect(func, it(), before=before, after=after)) + except RuntimeError: + pass + + # The iterable should have been written to the file + self.assertEqual(collector, [u'HEADER\na\n', u'HEADER\na\nb\n']) + + # The file should be closed even though something bad happened + self.assertTrue(f.closed) + + def test_before_fails(self): + f = StringIO() + func = lambda x: print(x, file=f) + + def before(): + raise RuntimeError('ouch') + + try: + mi.consume( + mi.side_effect(func, u'abc', before=before, after=f.close) + ) + except RuntimeError: + pass + + # The file should be closed even though something bad happened in the + # before function + self.assertTrue(f.closed) + + +class SlicedTests(TestCase): + """Tests for ``sliced()``""" + + def test_even(self): + """Test when the length of the sequence is divisible by *n*""" + seq = 'ABCDEFGHI' + self.assertEqual(list(mi.sliced(seq, 3)), ['ABC', 'DEF', 'GHI']) + + def test_odd(self): + """Test when the length of the sequence is not divisible by *n*""" + seq = 'ABCDEFGHI' + self.assertEqual(list(mi.sliced(seq, 4)), ['ABCD', 'EFGH', 'I']) + + def test_not_sliceable(self): + seq = (x for x in 'ABCDEFGHI') + + with self.assertRaises(TypeError): + list(mi.sliced(seq, 3)) + + +class SplitAtTests(TestCase): + """Tests for ``split()``""" + + def comp_with_str_split(self, str_to_split, delim): + pred = lambda c: c == delim + actual = list(map(''.join, mi.split_at(str_to_split, pred))) + expected = str_to_split.split(delim) + self.assertEqual(actual, expected) + + def test_seperators(self): + test_strs = ['', 'abcba', 'aaabbbcccddd', 'e'] + for s, delim in product(test_strs, 'abcd'): + self.comp_with_str_split(s, delim) + + +class SplitBeforeTest(TestCase): + """Tests for ``split_before()``""" + + def test_starts_with_sep(self): + actual = list(mi.split_before('xooxoo', lambda c: c == 'x')) + expected = [['x', 'o', 'o'], ['x', 'o', 'o']] + self.assertEqual(actual, expected) + + def test_ends_with_sep(self): + actual = list(mi.split_before('ooxoox', lambda c: c == 'x')) + expected = [['o', 'o'], ['x', 'o', 'o'], ['x']] + self.assertEqual(actual, expected) + + def test_no_sep(self): + actual = list(mi.split_before('ooo', lambda c: c == 'x')) + expected = [['o', 'o', 'o']] + self.assertEqual(actual, expected) + + +class SplitAfterTest(TestCase): + """Tests for ``split_after()``""" + + def test_starts_with_sep(self): + actual = list(mi.split_after('xooxoo', lambda c: c == 'x')) + expected = [['x'], ['o', 'o', 'x'], ['o', 'o']] + self.assertEqual(actual, expected) + + def test_ends_with_sep(self): + actual = list(mi.split_after('ooxoox', lambda c: c == 'x')) + expected = [['o', 'o', 'x'], ['o', 'o', 'x']] + self.assertEqual(actual, expected) + + def test_no_sep(self): + actual = list(mi.split_after('ooo', lambda c: c == 'x')) + expected = [['o', 'o', 'o']] + self.assertEqual(actual, expected) + + +class PaddedTest(TestCase): + """Tests for ``padded()``""" + + def test_no_n(self): + seq = [1, 2, 3] + + # No fillvalue + self.assertEqual(mi.take(5, mi.padded(seq)), [1, 2, 3, None, None]) + + # With fillvalue + self.assertEqual( + mi.take(5, mi.padded(seq, fillvalue='')), [1, 2, 3, '', ''] + ) + + def test_invalid_n(self): + self.assertRaises(ValueError, lambda: list(mi.padded([1, 2, 3], n=-1))) + self.assertRaises(ValueError, lambda: list(mi.padded([1, 2, 3], n=0))) + + def test_valid_n(self): + seq = [1, 2, 3, 4, 5] + + # No need for padding: len(seq) <= n + self.assertEqual(list(mi.padded(seq, n=4)), [1, 2, 3, 4, 5]) + self.assertEqual(list(mi.padded(seq, n=5)), [1, 2, 3, 4, 5]) + + # No fillvalue + self.assertEqual( + list(mi.padded(seq, n=7)), [1, 2, 3, 4, 5, None, None] + ) + + # With fillvalue + self.assertEqual( + list(mi.padded(seq, fillvalue='', n=7)), [1, 2, 3, 4, 5, '', ''] + ) + + def test_next_multiple(self): + seq = [1, 2, 3, 4, 5, 6] + + # No need for padding: len(seq) % n == 0 + self.assertEqual( + list(mi.padded(seq, n=3, next_multiple=True)), [1, 2, 3, 4, 5, 6] + ) + + # Padding needed: len(seq) < n + self.assertEqual( + list(mi.padded(seq, n=8, next_multiple=True)), + [1, 2, 3, 4, 5, 6, None, None] + ) + + # No padding needed: len(seq) == n + self.assertEqual( + list(mi.padded(seq, n=6, next_multiple=True)), [1, 2, 3, 4, 5, 6] + ) + + # Padding needed: len(seq) > n + self.assertEqual( + list(mi.padded(seq, n=4, next_multiple=True)), + [1, 2, 3, 4, 5, 6, None, None] + ) + + # With fillvalue + self.assertEqual( + list(mi.padded(seq, fillvalue='', n=4, next_multiple=True)), + [1, 2, 3, 4, 5, 6, '', ''] + ) + + +class DistributeTest(TestCase): + """Tests for distribute()""" + + def test_invalid_n(self): + self.assertRaises(ValueError, lambda: mi.distribute(-1, [1, 2, 3])) + self.assertRaises(ValueError, lambda: mi.distribute(0, [1, 2, 3])) + + def test_basic(self): + iterable = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + + for n, expected in [ + (1, [iterable]), + (2, [[1, 3, 5, 7, 9], [2, 4, 6, 8, 10]]), + (3, [[1, 4, 7, 10], [2, 5, 8], [3, 6, 9]]), + (10, [[n] for n in range(1, 10 + 1)]), + ]: + self.assertEqual( + [list(x) for x in mi.distribute(n, iterable)], expected + ) + + def test_large_n(self): + iterable = [1, 2, 3, 4] + self.assertEqual( + [list(x) for x in mi.distribute(6, iterable)], + [[1], [2], [3], [4], [], []] + ) + + +class StaggerTest(TestCase): + """Tests for ``stagger()``""" + + def test_default(self): + iterable = [0, 1, 2, 3] + actual = list(mi.stagger(iterable)) + expected = [(None, 0, 1), (0, 1, 2), (1, 2, 3)] + self.assertEqual(actual, expected) + + def test_offsets(self): + iterable = [0, 1, 2, 3] + for offsets, expected in [ + ((-2, 0, 2), [('', 0, 2), ('', 1, 3)]), + ((-2, -1), [('', ''), ('', 0), (0, 1), (1, 2), (2, 3)]), + ((1, 2), [(1, 2), (2, 3)]), + ]: + all_groups = mi.stagger(iterable, offsets=offsets, fillvalue='') + self.assertEqual(list(all_groups), expected) + + def test_longest(self): + iterable = [0, 1, 2, 3] + for offsets, expected in [ + ( + (-1, 0, 1), + [('', 0, 1), (0, 1, 2), (1, 2, 3), (2, 3, ''), (3, '', '')] + ), + ((-2, -1), [('', ''), ('', 0), (0, 1), (1, 2), (2, 3), (3, '')]), + ((1, 2), [(1, 2), (2, 3), (3, '')]), + ]: + all_groups = mi.stagger( + iterable, offsets=offsets, fillvalue='', longest=True + ) + self.assertEqual(list(all_groups), expected) + + +class ZipOffsetTest(TestCase): + """Tests for ``zip_offset()``""" + + def test_shortest(self): + a_1 = [0, 1, 2, 3] + a_2 = [0, 1, 2, 3, 4, 5] + a_3 = [0, 1, 2, 3, 4, 5, 6, 7] + actual = list( + mi.zip_offset(a_1, a_2, a_3, offsets=(-1, 0, 1), fillvalue='') + ) + expected = [('', 0, 1), (0, 1, 2), (1, 2, 3), (2, 3, 4), (3, 4, 5)] + self.assertEqual(actual, expected) + + def test_longest(self): + a_1 = [0, 1, 2, 3] + a_2 = [0, 1, 2, 3, 4, 5] + a_3 = [0, 1, 2, 3, 4, 5, 6, 7] + actual = list( + mi.zip_offset(a_1, a_2, a_3, offsets=(-1, 0, 1), longest=True) + ) + expected = [ + (None, 0, 1), + (0, 1, 2), + (1, 2, 3), + (2, 3, 4), + (3, 4, 5), + (None, 5, 6), + (None, None, 7), + ] + self.assertEqual(actual, expected) + + def test_mismatch(self): + iterables = [0, 1, 2], [2, 3, 4] + offsets = (-1, 0, 1) + self.assertRaises( + ValueError, + lambda: list(mi.zip_offset(*iterables, offsets=offsets)) + ) + + +class SortTogetherTest(TestCase): + """Tests for sort_together()""" + + def test_key_list(self): + """tests `key_list` including default, iterables include duplicates""" + iterables = [ + ['GA', 'GA', 'GA', 'CT', 'CT', 'CT'], + ['May', 'Aug.', 'May', 'June', 'July', 'July'], + [97, 20, 100, 70, 100, 20] + ] + + self.assertEqual( + mi.sort_together(iterables), + [ + ('CT', 'CT', 'CT', 'GA', 'GA', 'GA'), + ('June', 'July', 'July', 'May', 'Aug.', 'May'), + (70, 100, 20, 97, 20, 100) + ] + ) + + self.assertEqual( + mi.sort_together(iterables, key_list=(0, 1)), + [ + ('CT', 'CT', 'CT', 'GA', 'GA', 'GA'), + ('July', 'July', 'June', 'Aug.', 'May', 'May'), + (100, 20, 70, 20, 97, 100) + ] + ) + + self.assertEqual( + mi.sort_together(iterables, key_list=(0, 1, 2)), + [ + ('CT', 'CT', 'CT', 'GA', 'GA', 'GA'), + ('July', 'July', 'June', 'Aug.', 'May', 'May'), + (20, 100, 70, 20, 97, 100) + ] + ) + + self.assertEqual( + mi.sort_together(iterables, key_list=(2,)), + [ + ('GA', 'CT', 'CT', 'GA', 'GA', 'CT'), + ('Aug.', 'July', 'June', 'May', 'May', 'July'), + (20, 20, 70, 97, 100, 100) + ] + ) + + def test_invalid_key_list(self): + """tests `key_list` for indexes not available in `iterables`""" + iterables = [ + ['GA', 'GA', 'GA', 'CT', 'CT', 'CT'], + ['May', 'Aug.', 'May', 'June', 'July', 'July'], + [97, 20, 100, 70, 100, 20] + ] + + self.assertRaises( + IndexError, lambda: mi.sort_together(iterables, key_list=(5,)) + ) + + def test_reverse(self): + """tests `reverse` to ensure a reverse sort for `key_list` iterables""" + iterables = [ + ['GA', 'GA', 'GA', 'CT', 'CT', 'CT'], + ['May', 'Aug.', 'May', 'June', 'July', 'July'], + [97, 20, 100, 70, 100, 20] + ] + + self.assertEqual( + mi.sort_together(iterables, key_list=(0, 1, 2), reverse=True), + [('GA', 'GA', 'GA', 'CT', 'CT', 'CT'), + ('May', 'May', 'Aug.', 'June', 'July', 'July'), + (100, 97, 20, 70, 100, 20)] + ) + + def test_uneven_iterables(self): + """tests trimming of iterables to the shortest length before sorting""" + iterables = [['GA', 'GA', 'GA', 'CT', 'CT', 'CT', 'MA'], + ['May', 'Aug.', 'May', 'June', 'July', 'July'], + [97, 20, 100, 70, 100, 20, 0]] + + self.assertEqual( + mi.sort_together(iterables), + [ + ('CT', 'CT', 'CT', 'GA', 'GA', 'GA'), + ('June', 'July', 'July', 'May', 'Aug.', 'May'), + (70, 100, 20, 97, 20, 100) + ] + ) + + +class DivideTest(TestCase): + """Tests for divide()""" + + def test_invalid_n(self): + self.assertRaises(ValueError, lambda: mi.divide(-1, [1, 2, 3])) + self.assertRaises(ValueError, lambda: mi.divide(0, [1, 2, 3])) + + def test_basic(self): + iterable = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + + for n, expected in [ + (1, [iterable]), + (2, [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]), + (3, [[1, 2, 3, 4], [5, 6, 7], [8, 9, 10]]), + (10, [[n] for n in range(1, 10 + 1)]), + ]: + self.assertEqual( + [list(x) for x in mi.divide(n, iterable)], expected + ) + + def test_large_n(self): + iterable = [1, 2, 3, 4] + self.assertEqual( + [list(x) for x in mi.divide(6, iterable)], + [[1], [2], [3], [4], [], []] + ) + + +class TestAlwaysIterable(TestCase): + """Tests for always_iterable()""" + def test_single(self): + self.assertEqual(list(mi.always_iterable(1)), [1]) + + def test_strings(self): + for obj in ['foo', b'bar', u'baz']: + actual = list(mi.always_iterable(obj)) + expected = [obj] + self.assertEqual(actual, expected) + + def test_base_type(self): + dict_obj = {'a': 1, 'b': 2} + str_obj = '123' + + # Default: dicts are iterable like they normally are + default_actual = list(mi.always_iterable(dict_obj)) + default_expected = list(dict_obj) + self.assertEqual(default_actual, default_expected) + + # Unitary types set: dicts are not iterable + custom_actual = list(mi.always_iterable(dict_obj, base_type=dict)) + custom_expected = [dict_obj] + self.assertEqual(custom_actual, custom_expected) + + # With unitary types set, strings are iterable + str_actual = list(mi.always_iterable(str_obj, base_type=None)) + str_expected = list(str_obj) + self.assertEqual(str_actual, str_expected) + + def test_iterables(self): + self.assertEqual(list(mi.always_iterable([0, 1])), [0, 1]) + self.assertEqual( + list(mi.always_iterable([0, 1], base_type=list)), [[0, 1]] + ) + self.assertEqual( + list(mi.always_iterable(iter('foo'))), ['f', 'o', 'o'] + ) + self.assertEqual(list(mi.always_iterable([])), []) + + def test_none(self): + self.assertEqual(list(mi.always_iterable(None)), []) + + def test_generator(self): + def _gen(): + yield 0 + yield 1 + + self.assertEqual(list(mi.always_iterable(_gen())), [0, 1]) + + +class AdjacentTests(TestCase): + def test_typical(self): + actual = list(mi.adjacent(lambda x: x % 5 == 0, range(10))) + expected = [(True, 0), (True, 1), (False, 2), (False, 3), (True, 4), + (True, 5), (True, 6), (False, 7), (False, 8), (False, 9)] + self.assertEqual(actual, expected) + + def test_empty_iterable(self): + actual = list(mi.adjacent(lambda x: x % 5 == 0, [])) + expected = [] + self.assertEqual(actual, expected) + + def test_length_one(self): + actual = list(mi.adjacent(lambda x: x % 5 == 0, [0])) + expected = [(True, 0)] + self.assertEqual(actual, expected) + + actual = list(mi.adjacent(lambda x: x % 5 == 0, [1])) + expected = [(False, 1)] + self.assertEqual(actual, expected) + + def test_consecutive_true(self): + """Test that when the predicate matches multiple consecutive elements + it doesn't repeat elements in the output""" + actual = list(mi.adjacent(lambda x: x % 5 < 2, range(10))) + expected = [(True, 0), (True, 1), (True, 2), (False, 3), (True, 4), + (True, 5), (True, 6), (True, 7), (False, 8), (False, 9)] + self.assertEqual(actual, expected) + + def test_distance(self): + actual = list(mi.adjacent(lambda x: x % 5 == 0, range(10), distance=2)) + expected = [(True, 0), (True, 1), (True, 2), (True, 3), (True, 4), + (True, 5), (True, 6), (True, 7), (False, 8), (False, 9)] + self.assertEqual(actual, expected) + + actual = list(mi.adjacent(lambda x: x % 5 == 0, range(10), distance=3)) + expected = [(True, 0), (True, 1), (True, 2), (True, 3), (True, 4), + (True, 5), (True, 6), (True, 7), (True, 8), (False, 9)] + self.assertEqual(actual, expected) + + def test_large_distance(self): + """Test distance larger than the length of the iterable""" + iterable = range(10) + actual = list(mi.adjacent(lambda x: x % 5 == 4, iterable, distance=20)) + expected = list(zip(repeat(True), iterable)) + self.assertEqual(actual, expected) + + actual = list(mi.adjacent(lambda x: False, iterable, distance=20)) + expected = list(zip(repeat(False), iterable)) + self.assertEqual(actual, expected) + + def test_zero_distance(self): + """Test that adjacent() reduces to zip+map when distance is 0""" + iterable = range(1000) + predicate = lambda x: x % 4 == 2 + actual = mi.adjacent(predicate, iterable, 0) + expected = zip(map(predicate, iterable), iterable) + self.assertTrue(all(a == e for a, e in zip(actual, expected))) + + def test_negative_distance(self): + """Test that adjacent() raises an error with negative distance""" + pred = lambda x: x + self.assertRaises( + ValueError, lambda: mi.adjacent(pred, range(1000), -1) + ) + self.assertRaises( + ValueError, lambda: mi.adjacent(pred, range(10), -10) + ) + + def test_grouping(self): + """Test interaction of adjacent() with groupby_transform()""" + iterable = mi.adjacent(lambda x: x % 5 == 0, range(10)) + grouper = mi.groupby_transform(iterable, itemgetter(0), itemgetter(1)) + actual = [(k, list(g)) for k, g in grouper] + expected = [ + (True, [0, 1]), + (False, [2, 3]), + (True, [4, 5, 6]), + (False, [7, 8, 9]), + ] + self.assertEqual(actual, expected) + + def test_call_once(self): + """Test that the predicate is only called once per item.""" + already_seen = set() + iterable = range(10) + + def predicate(item): + self.assertNotIn(item, already_seen) + already_seen.add(item) + return True + + actual = list(mi.adjacent(predicate, iterable)) + expected = [(True, x) for x in iterable] + self.assertEqual(actual, expected) + + +class GroupByTransformTests(TestCase): + def assertAllGroupsEqual(self, groupby1, groupby2): + """Compare two groupby objects for equality, both keys and groups.""" + for a, b in zip(groupby1, groupby2): + key1, group1 = a + key2, group2 = b + self.assertEqual(key1, key2) + self.assertListEqual(list(group1), list(group2)) + self.assertRaises(StopIteration, lambda: next(groupby1)) + self.assertRaises(StopIteration, lambda: next(groupby2)) + + def test_default_funcs(self): + """Test that groupby_transform() with default args mimics groupby()""" + iterable = [(x // 5, x) for x in range(1000)] + actual = mi.groupby_transform(iterable) + expected = groupby(iterable) + self.assertAllGroupsEqual(actual, expected) + + def test_valuefunc(self): + iterable = [(int(x / 5), int(x / 3), x) for x in range(10)] + + # Test the standard usage of grouping one iterable using another's keys + grouper = mi.groupby_transform( + iterable, keyfunc=itemgetter(0), valuefunc=itemgetter(-1) + ) + actual = [(k, list(g)) for k, g in grouper] + expected = [(0, [0, 1, 2, 3, 4]), (1, [5, 6, 7, 8, 9])] + self.assertEqual(actual, expected) + + grouper = mi.groupby_transform( + iterable, keyfunc=itemgetter(1), valuefunc=itemgetter(-1) + ) + actual = [(k, list(g)) for k, g in grouper] + expected = [(0, [0, 1, 2]), (1, [3, 4, 5]), (2, [6, 7, 8]), (3, [9])] + self.assertEqual(actual, expected) + + # and now for something a little different + d = dict(zip(range(10), 'abcdefghij')) + grouper = mi.groupby_transform( + range(10), keyfunc=lambda x: x // 5, valuefunc=d.get + ) + actual = [(k, ''.join(g)) for k, g in grouper] + expected = [(0, 'abcde'), (1, 'fghij')] + self.assertEqual(actual, expected) + + def test_no_valuefunc(self): + iterable = range(1000) + + def key(x): + return x // 5 + + actual = mi.groupby_transform(iterable, key, valuefunc=None) + expected = groupby(iterable, key) + self.assertAllGroupsEqual(actual, expected) + + actual = mi.groupby_transform(iterable, key) # default valuefunc + expected = groupby(iterable, key) + self.assertAllGroupsEqual(actual, expected) + + +class NumericRangeTests(TestCase): + def test_basic(self): + for args, expected in [ + ((4,), [0, 1, 2, 3]), + ((4.0,), [0.0, 1.0, 2.0, 3.0]), + ((1.0, 4), [1.0, 2.0, 3.0]), + ((1, 4.0), [1, 2, 3]), + ((1.0, 5), [1.0, 2.0, 3.0, 4.0]), + ((0, 20, 5), [0, 5, 10, 15]), + ((0, 20, 5.0), [0.0, 5.0, 10.0, 15.0]), + ((0, 10, 3), [0, 3, 6, 9]), + ((0, 10, 3.0), [0.0, 3.0, 6.0, 9.0]), + ((0, -5, -1), [0, -1, -2, -3, -4]), + ((0.0, -5, -1), [0.0, -1.0, -2.0, -3.0, -4.0]), + ((1, 2, Fraction(1, 2)), [Fraction(1, 1), Fraction(3, 2)]), + ((0,), []), + ((0.0,), []), + ((1, 0), []), + ((1.0, 0.0), []), + ((Fraction(2, 1),), [Fraction(0, 1), Fraction(1, 1)]), + ((Decimal('2.0'),), [Decimal('0.0'), Decimal('1.0')]), + ]: + actual = list(mi.numeric_range(*args)) + self.assertEqual(actual, expected) + self.assertTrue( + all(type(a) == type(e) for a, e in zip(actual, expected)) + ) + + def test_arg_count(self): + self.assertRaises(TypeError, lambda: list(mi.numeric_range())) + self.assertRaises( + TypeError, lambda: list(mi.numeric_range(0, 1, 2, 3)) + ) + + def test_zero_step(self): + self.assertRaises( + ValueError, lambda: list(mi.numeric_range(1, 2, 0)) + ) + + +class CountCycleTests(TestCase): + def test_basic(self): + expected = [ + (0, 'a'), (0, 'b'), (0, 'c'), + (1, 'a'), (1, 'b'), (1, 'c'), + (2, 'a'), (2, 'b'), (2, 'c'), + ] + for actual in [ + mi.take(9, mi.count_cycle('abc')), # n=None + list(mi.count_cycle('abc', 3)), # n=3 + ]: + self.assertEqual(actual, expected) + + def test_empty(self): + self.assertEqual(list(mi.count_cycle('')), []) + self.assertEqual(list(mi.count_cycle('', 2)), []) + + def test_negative(self): + self.assertEqual(list(mi.count_cycle('abc', -3)), []) + + +class LocateTests(TestCase): + def test_default_pred(self): + iterable = [0, 1, 1, 0, 1, 0, 0] + actual = list(mi.locate(iterable)) + expected = [1, 2, 4] + self.assertEqual(actual, expected) + + def test_no_matches(self): + iterable = [0, 0, 0] + actual = list(mi.locate(iterable)) + expected = [] + self.assertEqual(actual, expected) + + def test_custom_pred(self): + iterable = ['0', 1, 1, '0', 1, '0', '0'] + pred = lambda x: x == '0' + actual = list(mi.locate(iterable, pred)) + expected = [0, 3, 5, 6] + self.assertEqual(actual, expected) + + def test_window_size(self): + iterable = ['0', 1, 1, '0', 1, '0', '0'] + pred = lambda *args: args == ('0', 1) + actual = list(mi.locate(iterable, pred, window_size=2)) + expected = [0, 3] + self.assertEqual(actual, expected) + + def test_window_size_large(self): + iterable = [1, 2, 3, 4] + pred = lambda a, b, c, d, e: True + actual = list(mi.locate(iterable, pred, window_size=5)) + expected = [0] + self.assertEqual(actual, expected) + + def test_window_size_zero(self): + iterable = [1, 2, 3, 4] + pred = lambda: True + with self.assertRaises(ValueError): + list(mi.locate(iterable, pred, window_size=0)) + + +class StripFunctionTests(TestCase): + def test_hashable(self): + iterable = list('www.example.com') + pred = lambda x: x in set('cmowz.') + + self.assertEqual(list(mi.lstrip(iterable, pred)), list('example.com')) + self.assertEqual(list(mi.rstrip(iterable, pred)), list('www.example')) + self.assertEqual(list(mi.strip(iterable, pred)), list('example')) + + def test_not_hashable(self): + iterable = [ + list('http://'), list('www'), list('.example'), list('.com') + ] + pred = lambda x: x in [list('http://'), list('www'), list('.com')] + + self.assertEqual(list(mi.lstrip(iterable, pred)), iterable[2:]) + self.assertEqual(list(mi.rstrip(iterable, pred)), iterable[:3]) + self.assertEqual(list(mi.strip(iterable, pred)), iterable[2: 3]) + + def test_math(self): + iterable = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2] + pred = lambda x: x <= 2 + + self.assertEqual(list(mi.lstrip(iterable, pred)), iterable[3:]) + self.assertEqual(list(mi.rstrip(iterable, pred)), iterable[:-3]) + self.assertEqual(list(mi.strip(iterable, pred)), iterable[3:-3]) + + +class IsliceExtendedTests(TestCase): + def test_all(self): + iterable = ['0', '1', '2', '3', '4', '5'] + indexes = list(range(-4, len(iterable) + 4)) + [None] + steps = [1, 2, 3, 4, -1, -2, -3, 4] + for slice_args in product(indexes, indexes, steps): + try: + actual = list(mi.islice_extended(iterable, *slice_args)) + except Exception as e: + self.fail((slice_args, e)) + + expected = iterable[slice(*slice_args)] + self.assertEqual(actual, expected, slice_args) + + def test_zero_step(self): + with self.assertRaises(ValueError): + list(mi.islice_extended([1, 2, 3], 0, 1, 0)) + + +class ConsecutiveGroupsTest(TestCase): + def test_numbers(self): + iterable = [-10, -8, -7, -6, 1, 2, 4, 5, -1, 7] + actual = [list(g) for g in mi.consecutive_groups(iterable)] + expected = [[-10], [-8, -7, -6], [1, 2], [4, 5], [-1], [7]] + self.assertEqual(actual, expected) + + def test_custom_ordering(self): + iterable = ['1', '10', '11', '20', '21', '22', '30', '31'] + ordering = lambda x: int(x) + actual = [list(g) for g in mi.consecutive_groups(iterable, ordering)] + expected = [['1'], ['10', '11'], ['20', '21', '22'], ['30', '31']] + self.assertEqual(actual, expected) + + def test_exotic_ordering(self): + iterable = [ + ('a', 'b', 'c', 'd'), + ('a', 'c', 'b', 'd'), + ('a', 'c', 'd', 'b'), + ('a', 'd', 'b', 'c'), + ('d', 'b', 'c', 'a'), + ('d', 'c', 'a', 'b'), + ] + ordering = list(permutations('abcd')).index + actual = [list(g) for g in mi.consecutive_groups(iterable, ordering)] + expected = [ + [('a', 'b', 'c', 'd')], + [('a', 'c', 'b', 'd'), ('a', 'c', 'd', 'b'), ('a', 'd', 'b', 'c')], + [('d', 'b', 'c', 'a'), ('d', 'c', 'a', 'b')], + ] + self.assertEqual(actual, expected) + + +class DifferenceTest(TestCase): + def test_normal(self): + iterable = [10, 20, 30, 40, 50] + actual = list(mi.difference(iterable)) + expected = [10, 10, 10, 10, 10] + self.assertEqual(actual, expected) + + def test_custom(self): + iterable = [10, 20, 30, 40, 50] + actual = list(mi.difference(iterable, add)) + expected = [10, 30, 50, 70, 90] + self.assertEqual(actual, expected) + + def test_roundtrip(self): + original = list(range(100)) + accumulated = mi.accumulate(original) + actual = list(mi.difference(accumulated)) + self.assertEqual(actual, original) + + def test_one(self): + self.assertEqual(list(mi.difference([0])), [0]) + + def test_empty(self): + self.assertEqual(list(mi.difference([])), []) + + +class SeekableTest(TestCase): + def test_exhaustion_reset(self): + iterable = [str(n) for n in range(10)] + + s = mi.seekable(iterable) + self.assertEqual(list(s), iterable) # Normal iteration + self.assertEqual(list(s), []) # Iterable is exhausted + + s.seek(0) + self.assertEqual(list(s), iterable) # Back in action + + def test_partial_reset(self): + iterable = [str(n) for n in range(10)] + + s = mi.seekable(iterable) + self.assertEqual(mi.take(5, s), iterable[:5]) # Normal iteration + + s.seek(1) + self.assertEqual(list(s), iterable[1:]) # Get the rest of the iterable + + def test_forward(self): + iterable = [str(n) for n in range(10)] + + s = mi.seekable(iterable) + self.assertEqual(mi.take(1, s), iterable[:1]) # Normal iteration + + s.seek(3) # Skip over index 2 + self.assertEqual(list(s), iterable[3:]) # Result is similar to slicing + + s.seek(0) # Back to 0 + self.assertEqual(list(s), iterable) # No difference in result + + def test_past_end(self): + iterable = [str(n) for n in range(10)] + + s = mi.seekable(iterable) + self.assertEqual(mi.take(1, s), iterable[:1]) # Normal iteration + + s.seek(20) + self.assertEqual(list(s), []) # Iterable is exhausted + + s.seek(0) # Back to 0 + self.assertEqual(list(s), iterable) # No difference in result + + def test_elements(self): + iterable = map(str, count()) + + s = mi.seekable(iterable) + mi.take(10, s) + + elements = s.elements() + self.assertEqual( + [elements[i] for i in range(10)], [str(n) for n in range(10)] + ) + self.assertEqual(len(elements), 10) + + mi.take(10, s) + self.assertEqual(list(elements), [str(n) for n in range(20)]) + + +class SequenceViewTests(TestCase): + def test_init(self): + view = mi.SequenceView((1, 2, 3)) + self.assertEqual(repr(view), "SequenceView((1, 2, 3))") + self.assertRaises(TypeError, lambda: mi.SequenceView({})) + + def test_update(self): + seq = [1, 2, 3] + view = mi.SequenceView(seq) + self.assertEqual(len(view), 3) + self.assertEqual(repr(view), "SequenceView([1, 2, 3])") + + seq.pop() + self.assertEqual(len(view), 2) + self.assertEqual(repr(view), "SequenceView([1, 2])") + + def test_indexing(self): + seq = ('a', 'b', 'c', 'd', 'e', 'f') + view = mi.SequenceView(seq) + for i in range(-len(seq), len(seq)): + self.assertEqual(view[i], seq[i]) + + def test_slicing(self): + seq = ('a', 'b', 'c', 'd', 'e', 'f') + view = mi.SequenceView(seq) + n = len(seq) + indexes = list(range(-n - 1, n + 1)) + [None] + steps = list(range(-n, n + 1)) + steps.remove(0) + for slice_args in product(indexes, indexes, steps): + i = slice(*slice_args) + self.assertEqual(view[i], seq[i]) + + def test_abc_methods(self): + # collections.Sequence should provide all of this functionality + seq = ('a', 'b', 'c', 'd', 'e', 'f', 'f') + view = mi.SequenceView(seq) + + # __contains__ + self.assertIn('b', view) + self.assertNotIn('g', view) + + # __iter__ + self.assertEqual(list(iter(view)), list(seq)) + + # __reversed__ + self.assertEqual(list(reversed(view)), list(reversed(seq))) + + # index + self.assertEqual(view.index('b'), 1) + + # count + self.assertEqual(seq.count('f'), 2) + + +class RunLengthTest(TestCase): + def test_encode(self): + iterable = (int(str(n)[0]) for n in count(800)) + actual = mi.take(4, mi.run_length.encode(iterable)) + expected = [(8, 100), (9, 100), (1, 1000), (2, 1000)] + self.assertEqual(actual, expected) + + def test_decode(self): + iterable = [('d', 4), ('c', 3), ('b', 2), ('a', 1)] + actual = ''.join(mi.run_length.decode(iterable)) + expected = 'ddddcccbba' + self.assertEqual(actual, expected) + + +class ExactlyNTests(TestCase): + """Tests for ``exactly_n()``""" + + def test_true(self): + """Iterable has ``n`` ``True`` elements""" + self.assertTrue(mi.exactly_n([True, False, True], 2)) + self.assertTrue(mi.exactly_n([1, 1, 1, 0], 3)) + self.assertTrue(mi.exactly_n([False, False], 0)) + self.assertTrue(mi.exactly_n(range(100), 10, lambda x: x < 10)) + + def test_false(self): + """Iterable does not have ``n`` ``True`` elements""" + self.assertFalse(mi.exactly_n([True, False, False], 2)) + self.assertFalse(mi.exactly_n([True, True, False], 1)) + self.assertFalse(mi.exactly_n([False], 1)) + self.assertFalse(mi.exactly_n([True], -1)) + self.assertFalse(mi.exactly_n(repeat(True), 100)) + + def test_empty(self): + """Return ``True`` if the iterable is empty and ``n`` is 0""" + self.assertTrue(mi.exactly_n([], 0)) + self.assertFalse(mi.exactly_n([], 1)) + + +class AlwaysReversibleTests(TestCase): + """Tests for ``always_reversible()``""" + + def test_regular_reversed(self): + self.assertEqual(list(reversed(range(10))), + list(mi.always_reversible(range(10)))) + self.assertEqual(list(reversed([1, 2, 3])), + list(mi.always_reversible([1, 2, 3]))) + self.assertEqual(reversed([1, 2, 3]).__class__, + mi.always_reversible([1, 2, 3]).__class__) + + def test_nonseq_reversed(self): + # Create a non-reversible generator from a sequence + with self.assertRaises(TypeError): + reversed(x for x in range(10)) + + self.assertEqual(list(reversed(range(10))), + list(mi.always_reversible(x for x in range(10)))) + self.assertEqual(list(reversed([1, 2, 3])), + list(mi.always_reversible(x for x in [1, 2, 3]))) + self.assertNotEqual(reversed((1, 2)).__class__, + mi.always_reversible(x for x in (1, 2)).__class__) + + +class CircularShiftsTests(TestCase): + def test_empty(self): + # empty iterable -> empty list + self.assertEqual(list(mi.circular_shifts([])), []) + + def test_simple_circular_shifts(self): + # test the a simple iterator case + self.assertEqual( + mi.circular_shifts(range(4)), + [(0, 1, 2, 3), (1, 2, 3, 0), (2, 3, 0, 1), (3, 0, 1, 2)] + ) + + def test_duplicates(self): + # test non-distinct entries + self.assertEqual( + mi.circular_shifts([0, 1, 0, 1]), + [(0, 1, 0, 1), (1, 0, 1, 0), (0, 1, 0, 1), (1, 0, 1, 0)] + ) + + +class MakeDecoratorTests(TestCase): + def test_basic(self): + slicer = mi.make_decorator(islice) + + @slicer(1, 10, 2) + def user_function(arg_1, arg_2, kwarg_1=None): + self.assertEqual(arg_1, 'arg_1') + self.assertEqual(arg_2, 'arg_2') + self.assertEqual(kwarg_1, 'kwarg_1') + return map(str, count()) + + it = user_function('arg_1', 'arg_2', kwarg_1='kwarg_1') + actual = list(it) + expected = ['1', '3', '5', '7', '9'] + self.assertEqual(actual, expected) + + def test_result_index(self): + def stringify(*args, **kwargs): + self.assertEqual(args[0], 'arg_0') + iterable = args[1] + self.assertEqual(args[2], 'arg_2') + self.assertEqual(kwargs['kwarg_1'], 'kwarg_1') + return map(str, iterable) + + stringifier = mi.make_decorator(stringify, result_index=1) + + @stringifier('arg_0', 'arg_2', kwarg_1='kwarg_1') + def user_function(n): + return count(n) + + it = user_function(1) + actual = mi.take(5, it) + expected = ['1', '2', '3', '4', '5'] + self.assertEqual(actual, expected) + + def test_wrap_class(self): + seeker = mi.make_decorator(mi.seekable) + + @seeker() + def user_function(n): + return map(str, range(n)) + + it = user_function(5) + self.assertEqual(list(it), ['0', '1', '2', '3', '4']) + + it.seek(0) + self.assertEqual(list(it), ['0', '1', '2', '3', '4']) + + +class MapReduceTests(TestCase): + def test_default(self): + iterable = (str(x) for x in range(5)) + keyfunc = lambda x: int(x) // 2 + actual = sorted(mi.map_reduce(iterable, keyfunc).items()) + expected = [(0, ['0', '1']), (1, ['2', '3']), (2, ['4'])] + self.assertEqual(actual, expected) + + def test_valuefunc(self): + iterable = (str(x) for x in range(5)) + keyfunc = lambda x: int(x) // 2 + valuefunc = int + actual = sorted(mi.map_reduce(iterable, keyfunc, valuefunc).items()) + expected = [(0, [0, 1]), (1, [2, 3]), (2, [4])] + self.assertEqual(actual, expected) + + def test_reducefunc(self): + iterable = (str(x) for x in range(5)) + keyfunc = lambda x: int(x) // 2 + valuefunc = int + reducefunc = lambda value_list: reduce(mul, value_list, 1) + actual = sorted( + mi.map_reduce(iterable, keyfunc, valuefunc, reducefunc).items() + ) + expected = [(0, 0), (1, 6), (2, 4)] + self.assertEqual(actual, expected) + + def test_ret(self): + d = mi.map_reduce([1, 0, 2, 0, 1, 0], bool) + self.assertEqual(d, {False: [0, 0, 0], True: [1, 2, 1]}) + self.assertRaises(KeyError, lambda: d[None].append(1)) + + +class RlocateTests(TestCase): + def test_default_pred(self): + iterable = [0, 1, 1, 0, 1, 0, 0] + for it in (iterable[:], iter(iterable)): + actual = list(mi.rlocate(it)) + expected = [4, 2, 1] + self.assertEqual(actual, expected) + + def test_no_matches(self): + iterable = [0, 0, 0] + for it in (iterable[:], iter(iterable)): + actual = list(mi.rlocate(it)) + expected = [] + self.assertEqual(actual, expected) + + def test_custom_pred(self): + iterable = ['0', 1, 1, '0', 1, '0', '0'] + pred = lambda x: x == '0' + for it in (iterable[:], iter(iterable)): + actual = list(mi.rlocate(it, pred)) + expected = [6, 5, 3, 0] + self.assertEqual(actual, expected) + + def test_efficient_reversal(self): + iterable = range(10 ** 10) # Is efficiently reversible + target = 10 ** 10 - 2 + pred = lambda x: x == target # Find-able from the right + actual = next(mi.rlocate(iterable, pred)) + self.assertEqual(actual, target) + + def test_window_size(self): + iterable = ['0', 1, 1, '0', 1, '0', '0'] + pred = lambda *args: args == ('0', 1) + for it in (iterable, iter(iterable)): + actual = list(mi.rlocate(it, pred, window_size=2)) + expected = [3, 0] + self.assertEqual(actual, expected) + + def test_window_size_large(self): + iterable = [1, 2, 3, 4] + pred = lambda a, b, c, d, e: True + for it in (iterable, iter(iterable)): + actual = list(mi.rlocate(iterable, pred, window_size=5)) + expected = [0] + self.assertEqual(actual, expected) + + def test_window_size_zero(self): + iterable = [1, 2, 3, 4] + pred = lambda: True + for it in (iterable, iter(iterable)): + with self.assertRaises(ValueError): + list(mi.locate(iterable, pred, window_size=0)) + + +class ReplaceTests(TestCase): + def test_basic(self): + iterable = range(10) + pred = lambda x: x % 2 == 0 + substitutes = [] + actual = list(mi.replace(iterable, pred, substitutes)) + expected = [1, 3, 5, 7, 9] + self.assertEqual(actual, expected) + + def test_count(self): + iterable = range(10) + pred = lambda x: x % 2 == 0 + substitutes = [] + actual = list(mi.replace(iterable, pred, substitutes, count=4)) + expected = [1, 3, 5, 7, 8, 9] + self.assertEqual(actual, expected) + + def test_window_size(self): + iterable = range(10) + pred = lambda *args: args == (0, 1, 2) + substitutes = [] + actual = list(mi.replace(iterable, pred, substitutes, window_size=3)) + expected = [3, 4, 5, 6, 7, 8, 9] + self.assertEqual(actual, expected) + + def test_window_size_end(self): + iterable = range(10) + pred = lambda *args: args == (7, 8, 9) + substitutes = [] + actual = list(mi.replace(iterable, pred, substitutes, window_size=3)) + expected = [0, 1, 2, 3, 4, 5, 6] + self.assertEqual(actual, expected) + + def test_window_size_count(self): + iterable = range(10) + pred = lambda *args: (args == (0, 1, 2)) or (args == (7, 8, 9)) + substitutes = [] + actual = list( + mi.replace(iterable, pred, substitutes, count=1, window_size=3) + ) + expected = [3, 4, 5, 6, 7, 8, 9] + self.assertEqual(actual, expected) + + def test_window_size_large(self): + iterable = range(4) + pred = lambda a, b, c, d, e: True + substitutes = [5, 6, 7] + actual = list(mi.replace(iterable, pred, substitutes, window_size=5)) + expected = [5, 6, 7] + self.assertEqual(actual, expected) + + def test_window_size_zero(self): + iterable = range(10) + pred = lambda *args: True + substitutes = [] + with self.assertRaises(ValueError): + list(mi.replace(iterable, pred, substitutes, window_size=0)) + + def test_iterable_substitutes(self): + iterable = range(5) + pred = lambda x: x % 2 == 0 + substitutes = iter('__') + actual = list(mi.replace(iterable, pred, substitutes)) + expected = ['_', '_', 1, '_', '_', 3, '_', '_'] + self.assertEqual(actual, expected) diff --git a/resources/lib/more_itertools/tests/test_recipes.py b/resources/lib/more_itertools/tests/test_recipes.py new file mode 100644 index 0000000..98981fe --- /dev/null +++ b/resources/lib/more_itertools/tests/test_recipes.py @@ -0,0 +1,616 @@ +from doctest import DocTestSuite +from unittest import TestCase + +from itertools import combinations +from six.moves import range + +import more_itertools as mi + + +def load_tests(loader, tests, ignore): + # Add the doctests + tests.addTests(DocTestSuite('more_itertools.recipes')) + return tests + + +class AccumulateTests(TestCase): + """Tests for ``accumulate()``""" + + def test_empty(self): + """Test that an empty input returns an empty output""" + self.assertEqual(list(mi.accumulate([])), []) + + def test_default(self): + """Test accumulate with the default function (addition)""" + self.assertEqual(list(mi.accumulate([1, 2, 3])), [1, 3, 6]) + + def test_bogus_function(self): + """Test accumulate with an invalid function""" + with self.assertRaises(TypeError): + list(mi.accumulate([1, 2, 3], func=lambda x: x)) + + def test_custom_function(self): + """Test accumulate with a custom function""" + self.assertEqual( + list(mi.accumulate((1, 2, 3, 2, 1), func=max)), [1, 2, 3, 3, 3] + ) + + +class TakeTests(TestCase): + """Tests for ``take()``""" + + def test_simple_take(self): + """Test basic usage""" + t = mi.take(5, range(10)) + self.assertEqual(t, [0, 1, 2, 3, 4]) + + def test_null_take(self): + """Check the null case""" + t = mi.take(0, range(10)) + self.assertEqual(t, []) + + def test_negative_take(self): + """Make sure taking negative items results in a ValueError""" + self.assertRaises(ValueError, lambda: mi.take(-3, range(10))) + + def test_take_too_much(self): + """Taking more than an iterator has remaining should return what the + iterator has remaining. + + """ + t = mi.take(10, range(5)) + self.assertEqual(t, [0, 1, 2, 3, 4]) + + +class TabulateTests(TestCase): + """Tests for ``tabulate()``""" + + def test_simple_tabulate(self): + """Test the happy path""" + t = mi.tabulate(lambda x: x) + f = tuple([next(t) for _ in range(3)]) + self.assertEqual(f, (0, 1, 2)) + + def test_count(self): + """Ensure tabulate accepts specific count""" + t = mi.tabulate(lambda x: 2 * x, -1) + f = (next(t), next(t), next(t)) + self.assertEqual(f, (-2, 0, 2)) + + +class TailTests(TestCase): + """Tests for ``tail()``""" + + def test_greater(self): + """Length of iterable is greather than requested tail""" + self.assertEqual(list(mi.tail(3, 'ABCDEFG')), ['E', 'F', 'G']) + + def test_equal(self): + """Length of iterable is equal to the requested tail""" + self.assertEqual( + list(mi.tail(7, 'ABCDEFG')), ['A', 'B', 'C', 'D', 'E', 'F', 'G'] + ) + + def test_less(self): + """Length of iterable is less than requested tail""" + self.assertEqual( + list(mi.tail(8, 'ABCDEFG')), ['A', 'B', 'C', 'D', 'E', 'F', 'G'] + ) + + +class ConsumeTests(TestCase): + """Tests for ``consume()``""" + + def test_sanity(self): + """Test basic functionality""" + r = (x for x in range(10)) + mi.consume(r, 3) + self.assertEqual(3, next(r)) + + def test_null_consume(self): + """Check the null case""" + r = (x for x in range(10)) + mi.consume(r, 0) + self.assertEqual(0, next(r)) + + def test_negative_consume(self): + """Check that negative consumsion throws an error""" + r = (x for x in range(10)) + self.assertRaises(ValueError, lambda: mi.consume(r, -1)) + + def test_total_consume(self): + """Check that iterator is totally consumed by default""" + r = (x for x in range(10)) + mi.consume(r) + self.assertRaises(StopIteration, lambda: next(r)) + + +class NthTests(TestCase): + """Tests for ``nth()``""" + + def test_basic(self): + """Make sure the nth item is returned""" + l = range(10) + for i, v in enumerate(l): + self.assertEqual(mi.nth(l, i), v) + + def test_default(self): + """Ensure a default value is returned when nth item not found""" + l = range(3) + self.assertEqual(mi.nth(l, 100, "zebra"), "zebra") + + def test_negative_item_raises(self): + """Ensure asking for a negative item raises an exception""" + self.assertRaises(ValueError, lambda: mi.nth(range(10), -3)) + + +class AllEqualTests(TestCase): + """Tests for ``all_equal()``""" + + def test_true(self): + """Everything is equal""" + self.assertTrue(mi.all_equal('aaaaaa')) + self.assertTrue(mi.all_equal([0, 0, 0, 0])) + + def test_false(self): + """Not everything is equal""" + self.assertFalse(mi.all_equal('aaaaab')) + self.assertFalse(mi.all_equal([0, 0, 0, 1])) + + def test_tricky(self): + """Not everything is identical, but everything is equal""" + items = [1, complex(1, 0), 1.0] + self.assertTrue(mi.all_equal(items)) + + def test_empty(self): + """Return True if the iterable is empty""" + self.assertTrue(mi.all_equal('')) + self.assertTrue(mi.all_equal([])) + + def test_one(self): + """Return True if the iterable is singular""" + self.assertTrue(mi.all_equal('0')) + self.assertTrue(mi.all_equal([0])) + + +class QuantifyTests(TestCase): + """Tests for ``quantify()``""" + + def test_happy_path(self): + """Make sure True count is returned""" + q = [True, False, True] + self.assertEqual(mi.quantify(q), 2) + + def test_custom_predicate(self): + """Ensure non-default predicates return as expected""" + q = range(10) + self.assertEqual(mi.quantify(q, lambda x: x % 2 == 0), 5) + + +class PadnoneTests(TestCase): + """Tests for ``padnone()``""" + + def test_happy_path(self): + """wrapper iterator should return None indefinitely""" + r = range(2) + p = mi.padnone(r) + self.assertEqual([0, 1, None, None], [next(p) for _ in range(4)]) + + +class NcyclesTests(TestCase): + """Tests for ``nyclces()``""" + + def test_happy_path(self): + """cycle a sequence three times""" + r = ["a", "b", "c"] + n = mi.ncycles(r, 3) + self.assertEqual( + ["a", "b", "c", "a", "b", "c", "a", "b", "c"], + list(n) + ) + + def test_null_case(self): + """asking for 0 cycles should return an empty iterator""" + n = mi.ncycles(range(100), 0) + self.assertRaises(StopIteration, lambda: next(n)) + + def test_pathalogical_case(self): + """asking for negative cycles should return an empty iterator""" + n = mi.ncycles(range(100), -10) + self.assertRaises(StopIteration, lambda: next(n)) + + +class DotproductTests(TestCase): + """Tests for ``dotproduct()``'""" + + def test_happy_path(self): + """simple dotproduct example""" + self.assertEqual(400, mi.dotproduct([10, 10], [20, 20])) + + +class FlattenTests(TestCase): + """Tests for ``flatten()``""" + + def test_basic_usage(self): + """ensure list of lists is flattened one level""" + f = [[0, 1, 2], [3, 4, 5]] + self.assertEqual(list(range(6)), list(mi.flatten(f))) + + def test_single_level(self): + """ensure list of lists is flattened only one level""" + f = [[0, [1, 2]], [[3, 4], 5]] + self.assertEqual([0, [1, 2], [3, 4], 5], list(mi.flatten(f))) + + +class RepeatfuncTests(TestCase): + """Tests for ``repeatfunc()``""" + + def test_simple_repeat(self): + """test simple repeated functions""" + r = mi.repeatfunc(lambda: 5) + self.assertEqual([5, 5, 5, 5, 5], [next(r) for _ in range(5)]) + + def test_finite_repeat(self): + """ensure limited repeat when times is provided""" + r = mi.repeatfunc(lambda: 5, times=5) + self.assertEqual([5, 5, 5, 5, 5], list(r)) + + def test_added_arguments(self): + """ensure arguments are applied to the function""" + r = mi.repeatfunc(lambda x: x, 2, 3) + self.assertEqual([3, 3], list(r)) + + def test_null_times(self): + """repeat 0 should return an empty iterator""" + r = mi.repeatfunc(range, 0, 3) + self.assertRaises(StopIteration, lambda: next(r)) + + +class PairwiseTests(TestCase): + """Tests for ``pairwise()``""" + + def test_base_case(self): + """ensure an iterable will return pairwise""" + p = mi.pairwise([1, 2, 3]) + self.assertEqual([(1, 2), (2, 3)], list(p)) + + def test_short_case(self): + """ensure an empty iterator if there's not enough values to pair""" + p = mi.pairwise("a") + self.assertRaises(StopIteration, lambda: next(p)) + + +class GrouperTests(TestCase): + """Tests for ``grouper()``""" + + def test_even(self): + """Test when group size divides evenly into the length of + the iterable. + + """ + self.assertEqual( + list(mi.grouper(3, 'ABCDEF')), [('A', 'B', 'C'), ('D', 'E', 'F')] + ) + + def test_odd(self): + """Test when group size does not divide evenly into the length of the + iterable. + + """ + self.assertEqual( + list(mi.grouper(3, 'ABCDE')), [('A', 'B', 'C'), ('D', 'E', None)] + ) + + def test_fill_value(self): + """Test that the fill value is used to pad the final group""" + self.assertEqual( + list(mi.grouper(3, 'ABCDE', 'x')), + [('A', 'B', 'C'), ('D', 'E', 'x')] + ) + + +class RoundrobinTests(TestCase): + """Tests for ``roundrobin()``""" + + def test_even_groups(self): + """Ensure ordered output from evenly populated iterables""" + self.assertEqual( + list(mi.roundrobin('ABC', [1, 2, 3], range(3))), + ['A', 1, 0, 'B', 2, 1, 'C', 3, 2] + ) + + def test_uneven_groups(self): + """Ensure ordered output from unevenly populated iterables""" + self.assertEqual( + list(mi.roundrobin('ABCD', [1, 2], range(0))), + ['A', 1, 'B', 2, 'C', 'D'] + ) + + +class PartitionTests(TestCase): + """Tests for ``partition()``""" + + def test_bool(self): + """Test when pred() returns a boolean""" + lesser, greater = mi.partition(lambda x: x > 5, range(10)) + self.assertEqual(list(lesser), [0, 1, 2, 3, 4, 5]) + self.assertEqual(list(greater), [6, 7, 8, 9]) + + def test_arbitrary(self): + """Test when pred() returns an integer""" + divisibles, remainders = mi.partition(lambda x: x % 3, range(10)) + self.assertEqual(list(divisibles), [0, 3, 6, 9]) + self.assertEqual(list(remainders), [1, 2, 4, 5, 7, 8]) + + +class PowersetTests(TestCase): + """Tests for ``powerset()``""" + + def test_combinatorics(self): + """Ensure a proper enumeration""" + p = mi.powerset([1, 2, 3]) + self.assertEqual( + list(p), + [(), (1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)] + ) + + +class UniqueEverseenTests(TestCase): + """Tests for ``unique_everseen()``""" + + def test_everseen(self): + """ensure duplicate elements are ignored""" + u = mi.unique_everseen('AAAABBBBCCDAABBB') + self.assertEqual( + ['A', 'B', 'C', 'D'], + list(u) + ) + + def test_custom_key(self): + """ensure the custom key comparison works""" + u = mi.unique_everseen('aAbACCc', key=str.lower) + self.assertEqual(list('abC'), list(u)) + + def test_unhashable(self): + """ensure things work for unhashable items""" + iterable = ['a', [1, 2, 3], [1, 2, 3], 'a'] + u = mi.unique_everseen(iterable) + self.assertEqual(list(u), ['a', [1, 2, 3]]) + + def test_unhashable_key(self): + """ensure things work for unhashable items with a custom key""" + iterable = ['a', [1, 2, 3], [1, 2, 3], 'a'] + u = mi.unique_everseen(iterable, key=lambda x: x) + self.assertEqual(list(u), ['a', [1, 2, 3]]) + + +class UniqueJustseenTests(TestCase): + """Tests for ``unique_justseen()``""" + + def test_justseen(self): + """ensure only last item is remembered""" + u = mi.unique_justseen('AAAABBBCCDABB') + self.assertEqual(list('ABCDAB'), list(u)) + + def test_custom_key(self): + """ensure the custom key comparison works""" + u = mi.unique_justseen('AABCcAD', str.lower) + self.assertEqual(list('ABCAD'), list(u)) + + +class IterExceptTests(TestCase): + """Tests for ``iter_except()``""" + + def test_exact_exception(self): + """ensure the exact specified exception is caught""" + l = [1, 2, 3] + i = mi.iter_except(l.pop, IndexError) + self.assertEqual(list(i), [3, 2, 1]) + + def test_generic_exception(self): + """ensure the generic exception can be caught""" + l = [1, 2] + i = mi.iter_except(l.pop, Exception) + self.assertEqual(list(i), [2, 1]) + + def test_uncaught_exception_is_raised(self): + """ensure a non-specified exception is raised""" + l = [1, 2, 3] + i = mi.iter_except(l.pop, KeyError) + self.assertRaises(IndexError, lambda: list(i)) + + def test_first(self): + """ensure first is run before the function""" + l = [1, 2, 3] + f = lambda: 25 + i = mi.iter_except(l.pop, IndexError, f) + self.assertEqual(list(i), [25, 3, 2, 1]) + + +class FirstTrueTests(TestCase): + """Tests for ``first_true()``""" + + def test_something_true(self): + """Test with no keywords""" + self.assertEqual(mi.first_true(range(10)), 1) + + def test_nothing_true(self): + """Test default return value.""" + self.assertEqual(mi.first_true([0, 0, 0]), False) + + def test_default(self): + """Test with a default keyword""" + self.assertEqual(mi.first_true([0, 0, 0], default='!'), '!') + + def test_pred(self): + """Test with a custom predicate""" + self.assertEqual( + mi.first_true([2, 4, 6], pred=lambda x: x % 3 == 0), 6 + ) + + +class RandomProductTests(TestCase): + """Tests for ``random_product()`` + + Since random.choice() has different results with the same seed across + python versions 2.x and 3.x, these tests use highly probably events to + create predictable outcomes across platforms. + """ + + def test_simple_lists(self): + """Ensure that one item is chosen from each list in each pair. + Also ensure that each item from each list eventually appears in + the chosen combinations. + + Odds are roughly 1 in 7.1 * 10e16 that one item from either list will + not be chosen after 100 samplings of one item from each list. Just to + be safe, better use a known random seed, too. + + """ + nums = [1, 2, 3] + lets = ['a', 'b', 'c'] + n, m = zip(*[mi.random_product(nums, lets) for _ in range(100)]) + n, m = set(n), set(m) + self.assertEqual(n, set(nums)) + self.assertEqual(m, set(lets)) + self.assertEqual(len(n), len(nums)) + self.assertEqual(len(m), len(lets)) + + def test_list_with_repeat(self): + """ensure multiple items are chosen, and that they appear to be chosen + from one list then the next, in proper order. + + """ + nums = [1, 2, 3] + lets = ['a', 'b', 'c'] + r = list(mi.random_product(nums, lets, repeat=100)) + self.assertEqual(2 * 100, len(r)) + n, m = set(r[::2]), set(r[1::2]) + self.assertEqual(n, set(nums)) + self.assertEqual(m, set(lets)) + self.assertEqual(len(n), len(nums)) + self.assertEqual(len(m), len(lets)) + + +class RandomPermutationTests(TestCase): + """Tests for ``random_permutation()``""" + + def test_full_permutation(self): + """ensure every item from the iterable is returned in a new ordering + + 15 elements have a 1 in 1.3 * 10e12 of appearing in sorted order, so + we fix a seed value just to be sure. + + """ + i = range(15) + r = mi.random_permutation(i) + self.assertEqual(set(i), set(r)) + if i == r: + raise AssertionError("Values were not permuted") + + def test_partial_permutation(self): + """ensure all returned items are from the iterable, that the returned + permutation is of the desired length, and that all items eventually + get returned. + + Sampling 100 permutations of length 5 from a set of 15 leaves a + (2/3)^100 chance that an item will not be chosen. Multiplied by 15 + items, there is a 1 in 2.6e16 chance that at least 1 item will not + show up in the resulting output. Using a random seed will fix that. + + """ + items = range(15) + item_set = set(items) + all_items = set() + for _ in range(100): + permutation = mi.random_permutation(items, 5) + self.assertEqual(len(permutation), 5) + permutation_set = set(permutation) + self.assertLessEqual(permutation_set, item_set) + all_items |= permutation_set + self.assertEqual(all_items, item_set) + + +class RandomCombinationTests(TestCase): + """Tests for ``random_combination()``""" + + def test_psuedorandomness(self): + """ensure different subsets of the iterable get returned over many + samplings of random combinations""" + items = range(15) + all_items = set() + for _ in range(50): + combination = mi.random_combination(items, 5) + all_items |= set(combination) + self.assertEqual(all_items, set(items)) + + def test_no_replacement(self): + """ensure that elements are sampled without replacement""" + items = range(15) + for _ in range(50): + combination = mi.random_combination(items, len(items)) + self.assertEqual(len(combination), len(set(combination))) + self.assertRaises( + ValueError, lambda: mi.random_combination(items, len(items) + 1) + ) + + +class RandomCombinationWithReplacementTests(TestCase): + """Tests for ``random_combination_with_replacement()``""" + + def test_replacement(self): + """ensure that elements are sampled with replacement""" + items = range(5) + combo = mi.random_combination_with_replacement(items, len(items) * 2) + self.assertEqual(2 * len(items), len(combo)) + if len(set(combo)) == len(combo): + raise AssertionError("Combination contained no duplicates") + + def test_pseudorandomness(self): + """ensure different subsets of the iterable get returned over many + samplings of random combinations""" + items = range(15) + all_items = set() + for _ in range(50): + combination = mi.random_combination_with_replacement(items, 5) + all_items |= set(combination) + self.assertEqual(all_items, set(items)) + + +class NthCombinationTests(TestCase): + def test_basic(self): + iterable = 'abcdefg' + r = 4 + for index, expected in enumerate(combinations(iterable, r)): + actual = mi.nth_combination(iterable, r, index) + self.assertEqual(actual, expected) + + def test_long(self): + actual = mi.nth_combination(range(180), 4, 2000000) + expected = (2, 12, 35, 126) + self.assertEqual(actual, expected) + + def test_invalid_r(self): + for r in (-1, 3): + with self.assertRaises(ValueError): + mi.nth_combination([], r, 0) + + def test_invalid_index(self): + with self.assertRaises(IndexError): + mi.nth_combination('abcdefg', 3, -36) + + +class PrependTests(TestCase): + def test_basic(self): + value = 'a' + iterator = iter('bcdefg') + actual = list(mi.prepend(value, iterator)) + expected = list('abcdefg') + self.assertEqual(actual, expected) + + def test_multiple(self): + value = 'ab' + iterator = iter('cdefg') + actual = tuple(mi.prepend(value, iterator)) + expected = ('ab',) + tuple('cdefg') + self.assertEqual(actual, expected) diff --git a/resources/lib/osd.py b/resources/lib/osd.py index dc28323..1f6baf2 100644 --- a/resources/lib/osd.py +++ b/resources/lib/osd.py @@ -8,7 +8,7 @@ ''' import threading -import thread +import _thread import xbmc import xbmcgui from utils import log_msg, log_exception, get_track_rating diff --git a/resources/lib/pkg_resources.py b/resources/lib/pkg_resources.py new file mode 100644 index 0000000..9ab28e9 --- /dev/null +++ b/resources/lib/pkg_resources.py @@ -0,0 +1,27 @@ +import uio + +c = {} + +def resource_stream(package, resource): + if package not in c: + try: + if package: + p = __import__(package + ".R", None, None, True) + else: + p = __import__("R") + c[package] = p.R + except ImportError: + if package: + p = __import__(package) + d = p.__path__ + else: + d = "." +# if d[0] != "/": +# import uos +# d = uos.getcwd() + "/" + d + c[package] = d + "/" + + p = c[package] + if isinstance(p, dict): + return uio.BytesIO(p[resource]) + return open(p + resource, "rb") diff --git a/resources/lib/player_monitor.py b/resources/lib/player_monitor.py index 1ef477f..aeec016 100644 --- a/resources/lib/player_monitor.py +++ b/resources/lib/player_monitor.py @@ -2,12 +2,12 @@ # -*- coding: utf-8 -*- -from utils import log_msg, log_exception, parse_spotify_track, PROXY_PORT +from utils import log_msg, log_exception, parse_spotify_track, PROXY_PORT, get_playername import xbmc import xbmcgui -from urllib import quote_plus +import urllib.parse import threading -import thread +import _thread class ConnectPlayer(xbmc.Player): @@ -66,6 +66,7 @@ def onPlayBackStarted(self): self.connect_local = True if "nexttrack" in filename: # next track requested for kodi player + log_msg("next track requested for kodi player") self.__sp.next_track() elif self.connect_playing: self.update_playlist() @@ -97,6 +98,7 @@ def onPlayBackStopped(self): def update_playlist(self): '''Update the playlist: add fake item at the end which allows us to skip''' + log_msg("Update the playlist: add fake item at the end which allows us to skip", xbmc.LOGDEBUG) if self.connect_local: url = "http://localhost:%s/nexttrack" % PROXY_PORT else: @@ -127,23 +129,30 @@ def start_playback(self, track_id): def update_info(self, force): cur_playback = self.__sp.current_playback() if cur_playback: - if cur_playback["is_playing"] and (not xbmc.getCondVisibility("Player.Paused") or force): + log_msg("Spotify Connect request received : %s" % cur_playback) + if cur_playback["device"]["name"] == get_playername() and (not xbmc.getCondVisibility("Player.Paused") and cur_playback["is_playing"] or force): player_title = None if self.isPlaying(): - player_title = self.getMusicInfoTag().getTitle().decode("utf-8") - + player_title = self.getMusicInfoTag().getTitle() trackdetails = cur_playback["item"] + # Set volume level + if cur_playback['device']['volume_percent'] != 50: + xbmc.executebuiltin("SetVolume(%s,true)" % cur_playback['device']['volume_percent'] ) if trackdetails is not None and (not player_title or player_title != trackdetails["name"]): - log_msg("Next track requested by Spotify Connect player") + log_msg("Next track requested by Spotify Connect player.") self.start_playback(trackdetails["id"]) - elif cur_playback["is_playing"] and xbmc.getCondVisibility("Player.Paused"): - log_msg("Playback resumed from pause requested by Spotify Connect") + elif cur_playback["device"]["name"] == get_playername() and xbmc.getCondVisibility("Player.Paused") and cur_playback["is_playing"]: + log_msg("Playback resumed from pause requested by Spotify Connect." ) self.__skip_events = True - self.play() + # Set volume level + if cur_playback['device']['volume_percent'] != 50: + xbmc.executebuiltin("SetVolume(%s,true)" % cur_playback['device']['volume_percent'] ) + log_msg("Start position : %s" % cur_playback['progress_ms']) + self.play(startpos = cur_playback['progress_ms']) elif not xbmc.getCondVisibility("Player.Paused"): - log_msg("Pause requested by Spotify Connect") + log_msg("Pause requested by Spotify Connect.") self.__skip_events = True - self.pause() + self.pause() else: self.__skip_events = True self.stop() diff --git a/resources/lib/plugin_content.py b/resources/lib/plugin_content.py index a3e65e3..092543c 100644 --- a/resources/lib/plugin_content.py +++ b/resources/lib/plugin_content.py @@ -1,10 +1,10 @@ # -*- coding: utf8 -*- from __future__ import print_function, unicode_literals from utils import log_msg, log_exception, ADDON_ID, PROXY_PORT, get_chunks, get_track_rating, parse_spotify_track, get_playername, KODI_VERSION, request_token_web -import urlparse +from urllib.parse import urlparse import urllib import threading -import thread +import _thread import time import spotipy import xbmc @@ -47,8 +47,8 @@ def __init__(self): if auth_token: self.parse_params() self.sp = spotipy.Spotify(auth=auth_token) - self.userid = self.win.getProperty("spotify-username").decode("utf-8") - self.usercountry = self.win.getProperty("spotify-country").decode("utf-8") + self.userid = self.win.getProperty("spotify-username") + self.usercountry = self.win.getProperty("spotify-country") self.local_playback, self.playername, self.connect_id = self.active_playback_device() if self.action: action = "self." + self.action @@ -67,7 +67,7 @@ def get_authkey(self): auth_token = None count = 10 while not auth_token and count: # wait max 5 seconds for the token - auth_token = self.win.getProperty("spotify-token").decode("utf-8") + auth_token = self.win.getProperty("spotify-token") count -= 1 if not auth_token: xbmc.sleep(500) @@ -89,34 +89,34 @@ def get_authkey(self): def parse_params(self): '''parse parameters from the plugin entry path''' - self.params = urlparse.parse_qs(sys.argv[2][1:]) + self.params = urllib.parse.parse_qs(sys.argv[2][1:]) action = self.params.get("action", None) if action: - self.action = action[0].lower().decode("utf-8") + self.action = action[0].lower() playlistid = self.params.get("playlistid", None) if playlistid: - self.playlistid = playlistid[0].decode("utf-8") + self.playlistid = playlistid[0] ownerid = self.params.get("ownerid", None) if ownerid: - self.ownerid = ownerid[0].decode("utf-8") + self.ownerid = ownerid[0] trackid = self.params.get("trackid", None) if trackid: - self.trackid = trackid[0].decode("utf-8") + self.trackid = trackid[0] albumid = self.params.get("albumid", None) if albumid: - self.albumid = albumid[0].decode("utf-8") + self.albumid = albumid[0] artistid = self.params.get("artistid", None) if artistid: - self.artistid = artistid[0].decode("utf-8") + self.artistid = artistid[0] artistname = self.params.get("artistname", None) if artistname: - self.artistname = artistname[0].decode("utf-8") + self.artistname = artistname[0] offset = self.params.get("offset", None) if offset: self.offset = int(offset[0]) filter = self.params.get("applyfilter", None) if filter: - self.filter = filter[0].decode("utf-8") + self.filter = filter[0] # default settings self.append_artist_to_title = self.addon.getSetting("appendArtistToTitle") == "true" self.defaultview_songs = self.addon.getSetting("songDefaultView") @@ -141,13 +141,13 @@ def cache_checksum(self, opt_value=None): def build_url(self, query): query_encoded = {} - for key, value in query.iteritems(): - if isinstance(key, unicode): + for key, value in query.items(): + if isinstance(key, str): key = key.encode("utf-8") - if isinstance(value, unicode): + if isinstance(value, str): value = value.encode("utf-8") query_encoded[key] = value - return self.base_url + '?' + urllib.urlencode(query_encoded) + return self.base_url + '?' + urllib.parse.urlencode(query_encoded) def refresh_listing(self): self.addon.setSetting("cache_checksum", time.strftime("%Y%m%d%H%M%S", time.gmtime())) @@ -186,7 +186,7 @@ def switch_user_multi(self): usernames = [] count = 1 while True: - username = self.addon.getSetting("username%s" % count).decode("utf-8") + username = self.addon.getSetting("username%s" % count) count += 1 if not username: break @@ -218,15 +218,13 @@ def switch_user_multi(self): def next_track(self): '''special entry which tells the remote connect player to move to the next track''' - + log_msg("Next track requested", xbmc.LOGDEBUG) cur_playlist_position = xbmc.PlayList(xbmc.PLAYLIST_MUSIC).getposition() - # prevent unintentional skipping when Kodi track ends before connect player - # playlist position will increse only when play next button is pressed - if cur_playlist_position > self.last_playlist_position: - # move to next track - self.sp.next_track() - # give time for connect player to update info - xbmc.sleep(100) + + self.sp.next_track() + # give time for connect player to update info + xbmc.sleep(100) + self.last_playlist_position = cur_playlist_position cur_playback = self.sp.current_playback() @@ -236,6 +234,7 @@ def next_track(self): def play_connect(self): '''start local connect playback - called from webservice when local connect player starts playback''' + log_msg("start local connect playback - called from webservice when local connect player starts playback", xbmc.LOGDEBUG) playlist = xbmc.PlayList(xbmc.PLAYLIST_MUSIC) trackdetails = None count = 0 @@ -316,7 +315,7 @@ def connect_playback(self): # launch our special OSD dialog from osd import SpotifyOSD osd = SpotifyOSD("plugin-audio-spotify-OSD.xml", - self.addon.getAddonInfo('path').decode("utf-8"), "Default", "1080i") + self.addon.getAddonInfo('path'), "Default", "1080i") osd.sp = self.sp osd.doModal() del osd @@ -361,8 +360,8 @@ def browse_main(self): for item in items: li = xbmcgui.ListItem( item[0], - path=item[1], - iconImage=item[2] + path=item[1] + # iconImage=item[2] ) li.setProperty('IsPlayable', 'false') li.setArt({"fanart": "special://home/addons/plugin.audio.spotify/fanart.jpg"}) @@ -411,7 +410,7 @@ def browse_playback_devices(self): if self.local_playback: label += " [%s]" % self.addon.getLocalizedString(11040) url = "plugin://plugin.audio.spotify/?action=set_playback_device&deviceid=local" - li = xbmcgui.ListItem(label, iconImage="DefaultMusicCompilations.png") + li = xbmcgui.ListItem(label) li.setProperty("isPlayable", "false") li.setArt({"fanart": "special://home/addons/plugin.audio.spotify/fanart.jpg"}) li.addContextMenuItems([], True) @@ -422,7 +421,7 @@ def browse_playback_devices(self): if self.addon.getSetting("playback_device") == "remote": label += " [%s]" % self.addon.getLocalizedString(11040) url = "plugin://plugin.audio.spotify/?action=set_playback_device&deviceid=remote" - li = xbmcgui.ListItem(label, iconImage="DefaultMusicCompilations.png") + li = xbmcgui.ListItem(label) li.setProperty("isPlayable", "false") li.setArt({"fanart": "special://home/addons/plugin.audio.spotify/fanart.jpg"}) li.addContextMenuItems([], True) @@ -434,7 +433,7 @@ def browse_playback_devices(self): label += " [%s]" % self.addon.getLocalizedString(11040) self.refresh_connected_device() url = "plugin://plugin.audio.spotify/?action=set_playback_device&deviceid=%s" % device["id"] - li = xbmcgui.ListItem(label, iconImage="DefaultMusicCompilations.png") + li = xbmcgui.ListItem(label) li.setProperty("isPlayable", "false") li.setArt({"fanart": "special://home/addons/plugin.audio.spotify/fanart.jpg"}) li.addContextMenuItems([], True) @@ -445,7 +444,7 @@ def browse_playback_devices(self): if self.addon.getSetting("playback_device") == "squeezebox": label += " [%s]" % self.addon.getLocalizedString(11040) url = "plugin://plugin.audio.spotify/?action=set_playback_device&deviceid=squeezebox" - li = xbmcgui.ListItem(label, iconImage="DefaultMusicCompilations.png") + li = xbmcgui.ListItem(label) li.setProperty("isPlayable", "false") li.setArt({"fanart": "special://home/addons/plugin.audio.spotify/fanart.jpg"}) li.addContextMenuItems([], True) @@ -516,8 +515,8 @@ def browse_main_library(self): for item in items: li = xbmcgui.ListItem( item[0], - path=item[1], - iconImage=item[2] + path=item[1] + # iconImage=item[2] ) li.setProperty('do_not_analyze', 'true') li.setProperty('IsPlayable', 'false') @@ -605,8 +604,8 @@ def browse_main_explore(self): for item in items: li = xbmcgui.ListItem( item[0], - path=item[1], - iconImage=item[2] + path=item[1] + # iconImage=item[2] ) li.setProperty('do_not_analyze', 'true') li.setProperty('IsPlayable', 'false') @@ -1090,6 +1089,7 @@ def add_track_listitems(self, tracks, append_artist_to_label=False): li.addContextMenuItems(track["contextitems"], True) li.setProperty('do_not_analyze', 'true') li.setMimeType("audio/wave") + li.setInfo('video', {}) list_items.append((url, li, False)) xbmcplugin.addDirectoryItems(self.addon_handle, list_items, totalItems=len(list_items)) @@ -1612,7 +1612,7 @@ def search(self): li = xbmcgui.ListItem( item[0], path=item[1], - iconImage="DefaultMusicAlbums.png" + # iconImage="DefaultMusicAlbums.png" ) li.setProperty('do_not_analyze', 'true') li.setProperty('IsPlayable', 'false') @@ -1626,7 +1626,7 @@ def add_next_button(self, listtotal): if listtotal > self.offset + self.limit: params["offset"] = self.offset + self.limit url = "plugin://plugin.audio.spotify/" - for key, value in params.iteritems(): + for key, value in params.items(): if key == "action": url += "?%s=%s" % (key, value[0]) elif key == "offset": @@ -1636,7 +1636,7 @@ def add_next_button(self, listtotal): li = xbmcgui.ListItem( xbmc.getLocalizedString(33078), path=url, - iconImage="DefaultMusicAlbums.png" + # iconImage="DefaultMusicAlbums.png" ) li.setProperty('do_not_analyze', 'true') li.setProperty('IsPlayable', 'false') @@ -1722,7 +1722,7 @@ def _fill_buffer(self): def _fetch(self): log_msg("Spotify radio track buffer invoking recommendations() via spotipy", xbmc.LOGDEBUG) try: - auth_token = xbmc.getInfoLabel("Window(Home).Property(spotify-token)").decode("utf-8") + auth_token = xbmc.getInfoLabel("Window(Home).Property(spotify-token)") client = spotipy.Spotify(auth_token) tracks = client.recommendations( seed_tracks=[t["id"] for t in self._buffer[0: 5]], diff --git a/resources/lib/portend.py b/resources/lib/portend.py new file mode 100644 index 0000000..4c39380 --- /dev/null +++ b/resources/lib/portend.py @@ -0,0 +1,212 @@ +# -*- coding: utf-8 -*- + +""" +A simple library for managing the availability of ports. +""" + +from __future__ import print_function, division + +import time +import socket +import argparse +import sys +import itertools +import contextlib +import collections +import platform + +from tempora import timing + + +def client_host(server_host): + """Return the host on which a client can connect to the given listener.""" + if server_host == '0.0.0.0': + # 0.0.0.0 is INADDR_ANY, which should answer on localhost. + return '127.0.0.1' + if server_host in ('::', '::0', '::0.0.0.0'): + # :: is IN6ADDR_ANY, which should answer on localhost. + # ::0 and ::0.0.0.0 are non-canonical but common + # ways to write IN6ADDR_ANY. + return '::1' + return server_host + + +class Checker(object): + def __init__(self, timeout=1.0): + self.timeout = timeout + + def assert_free(self, host, port=None): + """ + Assert that the given addr is free + in that all attempts to connect fail within the timeout + or raise a PortNotFree exception. + + >>> free_port = find_available_local_port() + + >>> Checker().assert_free('localhost', free_port) + >>> Checker().assert_free('127.0.0.1', free_port) + >>> Checker().assert_free('::1', free_port) + + Also accepts an addr tuple + + >>> addr = '::1', free_port, 0, 0 + >>> Checker().assert_free(addr) + + Host might refer to a server bind address like '::', which + should use localhost to perform the check. + + >>> Checker().assert_free('::', free_port) + """ + if port is None and isinstance(host, collections.Sequence): + host, port = host[:2] + if platform.system() == 'Windows': + host = client_host(host) + info = socket.getaddrinfo( + host, port, socket.AF_UNSPEC, socket.SOCK_STREAM, + ) + list(itertools.starmap(self._connect, info)) + + def _connect(self, af, socktype, proto, canonname, sa): + s = socket.socket(af, socktype, proto) + # fail fast with a small timeout + s.settimeout(self.timeout) + + with contextlib.closing(s): + try: + s.connect(sa) + except socket.error: + return + + # the connect succeeded, so the port isn't free + port, host = sa[:2] + tmpl = "Port {port} is in use on {host}." + raise PortNotFree(tmpl.format(**locals())) + + +class Timeout(IOError): + pass + + +class PortNotFree(IOError): + pass + + +def free(host, port, timeout=float('Inf')): + """ + Wait for the specified port to become free (dropping or rejecting + requests). Return when the port is free or raise a Timeout if timeout has + elapsed. + + Timeout may be specified in seconds or as a timedelta. + If timeout is None or ∞, the routine will run indefinitely. + + >>> free('localhost', find_available_local_port()) + """ + if not host: + raise ValueError("Host values of '' or None are not allowed.") + + timer = timing.Timer(timeout) + + while not timer.expired(): + try: + # Expect a free port, so use a small timeout + Checker(timeout=0.1).assert_free(host, port) + return + except PortNotFree: + # Politely wait. + time.sleep(0.1) + + raise Timeout("Port {port} not free on {host}.".format(**locals())) +wait_for_free_port = free + + +def occupied(host, port, timeout=float('Inf')): + """ + Wait for the specified port to become occupied (accepting requests). + Return when the port is occupied or raise a Timeout if timeout has + elapsed. + + Timeout may be specified in seconds or as a timedelta. + If timeout is None or ∞, the routine will run indefinitely. + + >>> occupied('localhost', find_available_local_port(), .1) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + Timeout: Port ... not bound on localhost. + """ + if not host: + raise ValueError("Host values of '' or None are not allowed.") + + timer = timing.Timer(timeout) + + while not timer.expired(): + try: + Checker(timeout=.5).assert_free(host, port) + # Politely wait + time.sleep(0.1) + except PortNotFree: + # port is occupied + return + + raise Timeout("Port {port} not bound on {host}.".format(**locals())) +wait_for_occupied_port = occupied + + +def find_available_local_port(): + """ + Find a free port on localhost. + + >>> 0 < find_available_local_port() < 65536 + True + """ + sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) + addr = '', 0 + sock.bind(addr) + addr, port = sock.getsockname()[:2] + sock.close() + return port + + +class HostPort(str): + """ + A simple representation of a host/port pair as a string + + >>> hp = HostPort('localhost:32768') + + >>> hp.host + 'localhost' + + >>> hp.port + 32768 + + >>> len(hp) + 15 + """ + + @property + def host(self): + host, sep, port = self.partition(':') + return host + + @property + def port(self): + host, sep, port = self.partition(':') + return int(port) + + +def _main(): + parser = argparse.ArgumentParser() + global_lookup = lambda key: globals()[key] + parser.add_argument('target', metavar='host:port', type=HostPort) + parser.add_argument('func', metavar='state', type=global_lookup) + parser.add_argument('-t', '--timeout', default=None, type=float) + args = parser.parse_args() + try: + args.func(args.target.host, args.target.port, timeout=args.timeout) + except Timeout as timeout: + print(timeout, file=sys.stderr) + raise SystemExit(1) + + +if __name__ == '__main__': + _main() diff --git a/resources/lib/runnow/__init__.py b/resources/lib/runnow/__init__.py new file mode 100644 index 0000000..dfc7663 --- /dev/null +++ b/resources/lib/runnow/__init__.py @@ -0,0 +1 @@ +from runnow.jobs import run diff --git a/resources/lib/runnow/jobs.py b/resources/lib/runnow/jobs.py new file mode 100644 index 0000000..e9adbe5 --- /dev/null +++ b/resources/lib/runnow/jobs.py @@ -0,0 +1,126 @@ +import os +import platform +import subprocess +import sys +import time + +from logless import logged, get_logger, flush_buffers + +logging = get_logger("runnow") + + +def _grep(full_text, match_with, insensitive=True, fn=any): + lines = full_text.splitlines() + if isinstance(match_with, str): + match_with = [match_with] + if insensitive: + return "\n".join( + [l for l in lines if fn([m.lower() in l.lower() for m in match_with])] + ) + else: + return "\n".join([l for l in lines if fn([m in l for m in match_with])]) + + +@logged("running command: {'(hidden)' if hide else cmd}") +def run( + cmd: str, + working_dir=None, + echo=True, + raise_error=True, + log_file_path: str = None, + shell=True, + daemon=False, + hide=False, + cwd=None, + wait_test=None, + wait_max=None, +): + """ Run a CLI command and return a tuple: (return_code, output_text) """ + loglines = [] + if working_dir: + prev_working_dir = os.getcwd() + os.chdir(working_dir) + if isinstance(cmd, list): + pass # cmd = " ".join(cmd) + elif platform.system() == "Windows": + cmd = " ".join(cmd.split()) + else: + cmd = cmd.replace("\n", " \\\n") + proc = subprocess.Popen( + cmd, + stderr=subprocess.STDOUT, + stdout=subprocess.PIPE, + universal_newlines=True, + shell=shell, + cwd=cwd, + ) + start_time = time.time() + if working_dir: + os.chdir(prev_working_dir) + if log_file_path: + logfile = open(log_file_path, "w", encoding="utf-8") + else: + logfile = None + line = proc.stdout.readline() + flush_buffers() + while (proc.poll() is None) or line: + if daemon: + if wait_max is None and wait_test is None: + logging.info("Daemon process is launched. Returning...") + break + if callable(wait_test) and wait_test(line): + logging.info(f"Returning. Wait test passed: {line}") + break + if wait_max and time.time() >= start_time + wait_max: + logging.info( + f"{line}\nMax timeout expired (wait_max={wait_max})." + f" Returning..." + ) + if callable(wait_test): + return_code = 1 + else: + return_code = 0 + break + if line: + line = line.rstrip() + loglines.append(line) + if echo: + for l in line.splitlines(): + sys.stdout.write(l.rstrip() + "\r\n") + if logfile: + logfile.write(line + "\n") + else: + time.sleep(0.5) # Sleep half a second if no new output + line = proc.stdout.readline() + flush_buffers() + output_text = chr(10).join(loglines) + if logfile: + logfile.close() + if not proc: + return_code = None + raise RuntimeError(f"Command failed: {cmd}\n\n") + else: + return_code = proc.returncode + if ( + return_code != 0 + and raise_error + and ((daemon == False) or (return_code is not None)) + ): + err_msg = f"Command failed (exit code {return_code}): {cmd}" + if not echo: + print_str = output_text + elif len(output_text.splitlines()) > 10: + print_str = _grep( + output_text, ["error", "exception", "warning", "fail", "deprecat"] + ) + else: + print_str = "" + if print_str: + err_msg += ( + f"{'-' * 80}\n" + f"SCRIPT ERRORS:\n{'-' * 80}\n" + f"{print_str}\n{'-' * 80}\n" + f"END OF SCRIPT OUTPUT\n{'-' * 80}" + ) + raise RuntimeError(err_msg) + return return_code, output_text diff --git a/resources/lib/spotipy/cache_handler.py b/resources/lib/spotipy/cache_handler.py new file mode 100644 index 0000000..3ba3987 --- /dev/null +++ b/resources/lib/spotipy/cache_handler.py @@ -0,0 +1,84 @@ +__all__ = ['CacheHandler', 'CacheFileHandler'] + +import errno +import json +import logging + +logger = logging.getLogger(__name__) + + +class CacheHandler(): + """ + An abstraction layer for handling the caching and retrieval of + authorization tokens. + + Custom extensions of this class must implement get_cached_token + and save_token_to_cache methods with the same input and output + structure as the CacheHandler class. + """ + + def get_cached_token(self): + """ + Get and return a token_info dictionary object. + """ + # return token_info + raise NotImplementedError() + + def save_token_to_cache(self, token_info): + """ + Save a token_info dictionary object to the cache and return None. + """ + raise NotImplementedError() + return None + + +class CacheFileHandler(CacheHandler): + """ + Handles reading and writing cached Spotify authorization tokens + as json files on disk. + """ + + def __init__(self, + cache_path=None, + username=None): + """ + Parameters: + * cache_path: May be supplied, will otherwise be generated + (takes precedence over `username`) + * username: May be supplied or set as environment variable + (will set `cache_path` to `.cache-{username}`) + """ + + if cache_path: + self.cache_path = cache_path + else: + cache_path = ".cache" + if username: + cache_path += "-" + str(username) + self.cache_path = cache_path + + def get_cached_token(self): + token_info = None + + try: + f = open(self.cache_path) + token_info_string = f.read() + f.close() + token_info = json.loads(token_info_string) + + except IOError as error: + if error.errno == errno.ENOENT: + logger.debug("cache does not exist at: %s", self.cache_path) + else: + logger.warning("Couldn't read cache at: %s", self.cache_path) + + return token_info + + def save_token_to_cache(self, token_info): + try: + f = open(self.cache_path, "w") + f.write(json.dumps(token_info)) + f.close() + except IOError: + logger.warning('Couldn\'t write token to cache at: %s', + self.cache_path) diff --git a/resources/lib/spotipy/exceptions.py b/resources/lib/spotipy/exceptions.py new file mode 100644 index 0000000..df503f1 --- /dev/null +++ b/resources/lib/spotipy/exceptions.py @@ -0,0 +1,16 @@ +class SpotifyException(Exception): + + def __init__(self, http_status, code, msg, reason=None, headers=None): + self.http_status = http_status + self.code = code + self.msg = msg + self.reason = reason + # `headers` is used to support `Retry-After` in the event of a + # 429 status code. + if headers is None: + headers = {} + self.headers = headers + + def __str__(self): + return 'http status: {0}, code:{1} - {2}, reason: {3}'.format( + self.http_status, self.code, self.msg, self.reason) diff --git a/resources/lib/tempora/__init__.py b/resources/lib/tempora/__init__.py new file mode 100644 index 0000000..e0cdead --- /dev/null +++ b/resources/lib/tempora/__init__.py @@ -0,0 +1,505 @@ +# -*- coding: UTF-8 -*- + +"Objects and routines pertaining to date and time (tempora)" + +from __future__ import division, unicode_literals + +import datetime +import time +import re +import numbers +import functools + +import six + +__metaclass__ = type + + +class Parser: + """ + Datetime parser: parses a date-time string using multiple possible + formats. + + >>> p = Parser(('%H%M', '%H:%M')) + >>> tuple(p.parse('1319')) + (1900, 1, 1, 13, 19, 0, 0, 1, -1) + >>> dateParser = Parser(('%m/%d/%Y', '%Y-%m-%d', '%d-%b-%Y')) + >>> tuple(dateParser.parse('2003-12-20')) + (2003, 12, 20, 0, 0, 0, 5, 354, -1) + >>> tuple(dateParser.parse('16-Dec-1994')) + (1994, 12, 16, 0, 0, 0, 4, 350, -1) + >>> tuple(dateParser.parse('5/19/2003')) + (2003, 5, 19, 0, 0, 0, 0, 139, -1) + >>> dtParser = Parser(('%Y-%m-%d %H:%M:%S', '%a %b %d %H:%M:%S %Y')) + >>> tuple(dtParser.parse('2003-12-20 19:13:26')) + (2003, 12, 20, 19, 13, 26, 5, 354, -1) + >>> tuple(dtParser.parse('Tue Jan 20 16:19:33 2004')) + (2004, 1, 20, 16, 19, 33, 1, 20, -1) + + Be forewarned, a ValueError will be raised if more than one format + matches: + + >>> Parser(('%H%M', '%H%M%S')).parse('732') + Traceback (most recent call last): + ... + ValueError: More than one format string matched target 732. + """ + + formats = ('%m/%d/%Y', '%m/%d/%y', '%Y-%m-%d', '%d-%b-%Y', '%d-%b-%y') + "some common default formats" + + def __init__(self, formats=None): + if formats: + self.formats = formats + + def parse(self, target): + self.target = target + results = tuple(filter(None, map(self._parse, self.formats))) + del self.target + if not results: + tmpl = "No format strings matched the target {target}." + raise ValueError(tmpl.format(**locals())) + if not len(results) == 1: + tmpl = "More than one format string matched target {target}." + raise ValueError(tmpl.format(**locals())) + return results[0] + + def _parse(self, format): + try: + result = time.strptime(self.target, format) + except ValueError: + result = False + return result + + +# some useful constants +osc_per_year = 290091329207984000 +""" +mean vernal equinox year expressed in oscillations of atomic cesium at the +year 2000 (see http://webexhibits.org/calendars/timeline.html for more info). +""" +osc_per_second = 9192631770 +seconds_per_second = 1 +seconds_per_year = 31556940 +seconds_per_minute = 60 +minutes_per_hour = 60 +hours_per_day = 24 +seconds_per_hour = seconds_per_minute * minutes_per_hour +seconds_per_day = seconds_per_hour * hours_per_day +days_per_year = seconds_per_year / seconds_per_day +thirty_days = datetime.timedelta(days=30) +# these values provide useful averages +six_months = datetime.timedelta(days=days_per_year / 2) +seconds_per_month = seconds_per_year / 12 +hours_per_month = hours_per_day * days_per_year / 12 + + +def strftime(fmt, t): + """A class to replace the strftime in datetime package or time module. + Identical to strftime behavior in those modules except supports any + year. + Also supports datetime.datetime times. + Also supports milliseconds using %s + Also supports microseconds using %u""" + if isinstance(t, (time.struct_time, tuple)): + t = datetime.datetime(*t[:6]) + assert isinstance(t, (datetime.datetime, datetime.time, datetime.date)) + try: + year = t.year + if year < 1900: + t = t.replace(year=1900) + except AttributeError: + year = 1900 + subs = ( + ('%Y', '%04d' % year), + ('%y', '%02d' % (year % 100)), + ('%s', '%03d' % (t.microsecond // 1000)), + ('%u', '%03d' % (t.microsecond % 1000)) + ) + + def doSub(s, sub): + return s.replace(*sub) + + def doSubs(s): + return functools.reduce(doSub, subs, s) + + fmt = '%%'.join(map(doSubs, fmt.split('%%'))) + return t.strftime(fmt) + + +def strptime(s, fmt, tzinfo=None): + """ + A function to replace strptime in the time module. Should behave + identically to the strptime function except it returns a datetime.datetime + object instead of a time.struct_time object. + Also takes an optional tzinfo parameter which is a time zone info object. + """ + res = time.strptime(s, fmt) + return datetime.datetime(tzinfo=tzinfo, *res[:6]) + + +class DatetimeConstructor: + """ + >>> cd = DatetimeConstructor.construct_datetime + >>> cd(datetime.datetime(2011,1,1)) + datetime.datetime(2011, 1, 1, 0, 0) + """ + @classmethod + def construct_datetime(cls, *args, **kwargs): + """Construct a datetime.datetime from a number of different time + types found in python and pythonwin""" + if len(args) == 1: + arg = args[0] + method = cls.__get_dt_constructor( + type(arg).__module__, + type(arg).__name__, + ) + result = method(arg) + try: + result = result.replace(tzinfo=kwargs.pop('tzinfo')) + except KeyError: + pass + if kwargs: + first_key = kwargs.keys()[0] + tmpl = ( + "{first_key} is an invalid keyword " + "argument for this function." + ) + raise TypeError(tmpl.format(**locals())) + else: + result = datetime.datetime(*args, **kwargs) + return result + + @classmethod + def __get_dt_constructor(cls, moduleName, name): + try: + method_name = '__dt_from_{moduleName}_{name}__'.format(**locals()) + return getattr(cls, method_name) + except AttributeError: + tmpl = ( + "No way to construct datetime.datetime from " + "{moduleName}.{name}" + ) + raise TypeError(tmpl.format(**locals())) + + @staticmethod + def __dt_from_datetime_datetime__(source): + dtattrs = ( + 'year', 'month', 'day', 'hour', 'minute', 'second', + 'microsecond', 'tzinfo', + ) + attrs = map(lambda a: getattr(source, a), dtattrs) + return datetime.datetime(*attrs) + + @staticmethod + def __dt_from___builtin___time__(pyt): + "Construct a datetime.datetime from a pythonwin time" + fmtString = '%Y-%m-%d %H:%M:%S' + result = strptime(pyt.Format(fmtString), fmtString) + # get milliseconds and microseconds. The only way to do this is + # to use the __float__ attribute of the time, which is in days. + microseconds_per_day = seconds_per_day * 1000000 + microseconds = float(pyt) * microseconds_per_day + microsecond = int(microseconds % 1000000) + result = result.replace(microsecond=microsecond) + return result + + @staticmethod + def __dt_from_timestamp__(timestamp): + return datetime.datetime.utcfromtimestamp(timestamp) + __dt_from___builtin___float__ = __dt_from_timestamp__ + __dt_from___builtin___long__ = __dt_from_timestamp__ + __dt_from___builtin___int__ = __dt_from_timestamp__ + + @staticmethod + def __dt_from_time_struct_time__(s): + return datetime.datetime(*s[:6]) + + +def datetime_mod(dt, period, start=None): + """ + Find the time which is the specified date/time truncated to the time delta + relative to the start date/time. + By default, the start time is midnight of the same day as the specified + date/time. + + >>> datetime_mod(datetime.datetime(2004, 1, 2, 3), + ... datetime.timedelta(days = 1.5), + ... start = datetime.datetime(2004, 1, 1)) + datetime.datetime(2004, 1, 1, 0, 0) + >>> datetime_mod(datetime.datetime(2004, 1, 2, 13), + ... datetime.timedelta(days = 1.5), + ... start = datetime.datetime(2004, 1, 1)) + datetime.datetime(2004, 1, 2, 12, 0) + >>> datetime_mod(datetime.datetime(2004, 1, 2, 13), + ... datetime.timedelta(days = 7), + ... start = datetime.datetime(2004, 1, 1)) + datetime.datetime(2004, 1, 1, 0, 0) + >>> datetime_mod(datetime.datetime(2004, 1, 10, 13), + ... datetime.timedelta(days = 7), + ... start = datetime.datetime(2004, 1, 1)) + datetime.datetime(2004, 1, 8, 0, 0) + """ + if start is None: + # use midnight of the same day + start = datetime.datetime.combine(dt.date(), datetime.time()) + # calculate the difference between the specified time and the start date. + delta = dt - start + + # now aggregate the delta and the period into microseconds + # Use microseconds because that's the highest precision of these time + # pieces. Also, using microseconds ensures perfect precision (no floating + # point errors). + def get_time_delta_microseconds(td): + return (td.days * seconds_per_day + td.seconds) * 1000000 + td.microseconds + delta, period = map(get_time_delta_microseconds, (delta, period)) + offset = datetime.timedelta(microseconds=delta % period) + # the result is the original specified time minus the offset + result = dt - offset + return result + + +def datetime_round(dt, period, start=None): + """ + Find the nearest even period for the specified date/time. + + >>> datetime_round(datetime.datetime(2004, 11, 13, 8, 11, 13), + ... datetime.timedelta(hours = 1)) + datetime.datetime(2004, 11, 13, 8, 0) + >>> datetime_round(datetime.datetime(2004, 11, 13, 8, 31, 13), + ... datetime.timedelta(hours = 1)) + datetime.datetime(2004, 11, 13, 9, 0) + >>> datetime_round(datetime.datetime(2004, 11, 13, 8, 30), + ... datetime.timedelta(hours = 1)) + datetime.datetime(2004, 11, 13, 9, 0) + """ + result = datetime_mod(dt, period, start) + if abs(dt - result) >= period // 2: + result += period + return result + + +def get_nearest_year_for_day(day): + """ + Returns the nearest year to now inferred from a Julian date. + """ + now = time.gmtime() + result = now.tm_year + # if the day is far greater than today, it must be from last year + if day - now.tm_yday > 365 // 2: + result -= 1 + # if the day is far less than today, it must be for next year. + if now.tm_yday - day > 365 // 2: + result += 1 + return result + + +def gregorian_date(year, julian_day): + """ + Gregorian Date is defined as a year and a julian day (1-based + index into the days of the year). + + >>> gregorian_date(2007, 15) + datetime.date(2007, 1, 15) + """ + result = datetime.date(year, 1, 1) + result += datetime.timedelta(days=julian_day - 1) + return result + + +def get_period_seconds(period): + """ + return the number of seconds in the specified period + + >>> get_period_seconds('day') + 86400 + >>> get_period_seconds(86400) + 86400 + >>> get_period_seconds(datetime.timedelta(hours=24)) + 86400 + >>> get_period_seconds('day + os.system("rm -Rf *")') + Traceback (most recent call last): + ... + ValueError: period not in (second, minute, hour, day, month, year) + """ + if isinstance(period, six.string_types): + try: + name = 'seconds_per_' + period.lower() + result = globals()[name] + except KeyError: + msg = "period not in (second, minute, hour, day, month, year)" + raise ValueError(msg) + elif isinstance(period, numbers.Number): + result = period + elif isinstance(period, datetime.timedelta): + result = period.days * get_period_seconds('day') + period.seconds + else: + raise TypeError('period must be a string or integer') + return result + + +def get_date_format_string(period): + """ + For a given period (e.g. 'month', 'day', or some numeric interval + such as 3600 (in secs)), return the format string that can be + used with strftime to format that time to specify the times + across that interval, but no more detailed. + For example, + + >>> get_date_format_string('month') + '%Y-%m' + >>> get_date_format_string(3600) + '%Y-%m-%d %H' + >>> get_date_format_string('hour') + '%Y-%m-%d %H' + >>> get_date_format_string(None) + Traceback (most recent call last): + ... + TypeError: period must be a string or integer + >>> get_date_format_string('garbage') + Traceback (most recent call last): + ... + ValueError: period not in (second, minute, hour, day, month, year) + """ + # handle the special case of 'month' which doesn't have + # a static interval in seconds + if isinstance(period, six.string_types) and period.lower() == 'month': + return '%Y-%m' + file_period_secs = get_period_seconds(period) + format_pieces = ('%Y', '-%m-%d', ' %H', '-%M', '-%S') + seconds_per_second = 1 + intervals = ( + seconds_per_year, + seconds_per_day, + seconds_per_hour, + seconds_per_minute, + seconds_per_second, + ) + mods = list(map(lambda interval: file_period_secs % interval, intervals)) + format_pieces = format_pieces[: mods.index(0) + 1] + return ''.join(format_pieces) + + +def divide_timedelta_float(td, divisor): + """ + Divide a timedelta by a float value + + >>> one_day = datetime.timedelta(days=1) + >>> half_day = datetime.timedelta(days=.5) + >>> divide_timedelta_float(one_day, 2.0) == half_day + True + >>> divide_timedelta_float(one_day, 2) == half_day + True + """ + # td is comprised of days, seconds, microseconds + dsm = [getattr(td, attr) for attr in ('days', 'seconds', 'microseconds')] + dsm = map(lambda elem: elem / divisor, dsm) + return datetime.timedelta(*dsm) + + +def calculate_prorated_values(): + """ + A utility function to prompt for a rate (a string in units per + unit time), and return that same rate for various time periods. + """ + rate = six.moves.input("Enter the rate (3/hour, 50/month)> ") + res = re.match('(?P[\d.]+)/(?P\w+)$', rate).groupdict() + value = float(res['value']) + value_per_second = value / get_period_seconds(res['period']) + for period in ('minute', 'hour', 'day', 'month', 'year'): + period_value = value_per_second * get_period_seconds(period) + print("per {period}: {period_value}".format(**locals())) + + +def parse_timedelta(str): + """ + Take a string representing a span of time and parse it to a time delta. + Accepts any string of comma-separated numbers each with a unit indicator. + + >>> parse_timedelta('1 day') + datetime.timedelta(days=1) + + >>> parse_timedelta('1 day, 30 seconds') + datetime.timedelta(days=1, seconds=30) + + >>> parse_timedelta('47.32 days, 20 minutes, 15.4 milliseconds') + datetime.timedelta(days=47, seconds=28848, microseconds=15400) + + Supports weeks, months, years + + >>> parse_timedelta('1 week') + datetime.timedelta(days=7) + + >>> parse_timedelta('1 year, 1 month') + datetime.timedelta(days=395, seconds=58685) + + Note that months and years strict intervals, not aligned + to a calendar: + + >>> now = datetime.datetime.now() + >>> later = now + parse_timedelta('1 year') + >>> later.replace(year=now.year) - now + datetime.timedelta(seconds=20940) + """ + deltas = (_parse_timedelta_part(part.strip()) for part in str.split(',')) + return sum(deltas, datetime.timedelta()) + + +def _parse_timedelta_part(part): + match = re.match('(?P[\d.]+) (?P\w+)', part) + if not match: + msg = "Unable to parse {part!r} as a time delta".format(**locals()) + raise ValueError(msg) + unit = match.group('unit').lower() + if not unit.endswith('s'): + unit += 's' + value = float(match.group('value')) + if unit == 'months': + unit = 'years' + value = value / 12 + if unit == 'years': + unit = 'days' + value = value * days_per_year + return datetime.timedelta(**{unit: value}) + + +def divide_timedelta(td1, td2): + """ + Get the ratio of two timedeltas + + >>> one_day = datetime.timedelta(days=1) + >>> one_hour = datetime.timedelta(hours=1) + >>> divide_timedelta(one_hour, one_day) == 1 / 24 + True + """ + try: + return td1 / td2 + except TypeError: + # Python 3.2 gets division + # http://bugs.python.org/issue2706 + return td1.total_seconds() / td2.total_seconds() + + +def date_range(start=None, stop=None, step=None): + """ + Much like the built-in function range, but works with dates + + >>> range_items = date_range( + ... datetime.datetime(2005,12,21), + ... datetime.datetime(2005,12,25), + ... ) + >>> my_range = tuple(range_items) + >>> datetime.datetime(2005,12,21) in my_range + True + >>> datetime.datetime(2005,12,22) in my_range + True + >>> datetime.datetime(2005,12,25) in my_range + False + """ + if step is None: + step = datetime.timedelta(days=1) + if start is None: + start = datetime.datetime.now() + while start < stop: + yield start + start += step diff --git a/resources/lib/tempora/schedule.py b/resources/lib/tempora/schedule.py new file mode 100644 index 0000000..1ad093b --- /dev/null +++ b/resources/lib/tempora/schedule.py @@ -0,0 +1,202 @@ +# -*- coding: utf-8 -*- + +""" +Classes for calling functions a schedule. +""" + +from __future__ import absolute_import + +import datetime +import numbers +import abc +import bisect + +import pytz + +__metaclass__ = type + + +def now(): + """ + Provide the current timezone-aware datetime. + + A client may override this function to change the default behavior, + such as to use local time or timezone-naïve times. + """ + return datetime.datetime.utcnow().replace(tzinfo=pytz.utc) + + +def from_timestamp(ts): + """ + Convert a numeric timestamp to a timezone-aware datetime. + + A client may override this function to change the default behavior, + such as to use local time or timezone-naïve times. + """ + return datetime.datetime.utcfromtimestamp(ts).replace(tzinfo=pytz.utc) + + +class DelayedCommand(datetime.datetime): + """ + A command to be executed after some delay (seconds or timedelta). + """ + + @classmethod + def from_datetime(cls, other): + return cls( + other.year, other.month, other.day, other.hour, + other.minute, other.second, other.microsecond, + other.tzinfo, + ) + + @classmethod + def after(cls, delay, target): + if not isinstance(delay, datetime.timedelta): + delay = datetime.timedelta(seconds=delay) + due_time = now() + delay + cmd = cls.from_datetime(due_time) + cmd.delay = delay + cmd.target = target + return cmd + + @staticmethod + def _from_timestamp(input): + """ + If input is a real number, interpret it as a Unix timestamp + (seconds sinc Epoch in UTC) and return a timezone-aware + datetime object. Otherwise return input unchanged. + """ + if not isinstance(input, numbers.Real): + return input + return from_timestamp(input) + + @classmethod + def at_time(cls, at, target): + """ + Construct a DelayedCommand to come due at `at`, where `at` may be + a datetime or timestamp. + """ + at = cls._from_timestamp(at) + cmd = cls.from_datetime(at) + cmd.delay = at - now() + cmd.target = target + return cmd + + def due(self): + return now() >= self + + +class PeriodicCommand(DelayedCommand): + """ + Like a delayed command, but expect this command to run every delay + seconds. + """ + def _next_time(self): + """ + Add delay to self, localized + """ + return self._localize(self + self.delay) + + @staticmethod + def _localize(dt): + """ + Rely on pytz.localize to ensure new result honors DST. + """ + try: + tz = dt.tzinfo + return tz.localize(dt.replace(tzinfo=None)) + except AttributeError: + return dt + + def next(self): + cmd = self.__class__.from_datetime(self._next_time()) + cmd.delay = self.delay + cmd.target = self.target + return cmd + + def __setattr__(self, key, value): + if key == 'delay' and not value > datetime.timedelta(): + raise ValueError( + "A PeriodicCommand must have a positive, " + "non-zero delay." + ) + super(PeriodicCommand, self).__setattr__(key, value) + + +class PeriodicCommandFixedDelay(PeriodicCommand): + """ + Like a periodic command, but don't calculate the delay based on + the current time. Instead use a fixed delay following the initial + run. + """ + + @classmethod + def at_time(cls, at, delay, target): + at = cls._from_timestamp(at) + cmd = cls.from_datetime(at) + if isinstance(delay, numbers.Number): + delay = datetime.timedelta(seconds=delay) + cmd.delay = delay + cmd.target = target + return cmd + + @classmethod + def daily_at(cls, at, target): + """ + Schedule a command to run at a specific time each day. + """ + daily = datetime.timedelta(days=1) + # convert when to the next datetime matching this time + when = datetime.datetime.combine(datetime.date.today(), at) + if when < now(): + when += daily + return cls.at_time(cls._localize(when), daily, target) + + +class Scheduler: + """ + A rudimentary abstract scheduler accepting DelayedCommands + and dispatching them on schedule. + """ + def __init__(self): + self.queue = [] + + def add(self, command): + assert isinstance(command, DelayedCommand) + bisect.insort(self.queue, command) + + def run_pending(self): + while self.queue: + command = self.queue[0] + if not command.due(): + break + self.run(command) + if isinstance(command, PeriodicCommand): + self.add(command.next()) + del self.queue[0] + + @abc.abstractmethod + def run(self, command): + """ + Run the command + """ + + +class InvokeScheduler(Scheduler): + """ + Command targets are functions to be invoked on schedule. + """ + def run(self, command): + command.target() + + +class CallbackScheduler(Scheduler): + """ + Command targets are passed to a dispatch callable on schedule. + """ + def __init__(self, dispatch): + super(CallbackScheduler, self).__init__() + self.dispatch = dispatch + + def run(self, command): + self.dispatch(command.target) diff --git a/resources/lib/tempora/tests/test_schedule.py b/resources/lib/tempora/tests/test_schedule.py new file mode 100644 index 0000000..38eb8dc --- /dev/null +++ b/resources/lib/tempora/tests/test_schedule.py @@ -0,0 +1,118 @@ +import time +import random +import datetime + +import pytest +import pytz +import freezegun + +from tempora import schedule + +__metaclass__ = type + + +@pytest.fixture +def naive_times(monkeypatch): + monkeypatch.setattr( + 'irc.schedule.from_timestamp', + datetime.datetime.fromtimestamp) + monkeypatch.setattr('irc.schedule.now', datetime.datetime.now) + + +do_nothing = type(None) +try: + do_nothing() +except TypeError: + # Python 2 compat + def do_nothing(): + return None + + +def test_delayed_command_order(): + """ + delayed commands should be sorted by delay time + """ + delays = [random.randint(0, 99) for x in range(5)] + cmds = sorted([ + schedule.DelayedCommand.after(delay, do_nothing) + for delay in delays + ]) + assert [c.delay.seconds for c in cmds] == sorted(delays) + + +def test_periodic_command_delay(): + "A PeriodicCommand must have a positive, non-zero delay." + with pytest.raises(ValueError) as exc_info: + schedule.PeriodicCommand.after(0, None) + assert str(exc_info.value) == test_periodic_command_delay.__doc__ + + +def test_periodic_command_fixed_delay(): + """ + Test that we can construct a periodic command with a fixed initial + delay. + """ + fd = schedule.PeriodicCommandFixedDelay.at_time( + at=schedule.now(), + delay=datetime.timedelta(seconds=2), + target=lambda: None, + ) + assert fd.due() is True + assert fd.next().due() is False + + +class TestCommands: + def test_delayed_command_from_timestamp(self): + """ + Ensure a delayed command can be constructed from a timestamp. + """ + t = time.time() + schedule.DelayedCommand.at_time(t, do_nothing) + + def test_command_at_noon(self): + """ + Create a periodic command that's run at noon every day. + """ + when = datetime.time(12, 0, tzinfo=pytz.utc) + cmd = schedule.PeriodicCommandFixedDelay.daily_at(when, target=None) + assert cmd.due() is False + next_cmd = cmd.next() + daily = datetime.timedelta(days=1) + day_from_now = schedule.now() + daily + two_days_from_now = day_from_now + daily + assert day_from_now < next_cmd < two_days_from_now + + +class TestTimezones: + def test_alternate_timezone_west(self): + target_tz = pytz.timezone('US/Pacific') + target = schedule.now().astimezone(target_tz) + cmd = schedule.DelayedCommand.at_time(target, target=None) + assert cmd.due() + + def test_alternate_timezone_east(self): + target_tz = pytz.timezone('Europe/Amsterdam') + target = schedule.now().astimezone(target_tz) + cmd = schedule.DelayedCommand.at_time(target, target=None) + assert cmd.due() + + def test_daylight_savings(self): + """ + A command at 9am should always be 9am regardless of + a DST boundary. + """ + with freezegun.freeze_time('2018-03-10 08:00:00'): + target_tz = pytz.timezone('US/Eastern') + target_time = datetime.time(9, tzinfo=target_tz) + cmd = schedule.PeriodicCommandFixedDelay.daily_at( + target_time, + target=lambda: None, + ) + + def naive(dt): + return dt.replace(tzinfo=None) + + assert naive(cmd) == datetime.datetime(2018, 3, 10, 9, 0, 0) + next_ = cmd.next() + assert naive(next_) == datetime.datetime(2018, 3, 11, 9, 0, 0) + assert next_ - cmd == datetime.timedelta(hours=23) diff --git a/resources/lib/tempora/timing.py b/resources/lib/tempora/timing.py new file mode 100644 index 0000000..03c2245 --- /dev/null +++ b/resources/lib/tempora/timing.py @@ -0,0 +1,219 @@ +# -*- coding: utf-8 -*- + +from __future__ import unicode_literals, absolute_import + +import datetime +import functools +import numbers +import time + +__metaclass__ = type + + +class Stopwatch: + """ + A simple stopwatch which starts automatically. + + >>> w = Stopwatch() + >>> _1_sec = datetime.timedelta(seconds=1) + >>> w.split() < _1_sec + True + >>> import time + >>> time.sleep(1.0) + >>> w.split() >= _1_sec + True + >>> w.stop() >= _1_sec + True + >>> w.reset() + >>> w.start() + >>> w.split() < _1_sec + True + + It should be possible to launch the Stopwatch in a context: + + >>> with Stopwatch() as watch: + ... assert isinstance(watch.split(), datetime.timedelta) + + In that case, the watch is stopped when the context is exited, + so to read the elapsed time:: + + >>> watch.elapsed + datetime.timedelta(...) + >>> watch.elapsed.seconds + 0 + """ + def __init__(self): + self.reset() + self.start() + + def reset(self): + self.elapsed = datetime.timedelta(0) + if hasattr(self, 'start_time'): + del self.start_time + + def start(self): + self.start_time = datetime.datetime.utcnow() + + def stop(self): + stop_time = datetime.datetime.utcnow() + self.elapsed += stop_time - self.start_time + del self.start_time + return self.elapsed + + def split(self): + local_duration = datetime.datetime.utcnow() - self.start_time + return self.elapsed + local_duration + + # context manager support + def __enter__(self): + self.start() + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.stop() + + +class IntervalGovernor: + """ + Decorate a function to only allow it to be called once per + min_interval. Otherwise, it returns None. + """ + def __init__(self, min_interval): + if isinstance(min_interval, numbers.Number): + min_interval = datetime.timedelta(seconds=min_interval) + self.min_interval = min_interval + self.last_call = None + + def decorate(self, func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + allow = ( + not self.last_call + or self.last_call.split() > self.min_interval + ) + if allow: + self.last_call = Stopwatch() + return func(*args, **kwargs) + return wrapper + + __call__ = decorate + + +class Timer(Stopwatch): + """ + Watch for a target elapsed time. + + >>> t = Timer(0.1) + >>> t.expired() + False + >>> __import__('time').sleep(0.15) + >>> t.expired() + True + """ + def __init__(self, target=float('Inf')): + self.target = self._accept(target) + super(Timer, self).__init__() + + def _accept(self, target): + "Accept None or ∞ or datetime or numeric for target" + if isinstance(target, datetime.timedelta): + target = target.total_seconds() + + if target is None: + # treat None as infinite target + target = float('Inf') + + return target + + def expired(self): + return self.split().total_seconds() > self.target + + +class BackoffDelay: + """ + Exponential backoff delay. + + Useful for defining delays between retries. Consider for use + with ``jaraco.functools.retry_call`` as the cleanup. + + Default behavior has no effect; a delay or jitter must + be supplied for the call to be non-degenerate. + + >>> bd = BackoffDelay() + >>> bd() + >>> bd() + + The following instance will delay 10ms for the first call, + 20ms for the second, etc. + + >>> bd = BackoffDelay(delay=0.01, factor=2) + >>> bd() + >>> bd() + + Inspect and adjust the state of the delay anytime. + + >>> bd.delay + 0.04 + >>> bd.delay = 0.01 + + Set limit to prevent the delay from exceeding bounds. + + >>> bd = BackoffDelay(delay=0.01, factor=2, limit=0.015) + >>> bd() + >>> bd.delay + 0.015 + + Limit may be a callable taking a number and returning + the limited number. + + >>> at_least_one = lambda n: max(n, 1) + >>> bd = BackoffDelay(delay=0.01, factor=2, limit=at_least_one) + >>> bd() + >>> bd.delay + 1 + + Pass a jitter to add or subtract seconds to the delay. + + >>> bd = BackoffDelay(jitter=0.01) + >>> bd() + >>> bd.delay + 0.01 + + Jitter may be a callable. To supply a non-deterministic jitter + between -0.5 and 0.5, consider: + + >>> import random + >>> jitter=functools.partial(random.uniform, -0.5, 0.5) + >>> bd = BackoffDelay(jitter=jitter) + >>> bd() + >>> 0 <= bd.delay <= 0.5 + True + """ + + delay = 0 + + factor = 1 + "Multiplier applied to delay" + + jitter = 0 + "Number or callable returning extra seconds to add to delay" + + def __init__(self, delay=0, factor=1, limit=float('inf'), jitter=0): + self.delay = delay + self.factor = factor + if isinstance(limit, numbers.Number): + limit_ = limit + + def limit(n): + return max(0, min(limit_, n)) + self.limit = limit + if isinstance(jitter, numbers.Number): + jitter_ = jitter + + def jitter(): + return jitter_ + self.jitter = jitter + + def __call__(self): + time.sleep(self.delay) + self.delay = self.limit(self.delay * self.factor + self.jitter()) diff --git a/resources/lib/tempora/utc.py b/resources/lib/tempora/utc.py new file mode 100644 index 0000000..35bfdb0 --- /dev/null +++ b/resources/lib/tempora/utc.py @@ -0,0 +1,36 @@ +""" +Facilities for common time operations in UTC. + +Inspired by the `utc project `_. + +>>> dt = now() +>>> dt == fromtimestamp(dt.timestamp()) +True +>>> dt.tzinfo +datetime.timezone.utc + +>>> from time import time as timestamp +>>> now().timestamp() - timestamp() < 0.1 +True + +>>> datetime(2018, 6, 26, 0).tzinfo +datetime.timezone.utc + +>>> time(0, 0).tzinfo +datetime.timezone.utc +""" + +import datetime as std +import functools + + +__all__ = ['now', 'fromtimestamp', 'datetime', 'time'] + + +now = functools.partial(std.datetime.now, std.timezone.utc) +fromtimestamp = functools.partial( + std.datetime.fromtimestamp, + tz=std.timezone.utc, +) +datetime = functools.partial(std.datetime, tzinfo=std.timezone.utc) +time = functools.partial(std.time, tzinfo=std.timezone.utc) diff --git a/resources/lib/termcolor.py b/resources/lib/termcolor.py new file mode 100644 index 0000000..f11b824 --- /dev/null +++ b/resources/lib/termcolor.py @@ -0,0 +1,168 @@ +# coding: utf-8 +# Copyright (c) 2008-2011 Volvox Development Team +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# Author: Konstantin Lepa + +"""ANSII Color formatting for output in terminal.""" + +from __future__ import print_function +import os + + +__ALL__ = [ 'colored', 'cprint' ] + +VERSION = (1, 1, 0) + +ATTRIBUTES = dict( + list(zip([ + 'bold', + 'dark', + '', + 'underline', + 'blink', + '', + 'reverse', + 'concealed' + ], + list(range(1, 9)) + )) + ) +del ATTRIBUTES[''] + + +HIGHLIGHTS = dict( + list(zip([ + 'on_grey', + 'on_red', + 'on_green', + 'on_yellow', + 'on_blue', + 'on_magenta', + 'on_cyan', + 'on_white' + ], + list(range(40, 48)) + )) + ) + + +COLORS = dict( + list(zip([ + 'grey', + 'red', + 'green', + 'yellow', + 'blue', + 'magenta', + 'cyan', + 'white', + ], + list(range(30, 38)) + )) + ) + + +RESET = '\033[0m' + + +def colored(text, color=None, on_color=None, attrs=None): + """Colorize text. + + Available text colors: + red, green, yellow, blue, magenta, cyan, white. + + Available text highlights: + on_red, on_green, on_yellow, on_blue, on_magenta, on_cyan, on_white. + + Available attributes: + bold, dark, underline, blink, reverse, concealed. + + Example: + colored('Hello, World!', 'red', 'on_grey', ['blue', 'blink']) + colored('Hello, World!', 'green') + """ + if os.getenv('ANSI_COLORS_DISABLED') is None: + fmt_str = '\033[%dm%s' + if color is not None: + text = fmt_str % (COLORS[color], text) + + if on_color is not None: + text = fmt_str % (HIGHLIGHTS[on_color], text) + + if attrs is not None: + for attr in attrs: + text = fmt_str % (ATTRIBUTES[attr], text) + + text += RESET + return text + + +def cprint(text, color=None, on_color=None, attrs=None, **kwargs): + """Print colorize text. + + It accepts arguments of print function. + """ + + print((colored(text, color, on_color, attrs)), **kwargs) + + +if __name__ == '__main__': + print('Current terminal type: %s' % os.getenv('TERM')) + print('Test basic colors:') + cprint('Grey color', 'grey') + cprint('Red color', 'red') + cprint('Green color', 'green') + cprint('Yellow color', 'yellow') + cprint('Blue color', 'blue') + cprint('Magenta color', 'magenta') + cprint('Cyan color', 'cyan') + cprint('White color', 'white') + print(('-' * 78)) + + print('Test highlights:') + cprint('On grey color', on_color='on_grey') + cprint('On red color', on_color='on_red') + cprint('On green color', on_color='on_green') + cprint('On yellow color', on_color='on_yellow') + cprint('On blue color', on_color='on_blue') + cprint('On magenta color', on_color='on_magenta') + cprint('On cyan color', on_color='on_cyan') + cprint('On white color', color='grey', on_color='on_white') + print('-' * 78) + + print('Test attributes:') + cprint('Bold grey color', 'grey', attrs=['bold']) + cprint('Dark red color', 'red', attrs=['dark']) + cprint('Underline green color', 'green', attrs=['underline']) + cprint('Blink yellow color', 'yellow', attrs=['blink']) + cprint('Reversed blue color', 'blue', attrs=['reverse']) + cprint('Concealed Magenta color', 'magenta', attrs=['concealed']) + cprint('Bold underline reverse cyan color', 'cyan', + attrs=['bold', 'underline', 'reverse']) + cprint('Dark blink concealed white color', 'white', + attrs=['dark', 'blink', 'concealed']) + print(('-' * 78)) + + print('Test mixing:') + cprint('Underline red on grey color', 'red', 'on_grey', + ['underline']) + cprint('Reversed green on red color', 'green', 'on_red', ['reverse']) + diff --git a/resources/lib/utils.py b/resources/lib/utils.py index 7424fdf..6a9c779 100644 --- a/resources/lib/utils.py +++ b/resources/lib/utils.py @@ -21,13 +21,14 @@ import xbmcaddon import struct import random +import io import time import math from threading import Thread, Event PROXY_PORT = 52308 -DEBUG = False +DEBUG = True try: import simplejson as json @@ -37,8 +38,12 @@ try: from cStringIO import StringIO except ImportError: - from StringIO import StringIO + from io import StringIO +try: + from cBytesIO import BytesIO +except ImportError: + from io import BytesIO ADDON_ID = "plugin.audio.spotify" KODI_VERSION = int(xbmc.getInfoLabel("System.BuildVersion").split(".")[0]) @@ -73,10 +78,10 @@ def log_msg(msg, loglevel=xbmc.LOGDEBUG): '''log message to kodi log''' - if isinstance(msg, unicode): + if isinstance(msg, str): msg = msg.encode('utf-8') if DEBUG: - loglevel = xbmc.LOGNOTICE + loglevel = xbmc.LOGINFO xbmc.log("%s --> %s" % (ADDON_ID, msg), level=loglevel) @@ -92,17 +97,9 @@ def addon_setting(settingname, set_value=None): if set_value: addon.setSetting(settingname, set_value) else: - return addon.getSetting(settingname).decode("utf-8") + return addon.getSetting(settingname) -def kill_spotty(): - '''make sure we don't have any (remaining) spotty processes running before we start one''' - if xbmc.getCondVisibility("System.Platform.Windows"): - startupinfo = subprocess.STARTUPINFO() - startupinfo.dwFlags |= subprocess._subprocess.STARTF_USESHOWWINDOW - subprocess.Popen(["taskkill", "/IM", "spotty.exe"], startupinfo=startupinfo, shell=True) - else: - os.system("killall spotty") def kill_on_timeout(done, timeout, proc): @@ -151,7 +148,7 @@ def request_token_spotty(spotty, use_creds=True): log_msg("request_token_spotty stdout: %s" % stdout) for line in stdout.split(): line = line.strip() - if line.startswith("{\"accessToken\""): + if line.startswith(b"{\"accessToken\""): result = eval(line) # transform token info to spotipy compatible format if result: @@ -173,7 +170,7 @@ def request_token_web(force=False): from spotipy import oauth2 xbmcvfs.mkdir("special://profile/addon_data/%s/" % ADDON_ID) cache_path = "special://profile/addon_data/%s/spotipy.cache" % ADDON_ID - cache_path = xbmc.translatePath(cache_path).decode("utf-8") + cache_path = xbmcvfs.translatePath(cache_path) scope = " ".join(SCOPE) redirect_url = 'http://localhost:%s/callback' % PROXY_PORT sp_oauth = oauth2.SpotifyOAuth(CLIENTID, CLIENT_SECRET, redirect_url, scope=scope, cache_path=cache_path) @@ -186,8 +183,8 @@ def request_token_web(force=False): # show message to user that the browser is going to be launched dialog = xbmcgui.Dialog() - header = xbmc.getInfoLabel("System.AddonTitle(%s)" % ADDON_ID).decode("utf-8") - msg = xbmc.getInfoLabel("$ADDON[%s 11049]" % ADDON_ID).decode("utf-8") + header = xbmc.getInfoLabel("System.AddonTitle(%s)" % ADDON_ID) + msg = xbmc.getInfoLabel("$ADDON[%s 11049]" % ADDON_ID) dialog.ok(header, msg) del dialog @@ -224,7 +221,7 @@ def request_token_web(force=False): def create_wave_header(duration): '''generate a wave header for the stream''' - file = StringIO() + file = BytesIO() numsamples = 44100 * duration channels = 2 samplerate = 44100 @@ -234,21 +231,21 @@ def create_wave_header(duration): format_chunk_spec = "<4sLHHLLHH" format_chunk = struct.pack( format_chunk_spec, - "fmt ", # Chunk id + "fmt ".encode(encoding='UTF-8'), # Chunk id 16, # Size of this chunk (excluding chunk id and this field) 1, # Audio format, 1 for PCM channels, # Number of channels samplerate, # Samplerate, 44100, 48000, etc. - samplerate * channels * (bitspersample / 8), # Byterate - channels * (bitspersample / 8), # Blockalign - bitspersample, # 16 bits for two byte samples, etc. + samplerate * channels * (bitspersample // 8), # Byterate + channels * (bitspersample // 8), # Blockalign + bitspersample, # 16 bits for two byte samples, etc. => A METTRE A JOUR - POUR TEST ) # Generate data chunk data_chunk_spec = "<4sL" datasize = numsamples * channels * (bitspersample / 8) data_chunk = struct.pack( data_chunk_spec, - "data", # Chunk id + "data".encode(encoding='UTF-8'), # Chunk id int(datasize), # Chunk size (excluding chunk id and this field) ) sum_items = [ @@ -264,9 +261,9 @@ def create_wave_header(duration): main_header_spec = "<4sL4s" main_header = struct.pack( main_header_spec, - "RIFF", + "RIFF".encode(encoding='UTF-8'), all_cunks_size, - "WAVE" + "WAVE".encode(encoding='UTF-8') ) # Write all the contents in file.write(main_header) @@ -313,10 +310,11 @@ def parse_spotify_track(track, is_album_track=True, silenced=False, is_connect=F thumb = "DefaultMusicSongs" duration = track['duration_ms'] / 1000 - if silenced: - url = "http://localhost:%s/silence/%s" % (PROXY_PORT, duration) - else: - url = "http://localhost:%s/track/%s/%s" % (PROXY_PORT, track['id'], duration) + #if silenced: + # url = "http://localhost:%s/silence/%s" % (PROXY_PORT, duration) + #else: + # url = "http://localhost:%s/track/%s/%s" % (PROXY_PORT, track['id'], duration) + url = "http://localhost:%s/track/%s/%s" % (PROXY_PORT, track['id'], duration) if is_connect or silenced: url += "/?connect=true" @@ -347,7 +345,7 @@ def parse_spotify_track(track, is_album_track=True, silenced=False, is_connect=F def get_chunks(data, chunksize): - return[data[x:x + chunksize] for x in xrange(0, len(data), chunksize)] + return[data[x:x + chunksize] for x in range(0, len(data), chunksize)] def try_encode(text, encoding="utf-8"): @@ -384,7 +382,7 @@ def normalize_string(text): def get_playername(): - playername = xbmc.getInfoLabel("System.FriendlyName").decode("utf-8") + playername = xbmc.getInfoLabel("System.FriendlyName") if playername == "Kodi": import socket playername = "Kodi - %s" % socket.gethostname() @@ -404,7 +402,7 @@ class Spotty(object): def __init__(self): '''initialize with default values''' - self.__cache_path = xbmc.translatePath("special://profile/addon_data/%s/" % ADDON_ID).decode("utf-8") + self.__cache_path = xbmcvfs.translatePath("special://profile/addon_data/%s/" % ADDON_ID) self.playername = get_playername() self.__spotty_binary = self.get_spotty_binary() @@ -422,12 +420,14 @@ def test_spotty(self, binary_path): args = [ binary_path, "-n", "selftest", - "-x", "--disable-discovery" + "--disable-discovery", + "-x", + "-v" ] startupinfo = None if os.name == 'nt': startupinfo = subprocess.STARTUPINFO() - startupinfo.dwFlags |= subprocess._subprocess.STARTF_USESHOWWINDOW + startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW spotty = subprocess.Popen( args, startupinfo=startupinfo, @@ -436,7 +436,7 @@ def test_spotty(self, binary_path): bufsize=0) stdout, stderr = spotty.communicate() log_msg(stdout) - if "ok spotty" in stdout: + if "ok spotty".encode(encoding='UTF-8') in stdout: return True elif xbmc.getCondVisibility("System.Platform.Windows"): log_msg("Unable to initialize spotty binary for playback." @@ -445,20 +445,22 @@ def test_spotty(self, binary_path): log_exception(__name__, exc) return False - def run_spotty(self, arguments=None, use_creds=False, disable_discovery=True, ap_port="443"): + def run_spotty(self, arguments=None, use_creds=False, disable_discovery=False, ap_port="54443"): '''On supported platforms we include spotty binary''' try: args = [ self.__spotty_binary, "-c", self.__cache_path, -## "-b", "320" - "--ap-port",ap_port + "-b", "320", + "-v", + "--enable-audio-cache", + "--ap-port",ap_port ] if use_creds: # use username/password login for spotty addon = xbmcaddon.Addon(id=ADDON_ID) - username = addon.getSetting("username").decode("utf-8") - password = addon.getSetting("password").decode("utf-8") + username = addon.getSetting("username") + password = addon.getSetting("password") del addon if username and password: args += ["-u", username, "-p", password] @@ -471,13 +473,24 @@ def run_spotty(self, arguments=None, use_creds=False, disable_discovery=True, ap startupinfo = None if os.name == 'nt': startupinfo = subprocess.STARTUPINFO() - startupinfo.dwFlags |= subprocess._subprocess.STARTF_USESHOWWINDOW + startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW return subprocess.Popen(args, startupinfo=startupinfo, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) except Exception as exc: log_exception(__name__, exc) return None + def kill_spotty(self): + '''make sure we don't have any (remaining) spotty processes running before we start one''' + if xbmc.getCondVisibility("System.Platform.Windows"): + startupinfo = subprocess.STARTUPINFO() + startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW + subprocess.Popen(["taskkill", "/IM", "spotty.exe"], startupinfo=startupinfo, shell=True) + else: + if self.__spotty_binary != None: + sp_binary_file = os.path.basename(self.__spotty_binary) + os.system("killall " + sp_binary_file) + def get_spotty_binary(self): '''find the correct spotty binary belonging to the platform''' sp_binary = None @@ -516,7 +529,7 @@ def get_spotty_binary(self): def get_username(self): ''' obtain/check (last) username of the credentials obtained by spotify connect''' username = "" - cred_file = xbmc.translatePath("special://profile/addon_data/%s/credentials.json" % ADDON_ID).decode("utf-8") + cred_file = xbmcvfs.translatePath("special://profile/addon_data/%s/credentials.json" % ADDON_ID) if xbmcvfs.exists(cred_file): with open(cred_file) as cred_file: data = cred_file.read() diff --git a/resources/lib/zc/__init__.py b/resources/lib/zc/__init__.py new file mode 100644 index 0000000..146c336 --- /dev/null +++ b/resources/lib/zc/__init__.py @@ -0,0 +1 @@ +__namespace__ = 'zc' \ No newline at end of file diff --git a/resources/lib/zc/lockfile/README.txt b/resources/lib/zc/lockfile/README.txt new file mode 100644 index 0000000..89ef33e --- /dev/null +++ b/resources/lib/zc/lockfile/README.txt @@ -0,0 +1,70 @@ +Lock file support +================= + +The ZODB lock_file module provides support for creating file system +locks. These are locks that are implemented with lock files and +OS-provided locking facilities. To create a lock, instantiate a +LockFile object with a file name: + + >>> import zc.lockfile + >>> lock = zc.lockfile.LockFile('lock') + +If we try to lock the same name, we'll get a lock error: + + >>> import zope.testing.loggingsupport + >>> handler = zope.testing.loggingsupport.InstalledHandler('zc.lockfile') + >>> try: + ... zc.lockfile.LockFile('lock') + ... except zc.lockfile.LockError: + ... print("Can't lock file") + Can't lock file + +.. We don't log failure to acquire. + + >>> for record in handler.records: # doctest: +ELLIPSIS + ... print(record.levelname+' '+record.getMessage()) + +To release the lock, use it's close method: + + >>> lock.close() + +The lock file is not removed. It is left behind: + + >>> import os + >>> os.path.exists('lock') + True + +Of course, now that we've released the lock, we can create it again: + + >>> lock = zc.lockfile.LockFile('lock') + >>> lock.close() + +.. Cleanup + + >>> import os + >>> os.remove('lock') + +Hostname in lock file +===================== + +In a container environment (e.g. Docker), the PID is typically always +identical even if multiple containers are running under the same operating +system instance. + +Clearly, inspecting lock files doesn't then help much in debugging. To identify +the container which created the lock file, we need information about the +container in the lock file. Since Docker uses the container identifier or name +as the hostname, this information can be stored in the lock file in addition to +or instead of the PID. + +Use the ``content_template`` keyword argument to ``LockFile`` to specify a +custom lock file content format: + + >>> lock = zc.lockfile.LockFile('lock', content_template='{pid};{hostname}') + >>> lock.close() + +If you now inspected the lock file, you would see e.g.: + + $ cat lock + 123;myhostname + diff --git a/resources/lib/zc/lockfile/__init__.py b/resources/lib/zc/lockfile/__init__.py new file mode 100644 index 0000000..a0ac2ff --- /dev/null +++ b/resources/lib/zc/lockfile/__init__.py @@ -0,0 +1,104 @@ +############################################################################## +# +# Copyright (c) 2001, 2002 Zope Foundation and Contributors. +# All Rights Reserved. +# +# This software is subject to the provisions of the Zope Public License, +# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. +# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED +# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS +# FOR A PARTICULAR PURPOSE +# +############################################################################## + +import os +import errno +import logging +logger = logging.getLogger("zc.lockfile") + +class LockError(Exception): + """Couldn't get a lock + """ + +try: + import fcntl +except ImportError: + try: + import msvcrt + except ImportError: + def _lock_file(file): + raise TypeError('No file-locking support on this platform') + def _unlock_file(file): + raise TypeError('No file-locking support on this platform') + + else: + # Windows + def _lock_file(file): + # Lock just the first byte + try: + msvcrt.locking(file.fileno(), msvcrt.LK_NBLCK, 1) + except IOError: + raise LockError("Couldn't lock %r" % file.name) + + def _unlock_file(file): + try: + file.seek(0) + msvcrt.locking(file.fileno(), msvcrt.LK_UNLCK, 1) + except IOError: + raise LockError("Couldn't unlock %r" % file.name) + +else: + # Unix + _flags = fcntl.LOCK_EX | fcntl.LOCK_NB + + def _lock_file(file): + try: + fcntl.flock(file.fileno(), _flags) + except IOError: + raise LockError("Couldn't lock %r" % file.name) + + def _unlock_file(file): + fcntl.flock(file.fileno(), fcntl.LOCK_UN) + +class LazyHostName(object): + """Avoid importing socket and calling gethostname() unnecessarily""" + def __str__(self): + import socket + return socket.gethostname() + + +class LockFile: + + _fp = None + + def __init__(self, path, content_template='{pid}'): + self._path = path + try: + # Try to open for writing without truncation: + fp = open(path, 'r+') + except IOError: + # If the file doesn't exist, we'll get an IO error, try a+ + # Note that there may be a race here. Multiple processes + # could fail on the r+ open and open the file a+, but only + # one will get the the lock and write a pid. + fp = open(path, 'a+') + + try: + _lock_file(fp) + except: + fp.close() + raise + + # We got the lock, record info in the file. + self._fp = fp + fp.write(" %s\n" % content_template.format(pid=os.getpid(), + hostname=LazyHostName())) + fp.truncate() + fp.flush() + + def close(self): + if self._fp is not None: + _unlock_file(self._fp) + self._fp.close() + self._fp = None diff --git a/resources/lib/zc/lockfile/tests.py b/resources/lib/zc/lockfile/tests.py new file mode 100644 index 0000000..e9fcbff --- /dev/null +++ b/resources/lib/zc/lockfile/tests.py @@ -0,0 +1,193 @@ +############################################################################## +# +# Copyright (c) 2004 Zope Foundation and Contributors. +# All Rights Reserved. +# +# This software is subject to the provisions of the Zope Public License, +# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. +# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED +# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS +# FOR A PARTICULAR PURPOSE. +# +############################################################################## +import os, re, sys, unittest, doctest +import zc.lockfile, time, threading +from zope.testing import renormalizing, setupstack +import tempfile +try: + from unittest.mock import Mock, patch +except ImportError: + from mock import Mock, patch + +checker = renormalizing.RENormalizing([ + # Python 3 adds module path to error class name. + (re.compile("zc\.lockfile\.LockError:"), + r"LockError:"), + ]) + +def inc(): + while 1: + try: + lock = zc.lockfile.LockFile('f.lock') + except zc.lockfile.LockError: + continue + else: + break + f = open('f', 'r+b') + v = int(f.readline().strip()) + time.sleep(0.01) + v += 1 + f.seek(0) + f.write(('%d\n' % v).encode('ASCII')) + f.close() + lock.close() + +def many_threads_read_and_write(): + r""" + >>> with open('f', 'w+b') as file: + ... _ = file.write(b'0\n') + >>> with open('f.lock', 'w+b') as file: + ... _ = file.write(b'0\n') + + >>> n = 50 + >>> threads = [threading.Thread(target=inc) for i in range(n)] + >>> _ = [thread.start() for thread in threads] + >>> _ = [thread.join() for thread in threads] + >>> with open('f', 'rb') as file: + ... saved = int(file.read().strip()) + >>> saved == n + True + + >>> os.remove('f') + + We should only have one pid in the lock file: + + >>> f = open('f.lock') + >>> len(f.read().strip().split()) + 1 + >>> f.close() + + >>> os.remove('f.lock') + + """ + +def pid_in_lockfile(): + r""" + >>> import os, zc.lockfile + >>> pid = os.getpid() + >>> lock = zc.lockfile.LockFile("f.lock") + >>> f = open("f.lock") + >>> _ = f.seek(1) + >>> f.read().strip() == str(pid) + True + >>> f.close() + + Make sure that locking twice does not overwrite the old pid: + + >>> lock = zc.lockfile.LockFile("f.lock") + Traceback (most recent call last): + ... + LockError: Couldn't lock 'f.lock' + + >>> f = open("f.lock") + >>> _ = f.seek(1) + >>> f.read().strip() == str(pid) + True + >>> f.close() + + >>> lock.close() + """ + + +def hostname_in_lockfile(): + r""" + hostname is correctly written into the lock file when it's included in the + lock file content template + + >>> import zc.lockfile + >>> with patch('socket.gethostname', Mock(return_value='myhostname')): + ... lock = zc.lockfile.LockFile("f.lock", content_template='{hostname}') + >>> f = open("f.lock") + >>> _ = f.seek(1) + >>> f.read().rstrip() + 'myhostname' + >>> f.close() + + Make sure that locking twice does not overwrite the old hostname: + + >>> lock = zc.lockfile.LockFile("f.lock", content_template='{hostname}') + Traceback (most recent call last): + ... + LockError: Couldn't lock 'f.lock' + + >>> f = open("f.lock") + >>> _ = f.seek(1) + >>> f.read().rstrip() + 'myhostname' + >>> f.close() + + >>> lock.close() + """ + + +class TestLogger(object): + def __init__(self): + self.log_entries = [] + + def exception(self, msg, *args): + self.log_entries.append((msg,) + args) + + +class LockFileLogEntryTestCase(unittest.TestCase): + """Tests for logging in case of lock failure""" + def setUp(self): + self.here = os.getcwd() + self.tmp = tempfile.mkdtemp(prefix='zc.lockfile-test-') + os.chdir(self.tmp) + + def tearDown(self): + os.chdir(self.here) + setupstack.rmtree(self.tmp) + + def test_log_formatting(self): + # PID and hostname are parsed and logged from lock file on failure + with patch('os.getpid', Mock(return_value=123)): + with patch('socket.gethostname', Mock(return_value='myhostname')): + lock = zc.lockfile.LockFile('f.lock', + content_template='{pid}/{hostname}') + with open('f.lock') as f: + self.assertEqual(' 123/myhostname\n', f.read()) + + lock.close() + + def test_unlock_and_lock_while_multiprocessing_process_running(self): + import multiprocessing + + lock = zc.lockfile.LockFile('l') + q = multiprocessing.Queue() + p = multiprocessing.Process(target=q.get) + p.daemon = True + p.start() + + # release and re-acquire should work (obviously) + lock.close() + lock = zc.lockfile.LockFile('l') + self.assertTrue(p.is_alive()) + + q.put(0) + lock.close() + p.join() + + +def test_suite(): + suite = unittest.TestSuite() + suite.addTest(doctest.DocFileSuite( + 'README.txt', checker=checker, + setUp=setupstack.setUpDirectory, tearDown=setupstack.tearDown)) + suite.addTest(doctest.DocTestSuite( + setUp=setupstack.setUpDirectory, tearDown=setupstack.tearDown, + checker=checker)) + # Add unittest test cases from this module + suite.addTest(unittest.defaultTestLoader.loadTestsFromName(__name__)) + return suite