From ae506ba9cf8fbe882fcdb159aa54fc57e0efe8f0 Mon Sep 17 00:00:00 2001 From: lilydjwg Date: Tue, 8 Jun 2021 14:55:57 +0800 Subject: [PATCH] add a combiner source and the underlying mechanism to wait for other entries' results --- docs/api.rst | 5 +++++ nvchecker/__main__.py | 8 +++++--- nvchecker/api.py | 4 ++-- nvchecker/core.py | 22 +++++++++++++++------- nvchecker/ctxvars.py | 6 +++++- nvchecker/util.py | 22 ++++++++++++++++++++++ nvchecker_source/combiner.py | 21 +++++++++++++++++++++ tests/conftest.py | 5 +++-- tests/test_combiner.py | 24 ++++++++++++++++++++++++ 9 files changed, 102 insertions(+), 15 deletions(-) create mode 100644 nvchecker_source/combiner.py create mode 100644 tests/test_combiner.py diff --git a/docs/api.rst b/docs/api.rst index 5404103..8cc4c0f 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -18,3 +18,8 @@ .. autodata:: nvchecker.api.proxy .. autodata:: nvchecker.api.user_agent .. autodata:: nvchecker.api.tries + +.. py:data:: nvchecker.api.entry_waiter + :type: contextvars.ContextVar + + This :class:`ContextVar ` contains an :class:`EntryWaiter ` instance for waiting on other entries. diff --git a/nvchecker/__main__.py b/nvchecker/__main__.py index a98a85a..3f25ff5 100755 --- a/nvchecker/__main__.py +++ b/nvchecker/__main__.py @@ -13,7 +13,7 @@ from pathlib import Path import structlog from . import core -from .util import VersData, RawResult, KeyManager +from .util import VersData, RawResult, KeyManager, EntryWaiter from .ctxvars import proxy as ctx_proxy logger = structlog.get_logger(logger_name=__name__) @@ -58,10 +58,12 @@ def main() -> None: options.httplib, options.http_timeout, ) + entry_waiter = EntryWaiter() try: futures = dispatcher.dispatch( entries, task_sem, result_q, - keymanager, args.tries, + keymanager, entry_waiter, + args.tries, options.source_configs, ) except ModuleNotFoundError as e: @@ -71,7 +73,7 @@ def main() -> None: oldvers = core.read_verfile(options.ver_files[0]) else: oldvers = {} - result_coro = core.process_result(oldvers, result_q) + result_coro = core.process_result(oldvers, result_q, entry_waiter) runner_coro = core.run_tasks(futures) # asyncio.run doesn't work because it always creates new eventloops diff --git a/nvchecker/api.py b/nvchecker/api.py index dc31324..049de04 100644 --- a/nvchecker/api.py +++ b/nvchecker/api.py @@ -4,7 +4,7 @@ from .httpclient import session, TemporaryError, HTTPError from .util import ( Entry, BaseWorker, RawResult, VersionResult, - AsyncCache, KeyManager, GetVersionError, + AsyncCache, KeyManager, GetVersionError, EntryWaiter, ) from .sortversion import sort_version_keys -from .ctxvars import tries, proxy, user_agent +from .ctxvars import tries, proxy, user_agent, entry_waiter diff --git a/nvchecker/core.py b/nvchecker/core.py index ae182e4..7fe9d5e 100644 --- a/nvchecker/core.py +++ b/nvchecker/core.py @@ -29,11 +29,12 @@ from . import slogconf from .util import ( Entry, Entries, KeyManager, RawResult, Result, VersData, FunctionWorker, GetVersionError, - FileLoadError, + FileLoadError, EntryWaiter, ) from . import __version__ from .sortversion import sort_version_keys from .ctxvars import tries as ctx_tries +from .ctxvars import entry_waiter as ctx_entry_waiter from . import httpclient logger = structlog.get_logger(logger_name=__name__) @@ -219,11 +220,13 @@ class Dispatcher: task_sem: asyncio.Semaphore, result_q: Queue[RawResult], keymanager: KeyManager, + entry_waiter: EntryWaiter, tries: int, source_configs: Dict[str, Dict[str, Any]], ) -> List[asyncio.Future]: mods: Dict[str, Tuple[types.ModuleType, List]] = {} ctx_tries.set(tries) + ctx_entry_waiter.set(entry_waiter) root_ctx = contextvars.copy_context() for name, entry in entries.items(): @@ -311,7 +314,7 @@ def apply_list_options( return versions[-1] -def _process_result(r: RawResult) -> Optional[Result]: +def _process_result(r: RawResult) -> Union[Result, Exception]: version = r.version conf = r.conf name = r.name @@ -320,11 +323,11 @@ def _process_result(r: RawResult) -> Optional[Result]: kw = version.kwargs kw['name'] = name logger.error(version.msg, **kw) - return None + return version elif isinstance(version, Exception): logger.error('unexpected error happened', name=r.name, exc_info=r.version) - return None + return version elif isinstance(version, list): version_str = apply_list_options(version, conf) else: @@ -336,10 +339,12 @@ def _process_result(r: RawResult) -> Optional[Result]: try: version_str = substitute_version(version_str, conf) return Result(name, version_str, conf) - except (ValueError, re.error): + except (ValueError, re.error) as e: logger.exception('error occurred in version substitutions', name=name) + return e - return None + else: + return ValueError('no version returned') def check_version_update( oldvers: VersData, name: str, version: str, @@ -353,15 +358,18 @@ def check_version_update( async def process_result( oldvers: VersData, result_q: Queue[RawResult], + entry_waiter: EntryWaiter, ) -> VersData: ret = {} try: while True: r = await result_q.get() r1 = _process_result(r) - if r1 is None: + if isinstance(r1, Exception): + entry_waiter.set_exception(r.name, r1) continue check_version_update(oldvers, r1.name, r1.version) + entry_waiter.set_result(r1.name, r1.version) ret[r1.name] = r1.version except asyncio.CancelledError: return ret diff --git a/nvchecker/ctxvars.py b/nvchecker/ctxvars.py index e9daf71..c68c9db 100644 --- a/nvchecker/ctxvars.py +++ b/nvchecker/ctxvars.py @@ -4,12 +4,16 @@ from __future__ import annotations from contextvars import ContextVar -from typing import Optional +from typing import Optional, TYPE_CHECKING from . import __version__ DEFAULT_USER_AGENT = f'lilydjwg/nvchecker {__version__}' +if TYPE_CHECKING: + from .util import EntryWaiter + tries = ContextVar('tries', default=1) proxy: ContextVar[Optional[str]] = ContextVar('proxy', default=None) user_agent = ContextVar('user_agent', default=DEFAULT_USER_AGENT) +entry_waiter: ContextVar[EntryWaiter] = ContextVar('entry_waiter') diff --git a/nvchecker/util.py b/nvchecker/util.py index b443a43..9a9ee33 100644 --- a/nvchecker/util.py +++ b/nvchecker/util.py @@ -65,6 +65,28 @@ class KeyManager: '''Get the named key (token) in the keyfile.''' return self.keys.get(name) +class EntryWaiter: + def __init__(self) -> None: + self._waiting: Dict[str, asyncio.Future] = {} + + async def wait(self, name: str) -> str: + '''Wait on the ``name`` entry and return its result (the version string)''' + fu = self._waiting.get(name) + if fu is None: + fu = asyncio.Future() + self._waiting[name] = fu + return await fu + + def set_result(self, name: str, value: str) -> None: + fu = self._waiting.get(name) + if fu is not None: + fu.set_result(value) + + def set_exception(self, name: str, e: Exception) -> None: + fu = self._waiting.get(name) + if fu is not None: + fu.set_exception(e) + class RawResult(NamedTuple): '''The unprocessed result from a check.''' name: str diff --git a/nvchecker_source/combiner.py b/nvchecker_source/combiner.py new file mode 100644 index 0000000..9e7dc6c --- /dev/null +++ b/nvchecker_source/combiner.py @@ -0,0 +1,21 @@ +# MIT licensed +# Copyright (c) 2021 lilydjwg , et al. + +import asyncio +import string + +from nvchecker.api import entry_waiter + +class CombineFormat(string.Template): + idpattern = '[0-9]+' + +async def get_version( + name, conf, *, cache, keymanager=None +): + t = CombineFormat(conf['format']) + from_ = conf['from'] + waiter = entry_waiter.get() + entries = [waiter.wait(name) for name in from_] + vers = await asyncio.gather(*entries) + versdict = {str(i+1): v for i, v in enumerate(vers)} + return t.substitute(versdict) diff --git a/tests/conftest.py b/tests/conftest.py index 2f2ba4e..1b0e2e0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -28,13 +28,14 @@ async def run( keymanager = core.KeyManager(None) dispatcher = core.setup_httpclient() + entry_waiter = core.EntryWaiter() futures = dispatcher.dispatch( entries, task_sem, result_q, - keymanager, 1, {}, + keymanager, entry_waiter, 1, {}, ) oldvers: VersData = {} - result_coro = core.process_result(oldvers, result_q) + result_coro = core.process_result(oldvers, result_q, entry_waiter) runner_coro = core.run_tasks(futures) return await main.run(result_coro, runner_coro) diff --git a/tests/test_combiner.py b/tests/test_combiner.py new file mode 100644 index 0000000..21fcbf0 --- /dev/null +++ b/tests/test_combiner.py @@ -0,0 +1,24 @@ +# MIT licensed +# Copyright (c) 2021 lilydjwg , et al. + +import pytest +pytestmark = pytest.mark.asyncio + +async def test_combiner(run_str_multi): + conf = r''' +[entry-1] +source = "cmd" +cmd = "echo 1" + +[entry-2] +source = "cmd" +cmd = "echo 2" + +[entry-3] +source = "combiner" +from = ["entry-1", "entry-2", "entry-2"] +format = "$1-$2-$3" +''' + + r = await run_str_multi(conf) + assert r['entry-3'] == '1-2-2'