Source code for novelai_api.GlobalSettings

from typing import TYPE_CHECKING, Any, Dict

from novelai_api.BiasGroup import BiasGroup
from novelai_api.Preset import Model
from novelai_api.python_utils import expand_kwargs
from novelai_api.Tokenizer import Tokenizer


[docs]class GlobalSettings: """ Object used to store global settings for the account """ # TODO: store bracket ban in a file _BRACKETS = { "gpt2": [ [58], [60], [90], [92], [685], [1391], [1782], [2361], [3693], [4083], [4357], [4895], [5512], [5974], [7131], [8183], [8351], [8762], [8964], [8973], [9063], [11208], [11709], [11907], [11919], [12878], [12962], [13018], [13412], [14631], [14692], [14980], [15090], [15437], [16151], [16410], [16589], [17241], [17414], [17635], [17816], [17912], [18083], [18161], [18477], [19629], [19779], [19953], [20520], [20598], [20662], [20740], [21476], [21737], [22133], [22241], [22345], [22935], [23330], [23785], [23834], [23884], [25295], [25597], [25719], [25787], [25915], [26076], [26358], [26398], [26894], [26933], [27007], [27422], [28013], [29164], [29225], [29342], [29565], [29795], [30072], [30109], [30138], [30866], [31161], [31478], [32092], [32239], [32509], [33116], [33250], [33761], [34171], [34758], [34949], [35944], [36338], [36463], [36563], [36786], [36796], [36937], [37250], [37913], [37981], [38165], [38362], [38381], [38430], [38892], [39850], [39893], [41832], [41888], [42535], [42669], [42785], [42924], [43839], [44438], [44587], [44926], [45144], [45297], [46110], [46570], [46581], [46956], [47175], [47182], [47527], [47715], [48600], [48683], [48688], [48874], [48999], [49074], [49082], [49146], [49946], [10221], [4841], [1427], [2602, 834], [29343], [37405], [35780], [2602], [50256], ], "gpt2-genji": [], "pile": [ [60], [62], [544], [683], [696], [880], [905], [1008], [1019], [1084], [1092], [1181], [1184], [1254], [1447], [1570], [1656], [2194], [2470], [2479], [2498], [2947], [3138], [3291], [3455], [3725], [3851], [3891], [3921], [3951], [4207], [4299], [4622], [4681], [5013], [5032], [5180], [5218], [5290], [5413], [5456], [5709], [5749], [5774], [6038], [6257], [6334], [6660], [6904], [7082], [7086], [7254], [7444], [7748], [8001], [8088], [8168], [8562], [8605], [8795], [8850], [9014], [9102], [9259], [9318], [9336], [9502], [9686], [9793], [9855], [9899], [9955], [10148], [10174], [10943], [11326], [11337], [11661], [12004], [12084], [12159], [12520], [12977], [13380], [13488], [13663], [13811], [13976], [14412], [14598], [14767], [15640], [15707], [15775], [15830], [16079], [16354], [16369], [16445], [16595], [16614], [16731], [16943], [17278], [17281], [17548], [17555], [17981], [18022], [18095], [18297], [18413], [18736], [18772], [18990], [19181], [20095], [20197], [20481], [20629], [20871], [20879], [20924], [20977], [21375], [21382], [21391], [21687], [21810], [21828], [21938], [22367], [22372], [22734], [23405], [23505], [23734], [23741], [23781], [24237], [24254], [24345], [24430], [25416], [25896], [26119], [26635], [26842], [26991], [26997], [27075], [27114], [27468], [27501], [27618], [27655], [27720], [27829], [28052], [28118], [28231], [28532], [28571], [28591], [28653], [29013], [29547], [29650], [29925], [30522], [30537], [30996], [31011], [31053], [31096], [31148], [31258], [31350], [31379], [31422], [31789], [31830], [32214], [32666], [32871], [33094], [33376], [33440], [33805], [34368], [34398], [34417], [34418], [34419], [34476], [34494], [34607], [34758], [34761], [34904], [34993], [35117], [35138], [35237], [35487], [35830], [35869], [36033], [36134], [36320], [36399], [36487], [36586], [36676], [36692], [36786], [37077], [37594], [37596], [37786], [37982], [38475], [38791], [39083], [39258], [39487], [39822], [40116], [40125], [41000], [41018], [41256], [41305], [41361], [41447], [41449], [41512], [41604], [42041], [42274], [42368], [42696], [42767], [42804], [42854], [42944], [42989], [43134], [43144], [43189], [43521], [43782], [44082], [44162], [44270], [44308], [44479], [44524], [44965], [45114], [45301], [45382], [45443], [45472], [45488], [45507], [45564], [45662], [46265], [46267], [46275], [46295], [46462], [46468], [46576], [46694], [47093], [47384], [47389], [47446], [47552], [47686], [47744], [47916], [48064], [48167], [48392], [48471], [48664], [48701], [49021], [49193], [49236], [49550], [49694], [49806], [49824], [50001], [50256], [0], [1], ], "nerdstash_v1": [ [3], [49356], [1431], [31715], [34387], [20765], [30702], [10691], [49333], [1266], [19438], [43145], [26523], [41471], [2936], [85, 85], [49332], [7286], [1115], ], "nerdstash_v2": [ [3], [49356], [1431], [31715], [34387], [20765], [30702], [10691], [49333], [1266], [19438], [43145], [26523], [41471], [2936], [85, 85], [49332], [7286], [1115], ], "llama3": [ [16067], [933, 11144], [25106, 11144], [58, 106901, 16073, 33710, 25, 109933], [933, 58, 11144], [128030], [58, 30591, 33503, 17663, 100204, 25, 11144], ], } # this one is pretty much mandatory, else genji throws errors all the time # NOTE: the list might reach more tokens than necessary, but better be safe than sorry _GENJI_AMBIGUOUS_TOKENS = [ [58], [60], [90], [92], [685], [1391], [1782], [2361], [3693], [4083], [4357], [4895], [5512], [5974], [7131], [8183], [8964], [8973], [11208], [11709], [11907], [12878], [12962], [13412], [14692], [14980], [15090], [15437], [16151], [16589], [17241], [17414], [17635], [17816], [17912], [18083], [18161], [18477], [19629], [19953], [20520], [20662], [20740], [21737], [22241], [23330], [23834], [26076], [26398], [26894], [26933], [27007], [27422], [29164], [29795], [30072], [30109], [30866], [31478], [32239], [33250], [34758], [34949], [36463], [36786], [36796], [36937], [37981], [38165], [38362], [38381], [38430], [38892], [39850], [41832], [41888], [42535], [42669], [42924], [43839], [44587], [45297], [46956], [47527], [47715], [48688], [48874], [48999], [49146], [49946], [50256], [8162], [198, 198, 198], [49564, 198, 49564], [37605], [22522], [40265], [5099], [39752], [32368], [49564, 49564, 49564, 49564, 49564, 49564], [31857, 31857, 31857, 31857, 31857, 31857], [17992], [39187], [40367], [15790], [47571], [27032], [628, 198], [628, 628], [49564, 198, 49564], [28156], [30298], [34650], ] # whitelist _REP_PEN_WHITELIST = { "gpt2": [], "gpt2-genji": [], "pile": [], "nerdstash_v1": [ "'", '"', ",", ".", ":", "\n", "ve", "s", "t", "n", "d", "ll", "re", "m", "-", "*", ")", " the", " a", " an", " and", " or", " not", " no", " is", " was", " were", " did", " does", " isn", " wasn", " weren", " didn", " doesn", " him", " her", " his", " hers", " their", " its", " could", " couldn", " should", " shouldn", " would", " wouldn", " have", " haven", " had", " hadn", " has", " hasn", " can", " cannot", " are", " aren", " will", " won", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", '."', ',"', "====", " ", ], "nerdstash_v2": [ "'", '"', ",", ".", ":", "\n", "-", "*", ")", " the", " a", " an", " and", " or", " not", " no", " is", " was", " were", " did", " does", " isn", " wasn", " weren", " didn", " doesn", " him", " her", " his", " hers", " their", " its", " could", " couldn", " should", " shouldn", " would", " wouldn", " have", " haven", " had", " hadn", " has", " hasn", " can", " cannot", " are", " aren", " will", " won", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", '."', ',"', "====", " ", "'t've", "'s", "'t", "'ve", "'n", "'ll", "'d", "'re", "'m", ], "llama3": [ '"', "'", ")", "*", ",", "-", ".", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", ":", "\n", " ", " a", " the", " in", " to", " of", " and", " is", " on", " that", " an", " or", " from", " at", " are", " not", " was", "'s", "====", " have", " can", " will", " has", " his", " their", " no", "'t", " had", " were", " would", " her", " its", '."', " should", ',"', " could", " him", " did", " does", "'re", " won", "'m", "'ve", " doesn", " didn", "'ll", " cannot", "'d", " isn", " wasn", " aren", " couldn", " wouldn", " haven", " hers", " hasn", " shouldn", " weren", " hadn", "'n", ], } _DINKUS_ASTERISM = { "gpt2": BiasGroup(-0.12).add("***", "⁂"), "gpt2-genji": None, "pile": BiasGroup(-0.12).add("***"), "nerdstash_v1": BiasGroup(-0.08).add("***", "⁂"), "nerdstash_v2": BiasGroup(-0.08).add("***", "⁂"), "llama3": None, } _DEFAULT_SETTINGS = { "generate_until_sentence": False, "num_logprobs": 10, "ban_brackets": True, "bias_dinkus_asterism": False, "ban_ambiguous_genji_tokens": True, "rep_pen_whitelist": True, } # type completion for __setitem__ and __getitem__ if TYPE_CHECKING: #: Generate up to 20 tokens after max_length if an end of sentence if found within these 20 tokens generate_until_sentence: bool #: Number of logprobs to return for each token. Set to NO_LOGPROBS to disable num_logprobs: int #: Apply the BRACKET biases ban_brackets: bool #: Apply the DINKUS_ASTERISM biases bias_dinkus_asterism: bool #: Apply the GENJI_AMBIGUOUS_TOKENS if model is Genji ban_ambiguous_genji_tokens: bool #: Apply the REP_PEN_WHITELIST (repetition penalty whitelist) rep_pen_whitelist: bool #: Value to set num_logprobs at to disable logprobs NO_LOGPROBS = -1 _settings: Dict[str, Any]
[docs] @expand_kwargs(_DEFAULT_SETTINGS.keys(), (type(e) for e in _DEFAULT_SETTINGS.values())) def __init__(self, **kwargs): object.__setattr__(self, "_settings", {}) for setting, default in self._DEFAULT_SETTINGS.items(): self._settings[setting] = kwargs.pop(setting, default) if kwargs: raise ValueError(f"Invalid global setting name: {', '.join(kwargs)}")
def __setitem__(self, key: str, value: Any) -> None: if key not in self._DEFAULT_SETTINGS: raise ValueError(f"Invalid setting: '{key}'") self._settings[key] = value def __getitem__(self, key: str) -> Any: if key not in self._DEFAULT_SETTINGS: raise ValueError(f"Invalid setting: '{key}'") return self._settings[key] # give dot access capabilities to the object def __setattr__(self, key, value): if key in self._DEFAULT_SETTINGS: self[key] = value else: object.__setattr__(self, key, value) def __getattr__(self, key): if key in self._DEFAULT_SETTINGS: return self[key] return object.__getattribute__(self, key)
[docs] def copy(self): """ Create a new GlobalSettings from the current """ return GlobalSettings(**self._settings)
[docs] def to_settings(self, model: Model) -> Dict[str, Any]: """ Create text generation settings from the GlobalSettings object :param model: Model to use the settings of """ settings = { "generate_until_sentence": self._settings["generate_until_sentence"], "num_logprobs": self._settings["num_logprobs"], "bad_words_ids": [], "logit_bias_exp": [], "repetition_penalty_whitelist": [], "return_full_text": False, "use_string": False, "use_cache": False, } # NO_LOGPROBS is used to disable logprobs (=> not in the call) if self._settings["num_logprobs"] == self.NO_LOGPROBS: del settings["num_logprobs"] tokenizer_name = Tokenizer.get_tokenizer_name(model) if self._settings["ban_brackets"]: settings["bad_words_ids"].extend(self._BRACKETS[tokenizer_name]) settings["bracket_ban"] = True if self._settings["ban_ambiguous_genji_tokens"] and tokenizer_name == "gpt2-genji": settings["bad_words_ids"].extend(self._GENJI_AMBIGUOUS_TOKENS) if self._settings["bias_dinkus_asterism"]: assert ( tokenizer_name in self._DINKUS_ASTERISM ), f"Tokenizer {tokenizer_name} not supported with bias_dinkus_asterism" bias = self._DINKUS_ASTERISM[tokenizer_name] if bias is not None: settings["logit_bias_exp"].extend(bias.get_tokenized_entries(model)) if self._settings["rep_pen_whitelist"]: settings["repetition_penalty_whitelist"].extend( Tokenizer.encode(model, tok) for tok in self._REP_PEN_WHITELIST[tokenizer_name] ) return settings