Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
321 changes: 321 additions & 0 deletions examples/server/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -732,6 +732,327 @@ int main(int argc, const char** argv) {
}
});

// sdapi endpoints (AUTOMATIC1111 / Forge)

auto sdapi_any2img = [&](const httplib::Request& req, httplib::Response& res, bool img2img) {
try {
if (req.body.empty()) {
res.status = 400;
res.set_content(R"({"error":"empty body"})", "application/json");
return;
}

json j = json::parse(req.body);

std::string prompt = j.value("prompt", "");
std::string negative_prompt = j.value("negative_prompt", "");
int width = j.value("width", 512);
int height = j.value("height", 512);
int steps = j.value("steps", -1);
float cfg_scale = j.value("cfg_scale", 7.f);
int64_t seed = j.value("seed", -1);
int batch_size = j.value("batch_size", 1);
int clip_skip = j.value("clip_skip", -1);
std::string sampler_name = j.value("sampler_name", "");
std::string scheduler_name = j.value("scheduler", "");

auto bad = [&](const std::string& msg) {
res.status = 400;
res.set_content("{\"error\":\"" + msg + "\"}", "application/json");
return;
};

if (width <= 0 || height <= 0) {
return bad("width and height must be positive");
}

if (steps < 1 || steps > 150) {
return bad("steps must be in range [1, 150]");
}

if (batch_size < 1 || batch_size > 8) {
return bad("batch_size must be in range [1, 8]");
}

if (cfg_scale < 0.f) {
return bad("cfg_scale must be positive");
}

if (prompt.empty()) {
return bad("prompt required");
}

auto get_sample_method = [](std::string name) -> enum sample_method_t {
enum sample_method_t result = str_to_sample_method(name.c_str());
if (result != SAMPLE_METHOD_COUNT) return result;
// some applications use a hardcoded sampler list
std::transform(name.begin(), name.end(), name.begin(),
[](unsigned char c) { return std::tolower(c); });
static const std::unordered_map<std::string_view, sample_method_t> hardcoded{
{"euler a", EULER_A_SAMPLE_METHOD},
{"k_euler_a", EULER_A_SAMPLE_METHOD},
{"euler", EULER_SAMPLE_METHOD},
{"k_euler", EULER_SAMPLE_METHOD},
{"heun", HEUN_SAMPLE_METHOD},
{"k_heun", HEUN_SAMPLE_METHOD},
{"dpm2", DPM2_SAMPLE_METHOD},
{"k_dpm_2", DPM2_SAMPLE_METHOD},
{"lcm", LCM_SAMPLE_METHOD},
{"ddim", DDIM_TRAILING_SAMPLE_METHOD},
{"dpm++ 2m", DPMPP2M_SAMPLE_METHOD},
{"k_dpmpp_2m", DPMPP2M_SAMPLE_METHOD}};
auto it = hardcoded.find(name);
if (it != hardcoded.end()) return it->second;
return SAMPLE_METHOD_COUNT;
};

enum sample_method_t sample_method = get_sample_method(sampler_name);

enum scheduler_t scheduler = str_to_scheduler(scheduler_name.c_str());

// avoid excessive resource usage

SDGenerationParams gen_params = default_gen_params;
gen_params.prompt = prompt;
gen_params.negative_prompt = negative_prompt;
gen_params.width = width;
gen_params.height = height;
gen_params.seed = seed;
gen_params.sample_params.sample_steps = steps;
gen_params.batch_count = batch_size;

if (clip_skip > 0) {
gen_params.clip_skip = clip_skip;
}

if (sample_method != SAMPLE_METHOD_COUNT) {
gen_params.sample_params.sample_method = sample_method;
}

if (scheduler != SCHEDULER_COUNT) {
gen_params.sample_params.scheduler = scheduler;
}

LOG_DEBUG("%s\n", gen_params.to_string().c_str());

sd_image_t init_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr};
sd_image_t control_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr};
sd_image_t mask_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 1, nullptr};
std::vector<uint8_t> mask_data;
std::vector<sd_image_t> pmid_images;
std::vector<sd_image_t> ref_images;

if (img2img) {
auto decode_image = [](sd_image_t& image, std::string encoded) -> bool {
// remove data URI prefix if present ("data:image/png;base64,")
auto comma_pos = encoded.find(',');
if (comma_pos != std::string::npos) {
encoded = encoded.substr(comma_pos + 1);
}
std::vector<uint8_t> img_data = base64_decode(encoded);
if (!img_data.empty()) {
int img_w = image.width;
int img_h = image.height;
uint8_t* raw_data = load_image_from_memory(
(const char*)img_data.data(), (int)img_data.size(),
img_w, img_h,
image.width, image.height, image.channel);
if (raw_data) {
image = {(uint32_t)img_w, (uint32_t)img_h, image.channel, raw_data};
return true;
}
}
return false;
};

if (j.contains("init_images") && j["init_images"].is_array() && !j["init_images"].empty()) {
std::string encoded = j["init_images"][0].get<std::string>();
decode_image(init_image, encoded);
}

if (j.contains("mask") && j["mask"].is_string()) {
std::string encoded = j["mask"].get<std::string>();
decode_image(mask_image, encoded);
bool inpainting_mask_invert = j.value("inpainting_mask_invert", 0) != 0;
if (inpainting_mask_invert && mask_image.data != nullptr) {
for (uint32_t i = 0; i < mask_image.width * mask_image.height; i++) {
mask_image.data[i] = 255 - mask_image.data[i];
}
}
} else {
mask_data = std::vector<uint8_t>(width * height, 255);
mask_image.width = width;
mask_image.height = height;
mask_image.channel = 1;
mask_image.data = mask_data.data();
}

if (j.contains("extra_images") && j["extra_images"].is_array()) {
for (auto extra_image : j["extra_images"]) {
std::string encoded = extra_image.get<std::string>();
sd_image_t tmp_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr};
if (decode_image(tmp_image, encoded)) {
ref_images.push_back(tmp_image);
}
}
}

