Spaces:
Configuration error
Configuration error
| """ | |
| Adapted from comfyui CLIP code. | |
| https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/sd1_clip.py | |
| """ | |
| import os | |
| from transformers import T5Tokenizer, T5EncoderModel, T5Config, modeling_utils | |
| import torch | |
| import traceback | |
| import zipfile | |
| from comfy import model_management | |
| from comfy.sd1_clip import parse_parentheses, token_weights, escape_important, unescape_important, safe_load_embed_zip, expand_directory_list, load_embed | |
| class T5v11Model(torch.nn.Module): | |
| def __init__(self, textmodel_ver="xxl", textmodel_json_config=None, textmodel_path=None, device="cpu", max_length=120, freeze=True, dtype=None): | |
| super().__init__() | |
| self.num_layers = 24 | |
| self.max_length = max_length | |
| self.bnb = False | |
| if textmodel_path is not None: | |
| model_args = {} | |
| model_args["low_cpu_mem_usage"] = True # Don't take 2x system ram on cpu | |
| if dtype == "bnb8bit": | |
| self.bnb = True | |
| model_args["load_in_8bit"] = True | |
| elif dtype == "bnb4bit": | |
| self.bnb = True | |
| model_args["load_in_4bit"] = True | |
| else: | |
| if dtype: model_args["torch_dtype"] = dtype | |
| self.bnb = False | |
| # second GPU offload hack part 2 | |
| if device.startswith("cuda"): | |
| model_args["device_map"] = device | |
| print(f"Loading T5 from '{textmodel_path}'") | |
| self.transformer = T5EncoderModel.from_pretrained(textmodel_path, **model_args) | |
| else: | |
| if textmodel_json_config is None: | |
| textmodel_json_config = os.path.join( | |
| os.path.dirname(os.path.realpath(__file__)), | |
| f"t5v11-{textmodel_ver}_config.json" | |
| ) | |
| config = T5Config.from_json_file(textmodel_json_config) | |
| self.num_layers = config.num_hidden_layers | |
| with modeling_utils.no_init_weights(): | |
| self.transformer = T5EncoderModel(config) | |
| if freeze: | |
| self.freeze() | |
| self.empty_tokens = [[0] * self.max_length] # <pad> token | |
| def freeze(self): | |
| self.transformer = self.transformer.eval() | |
| for param in self.parameters(): | |
| param.requires_grad = False | |
| def forward(self, tokens): | |
| device = self.transformer.get_input_embeddings().weight.device | |
| tokens = torch.LongTensor(tokens).to(device) | |
| attention_mask = torch.zeros_like(tokens) | |
| max_token = 1 # </s> token | |
| for x in range(attention_mask.shape[0]): | |
| for y in range(attention_mask.shape[1]): | |
| attention_mask[x, y] = 1 | |
| if tokens[x, y] == max_token: | |
| break | |
| outputs = self.transformer(input_ids=tokens, attention_mask=attention_mask) | |
| z = outputs['last_hidden_state'] | |
| z.detach().cpu().float() | |
| return z | |
| def encode(self, tokens): | |
| return self(tokens) | |
| def load_sd(self, sd): | |
| return self.transformer.load_state_dict(sd, strict=False) | |
| def to(self, *args, **kwargs): | |
| """BNB complains if you try to change the device or dtype""" | |
| if self.bnb: | |
| print("Thanks to BitsAndBytes, T5 becomes an immovable rock.", args, kwargs) | |
| else: | |
| self.transformer.to(*args, **kwargs) | |
| def encode_token_weights(self, token_weight_pairs, return_padded=False): | |
| to_encode = list(self.empty_tokens) | |
| for x in token_weight_pairs: | |
| tokens = list(map(lambda a: a[0], x)) | |
| to_encode.append(tokens) | |
| out = self.encode(to_encode) | |
| z_empty = out[0:1] | |
| output = [] | |
| for k in range(1, out.shape[0]): | |
| z = out[k:k+1] | |
| for i in range(len(z)): | |
| for j in range(len(z[i])): | |
| weight = token_weight_pairs[k - 1][j][1] | |
| z[i][j] = (z[i][j] - z_empty[0][j]) * weight + z_empty[0][j] | |
| output.append(z) | |
| if (len(output) == 0): | |
| return z_empty.cpu() | |
| out = torch.cat(output, dim=-2) | |
| if not return_padded: | |
| # Count number of tokens that aren't <pad>, then use that number as an index. | |
| keep_index = sum([sum([1 for y in x if y[0] != 0]) for x in token_weight_pairs]) | |
| out = out[:, :keep_index, :] | |
| return out | |
| class T5v11Tokenizer: | |
| """ | |
| This is largely just based on the ComfyUI CLIP code. | |
| """ | |
| def __init__(self, tokenizer_path=None, max_length=120, embedding_directory=None, embedding_size=4096, embedding_key='t5'): | |
| if tokenizer_path is None: | |
| tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer") | |
| self.tokenizer = T5Tokenizer.from_pretrained(tokenizer_path) | |
| self.max_length = max_length | |
| self.max_tokens_per_section = self.max_length - 1 # </s> but no <BOS> | |
| self.pad_token = self.tokenizer("<pad>", add_special_tokens=False)["input_ids"][0] | |
| self.end_token = self.tokenizer("</s>", add_special_tokens=False)["input_ids"][0] | |
| vocab = self.tokenizer.get_vocab() | |
| self.inv_vocab = {v: k for k, v in vocab.items()} | |
| self.embedding_directory = embedding_directory | |
| self.max_word_length = 8 # haven't verified this | |
| self.embedding_identifier = "embedding:" | |
| self.embedding_size = embedding_size | |
| self.embedding_key = embedding_key | |
| def _try_get_embedding(self, embedding_name:str): | |
| ''' | |
| Takes a potential embedding name and tries to retrieve it. | |
| Returns a Tuple consisting of the embedding and any leftover string, embedding can be None. | |
| ''' | |
| embed = load_embed(embedding_name, self.embedding_directory, self.embedding_size, self.embedding_key) | |
| if embed is None: | |
| stripped = embedding_name.strip(',') | |
| if len(stripped) < len(embedding_name): | |
| embed = load_embed(stripped, self.embedding_directory, self.embedding_size, self.embedding_key) | |
| return (embed, embedding_name[len(stripped):]) | |
| return (embed, "") | |
| def tokenize_with_weights(self, text:str, return_word_ids=False): | |
| ''' | |
| Takes a prompt and converts it to a list of (token, weight, word id) elements. | |
| Tokens can both be integer tokens and pre computed T5 tensors. | |
| Word id values are unique per word and embedding, where the id 0 is reserved for non word tokens. | |
| Returned list has the dimensions NxM where M is the input size of T5 | |
| ''' | |
| pad_token = self.pad_token | |
| text = escape_important(text) | |
| parsed_weights = token_weights(text, 1.0) | |
| #tokenize words | |
| tokens = [] | |
| for weighted_segment, weight in parsed_weights: | |
| to_tokenize = unescape_important(weighted_segment).replace("\n", " ").split(' ') | |
| to_tokenize = [x for x in to_tokenize if x != ""] | |
| for word in to_tokenize: | |
| #if we find an embedding, deal with the embedding | |
| if word.startswith(self.embedding_identifier) and self.embedding_directory is not None: | |
| embedding_name = word[len(self.embedding_identifier):].strip('\n') | |
| embed, leftover = self._try_get_embedding(embedding_name) | |
| if embed is None: | |
| print(f"warning, embedding:{embedding_name} does not exist, ignoring") | |
| else: | |
| if len(embed.shape) == 1: | |
| tokens.append([(embed, weight)]) | |
| else: | |
| tokens.append([(embed[x], weight) for x in range(embed.shape[0])]) | |
| #if we accidentally have leftover text, continue parsing using leftover, else move on to next word | |
| if leftover != "": | |
| word = leftover | |
| else: | |
| continue | |
| #parse word | |
| tokens.append([(t, weight) for t in self.tokenizer(word, add_special_tokens=False)["input_ids"]]) | |
| #reshape token array to T5 input size | |
| batched_tokens = [] | |
| batch = [] | |
| batched_tokens.append(batch) | |
| for i, t_group in enumerate(tokens): | |
| #determine if we're going to try and keep the tokens in a single batch | |
| is_large = len(t_group) >= self.max_word_length | |
| while len(t_group) > 0: | |
| if len(t_group) + len(batch) > self.max_length - 1: | |
| remaining_length = self.max_length - len(batch) - 1 | |
| #break word in two and add end token | |
| if is_large: | |
| batch.extend([(t,w,i+1) for t,w in t_group[:remaining_length]]) | |
| batch.append((self.end_token, 1.0, 0)) | |
| t_group = t_group[remaining_length:] | |
| #add end token and pad | |
| else: | |
| batch.append((self.end_token, 1.0, 0)) | |
| batch.extend([(self.pad_token, 1.0, 0)] * (remaining_length)) | |
| #start new batch | |
| batch = [] | |
| batched_tokens.append(batch) | |
| else: | |
| batch.extend([(t,w,i+1) for t,w in t_group]) | |
| t_group = [] | |
| # fill last batch | |
| batch.extend([(self.end_token, 1.0, 0)] + [(self.pad_token, 1.0, 0)] * (self.max_length - len(batch) - 1)) | |
| # instead of filling, just add EOS (DEBUG) | |
| # batch.extend([(self.end_token, 1.0, 0)]) | |
| if not return_word_ids: | |
| batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens] | |
| return batched_tokens | |
| def untokenize(self, token_weight_pair): | |
| return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair)) | |