expose HTTPError, and raise it if status >= 400 for aiohttp

This commit is contained in:
lilydjwg 2018-05-08 18:34:32 +08:00
parent c86f8820b7
commit 6c15ee8517
3 changed files with 16 additions and 4 deletions

View File

@ -5,14 +5,23 @@ import atexit
import aiohttp import aiohttp
connector = aiohttp.TCPConnector(limit=20) connector = aiohttp.TCPConnector(limit=20)
__all__ = ['session'] __all__ = ['session', 'HTTPError']
class HTTPError(Exception):
def __init__(self, code, message):
self.code = code
self.message = message
class BetterClientSession(aiohttp.ClientSession): class BetterClientSession(aiohttp.ClientSession):
async def _request(self, *args, **kwargs): async def _request(self, *args, **kwargs):
if hasattr(self, "nv_config") and self.nv_config.get("proxy"): if hasattr(self, "nv_config") and self.nv_config.get("proxy"):
kwargs.setdefault("proxy", self.nv_config.get("proxy")) kwargs.setdefault("proxy", self.nv_config.get("proxy"))
return await super(BetterClientSession, self)._request(*args, **kwargs) res = await super(BetterClientSession, self)._request(
*args, **kwargs)
if res.status >= 400:
raise HTTPError(res.status, res.reason)
return res
session = BetterClientSession(connector=connector, read_timeout=10, conn_timeout=5) session = BetterClientSession(connector=connector, read_timeout=10, conn_timeout=5)
atexit.register(session.close) atexit.register(session.close)

View File

@ -5,6 +5,7 @@ import json
from urllib.parse import urlencode from urllib.parse import urlencode
from tornado.httpclient import AsyncHTTPClient, HTTPRequest, HTTPResponse from tornado.httpclient import AsyncHTTPClient, HTTPRequest, HTTPResponse
from tornado.httpclient import HTTPError
from tornado.platform.asyncio import AsyncIOMainLoop, to_asyncio_future from tornado.platform.asyncio import AsyncIOMainLoop, to_asyncio_future
AsyncIOMainLoop().install() AsyncIOMainLoop().install()
@ -14,7 +15,7 @@ try:
except ImportError: except ImportError:
pycurl = None pycurl = None
__all__ = ['session'] __all__ = ['session', 'HTTPError']
client = AsyncHTTPClient() client = AsyncHTTPClient()
HTTP2_AVAILABLE = None if pycurl else False HTTP2_AVAILABLE = None if pycurl else False

View File

@ -3,6 +3,8 @@
import tempfile import tempfile
from nvchecker.source import HTTPError
import pytest import pytest
pytestmark = [pytest.mark.asyncio] pytestmark = [pytest.mark.asyncio]
@ -31,5 +33,5 @@ keyfile = {f.name}
try: try:
await run_source(test_conf) await run_source(test_conf)
except Exception as e: except HTTPError as e:
assert e.code == 401 assert e.code == 401