Source code for tests.api.boilerplate
import asyncio
import functools
import json
from logging import Logger, StreamHandler
from os import environ as env
from pathlib import Path
from typing import Any, Awaitable, Callable, NoReturn, Optional
import pytest
from aiohttp import ClientConnectionError, ClientPayloadError, ClientSession
from novelai_api import NovelAIAPI, NovelAIError
from novelai_api.utils import get_encryption_key
[docs]class API:
_username: str
_password: str
_session: ClientSession
_sync: bool
logger: Logger
api: NovelAIAPI
[docs] def __init__(self, sync: bool = False):
dotenv = Path(".env")
if dotenv.exists():
with dotenv.open("r") as f:
for line in f:
if "=" in line:
key, value = line.strip().split("=", 1)
env[key] = value.strip()
if "NAI_USERNAME" not in env or "NAI_PASSWORD" not in env:
raise RuntimeError("Please ensure that NAI_USERNAME and NAI_PASSWORD are set in your environment")
self._username = env["NAI_USERNAME"]
self._password = env["NAI_PASSWORD"]
self._sync = sync
self.logger = Logger("NovelAI")
self.logger.addHandler(StreamHandler())
proxy = env["NAI_PROXY"] if "NAI_PROXY" in env else None
self.api = NovelAIAPI(logger=self.logger)
self.api.proxy = proxy
@property
def encryption_key(self):
return get_encryption_key(self._username, self._password)
def __enter__(self) -> NoReturn:
raise TypeError("Use async with instead")
async def __aenter__(self):
if not self._sync:
self._session = ClientSession()
await self._session.__aenter__()
self.api.attach_session(self._session)
await self.api.high_level.login(self._username, self._password)
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
if not self._sync:
await self._session.__aexit__(exc_type, exc_val, exc_tb)
[docs]def error_handler(func_ext: Optional[Callable[[Any, Any], Awaitable[Any]]] = None, *, attempts: int = 5, wait: int = 5):
"""
Decorator to add error handling to the decorated function
The function must accept an API object as first arguments
:param func_ext: Substitute for func if the decorator is run without argument. Do not provide it directly
:param attempts: Number of attempts to do before raising the error
:param wait: Time (in seconds) to wait after each call
"""
def decorator(func: Callable[[Any, Any], Awaitable[Any]]):
@functools.wraps(func)
async def wrap(*args, **kwargs):
err: Exception = RuntimeError("Error placeholder. Shouldn't happen")
for attempt in range(attempts):
try:
res = await func(*args, **kwargs)
await asyncio.sleep(wait)
return res
except (ClientConnectionError, asyncio.TimeoutError, ClientPayloadError) as e:
err = e
retry = True
except NovelAIError as e:
err = e
retry = any(
[
e.status == 502, # Bad Gateway
e.status == 520, # Cloudflare Unknown Error
e.status == 524, # Cloudflare Gateway Error
]
)
if not retry:
break
print(f"Error: {err}. Try {attempt + 1}/{attempts}")
# 10s wait between each retry
await asyncio.sleep(10)
# no internet: ping every 5 mins until connection is re-established
async with ClientSession() as session:
while True:
try:
rsp = await session.get("https://www.google.com", timeout=5 * 60)
rsp.raise_for_status()
break
except ClientConnectionError:
await asyncio.sleep(5 * 60)
except asyncio.TimeoutError:
pass
raise err
return wrap
# allow to run the function without argument
if func_ext is None:
return decorator
return decorator(func_ext)
[docs]class JSONEncoder(json.JSONEncoder):
"""
Extended JSON encoder to support bytes
"""
[docs] def default(self, o: Any) -> Any:
if isinstance(o, bytes):
return o.hex()
return super().default(o)
[docs]def dumps(e: Any) -> str:
"""
Shortcut to a configuration of json.dumps for consistency
"""
return json.dumps(e, indent=4, ensure_ascii=False, cls=JSONEncoder)
[docs]@pytest.fixture(scope="session")
async def api_handle():
"""
API handle for an Async Test. Use it as a pytest fixture
"""
async with API() as api:
yield api
[docs]@pytest.fixture(scope="session")
async def api_handle_sync():
"""
API handle for a Sync Test. Use it as a pytest fixture
"""
async with API(sync=True) as api:
yield api