mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-05-12 03:54:06 +00:00
Compare commits
6 Commits
b9062
...
maxk/sched
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ba72d4d287 | ||
|
|
44dbe8c521 | ||
|
|
05ff59cb57 | ||
|
|
aaf4a4d5e0 | ||
|
|
e43431b381 | ||
|
|
ceb7e14b96 |
@@ -13684,6 +13684,27 @@ class DotsOCRVisionModel(MmprojModel):
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("Sarashina2VisionForCausalLM")
|
||||
class Sarashina2VLTextModel(LlamaModel):
|
||||
model_arch = gguf.MODEL_ARCH.LLAMA
|
||||
|
||||
@classmethod
|
||||
def filter_tensors(cls, item: tuple[str, Callable[[], Tensor]]) -> tuple[str, Callable[[], Tensor]] | None:
|
||||
name, gen = item
|
||||
if name.startswith("llm."):
|
||||
name = name.replace("llm.", "", 1)
|
||||
elif name.startswith("norm."):
|
||||
return None
|
||||
return super().filter_tensors((name, gen))
|
||||
|
||||
|
||||
@ModelBase.register("Sarashina2VisionForCausalLM")
|
||||
class Sarashina2VLVisionModel(Qwen2VLVisionModel):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.global_config['model_type'] = "qwen2_vl"
|
||||
|
||||
|
||||
###### CONVERSION LOGIC ######
|
||||
|
||||
|
||||
@@ -13940,7 +13961,7 @@ def get_model_architecture(hparams: dict[str, Any], model_type: ModelType) -> st
|
||||
# Step3-VL keeps text config under text_config but uses a custom top-level architecture.
|
||||
# For text conversion we route to a dedicated text-only class.
|
||||
# TODO: refactor this later to avoid adding exception here
|
||||
if model_type == ModelType.TEXT and arch == "StepVLForConditionalGeneration":
|
||||
if model_type == ModelType.TEXT and arch in ("StepVLForConditionalGeneration", "Sarashina2VisionForCausalLM"):
|
||||
return arch
|
||||
|
||||
# if "architectures" is found in the sub-config, use that instead
|
||||
|
||||
@@ -965,7 +965,7 @@ static void ggml_backend_sched_print_assignments(ggml_backend_sched_t sched, str
|
||||
}
|
||||
if (sched->debug > 1) {
|
||||
ggml_backend_t tensor_backend = ggml_backend_sched_get_tensor_backend(sched, node);
|
||||
GGML_LOG_DEBUG("node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s] use=%d,c=%d:", i, ggml_op_name(node->op), node->name,
|
||||
GGML_LOG_DEBUG("node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s] use=%d,c=%d:", i, ggml_op_desc(node), node->name,
|
||||
fmt_size(ggml_nbytes(node)), tensor_backend ? ggml_backend_name(tensor_backend) : "NULL", GET_CAUSE(node),
|
||||
graph->use_counts[ggml_hash_find(&graph->visited_hash_set, node)], node->flags & GGML_TENSOR_FLAG_COMPUTE ? 1 : 0);
|
||||
for (int j = 0; j < GGML_MAX_SRC; j++) {
|
||||
|
||||
@@ -54,15 +54,31 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const int64_t dps2 = ne2 / ne02;
|
||||
const int64_t dps3 = ne3 / ne03;
|
||||
|
||||
// TODO batched matrix multiplication
|
||||
for (int64_t i3 = 0; i3 < ne3; ++i3) {
|
||||
for (int64_t i2 = 0; i2 < ne2; ++i2) {
|
||||
if (dps2 == 1 && ne2 > 1) {
|
||||
// src0 has uniform stride s02 along dim 2; batch the inner loop with a strided GEMM
|
||||
GGML_ASSERT(ne2 <= std::numeric_limits<int>::max());
|
||||
const int batch_count = (int) ne2;
|
||||
for (int64_t i3 = 0; i3 < ne3; ++i3) {
|
||||
CUBLAS_CHECK(
|
||||
cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op,
|
||||
cublasSgemmStridedBatched(handle, CUBLAS_OP_N, src1_cublas_op,
|
||||
ne0, ne1, ne01,
|
||||
&alpha, src0_d + (i3/dps3)*s03 + (i2/dps2)*s02, lda,
|
||||
src1_d + i3 *s13 + i2 *s12, ldb,
|
||||
&beta, dst_d + i3 *s3 + i2 *s2, ldc));
|
||||
&alpha, src0_d + (i3/dps3)*s03, lda, s02,
|
||||
src1_d + i3 *s13, ldb, s12,
|
||||
&beta, dst_d + i3 *s3, ldc, s2,
|
||||
batch_count));
|
||||
}
|
||||
} else {
|
||||
// Fallback: ne2 == 1 (no batching benefit) or dps2 > 1 (src0 broadcast along dim 2
|
||||
// with non-uniform stride; would need cublasSgemmBatched with pointer arrays).
|
||||
for (int64_t i3 = 0; i3 < ne3; ++i3) {
|
||||
for (int64_t i2 = 0; i2 < ne2; ++i2) {
|
||||
CUBLAS_CHECK(
|
||||
cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op,
|
||||
ne0, ne1, ne01,
|
||||
&alpha, src0_d + (i3/dps3)*s03 + (i2/dps2)*s02, lda,
|
||||
src1_d + i3 *s13 + i2 *s12, ldb,
|
||||
&beta, dst_d + i3 *s3 + i2 *s2, ldc));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
1
ggml/src/ggml-cuda/vendors/hip.h
vendored
1
ggml/src/ggml-cuda/vendors/hip.h
vendored
@@ -48,6 +48,7 @@
|
||||
#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
|
||||
#define cublasSetStream hipblasSetStream
|
||||
#define cublasSgemm hipblasSgemm
|
||||
#define cublasSgemmStridedBatched hipblasSgemmStridedBatched
|
||||
#define cublasStatus_t hipblasStatus_t
|
||||
#define cublasOperation_t hipblasOperation_t
|
||||
#define cudaDevAttrCooperativeLaunch hipDeviceAttributeCooperativeLaunch
|
||||
|
||||
1
ggml/src/ggml-cuda/vendors/musa.h
vendored
1
ggml/src/ggml-cuda/vendors/musa.h
vendored
@@ -32,6 +32,7 @@
|
||||
#define cublasSetMathMode mublasSetMathMode
|
||||
#define cublasSetStream mublasSetStream
|
||||
#define cublasSgemm mublasSgemm
|
||||
#define cublasSgemmStridedBatched mublasSgemmStridedBatched
|
||||
#define cublasStatus_t mublasStatus_t
|
||||
#define cublasOperation_t mublasOperation_t
|
||||
#define cublasGetStatusString mublasGetStatusString
|
||||
|
||||
@@ -87,17 +87,17 @@ static void ggml_backend_metal_buffer_shared_clear(ggml_backend_buffer_t buffer,
|
||||
}
|
||||
|
||||
static ggml_backend_buffer_i ggml_backend_metal_buffer_shared_i = {
|
||||
/* .free_buffer = */ ggml_backend_metal_buffer_shared_free_buffer,
|
||||
/* .get_base = */ ggml_backend_metal_buffer_shared_get_base,
|
||||
/* .init_tensor = */ NULL,
|
||||
/* .memset_tensor = */ ggml_backend_metal_buffer_shared_memset_tensor,
|
||||
/* .set_tensor = */ ggml_backend_metal_buffer_shared_set_tensor,
|
||||
/* .get_tensor = */ ggml_backend_metal_buffer_shared_get_tensor,
|
||||
/* .set_tensor_2d = */ NULL,
|
||||
/* .get_tensor_2d = */ NULL,
|
||||
/* .cpy_tensor = */ ggml_backend_metal_buffer_shared_cpy_tensor,
|
||||
/* .clear = */ ggml_backend_metal_buffer_shared_clear,
|
||||
/* .reset = */ NULL,
|
||||
/* .free_buffer = */ ggml_backend_metal_buffer_shared_free_buffer,
|
||||
/* .get_base = */ ggml_backend_metal_buffer_shared_get_base,
|
||||
/* .init_tensor = */ NULL,
|
||||
/* .memset_tensor = */ ggml_backend_metal_buffer_shared_memset_tensor,
|
||||
/* .set_tensor = */ ggml_backend_metal_buffer_shared_set_tensor,
|
||||
/* .get_tensor = */ ggml_backend_metal_buffer_shared_get_tensor,
|
||||
/* .set_tensor_2d = */ NULL,
|
||||
/* .get_tensor_2d = */ NULL,
|
||||
/* .cpy_tensor = */ ggml_backend_metal_buffer_shared_cpy_tensor,
|
||||
/* .clear = */ ggml_backend_metal_buffer_shared_clear,
|
||||
/* .reset = */ NULL,
|
||||
};
|
||||
|
||||
// private buffer
|
||||
@@ -163,17 +163,17 @@ static void ggml_backend_metal_buffer_private_clear(ggml_backend_buffer_t buffer
|
||||
}
|
||||
|
||||
static ggml_backend_buffer_i ggml_backend_metal_buffer_private_i = {
|
||||
/* .free_buffer = */ ggml_backend_metal_buffer_private_free_buffer,
|
||||
/* .get_base = */ ggml_backend_metal_buffer_private_get_base,
|
||||
/* .init_tensor = */ NULL,
|
||||
/* .memset_tensor = */ ggml_backend_metal_buffer_private_memset_tensor,
|
||||
/* .set_tensor = */ ggml_backend_metal_buffer_private_set_tensor,
|
||||
/* .get_tensor = */ ggml_backend_metal_buffer_private_get_tensor,
|
||||
/* .set_tensor_2d = */ NULL,
|
||||
/* .get_tensor_2d = */ NULL,
|
||||
/* .cpy_tensor = */ ggml_backend_metal_buffer_private_cpy_tensor,
|
||||
/* .clear = */ ggml_backend_metal_buffer_private_clear,
|
||||
/* .reset = */ NULL,
|
||||
/* .free_buffer = */ ggml_backend_metal_buffer_private_free_buffer,
|
||||
/* .get_base = */ ggml_backend_metal_buffer_private_get_base,
|
||||
/* .init_tensor = */ NULL,
|
||||
/* .memset_tensor = */ ggml_backend_metal_buffer_private_memset_tensor,
|
||||
/* .set_tensor = */ ggml_backend_metal_buffer_private_set_tensor,
|
||||
/* .get_tensor = */ ggml_backend_metal_buffer_private_get_tensor,
|
||||
/* .set_tensor_2d = */ NULL,
|
||||
/* .get_tensor_2d = */ NULL,
|
||||
/* .cpy_tensor = */ ggml_backend_metal_buffer_private_cpy_tensor,
|
||||
/* .clear = */ ggml_backend_metal_buffer_private_clear,
|
||||
/* .reset = */ NULL,
|
||||
};
|
||||
|
||||
static bool ggml_backend_buffer_is_metal(ggml_backend_buffer_t buffer) {
|
||||
|
||||
@@ -28,6 +28,7 @@
|
||||
#include <memory>
|
||||
#include <charconv>
|
||||
#include <mutex>
|
||||
#include <regex>
|
||||
|
||||
#undef MIN
|
||||
#undef MAX
|
||||
@@ -396,6 +397,8 @@ struct ggml_backend_opencl_context {
|
||||
bool has_vector_subgroup_broadcast;
|
||||
bool disable_fusion;
|
||||
|
||||
std::regex *opfilter = nullptr; // regex of ops to not claim
|
||||
|
||||
bool adreno_has_large_buffer;
|
||||
bool adreno_use_large_buffer;
|
||||
ggml_cl_compiler_version adreno_cl_compiler_version;
|
||||
@@ -3494,6 +3497,12 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
|
||||
|
||||
backend_ctx->disable_fusion = getenv("GGML_OPENCL_DISABLE_FUSION") != nullptr;
|
||||
|
||||
const char * str_opfilter = getenv("GGML_OPENCL_OPFILTER");
|
||||
if (str_opfilter) {
|
||||
backend_ctx->opfilter = new std::regex(str_opfilter, std::regex_constants::icase);
|
||||
GGML_LOG_INFO("ggml_opencl: opfilter regex = \"%s\"\n", str_opfilter);
|
||||
}
|
||||
|
||||
dev_ctx->backend_ctx = backend_ctx.release();
|
||||
return dev_ctx->backend_ctx;
|
||||
}
|
||||
@@ -4143,6 +4152,11 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
|
||||
ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *)dev->context;
|
||||
ggml_backend_opencl_context * backend_ctx = dev_ctx->backend_ctx;
|
||||
|
||||
// reject ops that match the opfilter regex
|
||||
if (backend_ctx->opfilter && std::regex_match(std::string(ggml_op_desc(op)), *backend_ctx->opfilter)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
switch (op->op) {
|
||||
case GGML_OP_NONE:
|
||||
return true;
|
||||
|
||||
@@ -2451,7 +2451,30 @@ public:
|
||||
for (auto & [buft, mbuf] : mbufs_new) {
|
||||
auto & mbuf_cur = mbufs[buft];
|
||||
|
||||
if (!mbuf_cur.buf || mbuf_cur.org.size() != mbuf.org.size() || mbuf_cur.total_size != mbuf.total_size) {
|
||||
bool need_alloc = false;
|
||||
|
||||
need_alloc = need_alloc || (!mbuf_cur.buf);
|
||||
need_alloc = need_alloc || (mbuf_cur.org.size() != mbuf.org.size());
|
||||
need_alloc = need_alloc || (mbuf_cur.total_size != mbuf.total_size);
|
||||
|
||||
if (!need_alloc) {
|
||||
for (size_t i = 0; i < mbuf_cur.org.size(); ++i) {
|
||||
auto * org0 = mbuf_cur.org[i];
|
||||
auto * org1 = mbuf.org[i];
|
||||
|
||||
if (!ggml_are_same_shape(org0, org1)) {
|
||||
need_alloc = true;
|
||||
break;
|
||||
}
|
||||
|
||||
if (org0->view_src != org1->view_src || org0->view_offs != org1->view_offs) {
|
||||
need_alloc = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (need_alloc) {
|
||||
mbuf_cur = std::move(mbuf);
|
||||
|
||||
mbuf_cur.buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(mbuf_cur.ctx.get(), buft));
|
||||
@@ -2515,6 +2538,31 @@ public:
|
||||
mbufs_new[buft].total_size += rinfo.size;
|
||||
}
|
||||
|
||||
for (auto & [buft, mbuf] : mbufs_new) {
|
||||
ggml_init_params params = {
|
||||
/*.mem_size =*/ mbuf.n_tensors*ggml_tensor_overhead(),
|
||||
/*.mem_buffer =*/ NULL,
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
|
||||
mbuf.ctx.reset(ggml_init(params));
|
||||
|
||||
mbuf.org.reserve(mbuf.n_tensors);
|
||||
}
|
||||
|
||||
for (const auto & rinfo : rinfos) {
|
||||
auto * buft = ggml_backend_buffer_get_type(rinfo.tensor->buffer);
|
||||
|
||||
const int64_t n = rinfo.size/ggml_element_size(rinfo.tensor);
|
||||
|
||||
auto & mbuf = mbufs_new[buft];
|
||||
|
||||
mbuf.org.push_back(ggml_view_1d(mbuf.ctx.get(), rinfo.tensor, n, rinfo.offset));
|
||||
|
||||
auto & view = mbuf.org.back();
|
||||
view->buffer = rinfo.tensor->buffer;
|
||||
}
|
||||
|
||||
for (auto & [buft, mbuf] : mbufs_new) {
|
||||
const auto & mbuf_cur = mbufs.at(buft);
|
||||
|
||||
@@ -2523,9 +2571,11 @@ public:
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < mbuf_cur.org.size(); ++i) {
|
||||
ggml_backend_tensor_copy(mbuf_cur.cpy[i], mbuf_cur.org[i]);
|
||||
ggml_backend_tensor_copy(mbuf_cur.cpy[i], mbuf.org[i]);
|
||||
}
|
||||
}
|
||||
|
||||
GGML_ASSERT(buf_size == 0);
|
||||
}
|
||||
|
||||
void read(void * dst, size_t size) override {
|
||||
|
||||
@@ -726,6 +726,10 @@ void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq
|
||||
cell_ranges.emplace_back(cell_range_begin, size);
|
||||
}
|
||||
|
||||
if (flags % LLAMA_STATE_SEQ_FLAGS_ON_DEVICE && cell_ranges.size() > 1) {
|
||||
GGML_ABORT("cannot save/load multiple ranges of cells to/from device memory\n");
|
||||
}
|
||||
|
||||
// DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
|
||||
uint32_t cell_count_check = 0;
|
||||
for (const auto & range : cell_ranges) {
|
||||
|
||||
@@ -8385,6 +8385,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
}
|
||||
}
|
||||
|
||||
// ne2 sweep to cover the cublasSgemmStridedBatched path (dps2 == 1, ne2 > 1)
|
||||
for (int64_t ne2 : {1, 8, 16, 32}) {
|
||||
test_cases.emplace_back(new test_out_prod(GGML_TYPE_F32, GGML_TYPE_F32,
|
||||
256, 16, 16, {ne2, 1}, {1, 1}));
|
||||
}
|
||||
|
||||
// add_id
|
||||
for (ggml_type type_a : {GGML_TYPE_F32}) {
|
||||
for (ggml_type type_b : {GGML_TYPE_F32}) {
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -13,6 +13,7 @@
|
||||
import { FORK_TREE_DEPTH_PADDING } from '$lib/constants';
|
||||
import { getAllLoadingChats } from '$lib/stores/chat.svelte';
|
||||
import { conversationsStore } from '$lib/stores/conversations.svelte';
|
||||
import { TruncatedText } from '$lib/components/app';
|
||||
import { onMount } from 'svelte';
|
||||
|
||||
interface Props {
|
||||
@@ -148,9 +149,7 @@
|
||||
</Tooltip.Root>
|
||||
{/if}
|
||||
|
||||
<span class="truncate text-sm font-medium">
|
||||
{conversation.name}
|
||||
</span>
|
||||
<TruncatedText text={conversation.name} class="text-sm font-medium" showTooltip={false} />
|
||||
</div>
|
||||
|
||||
{#if renderActionsDropdown}
|
||||
|
||||
@@ -104,13 +104,15 @@
|
||||
</p>
|
||||
{/if}
|
||||
{:else if field.type === SettingsFieldType.TEXTAREA}
|
||||
<Label for={field.key} class="block flex items-center gap-1.5 text-sm font-medium">
|
||||
{field.label}
|
||||
{#if field.label}
|
||||
<Label for={field.key} class="block flex items-center gap-1.5 text-sm font-medium">
|
||||
{field.label}
|
||||
|
||||
{#if field.isExperimental}
|
||||
<FlaskConical class="h-3.5 w-3.5 text-muted-foreground" />
|
||||
{/if}
|
||||
</Label>
|
||||
{#if field.isExperimental}
|
||||
<FlaskConical class="h-3.5 w-3.5 text-muted-foreground" />
|
||||
{/if}
|
||||
</Label>
|
||||
{/if}
|
||||
|
||||
<Textarea
|
||||
id={field.key}
|
||||
|
||||
@@ -35,6 +35,7 @@ export * from './settings-keys';
|
||||
export * from './settings-sections';
|
||||
export * from './supported-file-types';
|
||||
export * from './table-html-restorer';
|
||||
export * from './title-generation';
|
||||
export * from './tools';
|
||||
export * from './tooltip-config';
|
||||
export * from './ui';
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import { ColorMode } from '$lib/enums/ui';
|
||||
import { Monitor, Moon, Sun } from '@lucide/svelte';
|
||||
import { TITLE } from './title-generation';
|
||||
|
||||
export const SETTING_CONFIG_DEFAULT: Record<string, string | number | boolean | undefined> = {
|
||||
// Note: in order not to introduce breaking changes, please keep the same data type (number, string, etc) if you want to change the default value.
|
||||
@@ -16,6 +17,8 @@ export const SETTING_CONFIG_DEFAULT: Record<string, string | number | boolean |
|
||||
showMessageStats: true,
|
||||
askForTitleConfirmation: false,
|
||||
titleGenerationUseFirstLine: false,
|
||||
titleGenerationUseLLM: false,
|
||||
titleGenerationPrompt: TITLE.DEFAULT_PROMPT,
|
||||
pasteLongTextToFileLen: 2500,
|
||||
copyTextAttachmentsAsPlainText: false,
|
||||
pdfAsImage: false,
|
||||
@@ -121,6 +124,10 @@ export const SETTING_CONFIG_INFO: Record<string, string> = {
|
||||
'Ask for confirmation before automatically changing conversation title when editing the first message.',
|
||||
titleGenerationUseFirstLine:
|
||||
'Use only the first non-empty line of the prompt to generate the conversation title.',
|
||||
titleGenerationUseLLM:
|
||||
'Use the LLM to automatically generate conversation titles based on the first message exchange.',
|
||||
titleGenerationPrompt:
|
||||
'Optional template for the title generation prompt. Use {{USER}} for the user message and {{ASSISTANT}} for the assistant message.',
|
||||
pdfAsImage:
|
||||
'Parse PDF as image instead of text. Automatically falls back to text processing for non-vision models.',
|
||||
disableAutoScroll:
|
||||
|
||||
@@ -16,6 +16,8 @@ export const SETTINGS_KEYS = {
|
||||
PDF_AS_IMAGE: 'pdfAsImage',
|
||||
ASK_FOR_TITLE_CONFIRMATION: 'askForTitleConfirmation',
|
||||
TITLE_GENERATION_USE_FIRST_LINE: 'titleGenerationUseFirstLine',
|
||||
TITLE_GENERATION_USE_LLM: 'titleGenerationUseLLM',
|
||||
TITLE_GENERATION_PROMPT: 'titleGenerationPrompt',
|
||||
// Display
|
||||
SHOW_MESSAGE_STATS: 'showMessageStats',
|
||||
SHOW_THOUGHT_IN_PROGRESS: 'showThoughtInProgress',
|
||||
|
||||
@@ -95,6 +95,17 @@ export const SETTINGS_CHAT_SECTIONS: SettingsSection[] = [
|
||||
key: SETTINGS_KEYS.TITLE_GENERATION_USE_FIRST_LINE,
|
||||
label: 'Use first non-empty line for conversation title',
|
||||
type: SettingsFieldType.CHECKBOX
|
||||
},
|
||||
{
|
||||
key: SETTINGS_KEYS.TITLE_GENERATION_USE_LLM,
|
||||
label: 'Use LLM to generate conversation title',
|
||||
type: SettingsFieldType.CHECKBOX,
|
||||
isExperimental: true
|
||||
},
|
||||
{
|
||||
key: SETTINGS_KEYS.TITLE_GENERATION_PROMPT,
|
||||
type: SettingsFieldType.TEXTAREA,
|
||||
isExperimental: true
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
9
tools/server/webui/src/lib/constants/title-generation.ts
Normal file
9
tools/server/webui/src/lib/constants/title-generation.ts
Normal file
@@ -0,0 +1,9 @@
|
||||
/* Title generation constants */
|
||||
export const TITLE = {
|
||||
MIN_LENGTH: 3,
|
||||
FALLBACK: 'New Chat',
|
||||
DEFAULT_PROMPT:
|
||||
'Based on the following interaction, generate a short, concise title (maximum 6-8 words) that captures the main topic. Return ONLY the title text, nothing else. Do not use quotes.\n\nUser: {{USER}}\n\nAssistant: {{ASSISTANT}}\n\nTitle:',
|
||||
PREFIX_PATTERN: /^(Title:|Subject:|Topic:)\s*/i,
|
||||
QUOTE_PATTERN: /^["]|["]$/g
|
||||
} as const;
|
||||
@@ -14,11 +14,59 @@ import {
|
||||
ReasoningFormat,
|
||||
UrlProtocol
|
||||
} from '$lib/enums';
|
||||
import type { ApiChatMessageContentPart, ApiChatCompletionToolCall } from '$lib/types/api';
|
||||
import type {
|
||||
ApiChatMessageContentPart,
|
||||
ApiChatMessageData,
|
||||
ApiChatCompletionToolCall
|
||||
} from '$lib/types/api';
|
||||
import type { DatabaseMessageExtraMcpPrompt, DatabaseMessageExtraMcpResource } from '$lib/types';
|
||||
import { modelsStore } from '$lib/stores/models.svelte';
|
||||
|
||||
export class ChatService {
|
||||
/**
|
||||
*
|
||||
*
|
||||
* Title Generation
|
||||
*
|
||||
*
|
||||
*/
|
||||
|
||||
/**
|
||||
* Sends a streaming chat completion request for generating a chat title.
|
||||
* Delegates to `sendMessage` for fetch, SSE parsing, and error handling.
|
||||
*
|
||||
* @param message - The single message to send (a user message containing the title generation prompt)
|
||||
* @param model - Optional model name to use (required in ROUTER mode)
|
||||
* @param signal - Optional AbortSignal to cancel the request
|
||||
* @returns {Promise<string>} The aggregated title text, or empty string if request failed
|
||||
* @static
|
||||
*/
|
||||
static async generateTitle(
|
||||
message: ApiChatMessageData,
|
||||
model?: string | null,
|
||||
signal?: AbortSignal
|
||||
): Promise<string> {
|
||||
let titleResponse = '';
|
||||
try {
|
||||
await ChatService.sendMessage(
|
||||
[message],
|
||||
{
|
||||
model: model || undefined,
|
||||
stream: true,
|
||||
custom: { chat_template_kwargs: { enable_thinking: false } },
|
||||
onChunk: (chunk: string) => {
|
||||
titleResponse += chunk;
|
||||
}
|
||||
},
|
||||
undefined,
|
||||
signal
|
||||
);
|
||||
} catch {
|
||||
return '';
|
||||
}
|
||||
return titleResponse;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
*
|
||||
@@ -122,7 +170,11 @@ export class ChatService {
|
||||
return true;
|
||||
});
|
||||
// If only text remains and it's a single part, simplify to string
|
||||
if (msg.content.length === 1 && msg.content[0].type === ContentPartType.TEXT) {
|
||||
if (
|
||||
msg.content.length === 1 &&
|
||||
msg.content[0].type === ContentPartType.TEXT &&
|
||||
typeof msg.content[0].text === 'string'
|
||||
) {
|
||||
msg.content = msg.content[0].text;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -36,7 +36,8 @@ import {
|
||||
import {
|
||||
MAX_INACTIVE_CONVERSATION_STATES,
|
||||
INACTIVE_CONVERSATION_STATE_MAX_AGE_MS,
|
||||
SYSTEM_MESSAGE_PLACEHOLDER
|
||||
SYSTEM_MESSAGE_PLACEHOLDER,
|
||||
TITLE
|
||||
} from '$lib/constants';
|
||||
import type {
|
||||
ChatMessageTimings,
|
||||
@@ -44,7 +45,12 @@ import type {
|
||||
ChatStreamCallbacks,
|
||||
ErrorDialogState
|
||||
} from '$lib/types/chat';
|
||||
import type { ApiProcessingState, DatabaseMessage, DatabaseMessageExtra } from '$lib/types';
|
||||
import type {
|
||||
ApiChatMessageData,
|
||||
ApiProcessingState,
|
||||
DatabaseMessage,
|
||||
DatabaseMessageExtra
|
||||
} from '$lib/types';
|
||||
import { ErrorDialogType, MessageRole, MessageType } from '$lib/enums';
|
||||
|
||||
interface ConversationStateEntry {
|
||||
@@ -572,7 +578,11 @@ class ChatStore {
|
||||
conversationsStore.addMessageToActive(assistantMessage);
|
||||
await this.streamChatCompletion(
|
||||
conversationsStore.activeMessages.slice(0, -1),
|
||||
assistantMessage
|
||||
assistantMessage,
|
||||
undefined,
|
||||
undefined,
|
||||
undefined,
|
||||
config().titleGenerationUseLLM && isNewConversation ? content : undefined
|
||||
);
|
||||
} catch (error) {
|
||||
if (isAbortError(error)) {
|
||||
@@ -601,7 +611,8 @@ class ChatStore {
|
||||
assistantMessage: DatabaseMessage,
|
||||
onComplete?: (content: string) => Promise<void>,
|
||||
onError?: (error: Error) => void,
|
||||
modelOverride?: string | null
|
||||
modelOverride?: string | null,
|
||||
firstUserMessageContent?: string
|
||||
): Promise<void> {
|
||||
let effectiveModel = modelOverride;
|
||||
|
||||
@@ -894,6 +905,12 @@ class ChatStore {
|
||||
if (onComplete) await onComplete(content);
|
||||
if (isRouterMode()) modelsStore.fetchRouterModels().catch(console.error);
|
||||
|
||||
// Generate LLM based title for new conversations (avoids stale reference
|
||||
// issue when user switches conversations while streaming)
|
||||
if (firstUserMessageContent) {
|
||||
await this.generateTitleWithLLM(firstUserMessageContent, streamedContent, convId);
|
||||
}
|
||||
|
||||
// Check if there's a pending message queued during streaming
|
||||
const pending = this.consumePendingMessage(convId);
|
||||
if (pending) {
|
||||
@@ -921,6 +938,49 @@ class ChatStore {
|
||||
this.setProcessingState(convId, null);
|
||||
this.clearPendingMessage(convId);
|
||||
}
|
||||
|
||||
private async generateTitleWithLLM(
|
||||
userContent: string,
|
||||
assistantContent: string,
|
||||
convId: string
|
||||
): Promise<void> {
|
||||
const effectiveModel = isRouterMode() && selectedModelName() ? selectedModelName() : undefined;
|
||||
const configValue = config();
|
||||
const titlePromptTemplate =
|
||||
typeof configValue.titleGenerationPrompt === 'string' &&
|
||||
configValue.titleGenerationPrompt.trim()
|
||||
? configValue.titleGenerationPrompt
|
||||
: TITLE.DEFAULT_PROMPT;
|
||||
|
||||
const titlePrompt = titlePromptTemplate
|
||||
.replace('{{USER}}', String(userContent || ''))
|
||||
.replace('{{ASSISTANT}}', String(assistantContent || ''));
|
||||
|
||||
const titleMessage: ApiChatMessageData = {
|
||||
role: MessageRole.USER,
|
||||
content: titlePrompt
|
||||
};
|
||||
|
||||
const titleResponse = await ChatService.generateTitle(titleMessage, effectiveModel);
|
||||
|
||||
if (!titleResponse) {
|
||||
return;
|
||||
}
|
||||
|
||||
let cleanTitle = titleResponse.trim();
|
||||
cleanTitle = cleanTitle
|
||||
.replace(TITLE.PREFIX_PATTERN, '')
|
||||
.replace(TITLE.QUOTE_PATTERN, '')
|
||||
.trim();
|
||||
if (!cleanTitle || cleanTitle.length < TITLE.MIN_LENGTH) {
|
||||
const firstLine = userContent.split('\n').find((l) => l.trim().length > 0);
|
||||
cleanTitle = firstLine ? firstLine.trim() : TITLE.FALLBACK;
|
||||
}
|
||||
if (cleanTitle && cleanTitle.length >= TITLE.MIN_LENGTH) {
|
||||
await conversationsStore.updateConversationName(convId, cleanTitle);
|
||||
}
|
||||
}
|
||||
|
||||
private async savePartialResponseIfNeeded(convId?: string): Promise<void> {
|
||||
const conversationId = convId || conversationsStore.activeConversation?.id;
|
||||
if (!conversationId) return;
|
||||
|
||||
Reference in New Issue
Block a user