sandbox
Loading...
Searching...
No Matches
compute_pipeline.hpp
1#ifndef LIBSBX_GRAPHICS_COMPUTE_PIPELINE_HPP_
2#define LIBSBX_GRAPHICS_COMPUTE_PIPELINE_HPP_
3
4#include <vector>
5
6#include <vulkan/vulkan.hpp>
7
8#include <fmt/format.h>
9
10#include <libsbx/math/vector3.hpp>
11
12#include <libsbx/graphics/pipeline/pipeline.hpp>
13#include <libsbx/graphics/pipeline/shader.hpp>
14
15namespace sbx::graphics {
16
17class compute_pipeline : public pipeline {
18
19 using base = pipeline;
20
21public:
22
23 compute_pipeline(const std::filesystem::path& path, const render_graph::compute_pass& pass);
24
25 ~compute_pipeline() override;
26
27 auto handle() const noexcept -> VkPipeline override {
28 return _handle;
29 }
30
31 auto has_variable_descriptors() const noexcept -> bool override {
32 return false;
33 }
34
35 auto descriptor_counts(std::uint32_t set) const noexcept -> std::vector<std::uint32_t> override {
36 auto counts = std::vector<std::uint32_t>{};
37
38 for (const auto& binding_data : _set_data[set].binding_data) {
39 counts.push_back(binding_data.descriptor_count);
40 }
41
42 return counts;
43 }
44
45 auto descriptor_set_layout(std::uint32_t set) const noexcept -> VkDescriptorSetLayout override {
46 return _set_data[set].layout;
47 }
48
49 auto descriptor_pool() const noexcept -> VkDescriptorPool override {
50 return _descriptor_pool;
51 }
52
53 auto layout() const noexcept -> VkPipelineLayout override {
54 return _layout;
55 }
56
57 auto bind_point() const noexcept -> VkPipelineBindPoint override {
58 return _bind_point;
59 }
60
61 auto descriptor_block(const std::string& name, std::uint32_t set) const -> const shader::uniform_block& override {
62 if (auto it = _set_data[set].uniform_blocks.find(name); it != _set_data[set].uniform_blocks.end()) {
63 return it->second;
64 }
65
66 throw std::runtime_error(fmt::format("Failed to find descriptor block '{}' in graphics pipeline '{}'", name, _name));
67 }
68
69 auto push_constant() const noexcept -> const std::optional<shader::uniform_block>& override {
70 return _push_constant;
71 }
72
73 auto find_descriptor_binding(const std::string& name, std::uint32_t set) const -> std::optional<std::uint32_t> override {
74 if (auto it = _set_data[set].descriptor_bindings.find(name); it != _set_data[set].descriptor_bindings.end()) {
75 return it->second;
76 }
77
78 return std::nullopt;
79 }
80
81 auto find_descriptor_type_at_binding(std::uint32_t set, std::uint32_t binding) const -> std::optional<VkDescriptorType> override {
82 if (_set_data[set].binding_data.size() <= binding) {
83 return std::nullopt;
84 }
85
86 return _set_data[set].binding_data[binding].descriptor_type;
87 }
88
89 auto dispatch(command_buffer& command_buffer, const math::vector3u& groups) -> void {
90 vkCmdDispatch(command_buffer, groups.x(), groups.y(), groups.z());
91 }
92
93private:
94
95 struct per_binding_data {
96 VkDescriptorType descriptor_type;
97 std::uint32_t descriptor_count;
98 }; // struct per_binding_data
99
100 struct per_set_data {
101 std::unordered_map<std::string, shader::uniform> uniforms;
102 std::unordered_map<std::string, shader::uniform_block> uniform_blocks;
103 std::unordered_map<std::string, std::uint32_t> descriptor_bindings;
104 std::unordered_map<std::string, std::uint32_t> descriptor_sizes;
105 std::vector<per_binding_data> binding_data;
106 VkDescriptorSetLayout layout;
107 }; // struct per_set_data
108
109 auto _get_stage_from_name(const std::string& name) const noexcept -> VkShaderStageFlagBits;
110
112
113 std::unique_ptr<shader> _shader;
114
115 std::vector<per_set_data> _set_data;
116 std::optional<shader::uniform_block> _push_constant;
117
118 std::string _name;
119 VkPipelineLayout _layout;
120 VkPipeline _handle;
121 VkPipelineBindPoint _bind_point;
122
123 VkDescriptorPool _descriptor_pool;
124
125}; // class compute_pipeline
126
127} // namespace sbx::graphics
128
129#endif // LIBSBX_GRAPHICS_COMPUTE_PIPELINE_HPP_
Definition: command_buffer.hpp:14
Definition: compute_pipeline.hpp:17
Definition: render_graph.hpp:242
Definition: pipeline.hpp:20
Definition: shader.hpp:125
Definition: vector3.hpp:22