LLVM 23.0.0git
IR2Vec.h
Go to the documentation of this file.
1//===- IR2Vec.h - Implementation of IR2Vec ----------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM
4// Exceptions. See the LICENSE file for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// \file
10/// This file defines the IR2Vec vocabulary analysis(IR2VecVocabAnalysis),
11/// the core ir2vec::Embedder interface for generating IR embeddings,
12/// and related utilities like the IR2VecPrinterPass.
13///
14/// Program Embeddings are typically or derived-from a learned
15/// representation of the program. Such embeddings are used to represent the
16/// programs as input to machine learning algorithms. IR2Vec represents the
17/// LLVM IR as embeddings.
18///
19/// The IR2Vec algorithm is described in the following paper:
20///
21/// IR2Vec: LLVM IR Based Scalable Program Embeddings, S. VenkataKeerthy,
22/// Rohit Aggarwal, Shalini Jain, Maunendra Sankar Desarkar, Ramakrishna
23/// Upadrasta, and Y. N. Srikant, ACM Transactions on Architecture and
24/// Code Optimization (TACO), 2020. https://doi.org/10.1145/3418463.
25/// https://arxiv.org/abs/1909.06228
26///
27/// To obtain embeddings:
28/// First run IR2VecVocabAnalysis to populate the vocabulary.
29/// Then, use the Embedder interface to generate embeddings for the desired IR
30/// entities. See the documentation for more details -
31/// https://llvm.org/docs/MLGO.html#ir2vec-embeddings
32///
33//===----------------------------------------------------------------------===//
34
35#ifndef LLVM_ANALYSIS_IR2VEC_H
36#define LLVM_ANALYSIS_IR2VEC_H
37
38#include "llvm/ADT/DenseMap.h"
40#include "llvm/IR/PassManager.h"
41#include "llvm/IR/Type.h"
45#include "llvm/Support/JSON.h"
46#include <array>
47#include <map>
48#include <optional>
49
50namespace llvm {
51
52class Module;
53class BasicBlock;
54class Instruction;
55class Function;
56class Value;
57class raw_ostream;
58class LLVMContext;
60
61/// IR2Vec computes two kinds of embeddings: Symbolic and Flow-aware.
62/// Symbolic embeddings capture the "syntactic" and "statistical correlation"
63/// of the IR entities. Flow-aware embeddings build on top of symbolic
64/// embeddings and additionally capture the flow information in the IR.
65/// IR2VecKind is used to specify the type of embeddings to generate.
66/// Note: Implementation of FlowAware embeddings is not same as the one
67/// described in the paper. The current implementation is a simplified version
68/// that captures the flow information (SSA-based use-defs) without tracing
69/// through memory level use-defs in the embedding computation described in the
70/// paper.
72
73namespace ir2vec {
74
81
82/// Embedding is a datatype that wraps std::vector<double>. It provides
83/// additional functionality for arithmetic and comparison operations.
84/// It is meant to be used *like* std::vector<double> but is more restrictive
85/// in the sense that it does not allow the user to change the size of the
86/// embedding vector. The dimension of the embedding is fixed at the time of
87/// construction of Embedding object. But the elements can be modified in-place.
88struct Embedding {
89private:
90 std::vector<double> Data;
91
92public:
93 Embedding() = default;
94 Embedding(const std::vector<double> &V) : Data(V) {}
95 Embedding(std::vector<double> &&V) : Data(std::move(V)) {}
96 Embedding(std::initializer_list<double> IL) : Data(IL) {}
97
98 explicit Embedding(size_t Size) : Data(Size, 0.0) {}
99 Embedding(size_t Size, double InitialValue) : Data(Size, InitialValue) {}
100
101 size_t size() const { return Data.size(); }
102 bool empty() const { return Data.empty(); }
103
104 double &operator[](size_t Itr) {
105 assert(Itr < Data.size() && "Index out of bounds");
106 return Data[Itr];
107 }
108
109 const double &operator[](size_t Itr) const {
110 assert(Itr < Data.size() && "Index out of bounds");
111 return Data[Itr];
112 }
113
114 using iterator = std::vector<double>::iterator;
115 using const_iterator = std::vector<double>::const_iterator;
116
117 iterator begin() { return Data.begin(); }
118 iterator end() { return Data.end(); }
119 const_iterator begin() const { return Data.begin(); }
120 const_iterator end() const { return Data.end(); }
121 const_iterator cbegin() const { return Data.cbegin(); }
122 const_iterator cend() const { return Data.cend(); }
123
124 const std::vector<double> &getData() const { return Data; }
125
126 /// Arithmetic operators
131 LLVM_ABI Embedding &operator*=(double Factor);
132 LLVM_ABI Embedding operator*(double Factor) const;
133
134 /// Adds Src Embedding scaled by Factor with the called Embedding.
135 /// Called_Embedding += Src * Factor
136 LLVM_ABI Embedding &scaleAndAdd(const Embedding &Src, float Factor);
137
138 /// Returns true if the embedding is approximately equal to the RHS embedding
139 /// within the specified tolerance.
141 double Tolerance = 1e-4) const;
142
143 /// Returns true if all elements of the embedding are zero.
144 bool isZero() const {
145 return llvm::all_of(Data, [](double D) { return D == 0.0; });
146 }
147
148 LLVM_ABI void print(raw_ostream &OS) const;
149};
150
153
154/// Generic storage class for section-based vocabularies.
155/// VocabStorage provides a generic foundation for storing and accessing
156/// embeddings organized into sections.
158private:
159 /// Section-based storage
160 std::vector<std::vector<Embedding>> Sections;
161
162 // Fixme: Check if these members can be made const (and delete move
163 // assignment) after changing Vocabulary creation by using static factory
164 // methods.
165 size_t TotalSize = 0;
166 unsigned Dimension = 0;
167
168public:
169 /// Default constructor creates empty storage (invalid state)
170 VocabStorage() = default;
171
172 /// Create a VocabStorage with pre-organized section data
173 LLVM_ABI VocabStorage(std::vector<std::vector<Embedding>> &&SectionData);
174
177
178 VocabStorage(const VocabStorage &) = delete;
180
181 /// Get total number of entries across all sections
182 size_t size() const { return TotalSize; }
183
184 /// Get number of sections
185 unsigned getNumSections() const {
186 return static_cast<unsigned>(Sections.size());
187 }
188
189 /// Section-based access: Storage[sectionId][localIndex]
190 const std::vector<Embedding> &operator[](unsigned SectionId) const {
191 assert(SectionId < Sections.size() && "Invalid section ID");
192 return Sections[SectionId];
193 }
194
195 /// Get vocabulary dimension
196 unsigned getDimension() const { return Dimension; }
197
198 /// Check if vocabulary is valid (has data)
199 bool isValid() const { return TotalSize > 0; }
200
201 /// Iterator support for section-based access
203 const VocabStorage *Storage;
204 unsigned SectionId = 0;
205 size_t LocalIndex = 0;
206
207 public:
208 const_iterator(const VocabStorage *Storage, unsigned SectionId,
209 size_t LocalIndex)
210 : Storage(Storage), SectionId(SectionId), LocalIndex(LocalIndex) {}
211
212 LLVM_ABI const Embedding &operator*() const;
214 LLVM_ABI bool operator==(const const_iterator &Other) const;
215 LLVM_ABI bool operator!=(const const_iterator &Other) const;
216 };
217
218 const_iterator begin() const { return const_iterator(this, 0, 0); }
220 return const_iterator(this, getNumSections(), 0);
221 }
222
223 using VocabMap = std::map<std::string, Embedding>;
224 /// Parse a vocabulary section from JSON and populate the target vocabulary
225 /// map.
227 const json::Value &ParsedVocabValue,
228 VocabMap &TargetVocab, unsigned &Dim);
229};
230
231/// Class for storing and accessing the IR2Vec vocabulary.
232/// The Vocabulary class manages seed embeddings for LLVM IR entities. The
233/// seed embeddings are the initial learned representations of the entities
234/// of LLVM IR. The IR2Vec representation for a given IR is derived from these
235/// seed embeddings.
236///
237/// The vocabulary contains the seed embeddings for three types of entities:
238/// instruction opcodes, types, and operands. Types are grouped/canonicalized
239/// for better learning (e.g., all float variants map to FloatTy). The
240/// vocabulary abstracts away the canonicalization effectively, the exposed APIs
241/// handle all the known LLVM IR opcodes, types and operands.
242///
243/// This class helps populate the seed embeddings in an internal vector-based
244/// ADT. It provides logic to map every IR entity to a specific slot index or
245/// position in this vector, enabling O(1) embedding lookup while avoiding
246/// unnecessary computations involving string based lookups while generating the
247/// embeddings.
250
251 // Vocabulary Layout:
252 // +----------------+------------------------------------------------------+
253 // | Entity Type | Index Range |
254 // +----------------+------------------------------------------------------+
255 // | Opcodes | [0 .. (MaxOpcodes-1)] |
256 // | Canonical Types| [MaxOpcodes .. (MaxOpcodes+MaxCanonicalTypeIDs-1)] |
257 // | Operands | [(MaxOpcodes+MaxCanonicalTypeIDs) .. NumCanEntries] |
258 // +----------------+------------------------------------------------------+
259 // Note: MaxOpcodes is the number of unique opcodes supported by LLVM IR.
260 // MaxCanonicalTypeIDs is the number of canonicalized type IDs.
261 // "Similar" LLVM Types are grouped/canonicalized together. E.g., all
262 // float variants (FloatTy, DoubleTy, HalfTy, etc.) map to
263 // CanonicalTypeID::FloatTy. This helps reduce the vocabulary size
264 // and improves learning. Operands include Comparison predicates
265 // (ICmp/FCmp) along with other operand types. This can be extended to
266 // include other specializations in future.
267 enum class Section : unsigned {
268 Opcodes = 0,
269 CanonicalTypes = 1,
270 Operands = 2,
271 Predicates = 3,
272 MaxSections
273 };
274
275 // Use section-based storage for better organization and efficiency
276 VocabStorage Storage;
277
278 static constexpr unsigned NumICmpPredicates =
279 static_cast<unsigned>(CmpInst::LAST_ICMP_PREDICATE) -
280 static_cast<unsigned>(CmpInst::FIRST_ICMP_PREDICATE) + 1;
281 static constexpr unsigned NumFCmpPredicates =
282 static_cast<unsigned>(CmpInst::LAST_FCMP_PREDICATE) -
283 static_cast<unsigned>(CmpInst::FIRST_FCMP_PREDICATE) + 1;
284
285public:
286 /// Canonical type IDs supported by IR2Vec Vocabulary
302
303 /// Operand kinds supported by IR2Vec Vocabulary
311
312 /// Vocabulary layout constants
313#define LAST_OTHER_INST(NUM) static constexpr unsigned MaxOpcodes = NUM;
314#include "llvm/IR/Instruction.def"
315#undef LAST_OTHER_INST
316
317 static constexpr unsigned MaxTypeIDs = Type::TypeID::TargetExtTyID + 1;
318 static constexpr unsigned MaxCanonicalTypeIDs =
319 static_cast<unsigned>(CanonicalTypeID::MaxCanonicalType);
320 static constexpr unsigned MaxOperandKinds =
321 static_cast<unsigned>(OperandKind::MaxOperandKind);
322 // CmpInst::Predicate has gaps. We want the vocabulary to be dense without
323 // empty slots.
324 static constexpr unsigned MaxPredicateKinds =
325 NumICmpPredicates + NumFCmpPredicates;
326
327 Vocabulary() = default;
328 LLVM_ABI Vocabulary(VocabStorage &&Storage) : Storage(std::move(Storage)) {}
329
330 Vocabulary(const Vocabulary &) = delete;
331 Vocabulary &operator=(const Vocabulary &) = delete;
332
333 Vocabulary(Vocabulary &&) = default;
335
336 /// Create a Vocabulary by loading embeddings from a JSON file.
337 /// This is the primary entry point for programmatic vocabulary creation,
338 /// suitable for use in Python bindings or other contexts where command-line
339 /// options are not available. Weights are applied to scale the embeddings
340 /// for opcodes, types, and arguments respectively.
341 LLVM_ABI static Expected<Vocabulary> fromFile(StringRef VocabFilePath,
342 float OpcWeight = 1.0,
343 float TypeWeight = 0.5,
344 float ArgWeight = 0.2);
345
346 LLVM_ABI bool isValid() const {
347 return Storage.size() == NumCanonicalEntries;
348 }
349
350 LLVM_ABI unsigned getDimension() const {
351 assert(isValid() && "IR2Vec Vocabulary is invalid");
352 return Storage.getDimension();
353 }
354
355 /// Total number of entries (opcodes + canonicalized types + operand kinds +
356 /// predicates)
357 static constexpr size_t getCanonicalSize() { return NumCanonicalEntries; }
358
359 /// Function to get vocabulary key for a given Opcode
360 LLVM_ABI static StringRef getVocabKeyForOpcode(unsigned Opcode);
361
362 /// Function to get vocabulary key for a given TypeID
364 return getVocabKeyForCanonicalTypeID(getCanonicalTypeID(TypeID));
365 }
366
367 /// Function to get vocabulary key for a given OperandKind
369 unsigned Index = static_cast<unsigned>(Kind);
370 assert(Index < MaxOperandKinds && "Invalid OperandKind");
371 return OperandKindNames[Index];
372 }
373
374 /// Function to classify an operand into OperandKind
376
377 /// Function to get vocabulary key for a given predicate
379
380 /// Functions to return flat index
381 LLVM_ABI static unsigned getIndex(unsigned Opcode) {
382 assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode");
383 return Opcode - 1; // Convert to zero-based index
384 }
385
387 assert(static_cast<unsigned>(TypeID) < MaxTypeIDs && "Invalid type ID");
388 return MaxOpcodes + static_cast<unsigned>(getCanonicalTypeID(TypeID));
389 }
390
391 LLVM_ABI static unsigned getIndex(const Value &Op) {
392 unsigned Index = static_cast<unsigned>(getOperandKind(&Op));
393 assert(Index < MaxOperandKinds && "Invalid OperandKind");
394 return OperandBaseOffset + Index;
395 }
396
398 return PredicateBaseOffset + getPredicateLocalIndex(P);
399 }
400
401 /// Accessors to get the embedding for a given entity.
402 LLVM_ABI const ir2vec::Embedding &operator[](unsigned Opcode) const {
403 assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode");
404 return Storage[static_cast<unsigned>(Section::Opcodes)][Opcode - 1];
405 }
406
408 assert(static_cast<unsigned>(TypeID) < MaxTypeIDs && "Invalid type ID");
409 unsigned LocalIndex = static_cast<unsigned>(getCanonicalTypeID(TypeID));
410 return Storage[static_cast<unsigned>(Section::CanonicalTypes)][LocalIndex];
411 }
412
413 LLVM_ABI const ir2vec::Embedding &operator[](const Value &Arg) const {
414 unsigned LocalIndex = static_cast<unsigned>(getOperandKind(&Arg));
415 assert(LocalIndex < MaxOperandKinds && "Invalid OperandKind");
416 return Storage[static_cast<unsigned>(Section::Operands)][LocalIndex];
417 }
418
420 unsigned LocalIndex = getPredicateLocalIndex(P);
421 return Storage[static_cast<unsigned>(Section::Predicates)][LocalIndex];
422 }
423
424 /// Const Iterator type aliases
426
428 assert(isValid() && "IR2Vec Vocabulary is invalid");
429 return Storage.begin();
430 }
431
432 const_iterator cbegin() const { return begin(); }
433
435 assert(isValid() && "IR2Vec Vocabulary is invalid");
436 return Storage.end();
437 }
438
439 const_iterator cend() const { return end(); }
440
441 /// Returns the string key for a given index position in the vocabulary.
442 /// This is useful for debugging or printing the vocabulary. Do not use this
443 /// for embedding generation as string based lookups are inefficient.
444 LLVM_ABI static StringRef getStringKey(unsigned Pos);
445
446 /// Create a dummy vocabulary for testing purposes.
447 LLVM_ABI static VocabStorage createDummyVocabForTest(unsigned Dim = 1);
448
449 LLVM_ABI bool invalidate(Module &M, const PreservedAnalyses &PA,
450 ModuleAnalysisManager::Invalidator &Inv) const;
451
452private:
453 constexpr static unsigned NumCanonicalEntries =
455
456 // Base offsets for flat index computation
457 constexpr static unsigned OperandBaseOffset =
458 MaxOpcodes + MaxCanonicalTypeIDs;
459 constexpr static unsigned PredicateBaseOffset =
460 OperandBaseOffset + MaxOperandKinds;
461
462 /// Functions for predicate index calculations
463 static unsigned getPredicateLocalIndex(CmpInst::Predicate P);
464 static CmpInst::Predicate getPredicateFromLocalIndex(unsigned LocalIndex);
465
466 /// String mappings for CanonicalTypeID values
467 static constexpr StringLiteral CanonicalTypeNames[] = {
468 "FloatTy", "VoidTy", "LabelTy", "MetadataTy",
469 "VectorTy", "TokenTy", "IntegerTy", "FunctionTy",
470 "PointerTy", "StructTy", "ArrayTy", "UnknownTy"};
471 static_assert(std::size(CanonicalTypeNames) ==
472 static_cast<unsigned>(CanonicalTypeID::MaxCanonicalType),
473 "CanonicalTypeNames array size must match MaxCanonicalType");
474
475 /// String mappings for OperandKind values
476 static constexpr StringLiteral OperandKindNames[] = {"Function", "Pointer",
477 "Constant", "Variable"};
478 static_assert(std::size(OperandKindNames) ==
479 static_cast<unsigned>(OperandKind::MaxOperandKind),
480 "OperandKindNames array size must match MaxOperandKind");
481
482 /// Every known TypeID defined in llvm/IR/Type.h is expected to have a
483 /// corresponding mapping here in the same order as enum Type::TypeID.
484 static constexpr std::array<CanonicalTypeID, MaxTypeIDs> TypeIDMapping = {{
485 CanonicalTypeID::FloatTy, // HalfTyID = 0
486 CanonicalTypeID::FloatTy, // BFloatTyID
487 CanonicalTypeID::FloatTy, // FloatTyID
488 CanonicalTypeID::FloatTy, // DoubleTyID
489 CanonicalTypeID::FloatTy, // X86_FP80TyID
490 CanonicalTypeID::FloatTy, // FP128TyID
491 CanonicalTypeID::FloatTy, // PPC_FP128TyID
492 CanonicalTypeID::VoidTy, // VoidTyID
493 CanonicalTypeID::LabelTy, // LabelTyID
494 CanonicalTypeID::MetadataTy, // MetadataTyID
495 CanonicalTypeID::VectorTy, // X86_AMXTyID
496 CanonicalTypeID::TokenTy, // TokenTyID
497 CanonicalTypeID::IntegerTy, // IntegerTyID
498 CanonicalTypeID::FunctionTy, // FunctionTyID
499 CanonicalTypeID::PointerTy, // PointerTyID
500 CanonicalTypeID::StructTy, // StructTyID
501 CanonicalTypeID::ArrayTy, // ArrayTyID
502 CanonicalTypeID::VectorTy, // FixedVectorTyID
503 CanonicalTypeID::VectorTy, // ScalableVectorTyID
504 CanonicalTypeID::PointerTy, // TypedPointerTyID
505 CanonicalTypeID::UnknownTy // TargetExtTyID
506 }};
507 static_assert(TypeIDMapping.size() == MaxTypeIDs,
508 "TypeIDMapping must cover all Type::TypeID values");
509
510 /// Function to get vocabulary key for canonical type by enum
511 LLVM_ABI static StringRef
512 getVocabKeyForCanonicalTypeID(CanonicalTypeID CType) {
513 unsigned Index = static_cast<unsigned>(CType);
514 assert(Index < MaxCanonicalTypeIDs && "Invalid CanonicalTypeID");
515 return CanonicalTypeNames[Index];
516 }
517
518 /// Function to convert TypeID to CanonicalTypeID
519 LLVM_ABI static CanonicalTypeID getCanonicalTypeID(Type::TypeID TypeID) {
520 unsigned Index = static_cast<unsigned>(TypeID);
521 assert(Index < MaxTypeIDs && "Invalid TypeID");
522 return TypeIDMapping[Index];
523 }
524
525 /// Function to get the predicate enum value for a given index. Index is
526 /// relative to the predicates section of the vocabulary. E.g., Index 0
527 /// corresponds to the first predicate.
528 LLVM_ABI static CmpInst::Predicate getPredicate(unsigned Index) {
529 assert(Index < MaxPredicateKinds && "Invalid predicate index");
530 return getPredicateFromLocalIndex(Index);
531 }
532
533 using VocabMap = std::map<std::string, Embedding>;
534
535 /// Generate VocabStorage from vocabulary maps.
536 static VocabStorage buildVocabStorage(const VocabMap &OpcVocab,
537 const VocabMap &TypeVocab,
538 const VocabMap &ArgVocab);
539};
540
541/// Embedder provides the interface to generate embeddings (vector
542/// representations) for instructions, basic blocks, and functions. The
543/// vector representations are generated using IR2Vec algorithms.
544///
545/// The Embedder class is an abstract class and it is intended to be
546/// subclassed for different IR2Vec algorithms like Symbolic and Flow-aware.
547class Embedder {
548protected:
549 const Function &F;
551
552 /// Dimension of the vector representation; captured from the input vocabulary
553 const unsigned Dimension;
554
555 /// Weights for different entities (like opcode, arguments, types)
556 /// in the IR instructions to generate the vector representation.
558
563
564 /// Function to compute embeddings.
566
567 /// Function to compute the embedding for a given basic block.
568 Embedding computeEmbeddings(const BasicBlock &BB) const;
569
570 /// Function to compute the embedding for a given instruction.
571 /// Specific to the kind of embeddings being computed.
572 virtual Embedding computeEmbeddings(const Instruction &I) const = 0;
573
574public:
575 virtual ~Embedder() = default;
576
577 /// Factory method to create an Embedder object.
578 LLVM_ABI static std::unique_ptr<Embedder>
580
581 /// Computes and returns the embedding for a given instruction in the function
582 /// F
584 return computeEmbeddings(I);
585 }
586
587 /// Computes and returns the embedding for a given basic block in the function
588 /// F
590 return computeEmbeddings(BB);
591 }
592
593 /// Computes and returns the embedding for the current function.
595
596 /// Invalidate embeddings if cached. The embeddings may not be relevant
597 /// anymore when the IR changes due to transformations. In such cases, the
598 /// cached embeddings should be invalidated to ensure
599 /// correctness/recomputation. This is a no-op for SymbolicEmbedder but
600 /// removes all the cached entries in FlowAwareEmbedder.
601 virtual void invalidateEmbeddings() {}
602};
603
604/// Class for computing the Symbolic embeddings of IR2Vec.
605/// Symbolic embeddings are constructed based on the entity-level
606/// representations obtained from the Vocabulary.
608private:
609 Embedding computeEmbeddings(const Instruction &I) const override;
610
611public:
614};
615
616/// Class for computing the Flow-aware embeddings of IR2Vec.
617/// Flow-aware embeddings build on the vocabulary, just like Symbolic
618/// embeddings, and additionally capture the flow information in the IR.
620private:
621 // FlowAware embeddings would benefit from caching instruction embeddings as
622 // they are reused while computing the embeddings of other instructions.
623 mutable InstEmbeddingsMap InstVecMap;
624 Embedding computeEmbeddings(const Instruction &I) const override;
625
626public:
629 void invalidateEmbeddings() override { InstVecMap.clear(); }
630};
631
632} // namespace ir2vec
633
634/// This analysis provides the vocabulary for IR2Vec. The vocabulary provides a
635/// mapping between an entity of the IR (like opcode, type, argument, etc.) and
636/// its corresponding embedding.
637class IR2VecVocabAnalysis : public AnalysisInfoMixin<IR2VecVocabAnalysis> {
638 std::optional<ir2vec::VocabStorage> Vocab;
639
640 void emitError(Error Err, LLVMContext &Ctx);
641
642public:
646 : Vocab(std::move(Vocab)) {}
649};
650
651/// This pass prints the IR2Vec embeddings for instructions, basic blocks, and
652/// functions.
653class IR2VecPrinterPass : public PassInfoMixin<IR2VecPrinterPass> {
654 raw_ostream &OS;
655
656public:
657 explicit IR2VecPrinterPass(raw_ostream &OS) : OS(OS) {}
659 static bool isRequired() { return true; }
660};
661
662/// This pass prints the embeddings in the vocabulary
663class IR2VecVocabPrinterPass : public PassInfoMixin<IR2VecVocabPrinterPass> {
664 raw_ostream &OS;
665
666public:
667 explicit IR2VecVocabPrinterPass(raw_ostream &OS) : OS(OS) {}
669 static bool isRequired() { return true; }
670};
671
672} // namespace llvm
673
674#endif // LLVM_ANALYSIS_IR2VEC_H
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
static GCRegistry::Add< StatepointGC > D("statepoint-example", "an example strategy for statepoint")
#define LLVM_ABI
Definition Compiler.h:213
This file defines the DenseMap class.
Provides ErrorOr<T> smart pointer.
This header defines various interfaces for pass management in LLVM.
This file supports working with JSON data.
#define I(x, y, z)
Definition MD5.cpp:57
Type::TypeID TypeID
#define P(N)
ModuleAnalysisManager MAM
static cl::opt< RegAllocEvictionAdvisorAnalysisLegacy::AdvisorMode > Mode("regalloc-enable-advisor", cl::Hidden, cl::init(RegAllocEvictionAdvisorAnalysisLegacy::AdvisorMode::Default), cl::desc("Enable regalloc advisor mode"), cl::values(clEnumValN(RegAllocEvictionAdvisorAnalysisLegacy::AdvisorMode::Default, "default", "Default"), clEnumValN(RegAllocEvictionAdvisorAnalysisLegacy::AdvisorMode::Release, "release", "precompiled"), clEnumValN(RegAllocEvictionAdvisorAnalysisLegacy::AdvisorMode::Development, "development", "for training")))
Value * RHS
LLVM Basic Block Representation.
Definition BasicBlock.h:62
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
Definition InstrTypes.h:676
Lightweight error class with error context and mandatory checking.
Definition Error.h:159
Tagged union holding either a T or a Error.
Definition Error.h:485
IR2VecPrinterPass(raw_ostream &OS)
Definition IR2Vec.h:657
static bool isRequired()
Definition IR2Vec.h:659
LLVM_ABI PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM)
Definition IR2Vec.cpp:670
This analysis provides the vocabulary for IR2Vec.
Definition IR2Vec.h:637
ir2vec::Vocabulary Result
Definition IR2Vec.h:647
LLVM_ABI Result run(Module &M, ModuleAnalysisManager &MAM)
Definition IR2Vec.cpp:641
LLVM_ABI IR2VecVocabAnalysis(ir2vec::VocabStorage &&Vocab)
Definition IR2Vec.h:645
static LLVM_ABI AnalysisKey Key
Definition IR2Vec.h:643
static bool isRequired()
Definition IR2Vec.h:669
LLVM_ABI PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM)
Definition IR2Vec.cpp:704
IR2VecVocabPrinterPass(raw_ostream &OS)
Definition IR2Vec.h:667
This is an important class for using LLVM in a threaded context.
Definition LLVMContext.h:68
A Module instance is used to store all the information related to an LLVM module.
Definition Module.h:67
A set of analyses that are preserved following a run of a transformation pass.
Definition Analysis.h:112
A wrapper around a string literal that serves as a proxy for constructing global tables of StringRefs...
Definition StringRef.h:864
StringRef - Represent a constant reference to a string, i.e.
Definition StringRef.h:55
TypeID
Definitions of all of the base types for the Type system.
Definition Type.h:54
LLVM Value Representation.
Definition Value.h:75
static LLVM_ABI std::unique_ptr< Embedder > create(IR2VecKind Mode, const Function &F, const Vocabulary &Vocab)
Factory method to create an Embedder object.
Definition IR2Vec.cpp:156
virtual Embedding computeEmbeddings(const Instruction &I) const =0
Function to compute the embedding for a given instruction.
LLVM_ABI Embedding getInstVector(const Instruction &I) const
Computes and returns the embedding for a given instruction in the function F.
Definition IR2Vec.h:583
const Vocabulary & Vocab
Definition IR2Vec.h:550
virtual ~Embedder()=default
const float TypeWeight
Definition IR2Vec.h:557
const float OpcWeight
Weights for different entities (like opcode, arguments, types) in the IR instructions to generate the...
Definition IR2Vec.h:557
LLVM_ABI Embedding getFunctionVector() const
Computes and returns the embedding for the current function.
Definition IR2Vec.h:594
LLVM_ABI Embedder(const Function &F, const Vocabulary &Vocab)
Definition IR2Vec.h:559
const unsigned Dimension
Dimension of the vector representation; captured from the input vocabulary.
Definition IR2Vec.h:553
virtual void invalidateEmbeddings()
Invalidate embeddings if cached.
Definition IR2Vec.h:601
Embedding computeEmbeddings() const
Function to compute embeddings.
Definition IR2Vec.cpp:167
const float ArgWeight
Definition IR2Vec.h:557
const Function & F
Definition IR2Vec.h:549
LLVM_ABI Embedding getBBVector(const BasicBlock &BB) const
Computes and returns the embedding for a given basic block in the function F.
Definition IR2Vec.h:589
void invalidateEmbeddings() override
Invalidate embeddings if cached.
Definition IR2Vec.h:629
FlowAwareEmbedder(const Function &F, const Vocabulary &Vocab)
Definition IR2Vec.h:627
SymbolicEmbedder(const Function &F, const Vocabulary &Vocab)
Definition IR2Vec.h:612
Iterator support for section-based access.
Definition IR2Vec.h:202
const_iterator(const VocabStorage *Storage, unsigned SectionId, size_t LocalIndex)
Definition IR2Vec.h:208
LLVM_ABI bool operator!=(const const_iterator &Other) const
Definition IR2Vec.cpp:311
LLVM_ABI const_iterator & operator++()
Definition IR2Vec.cpp:292
LLVM_ABI const Embedding & operator*() const
Definition IR2Vec.cpp:285
LLVM_ABI bool operator==(const const_iterator &Other) const
Definition IR2Vec.cpp:305
Generic storage class for section-based vocabularies.
Definition IR2Vec.h:157
static Error parseVocabSection(StringRef Key, const json::Value &ParsedVocabValue, VocabMap &TargetVocab, unsigned &Dim)
Parse a vocabulary section from JSON and populate the target vocabulary map.
Definition IR2Vec.cpp:316
VocabStorage & operator=(VocabStorage &&)=default
const_iterator end() const
Definition IR2Vec.h:219
unsigned getNumSections() const
Get number of sections.
Definition IR2Vec.h:185
VocabStorage & operator=(const VocabStorage &)=delete
unsigned getDimension() const
Get vocabulary dimension.
Definition IR2Vec.h:196
size_t size() const
Get total number of entries across all sections.
Definition IR2Vec.h:182
VocabStorage()=default
Default constructor creates empty storage (invalid state)
const_iterator begin() const
Definition IR2Vec.h:218
bool isValid() const
Check if vocabulary is valid (has data)
Definition IR2Vec.h:199
VocabStorage(VocabStorage &&)=default
std::map< std::string, Embedding > VocabMap
Definition IR2Vec.h:223
const std::vector< Embedding > & operator[](unsigned SectionId) const
Section-based access: Storage[sectionId][localIndex].
Definition IR2Vec.h:190
VocabStorage(const VocabStorage &)=delete
Class for storing and accessing the IR2Vec vocabulary.
Definition IR2Vec.h:248
static LLVM_ABI StringRef getVocabKeyForOperandKind(OperandKind Kind)
Function to get vocabulary key for a given OperandKind.
Definition IR2Vec.h:368
LLVM_ABI bool invalidate(Module &M, const PreservedAnalyses &PA, ModuleAnalysisManager::Invalidator &Inv) const
Definition IR2Vec.cpp:426
static LLVM_ABI Expected< Vocabulary > fromFile(StringRef VocabFilePath, float OpcWeight=1.0, float TypeWeight=0.5, float ArgWeight=0.2)
Create a Vocabulary by loading embeddings from a JSON file.
Definition IR2Vec.cpp:609
const_iterator begin() const
Definition IR2Vec.h:427
LLVM_ABI unsigned getDimension() const
Definition IR2Vec.h:350
Vocabulary(Vocabulary &&)=default
static LLVM_ABI OperandKind getOperandKind(const Value *Op)
Function to classify an operand into OperandKind.
Definition IR2Vec.cpp:369
static LLVM_ABI unsigned getIndex(CmpInst::Predicate P)
Definition IR2Vec.h:397
Vocabulary & operator=(const Vocabulary &)=delete
static LLVM_ABI StringRef getStringKey(unsigned Pos)
Returns the string key for a given index position in the vocabulary.
Definition IR2Vec.cpp:408
static constexpr unsigned MaxCanonicalTypeIDs
Definition IR2Vec.h:318
LLVM_ABI const ir2vec::Embedding & operator[](CmpInst::Predicate P) const
Definition IR2Vec.h:419
static constexpr unsigned MaxOperandKinds
Definition IR2Vec.h:320
Vocabulary(const Vocabulary &)=delete
const_iterator cbegin() const
Definition IR2Vec.h:432
OperandKind
Operand kinds supported by IR2Vec Vocabulary.
Definition IR2Vec.h:304
static constexpr size_t getCanonicalSize()
Total number of entries (opcodes + canonicalized types + operand kinds + predicates)
Definition IR2Vec.h:357
static LLVM_ABI unsigned getIndex(const Value &Op)
Definition IR2Vec.h:391
static LLVM_ABI StringRef getVocabKeyForPredicate(CmpInst::Predicate P)
Function to get vocabulary key for a given predicate.
Definition IR2Vec.cpp:398
static constexpr unsigned MaxTypeIDs
Definition IR2Vec.h:317
LLVM_ABI Vocabulary(VocabStorage &&Storage)
Definition IR2Vec.h:328
LLVM_ABI const ir2vec::Embedding & operator[](Type::TypeID TypeID) const
Definition IR2Vec.h:407
static LLVM_ABI unsigned getIndex(Type::TypeID TypeID)
Definition IR2Vec.h:386
const_iterator end() const
Definition IR2Vec.h:434
static LLVM_ABI StringRef getVocabKeyForOpcode(unsigned Opcode)
Function to get vocabulary key for a given Opcode.
Definition IR2Vec.cpp:357
static LLVM_ABI StringRef getVocabKeyForTypeID(Type::TypeID TypeID)
Function to get vocabulary key for a given TypeID.
Definition IR2Vec.h:363
VocabStorage::const_iterator const_iterator
Const Iterator type aliases.
Definition IR2Vec.h:425
const_iterator cend() const
Definition IR2Vec.h:439
static LLVM_ABI unsigned getIndex(unsigned Opcode)
Functions to return flat index.
Definition IR2Vec.h:381
LLVM_ABI bool isValid() const
Definition IR2Vec.h:346
Vocabulary & operator=(Vocabulary &&Other)=delete
LLVM_ABI const ir2vec::Embedding & operator[](unsigned Opcode) const
Accessors to get the embedding for a given entity.
Definition IR2Vec.h:402
static LLVM_ABI VocabStorage createDummyVocabForTest(unsigned Dim=1)
Create a dummy vocabulary for testing purposes.
Definition IR2Vec.cpp:432
static constexpr unsigned MaxPredicateKinds
Definition IR2Vec.h:324
CanonicalTypeID
Canonical type IDs supported by IR2Vec Vocabulary.
Definition IR2Vec.h:287
LLVM_ABI const ir2vec::Embedding & operator[](const Value &Arg) const
Definition IR2Vec.h:413
A Value is an JSON value of unknown type.
Definition JSON.h:291
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition raw_ostream.h:53
DenseMap< const Instruction *, Embedding > InstEmbeddingsMap
Definition IR2Vec.h:151
LLVM_ABI cl::opt< float > ArgWeight
DenseMap< const BasicBlock *, Embedding > BBEmbeddingsMap
Definition IR2Vec.h:152
LLVM_ABI cl::opt< std::string > VocabFile
LLVM_ABI cl::opt< float > OpcWeight
LLVM_ABI cl::opt< float > TypeWeight
LLVM_ABI cl::opt< IR2VecKind > IR2VecEmbeddingKind
LLVM_ABI llvm::cl::OptionCategory IR2VecCategory
This is an optimization pass for GlobalISel generic memory operations.
Definition Types.h:26
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1737
IR2VecKind
IR2Vec computes two kinds of embeddings: Symbolic and Flow-aware.
Definition IR2Vec.h:71
LLVM_ATTRIBUTE_VISIBILITY_DEFAULT AnalysisKey InnerAnalysisManagerProxy< AnalysisManagerT, IRUnitT, ExtraArgTs... >::Key
@ Other
Any other memory.
Definition ModRef.h:68
DWARFExpression::Operation Op
OutputIt move(R &&Range, OutputIt Out)
Provide wrappers to std::move which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1915
AnalysisManager< Module > ModuleAnalysisManager
Convenience typedef for the Module analysis manager.
Definition MIRParser.h:39
Implement std::hash so that hash_code can be used in STL containers.
Definition BitVector.h:870
A CRTP mix-in that provides informational APIs needed for analysis passes.
Definition PassManager.h:93
A special type used by analysis passes to provide an address that identifies that particular analysis...
Definition Analysis.h:29
A CRTP mix-in to automatically provide informational APIs needed for passes.
Definition PassManager.h:70
Embedding is a datatype that wraps std::vector<double>.
Definition IR2Vec.h:88
const_iterator end() const
Definition IR2Vec.h:120
LLVM_ABI bool approximatelyEquals(const Embedding &RHS, double Tolerance=1e-4) const
Returns true if the embedding is approximately equal to the RHS embedding within the specified tolera...
Definition IR2Vec.cpp:132
const_iterator cbegin() const
Definition IR2Vec.h:121
std::vector< double >::iterator iterator
Definition IR2Vec.h:114
LLVM_ABI Embedding & operator+=(const Embedding &RHS)
Arithmetic operators.
Definition IR2Vec.cpp:87
std::vector< double >::const_iterator const_iterator
Definition IR2Vec.h:115
LLVM_ABI Embedding operator-(const Embedding &RHS) const
Definition IR2Vec.cpp:107
const std::vector< double > & getData() const
Definition IR2Vec.h:124
Embedding(size_t Size, double InitialValue)
Definition IR2Vec.h:99
LLVM_ABI Embedding & operator-=(const Embedding &RHS)
Definition IR2Vec.cpp:100
const_iterator cend() const
Definition IR2Vec.h:122
bool isZero() const
Returns true if all elements of the embedding are zero.
Definition IR2Vec.h:144
LLVM_ABI Embedding operator*(double Factor) const
Definition IR2Vec.cpp:119
size_t size() const
Definition IR2Vec.h:101
LLVM_ABI Embedding & operator*=(double Factor)
Definition IR2Vec.cpp:113
Embedding(std::initializer_list< double > IL)
Definition IR2Vec.h:96
Embedding(const std::vector< double > &V)
Definition IR2Vec.h:94
LLVM_ABI Embedding operator+(const Embedding &RHS) const
Definition IR2Vec.cpp:94
bool empty() const
Definition IR2Vec.h:102
LLVM_ABI Embedding & scaleAndAdd(const Embedding &Src, float Factor)
Adds Src Embedding scaled by Factor with the called Embedding.
Definition IR2Vec.cpp:125
Embedding(std::vector< double > &&V)
Definition IR2Vec.h:95
const double & operator[](size_t Itr) const
Definition IR2Vec.h:109
Embedding(size_t Size)
Definition IR2Vec.h:98
LLVM_ABI void print(raw_ostream &OS) const
Definition IR2Vec.cpp:145
const_iterator begin() const
Definition IR2Vec.h:119
double & operator[](size_t Itr)
Definition IR2Vec.h:104