mirror of
https://github.com/lilydjwg/nvchecker
synced 2024-12-24 15:42:46 +00:00
add a combiner source and the underlying mechanism to wait for other entries' results
This commit is contained in:
parent
d83d8d5367
commit
ae506ba9cf
@ -18,3 +18,8 @@
|
||||
.. autodata:: nvchecker.api.proxy
|
||||
.. autodata:: nvchecker.api.user_agent
|
||||
.. autodata:: nvchecker.api.tries
|
||||
|
||||
.. py:data:: nvchecker.api.entry_waiter
|
||||
:type: contextvars.ContextVar
|
||||
|
||||
This :class:`ContextVar <contextvars.ContextVar>` contains an :class:`EntryWaiter <nvchecker.api.EntryWaiter>` instance for waiting on other entries.
|
||||
|
@ -13,7 +13,7 @@ from pathlib import Path
|
||||
import structlog
|
||||
|
||||
from . import core
|
||||
from .util import VersData, RawResult, KeyManager
|
||||
from .util import VersData, RawResult, KeyManager, EntryWaiter
|
||||
from .ctxvars import proxy as ctx_proxy
|
||||
|
||||
logger = structlog.get_logger(logger_name=__name__)
|
||||
@ -58,10 +58,12 @@ def main() -> None:
|
||||
options.httplib,
|
||||
options.http_timeout,
|
||||
)
|
||||
entry_waiter = EntryWaiter()
|
||||
try:
|
||||
futures = dispatcher.dispatch(
|
||||
entries, task_sem, result_q,
|
||||
keymanager, args.tries,
|
||||
keymanager, entry_waiter,
|
||||
args.tries,
|
||||
options.source_configs,
|
||||
)
|
||||
except ModuleNotFoundError as e:
|
||||
@ -71,7 +73,7 @@ def main() -> None:
|
||||
oldvers = core.read_verfile(options.ver_files[0])
|
||||
else:
|
||||
oldvers = {}
|
||||
result_coro = core.process_result(oldvers, result_q)
|
||||
result_coro = core.process_result(oldvers, result_q, entry_waiter)
|
||||
runner_coro = core.run_tasks(futures)
|
||||
|
||||
# asyncio.run doesn't work because it always creates new eventloops
|
||||
|
@ -4,7 +4,7 @@
|
||||
from .httpclient import session, TemporaryError, HTTPError
|
||||
from .util import (
|
||||
Entry, BaseWorker, RawResult, VersionResult,
|
||||
AsyncCache, KeyManager, GetVersionError,
|
||||
AsyncCache, KeyManager, GetVersionError, EntryWaiter,
|
||||
)
|
||||
from .sortversion import sort_version_keys
|
||||
from .ctxvars import tries, proxy, user_agent
|
||||
from .ctxvars import tries, proxy, user_agent, entry_waiter
|
||||
|
@ -29,11 +29,12 @@ from . import slogconf
|
||||
from .util import (
|
||||
Entry, Entries, KeyManager, RawResult, Result, VersData,
|
||||
FunctionWorker, GetVersionError,
|
||||
FileLoadError,
|
||||
FileLoadError, EntryWaiter,
|
||||
)
|
||||
from . import __version__
|
||||
from .sortversion import sort_version_keys
|
||||
from .ctxvars import tries as ctx_tries
|
||||
from .ctxvars import entry_waiter as ctx_entry_waiter
|
||||
from . import httpclient
|
||||
|
||||
logger = structlog.get_logger(logger_name=__name__)
|
||||
@ -219,11 +220,13 @@ class Dispatcher:
|
||||
task_sem: asyncio.Semaphore,
|
||||
result_q: Queue[RawResult],
|
||||
keymanager: KeyManager,
|
||||
entry_waiter: EntryWaiter,
|
||||
tries: int,
|
||||
source_configs: Dict[str, Dict[str, Any]],
|
||||
) -> List[asyncio.Future]:
|
||||
mods: Dict[str, Tuple[types.ModuleType, List]] = {}
|
||||
ctx_tries.set(tries)
|
||||
ctx_entry_waiter.set(entry_waiter)
|
||||
root_ctx = contextvars.copy_context()
|
||||
|
||||
for name, entry in entries.items():
|
||||
@ -311,7 +314,7 @@ def apply_list_options(
|
||||
|
||||
return versions[-1]
|
||||
|
||||
def _process_result(r: RawResult) -> Optional[Result]:
|
||||
def _process_result(r: RawResult) -> Union[Result, Exception]:
|
||||
version = r.version
|
||||
conf = r.conf
|
||||
name = r.name
|
||||
@ -320,11 +323,11 @@ def _process_result(r: RawResult) -> Optional[Result]:
|
||||
kw = version.kwargs
|
||||
kw['name'] = name
|
||||
logger.error(version.msg, **kw)
|
||||
return None
|
||||
return version
|
||||
elif isinstance(version, Exception):
|
||||
logger.error('unexpected error happened',
|
||||
name=r.name, exc_info=r.version)
|
||||
return None
|
||||
return version
|
||||
elif isinstance(version, list):
|
||||
version_str = apply_list_options(version, conf)
|
||||
else:
|
||||
@ -336,10 +339,12 @@ def _process_result(r: RawResult) -> Optional[Result]:
|
||||
try:
|
||||
version_str = substitute_version(version_str, conf)
|
||||
return Result(name, version_str, conf)
|
||||
except (ValueError, re.error):
|
||||
except (ValueError, re.error) as e:
|
||||
logger.exception('error occurred in version substitutions', name=name)
|
||||
return e
|
||||
|
||||
return None
|
||||
else:
|
||||
return ValueError('no version returned')
|
||||
|
||||
def check_version_update(
|
||||
oldvers: VersData, name: str, version: str,
|
||||
@ -353,15 +358,18 @@ def check_version_update(
|
||||
async def process_result(
|
||||
oldvers: VersData,
|
||||
result_q: Queue[RawResult],
|
||||
entry_waiter: EntryWaiter,
|
||||
) -> VersData:
|
||||
ret = {}
|
||||
try:
|
||||
while True:
|
||||
r = await result_q.get()
|
||||
r1 = _process_result(r)
|
||||
if r1 is None:
|
||||
if isinstance(r1, Exception):
|
||||
entry_waiter.set_exception(r.name, r1)
|
||||
continue
|
||||
check_version_update(oldvers, r1.name, r1.version)
|
||||
entry_waiter.set_result(r1.name, r1.version)
|
||||
ret[r1.name] = r1.version
|
||||
except asyncio.CancelledError:
|
||||
return ret
|
||||
|
@ -4,12 +4,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextvars import ContextVar
|
||||
from typing import Optional
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
|
||||
from . import __version__
|
||||
|
||||
DEFAULT_USER_AGENT = f'lilydjwg/nvchecker {__version__}'
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .util import EntryWaiter
|
||||
|
||||
tries = ContextVar('tries', default=1)
|
||||
proxy: ContextVar[Optional[str]] = ContextVar('proxy', default=None)
|
||||
user_agent = ContextVar('user_agent', default=DEFAULT_USER_AGENT)
|
||||
entry_waiter: ContextVar[EntryWaiter] = ContextVar('entry_waiter')
|
||||
|
@ -65,6 +65,28 @@ class KeyManager:
|
||||
'''Get the named key (token) in the keyfile.'''
|
||||
return self.keys.get(name)
|
||||
|
||||
class EntryWaiter:
|
||||
def __init__(self) -> None:
|
||||
self._waiting: Dict[str, asyncio.Future] = {}
|
||||
|
||||
async def wait(self, name: str) -> str:
|
||||
'''Wait on the ``name`` entry and return its result (the version string)'''
|
||||
fu = self._waiting.get(name)
|
||||
if fu is None:
|
||||
fu = asyncio.Future()
|
||||
self._waiting[name] = fu
|
||||
return await fu
|
||||
|
||||
def set_result(self, name: str, value: str) -> None:
|
||||
fu = self._waiting.get(name)
|
||||
if fu is not None:
|
||||
fu.set_result(value)
|
||||
|
||||
def set_exception(self, name: str, e: Exception) -> None:
|
||||
fu = self._waiting.get(name)
|
||||
if fu is not None:
|
||||
fu.set_exception(e)
|
||||
|
||||
class RawResult(NamedTuple):
|
||||
'''The unprocessed result from a check.'''
|
||||
name: str
|
||||
|
21
nvchecker_source/combiner.py
Normal file
21
nvchecker_source/combiner.py
Normal file
@ -0,0 +1,21 @@
|
||||
# MIT licensed
|
||||
# Copyright (c) 2021 lilydjwg <lilydjwg@gmail.com>, et al.
|
||||
|
||||
import asyncio
|
||||
import string
|
||||
|
||||
from nvchecker.api import entry_waiter
|
||||
|
||||
class CombineFormat(string.Template):
|
||||
idpattern = '[0-9]+'
|
||||
|
||||
async def get_version(
|
||||
name, conf, *, cache, keymanager=None
|
||||
):
|
||||
t = CombineFormat(conf['format'])
|
||||
from_ = conf['from']
|
||||
waiter = entry_waiter.get()
|
||||
entries = [waiter.wait(name) for name in from_]
|
||||
vers = await asyncio.gather(*entries)
|
||||
versdict = {str(i+1): v for i, v in enumerate(vers)}
|
||||
return t.substitute(versdict)
|
@ -28,13 +28,14 @@ async def run(
|
||||
keymanager = core.KeyManager(None)
|
||||
|
||||
dispatcher = core.setup_httpclient()
|
||||
entry_waiter = core.EntryWaiter()
|
||||
futures = dispatcher.dispatch(
|
||||
entries, task_sem, result_q,
|
||||
keymanager, 1, {},
|
||||
keymanager, entry_waiter, 1, {},
|
||||
)
|
||||
|
||||
oldvers: VersData = {}
|
||||
result_coro = core.process_result(oldvers, result_q)
|
||||
result_coro = core.process_result(oldvers, result_q, entry_waiter)
|
||||
runner_coro = core.run_tasks(futures)
|
||||
|
||||
return await main.run(result_coro, runner_coro)
|
||||
|
24
tests/test_combiner.py
Normal file
24
tests/test_combiner.py
Normal file
@ -0,0 +1,24 @@
|
||||
# MIT licensed
|
||||
# Copyright (c) 2021 lilydjwg <lilydjwg@gmail.com>, et al.
|
||||
|
||||
import pytest
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
async def test_combiner(run_str_multi):
|
||||
conf = r'''
|
||||
[entry-1]
|
||||
source = "cmd"
|
||||
cmd = "echo 1"
|
||||
|
||||
[entry-2]
|
||||
source = "cmd"
|
||||
cmd = "echo 2"
|
||||
|
||||
[entry-3]
|
||||
source = "combiner"
|
||||
from = ["entry-1", "entry-2", "entry-2"]
|
||||
format = "$1-$2-$3"
|
||||
'''
|
||||
|
||||
r = await run_str_multi(conf)
|
||||
assert r['entry-3'] == '1-2-2'
|
Loading…
Reference in New Issue
Block a user