cephadm: refactor call() using asyncio.asyncio.StreamReader

simpler this way, also fix a couple issues:

* create a child watcher explicitly, see
  https://bugs.python.org/issue35621
* use StringIO for collecting outputs for better performance,
  instead of appending the lines to an existing str
* catch ValueError when reading from the stream reader,
  because StreamReader.readline() could raise ValueError when
  it reaches the buffer limit while looking for a separator.
  in this case, we should try again, in hope that the spawned
  process can feed the reader with more data which contains
  the separator (i.e., b'\n').
* backport ThreadedChildWatcher from Python 3.8 so we can
  run create_subprocess_exec() in non-main threads.

Signed-off-by: Kefu Chai <kchai@redhat.com>
This commit is contained in:
Kefu Chai 2021-01-24 14:58:51 +08:00
parent 24e858ec31
commit 30070be248

View File

@ -38,6 +38,7 @@ You can invoke cephadm in two ways:
injected_stdin = '...'
"""
import asyncio
import asyncio.subprocess
import argparse
import datetime
import fcntl
@ -64,17 +65,17 @@ from socketserver import ThreadingMixIn
from http.server import BaseHTTPRequestHandler, HTTPServer
import signal
import io
from contextlib import closing, redirect_stdout
from contextlib import redirect_stdout
import ssl
from enum import Enum
from typing import cast, Dict, List, Tuple, Optional, Union, Any, NoReturn, Callable, IO
from typing import Dict, List, Tuple, Optional, Union, Any, NoReturn, Callable, IO
import re
import uuid
from functools import partial, wraps
from functools import wraps
from glob import glob
from threading import Thread, RLock
@ -1185,38 +1186,108 @@ class CallVerbosity(Enum):
VERBOSE = 3
class StreamReaderProto(asyncio.SubprocessProtocol):
def __init__(self,
exited: asyncio.Future,
desc: str,
verbosity: CallVerbosity) -> None:
self.exited = exited
self.desc = desc
self.verbosity = verbosity
self.stdout = ''
self.stderr = ''
if sys.version_info < (3, 8):
import itertools
import threading
import warnings
from asyncio import events
def pipe_data_received(self, fd: int, data: bytes) -> None:
prefix = ''
lines = data.decode('utf-8')
class ThreadedChildWatcher(asyncio.AbstractChildWatcher):
"""Threaded child watcher implementation.
The watcher uses a thread per process
for waiting for the process finish.
It doesn't require subscription on POSIX signal
but a thread creation is not free.
The watcher has O(1) complexity, its performance doesn't depend
on amount of spawn processes.
"""
if fd == sys.stdout.fileno():
prefix = self.desc + 'stdout'
self.stdout += lines
elif fd == sys.stderr.fileno():
prefix = self.desc + 'stderr'
self.stderr += lines
else:
assert False, f"unknown data received from fd: {fd}"
def __init__(self):
self._pid_counter = itertools.count(0)
self._threads = {}
for line in lines.split('\n'):
if self.verbosity == CallVerbosity.VERBOSE:
logger.info(prefix + line)
elif self.verbosity != CallVerbosity.SILENT:
logger.debug(prefix + line)
def is_active(self):
return True
def close(self):
self._join_threads()
def _join_threads(self):
"""Internal: Join all non-daemon threads"""
threads = [thread for thread in list(self._threads.values())
if thread.is_alive() and not thread.daemon]
for thread in threads:
thread.join()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
pass
def __del__(self, _warn=warnings.warn):
threads = [thread for thread in list(self._threads.values())
if thread.is_alive()]
if threads:
_warn(f"{self.__class__} has registered but not finished child processes",
ResourceWarning,
source=self)
def add_child_handler(self, pid, callback, *args):
loop = events.get_event_loop()
thread = threading.Thread(target=self._do_waitpid,
name=f"waitpid-{next(self._pid_counter)}",
args=(loop, pid, callback, args),
daemon=True)
self._threads[pid] = thread
thread.start()
def remove_child_handler(self, pid):
# asyncio never calls remove_child_handler() !!!
# The method is no-op but is implemented because
# abstract base classe requires it
return True
def attach_loop(self, loop):
pass
def _do_waitpid(self, loop, expected_pid, callback, args):
assert expected_pid > 0
try:
pid, status = os.waitpid(expected_pid, 0)
except ChildProcessError:
# The child process is already reaped
# (may happen if waitpid() is called elsewhere).
pid = expected_pid
returncode = 255
logger.warning(
"Unknown child process pid %d, will report returncode 255",
pid)
else:
if os.WIFEXITED(status):
returncode = os.WEXITSTATUS(status)
elif os.WIFSIGNALED(status):
returncode = -os.WTERMSIG(status)
else:
raise ValueError(f'unknown wait status {status}')
if loop.get_debug():
logger.debug('process %s exited with returncode %s',
expected_pid, returncode)
if loop.is_closed():
logger.warning("Loop %r that handles pid %r is closed", loop, pid)
else:
loop.call_soon_threadsafe(callback, pid, returncode, *args)
self._threads.pop(expected_pid)
# unlike SafeChildWatcher which handles SIGCHLD in the main thread,
# ThreadedChildWatcher runs in a separated thread, hence allows us to
# run create_subprocess_exec() in non-main thread, see
# https://bugs.python.org/issue35621
asyncio.set_child_watcher(ThreadedChildWatcher())
def process_exited(self) -> None:
self.exited.set_result(True)
try:
from asyncio import run as async_run # type: ignore[attr-defined]
@ -1227,8 +1298,11 @@ except ImportError:
asyncio.set_event_loop(loop)
return loop.run_until_complete(coro)
finally:
asyncio.set_event_loop(None)
loop.close()
try:
loop.run_until_complete(loop.shutdown_asyncgens())
finally:
asyncio.set_event_loop(None)
loop.close()
def call(ctx: CephadmContext,
command: List[str],
@ -1253,47 +1327,43 @@ def call(ctx: CephadmContext,
logger.debug("Running command: %s" % ' '.join(command))
async def run_with_timeout():
loop = asyncio.get_event_loop()
proc_exited = loop.create_future()
protocol_factory = partial(StreamReaderProto,
proc_exited,
prefix, verbosity)
transport, protocol = await loop.subprocess_exec(
protocol_factory,
async def tee(reader: asyncio.StreamReader) -> str:
collected = StringIO()
async for line in reader:
message = line.decode('utf-8')
collected.write(message)
if verbosity == CallVerbosity.VERBOSE:
logger.info(prefix + message.rstrip())
elif verbosity != CallVerbosity.SILENT:
logger.debug(prefix + message.rstrip())
return collected.getvalue()
async def run_with_timeout() -> Tuple[str, str, int]:
process = await asyncio.create_subprocess_exec(
*command,
close_fds=True,
**kwargs)
proc_transport = cast(asyncio.SubprocessTransport, transport)
proc_protocol = cast(StreamReaderProto, protocol)
returncode = 0
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE)
assert process.stdout
assert process.stderr
try:
if timeout:
await asyncio.wait_for(proc_exited, timeout)
else:
await proc_exited
stdout, stderr = await asyncio.gather(tee(process.stdout),
tee(process.stderr))
returncode = await asyncio.wait_for(process.wait(), timeout)
except asyncio.TimeoutError:
logger.info(prefix + f'timeout after {timeout} seconds')
returncode = 124
return '', '', 124
else:
returncode = cast(int, proc_transport.get_returncode())
finally:
proc_transport.close()
return (returncode,
proc_protocol.stdout,
proc_protocol.stderr)
return stdout, stderr, returncode
returncode, out, err = async_run(run_with_timeout())
stdout, stderr, returncode = async_run(run_with_timeout())
if returncode != 0 and verbosity == CallVerbosity.VERBOSE_ON_FAILURE:
# dump stdout + stderr
logger.info('Non-zero exit code %d from %s',
returncode, ' '.join(command))
for line in out.splitlines():
for line in stdout.splitlines():
logger.info(prefix + 'stdout ' + line)
for line in err.splitlines():
for line in stderr.splitlines():
logger.info(prefix + 'stderr ' + line)
return out, err, returncode
return stdout, stderr, returncode
def call_throws(