Source code for novelai_api.BanList

from typing import Dict, Iterable, List, Union

from novelai_api.Preset import Model
from novelai_api.utils import tokenize_if_not


[docs]class BanList: _sequences: List[Union[List[int], str]] enabled: bool
[docs] def __init__(self, *sequences: Union[List[int], str], enabled: bool = True): """ Create a ban list with the given elements. Elements can be string or tokenized strings Using tokenized strings is not recommended, for flexibility between tokenizers :param enabled: Is the ban list enabled """ self.enabled = enabled self._sequences = [] if sequences: self.add(*sequences)
[docs] def add( self, *sequences: Union[Dict[str, List[List[int]]], Dict[str, List[int]], List[int], str], ) -> "BanList": """ Add elements to the ban list. Elements can be string or tokenized strings Using tokenized strings is not recommended, for flexibility between tokenizers """ for i, sequence in enumerate(sequences): if "sequence" in sequence: sequence = sequence["sequence"] elif "sequences" in sequence: sequence = sequence["sequences"][0] if not isinstance(sequence, str): if not isinstance(sequence, list): raise ValueError( f"Expected type 'List[int]' for sequence #{i} of 'sequences', " f"but got '{type(sequence)}'" ) for j, s in enumerate(sequence): if not isinstance(s, int): raise ValueError( f"Expected type 'int' for item #{j} of sequence #{i} of 'sequences', " f"but got '{type(s)}': {sequence}" ) self._sequences.append(sequence) return self
def __iadd__(self, o: Union[List[int], str]) -> "BanList": """ Add elements to the ban list. Elements can be string or tokenized strings Using tokenized strings is not recommended, for flexibility between tokenizers """ self.add(o) return self def __iter__(self): """ Return an iterator on the stored sequences """ return self._sequences.__iter__()
[docs] def get_tokenized_entries(self, model: Model) -> Iterable[List[int]]: """ Return the tokenized sequences for the ban list, if it is enabled :param model: Model to use for tokenization """ return (tokenize_if_not(model, s) for s in self._sequences if self.enabled)
def __str__(self) -> str: return self._sequences.__str__()