Skip to content

Commit e1dfa44

Browse files
committed
Add 64 bit type support, cleanup value clutter
Next step would be the constant evaluator
1 parent 4937236 commit e1dfa44

File tree

3 files changed

+150
-81
lines changed

3 files changed

+150
-81
lines changed

common/output_stream.cpp

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1047,18 +1047,53 @@ void StreamWriteSpecializationConstant(std::ostream& os, const SpvReflectSpecial
10471047
os << t << "constant id: " << obj.constant_id << "\n";
10481048
os << t << "name : " << (obj.name != NULL ? obj.name : "") << '\n';
10491049
os << t << "type : ";
1050-
switch (obj.constant_type) {
1051-
case SPV_REFLECT_SPECIALIZATION_CONSTANT_BOOL:
1050+
switch (obj.default_value.type) {
1051+
case SPV_REFLECT_SCALAR_TYPE_BOOL:
10521052
os << "boolean\n";
1053-
os << t << "default : " << obj.default_value.int_bool_value;
1053+
os << t << "default : " << obj.default_value.value.int_bool_value;
10541054
break;
1055-
case SPV_REFLECT_SPECIALIZATION_CONSTANT_INT:
1056-
os << "integer\n";
1057-
os << t << "default : "<<obj.default_value.int_bool_value;
1055+
case SPV_REFLECT_SCALAR_TYPE_INT:
1056+
if (obj.default_value.is_signed) {
1057+
os << "signed ";
1058+
}
1059+
else {
1060+
os << "unsigned ";
1061+
}
1062+
os<<obj.default_value.bit_size <<" bit integer\n";
1063+
os << t << "default : ";
1064+
// let's assume only 32 bit and 64 bit types (no 8 and 16 bit types here)
1065+
if (obj.default_value.bit_size == 32) {
1066+
if (obj.default_value.is_signed) {
1067+
os << (int32_t)obj.default_value.value.int_bool_value;
1068+
}
1069+
else {
1070+
os << (uint32_t)obj.default_value.value.int_bool_value;
1071+
}
1072+
}
1073+
else if(obj.default_value.bit_size == 64){
1074+
if (obj.default_value.is_signed) {
1075+
os << (int64_t)obj.default_value.value.int64_value;
1076+
}
1077+
else {
1078+
os << (uint32_t)obj.default_value.value.int64_value;
1079+
}
1080+
}
1081+
else {
1082+
os << "default value not native in c/cpp";
1083+
}
10581084
break;
1059-
case SPV_REFLECT_SPECIALIZATION_CONSTANT_FLOAT:
1060-
os << "float\n";
1061-
os << t << "default : " << obj.default_value.float_value;
1085+
case SPV_REFLECT_SCALAR_TYPE_FLOAT:
1086+
os << obj.default_value.bit_size << " bit floating point\n";
1087+
os << t << "default : ";
1088+
if (obj.default_value.bit_size == 32) {
1089+
os << obj.default_value.value.float_value;
1090+
}
1091+
else if (obj.default_value.bit_size == 64) {
1092+
os << obj.default_value.value.float64_value;
1093+
}
1094+
else {
1095+
os << "default value not native in c/cpp";
1096+
}
10621097
break;
10631098
default:
10641099
os << "unknown type";

spirv_reflect.c

Lines changed: 87 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1818,12 +1818,6 @@ static SpvReflectResult ParseType(
18181818
p_type->type_flags |= SPV_REFLECT_TYPE_FLAG_EXTERNAL_ACCELERATION_STRUCTURE;
18191819
}
18201820
break;
1821-
1822-
case SpvOpSpecConstantTrue:
1823-
case SpvOpSpecConstantFalse:
1824-
case SpvOpSpecConstant: {
1825-
}
1826-
break;
18271821
}
18281822

18291823
if (result == SPV_REFLECT_RESULT_SUCCESS) {
@@ -3383,6 +3377,50 @@ static SpvReflectResult ParseExecutionModes(
33833377
return SPV_REFLECT_RESULT_SUCCESS;
33843378
}
33853379

3380+
SpvReflectResult GetTypeByTypeId(SpvReflectShaderModule* p_module, uint32_t type_id, SpvReflectTypeDescription** pp_type)
3381+
{
3382+
SpvReflectResult res = SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_ID_REFERENCE;
3383+
for (uint32_t i = 0; i < p_module->_internal->type_description_count; ++i) {
3384+
if (p_module->_internal->type_descriptions[i].id == type_id) {
3385+
*pp_type = &p_module->_internal->type_descriptions[i];
3386+
return SPV_REFLECT_RESULT_SUCCESS;
3387+
}
3388+
}
3389+
return res;
3390+
}
3391+
3392+
static SpvReflectResult GetScalarConstant(SpvReflectPrvParser* p_parser, SpvReflectShaderModule* p_module, SpvReflectPrvNode* p_node, SpvReflectScalarValue* result)
3393+
{
3394+
SpvReflectTypeDescription* type;
3395+
SpvReflectResult res = GetTypeByTypeId(p_module, p_node->result_type_id, &type);
3396+
if (res != SPV_REFLECT_RESULT_SUCCESS) return res;
3397+
if(type->type_flags & SPV_REFLECT_TYPE_FLAG_INT){
3398+
result->type = SPV_REFLECT_SCALAR_TYPE_INT;
3399+
}
3400+
else if (type->type_flags & SPV_REFLECT_TYPE_FLAG_FLOAT) {
3401+
result->type = SPV_REFLECT_SCALAR_TYPE_FLOAT;
3402+
}
3403+
else{
3404+
result->type = SPV_REFLECT_SCALAR_TYPE_UNKNOWN;
3405+
}
3406+
result->bit_size = type->traits.numeric.scalar.width;
3407+
uint32_t low_word;
3408+
CHECKED_READU32(p_parser, p_node->word_offset + 3, low_word);
3409+
if (type->traits.numeric.scalar.width == 32) {
3410+
result->value.int_bool_value = low_word;
3411+
return SPV_REFLECT_RESULT_SUCCESS;
3412+
}
3413+
else if (type->traits.numeric.scalar.width ==64) {
3414+
uint32_t high_word;
3415+
CHECKED_READU32(p_parser, p_node->word_offset + 4, high_word);
3416+
result->value.int64_value = low_word | (((uint64_t)high_word) << 32);
3417+
return SPV_REFLECT_RESULT_SUCCESS;
3418+
}
3419+
else {
3420+
return SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_TYPE;
3421+
}
3422+
}
3423+
33863424
static SpvReflectResult ParseSpecializationConstants(SpvReflectPrvParser* p_parser, SpvReflectShaderModule* p_module)
33873425
{
33883426
p_module->specialization_constant_count = 0;
@@ -3408,37 +3446,26 @@ static SpvReflectResult ParseSpecializationConstants(SpvReflectPrvParser* p_pars
34083446
switch(p_node->op) {
34093447
default: continue;
34103448
case SpvOpSpecConstantTrue: {
3411-
p_module->specialization_constants[index].constant_type = SPV_REFLECT_SPECIALIZATION_CONSTANT_BOOL;
3412-
p_module->specialization_constants[index].default_value.int_bool_value = 1;
3413-
p_module->specialization_constants[index].current_value.int_bool_value = 1;
3449+
p_module->specialization_constants[index].default_value.type = SPV_REFLECT_SCALAR_TYPE_BOOL;
3450+
p_module->specialization_constants[index].default_value.value.int_bool_value = 1;
3451+
p_module->specialization_constants[index].default_value.bit_size = 1;
3452+
p_module->specialization_constants[index].default_value.is_signed = 0;
3453+
p_module->specialization_constants[index].current_value = p_module->specialization_constants[index].default_value;
34143454
} break;
34153455
case SpvOpSpecConstantFalse: {
3416-
p_module->specialization_constants[index].constant_type = SPV_REFLECT_SPECIALIZATION_CONSTANT_BOOL;
3417-
p_module->specialization_constants[index].default_value.int_bool_value = 0;
3418-
p_module->specialization_constants[index].current_value.int_bool_value = 0;
3456+
p_module->specialization_constants[index].default_value.type = SPV_REFLECT_SCALAR_TYPE_BOOL;
3457+
p_module->specialization_constants[index].default_value.value.int_bool_value = 1;
3458+
p_module->specialization_constants[index].default_value.bit_size = 1;
3459+
p_module->specialization_constants[index].default_value.is_signed = 0;
3460+
p_module->specialization_constants[index].current_value = p_module->specialization_constants[index].default_value;
34193461
} break;
34203462
case SpvOpSpecConstant: {
34213463
SpvReflectResult result = SPV_REFLECT_RESULT_SUCCESS;
3422-
uint32_t element_type_id = (uint32_t)INVALID_VALUE;
3423-
uint32_t default_value = 0;
3424-
IF_READU32(result, p_parser, p_node->word_offset + 1, element_type_id);
3425-
// only support 32 bit arguments here...
3426-
IF_READU32(result, p_parser, p_node->word_offset + 3, default_value);
3427-
3428-
SpvReflectPrvNode* p_next_node = FindNode(p_parser, element_type_id);
3429-
if(IsNull(p_next_node)){
3430-
return SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_ID_REFERENCE;
3431-
}
3432-
if (p_next_node->op == SpvOpTypeInt) {
3433-
p_module->specialization_constants[index].constant_type = SPV_REFLECT_SPECIALIZATION_CONSTANT_INT;
3434-
} else if (p_next_node->op == SpvOpTypeFloat) {
3435-
p_module->specialization_constants[index].constant_type = SPV_REFLECT_SPECIALIZATION_CONSTANT_FLOAT;
3436-
} else {
3437-
return SPV_REFLECT_RESULT_ERROR_PARSE_FAILED;
3438-
}
3439-
3440-
p_module->specialization_constants[index].default_value.int_bool_value = default_value; //bits are the same for int and float
3441-
p_module->specialization_constants[index].current_value.int_bool_value = default_value;
3464+
SpvReflectScalarValue default_value = { 0 };
3465+
result = GetScalarConstant(p_parser,p_module, p_node, &default_value);
3466+
if (result != SPV_REFLECT_RESULT_SUCCESS) return result;
3467+
p_module->specialization_constants[index].default_value = default_value;
3468+
p_module->specialization_constants[index].current_value = p_module->specialization_constants[index].default_value;
34423469
} break;
34433470
}
34443471
// spec constant id cannot be the same, at least for valid values. (invalid value is just constant?)
@@ -5235,57 +5262,40 @@ SpvReflectResult GetSpecContantById(SpvReflectShaderModule* p_module, uint32_t c
52355262
}
52365263

52375264
// used for calculating specialization constants.
5238-
// maybe check for recursion?
5239-
SpvReflectResult EvaluateResult(SpvReflectShaderModule* p_module, uint32_t result_id, SpvReflectScalarValue* result)
5265+
SpvReflectResult EvaluateResultImpl(SpvReflectShaderModule* p_module, uint32_t result_id, SpvReflectScalarValue* result, uint32_t maxRecursion)
52405266
{
5241-
if (!result || !p_module) {
5242-
return SPV_REFLECT_RESULT_ERROR_NULL_POINTER;
5243-
}
5244-
if((p_module->_internal->module_flags & SPV_REFLECT_MODULE_FLAG_EVALUATE_SPEC_CONSTANT)==0){
5245-
return SPV_REFLECT_RESULT_ERROR_PARSE_FAILED;
5246-
}
5267+
if(!maxRecursion) return SPV_REFLECT_RESULT_ERROR_SPIRV_RECURSION;
52475268
SpvReflectResult res;
52485269
SpvReflectPrvParser* p_parser = p_module->_internal->parser;
52495270
SpvReflectPrvNode* p_node = FindNode(p_parser, result_id);
5250-
if(!p_node) return SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_ID_REFERENCE;
5271+
if (!p_node) return SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_ID_REFERENCE;
52515272
switch (p_node->op) {
52525273
default:
52535274
return SPV_REFLECT_RESULT_ERROR_SPIRV_UNRESOLVED_EVALUATION;
52545275
case SpvOpConstant:
52555276
CONSTANT_RESULT:
5256-
/*switch(){
5257-
5258-
}*/
5259-
// read constant value
5260-
return SPV_REFLECT_RESULT_SUCCESS;
5261-
case SpvOpSpecConstant: {
5277+
return GetScalarConstant(p_parser, p_module, p_node, result);
5278+
5279+
case SpvOpSpecConstant:
5280+
{
52625281
if (p_node->decorations.specialization_constant.value == (uint32_t)INVALID_VALUE) {
52635282
goto CONSTANT_RESULT;
52645283
}
52655284
SpvReflectSpecializationConstant* p_constant;
52665285
res = GetSpecContantById(p_module, p_node->decorations.specialization_constant.value, &p_constant);
5267-
if(res != SPV_REFLECT_RESULT_SUCCESS) return res;
5268-
if(p_constant->constant_type== SPV_REFLECT_SPECIALIZATION_CONSTANT_INT ||
5269-
p_constant->constant_type==SPV_REFLECT_SPECIALIZATION_CONSTANT_BOOL){
5270-
result->int_bool_value = p_constant->current_value.int_bool_value;
5271-
}
5272-
else if (p_constant->constant_type == SPV_REFLECT_SPECIALIZATION_CONSTANT_FLOAT) {
5273-
result->float_value = p_constant->current_value.float_value;
5274-
}
5275-
else {
5276-
return SPV_REFLECT_RESULT_ERROR_INTERNAL_ERROR;
5277-
}
5286+
if (res != SPV_REFLECT_RESULT_SUCCESS) return res;
5287+
*result = p_constant->current_value;
52785288
}
52795289
return SPV_REFLECT_RESULT_SUCCESS;
52805290
case SpvOpSpecConstantComposite:
52815291
{
5282-
5292+
// only support scalar types for now...
52835293
}
52845294
return SPV_REFLECT_RESULT_ERROR_SPIRV_UNRESOLVED_EVALUATION;
52855295
case SpvOpUndef:
52865296
{
52875297
// invalid data should be handled separately?
5288-
result->int_bool_value = (uint32_t)INVALID_VALUE;
5298+
result->value.int_bool_value = (uint32_t)INVALID_VALUE;
52895299
}
52905300
case SpvOpSConvert: case SpvOpUConvert: case SpvOpFConvert:
52915301
case SpvOpSNegate: case SpvOpNot: case SpvOpIAdd: case SpvOpISub:
@@ -5305,19 +5315,19 @@ SpvReflectResult EvaluateResult(SpvReflectShaderModule* p_module, uint32_t resul
53055315
// add implementations here.
53065316
return SPV_REFLECT_RESULT_ERROR_SPIRV_UNRESOLVED_EVALUATION;
53075317

5308-
// check shader capability... vulkan should assume this...
5318+
// check shader capability... vulkan should assume this...
53095319
case SpvOpQuantizeToF16:
53105320
return SPV_REFLECT_RESULT_ERROR_SPIRV_UNRESOLVED_EVALUATION;
53115321
break;
53125322

5313-
// check kernel capability... vulkan have none currently...
5323+
// check kernel capability... vulkan have none currently...
53145324
case SpvOpConvertFToS: case SpvOpConvertSToF:
53155325
case SpvOpConvertFToU: case SpvOpConvertUToF:
53165326
case SpvOpConvertPtrToU: case SpvOpConvertUToPtr:
53175327
case SpvOpGenericCastToPtr: case SpvOpPtrCastToGeneric:
5318-
case SpvOpBitcast: case SpvOpFNegate:
5328+
case SpvOpBitcast: case SpvOpFNegate:
53195329
case SpvOpFAdd: case SpvOpFSub: case SpvOpFMul: case SpvOpFDiv:
5320-
case SpvOpFRem: case SpvOpFMod:
5330+
case SpvOpFRem: case SpvOpFMod:
53215331
case SpvOpAccessChain: case SpvOpInBoundsAccessChain:
53225332
case SpvOpPtrAccessChain: case SpvOpInBoundsPtrAccessChain:
53235333
return SPV_REFLECT_RESULT_ERROR_SPIRV_UNRESOLVED_EVALUATION;
@@ -5326,3 +5336,16 @@ SpvReflectResult EvaluateResult(SpvReflectShaderModule* p_module, uint32_t resul
53265336
}
53275337

53285338

5339+
SpvReflectResult EvaluateResult(SpvReflectShaderModule* p_module, uint32_t result_id, SpvReflectScalarValue* result)
5340+
{
5341+
if (!result || !p_module) {
5342+
return SPV_REFLECT_RESULT_ERROR_NULL_POINTER;
5343+
}
5344+
if((p_module->_internal->module_flags & SPV_REFLECT_MODULE_FLAG_EVALUATE_SPEC_CONSTANT)==0){
5345+
return SPV_REFLECT_RESULT_ERROR_PARSE_FAILED;
5346+
}
5347+
// compute at most 100 instruction levels, maybe defining somewhere is better.
5348+
return EvaluateResultImpl(p_module, result_id, result, 100);
5349+
}
5350+
5351+

spirv_reflect.h

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,8 @@ typedef enum SpvReflectResult {
8080
SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_ENTRY_POINT,
8181
SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_EXECUTION_MODE,
8282
SPV_REFLECT_RESULT_ERROR_SPIRV_DUPLICATE_SPEC_CONSTANT_NAME,
83-
SPV_REFLECT_RESULT_ERROR_SPIRV_UNRESOLVED_EVALUATION
83+
SPV_REFLECT_RESULT_ERROR_SPIRV_UNRESOLVED_EVALUATION,
84+
SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_TYPE
8485
} SpvReflectResult;
8586

8687
/*! @enum SpvReflectModuleFlagBits
@@ -340,21 +341,31 @@ typedef struct SpvReflectTypeDescription {
340341
341342
*/
342343
typedef enum SpvReflectSpecializationConstantType {
343-
SPV_REFLECT_SPECIALIZATION_CONSTANT_BOOL = 0,
344-
SPV_REFLECT_SPECIALIZATION_CONSTANT_INT = 1,
345-
SPV_REFLECT_SPECIALIZATION_CONSTANT_FLOAT = 2,
344+
SPV_REFLECT_SCALAR_TYPE_UNKNOWN = 0,
345+
SPV_REFLECT_SCALAR_TYPE_BOOL = 1,
346+
SPV_REFLECT_SCALAR_TYPE_INT = 2,
347+
SPV_REFLECT_SCALAR_TYPE_FLOAT = 3,
346348
} SpvReflectSpecializationConstantType;
347349

348-
typedef union SpvReflectScalarValue {
349-
float float_value;
350-
uint32_t int_bool_value;
350+
// using union may have alignment issues on certain platforms
351+
// having type info here helps evaluating results
352+
typedef struct SpvReflectScalarValue {
353+
union {
354+
float float_value;
355+
uint32_t int_bool_value;
356+
// c/cpp doesn't have alignment requirements
357+
double float64_value;
358+
uint64_t int64_value;
359+
} value ;
360+
SpvReflectSpecializationConstantType type;
361+
int is_signed;
362+
int bit_size;
351363
} SpvReflectScalarValue;
352364

353365
typedef struct SpvReflectSpecializationConstant {
354366
const char* name;
355367
uint32_t spirv_id;
356368
uint32_t constant_id;
357-
SpvReflectSpecializationConstantType constant_type;
358369
SpvReflectScalarValue default_value;
359370
SpvReflectScalarValue current_value;
360371
} SpvReflectSpecializationConstant;

0 commit comments

Comments
 (0)