Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 20 additions & 8 deletions flux.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -748,7 +748,7 @@ namespace Flux {
int nerf_depth = 4;
int nerf_max_freqs = 8;
bool use_x0 = false;
bool use_patch_size_32 = false;
bool fake_patch_size_x2 = false;
};

struct FluxParams {
Expand Down Expand Up @@ -786,8 +786,11 @@ namespace Flux {
Flux(FluxParams params)
: params(params) {
if (params.version == VERSION_CHROMA_RADIANCE) {
std::pair<int, int> kernel_size = {16, 16};
std::pair<int, int> stride = kernel_size;
std::pair<int, int> kernel_size = {params.patch_size, params.patch_size};
if (params.chroma_radiance_params.fake_patch_size_x2) {
kernel_size = {params.patch_size / 2, params.patch_size / 2};
}
std::pair<int, int> stride = kernel_size;

blocks["img_in_patch"] = std::make_shared<Conv2d>(params.in_channels,
params.hidden_size,
Expand Down Expand Up @@ -1082,7 +1085,7 @@ namespace Flux {
auto img = pad_to_patch_size(ctx, x);
auto orig_img = img;

if (params.chroma_radiance_params.use_patch_size_32) {
if (params.chroma_radiance_params.fake_patch_size_x2) {
// It's supposed to be using GGML_SCALE_MODE_NEAREST, but this seems more stable
// Maybe the implementation of nearest-neighbor interpolation in ggml behaves differently than the one in PyTorch?
// img = F.interpolate(img, size=(H//2, W//2), mode="nearest")
Expand Down Expand Up @@ -1303,7 +1306,8 @@ namespace Flux {
flux_params.ref_index_scale = 10.f;
flux_params.use_mlp_silu_act = true;
}
int64_t head_dim = 0;
int64_t head_dim = 0;
int64_t actual_radiance_patch_size = -1;
for (auto pair : tensor_storage_map) {
std::string tensor_name = pair.first;
if (!starts_with(tensor_name, prefix))
Expand All @@ -1316,9 +1320,12 @@ namespace Flux {
flux_params.chroma_radiance_params.use_x0 = true;
}
if (tensor_name.find("__32x32__") != std::string::npos) {
LOG_DEBUG("using patch size 32 prediction");
flux_params.chroma_radiance_params.use_patch_size_32 = true;
flux_params.patch_size = 32;
LOG_DEBUG("using patch size 32");
flux_params.patch_size = 32;
}
if (tensor_name.find("img_in_patch.weight") != std::string::npos) {
actual_radiance_patch_size = pair.second.ne[0];
LOG_DEBUG("actual radiance patch size: %d", actual_radiance_patch_size);
}
if (tensor_name.find("distilled_guidance_layer.in_proj.weight") != std::string::npos) {
// Chroma
Expand Down Expand Up @@ -1351,6 +1358,11 @@ namespace Flux {
head_dim = pair.second.ne[0];
}
}
if (actual_radiance_patch_size > 0 && actual_radiance_patch_size != flux_params.patch_size) {
GGML_ASSERT(flux_params.patch_size == 2 * actual_radiance_patch_size);
LOG_DEBUG("using fake x2 patch size");
flux_params.chroma_radiance_params.fake_patch_size_x2 = true;
}

flux_params.num_heads = static_cast<int>(flux_params.hidden_size / head_dim);

Expand Down
Loading