From f77b4d155b5900f6b10bdb14cd7f56aa20e27e5c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pedro=20J=2E=20Est=C3=A9banez?= Date: Wed, 6 Mar 2024 11:06:17 +0100 Subject: [PATCH] Make shader binary alignment handling simpler and more robust Bonus: Also simplified the rounding to block size in image size calculations. --- .../d3d12/rendering_device_driver_d3d12.cpp | 47 ++++++----------- .../vulkan/rendering_device_driver_vulkan.cpp | 50 +++++++------------ .../rendering/rendering_device_commons.cpp | 7 +-- 3 files changed, 37 insertions(+), 67 deletions(-) diff --git a/drivers/d3d12/rendering_device_driver_d3d12.cpp b/drivers/d3d12/rendering_device_driver_d3d12.cpp index 381d022a55a9..7c7a73bbaaa1 100644 --- a/drivers/d3d12/rendering_device_driver_d3d12.cpp +++ b/drivers/d3d12/rendering_device_driver_d3d12.cpp @@ -3246,10 +3246,7 @@ Vector RenderingDeviceDriverD3D12::shader_compile_binary_from_spirv(Vec compressed_stages.push_back(zstd); uint32_t s = compressed_stages[i].size(); - if (s % 4 != 0) { - s += 4 - (s % 4); - } - stages_binary_size += s; + stages_binary_size += STEPIFY(s, 4); } CharString shader_name_utf = p_shader_name.utf8(); @@ -3259,10 +3256,7 @@ Vector RenderingDeviceDriverD3D12::shader_compile_binary_from_spirv(Vec uint32_t total_size = sizeof(uint32_t) * 3; // Header + version + main datasize;. total_size += sizeof(ShaderBinary::Data); - total_size += binary_data.shader_name_len; - if ((binary_data.shader_name_len % 4) != 0) { // Alignment rules are really strange. - total_size += 4 - (binary_data.shader_name_len % 4); - } + total_size += STEPIFY(binary_data.shader_name_len, 4); for (int i = 0; i < sets_bindings.size(); i++) { total_size += sizeof(uint32_t); @@ -3294,13 +3288,17 @@ Vector RenderingDeviceDriverD3D12::shader_compile_binary_from_spirv(Vec memcpy(binptr + offset, &binary_data, sizeof(ShaderBinary::Data)); offset += sizeof(ShaderBinary::Data); +#define ADVANCE_OFFSET_WITH_ALIGNMENT(m_bytes) \ + { \ + offset += m_bytes; \ + uint32_t padding = STEPIFY(m_bytes, 4) - m_bytes; \ + memset(binptr + offset, 0, padding); /* Avoid garbage data. */ \ + offset += padding; \ + } + if (binary_data.shader_name_len > 0) { memcpy(binptr + offset, shader_name_utf.ptr(), binary_data.shader_name_len); - offset += binary_data.shader_name_len; - - if ((binary_data.shader_name_len % 4) != 0) { // Alignment rules are really strange. - offset += 4 - (binary_data.shader_name_len % 4); - } + ADVANCE_OFFSET_WITH_ALIGNMENT(binary_data.shader_name_len); } for (int i = 0; i < sets_bindings.size(); i++) { @@ -3326,14 +3324,7 @@ Vector RenderingDeviceDriverD3D12::shader_compile_binary_from_spirv(Vec encode_uint32(zstd_size[i], binptr + offset); offset += sizeof(uint32_t); memcpy(binptr + offset, compressed_stages[i].ptr(), compressed_stages[i].size()); - - uint32_t s = compressed_stages[i].size(); - - if (s % 4 != 0) { - s += 4 - (s % 4); - } - - offset += s; + ADVANCE_OFFSET_WITH_ALIGNMENT(compressed_stages[i].size()); } memcpy(binptr + offset, root_sig_blob->GetBufferPointer(), root_sig_blob->GetBufferSize()); @@ -3382,10 +3373,7 @@ RDD::ShaderID RenderingDeviceDriverD3D12::shader_create_from_bytecode(const Vect if (binary_data.shader_name_len) { r_name.parse_utf8((const char *)(binptr + read_offset), binary_data.shader_name_len); - read_offset += binary_data.shader_name_len; - if ((binary_data.shader_name_len % 4) != 0) { // Alignment rules are really strange. - read_offset += 4 - (binary_data.shader_name_len % 4); - } + read_offset += STEPIFY(binary_data.shader_name_len, 4); } r_shader_desc.uniform_sets.resize(binary_data.set_count); @@ -3458,6 +3446,7 @@ RDD::ShaderID RenderingDeviceDriverD3D12::shader_create_from_bytecode(const Vect for (uint32_t i = 0; i < binary_data.stage_count; i++) { ERR_FAIL_COND_V(read_offset + sizeof(uint32_t) * 3 >= binsize, ShaderID()); + uint32_t stage = decode_uint32(binptr + read_offset); read_offset += sizeof(uint32_t); uint32_t dxil_size = decode_uint32(binptr + read_offset); @@ -3472,13 +3461,9 @@ RDD::ShaderID RenderingDeviceDriverD3D12::shader_create_from_bytecode(const Vect ERR_FAIL_COND_V(dec_dxil_size != (int32_t)dxil_size, ShaderID()); shader_info_in.stages_bytecode[ShaderStage(stage)] = dxil; - if (zstd_size % 4 != 0) { - zstd_size += 4 - (zstd_size % 4); - } - - ERR_FAIL_COND_V(read_offset + zstd_size > binsize, ShaderID()); - + zstd_size = STEPIFY(zstd_size, 4); read_offset += zstd_size; + ERR_FAIL_COND_V(read_offset > binsize, ShaderID()); } const uint8_t *root_sig_data_ptr = binptr + read_offset; diff --git a/drivers/vulkan/rendering_device_driver_vulkan.cpp b/drivers/vulkan/rendering_device_driver_vulkan.cpp index 21cf54b4be28..08bdb38a7891 100644 --- a/drivers/vulkan/rendering_device_driver_vulkan.cpp +++ b/drivers/vulkan/rendering_device_driver_vulkan.cpp @@ -2957,10 +2957,7 @@ Vector RenderingDeviceDriverVulkan::shader_compile_binary_from_spirv(Ve } } uint32_t s = compressed_stages[i].size(); - if (s % 4 != 0) { - s += 4 - (s % 4); - } - stages_binary_size += s; + stages_binary_size += STEPIFY(s, 4); } binary_data.specialization_constants_count = specialization_constants.size(); @@ -2974,11 +2971,7 @@ Vector RenderingDeviceDriverVulkan::shader_compile_binary_from_spirv(Ve uint32_t total_size = sizeof(uint32_t) * 3; // Header + version + main datasize;. total_size += sizeof(ShaderBinary::Data); - total_size += binary_data.shader_name_len; - - if ((binary_data.shader_name_len % 4) != 0) { // Alignment rules are really strange. - total_size += 4 - (binary_data.shader_name_len % 4); - } + total_size += STEPIFY(binary_data.shader_name_len, 4); for (int i = 0; i < uniforms.size(); i++) { total_size += sizeof(uint32_t); @@ -3007,13 +3000,17 @@ Vector RenderingDeviceDriverVulkan::shader_compile_binary_from_spirv(Ve memcpy(binptr + offset, &binary_data, sizeof(ShaderBinary::Data)); offset += sizeof(ShaderBinary::Data); +#define ADVANCE_OFFSET_WITH_ALIGNMENT(m_bytes) \ + { \ + offset += m_bytes; \ + uint32_t padding = STEPIFY(m_bytes, 4) - m_bytes; \ + memset(binptr + offset, 0, padding); /* Avoid garbage data. */ \ + offset += padding; \ + } + if (binary_data.shader_name_len > 0) { memcpy(binptr + offset, shader_name_utf.ptr(), binary_data.shader_name_len); - offset += binary_data.shader_name_len; - - if ((binary_data.shader_name_len % 4) != 0) { // Alignment rules are really strange. - offset += 4 - (binary_data.shader_name_len % 4); - } + ADVANCE_OFFSET_WITH_ALIGNMENT(binary_data.shader_name_len); } for (int i = 0; i < uniforms.size(); i++) { @@ -3039,14 +3036,7 @@ Vector RenderingDeviceDriverVulkan::shader_compile_binary_from_spirv(Ve encode_uint32(zstd_size[i], binptr + offset); offset += sizeof(uint32_t); memcpy(binptr + offset, compressed_stages[i].ptr(), compressed_stages[i].size()); - - uint32_t s = compressed_stages[i].size(); - - if (s % 4 != 0) { - s += 4 - (s % 4); - } - - offset += s; + ADVANCE_OFFSET_WITH_ALIGNMENT(compressed_stages[i].size()); } DEV_ASSERT(offset == (uint32_t)ret.size()); @@ -3090,10 +3080,7 @@ RDD::ShaderID RenderingDeviceDriverVulkan::shader_create_from_bytecode(const Vec if (binary_data.shader_name_len) { r_name.parse_utf8((const char *)(binptr + read_offset), binary_data.shader_name_len); - read_offset += binary_data.shader_name_len; - if ((binary_data.shader_name_len % 4) != 0) { // Alignment rules are really strange. - read_offset += 4 - (binary_data.shader_name_len % 4); - } + read_offset += STEPIFY(binary_data.shader_name_len, 4); } Vector> vk_set_bindings; @@ -3192,6 +3179,7 @@ RDD::ShaderID RenderingDeviceDriverVulkan::shader_create_from_bytecode(const Vec for (uint32_t i = 0; i < binary_data.stage_count; i++) { ERR_FAIL_COND_V(read_offset + sizeof(uint32_t) * 3 >= binsize, ShaderID()); + uint32_t stage = decode_uint32(binptr + read_offset); read_offset += sizeof(uint32_t); uint32_t smolv_size = decode_uint32(binptr + read_offset); @@ -3223,16 +3211,12 @@ RDD::ShaderID RenderingDeviceDriverVulkan::shader_create_from_bytecode(const Vec r_shader_desc.stages.set(i, ShaderStage(stage)); - if (buf_size % 4 != 0) { - buf_size += 4 - (buf_size % 4); - } - - DEV_ASSERT(read_offset + buf_size <= binsize); - + buf_size = STEPIFY(buf_size, 4); read_offset += buf_size; + ERR_FAIL_COND_V(read_offset > binsize, ShaderID()); } - DEV_ASSERT(read_offset == binsize); + ERR_FAIL_COND_V(read_offset != binsize, ShaderID()); // Modules. diff --git a/servers/rendering/rendering_device_commons.cpp b/servers/rendering/rendering_device_commons.cpp index c8b7980633fe..4dbd0e396476 100644 --- a/servers/rendering/rendering_device_commons.cpp +++ b/servers/rendering/rendering_device_commons.cpp @@ -711,12 +711,13 @@ uint32_t RenderingDeviceCommons::get_image_format_required_size(DataFormat p_for uint32_t pixel_size = get_image_format_pixel_size(p_format); uint32_t pixel_rshift = get_compressed_image_format_pixel_rshift(p_format); - uint32_t blockw, blockh; + uint32_t blockw = 0; + uint32_t blockh = 0; get_compressed_image_format_block_dimensions(p_format, blockw, blockh); for (uint32_t i = 0; i < p_mipmaps; i++) { - uint32_t bw = w % blockw != 0 ? w + (blockw - w % blockw) : w; - uint32_t bh = h % blockh != 0 ? h + (blockh - h % blockh) : h; + uint32_t bw = STEPIFY(w, blockw); + uint32_t bh = STEPIFY(h, blockh); uint32_t s = bw * bh;