diff --git a/dlls/wined3d/shader_spirv.c b/dlls/wined3d/shader_spirv.c index d729bf5c390..eb2647db12d 100644 --- a/dlls/wined3d/shader_spirv.c +++ b/dlls/wined3d/shader_spirv.c @@ -54,6 +54,7 @@ struct shader_spirv_resource_bindings SIZE_T vk_bindings_size, vk_binding_count; size_t binding_base[WINED3D_SHADER_TYPE_COUNT]; + enum wined3d_shader_type so_stage; }; struct shader_spirv_priv @@ -80,6 +81,7 @@ struct shader_spirv_compile_arguments struct shader_spirv_graphics_program_variant_vk { struct shader_spirv_compile_arguments compile_args; + const struct wined3d_stream_output_desc *so_desc; size_t binding_base; VkShaderModule vk_module; @@ -276,17 +278,14 @@ static const char *get_line(const char **ptr) } static void shader_spirv_init_shader_interface_vk(struct wined3d_shader_spirv_shader_interface *iface, - struct wined3d_shader *shader, const struct shader_spirv_resource_bindings *b) + struct wined3d_shader *shader, const struct shader_spirv_resource_bindings *b, + const struct wined3d_stream_output_desc *so_desc) { - enum wined3d_shader_type shader_type = shader->reg_maps.shader_version.type; - memset(iface, 0, sizeof(*iface)); iface->vkd3d_interface.type = VKD3D_SHADER_STRUCTURE_TYPE_INTERFACE_INFO; - if (shader_type == WINED3D_SHADER_TYPE_GEOMETRY && shader->u.gs.so_desc) + if (so_desc) { - const struct wined3d_stream_output_desc *so_desc = shader->u.gs.so_desc; - iface->xfb_info.type = VKD3D_SHADER_STRUCTURE_TYPE_TRANSFORM_FEEDBACK_INFO; iface->xfb_info.next = NULL; @@ -307,7 +306,7 @@ static void shader_spirv_init_shader_interface_vk(struct wined3d_shader_spirv_sh static VkShaderModule shader_spirv_compile(struct wined3d_context_vk *context_vk, struct wined3d_shader *shader, const struct shader_spirv_compile_arguments *args, - const struct shader_spirv_resource_bindings *bindings) + const struct shader_spirv_resource_bindings *bindings, const struct wined3d_stream_output_desc *so_desc) { struct wined3d_shader_spirv_compile_args compile_args; struct wined3d_shader_spirv_shader_interface iface; @@ -322,7 +321,7 @@ static VkShaderModule shader_spirv_compile(struct wined3d_context_vk *context_vk VkResult vr; int ret; - shader_spirv_init_shader_interface_vk(&iface, shader, bindings); + shader_spirv_init_shader_interface_vk(&iface, shader, bindings, so_desc); shader_type = shader->reg_maps.shader_version.type; shader_spirv_init_compile_args(&compile_args, &iface.vkd3d_interface, VKD3D_SHADER_SPIRV_ENVIRONMENT_VULKAN_1_0, shader_type, args); @@ -383,13 +382,17 @@ static struct shader_spirv_graphics_program_variant_vk *shader_spirv_find_graphi struct shader_spirv_priv *priv, struct wined3d_context_vk *context_vk, struct wined3d_shader *shader, const struct wined3d_state *state, const struct shader_spirv_resource_bindings *bindings) { - size_t binding_base = bindings->binding_base[shader->reg_maps.shader_version.type]; + enum wined3d_shader_type shader_type = shader->reg_maps.shader_version.type; struct shader_spirv_graphics_program_variant_vk *variant_vk; + size_t binding_base = bindings->binding_base[shader_type]; + const struct wined3d_stream_output_desc *so_desc = NULL; struct shader_spirv_graphics_program_vk *program_vk; struct shader_spirv_compile_arguments args; size_t variant_count, i; shader_spirv_compile_arguments_init(&args, &context_vk->c, shader, state, context_vk->sample_count); + if (bindings->so_stage == shader_type) + so_desc = state->shader[WINED3D_SHADER_TYPE_GEOMETRY]->u.gs.so_desc; if (!(program_vk = shader->backend_data)) return NULL; @@ -398,7 +401,7 @@ static struct shader_spirv_graphics_program_variant_vk *shader_spirv_find_graphi for (i = 0; i < variant_count; ++i) { variant_vk = &program_vk->variants[i]; - if (variant_vk->binding_base == binding_base + if (variant_vk->so_desc == so_desc && variant_vk->binding_base == binding_base && !memcmp(&variant_vk->compile_args, &args, sizeof(args))) return variant_vk; } @@ -410,7 +413,7 @@ static struct shader_spirv_graphics_program_variant_vk *shader_spirv_find_graphi variant_vk = &program_vk->variants[variant_count]; variant_vk->compile_args = args; variant_vk->binding_base = binding_base; - if (!(variant_vk->vk_module = shader_spirv_compile(context_vk, shader, &args, bindings))) + if (!(variant_vk->vk_module = shader_spirv_compile(context_vk, shader, &args, bindings, so_desc))) return NULL; ++program_vk->variant_count; @@ -434,7 +437,7 @@ static struct shader_spirv_compute_program_vk *shader_spirv_find_compute_program if (program->vk_module) return program; - if (!(program->vk_module = shader_spirv_compile(context_vk, shader, NULL, bindings))) + if (!(program->vk_module = shader_spirv_compile(context_vk, shader, NULL, bindings, NULL))) return NULL; if (!(layout = wined3d_context_vk_get_pipeline_layout(context_vk, @@ -671,6 +674,7 @@ static bool shader_spirv_resource_bindings_init(struct shader_spirv_resource_bin bindings->binding_count = 0; bindings->uav_counter_count = 0; bindings->vk_binding_count = 0; + bindings->so_stage = WINED3D_SHADER_TYPE_GEOMETRY; wined3d_bindings->count = 0; for (shader_type = 0; shader_type < WINED3D_SHADER_TYPE_COUNT; ++shader_type) @@ -681,9 +685,15 @@ static bool shader_spirv_resource_bindings_init(struct shader_spirv_resource_bin continue; if (shader_type == WINED3D_SHADER_TYPE_COMPUTE) + { descriptor_info = &((struct shader_spirv_compute_program_vk *)shader->backend_data)->descriptor_info; + } else + { descriptor_info = &((struct shader_spirv_graphics_program_vk *)shader->backend_data)->descriptor_info; + if (shader_type == WINED3D_SHADER_TYPE_GEOMETRY && !shader->function) + bindings->so_stage = WINED3D_SHADER_TYPE_VERTEX; + } vk_stage = vk_shader_stage_from_wined3d(shader_type); shader_visibility = vkd3d_shader_visibility_from_wined3d(shader_type); @@ -829,6 +839,8 @@ static void shader_spirv_select(void *shader_priv, struct wined3d_context *conte ERR("Failed to initialise shader resource bindings.\n"); goto fail; } + if (context->shader_update_mask & (1u << WINED3D_SHADER_TYPE_GEOMETRY)) + context->shader_update_mask |= 1u << bindings->so_stage; layout_vk = wined3d_context_vk_get_pipeline_layout(context_vk, bindings->vk_bindings, bindings->vk_binding_count); context_vk->graphics.vk_set_layout = layout_vk->vk_set_layout; @@ -840,7 +852,7 @@ static void shader_spirv_select(void *shader_priv, struct wined3d_context *conte || binding_base[shader_type] == bindings->binding_base[shader_type])) continue; - if (!(shader = state->shader[shader_type])) + if (!(shader = state->shader[shader_type]) || !shader->function) { context_vk->graphics.vk_modules[shader_type] = VK_NULL_HANDLE; continue;