Source code for tests.api.boilerplate
import asyncio
import functools
import json
from logging import Logger, StreamHandler
from os import environ as env
from pathlib import Path
from typing import Any, Awaitable, Callable, NoReturn, Optional
import pytest
from aiohttp import ClientConnectionError, ClientPayloadError, ClientSession
from novelai_api import NovelAIAPI, NovelAIError
from novelai_api.utils import get_encryption_key
[docs]class API:
_username: str
_password: str
_session: ClientSession
_sync: bool
logger: Logger
api: NovelAIAPI
[docs] def __init__(self, sync: bool = False):
dotenv = Path(".env")
if dotenv.exists():
with"r") as f:
for line in f:
if "=" in line:
key, value = line.strip().split("=", 1)
env[key] = value.strip()
if "NAI_USERNAME" not in env or "NAI_PASSWORD" not in env:
raise RuntimeError("Please ensure that NAI_USERNAME and NAI_PASSWORD are set in your environment")
self._username = env["NAI_USERNAME"]
self._password = env["NAI_PASSWORD"]
self._sync = sync
self.logger = Logger("NovelAI")
proxy = env["NAI_PROXY"] if "NAI_PROXY" in env else None
self.api = NovelAIAPI(logger=self.logger)
self.api.proxy = proxy
def encryption_key(self):
return get_encryption_key(self._username, self._password)
def __enter__(self) -> NoReturn:
raise TypeError("Use async with instead")
async def __aenter__(self):
if not self._sync:
self._session = ClientSession()
await self._session.__aenter__()
await self.api.high_level.login(self._username, self._password)
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
if not self._sync:
await self._session.__aexit__(exc_type, exc_val, exc_tb)
[docs]def error_handler(func_ext: Optional[Callable[[Any, Any], Awaitable[Any]]] = None, *, attempts: int = 5, wait: int = 5):
Decorator to add error handling to the decorated function
The function must accept an API object as first arguments
:param func_ext: Substitute for func if the decorator is run without argument. Do not provide it directly
:param attempts: Number of attempts to do before raising the error
:param wait: Time (in seconds) to wait after each call
def decorator(func: Callable[[Any, Any], Awaitable[Any]]):
async def wrap(*args, **kwargs):
err: Exception = RuntimeError("Error placeholder. Shouldn't happen")
for attempt in range(attempts):
res = await func(*args, **kwargs)
await asyncio.sleep(wait)
return res
except (ClientConnectionError, asyncio.TimeoutError, ClientPayloadError) as e:
err = e
retry = True
except NovelAIError as e:
err = e
retry = any(
e.status == 502, # Bad Gateway
e.status == 520, # Cloudflare Unknown Error
e.status == 524, # Cloudflare Gateway Error
if not retry:
print(f"Error: {err}. Try {attempt + 1}/{attempts}")
# 10s wait between each retry
await asyncio.sleep(10)
# no internet: ping every 5 mins until connection is re-established
async with ClientSession() as session:
while True:
rsp = await session.get("", timeout=5 * 60)
except ClientConnectionError:
await asyncio.sleep(5 * 60)
except asyncio.TimeoutError:
raise err
return wrap
# allow to run the function without argument
if func_ext is None:
return decorator
return decorator(func_ext)
[docs]class JSONEncoder(json.JSONEncoder):
Extended JSON encoder to support bytes
[docs] def default(self, o: Any) -> Any:
if isinstance(o, bytes):
return o.hex()
return super().default(o)
[docs]def dumps(e: Any) -> str:
Shortcut to a configuration of json.dumps for consistency
return json.dumps(e, indent=4, ensure_ascii=False, cls=JSONEncoder)
async def api_handle():
API handle for an Async Test. Use it as a pytest fixture
async with API() as api:
yield api
async def api_handle_sync():
API handle for a Sync Test. Use it as a pytest fixture
async with API(sync=True) as api:
yield api