Skip to content
Open
119 changes: 119 additions & 0 deletions mlir/include/mlir/Query/Matcher/Diagnostics.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
//===--- Diagnostics.h - Helper class for error diagnostics -----*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Diagnostics class to manage error messages. Implementation shares similarity
// to clang-query Diagnostics.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_TOOLS_MLIRQUERY_MATCHER_DIAGNOSTICS_H
#define MLIR_TOOLS_MLIRQUERY_MATCHER_DIAGNOSTICS_H

#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/raw_ostream.h"
#include <string>
#include <vector>

namespace mlir::query::matcher::internal {

// Represents the line and column numbers in a source query.
struct SourceLocation {
unsigned line{};
unsigned column{};
};

// Represents a range in a source query, defined by its start and end locations.
struct SourceRange {
SourceLocation start{};
SourceLocation end{};
};

// Diagnostics class to manage error messages.
class Diagnostics {
public:
// All errors from the system.
enum class ErrorType {
None,

// Parser Errors
ParserFailedToBuildMatcher,
ParserInvalidToken,
ParserNoCloseParen,
ParserNoCode,
ParserNoComma,
ParserNoOpenParen,
ParserNotAMatcher,
ParserOverloadedType,
ParserStringError,
ParserTrailingCode,

// Registry Errors
RegistryMatcherNotFound,
RegistryValueNotFound,
RegistryWrongArgCount,
RegistryWrongArgType
};

// Helper stream class for constructing error messages.
class ArgStream {
public:
ArgStream(std::vector<std::string> *out) : out(out) {}
template <class T>
ArgStream &operator<<(const T &arg) {
return operator<<(llvm::Twine(arg));
}
ArgStream &operator<<(const llvm::Twine &arg);

private:
std::vector<std::string> *out;
};

// Add an error message with the specified range and error type.
// Returns an ArgStream object to allow constructing the error message using
// the << operator.
ArgStream addError(SourceRange range, ErrorType error);

// Print all error messages to the specified output stream.
void print(llvm::raw_ostream &OS) const;

private:
// Information stored for one frame of the context.
struct ContextFrame {
SourceRange range;
std::vector<std::string> args;
};

// Information stored for each error found.
struct ErrorContent {
std::vector<ContextFrame> contextStack;
struct Message {
SourceRange range;
ErrorType type;
std::vector<std::string> args;
};
std::vector<Message> messages;
};

// Get an array reference to the error contents.
llvm::ArrayRef<ErrorContent> errors() const { return errorValues; }

void printMessage(const ErrorContent::Message &message,
const llvm::Twine Prefix, llvm::raw_ostream &OS) const;

void printErrorContent(const ErrorContent &content,
llvm::raw_ostream &OS) const;

std::vector<ContextFrame> contextStack;
std::vector<ErrorContent> errorValues;
};

} // namespace mlir::query::matcher::internal

#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_DIAGNOSTICS_H
196 changes: 196 additions & 0 deletions mlir/include/mlir/Query/Matcher/Marshallers.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
//===--- Marshallers.h - Generic matcher function marshallers ---*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains function templates and classes to wrap matcher construct
// functions. It provides a collection of template function and classes that
// present a generic marshalling layer on top of matcher construct functions.
// The registry uses these to export all marshaller constructors with a uniform
// interface. This mechanism takes inspiration from clang-query.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_TOOLS_MLIRQUERY_MATCHER_MARSHALLERS_H
#define MLIR_TOOLS_MLIRQUERY_MATCHER_MARSHALLERS_H

#include "Diagnostics.h"
#include "VariantValue.h"

