from typing import TYPE_CHECKING, Any, Dict
from novelai_api.BiasGroup import BiasGroup
from novelai_api.Preset import Model
from novelai_api.python_utils import expand_kwargs
from novelai_api.Tokenizer import Tokenizer
[docs]class GlobalSettings:
"""
Object used to store global settings for the account
"""
# TODO: store bracket ban in a file
_BRACKETS = {
"gpt2": [
[58],
[60],
[90],
[92],
[685],
[1391],
[1782],
[2361],
[3693],
[4083],
[4357],
[4895],
[5512],
[5974],
[7131],
[8183],
[8351],
[8762],
[8964],
[8973],
[9063],
[11208],
[11709],
[11907],
[11919],
[12878],
[12962],
[13018],
[13412],
[14631],
[14692],
[14980],
[15090],
[15437],
[16151],
[16410],
[16589],
[17241],
[17414],
[17635],
[17816],
[17912],
[18083],
[18161],
[18477],
[19629],
[19779],
[19953],
[20520],
[20598],
[20662],
[20740],
[21476],
[21737],
[22133],
[22241],
[22345],
[22935],
[23330],
[23785],
[23834],
[23884],
[25295],
[25597],
[25719],
[25787],
[25915],
[26076],
[26358],
[26398],
[26894],
[26933],
[27007],
[27422],
[28013],
[29164],
[29225],
[29342],
[29565],
[29795],
[30072],
[30109],
[30138],
[30866],
[31161],
[31478],
[32092],
[32239],
[32509],
[33116],
[33250],
[33761],
[34171],
[34758],
[34949],
[35944],
[36338],
[36463],
[36563],
[36786],
[36796],
[36937],
[37250],
[37913],
[37981],
[38165],
[38362],
[38381],
[38430],
[38892],
[39850],
[39893],
[41832],
[41888],
[42535],
[42669],
[42785],
[42924],
[43839],
[44438],
[44587],
[44926],
[45144],
[45297],
[46110],
[46570],
[46581],
[46956],
[47175],
[47182],
[47527],
[47715],
[48600],
[48683],
[48688],
[48874],
[48999],
[49074],
[49082],
[49146],
[49946],
[10221],
[4841],
[1427],
[2602, 834],
[29343],
[37405],
[35780],
[2602],
[50256],
],
"gpt2-genji": [],
"pile": [
[60],
[62],
[544],
[683],
[696],
[880],
[905],
[1008],
[1019],
[1084],
[1092],
[1181],
[1184],
[1254],
[1447],
[1570],
[1656],
[2194],
[2470],
[2479],
[2498],
[2947],
[3138],
[3291],
[3455],
[3725],
[3851],
[3891],
[3921],
[3951],
[4207],
[4299],
[4622],
[4681],
[5013],
[5032],
[5180],
[5218],
[5290],
[5413],
[5456],
[5709],
[5749],
[5774],
[6038],
[6257],
[6334],
[6660],
[6904],
[7082],
[7086],
[7254],
[7444],
[7748],
[8001],
[8088],
[8168],
[8562],
[8605],
[8795],
[8850],
[9014],
[9102],
[9259],
[9318],
[9336],
[9502],
[9686],
[9793],
[9855],
[9899],
[9955],
[10148],
[10174],
[10943],
[11326],
[11337],
[11661],
[12004],
[12084],
[12159],
[12520],
[12977],
[13380],
[13488],
[13663],
[13811],
[13976],
[14412],
[14598],
[14767],
[15640],
[15707],
[15775],
[15830],
[16079],
[16354],
[16369],
[16445],
[16595],
[16614],
[16731],
[16943],
[17278],
[17281],
[17548],
[17555],
[17981],
[18022],
[18095],
[18297],
[18413],
[18736],
[18772],
[18990],
[19181],
[20095],
[20197],
[20481],
[20629],
[20871],
[20879],
[20924],
[20977],
[21375],
[21382],
[21391],
[21687],
[21810],
[21828],
[21938],
[22367],
[22372],
[22734],
[23405],
[23505],
[23734],
[23741],
[23781],
[24237],
[24254],
[24345],
[24430],
[25416],
[25896],
[26119],
[26635],
[26842],
[26991],
[26997],
[27075],
[27114],
[27468],
[27501],
[27618],
[27655],
[27720],
[27829],
[28052],
[28118],
[28231],
[28532],
[28571],
[28591],
[28653],
[29013],
[29547],
[29650],
[29925],
[30522],
[30537],
[30996],
[31011],
[31053],
[31096],
[31148],
[31258],
[31350],
[31379],
[31422],
[31789],
[31830],
[32214],
[32666],
[32871],
[33094],
[33376],
[33440],
[33805],
[34368],
[34398],
[34417],
[34418],
[34419],
[34476],
[34494],
[34607],
[34758],
[34761],
[34904],
[34993],
[35117],
[35138],
[35237],
[35487],
[35830],
[35869],
[36033],
[36134],
[36320],
[36399],
[36487],
[36586],
[36676],
[36692],
[36786],
[37077],
[37594],
[37596],
[37786],
[37982],
[38475],
[38791],
[39083],
[39258],
[39487],
[39822],
[40116],
[40125],
[41000],
[41018],
[41256],
[41305],
[41361],
[41447],
[41449],
[41512],
[41604],
[42041],
[42274],
[42368],
[42696],
[42767],
[42804],
[42854],
[42944],
[42989],
[43134],
[43144],
[43189],
[43521],
[43782],
[44082],
[44162],
[44270],
[44308],
[44479],
[44524],
[44965],
[45114],
[45301],
[45382],
[45443],
[45472],
[45488],
[45507],
[45564],
[45662],
[46265],
[46267],
[46275],
[46295],
[46462],
[46468],
[46576],
[46694],
[47093],
[47384],
[47389],
[47446],
[47552],
[47686],
[47744],
[47916],
[48064],
[48167],
[48392],
[48471],
[48664],
[48701],
[49021],
[49193],
[49236],
[49550],
[49694],
[49806],
[49824],
[50001],
[50256],
[0],
[1],
],
"nerdstash_v1": [
[3],
[49356],
[1431],
[31715],
[34387],
[20765],
[30702],
[10691],
[49333],
[1266],
[19438],
[43145],
[26523],
[41471],
[2936],
[85, 85],
[49332],
[7286],
[1115],
],
"nerdstash_v2": [
[3],
[49356],
[1431],
[31715],
[34387],
[20765],
[30702],
[10691],
[49333],
[1266],
[19438],
[43145],
[26523],
[41471],
[2936],
[85, 85],
[49332],
[7286],
[1115],
],
"llama3": [
[16067],
[933, 11144],
[25106, 11144],
[58, 106901, 16073, 33710, 25, 109933],
[933, 58, 11144],
[128030],
[58, 30591, 33503, 17663, 100204, 25, 11144],
],
}
# this one is pretty much mandatory, else genji throws errors all the time
# NOTE: the list might reach more tokens than necessary, but better be safe than sorry
_GENJI_AMBIGUOUS_TOKENS = [
[58],
[60],
[90],
[92],
[685],
[1391],
[1782],
[2361],
[3693],
[4083],
[4357],
[4895],
[5512],
[5974],
[7131],
[8183],
[8964],
[8973],
[11208],
[11709],
[11907],
[12878],
[12962],
[13412],
[14692],
[14980],
[15090],
[15437],
[16151],
[16589],
[17241],
[17414],
[17635],
[17816],
[17912],
[18083],
[18161],
[18477],
[19629],
[19953],
[20520],
[20662],
[20740],
[21737],
[22241],
[23330],
[23834],
[26076],
[26398],
[26894],
[26933],
[27007],
[27422],
[29164],
[29795],
[30072],
[30109],
[30866],
[31478],
[32239],
[33250],
[34758],
[34949],
[36463],
[36786],
[36796],
[36937],
[37981],
[38165],
[38362],
[38381],
[38430],
[38892],
[39850],
[41832],
[41888],
[42535],
[42669],
[42924],
[43839],
[44587],
[45297],
[46956],
[47527],
[47715],
[48688],
[48874],
[48999],
[49146],
[49946],
[50256],
[8162],
[198, 198, 198],
[49564, 198, 49564],
[37605],
[22522],
[40265],
[5099],
[39752],
[32368],
[49564, 49564, 49564, 49564, 49564, 49564],
[31857, 31857, 31857, 31857, 31857, 31857],
[17992],
[39187],
[40367],
[15790],
[47571],
[27032],
[628, 198],
[628, 628],
[49564, 198, 49564],
[28156],
[30298],
[34650],
]
# whitelist
_REP_PEN_WHITELIST = {
"gpt2": [],
"gpt2-genji": [],
"pile": [],
"nerdstash_v1": [
"'",
'"',
",",
".",
":",
"\n",
"ve",
"s",
"t",
"n",
"d",
"ll",
"re",
"m",
"-",
"*",
")",
" the",
" a",
" an",
" and",
" or",
" not",
" no",
" is",
" was",
" were",
" did",
" does",
" isn",
" wasn",
" weren",
" didn",
" doesn",
" him",
" her",
" his",
" hers",
" their",
" its",
" could",
" couldn",
" should",
" shouldn",
" would",
" wouldn",
" have",
" haven",
" had",
" hadn",
" has",
" hasn",
" can",
" cannot",
" are",
" aren",
" will",
" won",
"0",
"1",
"2",
"3",
"4",
"5",
"6",
"7",
"8",
"9",
'."',
',"',
"====",
" ",
],
"nerdstash_v2": [
"'",
'"',
",",
".",
":",
"\n",
"-",
"*",
")",
" the",
" a",
" an",
" and",
" or",
" not",
" no",
" is",
" was",
" were",
" did",
" does",
" isn",
" wasn",
" weren",
" didn",
" doesn",
" him",
" her",
" his",
" hers",
" their",
" its",
" could",
" couldn",
" should",
" shouldn",
" would",
" wouldn",
" have",
" haven",
" had",
" hadn",
" has",
" hasn",
" can",
" cannot",
" are",
" aren",
" will",
" won",
"0",
"1",
"2",
"3",
"4",
"5",
"6",
"7",
"8",
"9",
'."',
',"',
"====",
" ",
"'t've",
"'s",
"'t",
"'ve",
"'n",
"'ll",
"'d",
"'re",
"'m",
],
"llama3": [
'"',
"'",
")",
"*",
",",
"-",
".",
"0",
"1",
"2",
"3",
"4",
"5",
"6",
"7",
"8",
"9",
":",
"\n",
" ",
" a",
" the",
" in",
" to",
" of",
" and",
" is",
" on",
" that",
" an",
" or",
" from",
" at",
" are",
" not",
" was",
"'s",
"====",
" have",
" can",
" will",
" has",
" his",
" their",
" no",
"'t",
" had",
" were",
" would",
" her",
" its",
'."',
" should",
',"',
" could",
" him",
" did",
" does",
"'re",
" won",
"'m",
"'ve",
" doesn",
" didn",
"'ll",
" cannot",
"'d",
" isn",
" wasn",
" aren",
" couldn",
" wouldn",
" haven",
" hers",
" hasn",
" shouldn",
" weren",
" hadn",
"'n",
],
}
_DINKUS_ASTERISM = {
"gpt2": BiasGroup(-0.12).add("***", "⁂"),
"gpt2-genji": None,
"pile": BiasGroup(-0.12).add("***"),
"nerdstash_v1": BiasGroup(-0.08).add("***", "⁂"),
"nerdstash_v2": BiasGroup(-0.08).add("***", "⁂"),
"llama3": None,
}
_DEFAULT_SETTINGS = {
"generate_until_sentence": False,
"num_logprobs": 10,
"ban_brackets": True,
"bias_dinkus_asterism": False,
"ban_ambiguous_genji_tokens": True,
"rep_pen_whitelist": True,
}
# type completion for __setitem__ and __getitem__
if TYPE_CHECKING:
#: Generate up to 20 tokens after max_length if an end of sentence if found within these 20 tokens
generate_until_sentence: bool
#: Number of logprobs to return for each token. Set to NO_LOGPROBS to disable
num_logprobs: int
#: Apply the BRACKET biases
ban_brackets: bool
#: Apply the DINKUS_ASTERISM biases
bias_dinkus_asterism: bool
#: Apply the GENJI_AMBIGUOUS_TOKENS if model is Genji
ban_ambiguous_genji_tokens: bool
#: Apply the REP_PEN_WHITELIST (repetition penalty whitelist)
rep_pen_whitelist: bool
#: Value to set num_logprobs at to disable logprobs
NO_LOGPROBS = -1
_settings: Dict[str, Any]
[docs] @expand_kwargs(_DEFAULT_SETTINGS.keys(), (type(e) for e in _DEFAULT_SETTINGS.values()))
def __init__(self, **kwargs):
object.__setattr__(self, "_settings", {})
for setting, default in self._DEFAULT_SETTINGS.items():
self._settings[setting] = kwargs.pop(setting, default)
if kwargs:
raise ValueError(f"Invalid global setting name: {', '.join(kwargs)}")
def __setitem__(self, key: str, value: Any) -> None:
if key not in self._DEFAULT_SETTINGS:
raise ValueError(f"Invalid setting: '{key}'")
self._settings[key] = value
def __getitem__(self, key: str) -> Any:
if key not in self._DEFAULT_SETTINGS:
raise ValueError(f"Invalid setting: '{key}'")
return self._settings[key]
# give dot access capabilities to the object
def __setattr__(self, key, value):
if key in self._DEFAULT_SETTINGS:
self[key] = value
else:
object.__setattr__(self, key, value)
def __getattr__(self, key):
if key in self._DEFAULT_SETTINGS:
return self[key]
return object.__getattribute__(self, key)
[docs] def copy(self):
"""
Create a new GlobalSettings from the current
"""
return GlobalSettings(**self._settings)
[docs] def to_settings(self, model: Model) -> Dict[str, Any]:
"""
Create text generation settings from the GlobalSettings object
:param model: Model to use the settings of
"""
settings = {
"generate_until_sentence": self._settings["generate_until_sentence"],
"num_logprobs": self._settings["num_logprobs"],
"bad_words_ids": [],
"logit_bias_exp": [],
"repetition_penalty_whitelist": [],
"return_full_text": False,
"use_string": False,
"use_cache": False,
}
# NO_LOGPROBS is used to disable logprobs (=> not in the call)
if self._settings["num_logprobs"] == self.NO_LOGPROBS:
del settings["num_logprobs"]
tokenizer_name = Tokenizer.get_tokenizer_name(model)
if self._settings["ban_brackets"]:
settings["bad_words_ids"].extend(self._BRACKETS[tokenizer_name])
settings["bracket_ban"] = True
if self._settings["ban_ambiguous_genji_tokens"] and tokenizer_name == "gpt2-genji":
settings["bad_words_ids"].extend(self._GENJI_AMBIGUOUS_TOKENS)
if self._settings["bias_dinkus_asterism"]:
assert (
tokenizer_name in self._DINKUS_ASTERISM
), f"Tokenizer {tokenizer_name} not supported with bias_dinkus_asterism"
bias = self._DINKUS_ASTERISM[tokenizer_name]
if bias is not None:
settings["logit_bias_exp"].extend(bias.get_tokenized_entries(model))
if self._settings["rep_pen_whitelist"]:
settings["repetition_penalty_whitelist"].extend(
Tokenizer.encode(model, tok) for tok in self._REP_PEN_WHITELIST[tokenizer_name]
)
return settings