Skip to content

[IR2Vec] Restructuring Vocabulary #145119

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: users/svkeerthy/06-20-overloading_operator_for_embeddngs
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions llvm/include/llvm/Analysis/FunctionPropertiesAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@ class FunctionPropertiesInfo {
void reIncludeBB(const BasicBlock &BB);

ir2vec::Embedding FunctionEmbedding = ir2vec::Embedding(0.0);
std::optional<ir2vec::Vocab> IR2VecVocab;
const ir2vec::Vocabulary *IR2VecVocab = nullptr;

public:
LLVM_ABI static FunctionPropertiesInfo
getFunctionPropertiesInfo(const Function &F, const DominatorTree &DT,
const LoopInfo &LI,
const IR2VecVocabResult *VocabResult);
const ir2vec::Vocabulary *Vocabulary);

LLVM_ABI static FunctionPropertiesInfo
getFunctionPropertiesInfo(Function &F, FunctionAnalysisManager &FAM);
Expand Down Expand Up @@ -145,9 +145,7 @@ class FunctionPropertiesInfo {
return FunctionEmbedding;
}

const std::optional<ir2vec::Vocab> &getIR2VecVocab() const {
return IR2VecVocab;
}
const ir2vec::Vocabulary *getIR2VecVocab() const { return IR2VecVocab; }

