Skip to content

Commit 6d7039c

Browse files
committed
[mlir][mlir-query] Introduce mlir-query tool with autocomplete support
This commit adds the initial version of the mlir-query tool, which leverages the pre-existing matchers defined in mlir/include/mlir/IR/Matchers.h The tool provides the following set of basic queries: QUERY MATCHER hasOpAttrName(string) -> m_Attr hasOpName(string) -> m_Op isConstantOp() -> m_Constant isNegInfFloat() -> m_NegInfFloat isNegZeroFloat() -> m_NegZeroFloat isNonZero() -> m_NonZero isOne() -> m_One isOneFloat() -> m_OneFloat isPosInfFloat() -> m_PosInfFloat isPosZeroFloat() -> m_PosZeroFloat isZero() -> m_Zero isZeroFloat() -> m_AnyZeroFloat Differential Revision: https://reviews.llvm.org/D155127
1 parent 8130166 commit 6d7039c

25 files changed

+2547
-0
lines changed
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
//===--- Diagnostics.h - Helper class for error diagnostics -----*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Diagnostics class to manage error messages. Implementation shares similarity
10+
// to clang-query Diagnostics.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#ifndef MLIR_TOOLS_MLIRQUERY_MATCHER_DIAGNOSTICS_H
15+
#define MLIR_TOOLS_MLIRQUERY_MATCHER_DIAGNOSTICS_H
16+
17+
#include "llvm/ADT/ArrayRef.h"
18+
#include "llvm/ADT/StringRef.h"
19+
#include "llvm/ADT/Twine.h"
20+
#include "llvm/Support/raw_ostream.h"
21+
#include <string>
22+
#include <vector>
23+
24+
namespace mlir::query::matcher {
25+
26+
// Represents the line and column numbers in a source query.
27+
struct SourceLocation {
28+
unsigned line{};
29+
unsigned column{};
30+
};
31+
32+
// Represents a range in a source query, defined by its start and end locations.
33+
struct SourceRange {
34+
SourceLocation start{};
35+
SourceLocation end{};
36+
};
37+
38+
// Diagnostics class to manage error messages.
39+
class Diagnostics {
40+
public:
41+
// Parser context types.
42+
enum ContextType { CT_MatcherArg, CT_MatcherConstruct };
43+
44+
// All errors from the system.
45+
enum ErrorType {
46+
ET_None,
47+
48+
// Parser Errors
49+
ET_ParserFailedToBuildMatcher,
50+
ET_ParserInvalidToken,
51+
ET_ParserNoCloseParen,
52+
ET_ParserNoCode,
53+
ET_ParserNoComma,
54+
ET_ParserNoOpenParen,
55+
ET_ParserNotAMatcher,
56+
ET_ParserOverloadedType,
57+
ET_ParserStringError,
58+
ET_ParserTrailingCode,
59+
60+
// Registry Errors
61+
ET_RegistryMatcherNotFound,
62+
ET_RegistryValueNotFound,
63+
ET_RegistryWrongArgCount,
64+
ET_RegistryWrongArgType
65+
};
66+
67+
// Helper stream class for constructing error messages.
68+
class ArgStream {
69+
public:
70+
ArgStream(std::vector<std::string> *out) : out(out) {}
71+
template <class T>
72+
ArgStream &operator<<(const T &arg) {
73+
return operator<<(llvm::Twine(arg));
74+
}
75+
ArgStream &operator<<(const llvm::Twine &arg);
76+
77+
private:
78+
std::vector<std::string> *out;
79+
};
80+
81+
// Context for constructing a matcher or parsing its argument.
82+
struct Context {
83+
enum ConstructMatcherEnum { ConstructMatcher };
84+
Context(ConstructMatcherEnum, Diagnostics *error,
85+
llvm::StringRef matcherName, SourceRange matcherRange);
86+
enum MatcherArgEnum { MatcherArg };
87+
Context(MatcherArgEnum, Diagnostics *error, llvm::StringRef matcherName,
88+
SourceRange matcherRange, int argNumber);
89+
~Context();
90+
91+
private:
92+
Diagnostics *const error;
93+
};
94+
95+
// Context for managing overloaded matcher construction.
96+
struct OverloadContext {
97+
// Construct an overload context with the given error.
98+
OverloadContext(Diagnostics *error);
99+
~OverloadContext();
100+
// Revert all errors that occurred within this context.
101+
void revertErrors();
102+
103+
private:
104+
Diagnostics *const error;
105+
unsigned beginIndex{};
106+
};
107+
108+
// Add an error message with the specified range and error type.
109+
// Returns an ArgStream object to allow constructing the error message using
110+
// the << operator.
111+
ArgStream addError(SourceRange range, ErrorType error);
112+
113+
// Information stored for one frame of the context.
114+
struct ContextFrame {
115+
ContextType type;
116+
SourceRange range;
117+
std::vector<std::string> args;
118+
};
119+
120+
// Information stored for each error found.
121+
struct ErrorContent {
122+
std::vector<ContextFrame> contextStack;
123+
struct Message {
124+
SourceRange range;
125+
ErrorType type;
126+
std::vector<std::string> args;
127+
};
128+
std::vector<Message> messages;
129+
};
130+
131+
// Get an array reference to the error contents.
132+
llvm::ArrayRef<ErrorContent> errors() const { return errorValues; }
133+
134+
// Print all error messages to the specified output stream.
135+
void print(llvm::raw_ostream &OS) const;
136+
137+
// Print the full error messages, including the context information, to the
138+
// specified output stream.
139+
void printFull(llvm::raw_ostream &OS) const;
140+
141+
private:
142+
// Push a new context frame onto the context stack with the specified type and
143+
// range.
144+
ArgStream pushContextFrame(ContextType type, SourceRange range);
145+
146+
std::vector<ContextFrame> contextStack;
147+
std::vector<ErrorContent> errorValues;
148+
};
149+
150+
} // namespace mlir::query::matcher
151+
152+
#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_DIAGNOSTICS_H
Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
//===--- Marshallers.h - Generic matcher function marshallers ---*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file contains function templates and classes to wrap matcher construct
10+
// functions. It provides a collection of template function and classes that
11+
// present a generic marshalling layer on top of matcher construct functions.
12+
// The registry uses these to export all marshaller constructors with a uniform
13+
// interface. This mechanism takes inspiration from clang-query.
14+
//
15+
//===----------------------------------------------------------------------===//
16+
17+
#ifndef MLIR_TOOLS_MLIRQUERY_MATCHER_MARSHALLERS_H
18+
#define MLIR_TOOLS_MLIRQUERY_MATCHER_MARSHALLERS_H
19+
20+
#include "Diagnostics.h"
21+
#include "VariantValue.h"
22+
23+
namespace mlir::query::matcher::internal {
24+
25+
// Helper template class for jumping from argument type to the correct is/get
26+
// functions in VariantValue. This is used for verifying and extracting the
27+
// matcher arguments.
28+
template <class T>
29+
struct ArgTypeTraits;
30+
template <class T>
31+
struct ArgTypeTraits<const T &> : public ArgTypeTraits<T> {};
32+
33+
template <>
34+
struct ArgTypeTraits<StringRef> {
35+
36+
static bool hasCorrectType(const VariantValue &value) {
37+
return value.isString();
38+
}
39+
40+
static const StringRef &get(const VariantValue &value) {
41+
return value.getString();
42+
}
43+
44+
static ArgKind getKind() { return ArgKind(ArgKind::AK_String); }
45+
46+
static std::optional<std::string> getBestGuess(const VariantValue &) {
47+
return std::nullopt;
48+
}
49+
};
50+
51+
template <>
52+
struct ArgTypeTraits<DynMatcher> {
53+
54+
static bool hasCorrectType(const VariantValue &value) {
55+
return value.isMatcher();
56+
}
57+
58+
static DynMatcher get(const VariantValue &value) {
59+
return *value.getMatcher().getDynMatcher();
60+
}
61+
62+
static ArgKind getKind() { return ArgKind(ArgKind::AK_Matcher); }
63+
64+
static std::optional<std::string> getBestGuess(const VariantValue &) {
65+
return std::nullopt;
66+
}
67+
};
68+
69+
// Interface for generic matcher descriptor.
70+
// Offers a create() method that constructs the matcher from the provided
71+
// arguments.
72+
class MatcherDescriptor {
73+
public:
74+
virtual ~MatcherDescriptor() = default;
75+
virtual VariantMatcher create(SourceRange nameRange,
76+
const ArrayRef<ParserValue> args,
77+
Diagnostics *error) const = 0;
78+
79+
// Returns the number of arguments accepted by the matcher.
80+
virtual unsigned getNumArgs() const = 0;
81+
82+
// Append the set of argument types accepted for argument 'ArgNo' to
83+
// 'ArgKinds'.
84+
virtual void getArgKinds(unsigned argNo,
85+
std::vector<ArgKind> &argKinds) const = 0;
86+
};
87+
88+
class FixedArgCountMatcherDescriptor : public MatcherDescriptor {
89+
public:
90+
using MarshallerType = VariantMatcher (*)(void (*func)(),
91+
StringRef matcherName,
92+
SourceRange nameRange,
93+
ArrayRef<ParserValue> args,
94+
Diagnostics *error);
95+
96+
// Marshaller Function to unpack the arguments and call Func. Func is the
97+
// Matcher construct function. This is the function that the matcher
98+
// expressions would use to create the matcher.
99+
FixedArgCountMatcherDescriptor(MarshallerType marshaller, void (*func)(),
100+
StringRef matcherName,
101+
ArrayRef<ArgKind> argKinds)
102+
: marshaller(marshaller), func(func), matcherName(matcherName),
103+
argKinds(argKinds.begin(), argKinds.end()) {}
104+
105+
VariantMatcher create(SourceRange nameRange, ArrayRef<ParserValue> args,
106+
Diagnostics *error) const override {
107+
return marshaller(func, matcherName, nameRange, args, error);
108+
}
109+
110+
unsigned getNumArgs() const override { return argKinds.size(); }
111+
112+
void getArgKinds(unsigned argNo, std::vector<ArgKind> &kinds) const override {
113+
kinds.push_back(argKinds[argNo]);
114+
}
115+
116+
private:
117+
const MarshallerType marshaller;
118+
void (*const func)();
119+
const StringRef matcherName;
120+
const std::vector<ArgKind> argKinds;
121+
};
122+
123+
// Helper function to check if argument count matches expected count
124+
inline bool checkArgCount(SourceRange nameRange, size_t expectedArgCount,
125+
ArrayRef<ParserValue> args, Diagnostics *error) {
126+
if (args.size() != expectedArgCount) {
127+
error->addError(nameRange, error->ET_RegistryWrongArgCount)
128+
<< expectedArgCount << args.size();
129+
return false;
130+
}
131+
return true;
132+
}
133+
134+
// Helper function for checking argument type
135+
template <typename ArgType, size_t Index>
136+
inline bool checkArgTypeAtIndex(StringRef matcherName,
137+
ArrayRef<ParserValue> args,
138+
Diagnostics *error) {
139+
if (!ArgTypeTraits<ArgType>::hasCorrectType(args[Index].value)) {
140+
error->addError(args[Index].range, error->ET_RegistryWrongArgType)
141+
<< matcherName << Index + 1;
142+
return false;
143+
}
144+
return true;
145+
}
146+
147+
// Marshaller function for fixed number of arguments
148+
template <typename ReturnType, typename... ArgTypes, size_t... Is>
149+
static VariantMatcher
150+
matcherMarshallFixedImpl(void (*func)(), StringRef matcherName,
151+
SourceRange nameRange, ArrayRef<ParserValue> args,
152+
Diagnostics *error, std::index_sequence<Is...>) {
153+
using FuncType = ReturnType (*)(ArgTypes...);
154+
155+
// Check if the argument count matches the expected count
156+
if (!checkArgCount(nameRange, sizeof...(ArgTypes), args, error)) {
157+
return VariantMatcher();
158+
}
159+
160+
// Check if each argument at the corresponding index has the correct type
161+
if ((... && checkArgTypeAtIndex<ArgTypes, Is>(matcherName, args, error))) {
162+
ReturnType fnPointer = reinterpret_cast<FuncType>(func)(
163+
ArgTypeTraits<ArgTypes>::get(args[Is].value)...);
164+
return VariantMatcher::SingleMatcher(
165+
*DynMatcher::constructDynMatcherFromMatcherFn(fnPointer));
166+
} else {
167+
return VariantMatcher();
168+
}
169+
}
170+
171+
template <typename ReturnType, typename... ArgTypes>
172+
static VariantMatcher
173+
matcherMarshallFixed(void (*func)(), StringRef matcherName,
174+
SourceRange nameRange, ArrayRef<ParserValue> args,
175+
Diagnostics *error) {
176+
return matcherMarshallFixedImpl<ReturnType, ArgTypes...>(
177+
func, matcherName, nameRange, args, error,
178+
std::index_sequence_for<ArgTypes...>{});
179+
}
180+
181+
// Fixed number of arguments overload
182+
template <typename ReturnType, typename... ArgTypes>
183+
std::unique_ptr<MatcherDescriptor>
184+
makeMatcherAutoMarshall(ReturnType (*func)(ArgTypes...),
185+
StringRef matcherName) {
186+
// Create a vector of argument kinds
187+
std::vector<ArgKind> argKinds = {ArgTypeTraits<ArgTypes>::getKind()...};
188+
return std::make_unique<FixedArgCountMatcherDescriptor>(
189+
matcherMarshallFixed<ReturnType, ArgTypes...>,
190+
reinterpret_cast<void (*)()>(func), matcherName, argKinds);
191+
}
192+
193+
} // namespace mlir::query::matcher::internal
194+
195+
#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MARSHALLERS_H

0 commit comments

Comments
 (0)