LLVM 22.0.0git
MIR2Vec.h
Go to the documentation of this file.
1//===- MIR2Vec.h - Implementation of MIR2Vec ------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM
4// Exceptions. See the LICENSE file for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// \file
10/// This file defines the MIR2Vec vocabulary
11/// analysis(MIR2VecVocabLegacyAnalysis), the core mir2vec::MIREmbedder
12/// interface for generating Machine IR embeddings, and related utilities.
13///
14/// MIR2Vec extends IR2Vec to support Machine IR embeddings. It represents the
15/// LLVM Machine IR as embeddings which can be used as input to machine learning
16/// algorithms.
17///
18/// The original idea of MIR2Vec is described in the following paper:
19///
20/// RL4ReAl: Reinforcement Learning for Register Allocation. S. VenkataKeerthy,
21/// Siddharth Jain, Anilava Kundu, Rohit Aggarwal, Albert Cohen, and Ramakrishna
22/// Upadrasta. 2023. RL4ReAl: Reinforcement Learning for Register Allocation.
23/// Proceedings of the 32nd ACM SIGPLAN International Conference on Compiler
24/// Construction (CC 2023). https://doi.org/10.1145/3578360.3580273.
25/// https://arxiv.org/abs/2204.02013
26///
27//===----------------------------------------------------------------------===//
28
29#ifndef LLVM_CODEGEN_MIR2VEC_H
30#define LLVM_CODEGEN_MIR2VEC_H
31
38#include "llvm/IR/PassManager.h"
39#include "llvm/Pass.h"
42#include <map>
43#include <set>
44#include <string>
45
46namespace llvm {
47
48class Module;
49class raw_ostream;
50class LLVMContext;
52class TargetInstrInfo;
53
54namespace mir2vec {
57
59
60/// Class for storing and accessing the MIR2Vec vocabulary.
61/// The MIRVocabulary class manages seed embeddings for LLVM Machine IR
64 using VocabMap = std::map<std::string, ir2vec::Embedding>;
65
66private:
67 // Define vocabulary layout - adapted for MIR
68 struct {
69 size_t OpcodeBase = 0;
70 size_t OperandBase = 0;
71 size_t TotalEntries = 0;
72 } Layout;
73
74 enum class Section : unsigned { Opcodes = 0, MaxSections };
75
77 mutable std::set<std::string> UniqueBaseOpcodeNames;
78 const TargetInstrInfo &TII;
79 void generateStorage(const VocabMap &OpcodeMap);
80 void buildCanonicalOpcodeMapping();
81
82 /// Get canonical index for a machine opcode
83 unsigned getCanonicalOpcodeIndex(unsigned Opcode) const;
84
85public:
86 /// Static method for extracting base opcode names (public for testing)
87 static std::string extractBaseOpcodeName(StringRef InstrName);
88
89 /// Get canonical index for base name (public for testing)
90 unsigned getCanonicalIndexForBaseName(StringRef BaseName) const;
91
92 /// Get the string key for a vocabulary entry at the given position
93 std::string getStringKey(unsigned Pos) const;
94
95 MIRVocabulary() = delete;
96 MIRVocabulary(VocabMap &&Entries, const TargetInstrInfo *TII);
98 : Storage(std::move(Storage)), TII(TII) {}
99
100 bool isValid() const {
101 return UniqueBaseOpcodeNames.size() > 0 &&
102 Layout.TotalEntries == Storage.size() && Storage.isValid();
103 }
104
105 unsigned getDimension() const {
106 if (!isValid())
107 return 0;
108 return Storage.getDimension();
109 }
110
111 // Accessor methods
112 const Embedding &operator[](unsigned Opcode) const {
113 assert(isValid() && "MIR2Vec Vocabulary is invalid");
114 unsigned LocalIndex = getCanonicalOpcodeIndex(Opcode);
115 return Storage[static_cast<unsigned>(Section::Opcodes)][LocalIndex];
116 }
117
118 // Iterator access
121 assert(isValid() && "MIR2Vec Vocabulary is invalid");
122 return Storage.begin();
123 }
124
126 assert(isValid() && "MIR2Vec Vocabulary is invalid");
127 return Storage.end();
128 }
129
130 /// Total number of entries in the vocabulary
131 size_t getCanonicalSize() const {
132 assert(isValid() && "Invalid vocabulary");
133 return Storage.size();
134 }
135};
136
137} // namespace mir2vec
138
139/// Pass to analyze and populate MIR2Vec vocabulary from a module
141 using VocabVector = std::vector<mir2vec::Embedding>;
142 using VocabMap = std::map<std::string, mir2vec::Embedding>;
143 VocabMap StrVocabMap;
144 VocabVector Vocab;
145
146 StringRef getPassName() const override;
147 Error readVocabulary();
148 void emitError(Error Err, LLVMContext &Ctx);
149
150protected:
151 void getAnalysisUsage(AnalysisUsage &AU) const override {
153 AU.setPreservesAll();
154 }
155
156public:
157 static char ID;
160};
161
162/// This pass prints the embeddings in the MIR2Vec vocabulary
164 raw_ostream &OS;
165
166public:
167 static char ID;
170
171 bool runOnMachineFunction(MachineFunction &MF) override;
172 bool doFinalization(Module &M) override;
178
179 StringRef getPassName() const override {
180 return "MIR2Vec Vocabulary Printer Pass";
181 }
182};
183
184} // namespace llvm
185
186#endif // LLVM_CODEGEN_MIR2VEC_H
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
Provides ErrorOr<T> smart pointer.
const HexagonInstrInfo * TII
This file defines the IR2Vec vocabulary analysis(IR2VecVocabAnalysis), the core ir2vec::Embedder inte...
This header defines various interfaces for pass management in LLVM.
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
void setPreservesAll()
Set by analyses that do not transform their input at all.
Lightweight error class with error context and mandatory checking.
Definition Error.h:159
ImmutablePass(char &pid)
Definition Pass.h:287
This is an important class for using LLVM in a threaded context.
Definition LLVMContext.h:68
Pass to analyze and populate MIR2Vec vocabulary from a module.
Definition MIR2Vec.h:140
mir2vec::MIRVocabulary getMIR2VecVocabulary(const Module &M)
Definition MIR2Vec.cpp:240
void getAnalysisUsage(AnalysisUsage &AU) const override
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
Definition MIR2Vec.h:151
StringRef getPassName() const override
getPassName - Return a nice clean name for a pass.
Definition MIR2Vec.h:179
bool doFinalization(Module &M) override
doFinalization - Virtual method overriden by subclasses to do any necessary clean up after all passes...
Definition MIR2Vec.cpp:285
bool runOnMachineFunction(MachineFunction &MF) override
runOnMachineFunction - This method must be overloaded to perform the desired machine code transformat...
Definition MIR2Vec.cpp:281
MIR2VecVocabPrinterLegacyPass(raw_ostream &OS)
Definition MIR2Vec.h:168
void getAnalysisUsage(AnalysisUsage &AU) const override
getAnalysisUsage - Subclasses that override getAnalysisUsage must call this.
Definition MIR2Vec.h:173
void getAnalysisUsage(AnalysisUsage &AU) const override
getAnalysisUsage - Subclasses that override getAnalysisUsage must call this.
A Module instance is used to store all the information related to an LLVM module.
Definition Module.h:67
virtual StringRef getPassName() const
getPassName - Return a nice clean name for a pass.
Definition Pass.cpp:85
StringRef - Represent a constant reference to a string, i.e.
Definition StringRef.h:55
TargetInstrInfo - Interface to description of machine instruction set.
Iterator support for section-based access.
Definition IR2Vec.h:196
Generic storage class for section-based vocabularies.
Definition IR2Vec.h:151
Class for storing and accessing the MIR2Vec vocabulary.
Definition MIR2Vec.h:62
unsigned getDimension() const
Definition MIR2Vec.h:105
const_iterator end() const
Definition MIR2Vec.h:125
static std::string extractBaseOpcodeName(StringRef InstrName)
Static method for extracting base opcode names (public for testing)
Definition MIR2Vec.cpp:70
ir2vec::VocabStorage::const_iterator const_iterator
Definition MIR2Vec.h:119
const_iterator begin() const
Definition MIR2Vec.h:120
const Embedding & operator[](unsigned Opcode) const
Definition MIR2Vec.h:112
size_t getCanonicalSize() const
Total number of entries in the vocabulary.
Definition MIR2Vec.h:131
MIRVocabulary(ir2vec::VocabStorage &&Storage, const TargetInstrInfo &TII)
Definition MIR2Vec.h:97
std::string getStringKey(unsigned Pos) const
Get the string key for a vocabulary entry at the given position.
Definition MIR2Vec.cpp:115
unsigned getCanonicalIndexForBaseName(StringRef BaseName) const
Get canonical index for base name (public for testing)
Definition MIR2Vec.cpp:100
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition raw_ostream.h:53
llvm::cl::OptionCategory MIR2VecCategory
cl::opt< float > OpcWeight
ir2vec::Embedding Embedding
Definition MIR2Vec.h:58
This is an optimization pass for GlobalISel generic memory operations.
OutputIt move(R &&Range, OutputIt Out)
Provide wrappers to std::move which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1869
Implement std::hash so that hash_code can be used in STL containers.
Definition BitVector.h:867
Embedding is a datatype that wraps std::vector<double>.
Definition IR2Vec.h:87