| | |
| | |
| |
|
| | import os |
| | import re |
| | from typing import Optional, List, Type |
| | import torch |
| | from library import sdxl_original_unet |
| | from library.utils import setup_logging |
| |
|
| | setup_logging() |
| | import logging |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | |
| | SKIP_INPUT_BLOCKS = False |
| |
|
| | |
| | SKIP_OUTPUT_BLOCKS = True |
| |
|
| | |
| | SKIP_CONV2D = False |
| |
|
| | |
| | |
| | TRANSFORMER_ONLY = True |
| |
|
| | |
| | ATTN1_2_ONLY = True |
| |
|
| | |
| | ATTN_QKV_ONLY = True |
| |
|
| | |
| | |
| | ATTN1_ETC_ONLY = False |
| |
|
| | |
| | |
| | TRANSFORMER_MAX_BLOCK_INDEX = None |
| |
|
| | ORIGINAL_LINEAR = torch.nn.Linear |
| | ORIGINAL_CONV2D = torch.nn.Conv2d |
| |
|
| |
|
| | def add_lllite_modules(module: torch.nn.Module, in_dim: int, depth, cond_emb_dim, mlp_dim) -> None: |
| | |
| | |
| | modules = [] |
| | modules.append(ORIGINAL_CONV2D(3, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0)) |
| | if depth == 1: |
| | modules.append(torch.nn.ReLU(inplace=True)) |
| | modules.append(ORIGINAL_CONV2D(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0)) |
| | elif depth == 2: |
| | modules.append(torch.nn.ReLU(inplace=True)) |
| | modules.append(ORIGINAL_CONV2D(cond_emb_dim // 2, cond_emb_dim, kernel_size=4, stride=4, padding=0)) |
| | elif depth == 3: |
| | |
| | modules.append(torch.nn.ReLU(inplace=True)) |
| | modules.append(ORIGINAL_CONV2D(cond_emb_dim // 2, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0)) |
| | modules.append(torch.nn.ReLU(inplace=True)) |
| | modules.append(ORIGINAL_CONV2D(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0)) |
| |
|
| | module.lllite_conditioning1 = torch.nn.Sequential(*modules) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | module.lllite_down = torch.nn.Sequential( |
| | ORIGINAL_LINEAR(in_dim, mlp_dim), |
| | torch.nn.ReLU(inplace=True), |
| | ) |
| | module.lllite_mid = torch.nn.Sequential( |
| | ORIGINAL_LINEAR(mlp_dim + cond_emb_dim, mlp_dim), |
| | torch.nn.ReLU(inplace=True), |
| | ) |
| | module.lllite_up = torch.nn.Sequential( |
| | ORIGINAL_LINEAR(mlp_dim, in_dim), |
| | ) |
| |
|
| | |
| | torch.nn.init.zeros_(module.lllite_up[0].weight) |
| |
|
| |
|
| | class LLLiteLinear(ORIGINAL_LINEAR): |
| | def __init__(self, in_features: int, out_features: int, **kwargs): |
| | super().__init__(in_features, out_features, **kwargs) |
| | self.enabled = False |
| |
|
| | def set_lllite(self, depth, cond_emb_dim, name, mlp_dim, dropout=None, multiplier=1.0): |
| | self.enabled = True |
| | self.lllite_name = name |
| | self.cond_emb_dim = cond_emb_dim |
| | self.dropout = dropout |
| | self.multiplier = multiplier |
| |
|
| | in_dim = self.in_features |
| | add_lllite_modules(self, in_dim, depth, cond_emb_dim, mlp_dim) |
| |
|
| | self.cond_image = None |
| |
|
| | def set_cond_image(self, cond_image): |
| | self.cond_image = cond_image |
| |
|
| | def forward(self, x): |
| | if not self.enabled: |
| | return super().forward(x) |
| |
|
| | cx = self.lllite_conditioning1(self.cond_image) |
| |
|
| | |
| | n, c, h, w = cx.shape |
| | cx = cx.view(n, c, h * w).permute(0, 2, 1) |
| |
|
| | cx = torch.cat([cx, self.lllite_down(x)], dim=2) |
| | cx = self.lllite_mid(cx) |
| |
|
| | if self.dropout is not None and self.training: |
| | cx = torch.nn.functional.dropout(cx, p=self.dropout) |
| |
|
| | cx = self.lllite_up(cx) * self.multiplier |
| |
|
| | x = super().forward(x + cx) |
| | return x |
| |
|
| |
|
| | class LLLiteConv2d(ORIGINAL_CONV2D): |
| | def __init__(self, in_channels: int, out_channels: int, kernel_size, **kwargs): |
| | super().__init__(in_channels, out_channels, kernel_size, **kwargs) |
| | self.enabled = False |
| |
|
| | def set_lllite(self, depth, cond_emb_dim, name, mlp_dim, dropout=None, multiplier=1.0): |
| | self.enabled = True |
| | self.lllite_name = name |
| | self.cond_emb_dim = cond_emb_dim |
| | self.dropout = dropout |
| | self.multiplier = multiplier |
| |
|
| | in_dim = self.in_channels |
| | add_lllite_modules(self, in_dim, depth, cond_emb_dim, mlp_dim) |
| |
|
| | self.cond_image = None |
| | self.cond_emb = None |
| |
|
| | def set_cond_image(self, cond_image): |
| | self.cond_image = cond_image |
| | self.cond_emb = None |
| |
|
| | def forward(self, x): |
| | if not self.enabled: |
| | return super().forward(x) |
| |
|
| | cx = self.lllite_conditioning1(self.cond_image) |
| |
|
| | cx = torch.cat([cx, self.down(x)], dim=1) |
| | cx = self.mid(cx) |
| |
|
| | if self.dropout is not None and self.training: |
| | cx = torch.nn.functional.dropout(cx, p=self.dropout) |
| |
|
| | cx = self.up(cx) * self.multiplier |
| |
|
| | x = super().forward(x + cx) |
| | return x |
| |
|
| |
|
| | class SdxlUNet2DConditionModelControlNetLLLite(sdxl_original_unet.SdxlUNet2DConditionModel): |
| | UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"] |
| | UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] |
| | LLLITE_PREFIX = "lllite_unet" |
| |
|
| | def __init__(self, **kwargs): |
| | super().__init__(**kwargs) |
| |
|
| | def apply_lllite( |
| | self, |
| | cond_emb_dim: int = 16, |
| | mlp_dim: int = 16, |
| | dropout: Optional[float] = None, |
| | varbose: Optional[bool] = False, |
| | multiplier: Optional[float] = 1.0, |
| | ) -> None: |
| | def apply_to_modules( |
| | root_module: torch.nn.Module, |
| | target_replace_modules: List[torch.nn.Module], |
| | ) -> List[torch.nn.Module]: |
| | prefix = "lllite_unet" |
| |
|
| | modules = [] |
| | for name, module in root_module.named_modules(): |
| | if module.__class__.__name__ in target_replace_modules: |
| | for child_name, child_module in module.named_modules(): |
| | is_linear = child_module.__class__.__name__ == "LLLiteLinear" |
| | is_conv2d = child_module.__class__.__name__ == "LLLiteConv2d" |
| |
|
| | if is_linear or (is_conv2d and not SKIP_CONV2D): |
| | |
| | |
| | block_name, index1, index2 = (name + "." + child_name).split(".")[:3] |
| | index1 = int(index1) |
| | if block_name == "input_blocks": |
| | if SKIP_INPUT_BLOCKS: |
| | continue |
| | depth = 1 if index1 <= 2 else (2 if index1 <= 5 else 3) |
| | elif block_name == "middle_block": |
| | depth = 3 |
| | elif block_name == "output_blocks": |
| | if SKIP_OUTPUT_BLOCKS: |
| | continue |
| | depth = 3 if index1 <= 2 else (2 if index1 <= 5 else 1) |
| | if int(index2) >= 2: |
| | depth -= 1 |
| | else: |
| | raise NotImplementedError() |
| |
|
| | lllite_name = prefix + "." + name + "." + child_name |
| | lllite_name = lllite_name.replace(".", "_") |
| |
|
| | if TRANSFORMER_MAX_BLOCK_INDEX is not None: |
| | p = lllite_name.find("transformer_blocks") |
| | if p >= 0: |
| | tf_index = int(lllite_name[p:].split("_")[2]) |
| | if tf_index > TRANSFORMER_MAX_BLOCK_INDEX: |
| | continue |
| |
|
| | |
| | |
| | |
| | |
| | if "emb_layers" in lllite_name or ( |
| | "attn2" in lllite_name and ("to_k" in lllite_name or "to_v" in lllite_name) |
| | ): |
| | continue |
| |
|
| | if ATTN1_2_ONLY: |
| | if not ("attn1" in lllite_name or "attn2" in lllite_name): |
| | continue |
| | if ATTN_QKV_ONLY: |
| | if "to_out" in lllite_name: |
| | continue |
| |
|
| | if ATTN1_ETC_ONLY: |
| | if "proj_out" in lllite_name: |
| | pass |
| | elif "attn1" in lllite_name and ( |
| | "to_k" in lllite_name or "to_v" in lllite_name or "to_out" in lllite_name |
| | ): |
| | pass |
| | elif "ff_net_2" in lllite_name: |
| | pass |
| | else: |
| | continue |
| |
|
| | child_module.set_lllite(depth, cond_emb_dim, lllite_name, mlp_dim, dropout, multiplier) |
| | modules.append(child_module) |
| |
|
| | return modules |
| |
|
| | target_modules = SdxlUNet2DConditionModelControlNetLLLite.UNET_TARGET_REPLACE_MODULE |
| | if not TRANSFORMER_ONLY: |
| | target_modules = target_modules + SdxlUNet2DConditionModelControlNetLLLite.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 |
| |
|
| | |
| | self.lllite_modules = apply_to_modules(self, target_modules) |
| | logger.info(f"enable ControlNet LLLite for U-Net: {len(self.lllite_modules)} modules.") |
| |
|
| | |
| | def prepare_params(self): |
| | train_params = [] |
| | non_train_params = [] |
| | for name, p in self.named_parameters(): |
| | if "lllite" in name: |
| | train_params.append(p) |
| | else: |
| | non_train_params.append(p) |
| | logger.info(f"count of trainable parameters: {len(train_params)}") |
| | logger.info(f"count of non-trainable parameters: {len(non_train_params)}") |
| |
|
| | for p in non_train_params: |
| | p.requires_grad_(False) |
| |
|
| | |
| | |
| | non_train_params[0].requires_grad_(True) |
| |
|
| | for p in train_params: |
| | p.requires_grad_(True) |
| |
|
| | return train_params |
| |
|
| | |
| | |
| |
|
| | |
| | |
| |
|
| | def get_trainable_params(self): |
| | return [p[1] for p in self.named_parameters() if "lllite" in p[0]] |
| |
|
| | def save_lllite_weights(self, file, dtype, metadata): |
| | if metadata is not None and len(metadata) == 0: |
| | metadata = None |
| |
|
| | org_state_dict = self.state_dict() |
| |
|
| | |
| | state_dict = {} |
| | for key in org_state_dict.keys(): |
| | |
| | pos = key.find(".lllite") |
| | if pos < 0: |
| | continue |
| | lllite_key = SdxlUNet2DConditionModelControlNetLLLite.LLLITE_PREFIX + "." + key[:pos] |
| | lllite_key = lllite_key.replace(".", "_") + key[pos:] |
| | lllite_key = lllite_key.replace(".lllite_", ".") |
| | state_dict[lllite_key] = org_state_dict[key] |
| |
|
| | if dtype is not None: |
| | for key in list(state_dict.keys()): |
| | v = state_dict[key] |
| | v = v.detach().clone().to("cpu").to(dtype) |
| | state_dict[key] = v |
| |
|
| | if os.path.splitext(file)[1] == ".safetensors": |
| | from safetensors.torch import save_file |
| |
|
| | save_file(state_dict, file, metadata) |
| | else: |
| | torch.save(state_dict, file) |
| |
|
| | def load_lllite_weights(self, file, non_lllite_unet_sd=None): |
| | r""" |
| | LLLiteの重みを読み込まない(initされた値を使う)場合はfileにNoneを指定する。 |
| | この場合、non_lllite_unet_sdにはU-Netのstate_dictを指定する。 |
| | |
| | If you do not want to load LLLite weights (use initialized values), specify None for file. |
| | In this case, specify the state_dict of U-Net for non_lllite_unet_sd. |
| | """ |
| | if not file: |
| | state_dict = self.state_dict() |
| | for key in non_lllite_unet_sd: |
| | if key in state_dict: |
| | state_dict[key] = non_lllite_unet_sd[key] |
| | info = self.load_state_dict(state_dict, False) |
| | return info |
| |
|
| | if os.path.splitext(file)[1] == ".safetensors": |
| | from safetensors.torch import load_file |
| |
|
| | weights_sd = load_file(file) |
| | else: |
| | weights_sd = torch.load(file, map_location="cpu") |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | pattern = re.compile(r"(_block|_layer|to_|time_embed|label_emb|skip_connection|proj_in|proj_out)") |
| |
|
| | |
| | state_dict = non_lllite_unet_sd.copy() if non_lllite_unet_sd is not None else {} |
| | for key in weights_sd.keys(): |
| | |
| | pos = key.find(".") |
| | if pos < 0: |
| | continue |
| |
|
| | module_name = key[:pos] |
| | weight_name = key[pos + 1 :] |
| | module_name = module_name.replace(SdxlUNet2DConditionModelControlNetLLLite.LLLITE_PREFIX + "_", "") |
| |
|
| | |
| | |
| |
|
| | |
| | matches = pattern.findall(module_name) |
| | if matches is not None: |
| | for m in matches: |
| | logger.info(f"{module_name} {m}") |
| | module_name = module_name.replace(m, m.replace("_", "@")) |
| | module_name = module_name.replace("_", ".") |
| | module_name = module_name.replace("@", "_") |
| |
|
| | lllite_key = module_name + ".lllite_" + weight_name |
| |
|
| | state_dict[lllite_key] = weights_sd[key] |
| |
|
| | info = self.load_state_dict(state_dict, False) |
| | return info |
| |
|
| | def forward(self, x, timesteps=None, context=None, y=None, cond_image=None, **kwargs): |
| | for m in self.lllite_modules: |
| | m.set_cond_image(cond_image) |
| | return super().forward(x, timesteps, context, y, **kwargs) |
| |
|
| |
|
| | def replace_unet_linear_and_conv2d(): |
| | logger.info("replace torch.nn.Linear and torch.nn.Conv2d to LLLiteLinear and LLLiteConv2d in U-Net") |
| | sdxl_original_unet.torch.nn.Linear = LLLiteLinear |
| | sdxl_original_unet.torch.nn.Conv2d = LLLiteConv2d |
| |
|
| |
|
| | if __name__ == "__main__": |
| | |
| |
|
| | |
| | replace_unet_linear_and_conv2d() |
| |
|
| | |
| | logger.info("create unet") |
| | unet = SdxlUNet2DConditionModelControlNetLLLite() |
| |
|
| | logger.info("enable ControlNet-LLLite") |
| | unet.apply_lllite(32, 64, None, False, 1.0) |
| | unet.to("cuda") |
| |
|
| | |
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| |
|
| | |
| | params = unet.prepare_params() |
| | logger.info(f"number of parameters {sum(p.numel() for p in params)}") |
| | |
| | |
| |
|
| | unet.set_use_memory_efficient_attention(True, False) |
| | unet.set_gradient_checkpointing(True) |
| | unet.train() |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import bitsandbytes |
| |
|
| | optimizer = bitsandbytes.adam.Adam8bit(params, 1e-3) |
| |
|
| | scaler = torch.cuda.amp.GradScaler(enabled=True) |
| |
|
| | logger.info("start training") |
| | steps = 10 |
| | batch_size = 1 |
| |
|
| | sample_param = [p for p in unet.named_parameters() if ".lllite_up." in p[0]][0] |
| | for step in range(steps): |
| | logger.info(f"step {step}") |
| |
|
| | conditioning_image = torch.rand(batch_size, 3, 1024, 1024).cuda() * 2.0 - 1.0 |
| | x = torch.randn(batch_size, 4, 128, 128).cuda() |
| | t = torch.randint(low=0, high=10, size=(batch_size,)).cuda() |
| | ctx = torch.randn(batch_size, 77, 2048).cuda() |
| | y = torch.randn(batch_size, sdxl_original_unet.ADM_IN_CHANNELS).cuda() |
| |
|
| | with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16): |
| | output = unet(x, t, ctx, y, conditioning_image) |
| | target = torch.randn_like(output) |
| | loss = torch.nn.functional.mse_loss(output, target) |
| |
|
| | scaler.scale(loss).backward() |
| | scaler.step(optimizer) |
| | scaler.update() |
| | optimizer.zero_grad(set_to_none=True) |
| | logger.info(sample_param) |
| |
|
| | |
| |
|
| | |
| | |
| |
|