27#define DEBUG_TYPE "mir2vec"
30 "Number of lookups to MIR entities not present in the vocabulary");
42 cl::desc(
"Weight for machine opcode embeddings"),
57 if (!TII || OpcodeEntries.empty())
60 buildCanonicalOpcodeMapping();
62 unsigned CanonicalOpcodeCount = UniqueBaseOpcodeNames.size();
63 assert(CanonicalOpcodeCount > 0 &&
64 "No canonical opcodes found for target - invalid vocabulary");
65 Layout.OperandBase = CanonicalOpcodeCount;
66 generateStorage(OpcodeEntries);
67 Layout.TotalEntries = Storage.size();
82 assert(!InstrName.
empty() &&
"Instruction name should not be empty");
85 static const Regex BaseOpcodeRegex(
"([a-zA-Z_]+)");
88 if (BaseOpcodeRegex.
match(InstrName, &Matches) && Matches.
size() > 1) {
91 while (!Match.
empty() && Match.
back() ==
'_')
97 return InstrName.
str();
101 assert(!UniqueBaseOpcodeNames.empty() &&
"Canonical mapping not built");
102 auto It = std::find(UniqueBaseOpcodeNames.begin(),
103 UniqueBaseOpcodeNames.end(), BaseName.
str());
104 assert(It != UniqueBaseOpcodeNames.end() &&
105 "Base name not found in unique opcodes");
106 return std::distance(UniqueBaseOpcodeNames.begin(), It);
109unsigned MIRVocabulary::getCanonicalOpcodeIndex(
unsigned Opcode)
const {
117 assert(Pos < Layout.TotalEntries &&
"Position out of bounds in vocabulary");
120 if (Pos < Layout.OperandBase && Pos < UniqueBaseOpcodeNames.size()) {
122 auto It = UniqueBaseOpcodeNames.begin();
123 std::advance(It, Pos);
131void MIRVocabulary::generateStorage(
const VocabMap &OpcodeMap) {
138 <<
"; using zero vector. This will result in an error "
140 ++MIRVocabMissCounter;
144 unsigned EmbeddingDim = OpcodeMap.begin()->second.size();
145 std::vector<Embedding> OpcodeEmbeddings(Layout.OperandBase,
149 for (
auto COpcodeName : UniqueBaseOpcodeNames) {
150 if (
auto It = OpcodeMap.find(COpcodeName); It != OpcodeMap.end()) {
152 assert(COpcodeIndex < Layout.OperandBase &&
153 "Canonical index out of bounds");
154 OpcodeEmbeddings[COpcodeIndex] = It->second;
156 handleMissingEntity(COpcodeName);
164 auto scaleVocabSection = [](std::vector<Embedding> &Embeddings,
169 scaleVocabSection(OpcodeEmbeddings,
OpcWeight);
171 std::vector<std::vector<Embedding>> Sections(1);
172 Sections[0] = std::move(OpcodeEmbeddings);
174 Storage = ir2vec::VocabStorage(std::move(Sections));
177void MIRVocabulary::buildCanonicalOpcodeMapping() {
179 if (!UniqueBaseOpcodeNames.empty())
183 for (
unsigned Opcode = 0; Opcode < TII.getNumOpcodes(); ++Opcode) {
185 UniqueBaseOpcodeNames.insert(BaseOpcode);
188 LLVM_DEBUG(
dbgs() <<
"MIR2Vec: Built canonical mapping for target with "
189 << UniqueBaseOpcodeNames.size()
190 <<
" unique base opcodes\n");
199 "MIR2Vec Vocabulary Analysis",
false,
true)
205 return "MIR2Vec Vocabulary Analysis";
208Error MIR2VecVocabLegacyAnalysis::readVocabulary() {
214 "MIR2Vec vocabulary file path not specified; set it "
215 "using --mir2vec-vocab-path");
221 auto Content = BufOrError.get()->getBuffer();
223 Expected<json::Value> ParsedVocabValue =
json::parse(Content);
224 if (!ParsedVocabValue)
229 "entities", *ParsedVocabValue, StrVocabMap, Dim))
235void MIR2VecVocabLegacyAnalysis::emitError(
Error Err, LLVMContext &Ctx) {
239mir2vec::MIRVocabulary
241 if (StrVocabMap.empty()) {
242 if (
Error Err = readVocabulary()) {
243 emitError(std::move(Err), M.getContext());
252 for (
const auto &
F : M) {
253 if (
F.isDeclaration())
275 "MIR2Vec Vocabulary Printer Pass",
false,
true)
287 auto MIR2VecVocab =
Analysis.getMIR2VecVocabulary(M);
289 if (!MIR2VecVocab.isValid()) {
290 OS <<
"MIR2Vec Vocabulary Printer: Invalid vocabulary\n";
295 for (
const auto &Entry : MIR2VecVocab) {
296 OS <<
"Key: " << MIR2VecVocab.getStringKey(Pos++) <<
": ";
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
block Block Frequency Analysis
const HexagonInstrInfo * TII
Module.h This file contains the declarations for the Module class.
This file defines the MIR2Vec vocabulary analysis(MIR2VecVocabLegacyAnalysis), the core mir2vec::MIRE...
#define INITIALIZE_PASS_DEPENDENCY(depName)
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Lightweight error class with error context and mandatory checking.
static ErrorSuccess success()
Create a success value.
Error takeError()
Take ownership of the stored error.
LLVM_ABI void emitError(const Instruction *I, const Twine &ErrorStr)
emitError - Emit an error message to the currently installed error handler with optional location inf...
Pass to analyze and populate MIR2Vec vocabulary from a module.
mir2vec::MIRVocabulary getMIR2VecVocabulary(const Module &M)
This pass prints the embeddings in the MIR2Vec vocabulary.
bool doFinalization(Module &M) override
doFinalization - Virtual method overriden by subclasses to do any necessary clean up after all passes...
bool runOnMachineFunction(MachineFunction &MF) override
runOnMachineFunction - This method must be overloaded to perform the desired machine code transformat...
MIR2VecVocabPrinterLegacyPass(raw_ostream &OS)
MachineFunctionPass - This class adapts the FunctionPass interface to allow convenient creation of pa...
This class contains meta information specific to a module.
LLVM_ABI MachineFunction * getMachineFunction(const Function &F) const
Returns the MachineFunction associated to IR function F if there is one, otherwise nullptr.
static ErrorOr< std::unique_ptr< MemoryBuffer > > getFileOrSTDIN(const Twine &Filename, bool IsText=false, bool RequiresNullTerminator=true, std::optional< Align > Alignment=std::nullopt)
Open the specified file as a MemoryBuffer, or open stdin if the Filename is "-".
A Module instance is used to store all the information related to an LLVM module.
AnalysisType & getAnalysis() const
getAnalysis<AnalysisType>() - This function is used by subclasses to get to the analysis information ...
LLVM_ABI bool match(StringRef String, SmallVectorImpl< StringRef > *Matches=nullptr, std::string *Error=nullptr) const
matches - Match the regex against a given String.
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
StringRef - Represent a constant reference to a string, i.e.
std::string str() const
str - Get the contents as an std::string.
constexpr bool empty() const
empty - Check if the string is empty.
char back() const
back - Get the last character in the string.
StringRef drop_back(size_t N=1) const
Return a StringRef equal to 'this' but with the last N elements dropped.
TargetInstrInfo - Interface to description of machine instruction set.
static Error parseVocabSection(StringRef Key, const json::Value &ParsedVocabValue, VocabMap &TargetVocab, unsigned &Dim)
Parse a vocabulary section from JSON and populate the target vocabulary map.
Class for storing and accessing the MIR2Vec vocabulary.
static std::string extractBaseOpcodeName(StringRef InstrName)
Static method for extracting base opcode names (public for testing)
std::string getStringKey(unsigned Pos) const
Get the string key for a vocabulary entry at the given position.
unsigned getCanonicalIndexForBaseName(StringRef BaseName) const
Get canonical index for base name (public for testing)
This class implements an extremely fast bulk output stream that can only output to a stream.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
initializer< Ty > init(const Ty &Val)
LLVM_ABI llvm::Expected< Value > parse(llvm::StringRef JSON)
Parses the provided JSON source, or returns a ParseError.
llvm::cl::OptionCategory MIR2VecCategory
cl::opt< float > OpcWeight
static cl::opt< std::string > VocabFile("mir2vec-vocab-path", cl::Optional, cl::desc("Path to the vocabulary file for MIR2Vec"), cl::init(""), cl::cat(MIR2VecCategory))
ir2vec::Embedding Embedding
This is an optimization pass for GlobalISel generic memory operations.
Error createFileError(const Twine &F, Error E)
Concatenate a source file path and/or name with an Error.
LLVM_ABI std::error_code inconvertibleErrorCode()
The value returned by this function can be returned from convertToErrorCode for Error values where no...
Error createStringError(std::error_code EC, char const *Fmt, const Ts &... Vals)
Create formatted StringError object.
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
LLVM_ATTRIBUTE_VISIBILITY_DEFAULT AnalysisKey InnerAnalysisManagerProxy< AnalysisManagerT, IRUnitT, ExtraArgTs... >::Key
Error make_error(ArgTs &&... Args)
Make a Error instance representing failure using the given error info type.
LLVM_ABI raw_fd_ostream & errs()
This returns a reference to a raw_ostream for standard error.
LLVM_ABI MachineFunctionPass * createMIR2VecVocabPrinterLegacyPass(raw_ostream &OS)
MIR2VecVocabPrinter pass - This pass prints out the MIR2Vec vocabulary contents to the given stream a...
std::string toString(const APInt &I, unsigned Radix, bool Signed, bool formatAsCLiteral=false, bool UpperCase=true, bool InsertSeparators=false)