diff --git a/common/arg.cpp b/common/arg.cpp index 8f54ee38c1..bd1a745e6a 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -3794,7 +3794,10 @@ common_params_context common_params_parser_init(common_params & params, llama_ex ).set_examples({ LLAMA_EXAMPLE_DIFFUSION })); add_opt(common_arg( {"--diffusion-algorithm"}, "N", - string_format("diffusion algorithm: 0=ORIGIN, 1=ENTROPY_BASED, 2=MARGIN_BASED, 3=RANDOM, 4=LOW_CONFIDENCE (default: %d)", params.diffusion.algorithm), + string_format( + "diffusion algorithm: 0=DIFFUSION_ALGORITHM_ORIGIN, 1=DIFFUSION_ALGORITHM_ENTROPY_BASED, " + "2=DIFFUSION_ALGORITHM_MARGIN_BASED, 3=DIFFUSION_ALGORITHM_RANDOM, " + "4=DIFFUSION_ALGORITHM_CONFIDENCE_BASED (default: %d)", params.diffusion.algorithm), [](common_params & params, int value) { params.diffusion.algorithm = value; } ).set_examples({ LLAMA_EXAMPLE_DIFFUSION })); add_opt(common_arg( diff --git a/examples/diffusion/CMakeLists.txt b/examples/diffusion/CMakeLists.txt index 70228d4079..42a84b2dfe 100644 --- a/examples/diffusion/CMakeLists.txt +++ b/examples/diffusion/CMakeLists.txt @@ -1,5 +1,10 @@ +set(TARGET llama-diffusion) +add_library(${TARGET} STATIC diffusion.cpp diffusion.h) +target_link_libraries(${TARGET} PUBLIC llama llama-common ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PUBLIC cxx_std_17) + set(TARGET llama-diffusion-cli) add_executable(${TARGET} diffusion-cli.cpp) install(TARGETS ${TARGET} RUNTIME) -target_link_libraries(${TARGET} PRIVATE llama llama-common ${CMAKE_THREAD_LIBS_INIT}) +target_link_libraries(${TARGET} PRIVATE llama-diffusion llama llama-common ${CMAKE_THREAD_LIBS_INIT}) target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/diffusion/README.md b/examples/diffusion/README.md index b394200214..6d2fffd64e 100644 --- a/examples/diffusion/README.md +++ b/examples/diffusion/README.md @@ -12,11 +12,11 @@ The diffusion CLI supports various parameters to control the generation process: ### Core Diffusion Parameters - `--diffusion-steps`: Number of diffusion steps (default: 256) - `--diffusion-algorithm`: Algorithm for token selection - - `0`: ORIGIN - Token will be generated in a purely random order from https://arxiv.org/abs/2107.03006. - - `1`: ENTROPY_BASED - Entropy-based selection - - `2`: MARGIN_BASED - Margin-based selection - - `3`: RANDOM - Random selection - - `4`: CONFIDENCE_BASED - Confidence-based selection (default) + - `0`: DIFFUSION_ALGORITHM_ORIGIN - Token will be generated in a purely random order from https://arxiv.org/abs/2107.03006. + - `1`: DIFFUSION_ALGORITHM_ENTROPY_BASED - Entropy-based selection + - `2`: DIFFUSION_ALGORITHM_MARGIN_BASED - Margin-based selection + - `3`: DIFFUSION_ALGORITHM_RANDOM - Random selection + - `4`: DIFFUSION_ALGORITHM_CONFIDENCE_BASED - Confidence-based selection (default) - More documentation here https://github.com/DreamLM/Dream - `--diffusion-visual`: Enable live visualization during generation diff --git a/examples/diffusion/diffusion-cli.cpp b/examples/diffusion/diffusion-cli.cpp index 403b9b4744..86ebbf88c9 100644 --- a/examples/diffusion/diffusion-cli.cpp +++ b/examples/diffusion/diffusion-cli.cpp @@ -1,127 +1,23 @@ #include "arg.h" #include "chat.h" #include "common.h" +#include "diffusion.h" #include "llama.h" #include "log.h" #include -#include #include -#include #include -#include -#include #include #include -enum diffusion_algorithm { ORIGIN = 0, ENTROPY_BASED = 1, MARGIN_BASED = 2, RANDOM = 3, CONFIDENCE_BASED = 4 }; - -// Unified transfer scheduling methods -enum transfer_schedule { - TIMESTEP_BASED = 0, // Dream-style: (1.0 - s/t) * remaining - BLOCK_BASED = 1, // LLaDA-style: process in blocks with get_num_transfer_tokens -}; - -typedef bool (*diffusion_step_callback_t)(int32_t step, - int32_t total_steps, - const llama_token * tokens, - int32_t n_tokens, - void * user_data); - -struct diffusion_params { - int32_t steps = 0; - float temperature = 0; - llama_token mask_token_id = LLAMA_TOKEN_NULL; - diffusion_step_callback_t step_callback = nullptr; - void * step_callback_user_data = nullptr; - int32_t seed = 0; - bool visual_mode = false; - bool shift_logits = false; // Shift logits by -1 after decode - - float top_p = 0.; - int32_t top_k = 0.; - - diffusion_algorithm algorithm = CONFIDENCE_BASED; - transfer_schedule schedule = TIMESTEP_BASED; - - float cfg_scale = 0.; // Config scale for classifier-free guidance - float eps = 0.; // Timestep scheduling - int32_t block_length = 0; // Block size (for block scheduling) - float alg_temp = 0; // algorithm temperature (0.0 = deterministic) - bool add_gumbel_noise = false; // Add gumbel noise to the logits if temp > 0.0 - - int32_t max_length = 0; // Maximum sequence length -}; - struct callback_data { diffusion_params * diff_params; const llama_vocab * vocab; int32_t n_input; }; -static float calculate_confidence(const llama_token_data_array & cur_p, - diffusion_algorithm algorithm, - std::mt19937 & rng) { - switch (algorithm) { - case CONFIDENCE_BASED: - return cur_p.data[cur_p.selected].p; // Selected token probability - - case ENTROPY_BASED: - { - float entropy = 0.0f; - const float epsilon = 1e-10f; - for (size_t i = 0; i < cur_p.size; i++) { - float prob = cur_p.data[i].p; - entropy += prob * logf(prob + epsilon); - } - return -entropy; // Higher entropy = lower confidence - } - - case MARGIN_BASED: - return (cur_p.size > 1) ? cur_p.data[0].p - cur_p.data[1].p : cur_p.data[0].p; - - case RANDOM: - { - std::uniform_real_distribution uniform(0.0f, 1.0f); - return uniform(rng); // Random confidence - } - - case ORIGIN: - return cur_p.data[cur_p.selected].p; - - default: - return 0.0f; - } -} - -// Unified transfer count calculation function -static int32_t calculate_transfer_count(int32_t step, - int32_t total_steps, - int32_t remaining_masked, - transfer_schedule schedule, - float eps, - const std::vector & num_transfer_tokens = {}) { - switch (schedule) { - case TIMESTEP_BASED: - { - float t = 1.0f - (float) step / total_steps * (1.0f - eps); - float s = 1.0f - (float) (step + 1) / total_steps * (1.0f - eps); - float p_transfer = (step < total_steps - 1) ? (1.0f - s / t) : 1.0f; - return (int32_t) (remaining_masked * p_transfer); - } - - case BLOCK_BASED: - if (!num_transfer_tokens.empty() && step < (int32_t) num_transfer_tokens.size()) { - return num_transfer_tokens[step]; - } - return remaining_masked / (total_steps - step); // Fallback - - default: - return remaining_masked / (total_steps - step); - } -} - static bool diffusion_step_callback(int32_t step, int32_t total_steps, const llama_token * tokens, @@ -176,341 +72,6 @@ static bool diffusion_step_callback(int32_t step, return true; } -static void add_gumbel_noise(float * logits, int32_t n_vocab, float temperature, std::mt19937 & rng) { - if (temperature == 0.0f) { - return; - } - - std::uniform_real_distribution uniform(0.0, 1.0); - for (int32_t i = 0; i < n_vocab; i++) { - double noise = uniform(rng); - // Prevent log(0) - noise = std::max(noise, 1e-20); - double gumbel_noise = std::pow(-std::log(noise), temperature); - logits[i] = std::exp(logits[i]) / gumbel_noise; - } -} - -static std::vector get_num_transfer_tokens(int32_t mask_count, int32_t steps) { - std::vector num_transfer_tokens(steps); - - int32_t base = mask_count / steps; - int32_t remainder = mask_count % steps; - - for (int32_t i = 0; i < steps; i++) { - num_transfer_tokens[i] = base + (i < remainder ? 1 : 0); - } - - return num_transfer_tokens; -} - -static void diffusion_generate(llama_context * ctx, - const llama_token * input_tokens, - llama_token * output_tokens, - int32_t n_input, - const diffusion_params & params, - int32_t & n_generated) { - n_generated = 0; - if (!ctx || !input_tokens || !output_tokens || n_input <= 0 || params.max_length <= n_input) { - return; - } - - const llama_model * model = llama_get_model(ctx); - - // Initialize with input and pad with mask tokens - std::copy(input_tokens, input_tokens + n_input, output_tokens); - std::fill(output_tokens + n_input, output_tokens + params.max_length, params.mask_token_id); - - std::mt19937 rng(params.seed); - - llama_set_causal_attn(ctx, false); - - int32_t n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(model)); - - std::vector candidates(n_vocab); - std::vector conf_candidates; - conf_candidates.reserve(params.max_length); - std::vector mask_positions; - mask_positions.reserve(params.max_length); - - // Setup sampler chain - struct llama_sampler * sampler = llama_sampler_chain_init(llama_sampler_chain_default_params()); - if (params.top_k > 0) { - llama_sampler_chain_add(sampler, llama_sampler_init_top_k(params.top_k)); - } - if (params.top_p < 1.0f) { - llama_sampler_chain_add(sampler, llama_sampler_init_top_p(params.top_p, 1)); - } - if (params.temperature > 0.0f) { - llama_sampler_chain_add(sampler, llama_sampler_init_temp(params.temperature)); - } - llama_sampler_chain_add(sampler, llama_sampler_init_dist(params.seed)); - - struct llama_sampler * dist_sampler = llama_sampler_init_dist(params.seed); - - llama_batch batch = llama_batch_init(params.max_length, 0, 1); - batch.n_tokens = params.max_length; - - // Pre-allocate buffers for CFG if needed - int32_t logits_size = n_vocab * params.max_length; - std::vector cond_logits_buffer; - std::vector un_x_buffer; - if (params.cfg_scale > 0.0f) { - cond_logits_buffer.resize(logits_size); - un_x_buffer.resize(params.max_length); - } - - // For block-based processing - std::vector num_transfer_tokens; - int32_t num_blocks = 1; - int32_t steps_per_block = params.steps; - - if (params.schedule == BLOCK_BASED) { - GGML_ASSERT(params.max_length % params.block_length == 0); - num_blocks = params.max_length / params.block_length; - GGML_ASSERT(params.steps % num_blocks == 0); - steps_per_block = params.steps / num_blocks; - } - - std::vector confidence(params.max_length); - - int64_t total_sampling_time = 0; - int64_t total_time = 0; - int64_t time_start = ggml_time_us(); - - for (int block_num = 0; block_num < num_blocks; block_num++) { - int32_t block_start = (params.schedule == BLOCK_BASED) ? n_input + block_num * params.block_length : 0; - int32_t block_end = (params.schedule == BLOCK_BASED) ? - std::min(n_input + (block_num + 1) * params.block_length, params.max_length) : - params.max_length; - - // Count masked tokens in current block for block-based processing - if (params.schedule == BLOCK_BASED) { - int32_t block_mask_count = 0; - for (int i = block_start; i < block_end; i++) { - if (output_tokens[i] == params.mask_token_id) { - block_mask_count++; - } - } - num_transfer_tokens = get_num_transfer_tokens(block_mask_count, steps_per_block); - } - - for (int32_t step = 0; step < steps_per_block; step++) { - int32_t global_step = block_num * steps_per_block + step; - - if (params.step_callback) { - if (!params.step_callback( - global_step, params.steps, output_tokens, params.max_length, params.step_callback_user_data)) { - break; - } - } - - // Setup batch - for (int32_t i = 0; i < params.max_length; i++) { - batch.token[i] = output_tokens[i]; - batch.pos[i] = i; - batch.n_seq_id[i] = 1; - batch.seq_id[i][0] = 0; - batch.logits[i] = 1; - } - - float * logits = nullptr; - - if (params.cfg_scale > 0.0f) { - int ret = llama_decode(ctx, batch); - if (ret != 0) { - LOG_ERR("Failed to generate conditional"); - break; - } - float * cond_logits_ptr = llama_get_logits(ctx); - std::memcpy(cond_logits_buffer.data(), cond_logits_ptr, logits_size * sizeof(float)); - - // Unconditional generation (mask input) - std::copy(output_tokens, output_tokens + params.max_length, un_x_buffer.begin()); - for (int32_t i = 0; i < n_input; i++) { - un_x_buffer[i] = params.mask_token_id; - } - - for (int32_t i = 0; i < params.max_length; i++) { - batch.token[i] = un_x_buffer[i]; - } - ret = llama_decode(ctx, batch); - if (ret != 0) { - LOG_ERR("Failed to generate unconditional"); - break; - } - float * uncond_logits = llama_get_logits(ctx); - - // Apply CFG - for (int32_t i = 0; i < logits_size; i++) { - cond_logits_buffer[i] = - uncond_logits[i] + (params.cfg_scale + 1.0f) * (cond_logits_buffer[i] - uncond_logits[i]); - } - logits = cond_logits_buffer.data(); - } else { - int ret = llama_decode(ctx, batch); - if (ret != 0) { - LOG_ERR("%s: failed to decode at step %d, ret = %d\n", __func__, global_step, ret); - break; - } - logits = llama_get_logits(ctx); - } - - if (!logits) { - LOG_ERR("%s: failed to get logits at step %d\n", __func__, global_step); - break; - } - - auto get_logits_for_pos = [&](int32_t pos) -> const float * { - if (params.shift_logits) { - return pos == 0 ? logits : logits + (pos - 1) * n_vocab; - } - return logits + (pos) *n_vocab; - }; - - int64_t time_start_sampling = ggml_time_us(); - - mask_positions.clear(); - for (int32_t i = 0; i < params.max_length; i++) { - if (output_tokens[i] == params.mask_token_id) { - // For block-based, only consider current block - if (params.schedule != BLOCK_BASED || (i >= block_start && i < block_end)) { - mask_positions.push_back(i); - } - } - } - - if (mask_positions.empty()) { - break; - } - - if (params.add_gumbel_noise && params.temperature > 0.0f) { - add_gumbel_noise(logits, n_vocab, params.temperature, rng); - } - - if (params.algorithm == ORIGIN) { - int32_t transfer_count = calculate_transfer_count( - step, steps_per_block, mask_positions.size(), params.schedule, params.eps, num_transfer_tokens); - float p_transfer = (float) transfer_count / mask_positions.size(); - - for (int32_t pos : mask_positions) { - if (std::uniform_real_distribution(0.0f, 1.0f)(rng) < p_transfer) { - const float * pos_logits = get_logits_for_pos(pos); - for (int32_t token_id = 0; token_id < n_vocab; token_id++) { - candidates[token_id].id = token_id; - candidates[token_id].logit = pos_logits[token_id]; - candidates[token_id].p = 0.0f; - } - - llama_token_data_array cur_p = { - candidates.data(), - (size_t) n_vocab, - -1, - false, - }; - - llama_sampler_apply(sampler, &cur_p); - output_tokens[pos] = cur_p.data[cur_p.selected].id; - } - } - } else { - std::vector> confidences; - std::vector sampled_tokens(mask_positions.size()); - - for (size_t i = 0; i < mask_positions.size(); i++) { - int32_t pos = mask_positions[i]; - const float * pos_logits = get_logits_for_pos(pos); - - for (int32_t token_id = 0; token_id < n_vocab; token_id++) { - candidates[token_id].logit = pos_logits[token_id]; - candidates[token_id].p = 0.0f; - candidates[token_id].id = token_id; - } - - llama_token_data_array cur_p = { - candidates.data(), - candidates.size(), - -1, - false, - }; - - llama_sampler_apply(sampler, &cur_p); - llama_token sampled_token = cur_p.data[cur_p.selected].id; - - float conf = calculate_confidence(cur_p, params.algorithm, rng); - - sampled_tokens[i] = sampled_token; - confidences.emplace_back(conf, i); - } - - int32_t transfer_count = calculate_transfer_count( - step, steps_per_block, mask_positions.size(), params.schedule, params.eps, num_transfer_tokens); - - if (transfer_count > 0) { - if (params.alg_temp == 0.0f) { - std::partial_sort(confidences.begin(), - confidences.begin() + std::min(transfer_count, (int32_t) confidences.size()), - confidences.end(), - [](const std::pair & a, const std::pair & b) { - if (a.first != b.first) { - return a.first > b.first; - } - return a.second < b.second; - }); - - for (int32_t i = 0; i < std::min(transfer_count, (int32_t) confidences.size()); i++) { - int32_t mask_idx = confidences[i].second; - int32_t pos = mask_positions[mask_idx]; - output_tokens[pos] = sampled_tokens[mask_idx]; - } - } else { - conf_candidates.clear(); - for (size_t i = 0; i < confidences.size(); i++) { - float conf_logit = confidences[i].first / params.alg_temp; - conf_candidates.emplace_back(llama_token_data{ (int32_t) i, conf_logit, 0.0f }); - } - - llama_token_data_array conf_array = { - conf_candidates.data(), - conf_candidates.size(), - -1, - false, - }; - - for (int32_t i = 0; i < std::min(transfer_count, (int32_t) confidences.size()); i++) { - llama_sampler_apply(dist_sampler, &conf_array); - int32_t selected_idx = conf_array.selected; - int32_t mask_idx = selected_idx; - int32_t pos = mask_positions[mask_idx]; - output_tokens[pos] = sampled_tokens[mask_idx]; - - conf_candidates[selected_idx].p = 0.0f; - conf_array.selected = -1; - } - } - } - } - - int64_t time_end_sampling = ggml_time_us(); - total_sampling_time += time_end_sampling - time_start_sampling; - } - } - - int64_t time_end = ggml_time_us(); - total_time += time_end - time_start; - - LOG_INF("\ntotal time: %0.2fms, time per step: %0.2fms, sampling time per step: %0.2fms\n", - total_time / 1000.0, - total_time / 1000.0 / params.steps, - total_sampling_time / 1000.0 / params.steps); - - llama_batch_free(batch); - llama_sampler_free(sampler); - llama_sampler_free(dist_sampler); - - n_generated = params.max_length; -} - static std::string format_input_text(const std::string & prompt, const std::string & system_prompt, bool use_chat_template, llama_model * model) { if (!use_chat_template) { return prompt; @@ -631,10 +192,10 @@ int main(int argc, char ** argv) { GGML_ASSERT((params.diffusion.eps == 0) ^ (params.diffusion.block_length == 0)); if (params.diffusion.eps) { - diff_params.schedule = TIMESTEP_BASED; + diff_params.schedule = DIFFUSION_TRANSFER_SCHEDULE_TIMESTEP_BASED; diff_params.eps = params.diffusion.eps; } else if (params.diffusion.block_length) { - diff_params.schedule = BLOCK_BASED; + diff_params.schedule = DIFFUSION_TRANSFER_SCHEDULE_BLOCK_BASED; diff_params.block_length = params.diffusion.block_length; } @@ -653,8 +214,17 @@ int main(int argc, char ** argv) { callback_data cb_data = { &diff_params, vocab, n_input }; diff_params.step_callback_user_data = &cb_data; - const char * alg_names[] = { "ORIGIN", "ENTROPY_BASED", "MARGIN_BASED", "RANDOM", "CONFIDENCE_BASED" }; - const char * sched_names[] = { "TIMESTEP_BASED", "BLOCK_BASED" }; + const char * alg_names[] = { + "DIFFUSION_ALGORITHM_ORIGIN", + "DIFFUSION_ALGORITHM_ENTROPY_BASED", + "DIFFUSION_ALGORITHM_MARGIN_BASED", + "DIFFUSION_ALGORITHM_RANDOM", + "DIFFUSION_ALGORITHM_CONFIDENCE_BASED", + }; + const char * sched_names[] = { + "DIFFUSION_TRANSFER_SCHEDULE_TIMESTEP_BASED", + "DIFFUSION_TRANSFER_SCHEDULE_BLOCK_BASED", + }; const char * alg_name = (diff_params.algorithm >= 0 && diff_params.algorithm <= 4) ? alg_names[diff_params.algorithm] : "UNKNOWN"; const char * sched_name = @@ -666,11 +236,11 @@ int main(int argc, char ** argv) { LOG_INF("diffusion_params: - %-25s enum = %d (%s)\n", "algorithm", diff_params.algorithm, alg_name); LOG_INF("diffusion_params: - %-25s enum = %d (%s)\n", "schedule", diff_params.schedule, sched_name); LOG_INF("diffusion_params: - %-25s f32 = %.3f\n", "temperature", diff_params.temperature); - if (diff_params.schedule == TIMESTEP_BASED) { + if (diff_params.schedule == DIFFUSION_TRANSFER_SCHEDULE_TIMESTEP_BASED) { LOG_INF("diffusion_params: - %-25s f32 = %.6f\n", "eps", diff_params.eps); LOG_INF("diffusion_params: - %-25s f32 = %.3f\n", "alg_temp", diff_params.alg_temp); } - if (diff_params.schedule == BLOCK_BASED) { + if (diff_params.schedule == DIFFUSION_TRANSFER_SCHEDULE_BLOCK_BASED) { LOG_INF("diffusion_params: - %-25s u32 = %d\n", "block_length", diff_params.block_length); LOG_INF("diffusion_params: - %-25s f32 = %.3f\n", "cfg_scale", diff_params.cfg_scale); } diff --git a/examples/diffusion/diffusion.cpp b/examples/diffusion/diffusion.cpp new file mode 100644 index 0000000000..97d6b69449 --- /dev/null +++ b/examples/diffusion/diffusion.cpp @@ -0,0 +1,408 @@ +#include "diffusion.h" + +#include "log.h" + +#include +#include +#include +#include +#include +#include +#include + +static float calculate_confidence(const llama_token_data_array & cur_p, + diffusion_algorithm algorithm, + std::mt19937 & rng) { + switch (algorithm) { + case DIFFUSION_ALGORITHM_CONFIDENCE_BASED: + return cur_p.data[cur_p.selected].p; // Selected token probability + + case DIFFUSION_ALGORITHM_ENTROPY_BASED: + { + float entropy = 0.0f; + const float epsilon = 1e-10f; + for (size_t i = 0; i < cur_p.size; i++) { + float prob = cur_p.data[i].p; + entropy += prob * logf(prob + epsilon); + } + return -entropy; // Higher entropy = lower confidence + } + + case DIFFUSION_ALGORITHM_MARGIN_BASED: + return (cur_p.size > 1) ? cur_p.data[0].p - cur_p.data[1].p : cur_p.data[0].p; + + case DIFFUSION_ALGORITHM_RANDOM: + { + std::uniform_real_distribution uniform(0.0f, 1.0f); + return uniform(rng); // Random confidence + } + + case DIFFUSION_ALGORITHM_ORIGIN: + return cur_p.data[cur_p.selected].p; + + default: + return 0.0f; + } +} + +// Unified transfer count calculation function +static int32_t calculate_transfer_count(int32_t step, + int32_t total_steps, + int32_t remaining_masked, + diffusion_transfer_schedule schedule, + float eps, + const std::vector & num_transfer_tokens = {}) { + switch (schedule) { + case DIFFUSION_TRANSFER_SCHEDULE_TIMESTEP_BASED: + { + float t = 1.0f - (float) step / total_steps * (1.0f - eps); + float s = 1.0f - (float) (step + 1) / total_steps * (1.0f - eps); + float p_transfer = (step < total_steps - 1) ? (1.0f - s / t) : 1.0f; + return (int32_t) (remaining_masked * p_transfer); + } + + case DIFFUSION_TRANSFER_SCHEDULE_BLOCK_BASED: + if (!num_transfer_tokens.empty() && step < (int32_t) num_transfer_tokens.size()) { + return num_transfer_tokens[step]; + } + return remaining_masked / (total_steps - step); // Fallback + + default: + return remaining_masked / (total_steps - step); + } +} + +static void add_gumbel_noise(float * logits, int32_t n_vocab, float temperature, std::mt19937 & rng) { + if (temperature == 0.0f) { + return; + } + + std::uniform_real_distribution uniform(0.0, 1.0); + for (int32_t i = 0; i < n_vocab; i++) { + double noise = uniform(rng); + // Prevent log(0) + noise = std::max(noise, 1e-20); + double gumbel_noise = std::pow(-std::log(noise), temperature); + logits[i] = std::exp(logits[i]) / gumbel_noise; + } +} + +static std::vector get_num_transfer_tokens(int32_t mask_count, int32_t steps) { + std::vector num_transfer_tokens(steps); + + int32_t base = mask_count / steps; + int32_t remainder = mask_count % steps; + + for (int32_t i = 0; i < steps; i++) { + num_transfer_tokens[i] = base + (i < remainder ? 1 : 0); + } + + return num_transfer_tokens; +} + +void diffusion_generate(llama_context * ctx, + const llama_token * input_tokens, + llama_token * output_tokens, + int32_t n_input, + const diffusion_params & params, + int32_t & n_generated) { + n_generated = 0; + if (!ctx || !input_tokens || !output_tokens || n_input <= 0 || params.max_length <= n_input) { + return; + } + + const llama_model * model = llama_get_model(ctx); + + // Initialize with input and pad with mask tokens + std::copy(input_tokens, input_tokens + n_input, output_tokens); + std::fill(output_tokens + n_input, output_tokens + params.max_length, params.mask_token_id); + + std::mt19937 rng(params.seed); + + llama_set_causal_attn(ctx, false); + + int32_t n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(model)); + + std::vector candidates(n_vocab); + std::vector conf_candidates; + conf_candidates.reserve(params.max_length); + std::vector mask_positions; + mask_positions.reserve(params.max_length); + + // Setup sampler chain + struct llama_sampler * sampler = llama_sampler_chain_init(llama_sampler_chain_default_params()); + if (params.top_k > 0) { + llama_sampler_chain_add(sampler, llama_sampler_init_top_k(params.top_k)); + } + if (params.top_p < 1.0f) { + llama_sampler_chain_add(sampler, llama_sampler_init_top_p(params.top_p, 1)); + } + if (params.temperature > 0.0f) { + llama_sampler_chain_add(sampler, llama_sampler_init_temp(params.temperature)); + } + llama_sampler_chain_add(sampler, llama_sampler_init_dist(params.seed)); + + struct llama_sampler * dist_sampler = llama_sampler_init_dist(params.seed); + + llama_batch batch = llama_batch_init(params.max_length, 0, 1); + batch.n_tokens = params.max_length; + + // Pre-allocate buffers for CFG if needed + int32_t logits_size = n_vocab * params.max_length; + std::vector cond_logits_buffer; + std::vector un_x_buffer; + if (params.cfg_scale > 0.0f) { + cond_logits_buffer.resize(logits_size); + un_x_buffer.resize(params.max_length); + } + + // For block-based processing + std::vector num_transfer_tokens; + int32_t num_blocks = 1; + int32_t steps_per_block = params.steps; + + if (params.schedule == DIFFUSION_TRANSFER_SCHEDULE_BLOCK_BASED) { + GGML_ASSERT(params.max_length % params.block_length == 0); + num_blocks = params.max_length / params.block_length; + GGML_ASSERT(params.steps % num_blocks == 0); + steps_per_block = params.steps / num_blocks; + } + + std::vector confidence(params.max_length); + + int64_t total_sampling_time = 0; + int64_t total_time = 0; + int64_t time_start = ggml_time_us(); + + for (int block_num = 0; block_num < num_blocks; block_num++) { + int32_t block_start = (params.schedule == DIFFUSION_TRANSFER_SCHEDULE_BLOCK_BASED) ? n_input + block_num * params.block_length : 0; + int32_t block_end = (params.schedule == DIFFUSION_TRANSFER_SCHEDULE_BLOCK_BASED) ? + std::min(n_input + (block_num + 1) * params.block_length, params.max_length) : + params.max_length; + + // Count masked tokens in current block for block-based processing + if (params.schedule == DIFFUSION_TRANSFER_SCHEDULE_BLOCK_BASED) { + int32_t block_mask_count = 0; + for (int i = block_start; i < block_end; i++) { + if (output_tokens[i] == params.mask_token_id) { + block_mask_count++; + } + } + num_transfer_tokens = get_num_transfer_tokens(block_mask_count, steps_per_block); + } + + for (int32_t step = 0; step < steps_per_block; step++) { + int32_t global_step = block_num * steps_per_block + step; + + if (params.step_callback) { + if (!params.step_callback( + global_step, params.steps, output_tokens, params.max_length, params.step_callback_user_data)) { + break; + } + } + + // Setup batch + for (int32_t i = 0; i < params.max_length; i++) { + batch.token[i] = output_tokens[i]; + batch.pos[i] = i; + batch.n_seq_id[i] = 1; + batch.seq_id[i][0] = 0; + batch.logits[i] = 1; + } + + float * logits = nullptr; + + if (params.cfg_scale > 0.0f) { + int ret = llama_decode(ctx, batch); + if (ret != 0) { + LOG_ERR("Failed to generate conditional"); + break; + } + float * cond_logits_ptr = llama_get_logits(ctx); + std::memcpy(cond_logits_buffer.data(), cond_logits_ptr, logits_size * sizeof(float)); + + // Unconditional generation (mask input) + std::copy(output_tokens, output_tokens + params.max_length, un_x_buffer.begin()); + for (int32_t i = 0; i < n_input; i++) { + un_x_buffer[i] = params.mask_token_id; + } + + for (int32_t i = 0; i < params.max_length; i++) { + batch.token[i] = un_x_buffer[i]; + } + ret = llama_decode(ctx, batch); + if (ret != 0) { + LOG_ERR("Failed to generate unconditional"); + break; + } + float * uncond_logits = llama_get_logits(ctx); + + // Apply CFG + for (int32_t i = 0; i < logits_size; i++) { + cond_logits_buffer[i] = + uncond_logits[i] + (params.cfg_scale + 1.0f) * (cond_logits_buffer[i] - uncond_logits[i]); + } + logits = cond_logits_buffer.data(); + } else { + int ret = llama_decode(ctx, batch); + if (ret != 0) { + LOG_ERR("%s: failed to decode at step %d, ret = %d\n", __func__, global_step, ret); + break; + } + logits = llama_get_logits(ctx); + } + + if (!logits) { + LOG_ERR("%s: failed to get logits at step %d\n", __func__, global_step); + break; + } + + auto get_logits_for_pos = [&](int32_t pos) -> const float * { + if (params.shift_logits) { + return pos == 0 ? logits : logits + (pos - 1) * n_vocab; + } + return logits + pos * n_vocab; + }; + + int64_t time_start_sampling = ggml_time_us(); + + mask_positions.clear(); + for (int32_t i = 0; i < params.max_length; i++) { + if (output_tokens[i] == params.mask_token_id) { + // For block-based, only consider current block + if (params.schedule != DIFFUSION_TRANSFER_SCHEDULE_BLOCK_BASED || (i >= block_start && i < block_end)) { + mask_positions.push_back(i); + } + } + } + + if (mask_positions.empty()) { + break; + } + + if (params.add_gumbel_noise && params.temperature > 0.0f) { + add_gumbel_noise(logits, n_vocab, params.temperature, rng); + } + + if (params.algorithm == DIFFUSION_ALGORITHM_ORIGIN) { + int32_t transfer_count = calculate_transfer_count( + step, steps_per_block, mask_positions.size(), params.schedule, params.eps, num_transfer_tokens); + float p_transfer = (float) transfer_count / mask_positions.size(); + + for (int32_t pos : mask_positions) { + if (std::uniform_real_distribution(0.0f, 1.0f)(rng) < p_transfer) { + const float * pos_logits = get_logits_for_pos(pos); + for (int32_t token_id = 0; token_id < n_vocab; token_id++) { + candidates[token_id].id = token_id; + candidates[token_id].logit = pos_logits[token_id]; + candidates[token_id].p = 0.0f; + } + + llama_token_data_array cur_p = { + candidates.data(), + (size_t) n_vocab, + -1, + false, + }; + + llama_sampler_apply(sampler, &cur_p); + output_tokens[pos] = cur_p.data[cur_p.selected].id; + } + } + } else { + std::vector> confidences; + std::vector sampled_tokens(mask_positions.size()); + + for (size_t i = 0; i < mask_positions.size(); i++) { + int32_t pos = mask_positions[i]; + const float * pos_logits = get_logits_for_pos(pos); + + for (int32_t token_id = 0; token_id < n_vocab; token_id++) { + candidates[token_id].logit = pos_logits[token_id]; + candidates[token_id].p = 0.0f; + candidates[token_id].id = token_id; + } + + llama_token_data_array cur_p = { + candidates.data(), + candidates.size(), + -1, + false, + }; + + llama_sampler_apply(sampler, &cur_p); + llama_token sampled_token = cur_p.data[cur_p.selected].id; + + float conf = calculate_confidence(cur_p, params.algorithm, rng); + + sampled_tokens[i] = sampled_token; + confidences.emplace_back(conf, i); + } + + int32_t transfer_count = calculate_transfer_count( + step, steps_per_block, mask_positions.size(), params.schedule, params.eps, num_transfer_tokens); + + if (transfer_count > 0) { + if (params.alg_temp == 0.0f) { + std::partial_sort(confidences.begin(), + confidences.begin() + std::min(transfer_count, (int32_t) confidences.size()), + confidences.end(), + [](const std::pair & a, const std::pair & b) { + if (a.first != b.first) { + return a.first > b.first; + } + return a.second < b.second; + }); + + for (int32_t i = 0; i < std::min(transfer_count, (int32_t) confidences.size()); i++) { + int32_t mask_idx = confidences[i].second; + int32_t pos = mask_positions[mask_idx]; + output_tokens[pos] = sampled_tokens[mask_idx]; + } + } else { + conf_candidates.clear(); + for (size_t i = 0; i < confidences.size(); i++) { + float conf_logit = confidences[i].first / params.alg_temp; + conf_candidates.emplace_back(llama_token_data{ (int32_t) i, conf_logit, 0.0f }); + } + + llama_token_data_array conf_array = { + conf_candidates.data(), + conf_candidates.size(), + -1, + false, + }; + + for (int32_t i = 0; i < std::min(transfer_count, (int32_t) confidences.size()); i++) { + llama_sampler_apply(dist_sampler, &conf_array); + int32_t selected_idx = conf_array.selected; + int32_t mask_idx = selected_idx; + int32_t pos = mask_positions[mask_idx]; + output_tokens[pos] = sampled_tokens[mask_idx]; + + conf_candidates[selected_idx].p = 0.0f; + conf_array.selected = -1; + } + } + } + } + + int64_t time_end_sampling = ggml_time_us(); + total_sampling_time += time_end_sampling - time_start_sampling; + } + } + + int64_t time_end = ggml_time_us(); + total_time += time_end - time_start; + + LOG_INF("\ntotal time: %0.2fms, time per step: %0.2fms, sampling time per step: %0.2fms\n", + total_time / 1000.0, + total_time / 1000.0 / params.steps, + total_sampling_time / 1000.0 / params.steps); + + llama_batch_free(batch); + llama_sampler_free(sampler); + llama_sampler_free(dist_sampler); + + n_generated = params.max_length; +} diff --git a/examples/diffusion/diffusion.h b/examples/diffusion/diffusion.h new file mode 100644 index 0000000000..7831445224 --- /dev/null +++ b/examples/diffusion/diffusion.h @@ -0,0 +1,57 @@ +#pragma once + +#include "llama.h" + +#include + +enum diffusion_algorithm { + DIFFUSION_ALGORITHM_ORIGIN = 0, + DIFFUSION_ALGORITHM_ENTROPY_BASED = 1, + DIFFUSION_ALGORITHM_MARGIN_BASED = 2, + DIFFUSION_ALGORITHM_RANDOM = 3, + DIFFUSION_ALGORITHM_CONFIDENCE_BASED = 4, +}; + +// Unified transfer scheduling methods +enum diffusion_transfer_schedule { + DIFFUSION_TRANSFER_SCHEDULE_TIMESTEP_BASED = 0, // Dream-style: (1.0 - s/t) * remaining + DIFFUSION_TRANSFER_SCHEDULE_BLOCK_BASED = 1, // LLaDA-style: process in blocks with get_num_transfer_tokens +}; + +typedef bool (*diffusion_step_callback_t)(int32_t step, + int32_t total_steps, + const llama_token * tokens, + int32_t n_tokens, + void * user_data); + +struct diffusion_params { + int32_t steps = 0; + float temperature = 0; + llama_token mask_token_id = LLAMA_TOKEN_NULL; + diffusion_step_callback_t step_callback = nullptr; + void * step_callback_user_data = nullptr; + int32_t seed = 0; + bool visual_mode = false; + bool shift_logits = false; // Shift logits by -1 after decode + + float top_p = 0.; + int32_t top_k = 0.; + + diffusion_algorithm algorithm = DIFFUSION_ALGORITHM_CONFIDENCE_BASED; + diffusion_transfer_schedule schedule = DIFFUSION_TRANSFER_SCHEDULE_TIMESTEP_BASED; + + float cfg_scale = 0.; // Config scale for classifier-free guidance + float eps = 0.; // Timestep scheduling + int32_t block_length = 0; // Block size (for block scheduling) + float alg_temp = 0; // algorithm temperature (0.0 = deterministic) + bool add_gumbel_noise = false; // Add gumbel noise to the logits if temp > 0.0 + + int32_t max_length = 0; // Maximum sequence length +}; + +void diffusion_generate(llama_context * ctx, + const llama_token * input_tokens, + llama_token * output_tokens, + int32_t n_input, + const diffusion_params & params, + int32_t & n_generated);