llama : fix device state save/load

This commit is contained in:
Georgi Gerganov
2026-05-07 16:40:24 +03:00
parent 803627f121
commit f29df1173b
3 changed files with 78 additions and 24 deletions

View File

@@ -87,17 +87,17 @@ static void ggml_backend_metal_buffer_shared_clear(ggml_backend_buffer_t buffer,
}
static ggml_backend_buffer_i ggml_backend_metal_buffer_shared_i = {
/* .free_buffer = */ ggml_backend_metal_buffer_shared_free_buffer,
/* .get_base = */ ggml_backend_metal_buffer_shared_get_base,
/* .init_tensor = */ NULL,
/* .memset_tensor = */ ggml_backend_metal_buffer_shared_memset_tensor,
/* .set_tensor = */ ggml_backend_metal_buffer_shared_set_tensor,
/* .get_tensor = */ ggml_backend_metal_buffer_shared_get_tensor,
/* .set_tensor_2d = */ NULL,
/* .get_tensor_2d = */ NULL,
/* .cpy_tensor = */ ggml_backend_metal_buffer_shared_cpy_tensor,
/* .clear = */ ggml_backend_metal_buffer_shared_clear,
/* .reset = */ NULL,
/* .free_buffer = */ ggml_backend_metal_buffer_shared_free_buffer,
/* .get_base = */ ggml_backend_metal_buffer_shared_get_base,
/* .init_tensor = */ NULL,
/* .memset_tensor = */ ggml_backend_metal_buffer_shared_memset_tensor,
/* .set_tensor = */ ggml_backend_metal_buffer_shared_set_tensor,
/* .get_tensor = */ ggml_backend_metal_buffer_shared_get_tensor,
/* .set_tensor_2d = */ NULL,
/* .get_tensor_2d = */ NULL,
/* .cpy_tensor = */ ggml_backend_metal_buffer_shared_cpy_tensor,
/* .clear = */ ggml_backend_metal_buffer_shared_clear,
/* .reset = */ NULL,
};
// private buffer
@@ -163,17 +163,17 @@ static void ggml_backend_metal_buffer_private_clear(ggml_backend_buffer_t buffer
}
static ggml_backend_buffer_i ggml_backend_metal_buffer_private_i = {
/* .free_buffer = */ ggml_backend_metal_buffer_private_free_buffer,
/* .get_base = */ ggml_backend_metal_buffer_private_get_base,
/* .init_tensor = */ NULL,
/* .memset_tensor = */ ggml_backend_metal_buffer_private_memset_tensor,
/* .set_tensor = */ ggml_backend_metal_buffer_private_set_tensor,
/* .get_tensor = */ ggml_backend_metal_buffer_private_get_tensor,
/* .set_tensor_2d = */ NULL,
/* .get_tensor_2d = */ NULL,
/* .cpy_tensor = */ ggml_backend_metal_buffer_private_cpy_tensor,
/* .clear = */ ggml_backend_metal_buffer_private_clear,
/* .reset = */ NULL,
/* .free_buffer = */ ggml_backend_metal_buffer_private_free_buffer,
/* .get_base = */ ggml_backend_metal_buffer_private_get_base,
/* .init_tensor = */ NULL,
/* .memset_tensor = */ ggml_backend_metal_buffer_private_memset_tensor,
/* .set_tensor = */ ggml_backend_metal_buffer_private_set_tensor,
/* .get_tensor = */ ggml_backend_metal_buffer_private_get_tensor,
/* .set_tensor_2d = */ NULL,
/* .get_tensor_2d = */ NULL,
/* .cpy_tensor = */ ggml_backend_metal_buffer_private_cpy_tensor,
/* .clear = */ ggml_backend_metal_buffer_private_clear,
/* .reset = */ NULL,
};
static bool ggml_backend_buffer_is_metal(ggml_backend_buffer_t buffer) {

View File

@@ -2451,7 +2451,30 @@ public:
for (auto & [buft, mbuf] : mbufs_new) {
auto & mbuf_cur = mbufs[buft];
if (!mbuf_cur.buf || mbuf_cur.org.size() != mbuf.org.size() || mbuf_cur.total_size != mbuf.total_size) {
bool need_alloc = false;
need_alloc = need_alloc || (!mbuf_cur.buf);
need_alloc = need_alloc || (mbuf_cur.org.size() != mbuf.org.size());
need_alloc = need_alloc || (mbuf_cur.total_size != mbuf.total_size);
if (!need_alloc) {
for (size_t i = 0; i < mbuf_cur.org.size(); ++i) {
auto * org0 = mbuf_cur.org[i];
auto * org1 = mbuf.org[i];
if (!ggml_are_same_shape(org0, org1)) {
need_alloc = true;
break;
}
if (org0->view_src != org1->view_src || org0->view_offs != org1->view_offs) {
need_alloc = true;
break;
}
}
}
if (need_alloc) {
mbuf_cur = std::move(mbuf);
mbuf_cur.buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(mbuf_cur.ctx.get(), buft));
@@ -2515,6 +2538,31 @@ public:
mbufs_new[buft].total_size += rinfo.size;
}
for (auto & [buft, mbuf] : mbufs_new) {
ggml_init_params params = {
/*.mem_size =*/ mbuf.n_tensors*ggml_tensor_overhead(),
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ true,
};
mbuf.ctx.reset(ggml_init(params));
mbuf.org.reserve(mbuf.n_tensors);
}
for (const auto & rinfo : rinfos) {
auto * buft = ggml_backend_buffer_get_type(rinfo.tensor->buffer);
const int64_t n = rinfo.size/ggml_element_size(rinfo.tensor);
auto & mbuf = mbufs_new[buft];
mbuf.org.push_back(ggml_view_1d(mbuf.ctx.get(), rinfo.tensor, n, rinfo.offset));
auto & view = mbuf.org.back();
view->buffer = rinfo.tensor->buffer;
}
for (auto & [buft, mbuf] : mbufs_new) {
const auto & mbuf_cur = mbufs.at(buft);
@@ -2523,9 +2571,11 @@ public:
}
for (size_t i = 0; i < mbuf_cur.org.size(); ++i) {
ggml_backend_tensor_copy(mbuf_cur.cpy[i], mbuf_cur.org[i]);
ggml_backend_tensor_copy(mbuf_cur.cpy[i], mbuf.org[i]);
}
}
GGML_ASSERT(buf_size == 0);
}
void read(void * dst, size_t size) override {

View File

@@ -726,6 +726,10 @@ void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq
cell_ranges.emplace_back(cell_range_begin, size);
}
if (flags % LLAMA_STATE_SEQ_FLAGS_ON_DEVICE && cell_ranges.size() > 1) {
GGML_ABORT("cannot save/load multiple ranges of cells to/from device memory\n");
}
// DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
uint32_t cell_count_check = 0;
for (const auto & range : cell_ranges) {