LLVM 22.0.0git
MIR2Vec.cpp
Go to the documentation of this file.
1//===- MIR2Vec.cpp - Implementation of MIR2Vec ---------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM
4// Exceptions. See the LICENSE file for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// \file
10/// This file implements the MIR2Vec algorithm for Machine IR embeddings.
11///
12//===----------------------------------------------------------------------===//
13
15#include "llvm/ADT/Statistic.h"
17#include "llvm/IR/Module.h"
19#include "llvm/Pass.h"
20#include "llvm/Support/Errc.h"
22#include "llvm/Support/Regex.h"
23
24using namespace llvm;
25using namespace mir2vec;
26
27#define DEBUG_TYPE "mir2vec"
28
29STATISTIC(MIRVocabMissCounter,
30 "Number of lookups to MIR entities not present in the vocabulary");
31
32namespace llvm {
33namespace mir2vec {
35
36// FIXME: Use a default vocab when not specified
38 VocabFile("mir2vec-vocab-path", cl::Optional,
39 cl::desc("Path to the vocabulary file for MIR2Vec"), cl::init(""),
41cl::opt<float> OpcWeight("mir2vec-opc-weight", cl::Optional, cl::init(1.0),
42 cl::desc("Weight for machine opcode embeddings"),
44} // namespace mir2vec
45} // namespace llvm
46
47//===----------------------------------------------------------------------===//
48// Vocabulary Implementation
49//===----------------------------------------------------------------------===//
50
51MIRVocabulary::MIRVocabulary(VocabMap &&OpcodeEntries,
52 const TargetInstrInfo *TII)
53 : TII(*TII) {
54 // Fixme: Use static factory methods for creating vocabularies instead of
55 // public constructors
56 // Early return for invalid inputs - creates empty/invalid vocabulary
57 if (!TII || OpcodeEntries.empty())
58 return;
59
60 buildCanonicalOpcodeMapping();
61
62 unsigned CanonicalOpcodeCount = UniqueBaseOpcodeNames.size();
63 assert(CanonicalOpcodeCount > 0 &&
64 "No canonical opcodes found for target - invalid vocabulary");
65 Layout.OperandBase = CanonicalOpcodeCount;
66 generateStorage(OpcodeEntries);
67 Layout.TotalEntries = Storage.size();
68}
69
71 // Extract base instruction name using regex to capture letters and
72 // underscores Examples: "ADD32rr" -> "ADD", "ARITH_FENCE" -> "ARITH_FENCE"
73 //
74 // TODO: Consider more sophisticated extraction:
75 // - Handle complex prefixes like "AVX1_SETALLONES" correctly (Currently, it
76 // would naively map to "AVX")
77 // - Extract width suffixes (8,16,32,64) as separate features
78 // - Capture addressing mode suffixes (r,i,m,ri,etc.) for better analysis
79 // (Currently, instances like "MOV32mi" map to "MOV", but "ADDPDrr" would map
80 // to "ADDPDrr")
81
82 assert(!InstrName.empty() && "Instruction name should not be empty");
83
84 // Use regex to extract initial sequence of letters and underscores
85 static const Regex BaseOpcodeRegex("([a-zA-Z_]+)");
87
88 if (BaseOpcodeRegex.match(InstrName, &Matches) && Matches.size() > 1) {
89 StringRef Match = Matches[1];
90 // Trim trailing underscores
91 while (!Match.empty() && Match.back() == '_')
92 Match = Match.drop_back();
93 return Match.str();
94 }
95
96 // Fallback to original name if no pattern matches
97 return InstrName.str();
98}
99
101 assert(!UniqueBaseOpcodeNames.empty() && "Canonical mapping not built");
102 auto It = std::find(UniqueBaseOpcodeNames.begin(),
103 UniqueBaseOpcodeNames.end(), BaseName.str());
104 assert(It != UniqueBaseOpcodeNames.end() &&
105 "Base name not found in unique opcodes");
106 return std::distance(UniqueBaseOpcodeNames.begin(), It);
107}
108
109unsigned MIRVocabulary::getCanonicalOpcodeIndex(unsigned Opcode) const {
110 assert(isValid() && "MIR2Vec Vocabulary is invalid");
111 auto BaseOpcode = extractBaseOpcodeName(TII.getName(Opcode));
112 return getCanonicalIndexForBaseName(BaseOpcode);
113}
114
115std::string MIRVocabulary::getStringKey(unsigned Pos) const {
116 assert(isValid() && "MIR2Vec Vocabulary is invalid");
117 assert(Pos < Layout.TotalEntries && "Position out of bounds in vocabulary");
118
119 // For now, all entries are opcodes since we only have one section
120 if (Pos < Layout.OperandBase && Pos < UniqueBaseOpcodeNames.size()) {
121 // Convert canonical index back to base opcode name
122 auto It = UniqueBaseOpcodeNames.begin();
123 std::advance(It, Pos);
124 return *It;
125 }
126
127 llvm_unreachable("Invalid position in vocabulary");
128 return "";
129}
130
131void MIRVocabulary::generateStorage(const VocabMap &OpcodeMap) {
132
133 // Helper for handling missing entities in the vocabulary.
134 // Currently, we use a zero vector. In the future, we will throw an error to
135 // ensure that *all* known entities are present in the vocabulary.
136 auto handleMissingEntity = [](StringRef Key) {
137 LLVM_DEBUG(errs() << "MIR2Vec: Missing vocabulary entry for " << Key
138 << "; using zero vector. This will result in an error "
139 "in the future.\n");
140 ++MIRVocabMissCounter;
141 };
142
143 // Initialize opcode embeddings section
144 unsigned EmbeddingDim = OpcodeMap.begin()->second.size();
145 std::vector<Embedding> OpcodeEmbeddings(Layout.OperandBase,
146 Embedding(EmbeddingDim));
147
148 // Populate opcode embeddings using canonical mapping
149 for (auto COpcodeName : UniqueBaseOpcodeNames) {
150 if (auto It = OpcodeMap.find(COpcodeName); It != OpcodeMap.end()) {
151 auto COpcodeIndex = getCanonicalIndexForBaseName(COpcodeName);
152 assert(COpcodeIndex < Layout.OperandBase &&
153 "Canonical index out of bounds");
154 OpcodeEmbeddings[COpcodeIndex] = It->second;
155 } else {
156 handleMissingEntity(COpcodeName);
157 }
158 }
159
160 // TODO: Add operand/argument embeddings as additional sections
161 // This will require extending the vocabulary format and layout
162
163 // Scale the vocabulary sections based on the provided weights
164 auto scaleVocabSection = [](std::vector<Embedding> &Embeddings,
165 double Weight) {
166 for (auto &Embedding : Embeddings)
167 Embedding *= Weight;
168 };
169 scaleVocabSection(OpcodeEmbeddings, OpcWeight);
170
171 std::vector<std::vector<Embedding>> Sections(1);
172 Sections[0] = std::move(OpcodeEmbeddings);
173
174 Storage = ir2vec::VocabStorage(std::move(Sections));
175}
176
177void MIRVocabulary::buildCanonicalOpcodeMapping() {
178 // Check if already built
179 if (!UniqueBaseOpcodeNames.empty())
180 return;
181
182 // Build mapping from opcodes to canonical base opcode indices
183 for (unsigned Opcode = 0; Opcode < TII.getNumOpcodes(); ++Opcode) {
184 std::string BaseOpcode = extractBaseOpcodeName(TII.getName(Opcode));
185 UniqueBaseOpcodeNames.insert(BaseOpcode);
186 }
187
188 LLVM_DEBUG(dbgs() << "MIR2Vec: Built canonical mapping for target with "
189 << UniqueBaseOpcodeNames.size()
190 << " unique base opcodes\n");
191}
192
193//===----------------------------------------------------------------------===//
194// MIR2VecVocabLegacyAnalysis Implementation
195//===----------------------------------------------------------------------===//
196
199 "MIR2Vec Vocabulary Analysis", false, true)
202 "MIR2Vec Vocabulary Analysis", false, true)
203
204StringRef MIR2VecVocabLegacyAnalysis::getPassName() const {
205 return "MIR2Vec Vocabulary Analysis";
206}
207
208Error MIR2VecVocabLegacyAnalysis::readVocabulary() {
209 // TODO: Extend vocabulary format to support multiple sections
210 // (opcodes, operands, etc.) similar to IR2Vec structure
211 if (VocabFile.empty())
212 return createStringError(
214 "MIR2Vec vocabulary file path not specified; set it "
215 "using --mir2vec-vocab-path");
216
217 auto BufOrError = MemoryBuffer::getFileOrSTDIN(VocabFile, /*IsText=*/true);
218 if (!BufOrError)
219 return createFileError(VocabFile, BufOrError.getError());
220
221 auto Content = BufOrError.get()->getBuffer();
222
223 Expected<json::Value> ParsedVocabValue = json::parse(Content);
224 if (!ParsedVocabValue)
225 return ParsedVocabValue.takeError();
226
227 unsigned Dim = 0;
229 "entities", *ParsedVocabValue, StrVocabMap, Dim))
230 return Err;
231
232 return Error::success();
233}
234
235void MIR2VecVocabLegacyAnalysis::emitError(Error Err, LLVMContext &Ctx) {
236 Ctx.emitError(toString(std::move(Err)));
237}
238
239mir2vec::MIRVocabulary
241 if (StrVocabMap.empty()) {
242 if (Error Err = readVocabulary()) {
243 emitError(std::move(Err), M.getContext());
244 return mir2vec::MIRVocabulary(std::move(StrVocabMap), nullptr);
245 }
246 }
247
248 // Get machine module info to access machine functions and target info
250
251 // Find first available machine function to get target instruction info
252 for (const auto &F : M) {
253 if (F.isDeclaration())
254 continue;
255
256 if (auto *MF = MMI.getMachineFunction(F)) {
257 const TargetInstrInfo *TII = MF->getSubtarget().getInstrInfo();
258 return mir2vec::MIRVocabulary(std::move(StrVocabMap), TII);
259 }
260 }
261
262 // No machine functions available - return invalid vocabulary
263 emitError(make_error<StringError>("No machine functions found in module",
265 M.getContext());
266 return mir2vec::MIRVocabulary(std::move(StrVocabMap), nullptr);
267}
268
269//===----------------------------------------------------------------------===//
270// Printer Passes Implementation
271//===----------------------------------------------------------------------===//
272
275 "MIR2Vec Vocabulary Printer Pass", false, true)
279 "MIR2Vec Vocabulary Printer Pass", false, true)
280
284
287 auto MIR2VecVocab = Analysis.getMIR2VecVocabulary(M);
288
289 if (!MIR2VecVocab.isValid()) {
290 OS << "MIR2Vec Vocabulary Printer: Invalid vocabulary\n";
291 return false;
292 }
293
294 unsigned Pos = 0;
295 for (const auto &Entry : MIR2VecVocab) {
296 OS << "Key: " << MIR2VecVocab.getStringKey(Pos++) << ": ";
297 Entry.print(OS);
298 }
299
300 return false;
301}
302
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
block Block Frequency Analysis
const HexagonInstrInfo * TII
Module.h This file contains the declarations for the Module class.
#define F(x, y, z)
Definition MD5.cpp:55
This file defines the MIR2Vec vocabulary analysis(MIR2VecVocabLegacyAnalysis), the core mir2vec::MIRE...
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition PassSupport.h:42
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition PassSupport.h:44
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition PassSupport.h:39
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition Statistic.h:171
#define LLVM_DEBUG(...)
Definition Debug.h:114
Lightweight error class with error context and mandatory checking.
Definition Error.h:159
static ErrorSuccess success()
Create a success value.
Definition Error.h:336
Error takeError()
Take ownership of the stored error.
Definition Error.h:612
LLVM_ABI void emitError(const Instruction *I, const Twine &ErrorStr)
emitError - Emit an error message to the currently installed error handler with optional location inf...
Pass to analyze and populate MIR2Vec vocabulary from a module.
Definition MIR2Vec.h:140
mir2vec::MIRVocabulary getMIR2VecVocabulary(const Module &M)
Definition MIR2Vec.cpp:240
This pass prints the embeddings in the MIR2Vec vocabulary.
Definition MIR2Vec.h:163
bool doFinalization(Module &M) override
doFinalization - Virtual method overriden by subclasses to do any necessary clean up after all passes...
Definition MIR2Vec.cpp:285
bool runOnMachineFunction(MachineFunction &MF) override
runOnMachineFunction - This method must be overloaded to perform the desired machine code transformat...
Definition MIR2Vec.cpp:281
MIR2VecVocabPrinterLegacyPass(raw_ostream &OS)
Definition MIR2Vec.h:168
MachineFunctionPass - This class adapts the FunctionPass interface to allow convenient creation of pa...
This class contains meta information specific to a module.
LLVM_ABI MachineFunction * getMachineFunction(const Function &F) const
Returns the MachineFunction associated to IR function F if there is one, otherwise nullptr.
static ErrorOr< std::unique_ptr< MemoryBuffer > > getFileOrSTDIN(const Twine &Filename, bool IsText=false, bool RequiresNullTerminator=true, std::optional< Align > Alignment=std::nullopt)
Open the specified file as a MemoryBuffer, or open stdin if the Filename is "-".
A Module instance is used to store all the information related to an LLVM module.
Definition Module.h:67
AnalysisType & getAnalysis() const
getAnalysis<AnalysisType>() - This function is used by subclasses to get to the analysis information ...
LLVM_ABI bool match(StringRef String, SmallVectorImpl< StringRef > *Matches=nullptr, std::string *Error=nullptr) const
matches - Match the regex against a given String.
Definition Regex.cpp:83
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
StringRef - Represent a constant reference to a string, i.e.
Definition StringRef.h:55
std::string str() const
str - Get the contents as an std::string.
Definition StringRef.h:225
constexpr bool empty() const
empty - Check if the string is empty.
Definition StringRef.h:143
char back() const
back - Get the last character in the string.
Definition StringRef.h:155
StringRef drop_back(size_t N=1) const
Return a StringRef equal to 'this' but with the last N elements dropped.
Definition StringRef.h:618
TargetInstrInfo - Interface to description of machine instruction set.
static Error parseVocabSection(StringRef Key, const json::Value &ParsedVocabValue, VocabMap &TargetVocab, unsigned &Dim)
Parse a vocabulary section from JSON and populate the target vocabulary map.
Definition IR2Vec.cpp:335
Class for storing and accessing the MIR2Vec vocabulary.
Definition MIR2Vec.h:62
static std::string extractBaseOpcodeName(StringRef InstrName)
Static method for extracting base opcode names (public for testing)
Definition MIR2Vec.cpp:70
std::string getStringKey(unsigned Pos) const
Get the string key for a vocabulary entry at the given position.
Definition MIR2Vec.cpp:115
unsigned getCanonicalIndexForBaseName(StringRef BaseName) const
Get canonical index for base name (public for testing)
Definition MIR2Vec.cpp:100
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition raw_ostream.h:53
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
initializer< Ty > init(const Ty &Val)
LLVM_ABI llvm::Expected< Value > parse(llvm::StringRef JSON)
Parses the provided JSON source, or returns a ParseError.
Definition JSON.cpp:675
llvm::cl::OptionCategory MIR2VecCategory
cl::opt< float > OpcWeight
static cl::opt< std::string > VocabFile("mir2vec-vocab-path", cl::Optional, cl::desc("Path to the vocabulary file for MIR2Vec"), cl::init(""), cl::cat(MIR2VecCategory))
ir2vec::Embedding Embedding
Definition MIR2Vec.h:58
This is an optimization pass for GlobalISel generic memory operations.
Error createFileError(const Twine &F, Error E)
Concatenate a source file path and/or name with an Error.
Definition Error.h:1399
LLVM_ABI std::error_code inconvertibleErrorCode()
The value returned by this function can be returned from convertToErrorCode for Error values where no...
Definition Error.cpp:98
Error createStringError(std::error_code EC, char const *Fmt, const Ts &... Vals)
Create formatted StringError object.
Definition Error.h:1305
@ invalid_argument
Definition Errc.h:56
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition Debug.cpp:207
LLVM_ATTRIBUTE_VISIBILITY_DEFAULT AnalysisKey InnerAnalysisManagerProxy< AnalysisManagerT, IRUnitT, ExtraArgTs... >::Key
Error make_error(ArgTs &&... Args)
Make a Error instance representing failure using the given error info type.
Definition Error.h:340
LLVM_ABI raw_fd_ostream & errs()
This returns a reference to a raw_ostream for standard error.
LLVM_ABI MachineFunctionPass * createMIR2VecVocabPrinterLegacyPass(raw_ostream &OS)
MIR2VecVocabPrinter pass - This pass prints out the MIR2Vec vocabulary contents to the given stream a...
Definition MIR2Vec.cpp:304
std::string toString(const APInt &I, unsigned Radix, bool Signed, bool formatAsCLiteral=false, bool UpperCase=true, bool InsertSeparators=false)