Source code for novelai_api.Preset

import os
import pathlib
import warnings
from copy import deepcopy
from enum import Enum, EnumMeta, IntEnum
from json import loads
from random import choice
from typing import TYPE_CHECKING, Any, Dict, List, NoReturn, Optional, Union


[docs]class Order(IntEnum): Temperature = 0 Top_K = 1 Top_P = 2 TFS = 3 Top_A = 4 Typical_P = 5 CFG = 6 Top_G = 7 Mirostat = 8 Unified = 9 Min_p = 10
NAME_TO_ORDER = { "temperature": Order.Temperature, "top_k": Order.Top_K, "top_p": Order.Top_P, "tfs": Order.TFS, "top_a": Order.Top_A, "typical_p": Order.Typical_P, "cfg": Order.CFG, "top_g": Order.Top_G, "mirostat": Order.Mirostat, "math1": Order.Unified, "min_p": Order.Min_p, } ORDER_TO_NAME = { Order.Temperature: "temperature", Order.Top_K: "top_k", Order.Top_P: "top_p", Order.TFS: "tfs", Order.Top_A: "top_a", Order.Typical_P: "typical_p", Order.CFG: "cfg", Order.Top_G: "top_g", Order.Mirostat: "mirostat", Order.Unified: "math1", Order.Min_p: "min_p", }
[docs]def enum_contains(enum_class: EnumMeta, value: str) -> bool: """ Check if the value provided is valid for the enum :param enum_class: Class of the Enum :param value: Value to check """ if not hasattr(enum_class, "enum_member_values"): enum_class.enum_member_values = list(e.value for e in enum_class) values = enum_class.enum_member_values if len(values) == 0: raise ValueError(f"Empty enum class: '{enum_class}'") return value in values
def _strip_model_version(value: str) -> str: parts = value.split("-") if parts[-1].startswith("v") and parts[-1][1:].isdecimal(): parts = parts[:-1] return "-".join(parts)
[docs]def collapse_model(enum_class: EnumMeta, value: str): """ Collapse multiple version of a model to the last model value :param enum_class: Class of the Enum :param value: Value of the model to collapse """ if not hasattr(enum_class, "enum_member_values"): enum_class.enum_member_values = {_strip_model_version(e.value): e for e in enum_class} values = enum_class.enum_member_values if len(values) == 0: raise ValueError(f"Empty enum class: '{enum_class}'") return values.get(_strip_model_version(value))
[docs]class StrEnum(str, Enum): pass
[docs]class Model(StrEnum): # Calliope = "2.7B" Sigurd = "6B-v4" Euterpe = "euterpe-v2" Krake = "krake-v2" Clio = "clio-v1" Kayra = "kayra-v1" Erato = "llama-3-erato-v1" Genji = "genji-jp-6b-v2" Snek = "genji-python-6b" HypeBot = "hypebot" Inline = "infillmodel"
[docs]class PhraseRepPen(StrEnum): Off = "off" VeryLight = "very_light" Light = "light" Medium = "medium" Aggressive = "aggressive" VeryAggressive = "very_aggressive"
#: Prompt sent to the model when the context is empty PREAMBLE = { # Model.Calliope: "⁂\n", Model.Sigurd: "⁂\n", Model.Genji: [60, 198, 198], # "]\n\n" - impossible combination, so it is pre-tokenized Model.Snek: "<|endoftext|>\n", Model.Euterpe: "\n***\n", Model.Krake: "<|endoftext|>[ Prologue ]\n", Model.Clio: "[ Author: Various ]\n[ Prologue ]\n", Model.Kayra: "", # no preamble, it uses the "special_openings" module instead Model.Erato: "<|endoftext|>", # <|reserved_special_token_81|> if context isn't full }
[docs]class PresetView: model: Model _official_values: Dict[str, List["Preset"]]
[docs] def __init__(self, model: Model, officials_values: Dict[str, List["Preset"]]): self.model = model self._officials_values = officials_values
def __iter__(self): return self._officials_values[self.model.value].__iter__()
class _PresetMetaclass(type): _officials_values: Dict[str, List["Preset"]] def __getitem__(cls, model: Model): if not isinstance(model, Model): raise ValueError(f"Expected instance of {type(Model)}, got type '{type(model)}'") return PresetView(model, cls._officials_values)
[docs]class Preset(metaclass=_PresetMetaclass): # TODO #: Similar to logprobs, but seems to return something different. Only return one token worth of data # next_word boolean #: ??? # output_nonzero_probs boolean _TYPE_MAPPING = { "textGenerationSettingsVersion": int, "stop_sequences": list, "temperature": (int, float), "max_length": int, "min_length": int, "top_k": int, "top_a": (int, float), "top_p": (int, float), "typical_p": (int, float), "tail_free_sampling": (int, float), "repetition_penalty": (int, float), "repetition_penalty_range": int, "repetition_penalty_slope": (int, float), "repetition_penalty_frequency": (int, float), "repetition_penalty_presence": (int, float), "repetition_penalty_whitelist": list, "repetition_penalty_default_whitelist": bool, "phrase_rep_pen": (str, PhraseRepPen), "length_penalty": (int, float), "diversity_penalty": (int, float), "order": list, "cfg_scale": (int, float), "cfg_uc": str, "top_g": int, "mirostat_lr": (int, float), "mirostat_tau": (int, float), "math1_quad": (int, float), "math1_quad_entropy_scale": (int, float), "math1_temp": (int, float), "min_p": (int, float), "pad_token_id": int, "bos_token_id": int, "eos_token_id": int, "max_time": int, "no_repeat_ngram_size": int, "encoder_no_repeat_ngram_size": int, "num_return_sequences": int, "get_hidden_states": bool, } DEFAULTS = { "stop_sequences": [], "temperature": 1.0, "max_length": 40, "min_length": 1, "top_k": 0, "top_a": 1.0, "top_p": 0.0, "typical_p": 0.0, "math1_quad": 0.0, "math1_quad_entropy_scale": 0.0, "math1_temp": 1.0, "min_p": 0.0, "tail_free_sampling": 1.0, "repetition_penalty": 1.0, "repetition_penalty_range": 0, "repetition_penalty_slope": 0.0, "repetition_penalty_frequency": 0.0, "repetition_penalty_presence": 0.0, "repetition_penalty_whitelist": [], "repetition_penalty_default_whitelist": False, "length_penalty": 1.0, "diversity_penalty": 0.0, "order": list(Order), "phrase_rep_pen": PhraseRepPen.Off, } # type completion for __setitem__ and __getitem__ if TYPE_CHECKING: #: Preset version, only relevant for .preset files textGenerationSettingsVersion: int #: List of tokenized strings that should stop the generation early # TODO: add possibility for late tokenization stop_sequences: List[List[int]] #: https://naidb.miraheze.org/wiki/Generation_Settings#Randomness_(Temperature) temperature: float #: Response length, if not interrupted by a Stop Sequence max_length: int #: Minimum number of token, if interrupted by a Stop Sequence min_length: int #: https://naidb.miraheze.org/wiki/Generation_Settings#Top-K_Sampling top_k: int #: https://naidb.miraheze.org/wiki/Generation_Settings#Top-A_Sampling top_a: float #: https://naidb.miraheze.org/wiki/Generation_Settings#Nucleus_Sampling top_p: float #: https://naidb.miraheze.org/wiki/Generation_Settings#Typical_Sampling (https://arxiv.org/pdf/2202.00666.pdf) typical_p: float #: https://naidb.miraheze.org/wiki/Generation_Settings#Tail-Free_Sampling tail_free_sampling: float #: https://arxiv.org/pdf/1909.05858.pdf repetition_penalty: float #: Range (in tokens) the repetition penalty covers (https://arxiv.org/pdf/1909.05858.pdf) repetition_penalty_range: int #: https://arxiv.org/pdf/1909.05858.pdf repetition_penalty_slope: float #: https://platform.openai.com/docs/api-reference/parameter-details repetition_penalty_frequency: float #: https://platform.openai.com/docs/api-reference/parameter-details repetition_penalty_presence: float #: List of tokens that are excluded from the repetition penalty (useful for colors and the likes) repetition_penalty_whitelist: list #: Whether to use the default whitelist. Used for presets compatibility, as this setting is saved in presets repetition_penalty_default_whitelist: bool #: https://docs.novelai.net/text/phrasereppen.html phrase_rep_pen: Union[str, PhraseRepPen] #: https://huggingface.co/docs/transformers/main_classes/configuration#transformers.PretrainedConfig length_penalty: float #: https://huggingface.co/docs/transformers/main_classes/configuration#transformers.PretrainedConfig diversity_penalty: float #: list of Order to set the sampling order order: List[Union[Order, int]] #: https://docs.novelai.net/text/cfg.html cfg_scale: float #: [DEPRECATED] https://docs.novelai.net/text/cfg.html # cfg_uc: str #: [DEPRECATED] https://docs.novelai.net/text/Editor/slidersettings.html#advanced-options # top_g: int #: https://docs.novelai.net/text/Editor/slidersettings.html#advanced-options mirostat_lr: float #: https://docs.novelai.net/text/Editor/slidersettings.html#advanced-options mirostat_tau: float #: https://docs.novelai.net/text/Editor/slidersettings.html#advanced-options (Unified quad) math1_quad: float #: https://docs.novelai.net/text/Editor/slidersettings.html#advanced-options (Unified conf) math1_quad_entropy_scale: float #: https://docs.novelai.net/text/Editor/slidersettings.html#advanced-options (Unified linear) math1_temp: float #: https://docs.novelai.net/text/Editor/slidersettings.html#advanced-options min_p: float #: https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig pad_token_id: int #: https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig bos_token_id: int #: https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig eos_token_id: int #: https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig max_time: int #: https://huggingface.co/docs/transformers/main_classes/configuration#transformers.PretrainedConfig no_repeat_ngram_size: int #: https://huggingface.co/docs/transformers/main_classes/configuration#transformers.PretrainedConfig encoder_no_repeat_ngram_size: int #: https://huggingface.co/docs/transformers/main_classes/configuration#transformers.PretrainedConfig num_return_sequences: int #: PretrainedConfig.output_hidden_states get_hidden_states: bool _officials: Dict[str, Dict[str, "Preset"]] _officials_values: Dict[str, List["Preset"]] _defaults: Dict[str, str] _settings: Dict[str, Any] #: Name of the preset name: str #: Model the preset is for model: Model #: Enable state of sampling options sampling_options: List[bool]
[docs] def __init__(self, name: str, model: Model, settings: Optional[Dict[str, Any]] = None): object.__setattr__(self, "name", name) object.__setattr__(self, "model", model) object.__setattr__(self, "_settings", {}) self.update(settings) if "order" in self._settings: self.set_sampling_options_state([True] * len(self._settings["order"]))
[docs] def set_sampling_options_state(self, sampling_options_state: List[bool]): """ Set the state (enabled/disabled) of the sampling options. Set it after setting the order setting. It should come in the same order as the order setting. """ if "order" not in self._settings: raise ValueError("The order setting must be set before setting the sampling options state") if len(sampling_options_state) != len(self._settings["order"]): raise ValueError( "The length of the sampling options state list must be equal to the length " "of the sampling options list" ) object.__setattr__(self, "sampling_options", sampling_options_state)
def __setitem__(self, key: str, value: Any): if key not in self._TYPE_MAPPING: raise ValueError(f"'{key}' is not a valid setting") if not isinstance(value, self._TYPE_MAPPING[key]): # noqa (pycharm PY-36317) raise ValueError(f"Expected type '{self._TYPE_MAPPING[key]}' for {key}, but got type '{type(value)}'") self._settings[key] = value if key == "order": if not isinstance(value, list): raise ValueError(f"Expected type 'List[int|Order] for order, but got type '{type(value)}'") for i, e in enumerate(value): if not isinstance(e, (int, Order)): raise ValueError(f"Expected type 'int' or 'Order for order #{i}, but got type '{type(value[i])}'") if isinstance(e, int): value[i] = Order(e) self._settings[key] = value def __contains__(self, key: str) -> bool: return key in self._settings def __getitem__(self, key: str) -> Optional[Any]: return self._settings.get(key) def __delitem__(self, key): del self._settings[key] # give dot access capabilities to the object def __setattr__(self, key, value): if key in self._TYPE_MAPPING: self[key] = value else: object.__setattr__(self, key, value) def __getattr__(self, key): if key in self._TYPE_MAPPING: return self[key] return object.__getattribute__(self, key) def __delattr__(self, name): if name in self._TYPE_MAPPING: del self[name] else: object.__delattr__(self, name) def __repr__(self) -> str: model = self.model.value if self.model is not None else "<?>" enabled_order = [o for o, enabled in zip(self._settings["order"], self.sampling_options) if enabled] enabled_keys = ", ".join(f"{ORDER_TO_NAME[o]} = {o in enabled_order}" for o in Order) return f"Preset: '{self.name} ({model}, {enabled_keys})'"
[docs] def to_settings(self) -> Dict[str, Any]: """ Return the values stored in the preset, for a generate function """ settings = deepcopy(self._settings) if "textGenerationSettingsVersion" in settings: del settings["textGenerationSettingsVersion"] # not API relevant # remove disabled sampling options if "order" in settings: order = [ (Order(o) if isinstance(o, int) else o) for o, enabled in zip(settings["order"], self.sampling_options) if enabled ] for o in Order: if o not in order: name = ORDER_TO_NAME[o] # special handling for samplers with multiple keys if order is Order.Mirostat: keys = ["mirostat_tau", "mirostat_lr"] elif order is Order.Unified: keys = ["math1_quad", "math1_quad_entropy_scale", "math1_temp"] else: keys = [name] for key in keys: settings.pop(key, None) settings["order"] = [e.value for e in order] # sanitize Phrase Repetition Penalty if settings.get("phrase_rep_pen", None) is not None: prp = settings.pop("phrase_rep_pen") if not isinstance(prp, PhraseRepPen): prp = PhraseRepPen(prp) settings["phrase_rep_pen"] = prp.value # seems that 0 doesn't disable it, but does weird things if settings.get("repetition_penalty_range", None) == 0: del settings["repetition_penalty_range"] # delete the options that return an unknown error (success status code, but server error) if settings.get("repetition_penalty_slope", None) == 0: del settings["repetition_penalty_slope"] return settings
def __str__(self): settings = self.to_settings() # use the sanitized settings is_default = {k: " (default)" if v == self.DEFAULTS.get(k, None) else "" for k, v in settings.items()} values = "\n".join(f" {k} = {v}{is_default[k]}" for k, v in settings.items()) return f"Preset<{self.name}, {self.model}> {{\n{values}\n}}"
[docs] def to_file(self, path: str) -> NoReturn: """ Write the current preset to a file :param path: Path to the preset file to write """ raise NotImplementedError()
[docs] def copy(self) -> "Preset": """ Instantiate a new preset object from the current one """ return Preset(self.name, self.model, deepcopy(self._settings))
[docs] def set(self, name: str, value: Any) -> "Preset": """ Set a preset value. Same as `preset[name] = value` """ self[name] = value return self
[docs] def update(self, values: Optional[Dict[str, Any]] = None, **kwargs) -> "Preset": """ Update the settings stored in the preset. Works like dict.update() """ if values is not None: for k, v in values.items(): self[k] = v for k, v in kwargs.items(): self[k] = v return self
[docs] @classmethod def from_preset_data(cls, data: Optional[Dict[str, Any]] = None, **kwargs) -> "Preset": """ Instantiate a preset from preset data, the data should be the same as in a preset file. Works like dict.update() """ if data is None: data = {} data.update(kwargs) name = data["name"] if "name" in data else "<?>" model_name = data["model"] if "model" in data else "" model = collapse_model(Model, model_name) settings = data["parameters"] if "parameters" in data else {} order = settings["order"] if "order" in settings else [] settings["order"] = [NAME_TO_ORDER[o["id"]] for o in order] # TODO: add support for token banning and bias in preset settings.pop("bad_words_ids", None) # get rid of unsupported option settings.pop("logit_bias_exp", None) # get rid of unsupported option settings.pop("logit_bias_groups", None) # get rid of unsupported option c = cls(name, model, settings) c.set_sampling_options_state([o["enabled"] for o in order]) return c
[docs] @classmethod def from_file(cls, path: Union[str, bytes, os.PathLike, int]) -> "Preset": """ Instantiate a preset from the given file :param path: Path to the preset file """ with open(path, encoding="utf-8") as f: data = loads(f.read()) return cls.from_preset_data(data)
[docs] @classmethod def from_official(cls, model: Model, name: Optional[str] = None) -> Union["Preset", None]: """ Return a copy of an official preset :param model: Model to get the preset of :param name: Name of the preset. None means a random official preset should be returned :return: The chosen preset, or None if the name was not found in the list of official presets """ model_value: str = model.value if name is None: preset = choice(cls._officials_values[model_value]) else: preset = cls._officials[model_value].get(name) if preset is not None: preset = deepcopy(preset) return preset
[docs] @classmethod def from_default(cls, model: Model) -> Union["Preset", None]: """ Return a copy of the default preset for the given model :param model: Model to get the default preset of :return: The chosen preset, or None if the default preset was not found for the model """ model_value: str = model.value default = cls._defaults.get(model_value) if default is None: return None preset = cls._officials[model_value].get(default) if preset is not None: preset = deepcopy(preset) return preset
def _import_officials(): """ Import the official presets under the 'presets' directory. Performed once, at import """ cls = Preset cls._officials_values = {} cls._officials = {} cls._defaults = {} for model in Model: model: Model path = pathlib.Path(__file__).parent / "presets" / f"presets_{model.value.replace('-', '_')}" if not path.exists(): warnings.warn(f"Missing preset folder for model {model.value}") cls._officials_values[model.value] = [] cls._officials[model.value] = {} continue if (path / "default.txt").exists(): with open(path / "default.txt", encoding="utf-8") as f: cls._defaults[model.value] = f.read().splitlines()[0] officials = {} for filename in path.iterdir(): if filename.suffix == ".preset": preset = cls.from_file(str(path / filename)) officials[preset.name] = preset cls._officials_values[model.value] = list(officials.values()) cls._officials[model.value] = officials if not hasattr(Preset, "_officials"): _import_officials()