From d40697c46acbc89bf669ff5ff8d839a1e37c6fd6 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Fri, 1 May 2026 07:49:49 +0200 Subject: [PATCH] add device group test --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 318 ++++++++++++++++++++++----- 1 file changed, 267 insertions(+), 51 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 55f751c98c..57262d96b9 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -12881,6 +12881,90 @@ static void ggml_vk_bench_pair( } #endif + // Probe for device group containing both devices + vk::Device dg_device{}; + vk::Queue dg_queue{}; + vk::CommandPool dg_cmd_pool{}; + vk::Fence dg_fence{}; + uint32_t dg_idx0 = UINT32_MAX, dg_idx1 = UINT32_MAX; + uint32_t dg_heap_idx = UINT32_MAX; + bool has_devgroup = false; + + { + auto groups = vk_instance.instance.enumeratePhysicalDeviceGroups(); + for (auto & group : groups) { + uint32_t i0 = UINT32_MAX, i1 = UINT32_MAX; + for (uint32_t i = 0; i < group.physicalDeviceCount; i++) { + if (group.physicalDevices[i] == dev0->physical_device) i0 = i; + if (group.physicalDevices[i] == dev1->physical_device) i1 = i; + } + if (i0 == UINT32_MAX || i1 == UINT32_MAX) continue; + + dg_idx0 = i0; + dg_idx1 = i1; + + // Find a queue family with transfer support + auto qf_props = dev0->physical_device.getQueueFamilyProperties(); + uint32_t qf_idx = UINT32_MAX; + for (uint32_t q = 0; q < (uint32_t)qf_props.size(); q++) { + if (qf_props[q].queueFlags & vk::QueueFlagBits::eTransfer) { + qf_idx = q; + break; + } + } + if (qf_idx == UINT32_MAX) break; + + float priority = 1.0f; + vk::DeviceQueueCreateInfo qci{}; + qci.queueFamilyIndex = qf_idx; + qci.queueCount = 1; + qci.pQueuePriorities = &priority; + + vk::DeviceGroupDeviceCreateInfo dgci{}; + dgci.physicalDeviceCount = group.physicalDeviceCount; + dgci.pPhysicalDevices = group.physicalDevices; + + vk::DeviceCreateInfo dci{}; + dci.queueCreateInfoCount = 1; + dci.pQueueCreateInfos = &qci; + dci.setPNext(&dgci); + + try { + dg_device = dev0->physical_device.createDevice(dci); + } catch (vk::SystemError& e) { + std::cerr << " devgroup: device creation failed: " << e.what() << std::endl; + break; + } + + dg_queue = dg_device.getQueue(qf_idx, 0); + dg_cmd_pool = dg_device.createCommandPool({ vk::CommandPoolCreateFlagBits::eResetCommandBuffer, qf_idx }); + dg_fence = dg_device.createFence({}); + + // Find device-local heap and check peer memory features + auto mem_props = dev0->physical_device.getMemoryProperties(); + for (uint32_t m = 0; m < mem_props.memoryTypeCount; m++) { + if (!(mem_props.memoryTypes[m].propertyFlags & vk::MemoryPropertyFlagBits::eDeviceLocal)) continue; + uint32_t heap = mem_props.memoryTypes[m].heapIndex; + + vk::PeerMemoryFeatureFlags peer_flags = dg_device.getGroupPeerMemoryFeatures(heap, dg_idx0, dg_idx1); + if (peer_flags & vk::PeerMemoryFeatureFlagBits::eCopySrc) { + dg_heap_idx = m; + has_devgroup = true; + break; + } + } + + if (!has_devgroup) { + std::cerr << " devgroup: no peer copy support between devices" << std::endl; + dg_device.destroyFence(dg_fence); + dg_device.destroyCommandPool(dg_cmd_pool); + dg_device.destroy(); + dg_device = vk::Device{}; + } + break; + } + } + // Helper to record a result auto record = [&](const std::string & method, size_t size, std::vector & times) { std::sort(times.begin(), times.end()); @@ -13289,60 +13373,76 @@ static void ggml_vk_bench_pair( vk::Buffer imported_buffer{}; vk::DeviceMemory imported_mem{}; - if (setup_ok) { + if (setup_ok) do { + vk::MemoryGetFdInfoKHR gi{}; + gi.memory = exp_mem; + gi.handleType = vk::ExternalMemoryHandleTypeFlagBits::eDmaBufEXT; + int dmabuf_fd = -1; try { - vk::MemoryGetFdInfoKHR gi{}; - gi.memory = exp_mem; - gi.handleType = vk::ExternalMemoryHandleTypeFlagBits::eDmaBufEXT; - int dmabuf_fd = dev0->device.getMemoryFdKHR(gi); - - vk::MemoryFdPropertiesKHR fd_props = dev1->device.getMemoryFdPropertiesKHR( - vk::ExternalMemoryHandleTypeFlagBits::eDmaBufEXT, dmabuf_fd); - - if (fd_props.memoryTypeBits == 0) { - close(dmabuf_fd); - throw vk::SystemError(vk::make_error_code(vk::Result::eErrorFormatNotSupported)); - } - - vk::ExternalMemoryBufferCreateInfo imp_ext_bci{}; - imp_ext_bci.handleTypes = vk::ExternalMemoryHandleTypeFlagBits::eDmaBufEXT; - vk::BufferCreateInfo imp_bci{}; - imp_bci.size = size; - imp_bci.usage = vk::BufferUsageFlagBits::eTransferSrc; - imp_bci.setPNext(&imp_ext_bci); - imported_buffer = dev1->device.createBuffer(imp_bci); - - vk::MemoryRequirements mem_req = dev1->device.getBufferMemoryRequirements(imported_buffer); - uint32_t mem_type_idx = UINT32_MAX; - for (uint32_t m = 0; m < 32; m++) { - if ((fd_props.memoryTypeBits & (1u << m)) && (mem_req.memoryTypeBits & (1u << m))) { - mem_type_idx = m; - break; - } - } - if (mem_type_idx == UINT32_MAX) { - close(dmabuf_fd); - throw vk::SystemError(vk::make_error_code(vk::Result::eErrorFormatNotSupported)); - } - - vk::ImportMemoryFdInfoKHR import_info{}; - import_info.handleType = vk::ExternalMemoryHandleTypeFlagBits::eDmaBufEXT; - import_info.fd = dmabuf_fd; - vk::MemoryAllocateInfo alloc_info{}; - alloc_info.allocationSize = mem_req.size; - alloc_info.memoryTypeIndex = mem_type_idx; - alloc_info.setPNext(&import_info); - imported_mem = dev1->device.allocateMemory(alloc_info); - dev1->device.bindBufferMemory(imported_buffer, imported_mem, 0); + dmabuf_fd = dev0->device.getMemoryFdKHR(gi); } catch (vk::SystemError& e) { - std::cerr << " dmabuf_p2p : SKIPPED (import: " << e.what() << ")" << std::endl; - if (imported_buffer) dev1->device.destroyBuffer(imported_buffer); - if (imported_mem) dev1->device.freeMemory(imported_mem); - imported_buffer = vk::Buffer{}; - imported_mem = vk::DeviceMemory{}; - setup_ok = false; + std::cerr << " dmabuf_p2p : SKIPPED (export fd: " << e.what() << ")" << std::endl; + setup_ok = false; break; } - } + + vk::MemoryFdPropertiesKHR fd_props; + try { + fd_props = dev1->device.getMemoryFdPropertiesKHR( + vk::ExternalMemoryHandleTypeFlagBits::eDmaBufEXT, dmabuf_fd); + } catch (vk::SystemError& e) { + std::cerr << " dmabuf_p2p : SKIPPED (fd props: " << e.what() << ")" << std::endl; + close(dmabuf_fd); + setup_ok = false; break; + } + + if (fd_props.memoryTypeBits == 0) { + std::cerr << " dmabuf_p2p : SKIPPED (fd has no importable memory types on dest)" << std::endl; + close(dmabuf_fd); + setup_ok = false; break; + } + + vk::ExternalMemoryBufferCreateInfo imp_ext_bci{}; + imp_ext_bci.handleTypes = vk::ExternalMemoryHandleTypeFlagBits::eDmaBufEXT; + vk::BufferCreateInfo imp_bci{}; + imp_bci.size = size; + imp_bci.usage = vk::BufferUsageFlagBits::eTransferSrc; + imp_bci.setPNext(&imp_ext_bci); + imported_buffer = dev1->device.createBuffer(imp_bci); + + vk::MemoryRequirements mem_req = dev1->device.getBufferMemoryRequirements(imported_buffer); + uint32_t mem_type_idx = UINT32_MAX; + for (uint32_t m = 0; m < 32; m++) { + if ((fd_props.memoryTypeBits & (1u << m)) && (mem_req.memoryTypeBits & (1u << m))) { + mem_type_idx = m; + break; + } + } + if (mem_type_idx == UINT32_MAX) { + std::cerr << " dmabuf_p2p : SKIPPED (fd_props=0x" << std::hex << fd_props.memoryTypeBits + << " buf_req=0x" << mem_req.memoryTypeBits << std::dec << " — no overlap)" << std::endl; + close(dmabuf_fd); + dev1->device.destroyBuffer(imported_buffer); + imported_buffer = vk::Buffer{}; + setup_ok = false; break; + } + + vk::ImportMemoryFdInfoKHR import_info{}; + import_info.handleType = vk::ExternalMemoryHandleTypeFlagBits::eDmaBufEXT; + import_info.fd = dmabuf_fd; + vk::MemoryAllocateInfo alloc_info{}; + alloc_info.allocationSize = mem_req.size; + alloc_info.memoryTypeIndex = mem_type_idx; + alloc_info.setPNext(&import_info); + try { + imported_mem = dev1->device.allocateMemory(alloc_info); + } catch (vk::SystemError& e) { + std::cerr << " dmabuf_p2p : SKIPPED (import alloc type " << mem_type_idx << ": " << e.what() << ")" << std::endl; + dev1->device.destroyBuffer(imported_buffer); + imported_buffer = vk::Buffer{}; + setup_ok = false; break; + } + dev1->device.bindBufferMemory(imported_buffer, imported_mem, 0); + } while (false); if (setup_ok) { std::vector times; @@ -13370,6 +13470,122 @@ static void ggml_vk_bench_pair( if (exp_mem) dev0->device.freeMemory(exp_mem); } #endif + + // ================================================================= + // 7. Device group P2P: direct peer memory access via VkDeviceGroup + // ================================================================= + if (has_devgroup) { + std::vector times; + bool run_ok = true; + + // Create src buffer + memory on device 0 + vk::BufferCreateInfo bci{}; + bci.size = size; + bci.usage = vk::BufferUsageFlagBits::eTransferSrc | vk::BufferUsageFlagBits::eTransferDst; + + vk::Buffer dg_src_buf{}, dg_dst_buf{}; + vk::DeviceMemory dg_src_mem{}, dg_dst_mem{}; + + try { + dg_src_buf = dg_device.createBuffer(bci); + dg_dst_buf = dg_device.createBuffer(bci); + + vk::MemoryRequirements src_req = dg_device.getBufferMemoryRequirements(dg_src_buf); + vk::MemoryRequirements dst_req = dg_device.getBufferMemoryRequirements(dg_dst_buf); + + vk::MemoryAllocateFlagsInfo flags_src{}; + flags_src.flags = vk::MemoryAllocateFlagBits::eDeviceMask; + flags_src.deviceMask = 1u << dg_idx0; + + vk::MemoryAllocateFlagsInfo flags_dst{}; + flags_dst.flags = vk::MemoryAllocateFlagBits::eDeviceMask; + flags_dst.deviceMask = 1u << dg_idx1; + + dg_src_mem = dg_device.allocateMemory({ src_req.size, dg_heap_idx, &flags_src }); + dg_dst_mem = dg_device.allocateMemory({ dst_req.size, dg_heap_idx, &flags_dst }); + + dg_device.bindBufferMemory(dg_src_buf, dg_src_mem, 0); + dg_device.bindBufferMemory(dg_dst_buf, dg_dst_mem, 0); + + // Fill src on device 0 + vk::CommandBuffer fill_cb = dg_device.allocateCommandBuffers( + { dg_cmd_pool, vk::CommandBufferLevel::ePrimary, 1 })[0]; + + vk::DeviceGroupCommandBufferBeginInfo dg_begin{}; + dg_begin.deviceMask = 1u << dg_idx0; + vk::CommandBufferBeginInfo cbi{}; + cbi.flags = vk::CommandBufferUsageFlagBits::eOneTimeSubmit; + cbi.setPNext(&dg_begin); + fill_cb.begin(cbi); + fill_cb.setDeviceMask(1u << dg_idx0); + fill_cb.fillBuffer(dg_src_buf, 0, size, 0xDEADBEEF); + fill_cb.end(); + + vk::DeviceGroupSubmitInfo dg_submit_info{}; + uint32_t fill_mask = 1u << dg_idx0; + dg_submit_info.commandBufferCount = 1; + dg_submit_info.pCommandBufferDeviceMasks = &fill_mask; + + vk::SubmitInfo si{}; + si.commandBufferCount = 1; + si.pCommandBuffers = &fill_cb; + si.setPNext(&dg_submit_info); + dg_queue.submit({ si }, dg_fence); + VK_CHECK(dg_device.waitForFences({ dg_fence }, true, UINT64_MAX), "devgroup fill"); + dg_device.resetFences({ dg_fence }); + dg_device.resetCommandPool(dg_cmd_pool); + } catch (vk::SystemError& e) { + std::cerr << " devgroup_p2p : SKIPPED (setup: " << e.what() << ")" << std::endl; + run_ok = false; + } + + for (size_t i = 0; i < num_it + warmup && run_ok; i++) { + auto begin = std::chrono::high_resolution_clock::now(); + + vk::CommandBuffer cb = dg_device.allocateCommandBuffers( + { dg_cmd_pool, vk::CommandBufferLevel::ePrimary, 1 })[0]; + + vk::DeviceGroupCommandBufferBeginInfo dg_begin{}; + dg_begin.deviceMask = 1u << dg_idx1; + vk::CommandBufferBeginInfo cbi{}; + cbi.flags = vk::CommandBufferUsageFlagBits::eOneTimeSubmit; + cbi.setPNext(&dg_begin); + cb.begin(cbi); + cb.setDeviceMask(1u << dg_idx1); + VkBufferCopy bc{ 0, 0, size }; + vkCmdCopyBuffer(cb, dg_src_buf, dg_dst_buf, 1, &bc); + cb.end(); + + vk::DeviceGroupSubmitInfo dg_submit_info{}; + uint32_t copy_mask = 1u << dg_idx1; + dg_submit_info.commandBufferCount = 1; + dg_submit_info.pCommandBufferDeviceMasks = ©_mask; + + vk::SubmitInfo si{}; + si.commandBufferCount = 1; + si.pCommandBuffers = &cb; + si.setPNext(&dg_submit_info); + dg_queue.submit({ si }, dg_fence); + VK_CHECK(dg_device.waitForFences({ dg_fence }, true, UINT64_MAX), "devgroup_p2p"); + dg_device.resetFences({ dg_fence }); + dg_device.resetCommandPool(dg_cmd_pool); + + auto end = std::chrono::high_resolution_clock::now(); + if (i >= warmup) times.push_back(std::chrono::duration_cast(end - begin).count() / 1000.0); + } + if (run_ok) record("devgroup_p2p", size, times); + + if (dg_dst_buf) dg_device.destroyBuffer(dg_dst_buf); + if (dg_src_buf) dg_device.destroyBuffer(dg_src_buf); + if (dg_dst_mem) dg_device.freeMemory(dg_dst_mem); + if (dg_src_mem) dg_device.freeMemory(dg_src_mem); + } + } + + if (has_devgroup) { + dg_device.destroyFence(dg_fence); + dg_device.destroyCommandPool(dg_cmd_pool); + dg_device.destroy(); } ggml_vk_destroy_buffer(buf_src);