@@ -45,143 +45,154 @@ void DxilPIXAddTidToAmplificationShaderPayload::applyOptions(PassOptions O) {
4545}
4646
4747void AddValueToExpandedPayload (OP *HlslOP, llvm::IRBuilder<> &B,
48- ExpandedStruct &expanded,
4948 AllocaInst *NewStructAlloca,
5049 unsigned int expandedValueIndex, Value *value) {
5150 Constant *Zero32Arg = HlslOP->GetU32Const (0 );
5251 SmallVector<Value *, 2 > IndexToAppendedValue;
5352 IndexToAppendedValue.push_back (Zero32Arg);
5453 IndexToAppendedValue.push_back (HlslOP->GetU32Const (expandedValueIndex));
5554 auto *PointerToEmbeddedNewValue = B.CreateInBoundsGEP (
56- expanded. ExpandedPayloadStructType , NewStructAlloca, IndexToAppendedValue,
55+ NewStructAlloca, IndexToAppendedValue,
5756 " PointerToEmbeddedNewValue" + std::to_string (expandedValueIndex));
5857 B.CreateStore (value, PointerToEmbeddedNewValue);
5958}
6059
61- bool DxilPIXAddTidToAmplificationShaderPayload::runOnModule (Module &M) {
60+ void CopyAggregate (IRBuilder<> &B, Type *Ty, Value *Source, Value *Dest,
61+ ArrayRef<Value *> GEPIndices) {
62+ if (StructType *ST = dyn_cast<StructType>(Ty)) {
63+ SmallVector<Value *, 16 > StructIndices;
64+ StructIndices.append (GEPIndices.begin (), GEPIndices.end ());
65+ StructIndices.push_back (nullptr );
66+ for (unsigned j = 0 ; j < ST->getNumElements (); ++j) {
67+ StructIndices.back () = B.getInt32 (j);
68+ CopyAggregate (B, ST->getElementType (j), Source, Dest, StructIndices);
69+ }
70+ } else if (ArrayType *AT = dyn_cast<ArrayType>(Ty)) {
71+ SmallVector<Value *, 16 > StructIndices;
72+ StructIndices.append (GEPIndices.begin (), GEPIndices.end ());
73+ StructIndices.push_back (nullptr );
74+ for (unsigned j = 0 ; j < AT->getNumElements (); ++j) {
75+ StructIndices.back () = B.getInt32 (j);
76+ CopyAggregate (B, AT->getArrayElementType (), Source, Dest, StructIndices);
77+ }
78+ } else {
79+ auto *SourceGEP = B.CreateGEP (Source, GEPIndices, " CopyStructSourceGEP" );
80+ Value *Val = B.CreateLoad (SourceGEP, " CopyStructLoad" );
81+ auto *DestGEP = B.CreateGEP (Dest, GEPIndices, " CopyStructDestGEP" );
82+ B.CreateStore (Val, DestGEP, " CopyStructStore" );
83+ }
84+ }
6285
86+ bool DxilPIXAddTidToAmplificationShaderPayload::runOnModule (Module &M) {
6387 DxilModule &DM = M.GetOrCreateDxilModule ();
6488 LLVMContext &Ctx = M.getContext ();
6589 OP *HlslOP = DM.GetOP ();
66-
67- Type *OriginalPayloadStructPointerType = nullptr ;
68- Type *OriginalPayloadStructType = nullptr ;
69- ExpandedStruct expanded;
7090 llvm::Function *entryFunction = PIXPassHelpers::GetEntryFunction (DM);
7191 for (inst_iterator I = inst_begin (entryFunction), E = inst_end (entryFunction);
7292 I != E; ++I) {
73- if (auto *Instr = llvm::cast<Instruction>(&*I)) {
74- if (hlsl::OP::IsDxilOpFuncCallInst (Instr,
75- hlsl::OP::OpCode::DispatchMesh)) {
76- DxilInst_DispatchMesh DispatchMesh (Instr);
77- OriginalPayloadStructPointerType =
78- DispatchMesh.get_payload ()->getType ();
79- OriginalPayloadStructType =
80- OriginalPayloadStructPointerType->getPointerElementType ();
81- expanded = ExpandStructType (Ctx, OriginalPayloadStructType);
82- }
83- }
84- }
85-
86- AllocaInst *OldStructAlloca = nullptr ;
87- AllocaInst *NewStructAlloca = nullptr ;
88- std::vector<AllocaInst *> allocasOfPayloadType;
89- for (inst_iterator I = inst_begin (entryFunction), E = inst_end (entryFunction);
90- I != E; ++I) {
91- auto *Inst = &*I;
92- if (llvm::isa<AllocaInst>(Inst)) {
93- auto *Alloca = llvm::cast<AllocaInst>(Inst);
94- if (Alloca->getType () == OriginalPayloadStructPointerType) {
95- allocasOfPayloadType.push_back (Alloca);
96- }
93+ if (hlsl::OP::IsDxilOpFuncCallInst (&*I, hlsl::OP::OpCode::DispatchMesh)) {
94+ DxilInst_DispatchMesh DispatchMesh (&*I);
95+ Type *OriginalPayloadStructPointerType =
96+ DispatchMesh.get_payload ()->getType ();
97+ Type *OriginalPayloadStructType =
98+ OriginalPayloadStructPointerType->getPointerElementType ();
99+ ExpandedStruct expanded =
100+ ExpandStructType (Ctx, OriginalPayloadStructType);
101+
102+ llvm::IRBuilder<> B (&*I);
103+
104+ auto *NewStructAlloca =
105+ B.CreateAlloca (expanded.ExpandedPayloadStructType ,
106+ HlslOP->GetU32Const (1 ), " NewPayload" );
107+ NewStructAlloca->setAlignment (4 );
108+ auto PayloadType =
109+ llvm::dyn_cast<PointerType>(DispatchMesh.get_payload ()->getType ());
110+ SmallVector<Value *, 16 > GEPIndices;
111+ GEPIndices.push_back (B.getInt32 (0 ));
112+ CopyAggregate (B, PayloadType->getPointerElementType (),
113+ DispatchMesh.get_payload (), NewStructAlloca, GEPIndices);
114+
115+ Constant *Zero32Arg = HlslOP->GetU32Const (0 );
116+ Constant *One32Arg = HlslOP->GetU32Const (1 );
117+ Constant *Two32Arg = HlslOP->GetU32Const (2 );
118+
119+ auto GroupIdFunc =
120+ HlslOP->GetOpFunc (DXIL::OpCode::GroupId, Type::getInt32Ty (Ctx));
121+ Constant *GroupIdOpcode =
122+ HlslOP->GetU32Const ((unsigned )DXIL::OpCode::GroupId);
123+ auto *GroupIdX =
124+ B.CreateCall (GroupIdFunc, {GroupIdOpcode, Zero32Arg}, " GroupIdX" );
125+ auto *GroupIdY =
126+ B.CreateCall (GroupIdFunc, {GroupIdOpcode, One32Arg}, " GroupIdY" );
127+ auto *GroupIdZ =
128+ B.CreateCall (GroupIdFunc, {GroupIdOpcode, Two32Arg}, " GroupIdZ" );
129+
130+ // FlatGroupID = z + y*numZ + x*numY*numZ
131+ // Where x,y,z are the group ID components, and numZ and numY are the
132+ // corresponding AS group-count arguments to the DispatchMesh Direct3D API
133+ auto *GroupYxNumZ = B.CreateMul (
134+ GroupIdY, HlslOP->GetU32Const (m_DispatchArgumentZ), " GroupYxNumZ" );
135+ auto *FlatGroupNumZY =
136+ B.CreateAdd (GroupIdZ, GroupYxNumZ, " FlatGroupNumZY" );
137+ auto *GroupXxNumYZ = B.CreateMul (
138+ GroupIdX,
139+ HlslOP->GetU32Const (m_DispatchArgumentY * m_DispatchArgumentZ),
140+ " GroupXxNumYZ" );
141+ auto *FlatGroupID =
142+ B.CreateAdd (GroupXxNumYZ, FlatGroupNumZY, " FlatGroupID" );
143+
144+ // The ultimate goal is a single unique thread ID for this AS thread.
145+ // So take the flat group number, multiply it by the number of
146+ // threads per group...
147+ auto *FlatGroupIDWithSpaceForThreadInGroupId = B.CreateMul (
148+ FlatGroupID,
149+ HlslOP->GetU32Const (DM.GetNumThreads (0 ) * DM.GetNumThreads (1 ) *
150+ DM.GetNumThreads (2 )),
151+ " FlatGroupIDWithSpaceForThreadInGroupId" );
152+
153+ auto *FlattenedThreadIdInGroupFunc = HlslOP->GetOpFunc (
154+ DXIL::OpCode::FlattenedThreadIdInGroup, Type::getInt32Ty (Ctx));
155+ Constant *FlattenedThreadIdInGroupOpcode =
156+ HlslOP->GetU32Const ((unsigned )DXIL::OpCode::FlattenedThreadIdInGroup);
157+ auto FlatThreadIdInGroup = B.CreateCall (FlattenedThreadIdInGroupFunc,
158+ {FlattenedThreadIdInGroupOpcode},
159+ " FlattenedThreadIdInGroup" );
160+
161+ // ...and add the flat thread id:
162+ auto *FlatId = B.CreateAdd (FlatGroupIDWithSpaceForThreadInGroupId,
163+ FlatThreadIdInGroup, " FlatId" );
164+
165+ AddValueToExpandedPayload (
166+ HlslOP, B, NewStructAlloca,
167+ expanded.ExpandedPayloadStructType ->getStructNumElements () - 3 ,
168+ FlatId);
169+ AddValueToExpandedPayload (
170+ HlslOP, B, NewStructAlloca,
171+ expanded.ExpandedPayloadStructType ->getStructNumElements () - 2 ,
172+ DispatchMesh.get_threadGroupCountY ());
173+ AddValueToExpandedPayload (
174+ HlslOP, B, NewStructAlloca,
175+ expanded.ExpandedPayloadStructType ->getStructNumElements () - 1 ,
176+ DispatchMesh.get_threadGroupCountZ ());
177+
178+ auto DispatchMeshFn = HlslOP->GetOpFunc (
179+ DXIL::OpCode::DispatchMesh, expanded.ExpandedPayloadStructPtrType );
180+ Constant *DispatchMeshOpcode =
181+ HlslOP->GetU32Const ((unsigned )DXIL::OpCode::DispatchMesh);
182+ B.CreateCall (DispatchMeshFn,
183+ {DispatchMeshOpcode, DispatchMesh.get_threadGroupCountX (),
184+ DispatchMesh.get_threadGroupCountY (),
185+ DispatchMesh.get_threadGroupCountZ (), NewStructAlloca});
186+ I->removeFromParent ();
187+ delete &*I;
188+ // Validation requires exactly one DispatchMesh in an AS, so we can exit
189+ // after the first one:
190+ DM.ReEmitDxilResources ();
191+ return true ;
97192 }
98193 }
99- for (auto &Alloca : allocasOfPayloadType) {
100- OldStructAlloca = Alloca;
101- llvm::IRBuilder<> B (Alloca->getContext ());
102- NewStructAlloca = B.CreateAlloca (expanded.ExpandedPayloadStructType ,
103- HlslOP->GetU32Const (1 ), " NewPayload" );
104- NewStructAlloca->setAlignment (Alloca->getAlignment ());
105- NewStructAlloca->insertAfter (Alloca);
106-
107- ReplaceAllUsesOfInstructionWithNewValueAndDeleteInstruction (
108- Alloca, NewStructAlloca, expanded.ExpandedPayloadStructType );
109- }
110-
111- auto F = HlslOP->GetOpFunc (DXIL::OpCode::DispatchMesh,
112- expanded.ExpandedPayloadStructPtrType );
113- for (auto FI = F->user_begin (); FI != F->user_end ();) {
114- auto *FunctionUser = *FI++;
115- auto *UserInstruction = llvm::cast<Instruction>(FunctionUser);
116- DxilInst_DispatchMesh DispatchMesh (UserInstruction);
117-
118- llvm::IRBuilder<> B (UserInstruction);
119-
120- Constant *Zero32Arg = HlslOP->GetU32Const (0 );
121- Constant *One32Arg = HlslOP->GetU32Const (1 );
122- Constant *Two32Arg = HlslOP->GetU32Const (2 );
123-
124- auto GroupIdFunc =
125- HlslOP->GetOpFunc (DXIL::OpCode::GroupId, Type::getInt32Ty (Ctx));
126- Constant *GroupIdOpcode =
127- HlslOP->GetU32Const ((unsigned )DXIL::OpCode::GroupId);
128- auto *GroupIdX =
129- B.CreateCall (GroupIdFunc, {GroupIdOpcode, Zero32Arg}, " GroupIdX" );
130- auto *GroupIdY =
131- B.CreateCall (GroupIdFunc, {GroupIdOpcode, One32Arg}, " GroupIdY" );
132- auto *GroupIdZ =
133- B.CreateCall (GroupIdFunc, {GroupIdOpcode, Two32Arg}, " GroupIdZ" );
134-
135- // FlatGroupID = z + y*numZ + x*numY*numZ
136- // Where x,y,z are the group ID components, and numZ and numY are the
137- // corresponding AS group-count arguments to the DispatchMesh Direct3D API
138- auto *GroupYxNumZ = B.CreateMul (
139- GroupIdY, HlslOP->GetU32Const (m_DispatchArgumentZ), " GroupYxNumZ" );
140- auto *FlatGroupNumZY = B.CreateAdd (GroupIdZ, GroupYxNumZ, " FlatGroupNumZY" );
141- auto *GroupXxNumYZ = B.CreateMul (
142- GroupIdX,
143- HlslOP->GetU32Const (m_DispatchArgumentY * m_DispatchArgumentZ),
144- " GroupXxNumYZ" );
145- auto *FlatGroupID =
146- B.CreateAdd (GroupXxNumYZ, FlatGroupNumZY, " FlatGroFlatGroupIDupNum" );
147-
148- // The ultimate goal is a single unique thread ID for this AS thread.
149- // So take the flat group number, multiply it by the number of
150- // threads per group...
151- auto *FlatGroupIDWithSpaceForThreadInGroupId = B.CreateMul (
152- FlatGroupID,
153- HlslOP->GetU32Const (DM.GetNumThreads (0 ) * DM.GetNumThreads (1 ) *
154- DM.GetNumThreads (2 )),
155- " FlatGroupIDWithSpaceForThreadInGroupId" );
156-
157- auto *FlattenedThreadIdInGroupFunc = HlslOP->GetOpFunc (
158- DXIL::OpCode::FlattenedThreadIdInGroup, Type::getInt32Ty (Ctx));
159- Constant *FlattenedThreadIdInGroupOpcode =
160- HlslOP->GetU32Const ((unsigned )DXIL::OpCode::FlattenedThreadIdInGroup);
161- auto FlatThreadIdInGroup = B.CreateCall (FlattenedThreadIdInGroupFunc,
162- {FlattenedThreadIdInGroupOpcode},
163- " FlattenedThreadIdInGroup" );
164-
165- // ...and add the flat thread id:
166- auto *FlatId = B.CreateAdd (FlatGroupIDWithSpaceForThreadInGroupId,
167- FlatThreadIdInGroup, " FlatId" );
168-
169- AddValueToExpandedPayload (HlslOP, B, expanded, NewStructAlloca,
170- OriginalPayloadStructType->getStructNumElements (),
171- FlatId);
172- AddValueToExpandedPayload (
173- HlslOP, B, expanded, NewStructAlloca,
174- OriginalPayloadStructType->getStructNumElements () + 1 ,
175- DispatchMesh.get_threadGroupCountY ());
176- AddValueToExpandedPayload (
177- HlslOP, B, expanded, NewStructAlloca,
178- OriginalPayloadStructType->getStructNumElements () + 2 ,
179- DispatchMesh.get_threadGroupCountZ ());
180- }
181-
182- DM.ReEmitDxilResources ();
183194
184- return true ;
195+ return false ;
185196}
186197
187198char DxilPIXAddTidToAmplificationShaderPayload::ID = 0 ;
0 commit comments