diff --git a/nvchecker/__main__.py b/nvchecker/__main__.py index 602e7b9..76d0c59 100755 --- a/nvchecker/__main__.py +++ b/nvchecker/__main__.py @@ -44,11 +44,11 @@ def main() -> None: if options.proxy is not None: ctx_proxy.set(options.proxy) - token_q = core.token_queue(options.max_concurrency) + task_sem = asyncio.Semaphore(options.max_concurrency) result_q: asyncio.Queue[RawResult] = asyncio.Queue() try: futures = core.dispatch( - entries, token_q, result_q, + entries, task_sem, result_q, keymanager, args.tries, ) except ModuleNotFoundError as e: diff --git a/nvchecker/core.py b/nvchecker/core.py index 8b320d5..0dec714 100644 --- a/nvchecker/core.py +++ b/nvchecker/core.py @@ -193,17 +193,9 @@ def load_file( return cast(Entries, config), Options( ver_files, max_concurrency, proxy, keymanager) -def token_queue(maxsize: int) -> Queue[bool]: - token_q: Queue[bool] = Queue(maxsize=maxsize) - - for _ in range(maxsize): - token_q.put_nowait(True) - - return token_q - def dispatch( entries: Entries, - token_q: Queue[bool], + task_sem: asyncio.Semaphore, result_q: Queue[RawResult], keymanager: KeyManager, tries: int, @@ -232,7 +224,7 @@ def dispatch( ctx = root_ctx.copy() worker = ctx.run( worker_cls, - token_q, result_q, tasks, keymanager, + task_sem, result_q, tasks, keymanager, ) if worker_cls is FunctionWorker: func = mod.get_version # type: ignore diff --git a/nvchecker/util.py b/nvchecker/util.py index d35ff70..7c167af 100644 --- a/nvchecker/util.py +++ b/nvchecker/util.py @@ -5,9 +5,8 @@ from __future__ import annotations import asyncio from asyncio import Queue -import contextlib from typing import ( - Dict, Optional, List, AsyncGenerator, NamedTuple, Union, + Dict, Optional, List, NamedTuple, Union, Any, Tuple, Callable, Coroutine, Hashable, TYPE_CHECKING, ) @@ -72,10 +71,10 @@ class Result(NamedTuple): class BaseWorker: '''The base class for defining `Worker` classes for source plugins. - .. py:attribute:: token_q - :type: Queue[bool] + .. py:attribute:: task_sem + :type: asyncio.Semaphore - This is the rate-limiting queue. Workers should obtain one token before doing one unit of work. + This is the rate-limiting semaphore. Workers should acquire it while doing one unit of work. .. py:attribute:: result_q :type: Queue[RawResult] @@ -96,28 +95,16 @@ class BaseWorker: ''' def __init__( self, - token_q: Queue[bool], + task_sem: asyncio.Semaphore, result_q: Queue[RawResult], tasks: List[Tuple[str, Entry]], keymanager: KeyManager, ) -> None: - self.token_q = token_q + self.task_sem = task_sem self.result_q = result_q self.keymanager = keymanager self.tasks = tasks - @contextlib.asynccontextmanager - async def acquire_token(self) -> AsyncGenerator[None, None]: - '''A context manager to obtain a token from the `token_q` on entrance and - release it on exit.''' - token = await self.token_q.get() - logger.debug('got token') - try: - yield - finally: - await self.token_q.put(token) - logger.debug('return token') - @abc.abstractmethod async def run(self) -> None: '''Run the `tasks`. Subclasses should implement this method.''' @@ -227,7 +214,7 @@ class FunctionWorker(BaseWorker): ctx_ua.set(ua) try: - async with self.acquire_token(): + async with self.task_sem: version = await self.func( name, entry, cache = self.cache, diff --git a/nvchecker_source/apt.py b/nvchecker_source/apt.py index 147e6d1..b1bedb9 100644 --- a/nvchecker_source/apt.py +++ b/nvchecker_source/apt.py @@ -20,8 +20,6 @@ def _decompress_data(url: str, data: bytes) -> str: elif url.endswith(".gz"): import gzip data = gzip.decompress(data) - else: - raise NotImplementedError(url) return data.decode('utf-8') diff --git a/nvchecker_source/aur.py b/nvchecker_source/aur.py index 29c87da..6373cbe 100644 --- a/nvchecker_source/aur.py +++ b/nvchecker_source/aur.py @@ -67,7 +67,7 @@ class Worker(BaseWorker): ) -> None: task_by_name: Dict[str, Entry] = dict(self.tasks) - async with self.acquire_token(): + async with self.task_sem: results = await _run_batch_impl(batch, aur_results) for name, version in results.items(): r = RawResult(name, version, task_by_name[name]) diff --git a/nvchecker_source/none.py b/nvchecker_source/none.py index 5427acd..20bf3fd 100644 --- a/nvchecker_source/none.py +++ b/nvchecker_source/none.py @@ -10,7 +10,7 @@ from nvchecker.api import ( class Worker(BaseWorker): async def run(self) -> None: exc = GetVersionError('no source specified') - async with self.acquire_token(): + async with self.task_sem: for name, conf in self.tasks: await self.result_q.put( RawResult(name, exc, conf)) diff --git a/tests/conftest.py b/tests/conftest.py index b99e69d..741d989 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -18,7 +18,7 @@ use_keyfile = False async def run( entries: Entries, max_concurrency: int = 20, ) -> VersData: - token_q = core.token_queue(max_concurrency) + task_sem = asyncio.Semaphore(max_concurrency) result_q: asyncio.Queue[RawResult] = asyncio.Queue() keyfile = os.environ.get('KEYFILE') if use_keyfile and keyfile: @@ -28,7 +28,7 @@ async def run( keymanager = core.KeyManager(None) futures = core.dispatch( - entries, token_q, result_q, + entries, task_sem, result_q, keymanager, 1, )