85 #include "llvm/IR/IntrinsicsBPF.h"
95 #define DEBUG_TYPE "bpf-abstract-member-access"
105 M, Intrinsic::bpf_passthrough, {Input->getType(), Input->getType()});
115 using namespace llvm;
118 class BPFAbstractMemberAccess final {
131 typedef std::stack<std::pair<CallInst *, CallInfo>> CallInfoStack;
135 BPFPreserveArrayAI = 1,
136 BPFPreserveUnionAI = 2,
137 BPFPreserveStructAI = 3,
138 BPFPreserveFieldInfoAI = 4,
145 static std::map<std::string, GlobalVariable *> GEPGlobals;
147 std::map<CallInst *, std::pair<CallInst *, CallInfo>> AIChain;
151 std::map<CallInst *, CallInfo> BaseAICalls;
165 bool removePreserveAccessIndexIntrinsic(
Function &
F);
166 void replaceWithGEP(std::vector<CallInst *> &CallList,
168 bool HasPreserveFieldInfoCall(CallInfoStack &CallStack);
176 std::string &AccessKey,
MDNode *&BaseMeta);
178 std::string &AccessKey,
bool &IsInt32Ret);
183 std::map<std::string, GlobalVariable *> BPFAbstractMemberAccess::GEPGlobals;
185 class BPFAbstractMemberAccessLegacyPass final :
public FunctionPass {
189 return BPFAbstractMemberAccess(
TM).run(
F);
207 "BPF Abstract Member Access",
false,
false)
210 return new BPFAbstractMemberAccessLegacyPass(
TM);
214 LLVM_DEBUG(
dbgs() <<
"********** Abstract Member Accesses **********\n");
221 if (
M->debug_compile_units().empty())
224 DL = &
M->getDataLayout();
225 return doTransformation(
F);
229 if (
Tag != dwarf::DW_TAG_typedef &&
Tag != dwarf::DW_TAG_const_type &&
230 Tag != dwarf::DW_TAG_volatile_type &&
231 Tag != dwarf::DW_TAG_restrict_type &&
232 Tag != dwarf::DW_TAG_member)
234 if (
Tag == dwarf::DW_TAG_typedef && !skipTypedef)
240 while (
auto *DTy = dyn_cast<DIDerivedType>(Ty)) {
243 Ty = DTy->getBaseType();
249 while (
auto *DTy = dyn_cast<DIDerivedType>(Ty)) {
252 Ty = DTy->getBaseType();
260 for (
uint32_t I = StartDim;
I < Elements.size(); ++
I) {
261 if (
auto *Element = dyn_cast_or_null<DINode>(Elements[
I]))
262 if (Element->getTag() == dwarf::DW_TAG_subrange_type) {
263 const DISubrange *SR = cast<DISubrange>(Element);
265 DimSize *= CI->getSExtValue();
274 return Call->getParamElementType(0);
278 bool BPFAbstractMemberAccess::IsPreserveDIAccessIndexCall(
const CallInst *Call,
283 const auto *GV = dyn_cast<GlobalValue>(
Call->getCalledOperand());
286 if (GV->getName().startswith(
"llvm.preserve.array.access.index")) {
287 CInfo.Kind = BPFPreserveArrayAI;
288 CInfo.Metadata =
Call->getMetadata(LLVMContext::MD_preserve_access_index);
290 report_fatal_error(
"Missing metadata for llvm.preserve.array.access.index intrinsic");
291 CInfo.AccessIndex = getConstant(
Call->getArgOperand(2));
292 CInfo.Base =
Call->getArgOperand(0);
296 if (GV->getName().startswith(
"llvm.preserve.union.access.index")) {
297 CInfo.Kind = BPFPreserveUnionAI;
298 CInfo.Metadata =
Call->getMetadata(LLVMContext::MD_preserve_access_index);
300 report_fatal_error(
"Missing metadata for llvm.preserve.union.access.index intrinsic");
301 CInfo.AccessIndex = getConstant(
Call->getArgOperand(1));
302 CInfo.Base =
Call->getArgOperand(0);
305 if (GV->getName().startswith(
"llvm.preserve.struct.access.index")) {
306 CInfo.Kind = BPFPreserveStructAI;
307 CInfo.Metadata =
Call->getMetadata(LLVMContext::MD_preserve_access_index);
309 report_fatal_error(
"Missing metadata for llvm.preserve.struct.access.index intrinsic");
310 CInfo.AccessIndex = getConstant(
Call->getArgOperand(2));
311 CInfo.Base =
Call->getArgOperand(0);
315 if (GV->getName().startswith(
"llvm.bpf.preserve.field.info")) {
316 CInfo.Kind = BPFPreserveFieldInfoAI;
317 CInfo.Metadata =
nullptr;
319 uint64_t InfoKind = getConstant(
Call->getArgOperand(1));
322 CInfo.AccessIndex = InfoKind;
325 if (GV->getName().startswith(
"llvm.bpf.preserve.type.info")) {
326 CInfo.Kind = BPFPreserveFieldInfoAI;
327 CInfo.Metadata =
Call->getMetadata(LLVMContext::MD_preserve_access_index);
341 if (GV->getName().startswith(
"llvm.bpf.preserve.enum.value")) {
342 CInfo.Kind = BPFPreserveFieldInfoAI;
343 CInfo.Metadata =
Call->getMetadata(LLVMContext::MD_preserve_access_index);
359 void BPFAbstractMemberAccess::replaceWithGEP(std::vector<CallInst *> &CallList,
362 for (
auto Call : CallList) {
364 if (DimensionIndex > 0)
365 Dimension = getConstant(
Call->getArgOperand(DimensionIndex));
371 IdxList.push_back(Zero);
372 IdxList.push_back(
Call->getArgOperand(GEPIndex));
377 Call->eraseFromParent();
381 bool BPFAbstractMemberAccess::removePreserveAccessIndexIntrinsic(
Function &
F) {
382 std::vector<CallInst *> PreserveArrayIndexCalls;
383 std::vector<CallInst *> PreserveUnionIndexCalls;
384 std::vector<CallInst *> PreserveStructIndexCalls;
389 auto *
Call = dyn_cast<CallInst>(&
I);
391 if (!IsPreserveDIAccessIndexCall(Call, CInfo))
395 if (CInfo.Kind == BPFPreserveArrayAI)
396 PreserveArrayIndexCalls.push_back(Call);
397 else if (CInfo.Kind == BPFPreserveUnionAI)
398 PreserveUnionIndexCalls.push_back(Call);
400 PreserveStructIndexCalls.push_back(Call);
413 replaceWithGEP(PreserveArrayIndexCalls, 1, 2);
414 replaceWithGEP(PreserveStructIndexCalls, 0, 1);
415 for (
auto Call : PreserveUnionIndexCalls) {
416 Call->replaceAllUsesWith(
Call->getArgOperand(0));
417 Call->eraseFromParent();
426 bool BPFAbstractMemberAccess::IsValidAIChain(
const MDNode *ParentType,
428 const MDNode *ChildType) {
437 if (isa<DIDerivedType>(CType))
441 if (
const auto *PtrTy = dyn_cast<DIDerivedType>(PType)) {
442 if (PtrTy->getTag() != dwarf::DW_TAG_pointer_type)
448 const auto *PTy = dyn_cast<DICompositeType>(PType);
449 const auto *CTy = dyn_cast<DICompositeType>(CType);
450 assert(PTy && CTy &&
"ParentType or ChildType is null or not composite");
453 assert(PTyTag == dwarf::DW_TAG_array_type ||
454 PTyTag == dwarf::DW_TAG_structure_type ||
455 PTyTag == dwarf::DW_TAG_union_type);
458 assert(CTyTag == dwarf::DW_TAG_array_type ||
459 CTyTag == dwarf::DW_TAG_structure_type ||
460 CTyTag == dwarf::DW_TAG_union_type);
463 if (PTyTag == dwarf::DW_TAG_array_type && PTyTag == CTyTag)
464 return PTy->getBaseType() == CTy->getBaseType();
467 if (PTyTag == dwarf::DW_TAG_array_type)
468 Ty = PTy->getBaseType();
470 Ty = dyn_cast<DIType>(PTy->getElements()[ParentAI]);
475 void BPFAbstractMemberAccess::traceAICall(
CallInst *Call,
482 if (
auto *BI = dyn_cast<BitCastInst>(Inst)) {
483 traceBitCast(BI, Call, ParentInfo);
484 }
else if (
auto *CI = dyn_cast<CallInst>(Inst)) {
487 if (IsPreserveDIAccessIndexCall(CI, ChildInfo) &&
488 IsValidAIChain(ParentInfo.Metadata, ParentInfo.AccessIndex,
489 ChildInfo.Metadata)) {
490 AIChain[CI] = std::make_pair(Call, ParentInfo);
491 traceAICall(CI, ChildInfo);
493 BaseAICalls[
Call] = ParentInfo;
495 }
else if (
auto *GI = dyn_cast<GetElementPtrInst>(Inst)) {
496 if (GI->hasAllZeroIndices())
497 traceGEP(GI, Call, ParentInfo);
499 BaseAICalls[
Call] = ParentInfo;
501 BaseAICalls[
Call] = ParentInfo;
506 void BPFAbstractMemberAccess::traceBitCast(
BitCastInst *BitCast,
514 if (
auto *BI = dyn_cast<BitCastInst>(Inst)) {
515 traceBitCast(BI, Parent, ParentInfo);
516 }
else if (
auto *CI = dyn_cast<CallInst>(Inst)) {
518 if (IsPreserveDIAccessIndexCall(CI, ChildInfo) &&
519 IsValidAIChain(ParentInfo.Metadata, ParentInfo.AccessIndex,
520 ChildInfo.Metadata)) {
521 AIChain[CI] = std::make_pair(Parent, ParentInfo);
522 traceAICall(CI, ChildInfo);
524 BaseAICalls[Parent] = ParentInfo;
526 }
else if (
auto *GI = dyn_cast<GetElementPtrInst>(Inst)) {
527 if (GI->hasAllZeroIndices())
528 traceGEP(GI, Parent, ParentInfo);
530 BaseAICalls[Parent] = ParentInfo;
532 BaseAICalls[Parent] = ParentInfo;
544 if (
auto *BI = dyn_cast<BitCastInst>(Inst)) {
545 traceBitCast(BI, Parent, ParentInfo);
546 }
else if (
auto *CI = dyn_cast<CallInst>(Inst)) {
548 if (IsPreserveDIAccessIndexCall(CI, ChildInfo) &&
549 IsValidAIChain(ParentInfo.Metadata, ParentInfo.AccessIndex,
550 ChildInfo.Metadata)) {
551 AIChain[CI] = std::make_pair(Parent, ParentInfo);
552 traceAICall(CI, ChildInfo);
554 BaseAICalls[Parent] = ParentInfo;
556 }
else if (
auto *GI = dyn_cast<GetElementPtrInst>(Inst)) {
557 if (GI->hasAllZeroIndices())
558 traceGEP(GI, Parent, ParentInfo);
560 BaseAICalls[Parent] = ParentInfo;
562 BaseAICalls[Parent] = ParentInfo;
567 void BPFAbstractMemberAccess::collectAICallChains(
Function &
F) {
574 auto *
Call = dyn_cast<CallInst>(&
I);
575 if (!IsPreserveDIAccessIndexCall(Call, CInfo) ||
576 AIChain.find(Call) != AIChain.end())
579 traceAICall(Call, CInfo);
583 uint64_t BPFAbstractMemberAccess::getConstant(
const Value *IndexValue) {
584 const ConstantInt *CV = dyn_cast<ConstantInt>(IndexValue);
590 void BPFAbstractMemberAccess::GetStorageBitRange(
DIDerivedType *MemberTy,
591 Align RecordAlignment,
597 if (RecordAlignment > 8) {
600 if (MemberBitOffset / 64 != (MemberBitOffset + MemberBitSize) / 64)
602 "requiring too big alignment");
603 RecordAlignment =
Align(8);
607 if (MemberBitSize > AlignBits)
609 "bitfield size greater than record alignment");
611 StartBitOffset = MemberBitOffset & ~(AlignBits - 1);
612 if ((StartBitOffset + AlignBits) < (MemberBitOffset + MemberBitSize))
614 "cross alignment boundary");
615 EndBitOffset = StartBitOffset + AlignBits;
628 if (
Tag == dwarf::DW_TAG_array_type) {
631 (EltTy->getSizeInBits() >> 3);
632 }
else if (
Tag == dwarf::DW_TAG_structure_type) {
633 auto *MemberTy = cast<DIDerivedType>(CTy->
getElements()[AccessIndex]);
637 unsigned SBitOffset, NextSBitOffset;
638 GetStorageBitRange(MemberTy, *RecordAlignment, SBitOffset,
640 PatchImm += SBitOffset >> 3;
647 if (
Tag == dwarf::DW_TAG_array_type) {
649 return calcArraySize(CTy, 1) * (EltTy->getSizeInBits() >> 3);
651 auto *MemberTy = cast<DIDerivedType>(CTy->
getElements()[AccessIndex]);
654 return SizeInBits >> 3;
656 unsigned SBitOffset, NextSBitOffset;
657 GetStorageBitRange(MemberTy, *RecordAlignment, SBitOffset,
659 SizeInBits = NextSBitOffset - SBitOffset;
660 if (SizeInBits & (SizeInBits - 1))
662 return SizeInBits >> 3;
668 if (
Tag == dwarf::DW_TAG_array_type) {
674 auto *MemberTy = cast<DIDerivedType>(CTy->
getElements()[AccessIndex]);
679 const auto *BTy = dyn_cast<DIBasicType>(BaseTy);
681 const auto *CompTy = dyn_cast<DICompositeType>(BaseTy);
683 if (!CompTy || CompTy->getTag() != dwarf::DW_TAG_enumeration_type)
686 BTy = dyn_cast<DIBasicType>(BaseTy);
688 uint32_t Encoding = BTy->getEncoding();
689 return (Encoding == dwarf::DW_ATE_signed || Encoding == dwarf::DW_ATE_signed_char);
699 bool IsBitField =
false;
702 if (
Tag == dwarf::DW_TAG_array_type) {
706 MemberTy = cast<DIDerivedType>(CTy->
getElements()[AccessIndex]);
714 return 64 - SizeInBits;
717 unsigned SBitOffset, NextSBitOffset;
718 GetStorageBitRange(MemberTy, *RecordAlignment, SBitOffset, NextSBitOffset);
719 if (NextSBitOffset - SBitOffset > 64)
724 return SBitOffset + 64 - OffsetInBits - SizeInBits;
726 return OffsetInBits + 64 - NextSBitOffset;
731 bool IsBitField =
false;
733 if (
Tag == dwarf::DW_TAG_array_type) {
737 MemberTy = cast<DIDerivedType>(CTy->
getElements()[AccessIndex]);
745 return 64 - SizeInBits;
748 unsigned SBitOffset, NextSBitOffset;
749 GetStorageBitRange(MemberTy, *RecordAlignment, SBitOffset, NextSBitOffset);
750 if (NextSBitOffset - SBitOffset > 64)
753 return 64 - SizeInBits;
759 bool BPFAbstractMemberAccess::HasPreserveFieldInfoCall(CallInfoStack &CallStack) {
761 while (CallStack.size()) {
762 auto StackElem = CallStack.top();
763 if (StackElem.second.Kind == BPFPreserveFieldInfoAI)
773 Value *BPFAbstractMemberAccess::computeBaseAndAccessKey(
CallInst *Call,
775 std::string &AccessKey,
779 CallInfoStack CallStack;
783 CallStack.push(std::make_pair(Call, CInfo));
784 CInfo = AIChain[
Call].second;
801 while (CallStack.size()) {
802 auto StackElem = CallStack.top();
803 Call = StackElem.first;
804 CInfo = StackElem.second;
812 if (CInfo.Kind == BPFPreserveUnionAI ||
813 CInfo.Kind == BPFPreserveStructAI) {
817 TypeMeta = PossibleTypeDef;
822 assert(CInfo.Kind == BPFPreserveArrayAI);
828 uint64_t AccessIndex = CInfo.AccessIndex;
831 bool CheckElemType =
false;
832 if (
const auto *CTy = dyn_cast<DICompositeType>(Ty)) {
842 auto *DTy = cast<DIDerivedType>(Ty);
843 assert(DTy->getTag() == dwarf::DW_TAG_pointer_type);
846 CTy = dyn_cast<DICompositeType>(BaseTy);
848 CheckElemType =
true;
849 }
else if (CTy->
getTag() != dwarf::DW_TAG_array_type) {
850 FirstIndex += AccessIndex;
851 CheckElemType =
true;
858 auto *CTy = dyn_cast<DICompositeType>(BaseTy);
860 if (HasPreserveFieldInfoCall(CallStack))
865 unsigned CTag = CTy->
getTag();
866 if (CTag == dwarf::DW_TAG_structure_type || CTag == dwarf::DW_TAG_union_type) {
869 if (HasPreserveFieldInfoCall(CallStack))
883 while (CallStack.size()) {
884 auto StackElem = CallStack.top();
885 CInfo = StackElem.second;
888 if (CInfo.Kind == BPFPreserveFieldInfoAI) {
889 InfoKind = CInfo.AccessIndex;
897 if (CallStack.size()) {
898 auto StackElem2 = CallStack.top();
899 CallInfo CInfo2 = StackElem2.second;
900 if (CInfo2.Kind == BPFPreserveFieldInfoAI) {
901 InfoKind = CInfo2.AccessIndex;
902 assert(CallStack.size() == 1);
907 uint64_t AccessIndex = CInfo.AccessIndex;
910 MDNode *MDN = CInfo.Metadata;
913 PatchImm = GetFieldInfo(InfoKind, CTy, AccessIndex, PatchImm,
914 CInfo.RecordAlignment);
931 std::string &AccessKey,
937 std::string AccessStr(
"0");
953 cast<GlobalVariable>(
Call->getArgOperand(1)->stripPointerCasts());
965 const auto *CTy = cast<DICompositeType>(BaseTy);
966 assert(CTy->
getTag() == dwarf::DW_TAG_enumeration_type);
969 const auto *
Enum = cast<DIEnumerator>(Element);
970 if (
Enum->getName() == EnumeratorStr) {
979 PatchImm = std::stoll(std::string(EValueStr));
985 AccessKey =
"llvm." + Ty->
getName().
str() +
":" +
994 bool BPFAbstractMemberAccess::transformGEPChain(
CallInst *Call,
996 std::string AccessKey;
1001 IsInt32Ret = CInfo.Kind == BPFPreserveFieldInfoAI;
1002 if (CInfo.Kind == BPFPreserveFieldInfoAI && CInfo.Metadata) {
1003 TypeMeta = computeAccessKey(Call, CInfo, AccessKey, IsInt32Ret);
1005 Base = computeBaseAndAccessKey(Call, CInfo, AccessKey, TypeMeta);
1013 if (GEPGlobals.find(AccessKey) == GEPGlobals.end()) {
1021 nullptr, AccessKey);
1023 GV->
setMetadata(LLVMContext::MD_preserve_access_index, TypeMeta);
1024 GEPGlobals[AccessKey] = GV;
1026 GV = GEPGlobals[AccessKey];
1029 if (CInfo.Kind == BPFPreserveFieldInfoAI) {
1039 Call->replaceAllUsesWith(PassThroughInst);
1040 Call->eraseFromParent();
1059 BB->getInstList().insert(
Call->getIterator(), BCInst);
1064 BB->getInstList().insert(
Call->getIterator(),
GEP);
1068 BB->getInstList().insert(
Call->getIterator(), BCInst2);
1117 Call->replaceAllUsesWith(PassThroughInst);
1118 Call->eraseFromParent();
1123 bool BPFAbstractMemberAccess::doTransformation(
Function &
F) {
1124 bool Transformed =
false;
1129 collectAICallChains(
F);
1131 for (
auto &
C : BaseAICalls)
1132 Transformed = transformGEPChain(
C.first,
C.second) || Transformed;
1134 return removePreserveAccessIndexIntrinsic(
F) || Transformed;