Source code for novelai_api.StoryHandler

from asyncio import run
from copy import deepcopy
from json import dumps, loads
from time import time
from typing import Any, Dict, Iterator, List, Optional

from aiohttp import ClientSession

from novelai_api import NovelAIAPI
from novelai_api.BanList import BanList
from novelai_api.BiasGroup import BiasGroup
from novelai_api.GlobalSettings import GlobalSettings
from novelai_api.Idstore import Idstore
from novelai_api.Keystore import Keystore
from novelai_api.Preset import Model, Preset
from novelai_api.Tokenizer import Tokenizer
from novelai_api.utils import b64_to_tokens, decrypt_user_data, encrypt_user_data


def _get_time() -> int:
    """
    Get the current time, as formatted for createdAt and lastUpdatedAt

    :return: Current time with millisecond precision
    """

    return int(time() * 1000)


def _get_short_time() -> int:
    """
    Because some lastUpdatedAt only are precise to the second

    :return: Current time with second precision
    """

    return int(time())


def _set_nested_item(item: Dict[str, Any], val: Any, path: str):
    path = path.split(".")

    for key in path[:-1]:
        item = item[key]

    item[path[-1]] = val


[docs]class NovelAIStory: TEXT_GENERATION_SETTINGS_VERSION = 2 DEFAULT_MODEL = Model.Euterpe api: NovelAIAPI keystore: Keystore key: bytes story: Dict[str, Any] storycontent: Dict[str, Any] tree: List[int] global_settings: GlobalSettings banlists: List[BanList] biases: List[BiasGroup] model: Model preset: Preset prefix: str context_size: int def _handle_banlist(self, data: Dict[str, Any]): if "bannedSequenceGroups" not in data: data["bannedSequenceGroups"] = [] ban_seq = data["bannedSequenceGroups"] self.banlists = [BanList(*seq["sequences"], enabled=seq["enabled"]) for seq in ban_seq] def _handle_biasgroups(self, data: Dict[str, Any]): if "phraseBiasGroup" not in data: data["phraseBiasGroups"] = [] self.biases = [] for bias in data["phraseBiasGroups"]: self.biases.append(BiasGroup.from_data(bias)) def _handle_preset(self, data: Dict[str, Any]): settings = data["settings"] if "textGenerationSettingsVersion" not in settings: settings["textGenerationSettingsVersion"] = self.TEXT_GENERATION_SETTINGS_VERSION if "prefix" not in settings: settings["prefix"] = "vanilla" self.prefix = settings["prefix"] if "model" not in settings: settings["model"] = self.DEFAULT_MODEL.value self.model = Model(settings["model"]) if "preset" not in settings: settings["preset"] = "" parameters = settings["parameters"] if "bad_words_ids" in parameters: self.banlists.append(BanList(*parameters["bad_words_ids"])) del parameters["bad_words_ids"] if "logit_bias_groups" in parameters: for bias in parameters["logit_bias_groups"]: self.biases.append(BiasGroup.from_data(bias)) del parameters["logit_bias_groups"] self.preset = Preset.from_preset_data(settings) self.preset.name = settings["preset"] self.preset.model = self.model
[docs] def __init__( self, api: NovelAIAPI, keystore: Keystore, meta: str, global_settings: GlobalSettings, story: Dict[str, Any], storycontent: Dict[str, Any], ): self.api = api self.key = keystore[meta] self.story = story self.storycontent = storycontent self.tree = [] data = storycontent["data"] self.global_settings = global_settings.copy() print(dumps(data, indent=4)) self._handle_banlist(data) self._handle_biasgroups(data) self._handle_preset(data) # FIXME: variable context size ? From global settings ? self.context_size = 2048
# TODO: trimResponses # TODO: banBrackets # TODO: dynamicPenaltyRange # TODO: remember # TODO: AN # TODO: Lorebook def _create_datablock(self, fragment: Dict[str, str], end_offset: int): story = self.storycontent["data"]["story"] blocks = story["datablocks"] fragments = story["fragments"] cur_index = story["currentBlock"] cur_block = blocks[cur_index] story["step"] += 1 frag_index = len(fragments) fragments.append(fragment) start = cur_block["endIndex"] + len(cur_block["dataFragment"]["data"]) block = { "nextBlock": [], "prevBlock": cur_index, "origin": fragment["origin"], "startIndex": start, "endIndex": start + end_offset, "dataFragment": fragment, "fragmentIndex": frag_index, "removedFragments": [], "chain": False, } new_index = len(blocks) blocks.append(block) cur_block["nextBlock"].append(new_index) story["currentBlock"] = new_index self.tree.append(new_index) def __str__(self) -> str: story_fragments = self.storycontent["data"]["story"]["fragments"] story_content = "".join(fragment["data"] for fragment in story_fragments) # FIXME: handle edit return story_content
[docs] def build_context(self) -> List[int]: tokens = [] # TODO: Remember tokens # TODO: AN tokens # TODO: optimize for large stories ? # edit is a pain for input in token form, so we use it's string representation instead story_content = str(self) story_content_size = self.context_size # TODO: add option to remove superfluous spaces at the end # only tokenize the tail to handle large stories story_tokens = [] while len(story_tokens) < self.context_size: story_content_size *= 2 story_tokens = Tokenizer.encode(self.model, story_content[-story_content_size:]) # whole story content is tokenized if len(story_content) < story_content_size: break story_tokens = story_tokens[-self.context_size :] # TODO: LB tokens # TODO: Order and cut everything to fit tokens.extend(story_tokens) # Internal assert, should never happen assert len(tokens) <= self.context_size return tokens
[docs] async def generate(self): prompt = self.build_context() # FIXME: find why the output is garbage rsp = await self.api.high_level.generate( prompt, self.model, self.preset, self.global_settings, self.banlists, self.biases, self.prefix, ) output = Tokenizer.decode(self.model, b64_to_tokens(rsp["output"])) fragment = {"data": output, "origin": "ai"} self._create_datablock(fragment, 0) return self
[docs] async def edit(self, start: int, end: int, replace: str): # FIXME: redo edit implementation fragment = {"data": replace, "origin": "edit"} self._create_datablock(fragment, end - start)
[docs] async def undo(self): story = self.storycontent["data"]["story"] cur_index = story["currentBlock"] blocks = story["datablocks"] cur_block = blocks[cur_index] story["currentBlock"] = cur_block["prevBlock"]
[docs] async def redo(self): story = self.storycontent["data"]["story"] cur_index = story["currentBlock"] blocks = story["datablocks"] cur_block = blocks[cur_index] story["currentBlock"] = cur_block["nextBlock"][-1]
[docs] async def save(self, upload: bool = False) -> bool: encrypted_story = encrypt_user_data(deepcopy(self.story), self.keystore) encrypted_storycontent = encrypt_user_data(deepcopy(self.storycontent), self.keystore) success = True # TODO: keep local copy if upload ? if upload: success = success and await self.api.high_level.upload_user_content(encrypted_storycontent) success = success and await self.api.high_level.upload_user_content(encrypted_story) return success
[docs] async def choose(self, index: int): story = self.storycontent["data"]["story"] cur_index = story["currentBlock"] blocks = story["datablocks"] cur_block = blocks[cur_index] next_blocks = cur_block["nextBlock"] if not (0 <= index < len(next_blocks)): raise ValueError(f"Expected index between 0 and {len(next_blocks)}, but got {index}") story["currentBlock"] = next_blocks[index]
[docs] async def flatten(self): story = self.storycontent["data"]["story"] blocks = story["datablocks"] new_datablocks = [blocks[i] for i in self.tree] self.tree = [i for i in range(len(new_datablocks))] story["datablocks"] = new_datablocks
[docs] async def delete(self): pass
[docs] async def get_current_tree(self) -> List[Dict[str, Any]]: story = self.storycontent["data"]["story"] blocks = story["datablocks"] return [blocks[i] for i in self.tree]
[docs]class NovelAIStoryStorage: """ General storage for the NovelAIStory objects. Instances of this class should be loaded or created from here. """ _story_instances: Dict[str, NovelAIStory] api: NovelAIAPI keystore: Keystore idstore: Idstore global_settings: GlobalSettings
[docs] def __init__(self, api: NovelAIAPI, keystore: Keystore, global_settings: Optional[GlobalSettings] = None): self.api = api self.keystore = keystore self.idstore = Idstore() self.global_settings = global_settings or GlobalSettings() self._story_instances = {}
def __iter__(self) -> Iterator[NovelAIStory]: return self._story_instances.values().__iter__() def __getitem__(self, story_id: str) -> NovelAIStory: return self._story_instances[story_id] def __len__(self) -> int: return len(self._story_instances)
[docs] def load(self, story: Dict[str, Any], storycontent: Dict[str, Any]) -> NovelAIStory: """ Load a story proxy from a story and storycontent object """ story_meta = story["meta"] story_id = story["data"]["remoteStoryId"] assert ( story_meta == storycontent["meta"] ), f"Expected meta {story_meta} for storycontent, but got meta {storycontent['meta']}" assert story_id == storycontent["id"], f"Missmached id: expected {story_id}, but got {storycontent['id']}" story = NovelAIStory(self.api, self.keystore, story_meta, self.global_settings, story, storycontent) # FIXME: ignore or overwrite if id exists ? self._story_instances[story_id] = story return story
[docs] def loads(self, stories: Dict[str, Dict[str, Any]], storycontents: Dict[str, Dict[str, Any]]) -> List[NovelAIStory]: mapping = {} for story in stories: if story.get("decrypted"): mapping[story["data"]["remoteStoryId"]] = story loaded = [] for storycontent in storycontents: if storycontent.get("decrypted"): story_id = storycontent["id"] if story_id not in mapping: self.api.logger.warning(f"Storycontent {story_id} has no associated story") else: proxy = self.load(mapping[story_id], storycontent) del mapping[story_id] loaded.append(proxy) for story_id in mapping.keys(): self.api.logger.warning(f"Story {story_id} has no associated storycontent") return loaded
[docs] async def load_from_remote(self) -> List[NovelAIStory]: stories = await self.api.high_level.download_user_stories() storycontents = await self.api.high_level.download_user_story_contents() decrypt_user_data(stories, self.keystore) decrypt_user_data(storycontents, self.keystore) return self.loads(stories, storycontents)
[docs] def create(self) -> NovelAIStory: meta = self.keystore.create() current_time = _get_time() current_time_short = _get_short_time() with open("templates/template_empty_story.txt") as f: story = loads(f.read()) # local overwrites id_story = self.idstore.create() for path, val in ( ("id", id_story), ("meta", meta), ("data.id", meta), ("data.remoteStoryId", id_story), ("data.createdAt", current_time), ("data.lastUpdatedAt", current_time), ("lastUpdatedAt", current_time_short), ): _set_nested_item(story, val, path) with open("templates/template_empty_storycontent.txt") as f: storycontent = loads(f.read()) # local overwrites id_storycontent = self.idstore.create() id_lore_default = "" # FIXME: get id for path, val in ( ("id", id_storycontent), ("meta", meta), ("lastUpdatedAt", current_time_short), ("data.contextDefaults.loreDefaults.id", id_lore_default), ("data.contextDefaults.loreDefaults.lastUpdatedAt", current_time), ): _set_nested_item(storycontent, val, path) proxy = self.load(story, storycontent) return proxy
[docs] def select(self, story_id: str) -> Optional[NovelAIStory]: """ Select a story proxy from the previously created/loaded ones :param story_id: Id of the selected story :return: Story or None if the story does't exist in the handler """ if story_id not in self._story_instances: return None return self._story_instances[story_id]
[docs] def unload(self, story_id: str): """ Unload a previously created/loaded story, free'ing the NovelAI_StoryProxy object """ if story_id in self._story_instances: del self._story_instances[story_id]