LLVM 22.0.0git
MIR2Vec.cpp
Go to the documentation of this file.
1//===- MIR2Vec.cpp - Implementation of MIR2Vec ---------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM
4// Exceptions. See the LICENSE file for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// \file
10/// This file implements the MIR2Vec algorithm for Machine IR embeddings.
11///
12//===----------------------------------------------------------------------===//
13
16#include "llvm/ADT/Statistic.h"
18#include "llvm/IR/Module.h"
20#include "llvm/Pass.h"
21#include "llvm/Support/Errc.h"
23#include "llvm/Support/Regex.h"
24
25using namespace llvm;
26using namespace mir2vec;
27
28#define DEBUG_TYPE "mir2vec"
29
30STATISTIC(MIRVocabMissCounter,
31 "Number of lookups to MIR entities not present in the vocabulary");
32
33namespace llvm {
34namespace mir2vec {
36
37// FIXME: Use a default vocab when not specified
39 VocabFile("mir2vec-vocab-path", cl::Optional,
40 cl::desc("Path to the vocabulary file for MIR2Vec"), cl::init(""),
42cl::opt<float> OpcWeight("mir2vec-opc-weight", cl::Optional, cl::init(1.0),
43 cl::desc("Weight for machine opcode embeddings"),
46 "mir2vec-common-operand-weight", cl::Optional, cl::init(1.0),
47 cl::desc("Weight for common operand embeddings"), cl::cat(MIR2VecCategory));
49 RegOperandWeight("mir2vec-reg-operand-weight", cl::Optional, cl::init(1.0),
50 cl::desc("Weight for register operand embeddings"),
53 "mir2vec-kind", cl::Optional,
55 "Generate symbolic embeddings for MIR")),
56 cl::init(MIR2VecKind::Symbolic), cl::desc("MIR2Vec embedding kind"),
58
59} // namespace mir2vec
60} // namespace llvm
61
62//===----------------------------------------------------------------------===//
63// Vocabulary
64//===----------------------------------------------------------------------===//
65
66MIRVocabulary::MIRVocabulary(VocabMap &&OpcodeMap, VocabMap &&CommonOperandMap,
67 VocabMap &&PhysicalRegisterMap,
68 VocabMap &&VirtualRegisterMap,
69 const TargetInstrInfo &TII,
72 : TII(TII), TRI(TRI), MRI(MRI) {
73 buildCanonicalOpcodeMapping();
74 unsigned CanonicalOpcodeCount = UniqueBaseOpcodeNames.size();
75 assert(CanonicalOpcodeCount > 0 &&
76 "No canonical opcodes found for target - invalid vocabulary");
77
78 buildRegisterOperandMapping();
79
80 // Define layout of vocabulary sections
81 Layout.OpcodeBase = 0;
82 Layout.CommonOperandBase = CanonicalOpcodeCount;
83 // We expect same classes for physical and virtual registers
84 Layout.PhyRegBase = Layout.CommonOperandBase + std::size(CommonOperandNames);
85 Layout.VirtRegBase = Layout.PhyRegBase + RegisterOperandNames.size();
86
87 generateStorage(OpcodeMap, CommonOperandMap, PhysicalRegisterMap,
88 VirtualRegisterMap);
89 Layout.TotalEntries = Storage.size();
90}
91
93MIRVocabulary::create(VocabMap &&OpcodeMap, VocabMap &&CommonOperandMap,
94 VocabMap &&PhyRegMap, VocabMap &&VirtRegMap,
95 const TargetInstrInfo &TII, const TargetRegisterInfo &TRI,
96 const MachineRegisterInfo &MRI) {
97 if (OpcodeMap.empty() || CommonOperandMap.empty() || PhyRegMap.empty() ||
98 VirtRegMap.empty())
100 "Empty vocabulary entries provided");
101
102 MIRVocabulary Vocab(std::move(OpcodeMap), std::move(CommonOperandMap),
103 std::move(PhyRegMap), std::move(VirtRegMap), TII, TRI,
104 MRI);
105
106 // Validate Storage after construction
107 if (!Vocab.Storage.isValid())
109 "Failed to create valid vocabulary storage");
110 Vocab.ZeroEmbedding = Embedding(Vocab.Storage.getDimension(), 0.0);
111 return std::move(Vocab);
112}
113
115 // Extract base instruction name using regex to capture letters and
116 // underscores Examples: "ADD32rr" -> "ADD", "ARITH_FENCE" -> "ARITH_FENCE"
117 //
118 // TODO: Consider more sophisticated extraction:
119 // - Handle complex prefixes like "AVX1_SETALLONES" correctly (Currently, it
120 // would naively map to "AVX")
121 // - Extract width suffixes (8,16,32,64) as separate features
122 // - Capture addressing mode suffixes (r,i,m,ri,etc.) for better analysis
123 // (Currently, instances like "MOV32mi" map to "MOV", but "ADDPDrr" would map
124 // to "ADDPDrr")
125
126 assert(!InstrName.empty() && "Instruction name should not be empty");
127
128 // Use regex to extract initial sequence of letters and underscores
129 static const Regex BaseOpcodeRegex("([a-zA-Z_]+)");
131
132 if (BaseOpcodeRegex.match(InstrName, &Matches) && Matches.size() > 1) {
133 StringRef Match = Matches[1];
134 // Trim trailing underscores
135 while (!Match.empty() && Match.back() == '_')
136 Match = Match.drop_back();
137 return Match.str();
138 }
139
140 // Fallback to original name if no pattern matches
141 return InstrName.str();
142}
143
145 assert(!UniqueBaseOpcodeNames.empty() && "Canonical mapping not built");
146 auto It = std::find(UniqueBaseOpcodeNames.begin(),
147 UniqueBaseOpcodeNames.end(), BaseName.str());
148 assert(It != UniqueBaseOpcodeNames.end() &&
149 "Base name not found in unique opcodes");
150 return std::distance(UniqueBaseOpcodeNames.begin(), It);
151}
152
153unsigned MIRVocabulary::getCanonicalOpcodeIndex(unsigned Opcode) const {
154 auto BaseOpcode = extractBaseOpcodeName(TII.getName(Opcode));
155 return getCanonicalIndexForBaseName(BaseOpcode);
156}
157
158unsigned
160 auto It = std::find(std::begin(CommonOperandNames),
161 std::end(CommonOperandNames), OperandName);
162 assert(It != std::end(CommonOperandNames) &&
163 "Operand name not found in common operands");
164 return Layout.CommonOperandBase +
165 std::distance(std::begin(CommonOperandNames), It);
166}
167
168unsigned
170 bool IsPhysical) const {
171 auto It = std::find(RegisterOperandNames.begin(), RegisterOperandNames.end(),
172 RegName);
173 assert(It != RegisterOperandNames.end() &&
174 "Register name not found in register operands");
175 unsigned LocalIndex = std::distance(RegisterOperandNames.begin(), It);
176 return (IsPhysical ? Layout.PhyRegBase : Layout.VirtRegBase) + LocalIndex;
177}
178
179std::string MIRVocabulary::getStringKey(unsigned Pos) const {
180 assert(Pos < Layout.TotalEntries && "Position out of bounds in vocabulary");
181
182 // Handle opcodes section
183 if (Pos < Layout.CommonOperandBase) {
184 // Convert canonical index back to base opcode name
185 auto It = UniqueBaseOpcodeNames.begin();
186 std::advance(It, Pos);
187 assert(It != UniqueBaseOpcodeNames.end() &&
188 "Canonical index out of bounds in opcode section");
189 return *It;
190 }
191
192 auto getLocalIndex = [](unsigned Pos, size_t BaseOffset, size_t Bound,
193 const char *Msg) {
194 unsigned LocalIndex = Pos - BaseOffset;
195 assert(LocalIndex < Bound && Msg);
196 return LocalIndex;
197 };
198
199 // Handle common operands section
200 if (Pos < Layout.PhyRegBase) {
201 unsigned LocalIndex = getLocalIndex(
202 Pos, Layout.CommonOperandBase, std::size(CommonOperandNames),
203 "Local index out of bounds in common operands");
204 return CommonOperandNames[LocalIndex].str();
205 }
206
207 // Handle physical registers section
208 if (Pos < Layout.VirtRegBase) {
209 unsigned LocalIndex =
210 getLocalIndex(Pos, Layout.PhyRegBase, RegisterOperandNames.size(),
211 "Local index out of bounds in physical registers");
212 return "PhyReg_" + RegisterOperandNames[LocalIndex];
213 }
214
215 // Handle virtual registers section
216 unsigned LocalIndex =
217 getLocalIndex(Pos, Layout.VirtRegBase, RegisterOperandNames.size(),
218 "Local index out of bounds in virtual registers");
219 return "VirtReg_" + RegisterOperandNames[LocalIndex];
220}
221
222void MIRVocabulary::generateStorage(const VocabMap &OpcodeMap,
223 const VocabMap &CommonOperandsMap,
224 const VocabMap &PhyRegMap,
225 const VocabMap &VirtRegMap) {
226
227 // Helper for handling missing entities in the vocabulary.
228 // Currently, we use a zero vector. In the future, we will throw an error to
229 // ensure that *all* known entities are present in the vocabulary.
230 auto handleMissingEntity = [](StringRef Key) {
231 LLVM_DEBUG(errs() << "MIR2Vec: Missing vocabulary entry for " << Key
232 << "; using zero vector. This will result in an error "
233 "in the future.\n");
234 ++MIRVocabMissCounter;
235 };
236
237 // Initialize opcode embeddings section
238 unsigned EmbeddingDim = OpcodeMap.begin()->second.size();
239 std::vector<Embedding> OpcodeEmbeddings(Layout.CommonOperandBase,
240 Embedding(EmbeddingDim));
241
242 // Populate opcode embeddings using canonical mapping
243 for (auto COpcodeName : UniqueBaseOpcodeNames) {
244 if (auto It = OpcodeMap.find(COpcodeName); It != OpcodeMap.end()) {
245 auto COpcodeIndex = getCanonicalIndexForBaseName(COpcodeName);
246 assert(COpcodeIndex < Layout.CommonOperandBase &&
247 "Canonical index out of bounds");
248 OpcodeEmbeddings[COpcodeIndex] = It->second;
249 } else {
250 handleMissingEntity(COpcodeName);
251 }
252 }
253
254 // Initialize common operand embeddings section
255 std::vector<Embedding> CommonOperandEmbeddings(std::size(CommonOperandNames),
256 Embedding(EmbeddingDim));
257 unsigned OperandIndex = 0;
258 for (const auto &CommonOperandName : CommonOperandNames) {
259 if (auto It = CommonOperandsMap.find(CommonOperandName.str());
260 It != CommonOperandsMap.end()) {
261 CommonOperandEmbeddings[OperandIndex] = It->second;
262 } else {
263 handleMissingEntity(CommonOperandName);
264 }
265 ++OperandIndex;
266 }
267
268 // Helper lambda for creating register operand embeddings
269 auto createRegisterEmbeddings = [&](const VocabMap &RegMap) {
270 std::vector<Embedding> RegEmbeddings(TRI.getNumRegClasses(),
271 Embedding(EmbeddingDim));
272 unsigned RegOperandIndex = 0;
273 for (const auto &RegOperandName : RegisterOperandNames) {
274 if (auto It = RegMap.find(RegOperandName); It != RegMap.end())
275 RegEmbeddings[RegOperandIndex] = It->second;
276 else
277 handleMissingEntity(RegOperandName);
278 ++RegOperandIndex;
279 }
280 return RegEmbeddings;
281 };
282
283 // Initialize register operand embeddings sections
284 std::vector<Embedding> PhyRegEmbeddings = createRegisterEmbeddings(PhyRegMap);
285 std::vector<Embedding> VirtRegEmbeddings =
286 createRegisterEmbeddings(VirtRegMap);
287
288 // Scale the vocabulary sections based on the provided weights
289 auto scaleVocabSection = [](std::vector<Embedding> &Embeddings,
290 double Weight) {
291 for (auto &Embedding : Embeddings)
292 Embedding *= Weight;
293 };
294 scaleVocabSection(OpcodeEmbeddings, OpcWeight);
295 scaleVocabSection(CommonOperandEmbeddings, CommonOperandWeight);
296 scaleVocabSection(PhyRegEmbeddings, RegOperandWeight);
297 scaleVocabSection(VirtRegEmbeddings, RegOperandWeight);
298
299 std::vector<std::vector<Embedding>> Sections(
300 static_cast<unsigned>(Section::MaxSections));
301 Sections[static_cast<unsigned>(Section::Opcodes)] =
302 std::move(OpcodeEmbeddings);
303 Sections[static_cast<unsigned>(Section::CommonOperands)] =
304 std::move(CommonOperandEmbeddings);
305 Sections[static_cast<unsigned>(Section::PhyRegisters)] =
306 std::move(PhyRegEmbeddings);
307 Sections[static_cast<unsigned>(Section::VirtRegisters)] =
308 std::move(VirtRegEmbeddings);
309
310 Storage = ir2vec::VocabStorage(std::move(Sections));
311}
312
313void MIRVocabulary::buildCanonicalOpcodeMapping() {
314 // Check if already built
315 if (!UniqueBaseOpcodeNames.empty())
316 return;
317
318 // Build mapping from opcodes to canonical base opcode indices
319 for (unsigned Opcode = 0; Opcode < TII.getNumOpcodes(); ++Opcode) {
320 std::string BaseOpcode = extractBaseOpcodeName(TII.getName(Opcode));
321 UniqueBaseOpcodeNames.insert(BaseOpcode);
322 }
323
324 LLVM_DEBUG(dbgs() << "MIR2Vec: Built canonical mapping for target with "
325 << UniqueBaseOpcodeNames.size()
326 << " unique base opcodes\n");
327}
328
329void MIRVocabulary::buildRegisterOperandMapping() {
330 // Check if already built
331 if (!RegisterOperandNames.empty())
332 return;
333
334 for (unsigned RC = 0; RC < TRI.getNumRegClasses(); ++RC) {
335 const TargetRegisterClass *RegClass = TRI.getRegClass(RC);
336 if (!RegClass)
337 continue;
338
339 // Get the register class name
340 StringRef ClassName = TRI.getRegClassName(RegClass);
341 RegisterOperandNames.push_back(ClassName.str());
342 }
343}
344
345unsigned MIRVocabulary::getCommonOperandIndex(
346 MachineOperand::MachineOperandType OperandType) const {
347 assert(OperandType != MachineOperand::MO_Register &&
348 "Expected non-register operand type");
349 assert(OperandType > MachineOperand::MO_Register &&
350 OperandType < MachineOperand::MO_Last && "Operand type out of bounds");
351 return static_cast<unsigned>(OperandType) - 1;
352}
353
354unsigned MIRVocabulary::getRegisterOperandIndex(Register Reg) const {
355 assert(!RegisterOperandNames.empty() && "Register operand mapping not built");
356 assert(Reg.isValid() && "Invalid register; not expected here");
357 assert((Reg.isPhysical() || Reg.isVirtual()) &&
358 "Expected a physical or virtual register");
359
360 const TargetRegisterClass *RegClass = nullptr;
361
362 // For physical registers, use TRI to get minimal register class as a
363 // physical register can belong to multiple classes. For virtual
364 // registers, use MRI to uniquely identify the assigned register class.
365 if (Reg.isPhysical())
366 RegClass = TRI.getMinimalPhysRegClass(Reg);
367 else
368 RegClass = MRI.getRegClass(Reg);
369
370 if (RegClass)
371 return RegClass->getID();
372 // Fallback for registers without a class (shouldn't happen)
373 llvm_unreachable("Register operand without a valid register class");
374 return 0;
375}
376
378 const TargetInstrInfo &TII, const TargetRegisterInfo &TRI,
379 const MachineRegisterInfo &MRI, unsigned Dim) {
380 assert(Dim > 0 && "Dimension must be greater than zero");
381
382 float DummyVal = 0.1f;
383
384 VocabMap DummyOpcMap, DummyOperandMap, DummyPhyRegMap, DummyVirtRegMap;
385
386 // Process opcodes directly without creating temporary vocabulary
387 for (unsigned Opcode = 0; Opcode < TII.getNumOpcodes(); ++Opcode) {
388 std::string BaseOpcode = extractBaseOpcodeName(TII.getName(Opcode));
389 if (DummyOpcMap.count(BaseOpcode) == 0) { // Only add if not already present
390 DummyOpcMap[BaseOpcode] = Embedding(Dim, DummyVal);
391 DummyVal += 0.1f;
392 }
393 }
394
395 // Add common operands
396 for (const auto &CommonOperandName : CommonOperandNames) {
397 DummyOperandMap[CommonOperandName.str()] = Embedding(Dim, DummyVal);
398 DummyVal += 0.1f;
399 }
400
401 // Process register classes directly
402 for (unsigned RC = 0; RC < TRI.getNumRegClasses(); ++RC) {
403 const TargetRegisterClass *RegClass = TRI.getRegClass(RC);
404 if (!RegClass)
405 continue;
406
407 std::string ClassName = TRI.getRegClassName(RegClass);
408 DummyPhyRegMap[ClassName] = Embedding(Dim, DummyVal);
409 DummyVirtRegMap[ClassName] = Embedding(Dim, DummyVal);
410 DummyVal += 0.1f;
411 }
412
413 // Create vocabulary directly without temporary instance
415 std::move(DummyOpcMap), std::move(DummyOperandMap),
416 std::move(DummyPhyRegMap), std::move(DummyVirtRegMap), TII, TRI, MRI);
417}
418
419//===----------------------------------------------------------------------===//
420// MIR2VecVocabProvider and MIR2VecVocabLegacyAnalysis
421//===----------------------------------------------------------------------===//
422
425 VocabMap OpcVocab, CommonOperandVocab, PhyRegVocabMap, VirtRegVocabMap;
426
427 if (Error Err = readVocabulary(OpcVocab, CommonOperandVocab, PhyRegVocabMap,
428 VirtRegVocabMap))
429 return std::move(Err);
430
431 for (const auto &F : M) {
432 if (F.isDeclaration())
433 continue;
434
435 if (auto *MF = MMI.getMachineFunction(F)) {
436 auto &Subtarget = MF->getSubtarget();
437 if (const auto *TII = Subtarget.getInstrInfo())
438 if (const auto *TRI = Subtarget.getRegisterInfo())
440 std::move(OpcVocab), std::move(CommonOperandVocab),
441 std::move(PhyRegVocabMap), std::move(VirtRegVocabMap), *TII, *TRI,
442 MF->getRegInfo());
443 }
444 }
446 "No machine functions found in module");
447}
448
449Error MIR2VecVocabProvider::readVocabulary(VocabMap &OpcodeVocab,
450 VocabMap &CommonOperandVocab,
451 VocabMap &PhyRegVocabMap,
452 VocabMap &VirtRegVocabMap) {
453 if (VocabFile.empty())
454 return createStringError(
456 "MIR2Vec vocabulary file path not specified; set it "
457 "using --mir2vec-vocab-path");
458
459 auto BufOrError = MemoryBuffer::getFileOrSTDIN(VocabFile, /*IsText=*/true);
460 if (!BufOrError)
461 return createFileError(VocabFile, BufOrError.getError());
462
463 auto Content = BufOrError.get()->getBuffer();
464
465 Expected<json::Value> ParsedVocabValue = json::parse(Content);
466 if (!ParsedVocabValue)
467 return ParsedVocabValue.takeError();
468
469 unsigned OpcodeDim = 0, CommonOperandDim = 0, PhyRegOperandDim = 0,
470 VirtRegOperandDim = 0;
472 "Opcodes", *ParsedVocabValue, OpcodeVocab, OpcodeDim))
473 return Err;
474
476 "CommonOperands", *ParsedVocabValue, CommonOperandVocab,
477 CommonOperandDim))
478 return Err;
479
481 "PhysicalRegisters", *ParsedVocabValue, PhyRegVocabMap,
482 PhyRegOperandDim))
483 return Err;
484
486 "VirtualRegisters", *ParsedVocabValue, VirtRegVocabMap,
487 VirtRegOperandDim))
488 return Err;
489
490 // All sections must have the same embedding dimension
491 if (!(OpcodeDim == CommonOperandDim && CommonOperandDim == PhyRegOperandDim &&
492 PhyRegOperandDim == VirtRegOperandDim)) {
493 return createStringError(
495 "MIR2Vec vocabulary sections have different dimensions");
496 }
497
498 return Error::success();
499}
500
503 "MIR2Vec Vocabulary Analysis", false, true)
506 "MIR2Vec Vocabulary Analysis", false, true)
507
508StringRef MIR2VecVocabLegacyAnalysis::getPassName() const {
509 return "MIR2Vec Vocabulary Analysis";
510}
511
512//===----------------------------------------------------------------------===//
513// MIREmbedder and its subclasses
514//===----------------------------------------------------------------------===//
515
516std::unique_ptr<MIREmbedder> MIREmbedder::create(MIR2VecKind Mode,
517 const MachineFunction &MF,
518 const MIRVocabulary &Vocab) {
519 switch (Mode) {
521 return std::make_unique<SymbolicMIREmbedder>(MF, Vocab);
522 }
523 return nullptr;
524}
525
528
529 // Get instruction info for opcode name resolution
530 const auto &Subtarget = MF.getSubtarget();
531 const auto *TII = Subtarget.getInstrInfo();
532 if (!TII) {
533 MF.getFunction().getContext().emitError(
534 "MIR2Vec: No TargetInstrInfo available; cannot compute embeddings");
535 return MBBVector;
536 }
537
538 // Process each machine instruction in the basic block
539 for (const auto &MI : MBB) {
540 // Skip debug instructions and other metadata
541 if (MI.isDebugInstr())
542 continue;
544 }
545
546 return MBBVector;
547}
548
550 Embedding MFuncVector(Dimension, 0);
551
552 // Consider all reachable machine basic blocks in the function
553 for (const auto *MBB : depth_first(&MF))
554 MFuncVector += computeEmbeddings(*MBB);
555 return MFuncVector;
556}
557
561
562std::unique_ptr<SymbolicMIREmbedder>
564 const MIRVocabulary &Vocab) {
565 return std::make_unique<SymbolicMIREmbedder>(MF, Vocab);
566}
567
569 // Skip debug instructions and other metadata
570 if (MI.isDebugInstr())
571 return Embedding(Dimension, 0);
572
573 // Opcode embedding
574 Embedding InstructionEmbedding = Vocab[MI.getOpcode()];
575
576 // Add operand contributions
577 for (const MachineOperand &MO : MI.operands())
578 InstructionEmbedding += Vocab[MO];
579
580 return InstructionEmbedding;
581}
582
583//===----------------------------------------------------------------------===//
584// Printer Passes
585//===----------------------------------------------------------------------===//
586
589 "MIR2Vec Vocabulary Printer Pass", false, true)
593 "MIR2Vec Vocabulary Printer Pass", false, true)
594
598
601 auto MIR2VecVocabOrErr = Analysis.getMIR2VecVocabulary(M);
602
603 if (!MIR2VecVocabOrErr) {
604 OS << "MIR2Vec Vocabulary Printer: Failed to get vocabulary - "
605 << toString(MIR2VecVocabOrErr.takeError()) << "\n";
606 return false;
607 }
608
609 auto &MIR2VecVocab = *MIR2VecVocabOrErr;
610 unsigned Pos = 0;
611 for (const auto &Entry : MIR2VecVocab) {
612 OS << "Key: " << MIR2VecVocab.getStringKey(Pos++) << ": ";
613 Entry.print(OS);
614 }
615
616 return false;
617}
618
623
626 "MIR2Vec Embedder Printer Pass", false, true)
630 "MIR2Vec Embedder Printer Pass", false, true)
631
634 auto VocabOrErr =
635 Analysis.getMIR2VecVocabulary(*MF.getFunction().getParent());
636 assert(VocabOrErr && "Failed to get MIR2Vec vocabulary");
637 auto &MIRVocab = *VocabOrErr;
638
639 auto Emb = mir2vec::MIREmbedder::create(MIR2VecEmbeddingKind, MF, MIRVocab);
640 if (!Emb) {
641 OS << "Error creating MIR2Vec embeddings for function " << MF.getName()
642 << "\n";
643 return false;
644 }
645
646 OS << "MIR2Vec embeddings for machine function " << MF.getName() << ":\n";
647 OS << "Machine Function vector: ";
648 Emb->getMFunctionVector().print(OS);
649
650 OS << "Machine basic block vectors:\n";
651 for (const MachineBasicBlock &MBB : MF) {
652 OS << "Machine basic block: " << MBB.getFullName() << ":\n";
653 Emb->getMBBVector(MBB).print(OS);
654 }
655
656 OS << "Machine instruction vectors:\n";
657 for (const MachineBasicBlock &MBB : MF) {
658 for (const MachineInstr &MI : MBB) {
659 // Skip debug instructions as they are not
660 // embedded
661 if (MI.isDebugInstr())
662 continue;
663
664 OS << "Machine instruction: ";
665 MI.print(OS);
666 Emb->getMInstVector(MI).print(OS);
667 }
668 }
669
670 return false;
671}
672
unsigned const MachineRegisterInfo * MRI
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
MachineBasicBlock & MBB
block Block Frequency Analysis
#define clEnumValN(ENUMVAL, FLAGNAME, DESC)
This file builds on the ADT/GraphTraits.h file to build generic depth first graph iterator.
const HexagonInstrInfo * TII
IRTranslator LLVM IR MI
Module.h This file contains the declarations for the Module class.
#define RegName(no)
#define F(x, y, z)
Definition MD5.cpp:55
This file defines the MIR2Vec framework for generating Machine IR embeddings.
Register Reg
Register const TargetRegisterInfo * TRI
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition PassSupport.h:42
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition PassSupport.h:44
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition PassSupport.h:39
SmallVector< MachineBasicBlock *, 4 > MBBVector
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition Statistic.h:171
#define LLVM_DEBUG(...)
Definition Debug.h:114
Lightweight error class with error context and mandatory checking.
Definition Error.h:159
static ErrorSuccess success()
Create a success value.
Definition Error.h:336
Tagged union holding either a T or a Error.
Definition Error.h:485
Error takeError()
Take ownership of the stored error.
Definition Error.h:612
This pass prints the MIR2Vec embeddings for machine functions, basic blocks, and instructions.
Definition MIR2Vec.h:436
MIR2VecPrinterLegacyPass(raw_ostream &OS)
Definition MIR2Vec.h:441
bool runOnMachineFunction(MachineFunction &MF) override
runOnMachineFunction - This method must be overloaded to perform the desired machine code transformat...
Definition MIR2Vec.cpp:632
Pass to analyze and populate MIR2Vec vocabulary from a module.
Definition MIR2Vec.h:381
This pass prints the embeddings in the MIR2Vec vocabulary.
Definition MIR2Vec.h:413
bool doFinalization(Module &M) override
doFinalization - Virtual method overriden by subclasses to do any necessary clean up after all passes...
Definition MIR2Vec.cpp:599
bool runOnMachineFunction(MachineFunction &MF) override
runOnMachineFunction - This method must be overloaded to perform the desired machine code transformat...
Definition MIR2Vec.cpp:595
MIR2VecVocabPrinterLegacyPass(raw_ostream &OS)
Definition MIR2Vec.h:418
Expected< mir2vec::MIRVocabulary > getVocabulary(const Module &M)
Definition MIR2Vec.cpp:424
MachineFunctionPass - This class adapts the FunctionPass interface to allow convenient creation of pa...
Representation of each machine instruction.
MachineOperand class - Representation of each machine instruction operand.
@ MO_Register
Register operand.
MachineRegisterInfo - Keep track of information for virtual and physical registers,...
static ErrorOr< std::unique_ptr< MemoryBuffer > > getFileOrSTDIN(const Twine &Filename, bool IsText=false, bool RequiresNullTerminator=true, std::optional< Align > Alignment=std::nullopt)
Open the specified file as a MemoryBuffer, or open stdin if the Filename is "-".
A Module instance is used to store all the information related to an LLVM module.
Definition Module.h:67
AnalysisType & getAnalysis() const
getAnalysis<AnalysisType>() - This function is used by subclasses to get to the analysis information ...
LLVM_ABI bool match(StringRef String, SmallVectorImpl< StringRef > *Matches=nullptr, std::string *Error=nullptr) const
matches - Match the regex against a given String.
Definition Regex.cpp:83
Wrapper class representing virtual and physical registers.
Definition Register.h:19
constexpr bool isValid() const
Definition Register.h:107
constexpr bool isVirtual() const
Return true if the specified register number is in the virtual register namespace.
Definition Register.h:74
constexpr bool isPhysical() const
Return true if the specified register number is in the physical register namespace.
Definition Register.h:78
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
StringRef - Represent a constant reference to a string, i.e.
Definition StringRef.h:55
std::string str() const
str - Get the contents as an std::string.
Definition StringRef.h:225
constexpr bool empty() const
empty - Check if the string is empty.
Definition StringRef.h:143
char back() const
back - Get the last character in the string.
Definition StringRef.h:155
StringRef drop_back(size_t N=1) const
Return a StringRef equal to 'this' but with the last N elements dropped.
Definition StringRef.h:618
TargetInstrInfo - Interface to description of machine instruction set.
unsigned getID() const
Return the register class ID number.
TargetRegisterInfo base class - We assume that the target defines a static array of TargetRegisterDes...
Generic storage class for section-based vocabularies.
Definition IR2Vec.h:151
static Error parseVocabSection(StringRef Key, const json::Value &ParsedVocabValue, VocabMap &TargetVocab, unsigned &Dim)
Parse a vocabulary section from JSON and populate the target vocabulary map.
Definition IR2Vec.cpp:316
unsigned getDimension() const
Get vocabulary dimension.
Definition IR2Vec.h:190
bool isValid() const
Check if vocabulary is valid (has data)
Definition IR2Vec.h:193
const unsigned Dimension
Dimension of the embeddings; Captured from the vocabulary.
Definition MIR2Vec.h:293
const MIRVocabulary & Vocab
Definition MIR2Vec.h:290
MIREmbedder(const MachineFunction &MF, const MIRVocabulary &Vocab)
Definition MIR2Vec.h:298
Embedding computeEmbeddings() const
Function to compute embeddings.
Definition MIR2Vec.cpp:549
const MachineFunction & MF
Definition MIR2Vec.h:289
static std::unique_ptr< MIREmbedder > create(MIR2VecKind Mode, const MachineFunction &MF, const MIRVocabulary &Vocab)
Factory method to create an Embedder object of the specified kind Returns nullptr if the requested ki...
Definition MIR2Vec.cpp:516
Class for storing and accessing the MIR2Vec vocabulary.
Definition MIR2Vec.h:86
unsigned getCanonicalIndexForOperandName(StringRef OperandName) const
Definition MIR2Vec.cpp:159
unsigned getCanonicalIndexForRegisterClass(StringRef RegName, bool IsPhysical=true) const
Definition MIR2Vec.cpp:169
static Expected< MIRVocabulary > create(VocabMap &&OpcMap, VocabMap &&CommonOperandsMap, VocabMap &&PhyRegMap, VocabMap &&VirtRegMap, const TargetInstrInfo &TII, const TargetRegisterInfo &TRI, const MachineRegisterInfo &MRI)
Factory method to create MIRVocabulary from vocabulary map.
Definition MIR2Vec.cpp:93
static std::string extractBaseOpcodeName(StringRef InstrName)
Static method for extracting base opcode names (public for testing)
Definition MIR2Vec.cpp:114
static Expected< MIRVocabulary > createDummyVocabForTest(const TargetInstrInfo &TII, const TargetRegisterInfo &TRI, const MachineRegisterInfo &MRI, unsigned Dim=1)
Create a dummy vocabulary for testing purposes.
Definition MIR2Vec.cpp:377
std::string getStringKey(unsigned Pos) const
Get the string key for a vocabulary entry at the given position.
Definition MIR2Vec.cpp:179
unsigned getCanonicalIndexForBaseName(StringRef BaseName) const
Get indices from opcode or operand names.
Definition MIR2Vec.cpp:144
static std::unique_ptr< SymbolicMIREmbedder > create(const MachineFunction &MF, const MIRVocabulary &Vocab)
Definition MIR2Vec.cpp:563
SymbolicMIREmbedder(const MachineFunction &F, const MIRVocabulary &Vocab)
Definition MIR2Vec.cpp:558
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition raw_ostream.h:53
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
OperandType
Operands are tagged with one of the values of this enum.
Definition MCInstrDesc.h:60
ValuesClass values(OptsTy... Options)
Helper to build a ValuesClass by forwarding a variable number of arguments as an initializer list to ...
initializer< Ty > init(const Ty &Val)
LLVM_ABI llvm::Expected< Value > parse(llvm::StringRef JSON)
Parses the provided JSON source, or returns a ParseError.
Definition JSON.cpp:675
llvm::cl::OptionCategory MIR2VecCategory
cl::opt< float > OpcWeight
static cl::opt< std::string > VocabFile("mir2vec-vocab-path", cl::Optional, cl::desc("Path to the vocabulary file for MIR2Vec"), cl::init(""), cl::cat(MIR2VecCategory))
cl::opt< float > RegOperandWeight
Definition MIR2Vec.h:77
ir2vec::Embedding Embedding
Definition MIR2Vec.h:79
cl::opt< float > CommonOperandWeight
Definition MIR2Vec.h:77
cl::opt< MIR2VecKind > MIR2VecEmbeddingKind("mir2vec-kind", cl::Optional, cl::values(clEnumValN(MIR2VecKind::Symbolic, "symbolic", "Generate symbolic embeddings for MIR")), cl::init(MIR2VecKind::Symbolic), cl::desc("MIR2Vec embedding kind"), cl::cat(MIR2VecCategory))
This is an optimization pass for GlobalISel generic memory operations.
Error createFileError(const Twine &F, Error E)
Concatenate a source file path and/or name with an Error.
Definition Error.h:1399
Error createStringError(std::error_code EC, char const *Fmt, const Ts &... Vals)
Create formatted StringError object.
Definition Error.h:1305
@ illegal_byte_sequence
Definition Errc.h:52
@ invalid_argument
Definition Errc.h:56
MachineFunctionPass * createMIR2VecPrinterLegacyPass(raw_ostream &OS)
Create a machine pass that prints MIR2Vec embeddings.
Definition MIR2Vec.cpp:673
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition Debug.cpp:207
LLVM_ATTRIBUTE_VISIBILITY_DEFAULT AnalysisKey InnerAnalysisManagerProxy< AnalysisManagerT, IRUnitT, ExtraArgTs... >::Key
LLVM_ABI raw_fd_ostream & errs()
This returns a reference to a raw_ostream for standard error.
LLVM_ABI MachineFunctionPass * createMIR2VecVocabPrinterLegacyPass(raw_ostream &OS)
MIR2VecVocabPrinter pass - This pass prints out the MIR2Vec vocabulary contents to the given stream a...
Definition MIR2Vec.cpp:620
MIR2VecKind
Definition MIR2Vec.h:68
std::string toString(const APInt &I, unsigned Radix, bool Signed, bool formatAsCLiteral=false, bool UpperCase=true, bool InsertSeparators=false)
iterator_range< df_iterator< T > > depth_first(const T &G)