Skip to content

Commit c72837a

Browse files
committed
fix: uninitialized checks for arrays (handle bitcast)
1 parent 9ac7ffc commit c72837a

File tree

1 file changed

+35
-25
lines changed

1 file changed

+35
-25
lines changed

lib/Optimizer/GILPasses/DetectUninitializedPass.cpp

Lines changed: 35 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -115,14 +115,23 @@ class DetectUninitializedPass
115115
state[allocatedPtr] = MemoryState::Uninitialized;
116116
} else if (auto *ptrOffsets
117117
= llvm::dyn_cast<gil::PtrOffsetInst>(&inst)) {
118-
gil::Value basePtr = ptrOffsets->getBasePointer();
118+
gil::Value basePtr = ptrOffsets->getBasePtr();
119119
MemoryState baseState = state.lookup(basePtr);
120120
state[ptrOffsets->getResult(0)] = baseState;
121121
} else if (auto *structFieldPtr
122122
= llvm::dyn_cast<gil::StructFieldPtrInst>(&inst)) {
123-
gil::Value basePtr = structFieldPtr->getStructValue();
123+
gil::Value basePtr = structFieldPtr->getStructPtr();
124124
MemoryState baseState = state.lookup(basePtr);
125125
state[structFieldPtr->getResult(0)] = baseState;
126+
} else if (auto *bitcastInst
127+
= llvm::dyn_cast<gil::BitcastInst>(&inst)) {
128+
auto source = bitcastInst->getOperand();
129+
auto result = bitcastInst->getResult(0);
130+
131+
if (llvm::isa<glu::types::PointerTy>(&*source.getType())
132+
&& llvm::isa<glu::types::PointerTy>(&*result.getType())) {
133+
state[result] = state.lookup(source);
134+
}
126135
}
127136
}
128137
}
@@ -267,7 +276,7 @@ class DetectUninitializedPass
267276
= llvm::dyn_cast_or_null<gil::StructFieldPtrInst>(
268277
destPtr.getDefiningInstruction()
269278
)) {
270-
gil::Value baseStruct = structFieldPtr->getStructValue();
279+
gil::Value baseStruct = structFieldPtr->getStructPtr();
271280
currentState[baseStruct] = MemoryState::Initialized;
272281
}
273282
}
@@ -278,8 +287,8 @@ class DetectUninitializedPass
278287

279288
MemoryState state = currentState.lookup(srcPtr);
280289
if (!currentState.contains(srcPtr)) {
281-
state = MemoryState::Initialized; // Default to initialized for
282-
// unknown values
290+
// Default to initialized for untracked values
291+
state = MemoryState::Initialized;
283292
}
284293

285294
if (state != MemoryState::Initialized) {
@@ -288,10 +297,8 @@ class DetectUninitializedPass
288297
);
289298
}
290299

291-
if (load->getResultCount() > 0) {
292-
gil::Value loadedValue = load->getResult(0);
293-
currentState[loadedValue] = state;
294-
}
300+
gil::Value loadedValue = load->getResult(0);
301+
currentState[loadedValue] = state;
295302

296303
if (load->getOwnershipKind() == gil::LoadOwnershipKind::Take) {
297304
currentState[srcPtr] = MemoryState::Uninitialized;
@@ -300,19 +307,13 @@ class DetectUninitializedPass
300307

301308
void visitAllocaInst(gil::AllocaInst *alloca)
302309
{
303-
if (alloca->getResultCount() > 0) {
304-
gil::Value allocatedPtr = alloca->getResult(0);
305-
currentState[allocatedPtr] = MemoryState::Uninitialized;
306-
}
310+
gil::Value allocatedPtr = alloca->getResult(0);
311+
currentState[allocatedPtr] = MemoryState::Uninitialized;
307312
}
308313

309314
void visitPtrOffsetInst(gil::PtrOffsetInst *inst)
310315
{
311-
if (inst->getResultCount() == 0) {
312-
return;
313-
}
314-
315-
gil::Value basePtr = inst->getBasePointer();
316+
gil::Value basePtr = inst->getBasePtr();
316317
MemoryState baseState = getTrackedStateOrDefault(
317318
basePtr, currentState, MemoryState::Uninitialized
318319
);
@@ -323,11 +324,7 @@ class DetectUninitializedPass
323324

324325
void visitStructFieldPtrInst(gil::StructFieldPtrInst *inst)
325326
{
326-
if (inst->getResultCount() == 0) {
327-
return;
328-
}
329-
330-
gil::Value basePtr = inst->getStructValue();
327+
gil::Value basePtr = inst->getStructPtr();
331328
MemoryState baseState = getTrackedStateOrDefault(
332329
basePtr, currentState, MemoryState::Uninitialized
333330
);
@@ -336,12 +333,25 @@ class DetectUninitializedPass
336333
currentState[resultPtr] = baseState;
337334
}
338335

339-
void visitStructExtractInst(gil::StructExtractInst *inst)
336+
void visitBitcastInst(gil::BitcastInst *inst)
340337
{
341-
if (inst->getResultCount() == 0) {
338+
auto value = inst->getOperand();
339+
auto result = inst->getResult(0);
340+
341+
if (!llvm::isa<glu::types::PointerTy>(&*value.getType())
342+
|| !llvm::isa<glu::types::PointerTy>(&*result.getType())) {
342343
return;
343344
}
344345

346+
MemoryState sourceState = getTrackedStateOrDefault(
347+
value, currentState, MemoryState::Uninitialized
348+
);
349+
350+
currentState[result] = sourceState;
351+
}
352+
353+
void visitStructExtractInst(gil::StructExtractInst *inst)
354+
{
345355
gil::Value fieldValue = inst->getResult(0);
346356
currentState[fieldValue] = MemoryState::Initialized;
347357
}

0 commit comments

Comments
 (0)