From 9a6fea431334dbb5d82716bcbe45c01f19523eb7 Mon Sep 17 00:00:00 2001 From: Felix Maier Date: Mon, 4 May 2020 12:12:23 +0200 Subject: [PATCH 1/6] Implement VK_KHR_ray_tracing in HLSL --- main.cpp | 24 +-- spirv_glsl.cpp | 128 ++++++------ spirv_glsl.hpp | 3 +- spirv_hlsl.cpp | 552 ++++++++++++++++++++++++++++++++++++++++++++++++- spirv_hlsl.hpp | 3 + 5 files changed, 625 insertions(+), 85 deletions(-) diff --git a/main.cpp b/main.cpp index b69c45e62..6424079df 100644 --- a/main.cpp +++ b/main.cpp @@ -345,18 +345,18 @@ static const char *execution_model_to_str(spv::ExecutionModel model) return "fragment"; case ExecutionModelGLCompute: return "compute"; - case ExecutionModelRayGenerationNV: - return "raygenNV"; - case ExecutionModelIntersectionNV: - return "intersectionNV"; - case ExecutionModelCallableNV: - return "callableNV"; - case ExecutionModelAnyHitNV: - return "anyhitNV"; - case ExecutionModelClosestHitNV: - return "closesthitNV"; - case ExecutionModelMissNV: - return "missNV"; + case ExecutionModelRayGenerationKHR: + return "raygenKHR"; + case ExecutionModelIntersectionKHR: + return "intersectionKHR"; + case ExecutionModelCallableKHR: + return "callableKHR"; + case ExecutionModelAnyHitKHR: + return "anyhitKHR"; + case ExecutionModelClosestHitKHR: + return "closesthitKHR"; + case ExecutionModelMissKHR: + return "missKHR"; default: return "???"; } diff --git a/spirv_glsl.cpp b/spirv_glsl.cpp index 50cc79ab1..fa09e7604 100644 --- a/spirv_glsl.cpp +++ b/spirv_glsl.cpp @@ -445,15 +445,15 @@ void CompilerGLSL::find_static_extensions() require_extension_internal("GL_ARB_tessellation_shader"); break; - case ExecutionModelRayGenerationNV: - case ExecutionModelIntersectionNV: - case ExecutionModelAnyHitNV: - case ExecutionModelClosestHitNV: - case ExecutionModelMissNV: - case ExecutionModelCallableNV: + case ExecutionModelRayGenerationKHR: + case ExecutionModelIntersectionKHR: + case ExecutionModelAnyHitKHR: + case ExecutionModelClosestHitKHR: + case ExecutionModelMissKHR: + case ExecutionModelCallableKHR: if (options.es || options.version < 460) SPIRV_CROSS_THROW("Ray tracing shaders require non-es profile with version 460 or above."); - require_extension_internal("GL_NV_ray_tracing"); + require_extension_internal("GL_KHR_ray_tracing"); break; default: @@ -2089,25 +2089,25 @@ const char *CompilerGLSL::to_storage_qualifiers_glsl(const SPIRVariable &var) { return "uniform "; } - else if (var.storage == StorageClassRayPayloadNV) + else if (var.storage == StorageClassRayPayloadKHR) { - return "rayPayloadNV "; + return "rayPayloadEXT "; } - else if (var.storage == StorageClassIncomingRayPayloadNV) + else if (var.storage == StorageClassIncomingRayPayloadKHR) { - return "rayPayloadInNV "; + return "rayPayloadInEXT "; } - else if (var.storage == StorageClassHitAttributeNV) + else if (var.storage == StorageClassHitAttributeKHR) { - return "hitAttributeNV "; + return "hitAttributeEXT "; } - else if (var.storage == StorageClassCallableDataNV) + else if (var.storage == StorageClassCallableDataKHR) { - return "callableDataNV "; + return "callableDataEXT "; } - else if (var.storage == StorageClassIncomingCallableDataNV) + else if (var.storage == StorageClassIncomingCallableDataKHR) { - return "callableDataInNV "; + return "callableDataInEXT "; } return ""; @@ -3056,9 +3056,9 @@ void CompilerGLSL::emit_resources() if (var.storage != StorageClassFunction && type.pointer && (type.storage == StorageClassUniformConstant || type.storage == StorageClassAtomicCounter || - type.storage == StorageClassRayPayloadNV || type.storage == StorageClassIncomingRayPayloadNV || - type.storage == StorageClassCallableDataNV || type.storage == StorageClassIncomingCallableDataNV || - type.storage == StorageClassHitAttributeNV) && + type.storage == StorageClassRayPayloadKHR || type.storage == StorageClassIncomingRayPayloadKHR || + type.storage == StorageClassCallableDataKHR || type.storage == StorageClassIncomingCallableDataKHR || + type.storage == StorageClassHitAttributeKHR) && !is_hidden_variable(var)) { emit_uniform(var); @@ -4647,7 +4647,8 @@ bool CompilerGLSL::emit_complex_bitcast(uint32_t result_type, uint32_t id, uint3 if (output_type.basetype == SPIRType::Half && input_type.basetype == SPIRType::Float && input_type.vecsize == 1) expr = join("unpackFloat2x16(floatBitsToUint(", to_unpacked_expression(op0), "))"); - else if (output_type.basetype == SPIRType::Float && input_type.basetype == SPIRType::Half && input_type.vecsize == 2) + else if (output_type.basetype == SPIRType::Float && input_type.basetype == SPIRType::Half && + input_type.vecsize == 2) expr = join("uintBitsToFloat(packFloat2x16(", to_unpacked_expression(op0), "))"); else return false; @@ -6965,34 +6966,34 @@ string CompilerGLSL::builtin_to_glsl(BuiltIn builtin, StorageClass storage) require_extension_internal("GL_KHR_shader_subgroup_ballot"); return "gl_SubgroupLtMask"; - case BuiltInLaunchIdNV: - return "gl_LaunchIDNV"; - case BuiltInLaunchSizeNV: - return "gl_LaunchSizeNV"; - case BuiltInWorldRayOriginNV: - return "gl_WorldRayOriginNV"; - case BuiltInWorldRayDirectionNV: - return "gl_WorldRayDirectionNV"; - case BuiltInObjectRayOriginNV: - return "gl_ObjectRayOriginNV"; - case BuiltInObjectRayDirectionNV: - return "gl_ObjectRayDirectionNV"; - case BuiltInRayTminNV: - return "gl_RayTminNV"; - case BuiltInRayTmaxNV: - return "gl_RayTmaxNV"; - case BuiltInInstanceCustomIndexNV: - return "gl_InstanceCustomIndexNV"; - case BuiltInObjectToWorldNV: - return "gl_ObjectToWorldNV"; - case BuiltInWorldToObjectNV: - return "gl_WorldToObjectNV"; - case BuiltInHitTNV: - return "gl_HitTNV"; - case BuiltInHitKindNV: - return "gl_HitKindNV"; - case BuiltInIncomingRayFlagsNV: - return "gl_IncomingRayFlagsNV"; + case BuiltInLaunchIdKHR: + return "gl_LaunchIDEXT"; + case BuiltInLaunchSizeKHR: + return "gl_LaunchSizeEXT"; + case BuiltInWorldRayOriginKHR: + return "gl_WorldRayOriginEXT"; + case BuiltInWorldRayDirectionKHR: + return "gl_WorldRayDirectionEXT"; + case BuiltInObjectRayOriginKHR: + return "gl_ObjectRayOriginEXT"; + case BuiltInObjectRayDirectionKHR: + return "gl_ObjectRayDirectionEXT"; + case BuiltInRayTminKHR: + return "gl_RayTminEXT"; + case BuiltInRayTmaxKHR: + return "gl_RayTmaxEXT"; + case BuiltInInstanceCustomIndexKHR: + return "gl_InstanceCustomIndexEXT"; + case BuiltInObjectToWorldKHR: + return "gl_ObjectToWorldEXT"; + case BuiltInWorldToObjectKHR: + return "gl_WorldToObjectEXT"; + case BuiltInHitTKHR: + return "gl_HitTEXT"; + case BuiltInHitKindKHR: + return "gl_HitKindEXT"; + case BuiltInIncomingRayFlagsKHR: + return "gl_IncomingRayFlagsEXT"; case BuiltInBaryCoordNV: { @@ -10777,23 +10778,23 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction) break; } - case OpReportIntersectionNV: - statement("reportIntersectionNV(", to_expression(ops[0]), ", ", to_expression(ops[1]), ");"); + case OpReportIntersectionKHR: + statement("reportIntersectionEXT(", to_expression(ops[0]), ", ", to_expression(ops[1]), ");"); break; - case OpIgnoreIntersectionNV: - statement("ignoreIntersectionNV();"); + case OpIgnoreIntersectionKHR: + statement("ignoreIntersectionEXT();"); break; - case OpTerminateRayNV: - statement("terminateRayNV();"); + case OpTerminateRayKHR: + statement("terminateRayEXT();"); break; - case OpTraceNV: - statement("traceNV(", to_expression(ops[0]), ", ", to_expression(ops[1]), ", ", to_expression(ops[2]), ", ", + case OpTraceRayKHR: + statement("traceRayEXT(", to_expression(ops[0]), ", ", to_expression(ops[1]), ", ", to_expression(ops[2]), ", ", to_expression(ops[3]), ", ", to_expression(ops[4]), ", ", to_expression(ops[5]), ", ", to_expression(ops[6]), ", ", to_expression(ops[7]), ", ", to_expression(ops[8]), ", ", to_expression(ops[9]), ", ", to_expression(ops[10]), ");"); break; - case OpExecuteCallableNV: - statement("executeCallableNV(", to_expression(ops[0]), ", ", to_expression(ops[1]), ");"); + case OpExecuteCallableKHR: + statement("executeCallableEXT(", to_expression(ops[0]), ", ", to_expression(ops[1]), ");"); break; case OpConvertUToPtr: @@ -11540,7 +11541,7 @@ string CompilerGLSL::type_to_glsl(const SPIRType &type, uint32_t id) return comparison_ids.count(id) ? "samplerShadow" : "sampler"; case SPIRType::AccelerationStructure: - return "accelerationStructureNV"; + return "accelerationStructureKHR"; case SPIRType::Void: return "void"; @@ -13726,11 +13727,12 @@ const SPIRVariable *CompilerGLSL::find_subpass_input_by_attachment_index(uint32_ return ret; } -const SPIRVariable *CompilerGLSL::find_color_output_by_location(uint32_t location) const +const SPIRVariable *CompilerGLSL::find_storage_class_variable_by_location(spv::StorageClass storage_class, + uint32_t location) const { const SPIRVariable *ret = nullptr; ir.for_each_typed_id([&](uint32_t, const SPIRVariable &var) { - if (var.storage == StorageClassOutput && get_decoration(var.self, DecorationLocation) == location) + if (var.storage == storage_class && get_decoration(var.self, DecorationLocation) == location) ret = &var; }); return ret; @@ -13741,7 +13743,7 @@ void CompilerGLSL::emit_inout_fragment_outputs_copy_to_subpass_inputs() for (auto &remap : subpass_to_framebuffer_fetch_attachment) { auto *subpass_var = find_subpass_input_by_attachment_index(remap.first); - auto *output_var = find_color_output_by_location(remap.second); + auto *output_var = find_storage_class_variable_by_location(StorageClassOutput, remap.second); if (!subpass_var) continue; if (!output_var) diff --git a/spirv_glsl.hpp b/spirv_glsl.hpp index 1eafc2cea..47aa79011 100644 --- a/spirv_glsl.hpp +++ b/spirv_glsl.hpp @@ -690,7 +690,8 @@ class CompilerGLSL : public Compiler bool subpass_input_is_framebuffer_fetch(uint32_t id) const; void emit_inout_fragment_outputs_copy_to_subpass_inputs(); const SPIRVariable *find_subpass_input_by_attachment_index(uint32_t index) const; - const SPIRVariable *find_color_output_by_location(uint32_t location) const; + const SPIRVariable *find_storage_class_variable_by_location(spv::StorageClass storage_class, + uint32_t location) const; // A variant which takes two sets of name. The secondary is only used to verify there are no collisions, // but the set is not updated when we have found a new name. diff --git a/spirv_hlsl.cpp b/spirv_hlsl.cpp index f27163c20..d08ea19ed 100644 --- a/spirv_hlsl.cpp +++ b/spirv_hlsl.cpp @@ -375,6 +375,9 @@ string CompilerHLSL::type_to_glsl(const SPIRType &type, uint32_t id) case SPIRType::Sampler: return comparison_ids.count(id) ? "SamplerComparisonState" : "SamplerState"; + case SPIRType::AccelerationStructure: + return "RaytracingAccelerationStructure"; + case SPIRType::Void: return "void"; @@ -582,6 +585,8 @@ void CompilerHLSL::emit_builtin_outputs_in_struct() void CompilerHLSL::emit_builtin_inputs_in_struct() { + auto &execution = get_entry_point(); + bool legacy = hlsl_options.shader_model <= 30; active_input_builtins.for_each_bit([&](uint32_t i) { const char *type = nullptr; @@ -603,11 +608,35 @@ void CompilerHLSL::emit_builtin_inputs_in_struct() break; case BuiltInInstanceId: + if (legacy) + SPIRV_CROSS_THROW("Instance index not supported in SM 3.0 or lower."); + type = "uint"; + // Ignore semantic when in RT shader + if (execution.model == ExecutionModelIntersectionKHR || execution.model == ExecutionModelAnyHitKHR || + execution.model == ExecutionModelClosestHitKHR) + semantic = nullptr; + else + semantic = "SV_InstanceID"; + break; + case BuiltInPrimitiveId: + type = "uint"; + // Ignore semantic when in RT shader + if (execution.model == ExecutionModelIntersectionKHR || execution.model == ExecutionModelAnyHitKHR || + execution.model == ExecutionModelClosestHitKHR) + semantic = nullptr; + else + semantic = "SV_PrimitiveID"; + break; case BuiltInInstanceIndex: if (legacy) SPIRV_CROSS_THROW("Instance index not supported in SM 3.0 or lower."); type = "uint"; - semantic = "SV_InstanceID"; + // Ignore semantic when in RT shader + if (execution.model == ExecutionModelIntersectionKHR || execution.model == ExecutionModelAnyHitKHR || + execution.model == ExecutionModelClosestHitKHR) + semantic = nullptr; + else + semantic = "SV_InstanceID"; break; case BuiltInSampleId: @@ -692,6 +721,62 @@ void CompilerHLSL::emit_builtin_inputs_in_struct() else SPIRV_CROSS_THROW("Unsupported builtin in HLSL."); + case BuiltInLaunchIdKHR: + type = "uint3"; + break; + + case BuiltInLaunchSizeKHR: + type = "uint2"; + break; + + case BuiltInWorldRayOriginKHR: + type = "float3"; + break; + + case BuiltInWorldRayDirectionKHR: + type = "float3"; + break; + + case BuiltInObjectRayOriginKHR: + type = "float3"; + break; + + case BuiltInObjectRayDirectionKHR: + type = "float3"; + break; + + case BuiltInRayTminKHR: + type = "float"; + break; + + case BuiltInRayTmaxKHR: + type = "float"; + break; + + case BuiltInInstanceCustomIndexKHR: + type = "uint"; + break; + + case BuiltInObjectToWorldKHR: + type = "float4x3"; + break; + + case BuiltInWorldToObjectKHR: + type = "float4x3"; + break; + + case BuiltInHitTKHR: + type = "float"; + break; + + case BuiltInHitKindKHR: + type = "uint"; + break; + + case BuiltInIncomingRayFlagsKHR: + type = "uint"; + break; + default: SPIRV_CROSS_THROW("Unsupported builtin in HLSL."); break; @@ -887,12 +972,27 @@ void CompilerHLSL::emit_interface_block_in_struct(const SPIRVariable &var, unord std::string CompilerHLSL::builtin_to_glsl(spv::BuiltIn builtin, spv::StorageClass storage) { + auto &execution = get_entry_point(); + switch (builtin) { case BuiltInVertexId: return "gl_VertexID"; case BuiltInInstanceId: - return "gl_InstanceID"; + // In RT shaders, this builtin gets overwritten + if (execution.model == ExecutionModelIntersectionKHR || execution.model == ExecutionModelAnyHitKHR || + execution.model == ExecutionModelClosestHitKHR) + return "InstanceIndex()"; + else + return "gl_InstanceID"; + case BuiltInPrimitiveId: + // In RT shaders, this builtin gets overwritten + if (execution.model == ExecutionModelIntersectionKHR || execution.model == ExecutionModelAnyHitKHR || + execution.model == ExecutionModelClosestHitKHR) + return "PrimitiveIndex()"; + else + return "SV_PrimitiveID"; + break; case BuiltInNumWorkgroups: { if (!num_workgroups_builtin) @@ -910,7 +1010,34 @@ std::string CompilerHLSL::builtin_to_glsl(spv::BuiltIn builtin, spv::StorageClas return "WaveGetLaneIndex()"; case BuiltInSubgroupSize: return "WaveGetLaneCount()"; - + case BuiltInLaunchIdKHR: + return "DispatchRaysIndex()"; + case BuiltInLaunchSizeKHR: + return "DispatchRaysDimensions()"; + case BuiltInWorldRayOriginKHR: + return "WorldRayOrigin()"; + case BuiltInWorldRayDirectionKHR: + return "WorldRayDirection()"; + case BuiltInObjectRayOriginKHR: + return "ObjectRayOrigin()"; + case BuiltInObjectRayDirectionKHR: + return "ObjectRayDirection()"; + case BuiltInRayTminKHR: + return "RayTMin()"; + case BuiltInRayTmaxKHR: + return "RayTCurrent()"; + case BuiltInInstanceCustomIndexKHR: + return "InstanceID()"; + case BuiltInObjectToWorldKHR: + return "ObjectToWorld4x3()"; + case BuiltInWorldToObjectKHR: + return "WorldToObject4x3()"; + case BuiltInHitTKHR: + return "RayTCurrent()"; + case BuiltInHitKindKHR: + return "HitKind()"; + case BuiltInIncomingRayFlagsKHR: + return "RayFlags()"; default: return CompilerGLSL::builtin_to_glsl(builtin, storage); } @@ -918,6 +1045,8 @@ std::string CompilerHLSL::builtin_to_glsl(spv::BuiltIn builtin, spv::StorageClas void CompilerHLSL::emit_builtin_variables() { + auto &execution = get_entry_point(); + Bitset builtins = active_input_builtins; builtins.merge_or(active_output_builtins); @@ -949,6 +1078,21 @@ void CompilerHLSL::emit_builtin_variables() break; case BuiltInInstanceId: + // Ignore when used in RT shaders, this is no longer a compile-time constant + if (execution.model == ExecutionModelIntersectionKHR || execution.model == ExecutionModelAnyHitKHR || + execution.model == ExecutionModelClosestHitKHR) + break; + type = "int"; + break; + + case BuiltInPrimitiveId: + // Ignore when used in RT shaders, this is no longer a compile-time constant + if (execution.model == ExecutionModelIntersectionKHR || execution.model == ExecutionModelAnyHitKHR || + execution.model == ExecutionModelClosestHitKHR) + break; + type = "uint"; + break; + case BuiltInSampleId: type = "int"; break; @@ -1008,6 +1152,23 @@ void CompilerHLSL::emit_builtin_variables() type = "float"; break; + case BuiltInLaunchIdKHR: + case BuiltInLaunchSizeKHR: + case BuiltInWorldRayOriginKHR: + case BuiltInWorldRayDirectionKHR: + case BuiltInObjectRayOriginKHR: + case BuiltInObjectRayDirectionKHR: + case BuiltInRayTminKHR: + case BuiltInRayTmaxKHR: + case BuiltInInstanceCustomIndexKHR: + case BuiltInObjectToWorldKHR: + case BuiltInWorldToObjectKHR: + case BuiltInHitTKHR: + case BuiltInHitKindKHR: + case BuiltInIncomingRayFlagsKHR: + // handled specially since they aren't compile time constants + break; + default: SPIRV_CROSS_THROW(join("Unsupported builtin in HLSL: ", unsigned(builtin))); } @@ -1196,6 +1357,51 @@ void CompilerHLSL::emit_resources() emitted = true; } + // HLSL requires ray payloads to be structs + if (execution.model == ExecutionModelClosestHitKHR || execution.model == ExecutionModelAnyHitKHR || + execution.model == ExecutionModelMissKHR) + { + auto *payload_var = get_ray_tracing_payload(); + if (payload_var) + { + std::string name = get_name(payload_var->self); + std::string type = type_to_glsl_constructor(get_type(payload_var->basetype)); + // Payload type is a primitive, create a shadow struct to wrap it + if (get_type(payload_var->basetype).basetype != SPIRType::Struct) + { + statement("struct _ShadowPayloadData_"); + statement("{"); + statement(" ", type_to_glsl(get_type(payload_var->basetype)), " data;"); + statement("};"); + } + statement("static ", type, " ", name, ";"); + } + // fallback payload struct + else { + statement("struct _ShadowPayloadData_ { float4 data; };"); + } + } + + // HLSL requires hit attributes to be structs + if (execution.model == ExecutionModelClosestHitKHR || execution.model == ExecutionModelAnyHitKHR) + { + auto *hitattrib_var = get_ray_tracing_hit_attrib(); + if (hitattrib_var) + { + std::string name = get_name(hitattrib_var->self); + std::string type = type_to_glsl_constructor(get_type(hitattrib_var->basetype)); + // Hit attribute type is a primitive, create a shadow struct to wrap it + if (get_type(hitattrib_var->basetype).basetype != SPIRType::Struct) + { + statement("struct _ShadowHitAttributeData_"); + statement("{"); + statement(" ", type_to_glsl(get_type(hitattrib_var->basetype)), " attribs;"); + statement("};"); + } + statement("static ", type, " ", name, ";"); + } + } + bool skip_separate_image_sampler = !combined_image_samplers.empty() || hlsl_options.shader_model <= 30; // Output Uniform Constants (values, samplers, images, etc). @@ -1214,7 +1420,9 @@ void CompilerHLSL::emit_resources() } if (var.storage != StorageClassFunction && !is_builtin_variable(var) && !var.remapped_variable && - type.pointer && (type.storage == StorageClassUniformConstant || type.storage == StorageClassAtomicCounter)) + type.pointer && + (type.storage == StorageClassUniformConstant || type.storage == StorageClassAtomicCounter || + type.storage == StorageClassRayPayloadKHR)) { emit_uniform(var); emitted = true; @@ -1325,6 +1533,13 @@ void CompilerHLSL::emit_resources() }; auto input_builtins = active_input_builtins; + // only make non-compile-time constant when used in RT shader + if (execution.model == ExecutionModelIntersectionKHR || execution.model == ExecutionModelAnyHitKHR || + execution.model == ExecutionModelClosestHitKHR) + { + input_builtins.clear(BuiltInInstanceId); + input_builtins.clear(BuiltInPrimitiveId); + } input_builtins.clear(BuiltInNumWorkgroups); input_builtins.clear(BuiltInPointCoord); input_builtins.clear(BuiltInSubgroupSize); @@ -1334,6 +1549,21 @@ void CompilerHLSL::emit_resources() input_builtins.clear(BuiltInSubgroupLeMask); input_builtins.clear(BuiltInSubgroupGtMask); input_builtins.clear(BuiltInSubgroupGeMask); + // RT builtins + input_builtins.clear(BuiltInLaunchIdKHR); + input_builtins.clear(BuiltInLaunchSizeKHR); + input_builtins.clear(BuiltInWorldRayOriginKHR); + input_builtins.clear(BuiltInWorldRayDirectionKHR); + input_builtins.clear(BuiltInObjectRayOriginKHR); + input_builtins.clear(BuiltInObjectRayDirectionKHR); + input_builtins.clear(BuiltInRayTminKHR); + input_builtins.clear(BuiltInRayTmaxKHR); + input_builtins.clear(BuiltInInstanceCustomIndexKHR); + input_builtins.clear(BuiltInObjectToWorldKHR); + input_builtins.clear(BuiltInWorldToObjectKHR); + input_builtins.clear(BuiltInHitTKHR); + input_builtins.clear(BuiltInHitKindKHR); + input_builtins.clear(BuiltInIncomingRayFlagsKHR); if (!input_variables.empty() || !input_builtins.empty()) { @@ -2144,6 +2374,18 @@ void CompilerHLSL::emit_function_prototype(SPIRFunction &func, const Bitset &ret decl += "frag_main"; else if (execution.model == ExecutionModelGLCompute) decl += "comp_main"; + else if (execution.model == ExecutionModelRayGenerationKHR) + decl += "rgen_main"; + else if (execution.model == ExecutionModelIntersectionKHR) + decl += "rint_main"; + else if (execution.model == ExecutionModelAnyHitKHR) + decl += "rahit_main"; + else if (execution.model == ExecutionModelClosestHitKHR) + decl += "rchit_main"; + else if (execution.model == ExecutionModelMissKHR) + decl += "rmiss_main"; + else if (execution.model == ExecutionModelCallableKHR) + decl += "call_main"; else SPIRV_CROSS_THROW("Unsupported execution model."); processing_entry_point = true; @@ -2266,6 +2508,36 @@ void CompilerHLSL::emit_hlsl_entry_point() statement("[numthreads(", x_expr, ", ", y_expr, ", ", z_expr, ")]"); break; } + case ExecutionModelRayGenerationKHR: + { + statement("[shader(\"raygeneration\")]"); + break; + } + case ExecutionModelIntersectionKHR: + { + statement("[shader(\"intersection\")]"); + break; + } + case ExecutionModelAnyHitKHR: + { + statement("[shader(\"anyhit\")]"); + break; + } + case ExecutionModelClosestHitKHR: + { + statement("[shader(\"closesthit\")]"); + break; + } + case ExecutionModelMissKHR: + { + statement("[shader(\"miss\")]"); + break; + } + case ExecutionModelCallableKHR: + { + statement("[shader(\"callable\")]"); + break; + } case ExecutionModelFragment: if (execution.flags.get(ExecutionModeEarlyFragmentTests)) statement("[earlydepthstencil]"); @@ -2274,6 +2546,54 @@ void CompilerHLSL::emit_hlsl_entry_point() break; } + // Add incoming payload and hit attributes for Hit shaders + if (execution.model == ExecutionModelClosestHitKHR || execution.model == ExecutionModelAnyHitKHR || + execution.model == ExecutionModelMissKHR) + { + // Add incoming payload + { + string out_argument; + auto *payload_var = get_ray_tracing_payload(); + out_argument += "inout "; + if (payload_var) + { + if (get_type(payload_var->basetype).basetype != SPIRType::Struct) + out_argument += "_ShadowPayloadData_"; + else + out_argument += type_to_glsl_constructor(get_type(payload_var->basetype)); + } + // no payload used, fallback to default payload + else { + out_argument += "_ShadowPayloadData_"; + } + out_argument += " "; + out_argument += "_payloadOut_"; + arguments.push_back(move(out_argument)); + } + // Add incoming hit attribute + if (execution.model == ExecutionModelClosestHitKHR || execution.model == ExecutionModelAnyHitKHR) + { + string out_argument; + const SPIRVariable *hitattrib_var = get_ray_tracing_hit_attrib(); + out_argument += "in "; + // in case a primitive attribute is used, fallback to the shadow struct + if (hitattrib_var) + { + if (get_type(hitattrib_var->basetype).basetype != SPIRType::Struct) + out_argument += "_ShadowHitAttributeData_"; + else + out_argument += type_to_glsl_constructor(get_type(hitattrib_var->basetype)); + } + // no hit attributes used, fallback to default hit attributes + else { + out_argument += "BuiltInTriangleIntersectionAttributes"; + } + out_argument += " "; + out_argument += "_hitAttribsOut_"; + arguments.push_back(move(out_argument)); + } + } + statement(require_output ? "SPIRV_Cross_Output " : "void ", "main(", merge(arguments), ")"); begin_scope(); bool legacy = hlsl_options.shader_model <= 30; @@ -2309,8 +2629,22 @@ void CompilerHLSL::emit_hlsl_entry_point() break; case BuiltInInstanceId: - // D3D semantics are uint, but shader wants int. - statement(builtin, " = int(stage_input.", builtin, ");"); + // In RT shaders this is not a compile-time constant + if (execution.model != ExecutionModelIntersectionKHR && execution.model != ExecutionModelAnyHitKHR && + execution.model != ExecutionModelClosestHitKHR) + { + // D3D semantics are uint, but shader wants int. + statement(builtin, " = int(stage_input.", builtin, ");"); + } + break; + + case BuiltInPrimitiveId: + // In RT shaders this is not a compile-time constant + if (execution.model != ExecutionModelIntersectionKHR && execution.model != ExecutionModelAnyHitKHR && + execution.model != ExecutionModelClosestHitKHR) + { + statement(builtin, " = stage_input.", builtin, ";"); + } break; case BuiltInNumWorkgroups: @@ -2393,6 +2727,23 @@ void CompilerHLSL::emit_hlsl_entry_point() ";"); break; + case BuiltInLaunchIdKHR: + case BuiltInLaunchSizeKHR: + case BuiltInWorldRayOriginKHR: + case BuiltInWorldRayDirectionKHR: + case BuiltInObjectRayOriginKHR: + case BuiltInObjectRayDirectionKHR: + case BuiltInRayTminKHR: + case BuiltInRayTmaxKHR: + case BuiltInInstanceCustomIndexKHR: + case BuiltInObjectToWorldKHR: + case BuiltInWorldToObjectKHR: + case BuiltInHitTKHR: + case BuiltInHitKindKHR: + case BuiltInIncomingRayFlagsKHR: + // handled specially since they aren't compile time constants + break; + default: statement(builtin, " = stage_input.", builtin, ";"); break; @@ -2434,6 +2785,37 @@ void CompilerHLSL::emit_hlsl_entry_point() } }); + // Copy the payload result in + if (execution.model == ExecutionModelClosestHitKHR || execution.model == ExecutionModelAnyHitKHR || + execution.model == ExecutionModelMissKHR) + { + auto *payload_var = get_ray_tracing_payload(); + if (payload_var) + { + std::string name = get_name(payload_var->self); + if (get_type(payload_var->basetype).basetype != SPIRType::Struct) + // copy input data to shadow data + statement(name, " = _payloadOut_.data", ";"); + else + statement(name, " = _payloadOut_", ";"); + } + } + + // Copy the hit attribute result in + if (execution.model == ExecutionModelClosestHitKHR || execution.model == ExecutionModelAnyHitKHR) + { + const SPIRVariable *hitattrib_var = get_ray_tracing_hit_attrib(); + if (hitattrib_var) + { + std::string name = get_name(hitattrib_var->self); + if (get_type(hitattrib_var->basetype).basetype != SPIRType::Struct) + // copy input data to shadow data + statement(name, " = _hitAttribsOut_.attribs", ";"); + else + statement(name, " = _hitAttribsOut_", ";"); + } + } + // Run the shader. if (execution.model == ExecutionModelVertex) statement("vert_main();"); @@ -2441,9 +2823,37 @@ void CompilerHLSL::emit_hlsl_entry_point() statement("frag_main();"); else if (execution.model == ExecutionModelGLCompute) statement("comp_main();"); + else if (execution.model == ExecutionModelRayGenerationKHR) + statement("rgen_main();"); + else if (execution.model == ExecutionModelIntersectionKHR) + statement("rint_main();"); + else if (execution.model == ExecutionModelAnyHitKHR) + statement("rahit_main();"); + else if (execution.model == ExecutionModelClosestHitKHR) + statement("rchit_main();"); + else if (execution.model == ExecutionModelMissKHR) + statement("rmiss_main();"); + else if (execution.model == ExecutionModelCallableKHR) + statement("call_main();"); else SPIRV_CROSS_THROW("Unsupported shader stage."); + // Copy the payload result back + if (execution.model == ExecutionModelClosestHitKHR || execution.model == ExecutionModelAnyHitKHR || + execution.model == ExecutionModelMissKHR) + { + auto *payload_var = get_ray_tracing_payload(); + if (payload_var) + { + std::string name = get_name(payload_var->self); + if (get_type(payload_var->basetype).basetype != SPIRType::Struct) + // copy input data to shadow data + statement("_payloadOut_.data = ", name, ";"); + else + statement("_payloadOut_ = ", name, ";"); + } + } + // Copy block outputs. ir.for_each_typed_id([&](uint32_t, SPIRVariable &var) { auto &type = this->get(var.basetype); @@ -3030,6 +3440,11 @@ string CompilerHLSL::to_resource_binding(const SPIRVariable &var) resource_flags = HLSL_BINDING_AUTO_SAMPLER_BIT; break; + case SPIRType::AccelerationStructure: + space = 't'; // SRV + resource_flags = HLSL_BINDING_AUTO_SRV_BIT; + break; + case SPIRType::Struct: { auto storage = type.storage; @@ -3153,6 +3568,7 @@ string CompilerHLSL::to_resource_register(HLSLBindingFlagBits flag, char space, void CompilerHLSL::emit_modern_uniform(const SPIRVariable &var) { auto &type = get(var.basetype); + switch (type.basetype) { case SPIRType::SampledImage: @@ -3186,6 +3602,11 @@ void CompilerHLSL::emit_modern_uniform(const SPIRVariable &var) statement("SamplerState ", to_name(var.self), type_to_array_glsl(type), to_resource_binding(var), ";"); break; + case SPIRType::AccelerationStructure: + statement("RaytracingAccelerationStructure ", to_name(var.self), type_to_array_glsl(type), + to_resource_binding(var), ";"); + break; + default: statement(variable_decl(var), to_resource_binding(var), ";"); break; @@ -3209,6 +3630,13 @@ void CompilerHLSL::emit_legacy_uniform(const SPIRVariable &var) void CompilerHLSL::emit_uniform(const SPIRVariable &var) { + // make rgen payload static + auto &type = this->get(var.basetype); + if (type.storage == StorageClassRayPayloadKHR) + { + statement_inner("static "); + } + add_resource_name(var.self); if (hlsl_options.shader_model >= 40) emit_modern_uniform(var); @@ -4174,13 +4602,14 @@ void CompilerHLSL::emit_atomic(const uint32_t *ops, uint32_t length, spv::Op op) if (data_type.storage == StorageClassImage || !chain) { - statement(atomic_op, "(", to_expression(ops[0]), ", ", to_expression(ops[3]), ", ", to_expression(tmp_id), ");"); + statement(atomic_op, "(", to_expression(ops[0]), ", ", to_expression(ops[3]), ", ", to_expression(tmp_id), + ");"); } else { // RWByteAddress buffer is always uint in its underlying type. - statement(chain->base, ".", atomic_op, "(", chain->dynamic_index, chain->static_index, ", ", to_expression(ops[3]), - ", ", to_expression(tmp_id), ");"); + statement(chain->base, ".", atomic_op, "(", chain->dynamic_index, chain->static_index, ", ", + to_expression(ops[3]), ", ", to_expression(tmp_id), ");"); } } else @@ -5186,6 +5615,89 @@ void CompilerHLSL::emit_instruction(const Instruction &instruction) SPIRV_CROSS_THROW("Rasterizer order views require Shader Model 5.1."); break; // Nothing to do in the body + case OpReportIntersectionKHR: + { + auto *hitattrib_var = get_ray_tracing_hit_attrib(); + if (!hitattrib_var) + SPIRV_CROSS_THROW("Failed to lookup hit attribute for OpReportIntersectionKHR"); + + bool is_primitive_attr = get_type(hitattrib_var->basetype).basetype != SPIRType::Struct; + + std::string hitattrib_name = get_name(hitattrib_var->self); + std::string hitattrib_type = type_to_glsl_constructor(get_type(hitattrib_var->basetype)); + + std::string hitattrib_uid = get_unique_identifier(); + + // Target attribute is either the user attribute struct, or our struct wrapped attribute + std::string target_attr_name = is_primitive_attr ? hitattrib_uid : hitattrib_name; + + if (is_primitive_attr) + { + // Attribute is a primitive and needs to be struct wrapped + statement("struct PrimitiveWrap", hitattrib_uid, " {"); + statement(" ", hitattrib_type, " data", ";"); + statement("} ", hitattrib_uid, ";"); + // Copy current payload value into struct wrap + statement(hitattrib_uid, ".data", " = ", hitattrib_name, ";"); + } + statement("ReportHit(", to_expression(ops[0]), ",", to_expression(ops[1]), ",", target_attr_name, ");"); + break; + } + case OpIgnoreIntersectionKHR: + statement("IgnoreHit();"); + break; + case OpTerminateRayKHR: + statement("AcceptHitAndEndSearch();"); + break; + + case OpTraceRayKHR: + { + // In GLSL, payload is passed as a number and is a compile-time constant + // In HLSL, payload is passed as a variable + // To simulate GLSL behavior, we lookup the payload name based on the + // input location's index, which is passed in the last param of TraceRay + uint32_t payload_index = std::stoi(to_expression(ops[10])); + + auto *payload_var = find_storage_class_variable_by_location(StorageClassRayPayloadKHR, payload_index); + if (!payload_var) + SPIRV_CROSS_THROW("Failed to lookup location of rayPayloadEXT"); + + std::string payload_name = get_name(payload_var->self); + std::string payload_type = type_to_glsl_constructor(get_type(payload_var->basetype)); + + bool is_primitive_payload = get_type(payload_var->basetype).basetype != SPIRType::Struct; + + std::string ray_uid = get_unique_identifier(); + std::string primitive_uid = get_unique_identifier(); + + // Target payload is either the user payload struct, or our struct wrapped payload + std::string target_payload_name = is_primitive_payload ? primitive_uid : payload_name; + + statement("RayDesc ", ray_uid, ";"); + statement(ray_uid, ".Origin = ", to_expression(ops[6]), ";"); + statement(ray_uid, ".Direction = ", to_expression(ops[8]), ";"); + statement(ray_uid, ".TMin = ", to_expression(ops[7]), ";"); + statement(ray_uid, ".TMax = ", to_expression(ops[9]), ";"); + if (is_primitive_payload) + { + // Payload is a primitive and needs to be struct wrapped + statement("struct PrimitiveWrap", primitive_uid, " {"); + statement(" ", payload_type, " data", ";"); + statement("} ", primitive_uid, ";"); + // Copy current payload value into struct wrap + statement(primitive_uid, ".data", " = ", payload_name, ";"); + } + statement("TraceRay(", to_expression(ops[0]), ", ", to_expression(ops[1]), ", ", to_expression(ops[2]), ", ", + to_expression(ops[3]), ", ", to_expression(ops[4]), ", ", to_expression(ops[5]), ", ", ray_uid, ", ", + target_payload_name, ");"); + if (is_primitive_payload) + { + // Copy payload result back from struct wrap to the payload primitive + statement(payload_name, " = ", primitive_uid, ".data", ";"); + } + break; + } + default: CompilerGLSL::emit_instruction(instruction); break; @@ -5426,6 +5938,28 @@ string CompilerHLSL::get_unique_identifier() return join("_", unique_identifier_count++, "ident"); } +const SPIRVariable *CompilerHLSL::get_ray_tracing_payload() +{ + const SPIRVariable *ret = nullptr; + // Find incoming payload + ir.for_each_typed_id([&](uint32_t, const SPIRVariable &var) { + if (var.storage == StorageClassIncomingRayPayloadKHR) + ret = &var; + }); + return ret; +} + +const SPIRVariable *CompilerHLSL::get_ray_tracing_hit_attrib() +{ + const SPIRVariable *ret = nullptr; + // Find incoming hit attribute + ir.for_each_typed_id([&](uint32_t, const SPIRVariable &var) { + if (var.storage == StorageClassHitAttributeKHR) + ret = &var; + }); + return ret; +} + void CompilerHLSL::add_hlsl_resource_binding(const HLSLResourceBinding &binding) { StageSetBinding tuple = { binding.stage, binding.desc_set, binding.binding }; diff --git a/spirv_hlsl.hpp b/spirv_hlsl.hpp index e98c6443a..b0fc0b411 100644 --- a/spirv_hlsl.hpp +++ b/spirv_hlsl.hpp @@ -321,6 +321,9 @@ class CompilerHLSL : public CompilerGLSL std::string get_unique_identifier(); uint32_t unique_identifier_count = 0; + const SPIRVariable *get_ray_tracing_payload(); + const SPIRVariable *get_ray_tracing_hit_attrib(); + std::unordered_map, InternalHasher> resource_bindings; void remap_hlsl_resource_binding(HLSLBindingFlagBits type, uint32_t &desc_set, uint32_t &binding); }; From 7c1bdaee0e4a2c2197ecc3217c64aa12c98f9af4 Mon Sep 17 00:00:00 2001 From: Felix Maier Date: Sun, 17 May 2020 11:46:53 +0200 Subject: [PATCH 2/6] Formattings --- spirv_glsl.cpp | 2 +- spirv_hlsl.cpp | 9 ++++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/spirv_glsl.cpp b/spirv_glsl.cpp index 6032c55b7..5ae569779 100644 --- a/spirv_glsl.cpp +++ b/spirv_glsl.cpp @@ -6805,7 +6805,7 @@ string CompilerGLSL::builtin_to_glsl(BuiltIn builtin, StorageClass storage) default: SPIRV_CROSS_THROW( - "Cannot implement gl_InstanceID in Vulkan GLSL. This shader was created with GL semantics."); + "Cannot implement gl_InstanceID in Vulkan GLSL. This shader was created with GL semantics."); } } return "gl_InstanceID"; diff --git a/spirv_hlsl.cpp b/spirv_hlsl.cpp index d08ea19ed..73c14c91c 100644 --- a/spirv_hlsl.cpp +++ b/spirv_hlsl.cpp @@ -1377,7 +1377,8 @@ void CompilerHLSL::emit_resources() statement("static ", type, " ", name, ";"); } // fallback payload struct - else { + else + { statement("struct _ShadowPayloadData_ { float4 data; };"); } } @@ -2563,7 +2564,8 @@ void CompilerHLSL::emit_hlsl_entry_point() out_argument += type_to_glsl_constructor(get_type(payload_var->basetype)); } // no payload used, fallback to default payload - else { + else + { out_argument += "_ShadowPayloadData_"; } out_argument += " "; @@ -2585,7 +2587,8 @@ void CompilerHLSL::emit_hlsl_entry_point() out_argument += type_to_glsl_constructor(get_type(hitattrib_var->basetype)); } // no hit attributes used, fallback to default hit attributes - else { + else + { out_argument += "BuiltInTriangleIntersectionAttributes"; } out_argument += " "; From ef6e321473b971a916a68e9fcec3d5bd9f38b703 Mon Sep 17 00:00:00 2001 From: Felix Maier Date: Tue, 19 May 2020 13:56:58 +0200 Subject: [PATCH 3/6] Fix intersection hit attribute --- spirv_hlsl.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/spirv_hlsl.cpp b/spirv_hlsl.cpp index 73c14c91c..75ab5de1a 100644 --- a/spirv_hlsl.cpp +++ b/spirv_hlsl.cpp @@ -1384,7 +1384,8 @@ void CompilerHLSL::emit_resources() } // HLSL requires hit attributes to be structs - if (execution.model == ExecutionModelClosestHitKHR || execution.model == ExecutionModelAnyHitKHR) + if (execution.model == ExecutionModelClosestHitKHR || execution.model == ExecutionModelAnyHitKHR || + ExecutionModelIntersectionKHR) { auto *hitattrib_var = get_ray_tracing_hit_attrib(); if (hitattrib_var) From eee53da8834c3d10f805e91b22f392b39a10e1e4 Mon Sep 17 00:00:00 2001 From: Felix Maier Date: Tue, 19 May 2020 14:07:35 +0200 Subject: [PATCH 4/6] Fix ops indices --- spirv_hlsl.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spirv_hlsl.cpp b/spirv_hlsl.cpp index 75ab5de1a..555cad758 100644 --- a/spirv_hlsl.cpp +++ b/spirv_hlsl.cpp @@ -5644,7 +5644,7 @@ void CompilerHLSL::emit_instruction(const Instruction &instruction) // Copy current payload value into struct wrap statement(hitattrib_uid, ".data", " = ", hitattrib_name, ";"); } - statement("ReportHit(", to_expression(ops[0]), ",", to_expression(ops[1]), ",", target_attr_name, ");"); + statement("ReportHit(", to_expression(ops[2]), ",", to_expression(ops[3]), ",", target_attr_name, ");"); break; } case OpIgnoreIntersectionKHR: From a5ec85d6bc894431985af5af05d0fd5c0fcf8d2e Mon Sep 17 00:00:00 2001 From: Felix Maier Date: Tue, 19 May 2020 18:22:50 +0200 Subject: [PATCH 5/6] KHR -> NV, Rename get_ray_tracing_payload to get_ray_tracing_in_payload --- main.cpp | 24 ++--- spirv_glsl.cpp | 76 ++++++------- spirv_hlsl.cpp | 281 ++++++++++++++++++++++++------------------------- spirv_hlsl.hpp | 2 +- 4 files changed, 191 insertions(+), 192 deletions(-) diff --git a/main.cpp b/main.cpp index 6424079df..b69c45e62 100644 --- a/main.cpp +++ b/main.cpp @@ -345,18 +345,18 @@ static const char *execution_model_to_str(spv::ExecutionModel model) return "fragment"; case ExecutionModelGLCompute: return "compute"; - case ExecutionModelRayGenerationKHR: - return "raygenKHR"; - case ExecutionModelIntersectionKHR: - return "intersectionKHR"; - case ExecutionModelCallableKHR: - return "callableKHR"; - case ExecutionModelAnyHitKHR: - return "anyhitKHR"; - case ExecutionModelClosestHitKHR: - return "closesthitKHR"; - case ExecutionModelMissKHR: - return "missKHR"; + case ExecutionModelRayGenerationNV: + return "raygenNV"; + case ExecutionModelIntersectionNV: + return "intersectionNV"; + case ExecutionModelCallableNV: + return "callableNV"; + case ExecutionModelAnyHitNV: + return "anyhitNV"; + case ExecutionModelClosestHitNV: + return "closesthitNV"; + case ExecutionModelMissNV: + return "missNV"; default: return "???"; } diff --git a/spirv_glsl.cpp b/spirv_glsl.cpp index 5ae569779..2b484a9ab 100644 --- a/spirv_glsl.cpp +++ b/spirv_glsl.cpp @@ -445,15 +445,15 @@ void CompilerGLSL::find_static_extensions() require_extension_internal("GL_ARB_tessellation_shader"); break; - case ExecutionModelRayGenerationKHR: - case ExecutionModelIntersectionKHR: - case ExecutionModelAnyHitKHR: - case ExecutionModelClosestHitKHR: - case ExecutionModelMissKHR: - case ExecutionModelCallableKHR: + case ExecutionModelRayGenerationNV: + case ExecutionModelIntersectionNV: + case ExecutionModelAnyHitNV: + case ExecutionModelClosestHitNV: + case ExecutionModelMissNV: + case ExecutionModelCallableNV: if (options.es || options.version < 460) SPIRV_CROSS_THROW("Ray tracing shaders require non-es profile with version 460 or above."); - require_extension_internal("GL_KHR_ray_tracing"); + require_extension_internal("GL_NV_ray_tracing"); break; default: @@ -2089,23 +2089,23 @@ const char *CompilerGLSL::to_storage_qualifiers_glsl(const SPIRVariable &var) { return "uniform "; } - else if (var.storage == StorageClassRayPayloadKHR) + else if (var.storage == StorageClassRayPayloadNV) { return "rayPayloadEXT "; } - else if (var.storage == StorageClassIncomingRayPayloadKHR) + else if (var.storage == StorageClassIncomingRayPayloadNV) { return "rayPayloadInEXT "; } - else if (var.storage == StorageClassHitAttributeKHR) + else if (var.storage == StorageClassHitAttributeNV) { return "hitAttributeEXT "; } - else if (var.storage == StorageClassCallableDataKHR) + else if (var.storage == StorageClassCallableDataNV) { return "callableDataEXT "; } - else if (var.storage == StorageClassIncomingCallableDataKHR) + else if (var.storage == StorageClassIncomingCallableDataNV) { return "callableDataInEXT "; } @@ -3056,9 +3056,9 @@ void CompilerGLSL::emit_resources() if (var.storage != StorageClassFunction && type.pointer && (type.storage == StorageClassUniformConstant || type.storage == StorageClassAtomicCounter || - type.storage == StorageClassRayPayloadKHR || type.storage == StorageClassIncomingRayPayloadKHR || - type.storage == StorageClassCallableDataKHR || type.storage == StorageClassIncomingCallableDataKHR || - type.storage == StorageClassHitAttributeKHR) && + type.storage == StorageClassRayPayloadNV || type.storage == StorageClassIncomingRayPayloadNV || + type.storage == StorageClassCallableDataNV || type.storage == StorageClassIncomingCallableDataNV || + type.storage == StorageClassHitAttributeNV) && !is_hidden_variable(var)) { emit_uniform(var); @@ -6797,9 +6797,9 @@ string CompilerGLSL::builtin_to_glsl(BuiltIn builtin, StorageClass storage) auto model = get_entry_point().model; switch (model) { - case spv::ExecutionModelIntersectionKHR: - case spv::ExecutionModelAnyHitKHR: - case spv::ExecutionModelClosestHitKHR: + case spv::ExecutionModelIntersectionNV: + case spv::ExecutionModelAnyHitNV: + case spv::ExecutionModelClosestHitNV: // gl_InstanceID is allowed in these shaders. break; @@ -6979,33 +6979,33 @@ string CompilerGLSL::builtin_to_glsl(BuiltIn builtin, StorageClass storage) require_extension_internal("GL_KHR_shader_subgroup_ballot"); return "gl_SubgroupLtMask"; - case BuiltInLaunchIdKHR: + case BuiltInLaunchIdNV: return "gl_LaunchIDEXT"; - case BuiltInLaunchSizeKHR: + case BuiltInLaunchSizeNV: return "gl_LaunchSizeEXT"; - case BuiltInWorldRayOriginKHR: + case BuiltInWorldRayOriginNV: return "gl_WorldRayOriginEXT"; - case BuiltInWorldRayDirectionKHR: + case BuiltInWorldRayDirectionNV: return "gl_WorldRayDirectionEXT"; - case BuiltInObjectRayOriginKHR: + case BuiltInObjectRayOriginNV: return "gl_ObjectRayOriginEXT"; - case BuiltInObjectRayDirectionKHR: + case BuiltInObjectRayDirectionNV: return "gl_ObjectRayDirectionEXT"; - case BuiltInRayTminKHR: + case BuiltInRayTminNV: return "gl_RayTminEXT"; - case BuiltInRayTmaxKHR: + case BuiltInRayTmaxNV: return "gl_RayTmaxEXT"; - case BuiltInInstanceCustomIndexKHR: + case BuiltInInstanceCustomIndexNV: return "gl_InstanceCustomIndexEXT"; - case BuiltInObjectToWorldKHR: + case BuiltInObjectToWorldNV: return "gl_ObjectToWorldEXT"; - case BuiltInWorldToObjectKHR: + case BuiltInWorldToObjectNV: return "gl_WorldToObjectEXT"; - case BuiltInHitTKHR: + case BuiltInHitTNV: return "gl_HitTEXT"; - case BuiltInHitKindKHR: + case BuiltInHitKindNV: return "gl_HitKindEXT"; - case BuiltInIncomingRayFlagsKHR: + case BuiltInIncomingRayFlagsNV: return "gl_IncomingRayFlagsEXT"; case BuiltInBaryCoordNV: @@ -10791,22 +10791,22 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction) break; } - case OpReportIntersectionKHR: + case OpReportIntersectionNV: statement("reportIntersectionEXT(", to_expression(ops[0]), ", ", to_expression(ops[1]), ");"); break; - case OpIgnoreIntersectionKHR: + case OpIgnoreIntersectionNV: statement("ignoreIntersectionEXT();"); break; - case OpTerminateRayKHR: + case OpTerminateRayNV: statement("terminateRayEXT();"); break; - case OpTraceRayKHR: + case OpTraceNV: statement("traceRayEXT(", to_expression(ops[0]), ", ", to_expression(ops[1]), ", ", to_expression(ops[2]), ", ", to_expression(ops[3]), ", ", to_expression(ops[4]), ", ", to_expression(ops[5]), ", ", to_expression(ops[6]), ", ", to_expression(ops[7]), ", ", to_expression(ops[8]), ", ", to_expression(ops[9]), ", ", to_expression(ops[10]), ");"); break; - case OpExecuteCallableKHR: + case OpExecuteCallableNV: statement("executeCallableEXT(", to_expression(ops[0]), ", ", to_expression(ops[1]), ");"); break; @@ -11554,7 +11554,7 @@ string CompilerGLSL::type_to_glsl(const SPIRType &type, uint32_t id) return comparison_ids.count(id) ? "samplerShadow" : "sampler"; case SPIRType::AccelerationStructure: - return "accelerationStructureKHR"; + return "accelerationStructureNV"; case SPIRType::Void: return "void"; diff --git a/spirv_hlsl.cpp b/spirv_hlsl.cpp index 555cad758..279d739a0 100644 --- a/spirv_hlsl.cpp +++ b/spirv_hlsl.cpp @@ -612,8 +612,8 @@ void CompilerHLSL::emit_builtin_inputs_in_struct() SPIRV_CROSS_THROW("Instance index not supported in SM 3.0 or lower."); type = "uint"; // Ignore semantic when in RT shader - if (execution.model == ExecutionModelIntersectionKHR || execution.model == ExecutionModelAnyHitKHR || - execution.model == ExecutionModelClosestHitKHR) + if (execution.model == ExecutionModelIntersectionNV || execution.model == ExecutionModelAnyHitNV || + execution.model == ExecutionModelClosestHitNV) semantic = nullptr; else semantic = "SV_InstanceID"; @@ -621,8 +621,8 @@ void CompilerHLSL::emit_builtin_inputs_in_struct() case BuiltInPrimitiveId: type = "uint"; // Ignore semantic when in RT shader - if (execution.model == ExecutionModelIntersectionKHR || execution.model == ExecutionModelAnyHitKHR || - execution.model == ExecutionModelClosestHitKHR) + if (execution.model == ExecutionModelIntersectionNV || execution.model == ExecutionModelAnyHitNV || + execution.model == ExecutionModelClosestHitNV) semantic = nullptr; else semantic = "SV_PrimitiveID"; @@ -632,8 +632,8 @@ void CompilerHLSL::emit_builtin_inputs_in_struct() SPIRV_CROSS_THROW("Instance index not supported in SM 3.0 or lower."); type = "uint"; // Ignore semantic when in RT shader - if (execution.model == ExecutionModelIntersectionKHR || execution.model == ExecutionModelAnyHitKHR || - execution.model == ExecutionModelClosestHitKHR) + if (execution.model == ExecutionModelIntersectionNV || execution.model == ExecutionModelAnyHitNV || + execution.model == ExecutionModelClosestHitNV) semantic = nullptr; else semantic = "SV_InstanceID"; @@ -721,59 +721,59 @@ void CompilerHLSL::emit_builtin_inputs_in_struct() else SPIRV_CROSS_THROW("Unsupported builtin in HLSL."); - case BuiltInLaunchIdKHR: + case BuiltInLaunchIdNV: type = "uint3"; break; - case BuiltInLaunchSizeKHR: + case BuiltInLaunchSizeNV: type = "uint2"; break; - case BuiltInWorldRayOriginKHR: + case BuiltInWorldRayOriginNV: type = "float3"; break; - case BuiltInWorldRayDirectionKHR: + case BuiltInWorldRayDirectionNV: type = "float3"; break; - case BuiltInObjectRayOriginKHR: + case BuiltInObjectRayOriginNV: type = "float3"; break; - case BuiltInObjectRayDirectionKHR: + case BuiltInObjectRayDirectionNV: type = "float3"; break; - case BuiltInRayTminKHR: + case BuiltInRayTminNV: type = "float"; break; - case BuiltInRayTmaxKHR: + case BuiltInRayTmaxNV: type = "float"; break; - case BuiltInInstanceCustomIndexKHR: + case BuiltInInstanceCustomIndexNV: type = "uint"; break; - case BuiltInObjectToWorldKHR: + case BuiltInObjectToWorldNV: type = "float4x3"; break; - case BuiltInWorldToObjectKHR: + case BuiltInWorldToObjectNV: type = "float4x3"; break; - case BuiltInHitTKHR: + case BuiltInHitTNV: type = "float"; break; - case BuiltInHitKindKHR: + case BuiltInHitKindNV: type = "uint"; break; - case BuiltInIncomingRayFlagsKHR: + case BuiltInIncomingRayFlagsNV: type = "uint"; break; @@ -980,15 +980,15 @@ std::string CompilerHLSL::builtin_to_glsl(spv::BuiltIn builtin, spv::StorageClas return "gl_VertexID"; case BuiltInInstanceId: // In RT shaders, this builtin gets overwritten - if (execution.model == ExecutionModelIntersectionKHR || execution.model == ExecutionModelAnyHitKHR || - execution.model == ExecutionModelClosestHitKHR) + if (execution.model == ExecutionModelIntersectionNV || execution.model == ExecutionModelAnyHitNV || + execution.model == ExecutionModelClosestHitNV) return "InstanceIndex()"; else return "gl_InstanceID"; case BuiltInPrimitiveId: // In RT shaders, this builtin gets overwritten - if (execution.model == ExecutionModelIntersectionKHR || execution.model == ExecutionModelAnyHitKHR || - execution.model == ExecutionModelClosestHitKHR) + if (execution.model == ExecutionModelIntersectionNV || execution.model == ExecutionModelAnyHitNV || + execution.model == ExecutionModelClosestHitNV) return "PrimitiveIndex()"; else return "SV_PrimitiveID"; @@ -1010,33 +1010,33 @@ std::string CompilerHLSL::builtin_to_glsl(spv::BuiltIn builtin, spv::StorageClas return "WaveGetLaneIndex()"; case BuiltInSubgroupSize: return "WaveGetLaneCount()"; - case BuiltInLaunchIdKHR: + case BuiltInLaunchIdNV: return "DispatchRaysIndex()"; - case BuiltInLaunchSizeKHR: + case BuiltInLaunchSizeNV: return "DispatchRaysDimensions()"; - case BuiltInWorldRayOriginKHR: + case BuiltInWorldRayOriginNV: return "WorldRayOrigin()"; - case BuiltInWorldRayDirectionKHR: + case BuiltInWorldRayDirectionNV: return "WorldRayDirection()"; - case BuiltInObjectRayOriginKHR: + case BuiltInObjectRayOriginNV: return "ObjectRayOrigin()"; - case BuiltInObjectRayDirectionKHR: + case BuiltInObjectRayDirectionNV: return "ObjectRayDirection()"; - case BuiltInRayTminKHR: + case BuiltInRayTminNV: return "RayTMin()"; - case BuiltInRayTmaxKHR: + case BuiltInRayTmaxNV: return "RayTCurrent()"; - case BuiltInInstanceCustomIndexKHR: + case BuiltInInstanceCustomIndexNV: return "InstanceID()"; - case BuiltInObjectToWorldKHR: + case BuiltInObjectToWorldNV: return "ObjectToWorld4x3()"; - case BuiltInWorldToObjectKHR: + case BuiltInWorldToObjectNV: return "WorldToObject4x3()"; - case BuiltInHitTKHR: + case BuiltInHitTNV: return "RayTCurrent()"; - case BuiltInHitKindKHR: + case BuiltInHitKindNV: return "HitKind()"; - case BuiltInIncomingRayFlagsKHR: + case BuiltInIncomingRayFlagsNV: return "RayFlags()"; default: return CompilerGLSL::builtin_to_glsl(builtin, storage); @@ -1079,16 +1079,16 @@ void CompilerHLSL::emit_builtin_variables() case BuiltInInstanceId: // Ignore when used in RT shaders, this is no longer a compile-time constant - if (execution.model == ExecutionModelIntersectionKHR || execution.model == ExecutionModelAnyHitKHR || - execution.model == ExecutionModelClosestHitKHR) + if (execution.model == ExecutionModelIntersectionNV || execution.model == ExecutionModelAnyHitNV || + execution.model == ExecutionModelClosestHitNV) break; type = "int"; break; case BuiltInPrimitiveId: // Ignore when used in RT shaders, this is no longer a compile-time constant - if (execution.model == ExecutionModelIntersectionKHR || execution.model == ExecutionModelAnyHitKHR || - execution.model == ExecutionModelClosestHitKHR) + if (execution.model == ExecutionModelIntersectionNV || execution.model == ExecutionModelAnyHitNV || + execution.model == ExecutionModelClosestHitNV) break; type = "uint"; break; @@ -1152,20 +1152,20 @@ void CompilerHLSL::emit_builtin_variables() type = "float"; break; - case BuiltInLaunchIdKHR: - case BuiltInLaunchSizeKHR: - case BuiltInWorldRayOriginKHR: - case BuiltInWorldRayDirectionKHR: - case BuiltInObjectRayOriginKHR: - case BuiltInObjectRayDirectionKHR: - case BuiltInRayTminKHR: - case BuiltInRayTmaxKHR: - case BuiltInInstanceCustomIndexKHR: - case BuiltInObjectToWorldKHR: - case BuiltInWorldToObjectKHR: - case BuiltInHitTKHR: - case BuiltInHitKindKHR: - case BuiltInIncomingRayFlagsKHR: + case BuiltInLaunchIdNV: + case BuiltInLaunchSizeNV: + case BuiltInWorldRayOriginNV: + case BuiltInWorldRayDirectionNV: + case BuiltInObjectRayOriginNV: + case BuiltInObjectRayDirectionNV: + case BuiltInRayTminNV: + case BuiltInRayTmaxNV: + case BuiltInInstanceCustomIndexNV: + case BuiltInObjectToWorldNV: + case BuiltInWorldToObjectNV: + case BuiltInHitTNV: + case BuiltInHitKindNV: + case BuiltInIncomingRayFlagsNV: // handled specially since they aren't compile time constants break; @@ -1358,10 +1358,10 @@ void CompilerHLSL::emit_resources() } // HLSL requires ray payloads to be structs - if (execution.model == ExecutionModelClosestHitKHR || execution.model == ExecutionModelAnyHitKHR || - execution.model == ExecutionModelMissKHR) + if (execution.model == ExecutionModelClosestHitNV || execution.model == ExecutionModelAnyHitNV || + execution.model == ExecutionModelMissNV) { - auto *payload_var = get_ray_tracing_payload(); + auto *payload_var = get_ray_tracing_in_payload(); if (payload_var) { std::string name = get_name(payload_var->self); @@ -1384,8 +1384,8 @@ void CompilerHLSL::emit_resources() } // HLSL requires hit attributes to be structs - if (execution.model == ExecutionModelClosestHitKHR || execution.model == ExecutionModelAnyHitKHR || - ExecutionModelIntersectionKHR) + if (execution.model == ExecutionModelClosestHitNV || execution.model == ExecutionModelAnyHitNV || + ExecutionModelIntersectionNV) { auto *hitattrib_var = get_ray_tracing_hit_attrib(); if (hitattrib_var) @@ -1424,8 +1424,14 @@ void CompilerHLSL::emit_resources() if (var.storage != StorageClassFunction && !is_builtin_variable(var) && !var.remapped_variable && type.pointer && (type.storage == StorageClassUniformConstant || type.storage == StorageClassAtomicCounter || - type.storage == StorageClassRayPayloadKHR)) + type.storage == StorageClassRayPayloadNV)) { + // make rgen payload static + auto &type = this->get(var.basetype); + if (type.storage == StorageClassRayPayloadNV) + { + statement_inner("static "); + } emit_uniform(var); emitted = true; } @@ -1536,8 +1542,8 @@ void CompilerHLSL::emit_resources() auto input_builtins = active_input_builtins; // only make non-compile-time constant when used in RT shader - if (execution.model == ExecutionModelIntersectionKHR || execution.model == ExecutionModelAnyHitKHR || - execution.model == ExecutionModelClosestHitKHR) + if (execution.model == ExecutionModelIntersectionNV || execution.model == ExecutionModelAnyHitNV || + execution.model == ExecutionModelClosestHitNV) { input_builtins.clear(BuiltInInstanceId); input_builtins.clear(BuiltInPrimitiveId); @@ -1552,20 +1558,20 @@ void CompilerHLSL::emit_resources() input_builtins.clear(BuiltInSubgroupGtMask); input_builtins.clear(BuiltInSubgroupGeMask); // RT builtins - input_builtins.clear(BuiltInLaunchIdKHR); - input_builtins.clear(BuiltInLaunchSizeKHR); - input_builtins.clear(BuiltInWorldRayOriginKHR); - input_builtins.clear(BuiltInWorldRayDirectionKHR); - input_builtins.clear(BuiltInObjectRayOriginKHR); - input_builtins.clear(BuiltInObjectRayDirectionKHR); - input_builtins.clear(BuiltInRayTminKHR); - input_builtins.clear(BuiltInRayTmaxKHR); - input_builtins.clear(BuiltInInstanceCustomIndexKHR); - input_builtins.clear(BuiltInObjectToWorldKHR); - input_builtins.clear(BuiltInWorldToObjectKHR); - input_builtins.clear(BuiltInHitTKHR); - input_builtins.clear(BuiltInHitKindKHR); - input_builtins.clear(BuiltInIncomingRayFlagsKHR); + input_builtins.clear(BuiltInLaunchIdNV); + input_builtins.clear(BuiltInLaunchSizeNV); + input_builtins.clear(BuiltInWorldRayOriginNV); + input_builtins.clear(BuiltInWorldRayDirectionNV); + input_builtins.clear(BuiltInObjectRayOriginNV); + input_builtins.clear(BuiltInObjectRayDirectionNV); + input_builtins.clear(BuiltInRayTminNV); + input_builtins.clear(BuiltInRayTmaxNV); + input_builtins.clear(BuiltInInstanceCustomIndexNV); + input_builtins.clear(BuiltInObjectToWorldNV); + input_builtins.clear(BuiltInWorldToObjectNV); + input_builtins.clear(BuiltInHitTNV); + input_builtins.clear(BuiltInHitKindNV); + input_builtins.clear(BuiltInIncomingRayFlagsNV); if (!input_variables.empty() || !input_builtins.empty()) { @@ -2376,17 +2382,17 @@ void CompilerHLSL::emit_function_prototype(SPIRFunction &func, const Bitset &ret decl += "frag_main"; else if (execution.model == ExecutionModelGLCompute) decl += "comp_main"; - else if (execution.model == ExecutionModelRayGenerationKHR) + else if (execution.model == ExecutionModelRayGenerationNV) decl += "rgen_main"; - else if (execution.model == ExecutionModelIntersectionKHR) + else if (execution.model == ExecutionModelIntersectionNV) decl += "rint_main"; - else if (execution.model == ExecutionModelAnyHitKHR) + else if (execution.model == ExecutionModelAnyHitNV) decl += "rahit_main"; - else if (execution.model == ExecutionModelClosestHitKHR) + else if (execution.model == ExecutionModelClosestHitNV) decl += "rchit_main"; - else if (execution.model == ExecutionModelMissKHR) + else if (execution.model == ExecutionModelMissNV) decl += "rmiss_main"; - else if (execution.model == ExecutionModelCallableKHR) + else if (execution.model == ExecutionModelCallableNV) decl += "call_main"; else SPIRV_CROSS_THROW("Unsupported execution model."); @@ -2510,32 +2516,32 @@ void CompilerHLSL::emit_hlsl_entry_point() statement("[numthreads(", x_expr, ", ", y_expr, ", ", z_expr, ")]"); break; } - case ExecutionModelRayGenerationKHR: + case ExecutionModelRayGenerationNV: { statement("[shader(\"raygeneration\")]"); break; } - case ExecutionModelIntersectionKHR: + case ExecutionModelIntersectionNV: { statement("[shader(\"intersection\")]"); break; } - case ExecutionModelAnyHitKHR: + case ExecutionModelAnyHitNV: { statement("[shader(\"anyhit\")]"); break; } - case ExecutionModelClosestHitKHR: + case ExecutionModelClosestHitNV: { statement("[shader(\"closesthit\")]"); break; } - case ExecutionModelMissKHR: + case ExecutionModelMissNV: { statement("[shader(\"miss\")]"); break; } - case ExecutionModelCallableKHR: + case ExecutionModelCallableNV: { statement("[shader(\"callable\")]"); break; @@ -2549,13 +2555,13 @@ void CompilerHLSL::emit_hlsl_entry_point() } // Add incoming payload and hit attributes for Hit shaders - if (execution.model == ExecutionModelClosestHitKHR || execution.model == ExecutionModelAnyHitKHR || - execution.model == ExecutionModelMissKHR) + if (execution.model == ExecutionModelClosestHitNV || execution.model == ExecutionModelAnyHitNV || + execution.model == ExecutionModelMissNV) { // Add incoming payload { string out_argument; - auto *payload_var = get_ray_tracing_payload(); + auto *payload_var = get_ray_tracing_in_payload(); out_argument += "inout "; if (payload_var) { @@ -2574,7 +2580,7 @@ void CompilerHLSL::emit_hlsl_entry_point() arguments.push_back(move(out_argument)); } // Add incoming hit attribute - if (execution.model == ExecutionModelClosestHitKHR || execution.model == ExecutionModelAnyHitKHR) + if (execution.model == ExecutionModelClosestHitNV || execution.model == ExecutionModelAnyHitNV) { string out_argument; const SPIRVariable *hitattrib_var = get_ray_tracing_hit_attrib(); @@ -2634,8 +2640,8 @@ void CompilerHLSL::emit_hlsl_entry_point() case BuiltInInstanceId: // In RT shaders this is not a compile-time constant - if (execution.model != ExecutionModelIntersectionKHR && execution.model != ExecutionModelAnyHitKHR && - execution.model != ExecutionModelClosestHitKHR) + if (execution.model != ExecutionModelIntersectionNV && execution.model != ExecutionModelAnyHitNV && + execution.model != ExecutionModelClosestHitNV) { // D3D semantics are uint, but shader wants int. statement(builtin, " = int(stage_input.", builtin, ");"); @@ -2644,8 +2650,8 @@ void CompilerHLSL::emit_hlsl_entry_point() case BuiltInPrimitiveId: // In RT shaders this is not a compile-time constant - if (execution.model != ExecutionModelIntersectionKHR && execution.model != ExecutionModelAnyHitKHR && - execution.model != ExecutionModelClosestHitKHR) + if (execution.model != ExecutionModelIntersectionNV && execution.model != ExecutionModelAnyHitNV && + execution.model != ExecutionModelClosestHitNV) { statement(builtin, " = stage_input.", builtin, ";"); } @@ -2731,20 +2737,20 @@ void CompilerHLSL::emit_hlsl_entry_point() ";"); break; - case BuiltInLaunchIdKHR: - case BuiltInLaunchSizeKHR: - case BuiltInWorldRayOriginKHR: - case BuiltInWorldRayDirectionKHR: - case BuiltInObjectRayOriginKHR: - case BuiltInObjectRayDirectionKHR: - case BuiltInRayTminKHR: - case BuiltInRayTmaxKHR: - case BuiltInInstanceCustomIndexKHR: - case BuiltInObjectToWorldKHR: - case BuiltInWorldToObjectKHR: - case BuiltInHitTKHR: - case BuiltInHitKindKHR: - case BuiltInIncomingRayFlagsKHR: + case BuiltInLaunchIdNV: + case BuiltInLaunchSizeNV: + case BuiltInWorldRayOriginNV: + case BuiltInWorldRayDirectionNV: + case BuiltInObjectRayOriginNV: + case BuiltInObjectRayDirectionNV: + case BuiltInRayTminNV: + case BuiltInRayTmaxNV: + case BuiltInInstanceCustomIndexNV: + case BuiltInObjectToWorldNV: + case BuiltInWorldToObjectNV: + case BuiltInHitTNV: + case BuiltInHitKindNV: + case BuiltInIncomingRayFlagsNV: // handled specially since they aren't compile time constants break; @@ -2790,10 +2796,10 @@ void CompilerHLSL::emit_hlsl_entry_point() }); // Copy the payload result in - if (execution.model == ExecutionModelClosestHitKHR || execution.model == ExecutionModelAnyHitKHR || - execution.model == ExecutionModelMissKHR) + if (execution.model == ExecutionModelClosestHitNV || execution.model == ExecutionModelAnyHitNV || + execution.model == ExecutionModelMissNV) { - auto *payload_var = get_ray_tracing_payload(); + auto *payload_var = get_ray_tracing_in_payload(); if (payload_var) { std::string name = get_name(payload_var->self); @@ -2806,7 +2812,7 @@ void CompilerHLSL::emit_hlsl_entry_point() } // Copy the hit attribute result in - if (execution.model == ExecutionModelClosestHitKHR || execution.model == ExecutionModelAnyHitKHR) + if (execution.model == ExecutionModelClosestHitNV || execution.model == ExecutionModelAnyHitNV) { const SPIRVariable *hitattrib_var = get_ray_tracing_hit_attrib(); if (hitattrib_var) @@ -2827,26 +2833,26 @@ void CompilerHLSL::emit_hlsl_entry_point() statement("frag_main();"); else if (execution.model == ExecutionModelGLCompute) statement("comp_main();"); - else if (execution.model == ExecutionModelRayGenerationKHR) + else if (execution.model == ExecutionModelRayGenerationNV) statement("rgen_main();"); - else if (execution.model == ExecutionModelIntersectionKHR) + else if (execution.model == ExecutionModelIntersectionNV) statement("rint_main();"); - else if (execution.model == ExecutionModelAnyHitKHR) + else if (execution.model == ExecutionModelAnyHitNV) statement("rahit_main();"); - else if (execution.model == ExecutionModelClosestHitKHR) + else if (execution.model == ExecutionModelClosestHitNV) statement("rchit_main();"); - else if (execution.model == ExecutionModelMissKHR) + else if (execution.model == ExecutionModelMissNV) statement("rmiss_main();"); - else if (execution.model == ExecutionModelCallableKHR) + else if (execution.model == ExecutionModelCallableNV) statement("call_main();"); else SPIRV_CROSS_THROW("Unsupported shader stage."); // Copy the payload result back - if (execution.model == ExecutionModelClosestHitKHR || execution.model == ExecutionModelAnyHitKHR || - execution.model == ExecutionModelMissKHR) + if (execution.model == ExecutionModelClosestHitNV || execution.model == ExecutionModelAnyHitNV || + execution.model == ExecutionModelMissNV) { - auto *payload_var = get_ray_tracing_payload(); + auto *payload_var = get_ray_tracing_in_payload(); if (payload_var) { std::string name = get_name(payload_var->self); @@ -3634,13 +3640,6 @@ void CompilerHLSL::emit_legacy_uniform(const SPIRVariable &var) void CompilerHLSL::emit_uniform(const SPIRVariable &var) { - // make rgen payload static - auto &type = this->get(var.basetype); - if (type.storage == StorageClassRayPayloadKHR) - { - statement_inner("static "); - } - add_resource_name(var.self); if (hlsl_options.shader_model >= 40) emit_modern_uniform(var); @@ -5619,11 +5618,11 @@ void CompilerHLSL::emit_instruction(const Instruction &instruction) SPIRV_CROSS_THROW("Rasterizer order views require Shader Model 5.1."); break; // Nothing to do in the body - case OpReportIntersectionKHR: + case OpReportIntersectionNV: { auto *hitattrib_var = get_ray_tracing_hit_attrib(); if (!hitattrib_var) - SPIRV_CROSS_THROW("Failed to lookup hit attribute for OpReportIntersectionKHR"); + SPIRV_CROSS_THROW("Failed to lookup hit attribute for OpReportIntersectionNV"); bool is_primitive_attr = get_type(hitattrib_var->basetype).basetype != SPIRType::Struct; @@ -5647,14 +5646,14 @@ void CompilerHLSL::emit_instruction(const Instruction &instruction) statement("ReportHit(", to_expression(ops[2]), ",", to_expression(ops[3]), ",", target_attr_name, ");"); break; } - case OpIgnoreIntersectionKHR: + case OpIgnoreIntersectionNV: statement("IgnoreHit();"); break; - case OpTerminateRayKHR: + case OpTerminateRayNV: statement("AcceptHitAndEndSearch();"); break; - case OpTraceRayKHR: + case OpTraceNV: { // In GLSL, payload is passed as a number and is a compile-time constant // In HLSL, payload is passed as a variable @@ -5662,7 +5661,7 @@ void CompilerHLSL::emit_instruction(const Instruction &instruction) // input location's index, which is passed in the last param of TraceRay uint32_t payload_index = std::stoi(to_expression(ops[10])); - auto *payload_var = find_storage_class_variable_by_location(StorageClassRayPayloadKHR, payload_index); + auto *payload_var = find_storage_class_variable_by_location(StorageClassRayPayloadNV, payload_index); if (!payload_var) SPIRV_CROSS_THROW("Failed to lookup location of rayPayloadEXT"); @@ -5942,12 +5941,12 @@ string CompilerHLSL::get_unique_identifier() return join("_", unique_identifier_count++, "ident"); } -const SPIRVariable *CompilerHLSL::get_ray_tracing_payload() +const SPIRVariable *CompilerHLSL::get_ray_tracing_in_payload() { const SPIRVariable *ret = nullptr; // Find incoming payload ir.for_each_typed_id([&](uint32_t, const SPIRVariable &var) { - if (var.storage == StorageClassIncomingRayPayloadKHR) + if (var.storage == StorageClassIncomingRayPayloadNV) ret = &var; }); return ret; @@ -5958,7 +5957,7 @@ const SPIRVariable *CompilerHLSL::get_ray_tracing_hit_attrib() const SPIRVariable *ret = nullptr; // Find incoming hit attribute ir.for_each_typed_id([&](uint32_t, const SPIRVariable &var) { - if (var.storage == StorageClassHitAttributeKHR) + if (var.storage == StorageClassHitAttributeNV) ret = &var; }); return ret; diff --git a/spirv_hlsl.hpp b/spirv_hlsl.hpp index b0fc0b411..9191154ed 100644 --- a/spirv_hlsl.hpp +++ b/spirv_hlsl.hpp @@ -321,7 +321,7 @@ class CompilerHLSL : public CompilerGLSL std::string get_unique_identifier(); uint32_t unique_identifier_count = 0; - const SPIRVariable *get_ray_tracing_payload(); + const SPIRVariable *get_ray_tracing_in_payload(); const SPIRVariable *get_ray_tracing_hit_attrib(); std::unordered_map, InternalHasher> resource_bindings; From 8979b0bf4d72ee275d9cb22ee6445503634595c6 Mon Sep 17 00:00:00 2001 From: Felix Maier Date: Wed, 20 May 2020 12:27:59 +0200 Subject: [PATCH 6/6] Fix comparison --- spirv_hlsl.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spirv_hlsl.cpp b/spirv_hlsl.cpp index 279d739a0..0f83f7a04 100644 --- a/spirv_hlsl.cpp +++ b/spirv_hlsl.cpp @@ -1385,7 +1385,7 @@ void CompilerHLSL::emit_resources() // HLSL requires hit attributes to be structs if (execution.model == ExecutionModelClosestHitNV || execution.model == ExecutionModelAnyHitNV || - ExecutionModelIntersectionNV) + execution.model == ExecutionModelIntersectionNV) { auto *hitattrib_var = get_ray_tracing_hit_attrib(); if (hitattrib_var)