@@ -1818,12 +1818,6 @@ static SpvReflectResult ParseType(
18181818 p_type -> type_flags |= SPV_REFLECT_TYPE_FLAG_EXTERNAL_ACCELERATION_STRUCTURE ;
18191819 }
18201820 break ;
1821-
1822- case SpvOpSpecConstantTrue :
1823- case SpvOpSpecConstantFalse :
1824- case SpvOpSpecConstant : {
1825- }
1826- break ;
18271821 }
18281822
18291823 if (result == SPV_REFLECT_RESULT_SUCCESS ) {
@@ -3383,6 +3377,50 @@ static SpvReflectResult ParseExecutionModes(
33833377 return SPV_REFLECT_RESULT_SUCCESS ;
33843378}
33853379
3380+ SpvReflectResult GetTypeByTypeId (SpvReflectShaderModule * p_module , uint32_t type_id , SpvReflectTypeDescription * * pp_type )
3381+ {
3382+ SpvReflectResult res = SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_ID_REFERENCE ;
3383+ for (uint32_t i = 0 ; i < p_module -> _internal -> type_description_count ; ++ i ) {
3384+ if (p_module -> _internal -> type_descriptions [i ].id == type_id ) {
3385+ * pp_type = & p_module -> _internal -> type_descriptions [i ];
3386+ return SPV_REFLECT_RESULT_SUCCESS ;
3387+ }
3388+ }
3389+ return res ;
3390+ }
3391+
3392+ static SpvReflectResult GetScalarConstant (SpvReflectPrvParser * p_parser , SpvReflectShaderModule * p_module , SpvReflectPrvNode * p_node , SpvReflectScalarValue * result )
3393+ {
3394+ SpvReflectTypeDescription * type ;
3395+ SpvReflectResult res = GetTypeByTypeId (p_module , p_node -> result_type_id , & type );
3396+ if (res != SPV_REFLECT_RESULT_SUCCESS ) return res ;
3397+ if (type -> type_flags & SPV_REFLECT_TYPE_FLAG_INT ){
3398+ result -> type = SPV_REFLECT_SCALAR_TYPE_INT ;
3399+ }
3400+ else if (type -> type_flags & SPV_REFLECT_TYPE_FLAG_FLOAT ) {
3401+ result -> type = SPV_REFLECT_SCALAR_TYPE_FLOAT ;
3402+ }
3403+ else {
3404+ result -> type = SPV_REFLECT_SCALAR_TYPE_UNKNOWN ;
3405+ }
3406+ result -> bit_size = type -> traits .numeric .scalar .width ;
3407+ uint32_t low_word ;
3408+ CHECKED_READU32 (p_parser , p_node -> word_offset + 3 , low_word );
3409+ if (type -> traits .numeric .scalar .width == 32 ) {
3410+ result -> value .int_bool_value = low_word ;
3411+ return SPV_REFLECT_RESULT_SUCCESS ;
3412+ }
3413+ else if (type -> traits .numeric .scalar .width == 64 ) {
3414+ uint32_t high_word ;
3415+ CHECKED_READU32 (p_parser , p_node -> word_offset + 4 , high_word );
3416+ result -> value .int64_value = low_word | (((uint64_t )high_word ) << 32 );
3417+ return SPV_REFLECT_RESULT_SUCCESS ;
3418+ }
3419+ else {
3420+ return SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_TYPE ;
3421+ }
3422+ }
3423+
33863424static SpvReflectResult ParseSpecializationConstants (SpvReflectPrvParser * p_parser , SpvReflectShaderModule * p_module )
33873425{
33883426 p_module -> specialization_constant_count = 0 ;
@@ -3408,37 +3446,26 @@ static SpvReflectResult ParseSpecializationConstants(SpvReflectPrvParser* p_pars
34083446 switch (p_node -> op ) {
34093447 default : continue ;
34103448 case SpvOpSpecConstantTrue : {
3411- p_module -> specialization_constants [index ].constant_type = SPV_REFLECT_SPECIALIZATION_CONSTANT_BOOL ;
3412- p_module -> specialization_constants [index ].default_value .int_bool_value = 1 ;
3413- p_module -> specialization_constants [index ].current_value .int_bool_value = 1 ;
3449+ p_module -> specialization_constants [index ].default_value .type = SPV_REFLECT_SCALAR_TYPE_BOOL ;
3450+ p_module -> specialization_constants [index ].default_value .value .int_bool_value = 1 ;
3451+ p_module -> specialization_constants [index ].default_value .bit_size = 1 ;
3452+ p_module -> specialization_constants [index ].default_value .is_signed = 0 ;
3453+ p_module -> specialization_constants [index ].current_value = p_module -> specialization_constants [index ].default_value ;
34143454 } break ;
34153455 case SpvOpSpecConstantFalse : {
3416- p_module -> specialization_constants [index ].constant_type = SPV_REFLECT_SPECIALIZATION_CONSTANT_BOOL ;
3417- p_module -> specialization_constants [index ].default_value .int_bool_value = 0 ;
3418- p_module -> specialization_constants [index ].current_value .int_bool_value = 0 ;
3456+ p_module -> specialization_constants [index ].default_value .type = SPV_REFLECT_SCALAR_TYPE_BOOL ;
3457+ p_module -> specialization_constants [index ].default_value .value .int_bool_value = 1 ;
3458+ p_module -> specialization_constants [index ].default_value .bit_size = 1 ;
3459+ p_module -> specialization_constants [index ].default_value .is_signed = 0 ;
3460+ p_module -> specialization_constants [index ].current_value = p_module -> specialization_constants [index ].default_value ;
34193461 } break ;
34203462 case SpvOpSpecConstant : {
34213463 SpvReflectResult result = SPV_REFLECT_RESULT_SUCCESS ;
3422- uint32_t element_type_id = (uint32_t )INVALID_VALUE ;
3423- uint32_t default_value = 0 ;
3424- IF_READU32 (result , p_parser , p_node -> word_offset + 1 , element_type_id );
3425- // only support 32 bit arguments here...
3426- IF_READU32 (result , p_parser , p_node -> word_offset + 3 , default_value );
3427-
3428- SpvReflectPrvNode * p_next_node = FindNode (p_parser , element_type_id );
3429- if (IsNull (p_next_node )){
3430- return SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_ID_REFERENCE ;
3431- }
3432- if (p_next_node -> op == SpvOpTypeInt ) {
3433- p_module -> specialization_constants [index ].constant_type = SPV_REFLECT_SPECIALIZATION_CONSTANT_INT ;
3434- } else if (p_next_node -> op == SpvOpTypeFloat ) {
3435- p_module -> specialization_constants [index ].constant_type = SPV_REFLECT_SPECIALIZATION_CONSTANT_FLOAT ;
3436- } else {
3437- return SPV_REFLECT_RESULT_ERROR_PARSE_FAILED ;
3438- }
3439-
3440- p_module -> specialization_constants [index ].default_value .int_bool_value = default_value ; //bits are the same for int and float
3441- p_module -> specialization_constants [index ].current_value .int_bool_value = default_value ;
3464+ SpvReflectScalarValue default_value = { 0 };
3465+ result = GetScalarConstant (p_parser ,p_module , p_node , & default_value );
3466+ if (result != SPV_REFLECT_RESULT_SUCCESS ) return result ;
3467+ p_module -> specialization_constants [index ].default_value = default_value ;
3468+ p_module -> specialization_constants [index ].current_value = p_module -> specialization_constants [index ].default_value ;
34423469 } break ;
34433470 }
34443471 // spec constant id cannot be the same, at least for valid values. (invalid value is just constant?)
@@ -5235,57 +5262,40 @@ SpvReflectResult GetSpecContantById(SpvReflectShaderModule* p_module, uint32_t c
52355262}
52365263
52375264// used for calculating specialization constants.
5238- // maybe check for recursion?
5239- SpvReflectResult EvaluateResult (SpvReflectShaderModule * p_module , uint32_t result_id , SpvReflectScalarValue * result )
5265+ SpvReflectResult EvaluateResultImpl (SpvReflectShaderModule * p_module , uint32_t result_id , SpvReflectScalarValue * result , uint32_t maxRecursion )
52405266{
5241- if (!result || !p_module ) {
5242- return SPV_REFLECT_RESULT_ERROR_NULL_POINTER ;
5243- }
5244- if ((p_module -> _internal -> module_flags & SPV_REFLECT_MODULE_FLAG_EVALUATE_SPEC_CONSTANT )== 0 ){
5245- return SPV_REFLECT_RESULT_ERROR_PARSE_FAILED ;
5246- }
5267+ if (!maxRecursion ) return SPV_REFLECT_RESULT_ERROR_SPIRV_RECURSION ;
52475268 SpvReflectResult res ;
52485269 SpvReflectPrvParser * p_parser = p_module -> _internal -> parser ;
52495270 SpvReflectPrvNode * p_node = FindNode (p_parser , result_id );
5250- if (!p_node ) return SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_ID_REFERENCE ;
5271+ if (!p_node ) return SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_ID_REFERENCE ;
52515272 switch (p_node -> op ) {
52525273 default :
52535274 return SPV_REFLECT_RESULT_ERROR_SPIRV_UNRESOLVED_EVALUATION ;
52545275 case SpvOpConstant :
52555276 CONSTANT_RESULT :
5256- /*switch(){
5257-
5258- }*/
5259- // read constant value
5260- return SPV_REFLECT_RESULT_SUCCESS ;
5261- case SpvOpSpecConstant : {
5277+ return GetScalarConstant (p_parser , p_module , p_node , result );
5278+
5279+ case SpvOpSpecConstant :
5280+ {
52625281 if (p_node -> decorations .specialization_constant .value == (uint32_t )INVALID_VALUE ) {
52635282 goto CONSTANT_RESULT ;
52645283 }
52655284 SpvReflectSpecializationConstant * p_constant ;
52665285 res = GetSpecContantById (p_module , p_node -> decorations .specialization_constant .value , & p_constant );
5267- if (res != SPV_REFLECT_RESULT_SUCCESS ) return res ;
5268- if (p_constant -> constant_type == SPV_REFLECT_SPECIALIZATION_CONSTANT_INT ||
5269- p_constant -> constant_type == SPV_REFLECT_SPECIALIZATION_CONSTANT_BOOL ){
5270- result -> int_bool_value = p_constant -> current_value .int_bool_value ;
5271- }
5272- else if (p_constant -> constant_type == SPV_REFLECT_SPECIALIZATION_CONSTANT_FLOAT ) {
5273- result -> float_value = p_constant -> current_value .float_value ;
5274- }
5275- else {
5276- return SPV_REFLECT_RESULT_ERROR_INTERNAL_ERROR ;
5277- }
5286+ if (res != SPV_REFLECT_RESULT_SUCCESS ) return res ;
5287+ * result = p_constant -> current_value ;
52785288 }
52795289 return SPV_REFLECT_RESULT_SUCCESS ;
52805290 case SpvOpSpecConstantComposite :
52815291 {
5282-
5292+ // only support scalar types for now...
52835293 }
52845294 return SPV_REFLECT_RESULT_ERROR_SPIRV_UNRESOLVED_EVALUATION ;
52855295 case SpvOpUndef :
52865296 {
52875297 // invalid data should be handled separately?
5288- result -> int_bool_value = (uint32_t )INVALID_VALUE ;
5298+ result -> value . int_bool_value = (uint32_t )INVALID_VALUE ;
52895299 }
52905300 case SpvOpSConvert : case SpvOpUConvert : case SpvOpFConvert :
52915301 case SpvOpSNegate : case SpvOpNot : case SpvOpIAdd : case SpvOpISub :
@@ -5305,19 +5315,19 @@ SpvReflectResult EvaluateResult(SpvReflectShaderModule* p_module, uint32_t resul
53055315 // add implementations here.
53065316 return SPV_REFLECT_RESULT_ERROR_SPIRV_UNRESOLVED_EVALUATION ;
53075317
5308- // check shader capability... vulkan should assume this...
5318+ // check shader capability... vulkan should assume this...
53095319 case SpvOpQuantizeToF16 :
53105320 return SPV_REFLECT_RESULT_ERROR_SPIRV_UNRESOLVED_EVALUATION ;
53115321 break ;
53125322
5313- // check kernel capability... vulkan have none currently...
5323+ // check kernel capability... vulkan have none currently...
53145324 case SpvOpConvertFToS : case SpvOpConvertSToF :
53155325 case SpvOpConvertFToU : case SpvOpConvertUToF :
53165326 case SpvOpConvertPtrToU : case SpvOpConvertUToPtr :
53175327 case SpvOpGenericCastToPtr : case SpvOpPtrCastToGeneric :
5318- case SpvOpBitcast : case SpvOpFNegate :
5328+ case SpvOpBitcast : case SpvOpFNegate :
53195329 case SpvOpFAdd : case SpvOpFSub : case SpvOpFMul : case SpvOpFDiv :
5320- case SpvOpFRem : case SpvOpFMod :
5330+ case SpvOpFRem : case SpvOpFMod :
53215331 case SpvOpAccessChain : case SpvOpInBoundsAccessChain :
53225332 case SpvOpPtrAccessChain : case SpvOpInBoundsPtrAccessChain :
53235333 return SPV_REFLECT_RESULT_ERROR_SPIRV_UNRESOLVED_EVALUATION ;
@@ -5326,3 +5336,16 @@ SpvReflectResult EvaluateResult(SpvReflectShaderModule* p_module, uint32_t resul
53265336}
53275337
53285338
5339+ SpvReflectResult EvaluateResult (SpvReflectShaderModule * p_module , uint32_t result_id , SpvReflectScalarValue * result )
5340+ {
5341+ if (!result || !p_module ) {
5342+ return SPV_REFLECT_RESULT_ERROR_NULL_POINTER ;
5343+ }
5344+ if ((p_module -> _internal -> module_flags & SPV_REFLECT_MODULE_FLAG_EVALUATE_SPEC_CONSTANT )== 0 ){
5345+ return SPV_REFLECT_RESULT_ERROR_PARSE_FAILED ;
5346+ }
5347+ // compute at most 100 instruction levels, maybe defining somewhere is better.
5348+ return EvaluateResultImpl (p_module , result_id , result , 100 );
5349+ }
5350+
5351+
0 commit comments