Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions merge_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@
type=click.Choice(["DEBUG", "INFO", "WARNING", "ERROR"], case_sensitive=False),
default="INFO",
)
@click.option("-xl", "--sdxl", "sdxl", is_flag=True)
def main(
model_a,
model_b,
Expand All @@ -138,6 +139,7 @@ def main(
presets_alpha_lambda,
presets_beta_lambda,
logging_level,
sdxl,
):
if logging_level:
logging.basicConfig(format="%(levelname)s: %(message)s", level=logging_level)
Expand All @@ -158,6 +160,7 @@ def main(
block_weights_preset_beta_b,
presets_alpha_lambda,
presets_beta_lambda,
sdxl,
)

merged = merge_models(
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "sd-meh"
version = "0.9.4"
version = "0.10.0"
description = "stable diffusion merging execution helper"
authors = ["s1dlx <[email protected]>"]
license = "MIT"
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ safetensors
torch
tqdm
tensordict
cupy
scipy
2 changes: 1 addition & 1 deletion sd_meh/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.9.4"
__version__ = "0.10.0"
106 changes: 101 additions & 5 deletions sd_meh/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
import logging
import os
import re
from typing import Tuple
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from pathlib import Path
from typing import Dict, Optional, Tuple
import numpy as np

import safetensors.torch
import torch
Expand All @@ -21,13 +23,18 @@
weight_matching,
)


logging.getLogger("sd_meh").addHandler(logging.NullHandler())
MAX_TOKENS = 77
NUM_INPUT_BLOCKS = 12
NUM_MID_BLOCK = 1
NUM_OUTPUT_BLOCKS = 12
NUM_TOTAL_BLOCKS = NUM_INPUT_BLOCKS + NUM_MID_BLOCK + NUM_OUTPUT_BLOCKS

NUM_INPUT_BLOCKS_XL = 9
NUM_OUTPUT_BLOCKS_XL = 9
NUM_TOTAL_BLOCKS_XL = NUM_INPUT_BLOCKS_XL + NUM_MID_BLOCK + NUM_OUTPUT_BLOCKS_XL

KEY_POSITION_IDS = ".".join(
[
"cond_stage_model",
Expand Down Expand Up @@ -141,9 +148,16 @@ def merge_models(
work_device: Optional[str] = None,
prune: bool = False,
threads: int = 1,
sum_mode: str = 'normal',
diff_mode: str = 'normal',
) -> Dict:
thetas = load_thetas(models, prune, device, precision)

sdxl = (
"conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.weight"
in thetas["model_a"].keys()
)

logging.info(f"start merging with {merge_mode} method")
if re_basin:
merged = rebasin_merge(
Expand All @@ -157,6 +171,9 @@ def merge_models(
device=device,
work_device=work_device,
threads=threads,
sdxl=sdxl,
sum_mode=sum_mode,
diff_mode=diff_mode,
)
else:
merged = simple_merge(
Expand All @@ -169,6 +186,9 @@ def merge_models(
device=device,
work_device=work_device,
threads=threads,
sdxl=sdxl,
sum_mode=sum_mode,
diff_mode=diff_mode,
)

return un_prune_model(merged, thetas, models, device, prune, precision)
Expand Down Expand Up @@ -221,8 +241,16 @@ def simple_merge(
device: str = "cpu",
work_device: Optional[str] = None,
threads: int = 1,
sdxl: bool = False,
sum_mode: str = 'normal',
diff_mode: str = 'normal',
) -> Dict:
futures = []
sim = None
sims = None
if sum_mode in ['cos_a', 'cos_b']:
sims, sims = get_cos_similarity(thetas, sum_mode)

with tqdm(thetas["model_a"].keys(), desc="stage 1") as progress:
with ThreadPoolExecutor(max_workers=threads) as executor:
for key in thetas["model_a"].keys():
Expand All @@ -238,6 +266,11 @@ def simple_merge(
weights_clip,
device,
work_device,
sdxl,
sum_mode,
diff_mode,
sim,
sims,
)
futures.append(future)

Expand Down Expand Up @@ -270,6 +303,9 @@ def rebasin_merge(
device="cpu",
work_device=None,
threads: int = 1,
sdxl: bool = False,
sum_mode: str = 'normal',
diff_mode: str = 'normal',
):
# WARNING: not sure how this does when 3 models are involved...

Expand Down Expand Up @@ -299,6 +335,9 @@ def rebasin_merge(
device,
work_device,
threads,
sdxl,
sum_mode,
diff_mode,
)

log_vram("simple merge done")
Expand Down Expand Up @@ -367,6 +406,9 @@ def merge_key(
weights_clip: bool = False,
device: str = "cpu",
work_device: Optional[str] = None,
sdxl: bool = False,
sim=None,
sims=None,
) -> Optional[Tuple[str, Dict]]:
if work_device is None:
work_device = device
Expand All @@ -391,16 +433,22 @@ def merge_key(
if "time_embed" in key:
weight_index = 0 # before input blocks
elif ".out." in key:
weight_index = NUM_TOTAL_BLOCKS - 1 # after output blocks
weight_index = (
NUM_TOTAL_BLOCKS_XL - 1 if sdxl else NUM_TOTAL_BLOCKS - 1
) # after output blocks
elif m := re_inp.search(key):
weight_index = int(m.groups()[0])
elif re_mid.search(key):
weight_index = NUM_INPUT_BLOCKS
weight_index = NUM_INPUT_BLOCKS_XL if sdxl else NUM_INPUT_BLOCKS
elif m := re_out.search(key):
weight_index = NUM_INPUT_BLOCKS + NUM_MID_BLOCK + int(m.groups()[0])
weight_index = (
(NUM_INPUT_BLOCKS_XL if sdxl else NUM_INPUT_BLOCKS)
+ NUM_MID_BLOCK
+ int(m.groups()[0])
)

if weight_index >= NUM_TOTAL_BLOCKS:
raise ValueError(f"illegal block index {key}")
if weight_index >= (NUM_TOTAL_BLOCKS_XL if sdxl else NUM_TOTAL_BLOCKS):
raise ValueError(f"illegal block index {weight_index} for key {key}")

if weight_index >= 0:
current_bases = {k: w[weight_index] for k, w in weights.items()}
Expand Down Expand Up @@ -460,6 +508,8 @@ def get_merge_method_args(
thetas: Dict,
key: str,
work_device: str,
sim=None,
sims=None,
) -> Dict:
merge_method_args = {
"a": thetas["model_a"][key].to(work_device),
Expand All @@ -470,6 +520,10 @@ def get_merge_method_args(
if "model_c" in thetas:
merge_method_args["c"] = thetas["model_c"][key].to(work_device)

if sim is not None:
merge_method_args["sim"] = sim
merge_method_args["sims"] = sims

return merge_method_args


Expand All @@ -483,3 +537,45 @@ def save_model(model, output_file, file_format) -> None:
)
else:
torch.save({"state_dict": model}, f"{output_file}.ckpt")


def get_cos_similarity(thetas, sum_mode):
sim = torch.nn.CosineSimilarity(dim=0)
sims = np.array([], dtype=np.float64)
for key in tqdm(thetas["model_a"].keys(), desc="stage 0"):
# skip VAE model parameters to get better results
if "first_stage_model" in key:
continue
if "model" in key and key in thetas["model_b"].keys():
if sum_mode == 'cos_a':
theta_A_norm = torch.nn.functional.normalize(
thetas["model_a"][key].to(torch.float32), p=2, dim=0
)
theta_B_norm = torch.nn.functional.normalize(
thetas["model_b"][key].to(torch.float32), p=2, dim=0
)
simab = sim(theta_A_norm, theta_B_norm)
sims = np.append(sims, simab.numpy())
elif sum_mode == 'cos_b':
simab = sim(
thetas["model_a"][key].to(torch.float32),
thetas["model_b"][key].to(torch.float32),
)
dot_product = torch.dot(
thetas["model_a"][key].view(-1).to(torch.float32),
thetas["model_b"][key].view(-1).to(torch.float32),
)
magnitude_similarity = dot_product / (
torch.norm(thetas["model_a"][key].to(torch.float32))
* torch.norm(thetas["model_a"][key].to(torch.float32))
)
combined_similarity = (simab + magnitude_similarity) / 2.0
sims = np.append(sims, combined_similarity.numpy())
sims = np.delete(
sims, np.where(sims < np.percentile(sims, 1, method="midpoint"))
)
sims = np.delete(
sims, np.where(sims > np.percentile(sims, 99, method="midpoint"))
)
log_vram("after stage 0")
return sim, sims
Loading