Skip to content

Image preview #522

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 19 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ test/
*.gguf
output*.png
models*
*.log
*.log
preview.png
54 changes: 53 additions & 1 deletion examples/cli/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@ const char* modes_str[] = {
};
#define SD_ALL_MODES_STR "img_gen, vid_gen, convert"

const char* previews_str[] = {
"none",
"proj",
"tae",
"vae",
};

enum SDMode {
IMG_GEN,
VID_GEN,
Expand Down Expand Up @@ -109,6 +116,11 @@ struct SDParams {
bool chroma_use_dit_mask = true;
bool chroma_use_t5_mask = false;
int chroma_t5_mask_pad = 1;

sd_preview_t preview_method = SD_PREVIEW_NONE;
int preview_interval = 1;
std::string preview_path = "preview.png";
bool taesd_preview = false;
};

void print_params(SDParams params) {
Expand Down Expand Up @@ -166,6 +178,8 @@ void print_params(SDParams params) {
printf(" chroma_use_dit_mask: %s\n", params.chroma_use_dit_mask ? "true" : "false");
printf(" chroma_use_t5_mask: %s\n", params.chroma_use_t5_mask ? "true" : "false");
printf(" chroma_t5_mask_pad: %d\n", params.chroma_t5_mask_pad);
printf(" preview_mode: %s\n", previews_str[params.preview_method]);
printf(" preview_interval: %d\n", params.preview_interval);
}

void print_usage(int argc, const char* argv[]) {
Expand All @@ -182,7 +196,8 @@ void print_usage(int argc, const char* argv[]) {
printf(" --clip_g path to the clip-g text encoder\n");
printf(" --t5xxl path to the t5xxl text encoder\n");
printf(" --vae [VAE] path to vae\n");
printf(" --taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)\n");
printf(" --taesd [TAESD] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)\n");
printf(" --taesd-preview-only prevents usage of taesd for decoding the final image. (for use with --preview %s)\n", previews_str[SD_PREVIEW_TAE]);
printf(" --control-net [CONTROL_PATH] path to control net model\n");
printf(" --embd-dir [EMBEDDING_PATH] path to embeddings\n");
printf(" --stacked-id-embd-dir [DIR] path to PHOTOMAKER stacked id embeddings\n");
Expand Down Expand Up @@ -234,6 +249,10 @@ void print_usage(int argc, const char* argv[]) {
printf(" This might crash if it is not supported by the backend.\n");
printf(" --control-net-cpu keep controlnet in cpu (for low vram)\n");
printf(" --canny apply canny preprocessor (edge detection)\n");
printf(" --preview {%s,%s,%s,%s} preview method. (default is %s(disabled))\n", previews_str[0], previews_str[1], previews_str[2], previews_str[3], previews_str[SD_PREVIEW_NONE]);
printf(" %s is the fastest\n", previews_str[SD_PREVIEW_PROJ]);
printf(" --preview-interval [N] How often to save the image preview");
printf(" --preview-path [PATH} path to write preview image to (default: ./preview.png)\n");
printf(" --color colors the logging tags according to level\n");
printf(" --chroma-disable-dit-mask disable dit mask for chroma\n");
printf(" --chroma-enable-t5-mask enable t5 mask for chroma\n");
Expand Down Expand Up @@ -386,6 +405,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
{"-o", "--output", "", &params.output_path},
{"-p", "--prompt", "", &params.prompt},
{"-n", "--negative-prompt", "", &params.negative_prompt},
{"", "--preview-path", "", &params.preview_path},

{"", "--upscale-model", "", &params.esrgan_path},
};
Expand All @@ -399,6 +419,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
{"", "--clip-skip", "", &params.clip_skip},
{"-b", "--batch-count", "", &params.batch_count},
{"", "--chroma-t5-mask-pad", "", &params.chroma_t5_mask_pad},
{"", "--preview-interval", "", &params.preview_interval},
};

options.float_options = {
Expand Down Expand Up @@ -427,6 +448,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
{"", "--color", "", true, &params.color},
{"", "--chroma-disable-dit-mask", "", false, &params.chroma_use_dit_mask},
{"", "--chroma-enable-t5-mask", "", true, &params.chroma_use_t5_mask},
{"", "--taesd-preview-only", "", false, &params.taesd_preview},
};

auto on_mode_arg = [&](int argc, const char** argv, int index) {
Expand Down Expand Up @@ -557,6 +579,26 @@ void parse_args(int argc, const char** argv, SDParams& params) {
return 1;
};

auto on_preview_arg = [&](int argc, const char** argv, int index) {
if (++index >= argc) {
return -1;
}
const char* preview = argv[index];
int preview_method = -1;
for (int m = 0; m < N_PREVIEWS; m++) {
if (!strcmp(preview, previews_str[m])) {
preview_method = m;
}
}
if (preview_method == -1) {
fprintf(stderr, "error: preview method %s\n",
preview);
return -1;
}
params.preview_method = (sd_preview_t)preview_method;
return 1;
};

options.manual_options = {
{"-M", "--mode", "", on_mode_arg},
{"", "--type", "", on_type_arg},
Expand All @@ -567,6 +609,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
{"", "--skip-layers", "", on_skip_layers_arg},
{"-r", "--ref-image", "", on_ref_image_arg},
{"-h", "--help", "", on_help_arg},
{"", "--preview", "", on_preview_arg},
};

if (!parse_options(argc, argv, options)) {
Expand Down Expand Up @@ -728,10 +771,17 @@ void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) {
fflush(out_stream);
}

const char* preview_path;

void step_callback(int step, sd_image_t image) {
stbi_write_png(preview_path, image.width, image.height, image.channel, image.data, 0);
}

int main(int argc, const char* argv[]) {
SDParams params;

parse_args(argc, argv, params);
preview_path = params.preview_path.c_str();

sd_guidance_params_t guidance_params = {params.cfg_scale,
params.img_cfg_scale,
Expand All @@ -746,6 +796,7 @@ int main(int argc, const char* argv[]) {
}};

sd_set_log_callback(sd_log_cb, (void*)&params);
sd_set_preview_callback((sd_preview_cb_t)step_callback, params.preview_method, params.preview_interval);

if (params.verbose) {
print_params(params);
Expand Down Expand Up @@ -887,6 +938,7 @@ int main(int argc, const char* argv[]) {
params.control_net_cpu,
params.vae_on_cpu,
params.diffusion_flash_attn,
params.taesd_preview,
params.chroma_use_dit_mask,
params.chroma_use_t5_mask,
params.chroma_t5_mask_pad,
Expand Down
2 changes: 1 addition & 1 deletion ggml
Submodule ggml updated from 9e4bee to b6d2eb
2 changes: 1 addition & 1 deletion ggml_extend.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,7 @@ __STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const
ggml_tensor* output_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, tile_size * scale, tile_size * scale, output->ne[2], 1);
on_processing(input_tile, NULL, true);
int num_tiles = ceil((float)input_width / non_tile_overlap) * ceil((float)input_height / non_tile_overlap);
LOG_INFO("processing %i tiles", num_tiles);
LOG_DEBUG("processing %i tiles", num_tiles);
pretty_progress(1, num_tiles, 0.0f);
int tile_count = 1;
bool last_y = false, last_x = false;
Expand Down
83 changes: 83 additions & 0 deletions latent-preview.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@

// https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/latent_formats.py#L152-L169
const float flux_latent_rgb_proj[16][3] = {
{-0.0346f, 0.0244f, 0.0681f},
{0.0034f, 0.0210f, 0.0687f},
{0.0275f, -0.0668f, -0.0433f},
{-0.0174f, 0.0160f, 0.0617f},
{0.0859f, 0.0721f, 0.0329f},
{0.0004f, 0.0383f, 0.0115f},
{0.0405f, 0.0861f, 0.0915f},
{-0.0236f, -0.0185f, -0.0259f},
{-0.0245f, 0.0250f, 0.1180f},
{0.1008f, 0.0755f, -0.0421f},
{-0.0515f, 0.0201f, 0.0011f},
{0.0428f, -0.0012f, -0.0036f},
{0.0817f, 0.0765f, 0.0749f},
{-0.1264f, -0.0522f, -0.1103f},
{-0.0280f, -0.0881f, -0.0499f},
{-0.1262f, -0.0982f, -0.0778f}};

// https://github.com/Stability-AI/sd3.5/blob/main/sd3_impls.py#L228-L246
const float sd3_latent_rgb_proj[16][3] = {
{-0.0645f, 0.0177f, 0.1052f},
{0.0028f, 0.0312f, 0.0650f},
{0.1848f, 0.0762f, 0.0360f},
{0.0944f, 0.0360f, 0.0889f},
{0.0897f, 0.0506f, -0.0364f},
{-0.0020f, 0.1203f, 0.0284f},
{0.0855f, 0.0118f, 0.0283f},
{-0.0539f, 0.0658f, 0.1047f},
{-0.0057f, 0.0116f, 0.0700f},
{-0.0412f, 0.0281f, -0.0039f},
{0.1106f, 0.1171f, 0.1220f},
{-0.0248f, 0.0682f, -0.0481f},
{0.0815f, 0.0846f, 0.1207f},
{-0.0120f, -0.0055f, -0.0867f},
{-0.0749f, -0.0634f, -0.0456f},
{-0.1418f, -0.1457f, -0.1259f},
};

// https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/latent_formats.py#L32-L38
const float sdxl_latent_rgb_proj[4][3] = {
{0.3651f, 0.4232f, 0.4341f},
{-0.2533f, -0.0042f, 0.1068f},
{0.1076f, 0.1111f, -0.0362f},
{-0.3165f, -0.2492f, -0.2188f}};

// https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/latent_formats.py#L32-L38
const float sd_latent_rgb_proj[4][3]{
{0.3512f, 0.2297f, 0.3227f},
{0.3250f, 0.4974f, 0.2350f},
{-0.2829f, 0.1762f, 0.2721f},
{-0.2120f, -0.2616f, -0.7177f}};

void preview_latent_image(uint8_t* buffer, struct ggml_tensor* latents, const float (*latent_rgb_proj)[3], int width, int height, int dim) {
size_t buffer_head = 0;
for (int j = 0; j < height; j++) {
for (int i = 0; i < width; i++) {
size_t latent_id = (i * latents->nb[0] + j * latents->nb[1]);
float r = 0, g = 0, b = 0;
for (int d = 0; d < dim; d++) {
float value = *(float*)((char*)latents->data + latent_id + d * latents->nb[2]);
r += value * latent_rgb_proj[d][0];
g += value * latent_rgb_proj[d][1];
b += value * latent_rgb_proj[d][2];
}

// change range
r = r * .5f + .5f;
g = g * .5f + .5f;
b = b * .5f + .5f;

// clamp rgb values to [0,1] range
r = r >= 0 ? r <= 1 ? r : 1 : 0;
g = g >= 0 ? g <= 1 ? g : 1 : 0;
b = b >= 0 ? b <= 1 ? b : 1 : 0;

buffer[buffer_head++] = (uint8_t)(r * 255);
buffer[buffer_head++] = (uint8_t)(g * 255);
buffer[buffer_head++] = (uint8_t)(b * 255);
}
}
}
Loading
Loading