add a combiner source and the underlying mechanism to wait for other entries' results

This commit is contained in:
lilydjwg 2021-06-08 14:55:57 +08:00
parent d83d8d5367
commit ae506ba9cf
9 changed files with 102 additions and 15 deletions

View File

@ -18,3 +18,8 @@
.. autodata:: nvchecker.api.proxy .. autodata:: nvchecker.api.proxy
.. autodata:: nvchecker.api.user_agent .. autodata:: nvchecker.api.user_agent
.. autodata:: nvchecker.api.tries .. autodata:: nvchecker.api.tries
.. py:data:: nvchecker.api.entry_waiter
:type: contextvars.ContextVar
This :class:`ContextVar <contextvars.ContextVar>` contains an :class:`EntryWaiter <nvchecker.api.EntryWaiter>` instance for waiting on other entries.

View File

@ -13,7 +13,7 @@ from pathlib import Path
import structlog import structlog
from . import core from . import core
from .util import VersData, RawResult, KeyManager from .util import VersData, RawResult, KeyManager, EntryWaiter
from .ctxvars import proxy as ctx_proxy from .ctxvars import proxy as ctx_proxy
logger = structlog.get_logger(logger_name=__name__) logger = structlog.get_logger(logger_name=__name__)
@ -58,10 +58,12 @@ def main() -> None:
options.httplib, options.httplib,
options.http_timeout, options.http_timeout,
) )
entry_waiter = EntryWaiter()
try: try:
futures = dispatcher.dispatch( futures = dispatcher.dispatch(
entries, task_sem, result_q, entries, task_sem, result_q,
keymanager, args.tries, keymanager, entry_waiter,
args.tries,
options.source_configs, options.source_configs,
) )
except ModuleNotFoundError as e: except ModuleNotFoundError as e:
@ -71,7 +73,7 @@ def main() -> None:
oldvers = core.read_verfile(options.ver_files[0]) oldvers = core.read_verfile(options.ver_files[0])
else: else:
oldvers = {} oldvers = {}
result_coro = core.process_result(oldvers, result_q) result_coro = core.process_result(oldvers, result_q, entry_waiter)
runner_coro = core.run_tasks(futures) runner_coro = core.run_tasks(futures)
# asyncio.run doesn't work because it always creates new eventloops # asyncio.run doesn't work because it always creates new eventloops

View File

@ -4,7 +4,7 @@
from .httpclient import session, TemporaryError, HTTPError from .httpclient import session, TemporaryError, HTTPError
from .util import ( from .util import (
Entry, BaseWorker, RawResult, VersionResult, Entry, BaseWorker, RawResult, VersionResult,
AsyncCache, KeyManager, GetVersionError, AsyncCache, KeyManager, GetVersionError, EntryWaiter,
) )
from .sortversion import sort_version_keys from .sortversion import sort_version_keys
from .ctxvars import tries, proxy, user_agent from .ctxvars import tries, proxy, user_agent, entry_waiter

View File