namespace mlir::query::matcher::internal {

// Helper template class for jumping from argument type to the correct is/get
// functions in VariantValue. This is used for verifying and extracting the
// matcher arguments.
template <class T>
struct ArgTypeTraits;
template <class T>
struct ArgTypeTraits<const T &> : public ArgTypeTraits<T> {};

template <>
struct ArgTypeTraits<StringRef> {

static bool hasCorrectType(const VariantValue &value) {
return value.isString();
}

static const StringRef &get(const VariantValue &value) {
return value.getString();
}

static ArgKind getKind() { return ArgKind::String; }

static std::optional<std::string> getBestGuess(const VariantValue &) {
return std::nullopt;
}
};

template <>
struct ArgTypeTraits<DynMatcher> {

static bool hasCorrectType(const VariantValue &value) {
return value.isMatcher();
}

static DynMatcher get(const VariantValue &value) {
return *value.getMatcher().getDynMatcher();
}

static ArgKind getKind() { return ArgKind::Matcher; }

static std::optional<std::string> getBestGuess(const VariantValue &) {
return std::nullopt;
}
};

// Interface for generic matcher descriptor.
// Offers a create() method that constructs the matcher from the provided
// arguments.
class MatcherDescriptor {
public:
virtual ~MatcherDescriptor() = default;
virtual VariantMatcher create(SourceRange nameRange,
const ArrayRef<ParserValue> args,
Diagnostics *error) const = 0;

// Returns the number of arguments accepted by the matcher.
virtual unsigned getNumArgs() const = 0;

// Append the set of argument types accepted for argument 'ArgNo' to
// 'ArgKinds'.
virtual void getArgKinds(unsigned argNo,
std::vector<ArgKind> &argKinds) const = 0;
};

class FixedArgCountMatcherDescriptor : public MatcherDescriptor {
public:
using MarshallerType = VariantMatcher (*)(void (*func)(),
StringRef matcherName,
SourceRange nameRange,
ArrayRef<ParserValue> args,
Diagnostics *error);

// Marshaller Function to unpack the arguments and call Func. Func is the
// Matcher construct function. This is the function that the matcher
// expressions would use to create the matcher.
FixedArgCountMatcherDescriptor(MarshallerType marshaller, void (*func)(),
StringRef matcherName,
ArrayRef<ArgKind> argKinds)
: marshaller(marshaller), func(func), matcherName(matcherName),
argKinds(argKinds.begin(), argKinds.end()) {}

VariantMatcher create(SourceRange nameRange, ArrayRef<ParserValue> args,
Diagnostics *error) const override {
return marshaller(func, matcherName, nameRange, args, error);
}

unsigned getNumArgs() const override { return argKinds.size(); }

void getArgKinds(unsigned argNo, std::vector<ArgKind> &kinds) const override {
kinds.push_back(argKinds[argNo]);
}

private:
const MarshallerType marshaller;
void (*const func)();
const StringRef matcherName;
const std::vector<ArgKind> argKinds;
};

// Helper function to check if argument count matches expected count
inline bool checkArgCount(SourceRange nameRange, size_t expectedArgCount,
ArrayRef<ParserValue> args, Diagnostics *error) {
if (args.size() != expectedArgCount) {
error->addError(nameRange, Diagnostics::ErrorType::RegistryWrongArgCount)
<< expectedArgCount << args.size();
return false;
}
return true;
}

// Helper function for checking argument type
template <typename ArgType, size_t Index>
inline bool checkArgTypeAtIndex(StringRef matcherName,
ArrayRef<ParserValue> args,
Diagnostics *error) {
if (!ArgTypeTraits<ArgType>::hasCorrectType(args[Index].value)) {
error->addError(args[Index].range,
Diagnostics::ErrorType::RegistryWrongArgType)
<< matcherName << Index + 1;
return false;
}
return true;
}

// Marshaller function for fixed number of arguments
template <typename ReturnType, typename... ArgTypes, size_t... Is>
static VariantMatcher
matcherMarshallFixedImpl(void (*func)(), StringRef matcherName,
SourceRange nameRange, ArrayRef<ParserValue> args,
Diagnostics *error, std::index_sequence<Is...>) {
using FuncType = ReturnType (*)(ArgTypes...);

// Check if the argument count matches the expected count
if (!checkArgCount(nameRange, sizeof...(ArgTypes), args, error)) {
return VariantMatcher();
}

// Check if each argument at the corresponding index has the correct type
if ((... && checkArgTypeAtIndex<ArgTypes, Is>(matcherName, args, error))) {
ReturnType fnPointer = reinterpret_cast<FuncType>(func)(
ArgTypeTraits<ArgTypes>::get(args[Is].value)...);
return VariantMatcher::SingleMatcher(
*DynMatcher::constructDynMatcherFromMatcherFn(fnPointer));
} else {
return VariantMatcher();
}
}

template <typename ReturnType, typename... ArgTypes>
static VariantMatcher
matcherMarshallFixed(void (*func)(), StringRef matcherName,
SourceRange nameRange, ArrayRef<ParserValue> args,
Diagnostics *error) {
return matcherMarshallFixedImpl<ReturnType, ArgTypes...>(
func, matcherName, nameRange, args, error,
std::index_sequence_for<ArgTypes...>{});
}

// Fixed number of arguments overload
template <typename ReturnType, typename... ArgTypes>
std::unique_ptr<MatcherDescriptor>
makeMatcherAutoMarshall(ReturnType (*func)(ArgTypes...),
StringRef matcherName) {
// Create a vector of argument kinds
std::vector<ArgKind> argKinds = {ArgTypeTraits<ArgTypes>::getKind()...};
return std::make_unique<FixedArgCountMatcherDescriptor>(
matcherMarshallFixed<ReturnType, ArgTypes...>,
reinterpret_cast<void (*)()>(func), matcherName, argKinds);
}

} // namespace mlir::query::matcher::internal

#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MARSHALLERS_H
42 changes: 42 additions & 0 deletions mlir/include/mlir/Query/Matcher/MatchFinder.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@

//===- MatchFinder.h - Structural query framework ---------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains the MatchFinder class, which is used to find operations
// that match a given matcher.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERFINDER_H
#define MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERFINDER_H

#include "MatchersInternal.h"

namespace mlir::query::matcher {

// MatchFinder is used to find all operations that match a given matcher.
class MatchFinder {
public:
// Returns all operations that match the given matcher.
static std::vector<Operation *> getMatches(Operation *root,
DynMatcher matcher) {
std::vector<Operation *> matches;

// Simple match finding with walk.
root->walk([&](Operation *subOp) {
if (matcher.match(subOp))
matches.push_back(subOp);
});

return matches;
}
};

} // namespace mlir::query::matcher

#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERFINDER_H
Loading