33#define DEBUG_TYPE "ir2vec"
36 "Number of lookups to entities not present in the vocabulary");
48 cl::desc(
"Weight for opcode embeddings"),
51 cl::desc(
"Weight for type embeddings"),
54 cl::desc(
"Weight for argument embeddings"),
59 "Generate symbolic embeddings"),
61 "Generate flow-aware embeddings")),
76 std::vector<double> TempOut;
88 assert(this->
size() == RHS.
size() &&
"Vectors must have the same dimension");
89 std::transform(this->
begin(), this->
end(), RHS.
begin(), this->begin(),
101 assert(this->
size() == RHS.
size() &&
"Vectors must have the same dimension");
102 std::transform(this->
begin(), this->
end(), RHS.
begin(), this->begin(),
103 std::minus<double>());
115 [Factor](
double Elem) {
return Elem * Factor; });
126 assert(this->
size() == Src.
size() &&
"Vectors must have the same dimension");
127 for (
size_t Itr = 0; Itr < this->
size(); ++Itr)
128 (*
this)[Itr] += Src[Itr] * Factor;
133 double Tolerance)
const {
134 assert(this->
size() == RHS.
size() &&
"Vectors must have the same dimension");
135 for (
size_t Itr = 0; Itr < this->
size(); ++Itr)
136 if (std::abs((*
this)[Itr] - RHS[Itr]) > Tolerance) {
137 LLVM_DEBUG(
errs() <<
"Embedding mismatch at index " << Itr <<
": "
138 << (*
this)[Itr] <<
" vs " << RHS[Itr]
139 <<
"; Tolerance: " << Tolerance <<
"\n");
147 for (
const auto &Elem : Data)
148 OS <<
" " <<
format(
"%.2f", Elem) <<
" ";
160 return std::make_unique<SymbolicEmbedder>(
F,
Vocab);
162 return std::make_unique<FlowAwareEmbedder>(
F,
Vocab);
170 if (
F.isDeclaration())
183 for (
const auto &
I : BB)
184 if (!
I.isDebugOrPseudoInst())
193 for (
const auto &
Op :
I.operands())
196 Vocab[
I.getOpcode()] +
Vocab[
I.getType()->getTypeID()] + ArgEmb;
198 InstVector +=
Vocab[IC->getPredicate()];
204 auto It = InstVecMap.
find(&
I);
205 if (It != InstVecMap.
end())
211 for (
const auto &
Op :
I.operands()) {
214 auto DefIt = InstVecMap.
find(DefInst);
226 if (DefIt != InstVecMap.
end())
227 ArgEmb += DefIt->second;
234 LLVM_DEBUG(
errs() <<
"Using embedding from vocabulary for operand: "
242 Vocab[
I.getOpcode()] +
Vocab[
I.getType()->getTypeID()] + ArgEmb;
244 InstVector +=
Vocab[IC->getPredicate()];
245 InstVecMap[&
I] = InstVector;
254 : Sections(
std::
move(SectionData)), TotalSize([&] {
255 assert(!Sections.empty() &&
"Vocabulary has no sections");
258 for (
const auto &Section : Sections) {
259 assert(!Section.empty() &&
"Vocabulary section is empty");
260 Size += Section.size();
267 assert(!Sections.empty() &&
"Vocabulary has no sections");
268 assert(!Sections[0].empty() &&
"First section of vocabulary is empty");
269 unsigned ExpectedDim =
static_cast<unsigned>(Sections[0][0].size());
273 [[maybe_unused]]
auto allSameDim =
274 [ExpectedDim](
const std::vector<Embedding> &Section) {
275 return std::all_of(Section.begin(), Section.end(),
277 return Emb.size() == ExpectedDim;
280 assert(std::all_of(Sections.begin(), Sections.end(), allSameDim) &&
281 "All embeddings must have the same dimension");
287 assert(SectionId < Storage->Sections.size() &&
"Invalid section ID");
288 assert(LocalIndex < Storage->Sections[SectionId].
size() &&
289 "Local index out of range");
290 return Storage->Sections[SectionId][LocalIndex];
297 LocalIndex >= Storage->Sections[SectionId].size()) {
298 assert(LocalIndex == Storage->Sections[SectionId].size() &&
299 "Local index should be at the end of the current section");
308 return Storage ==
Other.Storage && SectionId ==
Other.SectionId &&
309 LocalIndex ==
Other.LocalIndex;
314 return !(*
this ==
Other);
319 VocabMap &TargetVocab,
unsigned &Dim) {
324 "JSON root is not an object");
329 "Missing '" + std::string(
Key) +
330 "' section in vocabulary file");
333 "Unable to parse '" + std::string(
Key) +
334 "' section from vocabulary");
336 Dim = TargetVocab.begin()->second.size();
339 "Dimension of '" + std::string(
Key) +
340 "' section of the vocabulary is zero");
342 if (!std::all_of(TargetVocab.begin(), TargetVocab.end(),
343 [Dim](
const std::pair<StringRef, Embedding> &Entry) {
344 return Entry.second.size() == Dim;
348 "All vectors in the '" + std::string(
Key) +
349 "' section of the vocabulary are not of the same dimension");
359 assert(Opcode >= 1 && Opcode <= MaxOpcodes &&
"Invalid opcode");
360#define HANDLE_INST(NUM, OPCODE, CLASS) \
361 if (Opcode == NUM) { \
364#include "llvm/IR/Instruction.def"
366 return "UnknownOpcode";
391 if (LocalIndex < fcmpRange)
396 LocalIndex - fcmpRange);
402 PredNameBuffer =
"FCMP_";
404 PredNameBuffer =
"ICMP_";
406 return PredNameBuffer;
410 assert(Pos < NumCanonicalEntries &&
"Position out of bounds in vocabulary");
412 if (Pos < MaxOpcodes)
415 if (Pos < OperandBaseOffset)
416 return getVocabKeyForCanonicalTypeID(
419 if (Pos < PredicateBaseOffset)
421 static_cast<OperandKind>(Pos - OperandBaseOffset));
428 ModuleAnalysisManager::Invalidator &Inv)
const {
430 return !(PAC.preservedWhenStateless());
434 float DummyVal = 0.1f;
438 std::vector<std::vector<Embedding>> Sections;
442 std::vector<Embedding> OpcodeSec;
443 OpcodeSec.reserve(MaxOpcodes);
444 for (
unsigned I = 0;
I < MaxOpcodes; ++
I) {
445 OpcodeSec.emplace_back(Dim, DummyVal);
448 Sections.push_back(std::move(OpcodeSec));
451 std::vector<Embedding> TypeSec;
454 TypeSec.emplace_back(Dim, DummyVal);
457 Sections.push_back(std::move(TypeSec));
460 std::vector<Embedding> OperandSec;
463 OperandSec.emplace_back(Dim, DummyVal);
466 Sections.push_back(std::move(OperandSec));
469 std::vector<Embedding> PredicateSec;
472 PredicateSec.emplace_back(Dim, DummyVal);
475 Sections.push_back(std::move(PredicateSec));
481using VocabMap = std::map<std::string, Embedding>;
491 auto Content = BufOrError.get()->getBuffer();
494 if (!ParsedVocabValue)
497 unsigned OpcodeDim = 0, TypeDim = 0, ArgDim = 0;
499 OpcVocab, OpcodeDim))
510 if (!(OpcodeDim == TypeDim && TypeDim == ArgDim))
512 "Vocabulary sections have different dimensions");
526 auto handleMissingEntity = [](
const std::string &Val) {
528 <<
" is not in vocabulary, using zero vector; This "
529 "would result in an error in future.\n");
533 unsigned Dim = OpcVocab.
begin()->second.size();
534 assert(Dim > 0 &&
"Vocabulary dimension must be greater than zero");
537 std::vector<Embedding> NumericOpcodeEmbeddings(Vocabulary::MaxOpcodes,
539 for (
unsigned Opcode :
seq(0u, Vocabulary::MaxOpcodes)) {
541 auto It = OpcVocab.find(VocabKey.
str());
542 if (It != OpcVocab.end())
543 NumericOpcodeEmbeddings[Opcode] = It->second;
545 handleMissingEntity(VocabKey.
str());
552 StringRef VocabKey = Vocabulary::getVocabKeyForCanonicalTypeID(
554 if (
auto It = TypeVocab.find(VocabKey.
str()); It != TypeVocab.end()) {
555 NumericTypeEmbeddings[CTypeID] = It->second;
558 handleMissingEntity(VocabKey.
str());
567 auto It = ArgVocab.find(VocabKey.
str());
568 if (It != ArgVocab.end()) {
569 NumericArgEmbeddings[OpKind] = It->second;
572 handleMissingEntity(VocabKey.
str());
582 auto It = ArgVocab.find(VocabKey.
str());
583 if (It != ArgVocab.end()) {
584 NumericPredEmbeddings[PK] = It->second;
587 handleMissingEntity(VocabKey.
str());
592 std::vector<std::vector<Embedding>> Sections(4);
593 Sections[
static_cast<unsigned>(Section::Opcodes)] =
594 std::move(NumericOpcodeEmbeddings);
595 Sections[
static_cast<unsigned>(Section::CanonicalTypes)] =
596 std::move(NumericTypeEmbeddings);
597 Sections[
static_cast<unsigned>(Section::Operands)] =
598 std::move(NumericArgEmbeddings);
599 Sections[
static_cast<unsigned>(Section::Predicates)] =
600 std::move(NumericPredEmbeddings);
613 VocabMap OpcVocab, TypeVocab, ArgVocab;
615 readVocabularyFromFile(VocabFilePath, OpcVocab, TypeVocab, ArgVocab))
616 return std::move(Err);
619 auto scaleVocabSection = [](VocabMap &Vocab,
float Weight) {
620 for (
auto &Entry : Vocab)
621 Entry.second *= Weight;
628 return Vocabulary(buildVocabStorage(OpcVocab, TypeVocab, ArgVocab));
637 Ctx.emitError(
"Error reading vocabulary: " + EI.
message());
643 auto Ctx = &M.getContext();
645 if (Vocab.has_value())
651 Ctx->emitError(
"IR2Vec vocabulary file path not specified; You may need to "
652 "set it using --ir2vec-vocab-path");
660 emitError(VocabOrErr.takeError(), *Ctx);
664 return std::move(*VocabOrErr);
679 OS <<
"Error creating IR2Vec embeddings \n";
683 OS <<
"IR2Vec embeddings for function " <<
F.getName() <<
":\n";
684 OS <<
"Function vector: ";
685 Emb->getFunctionVector().print(OS);
687 OS <<
"Basic block vectors:\n";
689 OS <<
"Basic block: " << BB.getName() <<
":\n";
690 Emb->getBBVector(BB).print(OS);
693 OS <<
"Instruction vectors:\n";
696 OS <<
"Instruction: ";
698 Emb->getInstVector(
I).print(OS);
708 assert(IR2VecVocabulary.isValid() &&
"IR2Vec Vocabulary is invalid");
712 for (
const auto &Entry : IR2VecVocabulary) {
713 OS <<
"Key: " << IR2VecVocabulary.getStringKey(Pos++) <<
": ";
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
#define clEnumValN(ENUMVAL, FLAGNAME, DESC)
This file builds on the ADT/GraphTraits.h file to build generic depth first graph iterator.
This file defines the IR2Vec vocabulary analysis(IR2VecVocabAnalysis), the core ir2vec::Embedder inte...
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
Module.h This file contains the declarations for the Module class.
This header defines various interfaces for pass management in LLVM.
ModuleAnalysisManager MAM
Provides some synthesis utilities to produce sequences of values.
This file defines the SmallVector class.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
LLVM Basic Block Representation.
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
static LLVM_ABI StringRef getPredicateName(Predicate P)
iterator find(const_arg_type_t< KeyT > Val)
Base class for error info classes.
virtual std::string message() const
Return the error message as a string.
Lightweight error class with error context and mandatory checking.
static ErrorSuccess success()
Create a success value.
Tagged union holding either a T or a Error.
Error takeError()
Take ownership of the stored error.
LLVM_ABI PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM)
This analysis provides the vocabulary for IR2Vec.
ir2vec::Vocabulary Result
LLVM_ABI Result run(Module &M, ModuleAnalysisManager &MAM)
static LLVM_ABI AnalysisKey Key
LLVM_ABI PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM)
This is an important class for using LLVM in a threaded context.
static ErrorOr< std::unique_ptr< MemoryBuffer > > getFileOrSTDIN(const Twine &Filename, bool IsText=false, bool RequiresNullTerminator=true, std::optional< Align > Alignment=std::nullopt)
Open the specified file as a MemoryBuffer, or open stdin if the Filename is "-".
A Module instance is used to store all the information related to an LLVM module.
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
PreservedAnalysisChecker getChecker() const
Build a checker for this PreservedAnalyses and the specified analysis type.
SmallString - A SmallString is just a SmallVector with methods and accessors that make it work better...
StringRef - Represent a constant reference to a string, i.e.
std::string str() const
str - Get the contents as an std::string.
LLVM Value Representation.
static LLVM_ABI std::unique_ptr< Embedder > create(IR2VecKind Mode, const Function &F, const Vocabulary &Vocab)
Factory method to create an Embedder object.
const unsigned Dimension
Dimension of the vector representation; captured from the input vocabulary.
Embedding computeEmbeddings() const
Function to compute embeddings.
Iterator support for section-based access.
const_iterator(const VocabStorage *Storage, unsigned SectionId, size_t LocalIndex)
LLVM_ABI bool operator!=(const const_iterator &Other) const
LLVM_ABI const_iterator & operator++()
LLVM_ABI const Embedding & operator*() const
LLVM_ABI bool operator==(const const_iterator &Other) const
Generic storage class for section-based vocabularies.
static Error parseVocabSection(StringRef Key, const json::Value &ParsedVocabValue, VocabMap &TargetVocab, unsigned &Dim)
Parse a vocabulary section from JSON and populate the target vocabulary map.
unsigned getNumSections() const
Get number of sections.
size_t size() const
Get total number of entries across all sections.
VocabStorage()=default
Default constructor creates empty storage (invalid state)
const_iterator begin() const
std::map< std::string, Embedding > VocabMap
Class for storing and accessing the IR2Vec vocabulary.
static LLVM_ABI StringRef getVocabKeyForOperandKind(OperandKind Kind)
Function to get vocabulary key for a given OperandKind.
LLVM_ABI bool invalidate(Module &M, const PreservedAnalyses &PA, ModuleAnalysisManager::Invalidator &Inv) const
static LLVM_ABI Expected< Vocabulary > fromFile(StringRef VocabFilePath, float OpcWeight=1.0, float TypeWeight=0.5, float ArgWeight=0.2)
Create a Vocabulary by loading embeddings from a JSON file.
static LLVM_ABI OperandKind getOperandKind(const Value *Op)
Function to classify an operand into OperandKind.
friend class llvm::IR2VecVocabAnalysis
static LLVM_ABI StringRef getStringKey(unsigned Pos)
Returns the string key for a given index position in the vocabulary.
static constexpr unsigned MaxCanonicalTypeIDs
static constexpr unsigned MaxOperandKinds
OperandKind
Operand kinds supported by IR2Vec Vocabulary.
static LLVM_ABI StringRef getVocabKeyForPredicate(CmpInst::Predicate P)
Function to get vocabulary key for a given predicate.
static LLVM_ABI StringRef getVocabKeyForOpcode(unsigned Opcode)
Function to get vocabulary key for a given Opcode.
LLVM_ABI bool isValid() const
static LLVM_ABI VocabStorage createDummyVocabForTest(unsigned Dim=1)
Create a dummy vocabulary for testing purposes.
static constexpr unsigned MaxPredicateKinds
CanonicalTypeID
Canonical type IDs supported by IR2Vec Vocabulary.
An Object is a JSON object, which maps strings to heterogenous JSON values.
LLVM_ABI Value * get(StringRef K)
The root is the trivial Path to the root value.
A "cursor" marking a position within a Value.
A Value is an JSON value of unknown type.
const json::Object * getAsObject() const
This class implements an extremely fast bulk output stream that can only output to a stream.
ValuesClass values(OptsTy... Options)
Helper to build a ValuesClass by forwarding a variable number of arguments as an initializer list to ...
initializer< Ty > init(const Ty &Val)
LLVM_ABI cl::opt< float > ArgWeight
LLVM_ABI cl::opt< std::string > VocabFile
LLVM_ABI cl::opt< float > OpcWeight
LLVM_ABI cl::opt< float > TypeWeight
LLVM_ABI cl::opt< IR2VecKind > IR2VecEmbeddingKind
LLVM_ABI llvm::cl::OptionCategory IR2VecCategory
LLVM_ABI llvm::Expected< Value > parse(llvm::StringRef JSON)
Parses the provided JSON source, or returns a ParseError.
bool fromJSON(const Value &E, std::string &Out, Path P)
ir2vec::Embedding Embedding
This is an optimization pass for GlobalISel generic memory operations.
Error createFileError(const Twine &F, Error E)
Concatenate a source file path and/or name with an Error.
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
void handleAllErrors(Error E, HandlerTs &&... Handlers)
Behaves the same as handleErrors, except that by contract all errors must be handled by the given han...
Error createStringError(std::error_code EC, char const *Fmt, const Ts &... Vals)
Create formatted StringError object.
IR2VecKind
IR2Vec computes two kinds of embeddings: Symbolic and Flow-aware.
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
format_object< Ts... > format(const char *Fmt, const Ts &... Vals)
These are helper functions used to produce formatted output.
LLVM_ATTRIBUTE_VISIBILITY_DEFAULT AnalysisKey InnerAnalysisManagerProxy< AnalysisManagerT, IRUnitT, ExtraArgTs... >::Key
LLVM_ABI raw_fd_ostream & errs()
This returns a reference to a raw_ostream for standard error.
DWARFExpression::Operation Op
OutputIt move(R &&Range, OutputIt Out)
Provide wrappers to std::move which take ranges instead of having to pass begin/end explicitly.
iterator_range< df_iterator< T > > depth_first(const T &G)
auto seq(T Begin, T End)
Iterate over an integral type from Begin up to - but not including - End.
AnalysisManager< Module > ModuleAnalysisManager
Convenience typedef for the Module analysis manager.
Implement std::hash so that hash_code can be used in STL containers.
A special type used by analysis passes to provide an address that identifies that particular analysis...
Embedding is a datatype that wraps std::vector<double>.
LLVM_ABI bool approximatelyEquals(const Embedding &RHS, double Tolerance=1e-4) const
Returns true if the embedding is approximately equal to the RHS embedding within the specified tolera...
LLVM_ABI Embedding & operator+=(const Embedding &RHS)
Arithmetic operators.
LLVM_ABI Embedding operator-(const Embedding &RHS) const
LLVM_ABI Embedding & operator-=(const Embedding &RHS)
LLVM_ABI Embedding operator*(double Factor) const
LLVM_ABI Embedding & operator*=(double Factor)
LLVM_ABI Embedding operator+(const Embedding &RHS) const
LLVM_ABI Embedding & scaleAndAdd(const Embedding &Src, float Factor)
Adds Src Embedding scaled by Factor with the called Embedding.
LLVM_ABI void print(raw_ostream &OS) const