LLVM 23.0.0git
IR2Vec.cpp
Go to the documentation of this file.
1//===- IR2Vec.cpp - Implementation of IR2Vec -----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM
4// Exceptions. See the LICENSE file for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// \file
10/// This file implements the IR2Vec algorithm.
11///
12//===----------------------------------------------------------------------===//
13
15
17#include "llvm/ADT/Sequence.h"
19#include "llvm/ADT/Statistic.h"
20#include "llvm/IR/CFG.h"
21#include "llvm/IR/Module.h"
22#include "llvm/IR/PassManager.h"
23#include "llvm/Support/Debug.h"
24#include "llvm/Support/Errc.h"
25#include "llvm/Support/Error.h"
27#include "llvm/Support/Format.h"
29
30using namespace llvm;
31using namespace ir2vec;
32
33#define DEBUG_TYPE "ir2vec"
34
35STATISTIC(VocabMissCounter,
36 "Number of lookups to entities not present in the vocabulary");
37
38namespace llvm {
39namespace ir2vec {
41
42// FIXME: Use a default vocab when not specified
44 VocabFile("ir2vec-vocab-path", cl::Optional,
45 cl::desc("Path to the vocabulary file for IR2Vec"), cl::init(""),
47cl::opt<float> OpcWeight("ir2vec-opc-weight", cl::Optional, cl::init(1.0),
48 cl::desc("Weight for opcode embeddings"),
50cl::opt<float> TypeWeight("ir2vec-type-weight", cl::Optional, cl::init(0.5),
51 cl::desc("Weight for type embeddings"),
53cl::opt<float> ArgWeight("ir2vec-arg-weight", cl::Optional, cl::init(0.2),
54 cl::desc("Weight for argument embeddings"),
57 "ir2vec-kind", cl::Optional,
59 "Generate symbolic embeddings"),
61 "Generate flow-aware embeddings")),
62 cl::init(IR2VecKind::Symbolic), cl::desc("IR2Vec embedding kind"),
64
65} // namespace ir2vec
66} // namespace llvm
67
69
70// ==----------------------------------------------------------------------===//
71// Local helper functions
72//===----------------------------------------------------------------------===//
73namespace llvm::json {
74inline bool fromJSON(const llvm::json::Value &E, Embedding &Out,
76 std::vector<double> TempOut;
77 if (!llvm::json::fromJSON(E, TempOut, P))
78 return false;
79 Out = Embedding(std::move(TempOut));
80 return true;
81}
82} // namespace llvm::json
83
84// ==----------------------------------------------------------------------===//
85// Embedding
86//===----------------------------------------------------------------------===//
88 assert(this->size() == RHS.size() && "Vectors must have the same dimension");
89 std::transform(this->begin(), this->end(), RHS.begin(), this->begin(),
90 std::plus<double>());
91 return *this;
92}
93
95 Embedding Result(*this);
96 Result += RHS;
97 return Result;
98}
99
101 assert(this->size() == RHS.size() && "Vectors must have the same dimension");
102 std::transform(this->begin(), this->end(), RHS.begin(), this->begin(),
103 std::minus<double>());
104 return *this;
105}
106
108 Embedding Result(*this);
109 Result -= RHS;
110 return Result;
111}
112
114 std::transform(this->begin(), this->end(), this->begin(),
115 [Factor](double Elem) { return Elem * Factor; });
116 return *this;
117}
118
119Embedding Embedding::operator*(double Factor) const {
120 Embedding Result(*this);
121 Result *= Factor;
122 return Result;
123}
124
125Embedding &Embedding::scaleAndAdd(const Embedding &Src, float Factor) {
126 assert(this->size() == Src.size() && "Vectors must have the same dimension");
127 for (size_t Itr = 0; Itr < this->size(); ++Itr)
128 (*this)[Itr] += Src[Itr] * Factor;
129 return *this;
130}
131
133 double Tolerance) const {
134 assert(this->size() == RHS.size() && "Vectors must have the same dimension");
135 for (size_t Itr = 0; Itr < this->size(); ++Itr)
136 if (std::abs((*this)[Itr] - RHS[Itr]) > Tolerance) {
137 LLVM_DEBUG(errs() << "Embedding mismatch at index " << Itr << ": "
138 << (*this)[Itr] << " vs " << RHS[Itr]
139 << "; Tolerance: " << Tolerance << "\n");
140 return false;
141 }
142 return true;
143}
144
146 OS << " [";
147 for (const auto &Elem : Data)
148 OS << " " << format("%.2f", Elem) << " ";
149 OS << "]\n";
150}
151
152// ==----------------------------------------------------------------------===//
153// Embedder and its subclasses
154//===----------------------------------------------------------------------===//
155
156std::unique_ptr<Embedder> Embedder::create(IR2VecKind Mode, const Function &F,
157 const Vocabulary &Vocab) {
158 switch (Mode) {
160 return std::make_unique<SymbolicEmbedder>(F, Vocab);
162 return std::make_unique<FlowAwareEmbedder>(F, Vocab);
163 }
164 return nullptr;
165}
166
168 Embedding FuncVector(Dimension, 0.0);
169
170 if (F.isDeclaration())
171 return FuncVector;
172
173 // Consider only the basic blocks that are reachable from entry
174 for (const BasicBlock *BB : depth_first(&F))
175 FuncVector += computeEmbeddings(*BB);
176 return FuncVector;
177}
178
180 Embedding BBVector(Dimension, 0);
181
182 // We consider only the non-debug and non-pseudo instructions
183 for (const auto &I : BB)
184 if (!I.isDebugOrPseudoInst())
185 BBVector += computeEmbeddings(I);
186 return BBVector;
187}
188
190 // Currently, we always (re)compute the embeddings for symbolic embedder.
191 // This is cheaper than caching the vectors.
192 Embedding ArgEmb(Dimension, 0);
193 for (const auto &Op : I.operands())
194 ArgEmb += Vocab[*Op];
195 auto InstVector =
196 Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
197 if (const auto *IC = dyn_cast<CmpInst>(&I))
198 InstVector += Vocab[IC->getPredicate()];
199 return InstVector;
200}
201
203 // If we have already computed the embedding for this instruction, return it
204 auto It = InstVecMap.find(&I);
205 if (It != InstVecMap.end())
206 return It->second;
207
208 // TODO: Handle call instructions differently.
209 // For now, we treat them like other instructions
210 Embedding ArgEmb(Dimension, 0);
211 for (const auto &Op : I.operands()) {
212 // If the operand is defined elsewhere, we use its embedding
213 if (const auto *DefInst = dyn_cast<Instruction>(Op)) {
214 auto DefIt = InstVecMap.find(DefInst);
215 // Fixme (#159171): Ideally we should never miss an instruction
216 // embedding here.
217 // But when we have cyclic dependencies (e.g., phi
218 // nodes), we might miss the embedding. In such cases, we fall back to
219 // using the vocabulary embedding. This can be fixed by iterating to a
220 // fixed-point, or by using a simple solver for the set of simultaneous
221 // equations.
222 // Another case when we might miss an instruction embedding is when
223 // the operand instruction is in a different basic block that has not
224 // been processed yet. This can be fixed by processing the basic blocks
225 // in a topological order.
226 if (DefIt != InstVecMap.end())
227 ArgEmb += DefIt->second;
228 else
229 ArgEmb += Vocab[*Op];
230 }
231 // If the operand is not defined by an instruction, we use the
232 // vocabulary
233 else {
234 LLVM_DEBUG(errs() << "Using embedding from vocabulary for operand: "
235 << *Op << "=" << Vocab[*Op][0] << "\n");
236 ArgEmb += Vocab[*Op];
237 }
238 }
239 // Create the instruction vector by combining opcode, type, and arguments
240 // embeddings
241 auto InstVector =
242 Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
243 if (const auto *IC = dyn_cast<CmpInst>(&I))
244 InstVector += Vocab[IC->getPredicate()];
245 InstVecMap[&I] = InstVector;
246 return InstVector;
247}
248
249// ==----------------------------------------------------------------------===//
250// VocabStorage
251//===----------------------------------------------------------------------===//
252
253VocabStorage::VocabStorage(std::vector<std::vector<Embedding>> &&SectionData)
254 : Sections(std::move(SectionData)), TotalSize([&] {
255 assert(!Sections.empty() && "Vocabulary has no sections");
256 // Compute total size across all sections
257 size_t Size = 0;
258 for (const auto &Section : Sections) {
259 assert(!Section.empty() && "Vocabulary section is empty");
260 Size += Section.size();
261 }
262 return Size;
263 }()),
264 Dimension([&] {
265 // Get dimension from the first embedding in the first section - all
266 // embeddings must have the same dimension
267 assert(!Sections.empty() && "Vocabulary has no sections");
268 assert(!Sections[0].empty() && "First section of vocabulary is empty");
269 unsigned ExpectedDim = static_cast<unsigned>(Sections[0][0].size());
270
271 // Verify that all embeddings across all sections have the same
272 // dimension
273 [[maybe_unused]] auto allSameDim =
274 [ExpectedDim](const std::vector<Embedding> &Section) {
275 return std::all_of(Section.begin(), Section.end(),
276 [ExpectedDim](const Embedding &Emb) {
277 return Emb.size() == ExpectedDim;
278 });
279 };
280 assert(std::all_of(Sections.begin(), Sections.end(), allSameDim) &&
281 "All embeddings must have the same dimension");
282
283 return ExpectedDim;
284 }()) {}
285
287 assert(SectionId < Storage->Sections.size() && "Invalid section ID");
288 assert(LocalIndex < Storage->Sections[SectionId].size() &&
289 "Local index out of range");
290 return Storage->Sections[SectionId][LocalIndex];
291}
292
294 ++LocalIndex;
295 // Check if we need to move to the next section
296 if (SectionId < Storage->getNumSections() &&
297 LocalIndex >= Storage->Sections[SectionId].size()) {
298 assert(LocalIndex == Storage->Sections[SectionId].size() &&
299 "Local index should be at the end of the current section");
300 LocalIndex = 0;
301 ++SectionId;
302 }
303 return *this;
304}
305
307 const const_iterator &Other) const {
308 return Storage == Other.Storage && SectionId == Other.SectionId &&
309 LocalIndex == Other.LocalIndex;
310}
311
313 const const_iterator &Other) const {
314 return !(*this == Other);
315}
316
318 const json::Value &ParsedVocabValue,
319 VocabMap &TargetVocab, unsigned &Dim) {
320 json::Path::Root Path("");
321 const json::Object *RootObj = ParsedVocabValue.getAsObject();
322 if (!RootObj)
324 "JSON root is not an object");
325
326 const json::Value *SectionValue = RootObj->get(Key);
327 if (!SectionValue)
329 "Missing '" + std::string(Key) +
330 "' section in vocabulary file");
331 if (!json::fromJSON(*SectionValue, TargetVocab, Path))
333 "Unable to parse '" + std::string(Key) +
334 "' section from vocabulary");
335
336 Dim = TargetVocab.begin()->second.size();
337 if (Dim == 0)
339 "Dimension of '" + std::string(Key) +
340 "' section of the vocabulary is zero");
341
342 if (!std::all_of(TargetVocab.begin(), TargetVocab.end(),
343 [Dim](const std::pair<StringRef, Embedding> &Entry) {
344 return Entry.second.size() == Dim;
345 }))
346 return createStringError(
348 "All vectors in the '" + std::string(Key) +
349 "' section of the vocabulary are not of the same dimension");
350
351 return Error::success();
352}
353
354// ==----------------------------------------------------------------------===//
355// Vocabulary
356//===----------------------------------------------------------------------===//
357
359 assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode");
360#define HANDLE_INST(NUM, OPCODE, CLASS) \
361 if (Opcode == NUM) { \
362 return #OPCODE; \
363 }
364#include "llvm/IR/Instruction.def"
365#undef HANDLE_INST
366 return "UnknownOpcode";
367}
368
369// Helper function to classify an operand into OperandKind
379
380unsigned Vocabulary::getPredicateLocalIndex(CmpInst::Predicate P) {
383 else
386}
387
388CmpInst::Predicate Vocabulary::getPredicateFromLocalIndex(unsigned LocalIndex) {
389 unsigned fcmpRange =
391 if (LocalIndex < fcmpRange)
393 LocalIndex);
394 else
396 LocalIndex - fcmpRange);
397}
398
400 static SmallString<16> PredNameBuffer;
402 PredNameBuffer = "FCMP_";
403 else
404 PredNameBuffer = "ICMP_";
405 PredNameBuffer += CmpInst::getPredicateName(Pred);
406 return PredNameBuffer;
407}
408
410 assert(Pos < NumCanonicalEntries && "Position out of bounds in vocabulary");
411 // Opcode
412 if (Pos < MaxOpcodes)
413 return getVocabKeyForOpcode(Pos + 1);
414 // Type
415 if (Pos < OperandBaseOffset)
416 return getVocabKeyForCanonicalTypeID(
417 static_cast<CanonicalTypeID>(Pos - MaxOpcodes));
418 // Operand
419 if (Pos < PredicateBaseOffset)
421 static_cast<OperandKind>(Pos - OperandBaseOffset));
422 // Predicates
423 return getVocabKeyForPredicate(getPredicate(Pos - PredicateBaseOffset));
424}
425
426// For now, assume vocabulary is stable unless explicitly invalidated.
428 ModuleAnalysisManager::Invalidator &Inv) const {
429 auto PAC = PA.getChecker<IR2VecVocabAnalysis>();
430 return !(PAC.preservedWhenStateless());
431}
432
434 float DummyVal = 0.1f;
435
436 // Create sections for opcodes, types, operands, and predicates
437 // Order must match Vocabulary::Section enum
438 std::vector<std::vector<Embedding>> Sections;
439 Sections.reserve(4);
440
441 // Opcodes section
442 std::vector<Embedding> OpcodeSec;
443 OpcodeSec.reserve(MaxOpcodes);
444 for (unsigned I = 0; I < MaxOpcodes; ++I) {
445 OpcodeSec.emplace_back(Dim, DummyVal);
446 DummyVal += 0.1f;
447 }
448 Sections.push_back(std::move(OpcodeSec));
449
450 // Types section
451 std::vector<Embedding> TypeSec;
452 TypeSec.reserve(MaxCanonicalTypeIDs);
453 for (unsigned I = 0; I < MaxCanonicalTypeIDs; ++I) {
454 TypeSec.emplace_back(Dim, DummyVal);
455 DummyVal += 0.1f;
456 }
457 Sections.push_back(std::move(TypeSec));
458
459 // Operands section
460 std::vector<Embedding> OperandSec;
461 OperandSec.reserve(MaxOperandKinds);
462 for (unsigned I = 0; I < MaxOperandKinds; ++I) {
463 OperandSec.emplace_back(Dim, DummyVal);
464 DummyVal += 0.1f;
465 }
466 Sections.push_back(std::move(OperandSec));
467
468 // Predicates section
469 std::vector<Embedding> PredicateSec;
470 PredicateSec.reserve(MaxPredicateKinds);
471 for (unsigned I = 0; I < MaxPredicateKinds; ++I) {
472 PredicateSec.emplace_back(Dim, DummyVal);
473 DummyVal += 0.1f;
474 }
475 Sections.push_back(std::move(PredicateSec));
476
477 return VocabStorage(std::move(Sections));
478}
479
480namespace {
481using VocabMap = std::map<std::string, Embedding>;
482
483/// Read vocabulary JSON file and populate the section maps.
484Error readVocabularyFromFile(StringRef VocabFilePath, VocabMap &OpcVocab,
485 VocabMap &TypeVocab, VocabMap &ArgVocab) {
486 auto BufOrError =
487 MemoryBuffer::getFileOrSTDIN(VocabFilePath, /*IsText=*/true);
488 if (!BufOrError)
489 return createFileError(VocabFilePath, BufOrError.getError());
490
491 auto Content = BufOrError.get()->getBuffer();
492
493 Expected<json::Value> ParsedVocabValue = json::parse(Content);
494 if (!ParsedVocabValue)
495 return ParsedVocabValue.takeError();
496
497 unsigned OpcodeDim = 0, TypeDim = 0, ArgDim = 0;
498 if (auto Err = VocabStorage::parseVocabSection("Opcodes", *ParsedVocabValue,
499 OpcVocab, OpcodeDim))
500 return Err;
501
502 if (auto Err = VocabStorage::parseVocabSection("Types", *ParsedVocabValue,
503 TypeVocab, TypeDim))
504 return Err;
505
506 if (auto Err = VocabStorage::parseVocabSection("Arguments", *ParsedVocabValue,
507 ArgVocab, ArgDim))
508 return Err;
509
510 if (!(OpcodeDim == TypeDim && TypeDim == ArgDim))
512 "Vocabulary sections have different dimensions");
513
514 return Error::success();
515}
516} // anonymous namespace
517
518/// Generate VocabStorage from vocabulary maps.
519VocabStorage Vocabulary::buildVocabStorage(const VocabMap &OpcVocab,
520 const VocabMap &TypeVocab,
521 const VocabMap &ArgVocab) {
522
523 // Helper for handling missing entities in the vocabulary.
524 // Currently, we use a zero vector. In the future, we will throw an error to
525 // ensure that *all* known entities are present in the vocabulary.
526 auto handleMissingEntity = [](const std::string &Val) {
527 LLVM_DEBUG(errs() << Val
528 << " is not in vocabulary, using zero vector; This "
529 "would result in an error in future.\n");
530 ++VocabMissCounter;
531 };
532
533 unsigned Dim = OpcVocab.begin()->second.size();
534 assert(Dim > 0 && "Vocabulary dimension must be greater than zero");
535
536 // Handle Opcodes
537 std::vector<Embedding> NumericOpcodeEmbeddings(Vocabulary::MaxOpcodes,
538 Embedding(Dim));
539 for (unsigned Opcode : seq(0u, Vocabulary::MaxOpcodes)) {
540 StringRef VocabKey = Vocabulary::getVocabKeyForOpcode(Opcode + 1);
541 auto It = OpcVocab.find(VocabKey.str());
542 if (It != OpcVocab.end())
543 NumericOpcodeEmbeddings[Opcode] = It->second;
544 else
545 handleMissingEntity(VocabKey.str());
546 }
547
548 // Handle Types - only canonical types are present in vocabulary
549 std::vector<Embedding> NumericTypeEmbeddings(Vocabulary::MaxCanonicalTypeIDs,
550 Embedding(Dim));
551 for (unsigned CTypeID : seq(0u, Vocabulary::MaxCanonicalTypeIDs)) {
552 StringRef VocabKey = Vocabulary::getVocabKeyForCanonicalTypeID(
553 static_cast<Vocabulary::CanonicalTypeID>(CTypeID));
554 if (auto It = TypeVocab.find(VocabKey.str()); It != TypeVocab.end()) {
555 NumericTypeEmbeddings[CTypeID] = It->second;
556 continue;
557 }
558 handleMissingEntity(VocabKey.str());
559 }
560
561 // Handle Arguments/Operands
562 std::vector<Embedding> NumericArgEmbeddings(Vocabulary::MaxOperandKinds,
563 Embedding(Dim));
564 for (unsigned OpKind : seq(0u, Vocabulary::MaxOperandKinds)) {
566 StringRef VocabKey = Vocabulary::getVocabKeyForOperandKind(Kind);
567 auto It = ArgVocab.find(VocabKey.str());
568 if (It != ArgVocab.end()) {
569 NumericArgEmbeddings[OpKind] = It->second;
570 continue;
571 }
572 handleMissingEntity(VocabKey.str());
573 }
574
575 // Handle Predicates: part of Operands section. We look up predicate keys
576 // in ArgVocab.
577 std::vector<Embedding> NumericPredEmbeddings(Vocabulary::MaxPredicateKinds,
578 Embedding(Dim, 0));
579 for (unsigned PK : seq(0u, Vocabulary::MaxPredicateKinds)) {
580 StringRef VocabKey =
581 Vocabulary::getVocabKeyForPredicate(Vocabulary::getPredicate(PK));
582 auto It = ArgVocab.find(VocabKey.str());
583 if (It != ArgVocab.end()) {
584 NumericPredEmbeddings[PK] = It->second;
585 continue;
586 }
587 handleMissingEntity(VocabKey.str());
588 }
589
590 // Create section-based storage instead of flat vocabulary
591 // Order must match Vocabulary::Section enum
592 std::vector<std::vector<Embedding>> Sections(4);
593 Sections[static_cast<unsigned>(Section::Opcodes)] =
594 std::move(NumericOpcodeEmbeddings); // Section::Opcodes
595 Sections[static_cast<unsigned>(Section::CanonicalTypes)] =
596 std::move(NumericTypeEmbeddings); // Section::CanonicalTypes
597 Sections[static_cast<unsigned>(Section::Operands)] =
598 std::move(NumericArgEmbeddings); // Section::Operands
599 Sections[static_cast<unsigned>(Section::Predicates)] =
600 std::move(NumericPredEmbeddings); // Section::Predicates
601
602 // Create VocabStorage from organized sections
603 return VocabStorage(std::move(Sections));
604}
605
606// ==----------------------------------------------------------------------===//
607// Vocabulary
608//===----------------------------------------------------------------------===//
609
611 float OpcWeight, float TypeWeight,
612 float ArgWeight) {
613 VocabMap OpcVocab, TypeVocab, ArgVocab;
614 if (auto Err =
615 readVocabularyFromFile(VocabFilePath, OpcVocab, TypeVocab, ArgVocab))
616 return std::move(Err);
617
618 // Scale the vocabulary sections based on the provided weights
619 auto scaleVocabSection = [](VocabMap &Vocab, float Weight) {
620 for (auto &Entry : Vocab)
621 Entry.second *= Weight;
622 };
623 scaleVocabSection(OpcVocab, OpcWeight);
624 scaleVocabSection(TypeVocab, TypeWeight);
625 scaleVocabSection(ArgVocab, ArgWeight);
626
627 // Generate the numeric lookup vocabulary
628 return Vocabulary(buildVocabStorage(OpcVocab, TypeVocab, ArgVocab));
629}
630
631// ==----------------------------------------------------------------------===//
632// IR2VecVocabAnalysis
633//===----------------------------------------------------------------------===//
634
635void IR2VecVocabAnalysis::emitError(Error Err, LLVMContext &Ctx) {
636 handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {
637 Ctx.emitError("Error reading vocabulary: " + EI.message());
638 });
639}
640
643 auto Ctx = &M.getContext();
644 // If vocabulary is already populated by the constructor, use it.
645 if (Vocab.has_value())
646 return Vocabulary(std::move(Vocab.value()));
647
648 // Otherwise, try to read from the vocabulary file specified via CLI.
649 if (VocabFile.empty()) {
650 // FIXME: Use default vocabulary
651 Ctx->emitError("IR2Vec vocabulary file path not specified; You may need to "
652 "set it using --ir2vec-vocab-path");
653 return Vocabulary(); // Return invalid result
654 }
655
656 // Use the static factory method to load the vocabulary.
657 auto VocabOrErr =
659 if (!VocabOrErr) {
660 emitError(VocabOrErr.takeError(), *Ctx);
661 return Vocabulary();
662 }
663
664 return std::move(*VocabOrErr);
665}
666
667// ==----------------------------------------------------------------------===//
668// Printer Passes
669//===----------------------------------------------------------------------===//
670
673 auto &Vocabulary = MAM.getResult<IR2VecVocabAnalysis>(M);
674 assert(Vocabulary.isValid() && "IR2Vec Vocabulary is invalid");
675
676 for (Function &F : M) {
678 if (!Emb) {
679 OS << "Error creating IR2Vec embeddings \n";
680 continue;
681 }
682
683 OS << "IR2Vec embeddings for function " << F.getName() << ":\n";
684 OS << "Function vector: ";
685 Emb->getFunctionVector().print(OS);
686
687 OS << "Basic block vectors:\n";
688 for (const BasicBlock &BB : F) {
689 OS << "Basic block: " << BB.getName() << ":\n";
690 Emb->getBBVector(BB).print(OS);
691 }
692
693 OS << "Instruction vectors:\n";
694 for (const BasicBlock &BB : F) {
695 for (const Instruction &I : BB) {
696 OS << "Instruction: ";
697 I.print(OS);
698 Emb->getInstVector(I).print(OS);
699 }
700 }
701 }
702 return PreservedAnalyses::all();
703}
704
707 auto &IR2VecVocabulary = MAM.getResult<IR2VecVocabAnalysis>(M);
708 assert(IR2VecVocabulary.isValid() && "IR2Vec Vocabulary is invalid");
709
710 // Print each entry
711 unsigned Pos = 0;
712 for (const auto &Entry : IR2VecVocabulary) {
713 OS << "Key: " << IR2VecVocabulary.getStringKey(Pos++) << ": ";
714 Entry.print(OS);
715 }
716 return PreservedAnalyses::all();
717}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
#define clEnumValN(ENUMVAL, FLAGNAME, DESC)
This file builds on the ADT/GraphTraits.h file to build generic depth first graph iterator.
This file defines the IR2Vec vocabulary analysis(IR2VecVocabAnalysis), the core ir2vec::Embedder inte...
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
Module.h This file contains the declarations for the Module class.
This header defines various interfaces for pass management in LLVM.
#define F(x, y, z)
Definition MD5.cpp:54
#define I(x, y, z)
Definition MD5.cpp:57
#define P(N)
ModuleAnalysisManager MAM
Provides some synthesis utilities to produce sequences of values.
This file defines the SmallVector class.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition Statistic.h:171
#define LLVM_DEBUG(...)
Definition Debug.h:114
LLVM Basic Block Representation.
Definition BasicBlock.h:62
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
Definition InstrTypes.h:676
static LLVM_ABI StringRef getPredicateName(Predicate P)
iterator find(const_arg_type_t< KeyT > Val)
Definition DenseMap.h:178
iterator end()
Definition DenseMap.h:81
Base class for error info classes.
Definition Error.h:44
virtual std::string message() const
Return the error message as a string.
Definition Error.h:52
Lightweight error class with error context and mandatory checking.
Definition Error.h:159
static ErrorSuccess success()
Create a success value.
Definition Error.h:336
Tagged union holding either a T or a Error.
Definition Error.h:485
Error takeError()
Take ownership of the stored error.
Definition Error.h:612
LLVM_ABI PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM)
Definition IR2Vec.cpp:671
This analysis provides the vocabulary for IR2Vec.
Definition IR2Vec.h:639
ir2vec::Vocabulary Result
Definition IR2Vec.h:649
LLVM_ABI Result run(Module &M, ModuleAnalysisManager &MAM)
Definition IR2Vec.cpp:642
static LLVM_ABI AnalysisKey Key
Definition IR2Vec.h:645
LLVM_ABI PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM)
Definition IR2Vec.cpp:705
This is an important class for using LLVM in a threaded context.
Definition LLVMContext.h:68
static ErrorOr< std::unique_ptr< MemoryBuffer > > getFileOrSTDIN(const Twine &Filename, bool IsText=false, bool RequiresNullTerminator=true, std::optional< Align > Alignment=std::nullopt)
Open the specified file as a MemoryBuffer, or open stdin if the Filename is "-".
A Module instance is used to store all the information related to an LLVM module.
Definition Module.h:67
A set of analyses that are preserved following a run of a transformation pass.
Definition Analysis.h:112
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition Analysis.h:118
PreservedAnalysisChecker getChecker() const
Build a checker for this PreservedAnalyses and the specified analysis type.
Definition Analysis.h:275
SmallString - A SmallString is just a SmallVector with methods and accessors that make it work better...
Definition SmallString.h:26
StringRef - Represent a constant reference to a string, i.e.
Definition StringRef.h:55
std::string str() const
str - Get the contents as an std::string.
Definition StringRef.h:222
LLVM Value Representation.
Definition Value.h:75
static LLVM_ABI std::unique_ptr< Embedder > create(IR2VecKind Mode, const Function &F, const Vocabulary &Vocab)
Factory method to create an Embedder object.
Definition IR2Vec.cpp:156
const Vocabulary & Vocab
Definition IR2Vec.h:552
const unsigned Dimension
Dimension of the vector representation; captured from the input vocabulary.
Definition IR2Vec.h:555
Embedding computeEmbeddings() const
Function to compute embeddings.
Definition IR2Vec.cpp:167
const Function & F
Definition IR2Vec.h:551
Iterator support for section-based access.
Definition IR2Vec.h:202
const_iterator(const VocabStorage *Storage, unsigned SectionId, size_t LocalIndex)
Definition IR2Vec.h:208
LLVM_ABI bool operator!=(const const_iterator &Other) const
Definition IR2Vec.cpp:312
LLVM_ABI const_iterator & operator++()
Definition IR2Vec.cpp:293
LLVM_ABI const Embedding & operator*() const
Definition IR2Vec.cpp:286
LLVM_ABI bool operator==(const const_iterator &Other) const
Definition IR2Vec.cpp:306
Generic storage class for section-based vocabularies.
Definition IR2Vec.h:157
static Error parseVocabSection(StringRef Key, const json::Value &ParsedVocabValue, VocabMap &TargetVocab, unsigned &Dim)
Parse a vocabulary section from JSON and populate the target vocabulary map.
Definition IR2Vec.cpp:317
unsigned getNumSections() const
Get number of sections.
Definition IR2Vec.h:185
size_t size() const
Get total number of entries across all sections.
Definition IR2Vec.h:182
VocabStorage()=default
Default constructor creates empty storage (invalid state)
const_iterator begin() const
Definition IR2Vec.h:218
std::map< std::string, Embedding > VocabMap
Definition IR2Vec.h:223
Class for storing and accessing the IR2Vec vocabulary.
Definition IR2Vec.h:248
static LLVM_ABI StringRef getVocabKeyForOperandKind(OperandKind Kind)
Function to get vocabulary key for a given OperandKind.
Definition IR2Vec.h:369
LLVM_ABI bool invalidate(Module &M, const PreservedAnalyses &PA, ModuleAnalysisManager::Invalidator &Inv) const
Definition IR2Vec.cpp:427
static LLVM_ABI Expected< Vocabulary > fromFile(StringRef VocabFilePath, float OpcWeight=1.0, float TypeWeight=0.5, float ArgWeight=0.2)
Create a Vocabulary by loading embeddings from a JSON file.
Definition IR2Vec.cpp:610
static LLVM_ABI OperandKind getOperandKind(const Value *Op)
Function to classify an operand into OperandKind.
Definition IR2Vec.cpp:370
friend class llvm::IR2VecVocabAnalysis
Definition IR2Vec.h:249
static LLVM_ABI StringRef getStringKey(unsigned Pos)
Returns the string key for a given index position in the vocabulary.
Definition IR2Vec.cpp:409
static constexpr unsigned MaxCanonicalTypeIDs
Definition IR2Vec.h:319
static constexpr unsigned MaxOperandKinds
Definition IR2Vec.h:321
OperandKind
Operand kinds supported by IR2Vec Vocabulary.
Definition IR2Vec.h:305
static LLVM_ABI StringRef getVocabKeyForPredicate(CmpInst::Predicate P)
Function to get vocabulary key for a given predicate.
Definition IR2Vec.cpp:399
static LLVM_ABI StringRef getVocabKeyForOpcode(unsigned Opcode)
Function to get vocabulary key for a given Opcode.
Definition IR2Vec.cpp:358
LLVM_ABI bool isValid() const
Definition IR2Vec.h:347
static LLVM_ABI VocabStorage createDummyVocabForTest(unsigned Dim=1)
Create a dummy vocabulary for testing purposes.
Definition IR2Vec.cpp:433
static constexpr unsigned MaxPredicateKinds
Definition IR2Vec.h:325
CanonicalTypeID
Canonical type IDs supported by IR2Vec Vocabulary.
Definition IR2Vec.h:287
An Object is a JSON object, which maps strings to heterogenous JSON values.
Definition JSON.h:98
LLVM_ABI Value * get(StringRef K)
Definition JSON.cpp:30
The root is the trivial Path to the root value.
Definition JSON.h:712
A "cursor" marking a position within a Value.
Definition JSON.h:665
A Value is an JSON value of unknown type.
Definition JSON.h:291
const json::Object * getAsObject() const
Definition JSON.h:465
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition raw_ostream.h:53
ValuesClass values(OptsTy... Options)
Helper to build a ValuesClass by forwarding a variable number of arguments as an initializer list to ...
initializer< Ty > init(const Ty &Val)
LLVM_ABI cl::opt< float > ArgWeight
LLVM_ABI cl::opt< std::string > VocabFile
LLVM_ABI cl::opt< float > OpcWeight
LLVM_ABI cl::opt< float > TypeWeight
LLVM_ABI cl::opt< IR2VecKind > IR2VecEmbeddingKind
LLVM_ABI llvm::cl::OptionCategory IR2VecCategory
LLVM_ABI llvm::Expected< Value > parse(llvm::StringRef JSON)
Parses the provided JSON source, or returns a ParseError.
Definition JSON.cpp:681
bool fromJSON(const Value &E, std::string &Out, Path P)
Definition JSON.h:741
ir2vec::Embedding Embedding
Definition MIR2Vec.h:79
This is an optimization pass for GlobalISel generic memory operations.
Definition Types.h:26
Error createFileError(const Twine &F, Error E)
Concatenate a source file path and/or name with an Error.
Definition Error.h:1399
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:643
void handleAllErrors(Error E, HandlerTs &&... Handlers)
Behaves the same as handleErrors, except that by contract all errors must be handled by the given han...
Definition Error.h:990
Error createStringError(std::error_code EC, char const *Fmt, const Ts &... Vals)
Create formatted StringError object.
Definition Error.h:1305
@ illegal_byte_sequence
Definition Errc.h:52
@ invalid_argument
Definition Errc.h:56
IR2VecKind
IR2Vec computes two kinds of embeddings: Symbolic and Flow-aware.
Definition IR2Vec.h:71
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
Definition Casting.h:547
format_object< Ts... > format(const char *Fmt, const Ts &... Vals)
These are helper functions used to produce formatted output.
Definition Format.h:129
LLVM_ATTRIBUTE_VISIBILITY_DEFAULT AnalysisKey InnerAnalysisManagerProxy< AnalysisManagerT, IRUnitT, ExtraArgTs... >::Key
LLVM_ABI raw_fd_ostream & errs()
This returns a reference to a raw_ostream for standard error.
@ Other
Any other memory.
Definition ModRef.h:68
DWARFExpression::Operation Op
OutputIt move(R &&Range, OutputIt Out)
Provide wrappers to std::move which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1917
iterator_range< df_iterator< T > > depth_first(const T &G)
auto seq(T Begin, T End)
Iterate over an integral type from Begin up to - but not including - End.
Definition Sequence.h:305
AnalysisManager< Module > ModuleAnalysisManager
Convenience typedef for the Module analysis manager.
Definition MIRParser.h:39
Implement std::hash so that hash_code can be used in STL containers.
Definition BitVector.h:870
A special type used by analysis passes to provide an address that identifies that particular analysis...
Definition Analysis.h:29
Embedding is a datatype that wraps std::vector<double>.
Definition IR2Vec.h:88
LLVM_ABI bool approximatelyEquals(const Embedding &RHS, double Tolerance=1e-4) const
Returns true if the embedding is approximately equal to the RHS embedding within the specified tolera...
Definition IR2Vec.cpp:132
LLVM_ABI Embedding & operator+=(const Embedding &RHS)
Arithmetic operators.
Definition IR2Vec.cpp:87
LLVM_ABI Embedding operator-(const Embedding &RHS) const
Definition IR2Vec.cpp:107
LLVM_ABI Embedding & operator-=(const Embedding &RHS)
Definition IR2Vec.cpp:100
LLVM_ABI Embedding operator*(double Factor) const
Definition IR2Vec.cpp:119
size_t size() const
Definition IR2Vec.h:101
LLVM_ABI Embedding & operator*=(double Factor)
Definition IR2Vec.cpp:113
LLVM_ABI Embedding operator+(const Embedding &RHS) const
Definition IR2Vec.cpp:94
LLVM_ABI Embedding & scaleAndAdd(const Embedding &Src, float Factor)
Adds Src Embedding scaled by Factor with the called Embedding.
Definition IR2Vec.cpp:125
LLVM_ABI void print(raw_ostream &OS) const
Definition IR2Vec.cpp:145