mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-05-14 21:14:10 +00:00
llama : fix device state save/load
This commit is contained in:
@@ -87,17 +87,17 @@ static void ggml_backend_metal_buffer_shared_clear(ggml_backend_buffer_t buffer,
|
||||
}
|
||||
|
||||
static ggml_backend_buffer_i ggml_backend_metal_buffer_shared_i = {
|
||||
/* .free_buffer = */ ggml_backend_metal_buffer_shared_free_buffer,
|
||||
/* .get_base = */ ggml_backend_metal_buffer_shared_get_base,
|
||||
/* .init_tensor = */ NULL,
|
||||
/* .memset_tensor = */ ggml_backend_metal_buffer_shared_memset_tensor,
|
||||
/* .set_tensor = */ ggml_backend_metal_buffer_shared_set_tensor,
|
||||
/* .get_tensor = */ ggml_backend_metal_buffer_shared_get_tensor,
|
||||
/* .set_tensor_2d = */ NULL,
|
||||
/* .get_tensor_2d = */ NULL,
|
||||
/* .cpy_tensor = */ ggml_backend_metal_buffer_shared_cpy_tensor,
|
||||
/* .clear = */ ggml_backend_metal_buffer_shared_clear,
|
||||
/* .reset = */ NULL,
|
||||
/* .free_buffer = */ ggml_backend_metal_buffer_shared_free_buffer,
|
||||
/* .get_base = */ ggml_backend_metal_buffer_shared_get_base,
|
||||
/* .init_tensor = */ NULL,
|
||||
/* .memset_tensor = */ ggml_backend_metal_buffer_shared_memset_tensor,
|
||||
/* .set_tensor = */ ggml_backend_metal_buffer_shared_set_tensor,
|
||||
/* .get_tensor = */ ggml_backend_metal_buffer_shared_get_tensor,
|
||||
/* .set_tensor_2d = */ NULL,
|
||||
/* .get_tensor_2d = */ NULL,
|
||||
/* .cpy_tensor = */ ggml_backend_metal_buffer_shared_cpy_tensor,
|
||||
/* .clear = */ ggml_backend_metal_buffer_shared_clear,
|
||||
/* .reset = */ NULL,
|
||||
};
|
||||
|
||||
// private buffer
|
||||
@@ -163,17 +163,17 @@ static void ggml_backend_metal_buffer_private_clear(ggml_backend_buffer_t buffer
|
||||
}
|
||||
|
||||
static ggml_backend_buffer_i ggml_backend_metal_buffer_private_i = {
|
||||
/* .free_buffer = */ ggml_backend_metal_buffer_private_free_buffer,
|
||||
/* .get_base = */ ggml_backend_metal_buffer_private_get_base,
|
||||
/* .init_tensor = */ NULL,
|
||||
/* .memset_tensor = */ ggml_backend_metal_buffer_private_memset_tensor,
|
||||
/* .set_tensor = */ ggml_backend_metal_buffer_private_set_tensor,
|
||||
/* .get_tensor = */ ggml_backend_metal_buffer_private_get_tensor,
|
||||
/* .set_tensor_2d = */ NULL,
|
||||
/* .get_tensor_2d = */ NULL,
|
||||
/* .cpy_tensor = */ ggml_backend_metal_buffer_private_cpy_tensor,
|
||||
/* .clear = */ ggml_backend_metal_buffer_private_clear,
|
||||
/* .reset = */ NULL,
|
||||
/* .free_buffer = */ ggml_backend_metal_buffer_private_free_buffer,
|
||||
/* .get_base = */ ggml_backend_metal_buffer_private_get_base,
|
||||
/* .init_tensor = */ NULL,
|
||||
/* .memset_tensor = */ ggml_backend_metal_buffer_private_memset_tensor,
|
||||
/* .set_tensor = */ ggml_backend_metal_buffer_private_set_tensor,
|
||||
/* .get_tensor = */ ggml_backend_metal_buffer_private_get_tensor,
|
||||
/* .set_tensor_2d = */ NULL,
|
||||
/* .get_tensor_2d = */ NULL,
|
||||
/* .cpy_tensor = */ ggml_backend_metal_buffer_private_cpy_tensor,
|
||||
/* .clear = */ ggml_backend_metal_buffer_private_clear,
|
||||
/* .reset = */ NULL,
|
||||
};
|
||||
|
||||
static bool ggml_backend_buffer_is_metal(ggml_backend_buffer_t buffer) {
|
||||
|
||||
@@ -2451,7 +2451,30 @@ public:
|
||||
for (auto & [buft, mbuf] : mbufs_new) {
|
||||
auto & mbuf_cur = mbufs[buft];
|
||||
|
||||
if (!mbuf_cur.buf || mbuf_cur.org.size() != mbuf.org.size() || mbuf_cur.total_size != mbuf.total_size) {
|
||||
bool need_alloc = false;
|
||||
|
||||
need_alloc = need_alloc || (!mbuf_cur.buf);
|
||||
need_alloc = need_alloc || (mbuf_cur.org.size() != mbuf.org.size());
|
||||
need_alloc = need_alloc || (mbuf_cur.total_size != mbuf.total_size);
|
||||
|
||||
if (!need_alloc) {
|
||||
for (size_t i = 0; i < mbuf_cur.org.size(); ++i) {
|
||||
auto * org0 = mbuf_cur.org[i];
|
||||
auto * org1 = mbuf.org[i];
|
||||
|
||||
if (!ggml_are_same_shape(org0, org1)) {
|
||||
need_alloc = true;
|
||||
break;
|
||||
}
|
||||
|
||||
if (org0->view_src != org1->view_src || org0->view_offs != org1->view_offs) {
|
||||
need_alloc = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (need_alloc) {
|
||||
mbuf_cur = std::move(mbuf);
|
||||
|
||||
mbuf_cur.buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(mbuf_cur.ctx.get(), buft));
|
||||
@@ -2515,6 +2538,31 @@ public:
|
||||
mbufs_new[buft].total_size += rinfo.size;
|
||||
}
|
||||
|
||||
for (auto & [buft, mbuf] : mbufs_new) {
|
||||
ggml_init_params params = {
|
||||
/*.mem_size =*/ mbuf.n_tensors*ggml_tensor_overhead(),
|
||||
/*.mem_buffer =*/ NULL,
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
|
||||
mbuf.ctx.reset(ggml_init(params));
|
||||
|
||||
mbuf.org.reserve(mbuf.n_tensors);
|
||||
}
|
||||
|
||||
for (const auto & rinfo : rinfos) {
|
||||
auto * buft = ggml_backend_buffer_get_type(rinfo.tensor->buffer);
|
||||
|
||||
const int64_t n = rinfo.size/ggml_element_size(rinfo.tensor);
|
||||
|
||||
auto & mbuf = mbufs_new[buft];
|
||||
|
||||
mbuf.org.push_back(ggml_view_1d(mbuf.ctx.get(), rinfo.tensor, n, rinfo.offset));
|
||||
|
||||
auto & view = mbuf.org.back();
|
||||
view->buffer = rinfo.tensor->buffer;
|
||||
}
|
||||
|
||||
for (auto & [buft, mbuf] : mbufs_new) {
|
||||
const auto & mbuf_cur = mbufs.at(buft);
|
||||
|
||||
@@ -2523,9 +2571,11 @@ public:
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < mbuf_cur.org.size(); ++i) {
|
||||
ggml_backend_tensor_copy(mbuf_cur.cpy[i], mbuf_cur.org[i]);
|
||||
ggml_backend_tensor_copy(mbuf_cur.cpy[i], mbuf.org[i]);
|
||||
}
|
||||
}
|
||||
|
||||
GGML_ASSERT(buf_size == 0);
|
||||
}
|
||||
|
||||
void read(void * dst, size_t size) override {
|
||||
|
||||
@@ -726,6 +726,10 @@ void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq
|
||||
cell_ranges.emplace_back(cell_range_begin, size);
|
||||
}
|
||||
|
||||
if (flags % LLAMA_STATE_SEQ_FLAGS_ON_DEVICE && cell_ranges.size() > 1) {
|
||||
GGML_ABORT("cannot save/load multiple ranges of cells to/from device memory\n");
|
||||
}
|
||||
|
||||
// DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
|
||||
uint32_t cell_count_check = 0;
|
||||
for (const auto & range : cell_ranges) {
|
||||
|
||||
Reference in New Issue
Block a user