28#define DEBUG_TYPE "mir2vec"
31 "Number of lookups to MIR entities not present in the vocabulary");
43 cl::desc(
"Weight for machine opcode embeddings"),
50 cl::desc(
"Weight for register operand embeddings"),
55 "Generate symbolic embeddings for MIR")),
67 VocabMap &&PhysicalRegisterMap,
68 VocabMap &&VirtualRegisterMap,
73 buildCanonicalOpcodeMapping();
74 unsigned CanonicalOpcodeCount = UniqueBaseOpcodeNames.size();
75 assert(CanonicalOpcodeCount > 0 &&
76 "No canonical opcodes found for target - invalid vocabulary");
78 buildRegisterOperandMapping();
81 Layout.OpcodeBase = 0;
82 Layout.CommonOperandBase = CanonicalOpcodeCount;
84 Layout.PhyRegBase = Layout.CommonOperandBase + std::size(CommonOperandNames);
85 Layout.VirtRegBase = Layout.PhyRegBase + RegisterOperandNames.size();
87 generateStorage(OpcodeMap, CommonOperandMap, PhysicalRegisterMap,
89 Layout.TotalEntries = Storage.size();
97 if (OpcodeMap.empty() || CommonOperandMap.empty() || PhyRegMap.empty() ||
100 "Empty vocabulary entries provided");
102 MIRVocabulary Vocab(std::move(OpcodeMap), std::move(CommonOperandMap),
103 std::move(PhyRegMap), std::move(
VirtRegMap), TII, TRI,
109 "Failed to create valid vocabulary storage");
111 return std::move(Vocab);
126 assert(!InstrName.
empty() &&
"Instruction name should not be empty");
129 static const Regex BaseOpcodeRegex(
"([a-zA-Z_]+)");
132 if (BaseOpcodeRegex.
match(InstrName, &Matches) && Matches.
size() > 1) {
135 while (!Match.
empty() && Match.
back() ==
'_')
141 return InstrName.
str();
145 assert(!UniqueBaseOpcodeNames.empty() &&
"Canonical mapping not built");
146 auto It = std::find(UniqueBaseOpcodeNames.begin(),
147 UniqueBaseOpcodeNames.end(), BaseName.
str());
148 assert(It != UniqueBaseOpcodeNames.end() &&
149 "Base name not found in unique opcodes");
150 return std::distance(UniqueBaseOpcodeNames.begin(), It);
153unsigned MIRVocabulary::getCanonicalOpcodeIndex(
unsigned Opcode)
const {
154 auto BaseOpcode = extractBaseOpcodeName(
TII.getName(Opcode));
155 return getCanonicalIndexForBaseName(BaseOpcode);
160 auto It = std::find(std::begin(CommonOperandNames),
161 std::end(CommonOperandNames), OperandName);
162 assert(It != std::end(CommonOperandNames) &&
163 "Operand name not found in common operands");
164 return Layout.CommonOperandBase +
165 std::distance(std::begin(CommonOperandNames), It);
170 bool IsPhysical)
const {
171 auto It = std::find(RegisterOperandNames.begin(), RegisterOperandNames.end(),
173 assert(It != RegisterOperandNames.end() &&
174 "Register name not found in register operands");
175 unsigned LocalIndex = std::distance(RegisterOperandNames.begin(), It);
176 return (IsPhysical ? Layout.PhyRegBase : Layout.VirtRegBase) + LocalIndex;
180 assert(Pos < Layout.TotalEntries &&
"Position out of bounds in vocabulary");
183 if (Pos < Layout.CommonOperandBase) {
185 auto It = UniqueBaseOpcodeNames.begin();
186 std::advance(It, Pos);
187 assert(It != UniqueBaseOpcodeNames.end() &&
188 "Canonical index out of bounds in opcode section");
192 auto getLocalIndex = [](
unsigned Pos,
size_t BaseOffset,
size_t Bound,
194 unsigned LocalIndex = Pos - BaseOffset;
195 assert(LocalIndex < Bound && Msg);
200 if (Pos < Layout.PhyRegBase) {
201 unsigned LocalIndex = getLocalIndex(
202 Pos, Layout.CommonOperandBase, std::size(CommonOperandNames),
203 "Local index out of bounds in common operands");
204 return CommonOperandNames[LocalIndex].str();
208 if (Pos < Layout.VirtRegBase) {
209 unsigned LocalIndex =
210 getLocalIndex(Pos, Layout.PhyRegBase, RegisterOperandNames.size(),
211 "Local index out of bounds in physical registers");
212 return "PhyReg_" + RegisterOperandNames[LocalIndex];
216 unsigned LocalIndex =
217 getLocalIndex(Pos, Layout.VirtRegBase, RegisterOperandNames.size(),
218 "Local index out of bounds in virtual registers");
219 return "VirtReg_" + RegisterOperandNames[LocalIndex];
222void MIRVocabulary::generateStorage(
const VocabMap &OpcodeMap,
223 const VocabMap &CommonOperandsMap,
224 const VocabMap &PhyRegMap,
232 <<
"; using zero vector. This will result in an error "
234 ++MIRVocabMissCounter;
238 unsigned EmbeddingDim = OpcodeMap.begin()->second.size();
239 std::vector<Embedding> OpcodeEmbeddings(Layout.CommonOperandBase,
243 for (
auto COpcodeName : UniqueBaseOpcodeNames) {
244 if (
auto It = OpcodeMap.find(COpcodeName); It != OpcodeMap.end()) {
245 auto COpcodeIndex = getCanonicalIndexForBaseName(COpcodeName);
246 assert(COpcodeIndex < Layout.CommonOperandBase &&
247 "Canonical index out of bounds");
248 OpcodeEmbeddings[COpcodeIndex] = It->second;
250 handleMissingEntity(COpcodeName);
255 std::vector<Embedding> CommonOperandEmbeddings(std::size(CommonOperandNames),
257 unsigned OperandIndex = 0;
258 for (
const auto &CommonOperandName : CommonOperandNames) {
259 if (
auto It = CommonOperandsMap.find(CommonOperandName.str());
260 It != CommonOperandsMap.end()) {
261 CommonOperandEmbeddings[OperandIndex] = It->second;
263 handleMissingEntity(CommonOperandName);
269 auto createRegisterEmbeddings = [&](
const VocabMap &RegMap) {
270 std::vector<Embedding> RegEmbeddings(
TRI.getNumRegClasses(),
272 unsigned RegOperandIndex = 0;
273 for (
const auto &RegOperandName : RegisterOperandNames) {
274 if (
auto It = RegMap.find(RegOperandName); It != RegMap.end())
275 RegEmbeddings[RegOperandIndex] = It->second;
277 handleMissingEntity(RegOperandName);
280 return RegEmbeddings;
284 std::vector<Embedding> PhyRegEmbeddings = createRegisterEmbeddings(PhyRegMap);
285 std::vector<Embedding> VirtRegEmbeddings =
289 auto scaleVocabSection = [](std::vector<Embedding> &Embeddings,
294 scaleVocabSection(OpcodeEmbeddings,
OpcWeight);
299 std::vector<std::vector<Embedding>> Sections(
300 static_cast<unsigned>(Section::MaxSections));
301 Sections[
static_cast<unsigned>(Section::Opcodes)] =
302 std::move(OpcodeEmbeddings);
303 Sections[
static_cast<unsigned>(Section::CommonOperands)] =
304 std::move(CommonOperandEmbeddings);
305 Sections[
static_cast<unsigned>(Section::PhyRegisters)] =
306 std::move(PhyRegEmbeddings);
307 Sections[
static_cast<unsigned>(Section::VirtRegisters)] =
308 std::move(VirtRegEmbeddings);
313void MIRVocabulary::buildCanonicalOpcodeMapping() {
315 if (!UniqueBaseOpcodeNames.empty())
319 for (
unsigned Opcode = 0; Opcode <
TII.getNumOpcodes(); ++Opcode) {
320 std::string BaseOpcode = extractBaseOpcodeName(
TII.getName(Opcode));
321 UniqueBaseOpcodeNames.insert(BaseOpcode);
324 LLVM_DEBUG(
dbgs() <<
"MIR2Vec: Built canonical mapping for target with "
325 << UniqueBaseOpcodeNames.size()
326 <<
" unique base opcodes\n");
329void MIRVocabulary::buildRegisterOperandMapping() {
331 if (!RegisterOperandNames.empty())
334 for (
unsigned RC = 0; RC <
TRI.getNumRegClasses(); ++RC) {
341 RegisterOperandNames.push_back(ClassName.
str());
345unsigned MIRVocabulary::getCommonOperandIndex(
348 "Expected non-register operand type");
354unsigned MIRVocabulary::getRegisterOperandIndex(
Register Reg)
const {
355 assert(!RegisterOperandNames.empty() &&
"Register operand mapping not built");
358 "Expected a physical or virtual register");
366 RegClass =
TRI.getMinimalPhysRegClass(
Reg);
368 RegClass =
MRI.getRegClass(
Reg);
371 return RegClass->
getID();
380 assert(Dim > 0 &&
"Dimension must be greater than zero");
382 float DummyVal = 0.1f;
384 VocabMap DummyOpcMap, DummyOperandMap, DummyPhyRegMap, DummyVirtRegMap;
387 for (
unsigned Opcode = 0; Opcode < TII.getNumOpcodes(); ++Opcode) {
389 if (DummyOpcMap.count(BaseOpcode) == 0) {
390 DummyOpcMap[BaseOpcode] =
Embedding(Dim, DummyVal);
396 for (
const auto &CommonOperandName : CommonOperandNames) {
397 DummyOperandMap[CommonOperandName.str()] =
Embedding(Dim, DummyVal);
402 for (
unsigned RC = 0; RC < TRI.getNumRegClasses(); ++RC) {
407 std::string ClassName = TRI.getRegClassName(RegClass);
408 DummyPhyRegMap[ClassName] =
Embedding(Dim, DummyVal);
409 DummyVirtRegMap[ClassName] =
Embedding(Dim, DummyVal);
415 std::move(DummyOpcMap), std::move(DummyOperandMap),
416 std::move(DummyPhyRegMap), std::move(DummyVirtRegMap), TII, TRI, MRI);
425 VocabMap OpcVocab, CommonOperandVocab, PhyRegVocabMap, VirtRegVocabMap;
427 if (
Error Err = readVocabulary(OpcVocab, CommonOperandVocab, PhyRegVocabMap,
429 return std::move(Err);
431 for (
const auto &
F : M) {
432 if (
F.isDeclaration())
435 if (
auto *MF = MMI.getMachineFunction(
F)) {
436 auto &Subtarget = MF->getSubtarget();
437 if (
const auto *
TII = Subtarget.getInstrInfo())
438 if (
const auto *
TRI = Subtarget.getRegisterInfo())
440 std::move(OpcVocab), std::move(CommonOperandVocab),
441 std::move(PhyRegVocabMap), std::move(VirtRegVocabMap), *
TII, *
TRI,
446 "No machine functions found in module");
449Error MIR2VecVocabProvider::readVocabulary(VocabMap &OpcodeVocab,
450 VocabMap &CommonOperandVocab,
451 VocabMap &PhyRegVocabMap,
452 VocabMap &VirtRegVocabMap) {
456 "MIR2Vec vocabulary file path not specified; set it "
457 "using --mir2vec-vocab-path");
463 auto Content = BufOrError.get()->getBuffer();
466 if (!ParsedVocabValue)
469 unsigned OpcodeDim = 0, CommonOperandDim = 0, PhyRegOperandDim = 0,
470 VirtRegOperandDim = 0;
472 "Opcodes", *ParsedVocabValue, OpcodeVocab, OpcodeDim))
476 "CommonOperands", *ParsedVocabValue, CommonOperandVocab,
481 "PhysicalRegisters", *ParsedVocabValue, PhyRegVocabMap,
486 "VirtualRegisters", *ParsedVocabValue, VirtRegVocabMap,
491 if (!(OpcodeDim == CommonOperandDim && CommonOperandDim == PhyRegOperandDim &&
492 PhyRegOperandDim == VirtRegOperandDim)) {
495 "MIR2Vec vocabulary sections have different dimensions");
503 "MIR2Vec Vocabulary Analysis",
false,
true)
509 return "MIR2Vec Vocabulary Analysis";
521 return std::make_unique<SymbolicMIREmbedder>(
MF,
Vocab);
530 const auto &Subtarget =
MF.getSubtarget();
531 const auto *
TII = Subtarget.getInstrInfo();
533 MF.getFunction().getContext().emitError(
534 "MIR2Vec: No TargetInstrInfo available; cannot compute embeddings");
539 for (
const auto &
MI :
MBB) {
541 if (
MI.isDebugInstr())
562std::unique_ptr<SymbolicMIREmbedder>
565 return std::make_unique<SymbolicMIREmbedder>(
MF,
Vocab);
570 if (
MI.isDebugInstr())
578 InstructionEmbedding +=
Vocab[MO];
580 return InstructionEmbedding;
589 "MIR2Vec Vocabulary Printer Pass",
false,
true)
601 auto MIR2VecVocabOrErr =
Analysis.getMIR2VecVocabulary(M);
603 if (!MIR2VecVocabOrErr) {
604 OS <<
"MIR2Vec Vocabulary Printer: Failed to get vocabulary - "
605 <<
toString(MIR2VecVocabOrErr.takeError()) <<
"\n";
609 auto &MIR2VecVocab = *MIR2VecVocabOrErr;
611 for (
const auto &Entry : MIR2VecVocab) {
612 OS <<
"Key: " << MIR2VecVocab.getStringKey(Pos++) <<
": ";
626 "MIR2Vec Embedder Printer Pass",
false,
true)
630 "MIR2Vec Embedder Printer Pass",
false,
true)
635 Analysis.getMIR2VecVocabulary(*MF.getFunction().getParent());
636 assert(VocabOrErr &&
"Failed to get MIR2Vec vocabulary");
637 auto &MIRVocab = *VocabOrErr;
641 OS <<
"Error creating MIR2Vec embeddings for function " << MF.getName()
646 OS <<
"MIR2Vec embeddings for machine function " << MF.getName() <<
":\n";
647 OS <<
"Machine Function vector: ";
648 Emb->getMFunctionVector().print(OS);
650 OS <<
"Machine basic block vectors:\n";
652 OS <<
"Machine basic block: " <<
MBB.getFullName() <<
":\n";
653 Emb->getMBBVector(
MBB).print(OS);
656 OS <<
"Machine instruction vectors:\n";
661 if (
MI.isDebugInstr())
664 OS <<
"Machine instruction: ";
666 Emb->getMInstVector(
MI).print(OS);
unsigned const MachineRegisterInfo * MRI
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
block Block Frequency Analysis
#define clEnumValN(ENUMVAL, FLAGNAME, DESC)
This file builds on the ADT/GraphTraits.h file to build generic depth first graph iterator.
const HexagonInstrInfo * TII
Module.h This file contains the declarations for the Module class.
This file defines the MIR2Vec framework for generating Machine IR embeddings.
Register const TargetRegisterInfo * TRI
#define INITIALIZE_PASS_DEPENDENCY(depName)
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
SmallVector< MachineBasicBlock *, 4 > MBBVector
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Lightweight error class with error context and mandatory checking.
static ErrorSuccess success()
Create a success value.
Tagged union holding either a T or a Error.
Error takeError()
Take ownership of the stored error.
This pass prints the MIR2Vec embeddings for machine functions, basic blocks, and instructions.
MIR2VecPrinterLegacyPass(raw_ostream &OS)
bool runOnMachineFunction(MachineFunction &MF) override
runOnMachineFunction - This method must be overloaded to perform the desired machine code transformat...
Pass to analyze and populate MIR2Vec vocabulary from a module.
This pass prints the embeddings in the MIR2Vec vocabulary.
bool doFinalization(Module &M) override
doFinalization - Virtual method overriden by subclasses to do any necessary clean up after all passes...
bool runOnMachineFunction(MachineFunction &MF) override
runOnMachineFunction - This method must be overloaded to perform the desired machine code transformat...
MIR2VecVocabPrinterLegacyPass(raw_ostream &OS)
Expected< mir2vec::MIRVocabulary > getVocabulary(const Module &M)
MachineFunctionPass - This class adapts the FunctionPass interface to allow convenient creation of pa...
Representation of each machine instruction.
MachineOperand class - Representation of each machine instruction operand.
@ MO_Register
Register operand.
MachineRegisterInfo - Keep track of information for virtual and physical registers,...
static ErrorOr< std::unique_ptr< MemoryBuffer > > getFileOrSTDIN(const Twine &Filename, bool IsText=false, bool RequiresNullTerminator=true, std::optional< Align > Alignment=std::nullopt)
Open the specified file as a MemoryBuffer, or open stdin if the Filename is "-".
A Module instance is used to store all the information related to an LLVM module.
AnalysisType & getAnalysis() const
getAnalysis<AnalysisType>() - This function is used by subclasses to get to the analysis information ...
LLVM_ABI bool match(StringRef String, SmallVectorImpl< StringRef > *Matches=nullptr, std::string *Error=nullptr) const
matches - Match the regex against a given String.
Wrapper class representing virtual and physical registers.
constexpr bool isValid() const
constexpr bool isVirtual() const
Return true if the specified register number is in the virtual register namespace.
constexpr bool isPhysical() const
Return true if the specified register number is in the physical register namespace.
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
StringRef - Represent a constant reference to a string, i.e.
std::string str() const
str - Get the contents as an std::string.
constexpr bool empty() const
empty - Check if the string is empty.
char back() const
back - Get the last character in the string.
StringRef drop_back(size_t N=1) const
Return a StringRef equal to 'this' but with the last N elements dropped.
TargetInstrInfo - Interface to description of machine instruction set.
unsigned getID() const
Return the register class ID number.
TargetRegisterInfo base class - We assume that the target defines a static array of TargetRegisterDes...
Generic storage class for section-based vocabularies.
static Error parseVocabSection(StringRef Key, const json::Value &ParsedVocabValue, VocabMap &TargetVocab, unsigned &Dim)
Parse a vocabulary section from JSON and populate the target vocabulary map.
unsigned getDimension() const
Get vocabulary dimension.
bool isValid() const
Check if vocabulary is valid (has data)
const unsigned Dimension
Dimension of the embeddings; Captured from the vocabulary.
const MIRVocabulary & Vocab
MIREmbedder(const MachineFunction &MF, const MIRVocabulary &Vocab)
Embedding computeEmbeddings() const
Function to compute embeddings.
const MachineFunction & MF
static std::unique_ptr< MIREmbedder > create(MIR2VecKind Mode, const MachineFunction &MF, const MIRVocabulary &Vocab)
Factory method to create an Embedder object of the specified kind Returns nullptr if the requested ki...
Class for storing and accessing the MIR2Vec vocabulary.
unsigned getCanonicalIndexForOperandName(StringRef OperandName) const
unsigned getCanonicalIndexForRegisterClass(StringRef RegName, bool IsPhysical=true) const
static Expected< MIRVocabulary > create(VocabMap &&OpcMap, VocabMap &&CommonOperandsMap, VocabMap &&PhyRegMap, VocabMap &&VirtRegMap, const TargetInstrInfo &TII, const TargetRegisterInfo &TRI, const MachineRegisterInfo &MRI)
Factory method to create MIRVocabulary from vocabulary map.
static std::string extractBaseOpcodeName(StringRef InstrName)
Static method for extracting base opcode names (public for testing)
static Expected< MIRVocabulary > createDummyVocabForTest(const TargetInstrInfo &TII, const TargetRegisterInfo &TRI, const MachineRegisterInfo &MRI, unsigned Dim=1)
Create a dummy vocabulary for testing purposes.
std::string getStringKey(unsigned Pos) const
Get the string key for a vocabulary entry at the given position.
unsigned getCanonicalIndexForBaseName(StringRef BaseName) const
Get indices from opcode or operand names.
static std::unique_ptr< SymbolicMIREmbedder > create(const MachineFunction &MF, const MIRVocabulary &Vocab)
SymbolicMIREmbedder(const MachineFunction &F, const MIRVocabulary &Vocab)
This class implements an extremely fast bulk output stream that can only output to a stream.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
OperandType
Operands are tagged with one of the values of this enum.
ValuesClass values(OptsTy... Options)
Helper to build a ValuesClass by forwarding a variable number of arguments as an initializer list to ...
initializer< Ty > init(const Ty &Val)
LLVM_ABI llvm::Expected< Value > parse(llvm::StringRef JSON)
Parses the provided JSON source, or returns a ParseError.
llvm::cl::OptionCategory MIR2VecCategory
cl::opt< float > OpcWeight
static cl::opt< std::string > VocabFile("mir2vec-vocab-path", cl::Optional, cl::desc("Path to the vocabulary file for MIR2Vec"), cl::init(""), cl::cat(MIR2VecCategory))
cl::opt< float > RegOperandWeight
ir2vec::Embedding Embedding
cl::opt< float > CommonOperandWeight
cl::opt< MIR2VecKind > MIR2VecEmbeddingKind("mir2vec-kind", cl::Optional, cl::values(clEnumValN(MIR2VecKind::Symbolic, "symbolic", "Generate symbolic embeddings for MIR")), cl::init(MIR2VecKind::Symbolic), cl::desc("MIR2Vec embedding kind"), cl::cat(MIR2VecCategory))
This is an optimization pass for GlobalISel generic memory operations.
Error createFileError(const Twine &F, Error E)
Concatenate a source file path and/or name with an Error.
Error createStringError(std::error_code EC, char const *Fmt, const Ts &... Vals)
Create formatted StringError object.
MachineFunctionPass * createMIR2VecPrinterLegacyPass(raw_ostream &OS)
Create a machine pass that prints MIR2Vec embeddings.
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
LLVM_ATTRIBUTE_VISIBILITY_DEFAULT AnalysisKey InnerAnalysisManagerProxy< AnalysisManagerT, IRUnitT, ExtraArgTs... >::Key
LLVM_ABI raw_fd_ostream & errs()
This returns a reference to a raw_ostream for standard error.
LLVM_ABI MachineFunctionPass * createMIR2VecVocabPrinterLegacyPass(raw_ostream &OS)
MIR2VecVocabPrinter pass - This pass prints out the MIR2Vec vocabulary contents to the given stream a...
std::string toString(const APInt &I, unsigned Radix, bool Signed, bool formatAsCLiteral=false, bool UpperCase=true, bool InsertSeparators=false)
iterator_range< df_iterator< T > > depth_first(const T &G)