40#ifndef LLVM_CODEGEN_MIR2VEC_H
41#define LLVM_CODEGEN_MIR2VEC_H
88 using VocabMap = std::map<std::string, ir2vec::Embedding>;
119 enum class Section :
unsigned {
128 std::set<std::string> UniqueBaseOpcodeNames;
138 "Immediate",
"CImmediate",
"FPImmediate",
"MBB",
139 "FrameIndex",
"ConstantPoolIndex",
"TargetIndex",
"JumpTableIndex",
140 "ExternalSymbol",
"GlobalAddress",
"BlockAddress",
"RegisterMask",
141 "RegisterLiveOut",
"Metadata",
"MCSymbol",
"CFIIndex",
142 "IntrinsicID",
"Predicate",
"ShuffleMask"};
144 "Common operand names size changed, update accordingly");
146 const TargetInstrInfo &
TII;
147 const TargetRegisterInfo &
TRI;
148 const MachineRegisterInfo &
MRI;
150 void generateStorage(
const VocabMap &OpcodeMap,
151 const VocabMap &CommonOperandMap,
152 const VocabMap &PhyRegMap,
const VocabMap &VirtRegMap);
153 void buildCanonicalOpcodeMapping();
154 void buildRegisterOperandMapping();
157 unsigned getCanonicalOpcodeIndex(
unsigned Opcode)
const;
164 unsigned getRegisterOperandIndex(
Register Reg)
const;
169 unsigned LocalIndex = getCommonOperandIndex(OperandType);
170 return Storage[
static_cast<unsigned>(Section::CommonOperands)][LocalIndex];
177 return ZeroEmbedding;
185 return ZeroEmbedding;
187 unsigned LocalIndex = getRegisterOperandIndex(
Reg);
189 Reg.isPhysical() ? Section::PhyRegisters : Section::VirtRegisters;
190 return Storage[
static_cast<unsigned>(SectionID)][LocalIndex];
195 unsigned getEntityIDForCommonOperand(
197 return Layout.CommonOperandBase + getCommonOperandIndex(OperandType);
202 unsigned getEntityIDForRegister(
Register Reg)
const {
203 if (!
Reg.isValid() ||
Reg.isStack())
206 unsigned LocalIndex = getRegisterOperandIndex(
Reg);
208 Reg.isPhysical() ? Layout.PhyRegBase : Layout.VirtRegBase;
209 return BaseOffset + LocalIndex;
221 bool IsPhysical =
true)
const;
231 return Layout.OpcodeBase + getCanonicalOpcodeIndex(Opcode);
238 return getEntityIDForRegister(MO.
getReg());
239 return getEntityIDForCommonOperand(MO.
getType());
244 unsigned LocalIndex = getCanonicalOpcodeIndex(Opcode);
245 return Storage[
static_cast<unsigned>(Section::Opcodes)][LocalIndex];
249 auto OperandType = Operand.
getType();
251 return operator[](Operand.
getReg());
253 return operator[](OperandType);
266 create(VocabMap &&OpcMap, VocabMap &&CommonOperandsMap, VocabMap &&PhyRegMap,
280 MIRVocabulary(VocabMap &&OpcMap, VocabMap &&CommonOperandsMap,
352 static std::unique_ptr<SymbolicMIREmbedder>
367 using VocabMap = std::map<std::string, mir2vec::Embedding>;
375 Error readVocabulary(VocabMap &OpcVocab, VocabMap &CommonOperandVocab,
376 VocabMap &PhyRegVocabMap, VocabMap &VirtRegVocabMap);
382 using VocabVector = std::vector<mir2vec::Embedding>;
383 using VocabMap = std::map<std::string, mir2vec::Embedding>;
402 Provider = std::make_unique<MIR2VecVocabProvider>(MMI);
430 return "MIR2Vec Vocabulary Printer Pass";
452 return "MIR2Vec Embedder Printer Pass";
unsigned const MachineRegisterInfo * MRI
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
Provides ErrorOr<T> smart pointer.
const HexagonInstrInfo * TII
This file defines the IR2Vec vocabulary analysis(IR2VecVocabAnalysis), the core ir2vec::Embedder inte...
This header defines various interfaces for pass management in LLVM.
Register const TargetRegisterInfo * TRI
Promote Memory to Register
static cl::opt< RegAllocEvictionAdvisorAnalysisLegacy::AdvisorMode > Mode("regalloc-enable-advisor", cl::Hidden, cl::init(RegAllocEvictionAdvisorAnalysisLegacy::AdvisorMode::Default), cl::desc("Enable regalloc advisor mode"), cl::values(clEnumValN(RegAllocEvictionAdvisorAnalysisLegacy::AdvisorMode::Default, "default", "Default"), clEnumValN(RegAllocEvictionAdvisorAnalysisLegacy::AdvisorMode::Release, "release", "precompiled"), clEnumValN(RegAllocEvictionAdvisorAnalysisLegacy::AdvisorMode::Development, "development", "for training")))
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
void setPreservesAll()
Set by analyses that do not transform their input at all.
Lightweight error class with error context and mandatory checking.
Tagged union holding either a T or a Error.
This is an important class for using LLVM in a threaded context.
MIR2VecPrinterLegacyPass(raw_ostream &OS)
void getAnalysisUsage(AnalysisUsage &AU) const override
getAnalysisUsage - Subclasses that override getAnalysisUsage must call this.
bool runOnMachineFunction(MachineFunction &MF) override
runOnMachineFunction - This method must be overloaded to perform the desired machine code transformat...
StringRef getPassName() const override
getPassName - Return a nice clean name for a pass.
Pass to analyze and populate MIR2Vec vocabulary from a module.
MIR2VecVocabProvider & getProvider()
Expected< mir2vec::MIRVocabulary > getMIR2VecVocabulary(const Module &M)
std::unique_ptr< MIR2VecVocabProvider > Provider
void getAnalysisUsage(AnalysisUsage &AU) const override
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
MIR2VecVocabLegacyAnalysis()
StringRef getPassName() const override
getPassName - Return a nice clean name for a pass.
bool doFinalization(Module &M) override
doFinalization - Virtual method overriden by subclasses to do any necessary clean up after all passes...
bool runOnMachineFunction(MachineFunction &MF) override
runOnMachineFunction - This method must be overloaded to perform the desired machine code transformat...
MIR2VecVocabPrinterLegacyPass(raw_ostream &OS)
void getAnalysisUsage(AnalysisUsage &AU) const override
getAnalysisUsage - Subclasses that override getAnalysisUsage must call this.
MIR2Vec vocabulary provider used by pass managers and standalone tools.
MIR2VecVocabProvider(const MachineModuleInfo &MMI)
Expected< mir2vec::MIRVocabulary > getVocabulary(const Module &M)
MachineFunctionPass(char &ID)
void getAnalysisUsage(AnalysisUsage &AU) const override
getAnalysisUsage - Subclasses that override getAnalysisUsage must call this.
Representation of each machine instruction.
This class contains meta information specific to a module.
MachineOperand class - Representation of each machine instruction operand.
MachineOperandType getType() const
getType - Returns the MachineOperandType for this operand.
Register getReg() const
getReg - Returns the register number.
@ MO_Register
Register operand.
MachineRegisterInfo - Keep track of information for virtual and physical registers,...
A Module instance is used to store all the information related to an LLVM module.
AnalysisType & getAnalysis() const
getAnalysis<AnalysisType>() - This function is used by subclasses to get to the analysis information ...
virtual StringRef getPassName() const
getPassName - Return a nice clean name for a pass.
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
A wrapper around a string literal that serves as a proxy for constructing global tables of StringRefs...
StringRef - Represent a constant reference to a string, i.e.
TargetInstrInfo - Interface to description of machine instruction set.
TargetRegisterInfo base class - We assume that the target defines a static array of TargetRegisterDes...
Iterator support for section-based access.
Generic storage class for section-based vocabularies.
Base class for MIR embedders.
const unsigned Dimension
Dimension of the embeddings; Captured from the vocabulary.
Embedding getMFunctionVector() const
Computes and returns the embedding for the current machine function.
const MIRVocabulary & Vocab
Embedding getMInstVector(const MachineInstr &MI) const
Computes and returns the embedding for a given machine instruction MI in the machine function MF.
virtual Embedding computeEmbeddings(const MachineInstr &MI) const =0
Function to compute the embedding for a given machine instruction.
Embedding getMBBVector(const MachineBasicBlock &MBB) const
Computes and returns the embedding for a given machine basic block in the machine function MF.
const float RegOperandWeight
MIREmbedder(const MachineFunction &MF, const MIRVocabulary &Vocab)
const float CommonOperandWeight
Embedding computeEmbeddings() const
Function to compute embeddings.
const float OpcWeight
Weight for opcode embeddings.
const MachineFunction & MF
virtual ~MIREmbedder()=default
static std::unique_ptr< MIREmbedder > create(MIR2VecKind Mode, const MachineFunction &MF, const MIRVocabulary &Vocab)
Factory method to create an Embedder object of the specified kind Returns nullptr if the requested ki...
Class for storing and accessing the MIR2Vec vocabulary.
unsigned getDimension() const
unsigned getEntityIDForOpcode(unsigned Opcode) const
Get entity ID (flat index) for an opcode This is used for triplet generation.
const_iterator end() const
unsigned getCanonicalIndexForOperandName(StringRef OperandName) const
const Embedding & operator[](MachineOperand Operand) const
unsigned getCanonicalIndexForRegisterClass(StringRef RegName, bool IsPhysical=true) const
static Expected< MIRVocabulary > create(VocabMap &&OpcMap, VocabMap &&CommonOperandsMap, VocabMap &&PhyRegMap, VocabMap &&VirtRegMap, const TargetInstrInfo &TII, const TargetRegisterInfo &TRI, const MachineRegisterInfo &MRI)
Factory method to create MIRVocabulary from vocabulary map.
static std::string extractBaseOpcodeName(StringRef InstrName)
Static method for extracting base opcode names (public for testing)
ir2vec::VocabStorage::const_iterator const_iterator
const_iterator begin() const
const Embedding & operator[](unsigned Opcode) const
size_t getCanonicalSize() const
Total number of entries in the vocabulary.
static Expected< MIRVocabulary > createDummyVocabForTest(const TargetInstrInfo &TII, const TargetRegisterInfo &TRI, const MachineRegisterInfo &MRI, unsigned Dim=1)
Create a dummy vocabulary for testing purposes.
unsigned getEntityIDForMachineOperand(const MachineOperand &MO) const
Get entity ID (flat index) for a machine operand This is used for triplet generation.
std::string getStringKey(unsigned Pos) const
Get the string key for a vocabulary entry at the given position.
unsigned getCanonicalIndexForBaseName(StringRef BaseName) const
Get indices from opcode or operand names.
Class for computing Symbolic embeddings Symbolic embeddings are constructed based on the entity-level...
static std::unique_ptr< SymbolicMIREmbedder > create(const MachineFunction &MF, const MIRVocabulary &Vocab)
SymbolicMIREmbedder(const MachineFunction &F, const MIRVocabulary &Vocab)
This class implements an extremely fast bulk output stream that can only output to a stream.
DenseMap< const MachineInstr *, Embedding > MachineInstEmbeddingsMap
llvm::cl::OptionCategory MIR2VecCategory
cl::opt< float > OpcWeight
cl::opt< float > RegOperandWeight
ir2vec::Embedding Embedding
DenseMap< const MachineBasicBlock *, Embedding > MachineBlockEmbeddingsMap
cl::opt< float > CommonOperandWeight
This is an optimization pass for GlobalISel generic memory operations.
MachineFunctionPass * createMIR2VecPrinterLegacyPass(raw_ostream &OS)
Create a machine pass that prints MIR2Vec embeddings.
Embedding is a datatype that wraps std::vector<double>.