// Helper intended to be useful for unittests
void setFunctionEmbeddingForTest(const ir2vec::Embedding &Embedding) {
Expand Down
125 changes: 85 additions & 40 deletions llvm/include/llvm/Analysis/IR2Vec.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

#include "llvm/ADT/DenseMap.h"
#include "llvm/IR/PassManager.h"
#include "llvm/IR/Type.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/ErrorOr.h"
#include "llvm/Support/JSON.h"
Expand All @@ -42,10 +43,10 @@ class Module;
class BasicBlock;
class Instruction;
class Function;
class Type;
class Value;
class raw_ostream;
class LLVMContext;
class IR2VecVocabAnalysis;

/// IR2Vec computes two kinds of embeddings: Symbolic and Flow-aware.
/// Symbolic embeddings capture the "syntactic" and "statistical correlation"
Expand Down Expand Up @@ -124,9 +125,73 @@ struct Embedding {

using InstEmbeddingsMap = DenseMap<const Instruction *, Embedding>;
using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>;
// FIXME: Current the keys are strings. This can be changed to
// use integers for cheaper lookups.
using Vocab = std::map<std::string, Embedding>;

/// Class for storing and accessing the IR2Vec vocabulary.
/// Encapsulates all vocabulary-related constants, logic, and access methods.
class Vocabulary {
friend class llvm::IR2VecVocabAnalysis;
using VocabVector = std::vector<ir2vec::Embedding>;
VocabVector Vocab;
bool Valid = false;

/// Operand kinds supported by IR2Vec Vocabulary
#define OPERAND_KINDS \
OPERAND_KIND(FunctionID, "Function") \
OPERAND_KIND(PointerID, "Pointer") \
OPERAND_KIND(ConstantID, "Constant") \
OPERAND_KIND(VariableID, "Variable")

enum class OperandKind : unsigned {
#define OPERAND_KIND(Name, Str) Name,
OPERAND_KINDS
#undef OPERAND_KIND
MaxOperandKind
};

#undef OPERAND_KINDS

/// Vocabulary layout constants
#define LAST_OTHER_INST(NUM) static constexpr unsigned MaxOpcodes = NUM;
#include "llvm/IR/Instruction.def"
#undef LAST_OTHER_INST

static constexpr unsigned MaxTypes = Type::TypeID::TargetExtTyID + 1;
static constexpr unsigned MaxOperandKinds =
static_cast<unsigned>(OperandKind::MaxOperandKind);

/// Helper function to get vocabulary key for a given OperandKind
static StringRef getVocabKeyForOperandKind(OperandKind Kind);

/// Helper function to classify an operand into OperandKind
static OperandKind getOperandKind(const Value *Op);

/// Helper function to get vocabulary key for a given TypeID
static StringRef getVocabKeyForTypeID(Type::TypeID TypeID);

public:
Vocabulary() = default;
Vocabulary(VocabVector &&Vocab);

bool isValid() const;
unsigned getDimension() const;
unsigned size() const;

const ir2vec::Embedding &at(unsigned Position) const;
const ir2vec::Embedding &operator[](unsigned Opcode) const;
const ir2vec::Embedding &operator[](Type::TypeID TypeId) const;
const ir2vec::Embedding &operator[](const Value *Arg) const;

/// Returns the string key for a given index position in the vocabulary.
/// This is useful for debugging or printing the vocabulary. Do not use this
/// for embedding generation as string based lookups are inefficient.
static StringRef getStringKey(unsigned Pos);

/// Create a dummy vocabulary for testing purposes.
static VocabVector createDummyVocabForTest(unsigned Dim = 1);

bool invalidate(Module &M, const PreservedAnalyses &PA,
ModuleAnalysisManager::Invalidator &Inv) const;
};

/// Embedder provides the interface to generate embeddings (vector
/// representations) for instructions, basic blocks, and functions. The
Expand All @@ -137,7 +202,7 @@ using Vocab = std::map<std::string, Embedding>;
class Embedder {
protected:
const Function &F;
const Vocab &Vocabulary;
const Vocabulary &Vocab;

/// Dimension of the vector representation; captured from the input vocabulary
const unsigned Dimension;
Expand All @@ -152,7 +217,7 @@ class Embedder {
mutable BBEmbeddingsMap BBVecMap;
mutable InstEmbeddingsMap InstVecMap;

Embedder(const Function &F, const Vocab &Vocabulary);
Embedder(const Function &F, const Vocabulary &Vocab);

/// Helper function to compute embeddings. It generates embeddings for all
/// the instructions and basic blocks in the function F. Logic of computing
Expand All @@ -163,16 +228,12 @@ class Embedder {
/// Specific to the kind of embeddings being computed.
virtual void computeEmbeddings(const BasicBlock &BB) const = 0;

/// Lookup vocabulary for a given Key. If the key is not found, it returns a
/// zero vector.
Embedding lookupVocab(const std::string &Key) const;

public:
virtual ~Embedder() = default;

/// Factory method to create an Embedder object.
static std::unique_ptr<Embedder> create(IR2VecKind Mode, const Function &F,
const Vocab &Vocabulary);
const Vocabulary &Vocab);

/// Returns a map containing instructions and the corresponding embeddings for
/// the function F if it has been computed. If not, it computes the embeddings
Expand All @@ -198,56 +259,40 @@ class Embedder {
/// representations obtained from the Vocabulary.
class SymbolicEmbedder : public Embedder {
private:
/// Utility function to compute the embedding for a given type.
Embedding getTypeEmbedding(const Type *Ty) const;

/// Utility function to compute the embedding for a given operand.
Embedding getOperandEmbedding(const Value *Op) const;

void computeEmbeddings() const override;
void computeEmbeddings(const BasicBlock &BB) const override;

public:
SymbolicEmbedder(const Function &F, const Vocab &Vocabulary)
: Embedder(F, Vocabulary) {
SymbolicEmbedder(const Function &F, const Vocabulary &Vocab)
: Embedder(F, Vocab) {
FuncVector = Embedding(Dimension, 0);
}
};

} // namespace ir2vec

/// Class for storing the result of the IR2VecVocabAnalysis.
class IR2VecVocabResult {
ir2vec::Vocab Vocabulary;
bool Valid = false;

public:
IR2VecVocabResult() = default;
IR2VecVocabResult(ir2vec::Vocab &&Vocabulary);

bool isValid() const { return Valid; }
const ir2vec::Vocab &getVocabulary() const;
unsigned getDimension() const;
bool invalidate(Module &M, const PreservedAnalyses &PA,
ModuleAnalysisManager::Invalidator &Inv) const;
};

/// This analysis provides the vocabulary for IR2Vec. The vocabulary provides a
/// mapping between an entity of the IR (like opcode, type, argument, etc.) and
/// its corresponding embedding.
class IR2VecVocabAnalysis : public AnalysisInfoMixin<IR2VecVocabAnalysis> {
ir2vec::Vocab Vocabulary;
using VocabVector = std::vector<ir2vec::Embedding>;
using VocabMap = std::map<std::string, ir2vec::Embedding>;
VocabMap OpcVocab, TypeVocab, ArgVocab;
VocabVector Vocab;

unsigned Dim = 0;
Error readVocabulary();
Error parseVocabSection(StringRef Key, const json::Value &ParsedVocabValue,
ir2vec::Vocab &TargetVocab, unsigned &Dim);
VocabMap &TargetVocab, unsigned &Dim);
void generateNumMappedVocab();
void emitError(Error Err, LLVMContext &Ctx);

public:
static AnalysisKey Key;
IR2VecVocabAnalysis() = default;
explicit IR2VecVocabAnalysis(const ir2vec::Vocab &Vocab);
explicit IR2VecVocabAnalysis(ir2vec::Vocab &&Vocab);
using Result = IR2VecVocabResult;
explicit IR2VecVocabAnalysis(const VocabVector &Vocab);
explicit IR2VecVocabAnalysis(VocabVector &&Vocab);
using Result = ir2vec::Vocabulary;
Result run(Module &M, ModuleAnalysisManager &MAM);
};

Expand Down
20 changes: 10 additions & 10 deletions llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,20 +242,20 @@ FunctionPropertiesInfo FunctionPropertiesInfo::getFunctionPropertiesInfo(
// We use the cached result of the IR2VecVocabAnalysis run by
// InlineAdvisorAnalysis. If the IR2VecVocabAnalysis is not run, we don't
// use IR2Vec embeddings.
auto VocabResult = FAM.getResult<ModuleAnalysisManagerFunctionProxy>(F)
.getCachedResult<IR2VecVocabAnalysis>(*F.getParent());
auto Vocabulary = FAM.getResult<ModuleAnalysisManagerFunctionProxy>(F)
.getCachedResult<IR2VecVocabAnalysis>(*F.getParent());
return getFunctionPropertiesInfo(F, FAM.getResult<DominatorTreeAnalysis>(F),
FAM.getResult<LoopAnalysis>(F), VocabResult);
FAM.getResult<LoopAnalysis>(F), Vocabulary);
}

FunctionPropertiesInfo FunctionPropertiesInfo::getFunctionPropertiesInfo(
const Function &F, const DominatorTree &DT, const LoopInfo &LI,
const IR2VecVocabResult *VocabResult) {
const ir2vec::Vocabulary *Vocabulary) {

FunctionPropertiesInfo FPI;
if (VocabResult && VocabResult->isValid()) {
FPI.IR2VecVocab = VocabResult->getVocabulary();
FPI.FunctionEmbedding = ir2vec::Embedding(VocabResult->getDimension(), 0.0);
if (Vocabulary && Vocabulary->isValid()) {
FPI.IR2VecVocab = Vocabulary;
FPI.FunctionEmbedding = ir2vec::Embedding(Vocabulary->getDimension(), 0.0);
}
for (const auto &BB : F)
if (DT.isReachableFromEntry(&BB))
Expand Down Expand Up @@ -588,9 +588,9 @@ bool FunctionPropertiesUpdater::isUpdateValid(Function &F,
return false;
DominatorTree DT(F);
LoopInfo LI(DT);
auto VocabResult = FAM.getResult<ModuleAnalysisManagerFunctionProxy>(F)
.getCachedResult<IR2VecVocabAnalysis>(*F.getParent());
auto Vocabulary = FAM.getResult<ModuleAnalysisManagerFunctionProxy>(F)
.getCachedResult<IR2VecVocabAnalysis>(*F.getParent());
auto Fresh =
FunctionPropertiesInfo::getFunctionPropertiesInfo(F, DT, LI, VocabResult);
FunctionPropertiesInfo::getFunctionPropertiesInfo(F, DT, LI, Vocabulary);
return FPI == Fresh;
}
Loading
Loading