@ -29,11 +29,12 @@ from . import slogconf
from .util import ( from .util import (
Entry, Entries, KeyManager, RawResult, Result, VersData, Entry, Entries, KeyManager, RawResult, Result, VersData,
FunctionWorker, GetVersionError, FunctionWorker, GetVersionError,
FileLoadError, FileLoadError, EntryWaiter,
) )
from . import __version__ from . import __version__
from .sortversion import sort_version_keys from .sortversion import sort_version_keys
from .ctxvars import tries as ctx_tries from .ctxvars import tries as ctx_tries
from .ctxvars import entry_waiter as ctx_entry_waiter
from . import httpclient from . import httpclient
logger = structlog.get_logger(logger_name=__name__) logger = structlog.get_logger(logger_name=__name__)
@ -219,11 +220,13 @@ class Dispatcher:
task_sem: asyncio.Semaphore, task_sem: asyncio.Semaphore,
result_q: Queue[RawResult], result_q: Queue[RawResult],
keymanager: KeyManager, keymanager: KeyManager,
entry_waiter: EntryWaiter,
tries: int, tries: int,
source_configs: Dict[str, Dict[str, Any]], source_configs: Dict[str, Dict[str, Any]],
) -> List[asyncio.Future]: ) -> List[asyncio.Future]:
mods: Dict[str, Tuple[types.ModuleType, List]] = {} mods: Dict[str, Tuple[types.ModuleType, List]] = {}
ctx_tries.set(tries) ctx_tries.set(tries)
ctx_entry_waiter.set(entry_waiter)
root_ctx = contextvars.copy_context() root_ctx = contextvars.copy_context()
for name, entry in entries.items(): for name, entry in entries.items():
@ -311,7 +314,7 @@ def apply_list_options(
return versions[-1] return versions[-1]
def _process_result(r: RawResult) -> Optional[Result]: def _process_result(r: RawResult) -> Union[Result, Exception]:
version = r.version version = r.version
conf = r.conf conf = r.conf
name = r.name name = r.name
@ -320,11 +323,11 @@ def _process_result(r: RawResult) -> Optional[Result]:
kw = version.kwargs kw = version.kwargs
kw['name'] = name kw['name'] = name
logger.error(version.msg, **kw) logger.error(version.msg, **kw)
return None return version
elif isinstance(version, Exception): elif isinstance(version, Exception):
logger.error('unexpected error happened', logger.error('unexpected error happened',
name=r.name, exc_info=r.version) name=r.name, exc_info=r.version)
return None return version
elif isinstance(version, list): elif isinstance(version, list):
version_str = apply_list_options(version, conf) version_str = apply_list_options(version, conf)
else: else:
@ -336,10 +339,12 @@ def _process_result(r: RawResult) -> Optional[Result]:
try: try:
version_str = substitute_version(version_str, conf) version_str = substitute_version(version_str, conf)
return Result(name, version_str, conf) return Result(name, version_str, conf)
except (ValueError, re.error): except (ValueError, re.error) as e:
logger.exception('error occurred in version substitutions', name=name) logger.exception('error occurred in version substitutions', name=name)
return e
return None else:
return ValueError('no version returned')
def check_version_update( def check_version_update(
oldvers: VersData, name: str, version: str, oldvers: VersData, name: str, version: str,
@ -353,15 +358,18 @@ def check_version_update(
async def process_result( async def process_result(
oldvers: VersData, oldvers: VersData,
result_q: Queue[RawResult], result_q: Queue[RawResult],
entry_waiter: EntryWaiter,
) -> VersData: ) -> VersData:
ret = {} ret = {}
try: try:
while True: while True:
r = await result_q.get() r = await result_q.get()
r1 = _process_result(r) r1 = _process_result(r)
if r1 is None: if isinstance(r1, Exception):
entry_waiter.set_exception(r.name, r1)
continue continue
check_version_update(oldvers, r1.name, r1.version) check_version_update(oldvers, r1.name, r1.version)
entry_waiter.set_result(r1.name, r1.version)
ret[r1.name] = r1.version ret[r1.name] = r1.version
except asyncio.CancelledError: except asyncio.CancelledError:
return ret return ret

View File

@ -4,12 +4,16 @@
from __future__ import annotations from __future__ import annotations
from contextvars import ContextVar from contextvars import ContextVar
from typing import Optional from typing import Optional, TYPE_CHECKING
from . import __version__ from . import __version__
DEFAULT_USER_AGENT = f'lilydjwg/nvchecker {__version__}' DEFAULT_USER_AGENT = f'lilydjwg/nvchecker {__version__}'
if TYPE_CHECKING:
from .util import EntryWaiter
tries = ContextVar('tries', default=1) tries = ContextVar('tries', default=1)
proxy: ContextVar[Optional[str]] = ContextVar('proxy', default=None) proxy: ContextVar[Optional[str]] = ContextVar('proxy', default=None)
user_agent = ContextVar('user_agent', default=DEFAULT_USER_AGENT) user_agent = ContextVar('user_agent', default=DEFAULT_USER_AGENT)
entry_waiter: ContextVar[EntryWaiter] = ContextVar('entry_waiter')

View File

@ -65,6 +65,28 @@ class KeyManager:
'''Get the named key (token) in the keyfile.''' '''Get the named key (token) in the keyfile.'''
return self.keys.get(name) return self.keys.get(name)
class EntryWaiter:
def __init__(self) -> None:
self._waiting: Dict[str, asyncio.Future] = {}
async def wait(self, name: str) -> str:
'''Wait on the ``name`` entry and return its result (the version string)'''
fu = self._waiting.get(name)
if fu is None:
fu = asyncio.Future()
self._waiting[name] = fu
return await fu
def set_result(self, name: str, value: str) -> None:
fu = self._waiting.get(name)
if fu is not None:
fu.set_result(value)
def set_exception(self, name: str, e: Exception) -> None:
fu = self._waiting.get(name)
if fu is not None:
fu.set_exception(e)
class RawResult(NamedTuple): class RawResult(NamedTuple):
'''The unprocessed result from a check.''' '''The unprocessed result from a check.'''
name: str name: str

View File

@ -0,0 +1,21 @@
# MIT licensed
# Copyright (c) 2021 lilydjwg <lilydjwg@gmail.com>, et al.
import asyncio
import string
from nvchecker.api import entry_waiter
class CombineFormat(string.Template):
idpattern = '[0-9]+'
async def get_version(
name, conf, *, cache, keymanager=None
):
t = CombineFormat(conf['format'])
from_ = conf['from']
waiter = entry_waiter.get()
entries = [waiter.wait(name) for name in from_]
vers = await asyncio.gather(*entries)
versdict = {str(i+1): v for i, v in enumerate(vers)}
return t.substitute(versdict)

View File

@ -28,13 +28,14 @@ async def run(
keymanager = core.KeyManager(None) keymanager = core.KeyManager(None)
dispatcher = core.setup_httpclient() dispatcher = core.setup_httpclient()
entry_waiter = core.EntryWaiter()
futures = dispatcher.dispatch( futures = dispatcher.dispatch(
entries, task_sem, result_q, entries, task_sem, result_q,
keymanager, 1, {}, keymanager, entry_waiter, 1, {},
) )
oldvers: VersData = {} oldvers: VersData = {}
result_coro = core.process_result(oldvers, result_q) result_coro = core.process_result(oldvers, result_q, entry_waiter)
runner_coro = core.run_tasks(futures) runner_coro = core.run_tasks(futures)
return await main.run(result_coro, runner_coro) return await main.run(result_coro, runner_coro)

24
tests/test_combiner.py Normal file
View File

@ -0,0 +1,24 @@
# MIT licensed
# Copyright (c) 2021 lilydjwg <lilydjwg@gmail.com>, et al.
import pytest
pytestmark = pytest.mark.asyncio
async def test_combiner(run_str_multi):
conf = r'''
[entry-1]
source = "cmd"
cmd = "echo 1"
[entry-2]
source = "cmd"
cmd = "echo 2"
[entry-3]
source = "combiner"
from = ["entry-1", "entry-2", "entry-2"]
format = "$1-$2-$3"
'''
r = await run_str_multi(conf)
assert r['entry-3'] == '1-2-2'