Source code for tests.api.test_decrypt_encrypt_integrity_check

"""
Test if the content decryption/decompression is consistent with encryption/compression for downloaded content
"""

from asyncio import run
from pathlib import Path
from subprocess import PIPE, Popen
from typing import Any, List

from novelai_api import utils
from novelai_api.utils import compress_user_data, decompress_user_data, decrypt_user_data, encrypt_user_data
from tests.api.boilerplate import API, api_handle, api_handle_sync, error_handler  # noqa: F401  # pylint: disable=W0611


def compare_in_out(type_name: str, items_in: List[Any], items_out: List[Any]):
    flags_list = []
    fail_items = []

    for i, (item_in, item_out) in enumerate(zip(items_in, items_out)):
        if item_in == item_out:
            flags_list.append("O")
        else:
            flags_list.append("X")
            fail_items.append((i, item_in, item_out))

    if fail_items:
        flags = "".join(flags_list)

        errors = [f"{flags.count('X')}/{len(flags)} integrity checks failed for {type_name}", flags, ""]

        for i, item_in, item_out in fail_items:
            errors.extend((f"Item {i} failed:", item_in, "=" * 20 + " vs " + "=" * 20, item_out, ""))

        raise AssertionError("\n".join(errors))

    print(f"All {len(flags_list)} integrity checks succeeded for {type_name}\n")


fflate_path = Path(__file__).parent.absolute() / "fflate_inflate.js"


def inflate_js(data: bytes, _) -> bytes:
    with Popen(["node", fflate_path, str(len(data))], stdin=PIPE, stdout=PIPE) as p:
        out, _ = p.communicate(data)

    return out


[docs]@error_handler(wait=0) async def keystore_integrity(handle: API): """ Verify the integrity of the keystore on decryption - encryption """ api = handle.api key = handle.encryption_key keystore = await api.high_level.get_keystore(key) encrypted_keystore_in = [str(keystore.data)] keystore.encrypt(key) encrypted_keystore_out = [str(keystore.data)] compare_in_out("keystore", encrypted_keystore_in, encrypted_keystore_out)
async def test_keystore_integrity_sync(api_handle_sync): # noqa: F811 # pylint: disable=W0621 await keystore_integrity(api_handle_sync) async def test_keystore_integrity_async(api_handle): # noqa: F811 # pylint: disable=W0621 await keystore_integrity(api_handle)
[docs]@error_handler(wait=0) async def stories_integrity(handle: API): """ Verify the integrity of 'stories' objects on decryption - encryption """ api = handle.api key = handle.encryption_key keystore = await api.high_level.get_keystore(key) stories = await api.high_level.download_user_stories() encrypted_stories_in = [str(story) for story in stories] decrypt_user_data(stories, keystore) encrypt_user_data(stories, keystore) encrypted_stories_out = [str(story) for story in stories] compare_in_out("stories", encrypted_stories_in, encrypted_stories_out)
async def test_stories_integrity_sync(api_handle_sync): # noqa: F811 # pylint: disable=W0621 await stories_integrity(api_handle_sync) async def test_stories_integrity_async(api_handle): # noqa: F811 # pylint: disable=W0621 await stories_integrity(api_handle)
[docs]@error_handler(wait=0) async def storycontent_integrity(handle: API): """ Verify the integrity of 'storycontent' objects on decryption - encryption """ api = handle.api key = handle.encryption_key keystore = await api.high_level.get_keystore(key) story_contents = await api.high_level.download_user_story_contents() decrypt_user_data(story_contents, keystore) decrypted_storycontent_in = [str(story_content) for story_content in story_contents] encrypt_user_data(story_contents, keystore) inflate_backup = utils.inflate utils.inflate = inflate_js decrypt_user_data(story_contents, keystore) utils.inflate = inflate_backup decrypted_storycontent_out = [str(story_content) for story_content in story_contents] compare_in_out("storycontent", decrypted_storycontent_in, decrypted_storycontent_out)
async def test_storycontent_integrity_sync(api_handle_sync): # noqa: F811 # pylint: disable=W0621 await storycontent_integrity(api_handle_sync) async def test_storycontent_integrity_async(api_handle): # noqa: F811 # pylint: disable=W0621 await storycontent_integrity(api_handle)
[docs]@error_handler(wait=0) async def presets_integrity(handle: API): """ Verify the integrity of 'presets' objects on decompression - compression """ api = handle.api presets = await api.high_level.download_user_presets() encrypted_presets_in = [str(preset) for preset in presets] decompress_user_data(presets) compress_user_data(presets) encrypted_presets_out = [str(preset) for preset in presets] compare_in_out("presets", encrypted_presets_in, encrypted_presets_out)
async def test_presets_integrity_sync(api_handle_sync): # noqa: F811 # pylint: disable=W0621 await presets_integrity(api_handle_sync) async def test_presets_integrity_async(api_handle): # noqa: F811 # pylint: disable=W0621 await presets_integrity(api_handle)
[docs]@error_handler(wait=0) async def aimodules_integrity(handle: API): """ Verify the integrity of 'aimodules' objects on decryption - encryption """ api = handle.api key = handle.encryption_key keystore = await api.high_level.get_keystore(key) modules = await api.high_level.download_user_modules() encrypted_modules_in = [str(module) for module in modules] decrypt_user_data(modules, keystore) encrypt_user_data(modules, keystore) encrypted_modules_out = [str(module) for module in modules] compare_in_out("aimodules", encrypted_modules_in, encrypted_modules_out)
async def test_aimodules_integrity_sync(api_handle_sync): # noqa: F811 # pylint: disable=W0621 await aimodules_integrity(api_handle_sync) async def test_aimodules_integrity_async(api_handle): # noqa: F811 # pylint: disable=W0621 await aimodules_integrity(api_handle)
[docs]@error_handler(wait=0) async def shelves_integrity(handle: API): """ Verify the integrity of 'shelves' objects on decompression - compression """ api = handle.api shelves = await api.high_level.download_user_shelves() encrypted_shelves_in = [str(shelf) for shelf in shelves] decompress_user_data(shelves) compress_user_data(shelves) encrypted_shelves_out = [str(shelf) for shelf in shelves] compare_in_out("shelves", encrypted_shelves_in, encrypted_shelves_out)
async def test_shelves_integrity_sync(api_handle_sync): # noqa: F811 # pylint: disable=W0621 await shelves_integrity(api_handle_sync) async def test_shelves_integrity_async(api_handle): # noqa: F811 # pylint: disable=W0621 await shelves_integrity(api_handle) if __name__ == "__main__": async def main(): async with API() as api: await test_keystore_integrity_async(api) await test_stories_integrity_async(api) await test_storycontent_integrity_async(api) await test_presets_integrity_async(api) await test_shelves_integrity_async(api) async with API(sync=True) as api: await test_keystore_integrity_sync(api) await test_stories_integrity_sync(api) await test_storycontent_integrity_sync(api) await test_presets_integrity_sync(api) await test_shelves_integrity_sync(api) run(main())