#include "diffusion.h" #include "log.h" #include #include #include #include #include #include #include static float calculate_confidence(const llama_token_data_array & cur_p, diffusion_algorithm algorithm, std::mt19937 & rng) { switch (algorithm) { case DIFFUSION_ALGORITHM_CONFIDENCE_BASED: return cur_p.data[cur_p.selected].p; // Selected token probability case DIFFUSION_ALGORITHM_ENTROPY_BASED: { float entropy = 0.0f; const float epsilon = 1e-10f; for (size_t i = 0; i < cur_p.size; i++) { float prob = cur_p.data[i].p; entropy += prob * logf(prob + epsilon); } return -entropy; // Higher entropy = lower confidence } case DIFFUSION_ALGORITHM_MARGIN_BASED: return (cur_p.size > 1) ? cur_p.data[0].p - cur_p.data[1].p : cur_p.data[0].p; case DIFFUSION_ALGORITHM_RANDOM: { std::uniform_real_distribution uniform(0.0f, 1.0f); return uniform(rng); // Random confidence } case DIFFUSION_ALGORITHM_ORIGIN: return cur_p.data[cur_p.selected].p; default: return 0.0f; } } // Unified transfer count calculation function static int32_t calculate_transfer_count(int32_t step, int32_t total_steps, int32_t remaining_masked, diffusion_transfer_schedule schedule, float eps, const std::vector & num_transfer_tokens = {}) { switch (schedule) { case DIFFUSION_TRANSFER_SCHEDULE_TIMESTEP_BASED: { float t = 1.0f - (float) step / total_steps * (1.0f - eps); float s = 1.0f - (float) (step + 1) / total_steps * (1.0f - eps); float p_transfer = (step < total_steps - 1) ? (1.0f - s / t) : 1.0f; return (int32_t) (remaining_masked * p_transfer); } case DIFFUSION_TRANSFER_SCHEDULE_BLOCK_BASED: if (!num_transfer_tokens.empty() && step < (int32_t) num_transfer_tokens.size()) { return num_transfer_tokens[step]; } return remaining_masked / (total_steps - step); // Fallback default: return remaining_masked / (total_steps - step); } } static void add_gumbel_noise(float * logits, int32_t n_vocab, float temperature, std::mt19937 & rng) { if (temperature == 0.0f) { return; } std::uniform_real_distribution uniform(0.0, 1.0); for (int32_t i = 0; i < n_vocab; i++) { double noise = uniform(rng); // Prevent log(0) noise = std::max(noise, 1e-20); double gumbel_noise = std::pow(-std::log(noise), temperature); logits[i] = std::exp(logits[i]) / gumbel_noise; } } static std::vector get_num_transfer_tokens(int32_t mask_count, int32_t steps) { std::vector num_transfer_tokens(steps); int32_t base = mask_count / steps; int32_t remainder = mask_count % steps; for (int32_t i = 0; i < steps; i++) { num_transfer_tokens[i] = base + (i < remainder ? 1 : 0); } return num_transfer_tokens; } void diffusion_generate(llama_context * ctx, const llama_token * input_tokens, llama_token * output_tokens, int32_t n_input, const diffusion_params & params, int32_t & n_generated) { n_generated = 0; if (!ctx || !input_tokens || !output_tokens || n_input <= 0 || params.max_length <= n_input) { return; } const llama_model * model = llama_get_model(ctx); // Initialize with input and pad with mask tokens std::copy(input_tokens, input_tokens + n_input, output_tokens); std::fill(output_tokens + n_input, output_tokens + params.max_length, params.mask_token_id); std::mt19937 rng(params.seed); llama_set_causal_attn(ctx, false); int32_t n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(model)); std::vector candidates(n_vocab); std::vector conf_candidates; conf_candidates.reserve(params.max_length); std::vector mask_positions; mask_positions.reserve(params.max_length); // Setup sampler chain struct llama_sampler * sampler = llama_sampler_chain_init(llama_sampler_chain_default_params()); if (params.top_k > 0) { llama_sampler_chain_add(sampler, llama_sampler_init_top_k(params.top_k)); } if (params.top_p < 1.0f) { llama_sampler_chain_add(sampler, llama_sampler_init_top_p(params.top_p, 1)); } if (params.temperature > 0.0f) { llama_sampler_chain_add(sampler, llama_sampler_init_temp(params.temperature)); } llama_sampler_chain_add(sampler, llama_sampler_init_dist(params.seed)); struct llama_sampler * dist_sampler = llama_sampler_init_dist(params.seed); llama_batch batch = llama_batch_init(params.max_length, 0, 1); batch.n_tokens = params.max_length; // Pre-allocate buffers for CFG if needed int32_t logits_size = n_vocab * params.max_length; std::vector cond_logits_buffer; std::vector un_x_buffer; if (params.cfg_scale > 0.0f) { cond_logits_buffer.resize(logits_size); un_x_buffer.resize(params.max_length); } // For block-based processing std::vector num_transfer_tokens; int32_t num_blocks = 1; int32_t steps_per_block = params.steps; if (params.schedule == DIFFUSION_TRANSFER_SCHEDULE_BLOCK_BASED) { GGML_ASSERT(params.max_length % params.block_length == 0); num_blocks = params.max_length / params.block_length; GGML_ASSERT(params.steps % num_blocks == 0); steps_per_block = params.steps / num_blocks; } std::vector confidence(params.max_length); int64_t total_sampling_time = 0; int64_t total_time = 0; int64_t time_start = ggml_time_us(); for (int block_num = 0; block_num < num_blocks; block_num++) { int32_t block_start = (params.schedule == DIFFUSION_TRANSFER_SCHEDULE_BLOCK_BASED) ? n_input + block_num * params.block_length : 0; int32_t block_end = (params.schedule == DIFFUSION_TRANSFER_SCHEDULE_BLOCK_BASED) ? std::min(n_input + (block_num + 1) * params.block_length, params.max_length) : params.max_length; // Count masked tokens in current block for block-based processing if (params.schedule == DIFFUSION_TRANSFER_SCHEDULE_BLOCK_BASED) { int32_t block_mask_count = 0; for (int i = block_start; i < block_end; i++) { if (output_tokens[i] == params.mask_token_id) { block_mask_count++; } } num_transfer_tokens = get_num_transfer_tokens(block_mask_count, steps_per_block); } for (int32_t step = 0; step < steps_per_block; step++) { int32_t global_step = block_num * steps_per_block + step; if (params.step_callback) { if (!params.step_callback( global_step, params.steps, output_tokens, params.max_length, params.step_callback_user_data)) { break; } } // Setup batch for (int32_t i = 0; i < params.max_length; i++) { batch.token[i] = output_tokens[i]; batch.pos[i] = i; batch.n_seq_id[i] = 1; batch.seq_id[i][0] = 0; batch.logits[i] = 1; } float * logits = nullptr; if (params.cfg_scale > 0.0f) { int ret = llama_decode(ctx, batch); if (ret != 0) { LOG_ERR("Failed to generate conditional"); break; } float * cond_logits_ptr = llama_get_logits(ctx); std::memcpy(cond_logits_buffer.data(), cond_logits_ptr, logits_size * sizeof(float)); // Unconditional generation (mask input) std::copy(output_tokens, output_tokens + params.max_length, un_x_buffer.begin()); for (int32_t i = 0; i < n_input; i++) { un_x_buffer[i] = params.mask_token_id; } for (int32_t i = 0; i < params.max_length; i++) { batch.token[i] = un_x_buffer[i]; } ret = llama_decode(ctx, batch); if (ret != 0) { LOG_ERR("Failed to generate unconditional"); break; } float * uncond_logits = llama_get_logits(ctx); // Apply CFG for (int32_t i = 0; i < logits_size; i++) { cond_logits_buffer[i] = uncond_logits[i] + (params.cfg_scale + 1.0f) * (cond_logits_buffer[i] - uncond_logits[i]); } logits = cond_logits_buffer.data(); } else { int ret = llama_decode(ctx, batch); if (ret != 0) { LOG_ERR("%s: failed to decode at step %d, ret = %d\n", __func__, global_step, ret); break; } logits = llama_get_logits(ctx); } if (!logits) { LOG_ERR("%s: failed to get logits at step %d\n", __func__, global_step); break; } auto get_logits_for_pos = [&](int32_t pos) -> const float * { if (params.shift_logits) { return pos == 0 ? logits : logits + (pos - 1) * n_vocab; } return logits + pos * n_vocab; }; int64_t time_start_sampling = ggml_time_us(); mask_positions.clear(); for (int32_t i = 0; i < params.max_length; i++) { if (output_tokens[i] == params.mask_token_id) { // For block-based, only consider current block if (params.schedule != DIFFUSION_TRANSFER_SCHEDULE_BLOCK_BASED || (i >= block_start && i < block_end)) { mask_positions.push_back(i); } } } if (mask_positions.empty()) { break; } if (params.add_gumbel_noise && params.temperature > 0.0f) { add_gumbel_noise(logits, n_vocab, params.temperature, rng); } if (params.algorithm == DIFFUSION_ALGORITHM_ORIGIN) { int32_t transfer_count = calculate_transfer_count( step, steps_per_block, mask_positions.size(), params.schedule, params.eps, num_transfer_tokens); float p_transfer = (float) transfer_count / mask_positions.size(); for (int32_t pos : mask_positions) { if (std::uniform_real_distribution(0.0f, 1.0f)(rng) < p_transfer) { const float * pos_logits = get_logits_for_pos(pos); for (int32_t token_id = 0; token_id < n_vocab; token_id++) { candidates[token_id].id = token_id; candidates[token_id].logit = pos_logits[token_id]; candidates[token_id].p = 0.0f; } llama_token_data_array cur_p = { candidates.data(), (size_t) n_vocab, -1, false, }; llama_sampler_apply(sampler, &cur_p); output_tokens[pos] = cur_p.data[cur_p.selected].id; } } } else { std::vector> confidences; std::vector sampled_tokens(mask_positions.size()); for (size_t i = 0; i < mask_positions.size(); i++) { int32_t pos = mask_positions[i]; const float * pos_logits = get_logits_for_pos(pos); for (int32_t token_id = 0; token_id < n_vocab; token_id++) { candidates[token_id].logit = pos_logits[token_id]; candidates[token_id].p = 0.0f; candidates[token_id].id = token_id; } llama_token_data_array cur_p = { candidates.data(), candidates.size(), -1, false, }; llama_sampler_apply(sampler, &cur_p); llama_token sampled_token = cur_p.data[cur_p.selected].id; float conf = calculate_confidence(cur_p, params.algorithm, rng); sampled_tokens[i] = sampled_token; confidences.emplace_back(conf, i); } int32_t transfer_count = calculate_transfer_count( step, steps_per_block, mask_positions.size(), params.schedule, params.eps, num_transfer_tokens); if (transfer_count > 0) { if (params.alg_temp == 0.0f) { std::partial_sort(confidences.begin(), confidences.begin() + std::min(transfer_count, (int32_t) confidences.size()), confidences.end(), [](const std::pair & a, const std::pair & b) { if (a.first != b.first) { return a.first > b.first; } return a.second < b.second; }); for (int32_t i = 0; i < std::min(transfer_count, (int32_t) confidences.size()); i++) { int32_t mask_idx = confidences[i].second; int32_t pos = mask_positions[mask_idx]; output_tokens[pos] = sampled_tokens[mask_idx]; } } else { conf_candidates.clear(); for (size_t i = 0; i < confidences.size(); i++) { float conf_logit = confidences[i].first / params.alg_temp; conf_candidates.emplace_back(llama_token_data{ (int32_t) i, conf_logit, 0.0f }); } llama_token_data_array conf_array = { conf_candidates.data(), conf_candidates.size(), -1, false, }; for (int32_t i = 0; i < std::min(transfer_count, (int32_t) confidences.size()); i++) { llama_sampler_apply(dist_sampler, &conf_array); int32_t selected_idx = conf_array.selected; int32_t mask_idx = selected_idx; int32_t pos = mask_positions[mask_idx]; output_tokens[pos] = sampled_tokens[mask_idx]; conf_candidates[selected_idx].p = 0.0f; conf_array.selected = -1; } } } } int64_t time_end_sampling = ggml_time_us(); total_sampling_time += time_end_sampling - time_start_sampling; } } int64_t time_end = ggml_time_us(); total_time += time_end - time_start; LOG_INF("\ntotal time: %0.2fms, time per step: %0.2fms, sampling time per step: %0.2fms\n", total_time / 1000.0, total_time / 1000.0 / params.steps, total_sampling_time / 1000.0 / params.steps); llama_batch_free(batch); llama_sampler_free(sampler); llama_sampler_free(dist_sampler); n_generated = params.max_length; }