use asyncio.Semaphore instead of self-made queue
This commit is contained in:
parent
f4983eaea3
commit
185a7e88a9
|
@ -44,11 +44,11 @@ def main() -> None:
|
|||
if options.proxy is not None:
|
||||
ctx_proxy.set(options.proxy)
|
||||
|
||||
token_q = core.token_queue(options.max_concurrency)
|
||||
task_sem = asyncio.Semaphore(options.max_concurrency)
|
||||
result_q: asyncio.Queue[RawResult] = asyncio.Queue()
|
||||
try:
|
||||
futures = core.dispatch(
|
||||
entries, token_q, result_q,
|
||||
entries, task_sem, result_q,
|
||||
keymanager, args.tries,
|
||||
)
|
||||
except ModuleNotFoundError as e:
|
||||
|
|
|
@ -193,17 +193,9 @@ def load_file(
|
|||
return cast(Entries, config), Options(
|
||||
ver_files, max_concurrency, proxy, keymanager)
|
||||
|
||||
def token_queue(maxsize: int) -> Queue[bool]:
|
||||
token_q: Queue[bool] = Queue(maxsize=maxsize)
|
||||
|
||||
for _ in range(maxsize):
|
||||
token_q.put_nowait(True)
|
||||
|
||||
return token_q
|
||||
|
||||
def dispatch(
|
||||
entries: Entries,
|
||||
token_q: Queue[bool],
|
||||
task_sem: asyncio.Semaphore,
|
||||
result_q: Queue[RawResult],
|
||||
keymanager: KeyManager,
|
||||
tries: int,
|
||||
|
@ -232,7 +224,7 @@ def dispatch(
|
|||
ctx = root_ctx.copy()
|
||||
worker = ctx.run(
|
||||
worker_cls,
|
||||
token_q, result_q, tasks, keymanager,
|
||||
task_sem, result_q, tasks, keymanager,
|
||||
)
|
||||
if worker_cls is FunctionWorker:
|
||||
func = mod.get_version # type: ignore
|
||||
|
|
|
@ -5,9 +5,8 @@ from __future__ import annotations
|
|||
|
||||
import asyncio
|
||||
from asyncio import Queue
|
||||
import contextlib
|
||||
from typing import (
|
||||
Dict, Optional, List, AsyncGenerator, NamedTuple, Union,
|
||||
Dict, Optional, List, NamedTuple, Union,
|
||||
Any, Tuple, Callable, Coroutine, Hashable,
|
||||
TYPE_CHECKING,
|
||||
)
|
||||
|
@ -72,10 +71,10 @@ class Result(NamedTuple):
|
|||
class BaseWorker:
|
||||
'''The base class for defining `Worker` classes for source plugins.
|
||||
|
||||
.. py:attribute:: token_q
|
||||
:type: Queue[bool]
|
||||
.. py:attribute:: task_sem
|
||||
:type: asyncio.Semaphore
|
||||
|
||||
This is the rate-limiting queue. Workers should obtain one token before doing one unit of work.
|
||||
This is the rate-limiting semaphore. Workers should acquire it while doing one unit of work.
|
||||
|
||||
.. py:attribute:: result_q
|
||||
:type: Queue[RawResult]
|
||||
|
@ -96,28 +95,16 @@ class BaseWorker:
|
|||
'''
|
||||
def __init__(
|
||||
self,
|
||||
token_q: Queue[bool],
|
||||
task_sem: asyncio.Semaphore,
|
||||
result_q: Queue[RawResult],
|
||||
tasks: List[Tuple[str, Entry]],
|
||||
keymanager: KeyManager,
|
||||
) -> None:
|
||||
self.token_q = token_q
|
||||
self.task_sem = task_sem
|
||||
self.result_q = result_q
|
||||
self.keymanager = keymanager
|
||||
self.tasks = tasks
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def acquire_token(self) -> AsyncGenerator[None, None]:
|
||||
'''A context manager to obtain a token from the `token_q` on entrance and
|
||||
release it on exit.'''
|
||||
token = await self.token_q.get()
|
||||
logger.debug('got token')
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
await self.token_q.put(token)
|
||||
logger.debug('return token')
|
||||
|
||||
@abc.abstractmethod
|
||||
async def run(self) -> None:
|
||||
'''Run the `tasks`. Subclasses should implement this method.'''
|
||||
|
@ -227,7 +214,7 @@ class FunctionWorker(BaseWorker):
|
|||
ctx_ua.set(ua)
|
||||
|
||||
try:
|
||||
async with self.acquire_token():
|
||||
async with self.task_sem:
|
||||
version = await self.func(
|
||||
name, entry,
|
||||
cache = self.cache,
|
||||
|
|
|
@ -20,8 +20,6 @@ def _decompress_data(url: str, data: bytes) -> str:
|
|||
elif url.endswith(".gz"):
|
||||
import gzip
|
||||
data = gzip.decompress(data)
|
||||
else:
|
||||
raise NotImplementedError(url)
|
||||
|
||||
return data.decode('utf-8')
|
||||
|
||||
|
|
|
@ -67,7 +67,7 @@ class Worker(BaseWorker):
|
|||
) -> None:
|
||||
task_by_name: Dict[str, Entry] = dict(self.tasks)
|
||||
|
||||
async with self.acquire_token():
|
||||
async with self.task_sem:
|
||||
results = await _run_batch_impl(batch, aur_results)
|
||||
for name, version in results.items():
|
||||
r = RawResult(name, version, task_by_name[name])
|
||||
|
|
|
@ -10,7 +10,7 @@ from nvchecker.api import (
|
|||
class Worker(BaseWorker):
|
||||
async def run(self) -> None:
|
||||
exc = GetVersionError('no source specified')
|
||||
async with self.acquire_token():
|
||||
async with self.task_sem:
|
||||
for name, conf in self.tasks:
|
||||
await self.result_q.put(
|
||||
RawResult(name, exc, conf))
|
||||
|
|
|
@ -18,7 +18,7 @@ use_keyfile = False
|
|||
async def run(
|
||||
entries: Entries, max_concurrency: int = 20,
|
||||
) -> VersData:
|
||||
token_q = core.token_queue(max_concurrency)
|
||||
task_sem = asyncio.Semaphore(max_concurrency)
|
||||
result_q: asyncio.Queue[RawResult] = asyncio.Queue()
|
||||
keyfile = os.environ.get('KEYFILE')
|
||||
if use_keyfile and keyfile:
|
||||
|
@ -28,7 +28,7 @@ async def run(
|
|||
keymanager = core.KeyManager(None)
|
||||
|
||||
futures = core.dispatch(
|
||||
entries, token_q, result_q,
|
||||
entries, task_sem, result_q,
|
||||
keymanager, 1,
|
||||
)
|
||||
|
||||
|
|
Loading…
Reference in New Issue