float denoising_strength = j.value("denoising_strength", -1.f);
if (denoising_strength >= 0.f) {
denoising_strength = std::min(denoising_strength, 1.0f);
gen_params.strength = denoising_strength;
}
}

sd_img_gen_params_t img_gen_params = {
gen_params.lora_vec.data(),
static_cast<uint32_t>(gen_params.lora_vec.size()),
gen_params.prompt.c_str(),
gen_params.negative_prompt.c_str(),
gen_params.clip_skip,
init_image,
ref_images.data(),
(int)ref_images.size(),
gen_params.auto_resize_ref_image,
gen_params.increase_ref_index,
mask_image,
gen_params.width,
gen_params.height,
gen_params.sample_params,
gen_params.strength,
gen_params.seed,
gen_params.batch_count,
control_image,
gen_params.control_strength,
{
pmid_images.data(),
(int)pmid_images.size(),
gen_params.pm_id_embed_path.c_str(),
gen_params.pm_style_strength,
}, // pm_params
ctx_params.vae_tiling_params,
gen_params.cache_params,
};

sd_image_t* results = nullptr;
int num_results = 0;

{
std::lock_guard<std::mutex> lock(sd_ctx_mutex);
results = generate_image(sd_ctx, &img_gen_params);
num_results = gen_params.batch_count;
}

json out;
out["images"] = json::array();
out["parameters"] = j; // TODO should return changed defaults
out["info"] = "";

for (int i = 0; i < num_results; i++) {
if (results[i].data == nullptr) {
continue;
}

auto image_bytes = write_image_to_vector(ImageFormat::PNG,
results[i].data,
results[i].width,
results[i].height,
results[i].channel);

if (image_bytes.empty()) {
LOG_ERROR("write image to mem failed");
continue;
}

std::string b64 = base64_encode(image_bytes);
out["images"].push_back(b64);
}

res.set_content(out.dump(), "application/json");
res.status = 200;

if (init_image.data) {
stbi_image_free(init_image.data);
}
if (mask_image.data && mask_data.empty()) {
stbi_image_free(mask_image.data);
}
for (auto ref_image : ref_images) {
stbi_image_free(ref_image.data);
}

} catch (const std::exception& e) {
res.status = 500;
json err;
err["error"] = "server_error";
err["message"] = e.what();
res.set_content(err.dump(), "application/json");
}
};

svr.Post("/sdapi/v1/txt2img", [&](const httplib::Request& req, httplib::Response& res) {
sdapi_any2img(req, res, false);
});

svr.Post("/sdapi/v1/img2img", [&](const httplib::Request& req, httplib::Response& res) {
sdapi_any2img(req, res, true);
});

svr.Get("/sdapi/v1/samplers", [&](const httplib::Request&, httplib::Response& res) {
std::vector<std::string> sampler_names;
sampler_names.push_back("default");
for (int i = 0; i < SAMPLE_METHOD_COUNT; i++) {
sampler_names.push_back(sd_sample_method_name((sample_method_t)i));
}
json r = json::array();
for (auto name : sampler_names) {
json entry;
entry["name"] = name;
entry["aliases"] = json::array({name});
entry["options"] = json::object();
r.push_back(entry);
}
res.set_content(r.dump(), "application/json");
});

svr.Get("/sdapi/v1/schedulers", [&](const httplib::Request&, httplib::Response& res) {
std::vector<std::string> scheduler_names;
scheduler_names.push_back("default");
for (int i = 0; i < SCHEDULER_COUNT; i++) {
scheduler_names.push_back(sd_scheduler_name((scheduler_t)i));
}
json r = json::array();
for (auto name : scheduler_names) {
json entry;
entry["name"] = name;
entry["label"] = name;
r.push_back(entry);
}
res.set_content(r.dump(), "application/json");
});

svr.Get("/sdapi/v1/sd-models", [&](const httplib::Request&, httplib::Response& res) {
fs::path model_path = ctx_params.model_path;
json entry;
entry["title"] = model_path.stem();
entry["model_name"] = model_path.stem();
entry["filename"] = model_path.filename();
entry["hash"] = "8888888888";
entry["sha256"] = "8888888888888888888888888888888888888888888888888888888888888888";
entry["config"] = nullptr;
json r = json::array();
r.push_back(entry);
res.set_content(r.dump(), "application/json");
});

svr.Get("/sdapi/v1/options", [&](const httplib::Request&, httplib::Response& res) {
fs::path model_path = ctx_params.model_path;
json r;
r["samples_format"] = "png";
r["sd_model_checkpoint"] = model_path.stem();
res.set_content(r.dump(), "application/json");
});

LOG_INFO("listening on: %s:%d\n", svr_params.listen_ip.c_str(), svr_params.listen_port);
svr.listen(svr_params.listen_ip, svr_params.listen_port);

Expand Down
Loading