summaryrefslogtreecommitdiff
path: root/vendor/github.com/google/cel-go/parser/parser.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/google/cel-go/parser/parser.go')
-rw-r--r--vendor/github.com/google/cel-go/parser/parser.go232
1 files changed, 104 insertions, 128 deletions
diff --git a/vendor/github.com/google/cel-go/parser/parser.go b/vendor/github.com/google/cel-go/parser/parser.go
index e6f70f9..5cbb176 100644
--- a/vendor/github.com/google/cel-go/parser/parser.go
+++ b/vendor/github.com/google/cel-go/parser/parser.go
@@ -21,17 +21,15 @@ import (
"regexp"
"strconv"
"strings"
- "sync"
- antlr "github.com/antlr/antlr4/runtime/Go/antlr/v4"
+ antlr "github.com/antlr4-go/antlr/v4"
"github.com/google/cel-go/common"
+ "github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/operators"
"github.com/google/cel-go/common/runes"
+ "github.com/google/cel-go/common/types"
"github.com/google/cel-go/parser/gen"
-
- exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
- structpb "google.golang.org/protobuf/types/known/structpb"
)
// Parser encapsulates the context necessary to perform parsing for different expressions.
@@ -88,10 +86,13 @@ func mustNewParser(opts ...Option) *Parser {
}
// Parse parses the expression represented by source and returns the result.
-func (p *Parser) Parse(source common.Source) (*exprpb.ParsedExpr, *common.Errors) {
+func (p *Parser) Parse(source common.Source) (*ast.AST, *common.Errors) {
+ errs := common.NewErrors(source)
+ fac := ast.NewExprFactory()
impl := parser{
- errors: &parseErrors{common.NewErrors(source)},
- helper: newParserHelper(source),
+ errors: &parseErrors{errs},
+ exprFactory: fac,
+ helper: newParserHelper(source, fac),
macros: p.macros,
maxRecursionDepth: p.maxRecursionDepth,
errorReportingLimit: p.errorReportingLimit,
@@ -99,23 +100,21 @@ func (p *Parser) Parse(source common.Source) (*exprpb.ParsedExpr, *common.Errors
errorRecoveryLookaheadTokenLimit: p.errorRecoveryTokenLookaheadLimit,
populateMacroCalls: p.populateMacroCalls,
enableOptionalSyntax: p.enableOptionalSyntax,
+ enableVariadicOperatorASTs: p.enableVariadicOperatorASTs,
}
buf, ok := source.(runes.Buffer)
if !ok {
buf = runes.NewBuffer(source.Content())
}
- var e *exprpb.Expr
+ var out ast.Expr
if buf.Len() > p.expressionSizeCodePointLimit {
- e = impl.reportError(common.NoLocation,
+ out = impl.reportError(common.NoLocation,
"expression code point size exceeds limit: size: %d, limit %d",
buf.Len(), p.expressionSizeCodePointLimit)
} else {
- e = impl.parse(buf, source.Description())
+ out = impl.parse(buf, source.Description())
}
- return &exprpb.ParsedExpr{
- Expr: e,
- SourceInfo: impl.helper.getSourceInfo(),
- }, impl.errors.Errors
+ return ast.NewAST(out, impl.helper.getSourceInfo()), errs
}
// reservedIds are not legal to use as variables. We exclude them post-parse, as they *are* valid
@@ -148,7 +147,7 @@ var reservedIds = map[string]struct{}{
// This function calls ParseWithMacros with AllMacros.
//
// Deprecated: Use NewParser().Parse() instead.
-func Parse(source common.Source) (*exprpb.ParsedExpr, *common.Errors) {
+func Parse(source common.Source) (*ast.AST, *common.Errors) {
return mustNewParser(Macros(AllMacros...)).Parse(source)
}
@@ -285,6 +284,7 @@ var _ antlr.ErrorStrategy = &recoveryLimitErrorStrategy{}
type parser struct {
gen.BaseCELVisitor
errors *parseErrors
+ exprFactory ast.ExprFactory
helper *parserHelper
macros map[string]Macro
recursionDepth int
@@ -295,55 +295,24 @@ type parser struct {
errorRecoveryLookaheadTokenLimit int
populateMacroCalls bool
enableOptionalSyntax bool
+ enableVariadicOperatorASTs bool
}
-var (
- _ gen.CELVisitor = (*parser)(nil)
+var _ gen.CELVisitor = (*parser)(nil)
- lexerPool *sync.Pool = &sync.Pool{
- New: func() any {
- l := gen.NewCELLexer(nil)
- l.RemoveErrorListeners()
- return l
- },
- }
-
- parserPool *sync.Pool = &sync.Pool{
- New: func() any {
- p := gen.NewCELParser(nil)
- p.RemoveErrorListeners()
- return p
- },
- }
-)
+func (p *parser) parse(expr runes.Buffer, desc string) ast.Expr {
+ lexer := gen.NewCELLexer(newCharStream(expr, desc))
+ lexer.RemoveErrorListeners()
+ lexer.AddErrorListener(p)
-func (p *parser) parse(expr runes.Buffer, desc string) *exprpb.Expr {
- // TODO: get rid of these pools once https://github.com/antlr/antlr4/pull/3571 is in a release
- lexer := lexerPool.Get().(*gen.CELLexer)
- prsr := parserPool.Get().(*gen.CELParser)
+ prsr := gen.NewCELParser(antlr.NewCommonTokenStream(lexer, 0))
+ prsr.RemoveErrorListeners()
prsrListener := &recursionListener{
maxDepth: p.maxRecursionDepth,
ruleTypeDepth: map[int]*int{},
}
- defer func() {
- // Unfortunately ANTLR Go runtime is missing (*antlr.BaseParser).RemoveParseListeners,
- // so this is good enough until that is exported.
- // Reset the lexer and parser before putting them back in the pool.
- lexer.RemoveErrorListeners()
- prsr.RemoveParseListener(prsrListener)
- prsr.RemoveErrorListeners()
- lexer.SetInputStream(nil)
- prsr.SetInputStream(nil)
- lexerPool.Put(lexer)
- parserPool.Put(prsr)
- }()
-
- lexer.SetInputStream(newCharStream(expr, desc))
- prsr.SetInputStream(antlr.NewCommonTokenStream(lexer, 0))
-
- lexer.AddErrorListener(p)
prsr.AddErrorListener(p)
prsr.AddParseListener(prsrListener)
@@ -357,9 +326,9 @@ func (p *parser) parse(expr runes.Buffer, desc string) *exprpb.Expr {
if val := recover(); val != nil {
switch err := val.(type) {
case *lookaheadLimitError:
- p.errors.ReportError(common.NoLocation, err.Error())
+ p.errors.internalError(err.Error())
case *recursionError:
- p.errors.ReportError(common.NoLocation, err.Error())
+ p.errors.internalError(err.Error())
case *tooManyErrors:
// do nothing
case *recoveryLimitError:
@@ -370,7 +339,7 @@ func (p *parser) parse(expr runes.Buffer, desc string) *exprpb.Expr {
}
}()
- return p.Visit(prsr.Start()).(*exprpb.Expr)
+ return p.Visit(prsr.Start_()).(ast.Expr)
}
// Visitor implementations.
@@ -449,7 +418,7 @@ func (p *parser) Visit(tree antlr.ParseTree) any {
// Report at least one error if the parser reaches an unknown parse element.
// Typically, this happens if the parser has already encountered a syntax error elsewhere.
- if len(p.errors.GetErrors()) == 0 {
+ if p.errors.errorCount() == 0 {
txt := "<<nil>>"
if t != nil {
txt = fmt.Sprintf("<<%T>>", t)
@@ -467,46 +436,46 @@ func (p *parser) VisitStart(ctx *gen.StartContext) any {
// Visit a parse tree produced by CELParser#expr.
func (p *parser) VisitExpr(ctx *gen.ExprContext) any {
- result := p.Visit(ctx.GetE()).(*exprpb.Expr)
+ result := p.Visit(ctx.GetE()).(ast.Expr)
if ctx.GetOp() == nil {
return result
}
opID := p.helper.id(ctx.GetOp())
- ifTrue := p.Visit(ctx.GetE1()).(*exprpb.Expr)
- ifFalse := p.Visit(ctx.GetE2()).(*exprpb.Expr)
+ ifTrue := p.Visit(ctx.GetE1()).(ast.Expr)
+ ifFalse := p.Visit(ctx.GetE2()).(ast.Expr)
return p.globalCallOrMacro(opID, operators.Conditional, result, ifTrue, ifFalse)
}
// Visit a parse tree produced by CELParser#conditionalOr.
func (p *parser) VisitConditionalOr(ctx *gen.ConditionalOrContext) any {
- result := p.Visit(ctx.GetE()).(*exprpb.Expr)
- b := newBalancer(p.helper, operators.LogicalOr, result)
+ result := p.Visit(ctx.GetE()).(ast.Expr)
+ l := p.newLogicManager(operators.LogicalOr, result)
rest := ctx.GetE1()
for i, op := range ctx.GetOps() {
if i >= len(rest) {
return p.reportError(ctx, "unexpected character, wanted '||'")
}
- next := p.Visit(rest[i]).(*exprpb.Expr)
+ next := p.Visit(rest[i]).(ast.Expr)
opID := p.helper.id(op)
- b.addTerm(opID, next)
+ l.addTerm(opID, next)
}
- return b.balance()
+ return l.toExpr()
}
// Visit a parse tree produced by CELParser#conditionalAnd.
func (p *parser) VisitConditionalAnd(ctx *gen.ConditionalAndContext) any {
- result := p.Visit(ctx.GetE()).(*exprpb.Expr)
- b := newBalancer(p.helper, operators.LogicalAnd, result)
+ result := p.Visit(ctx.GetE()).(ast.Expr)
+ l := p.newLogicManager(operators.LogicalAnd, result)
rest := ctx.GetE1()
for i, op := range ctx.GetOps() {
if i >= len(rest) {
return p.reportError(ctx, "unexpected character, wanted '&&'")
}
- next := p.Visit(rest[i]).(*exprpb.Expr)
+ next := p.Visit(rest[i]).(ast.Expr)
opID := p.helper.id(op)
- b.addTerm(opID, next)
+ l.addTerm(opID, next)
}
- return b.balance()
+ return l.toExpr()
}
// Visit a parse tree produced by CELParser#relation.
@@ -516,9 +485,9 @@ func (p *parser) VisitRelation(ctx *gen.RelationContext) any {
opText = ctx.GetOp().GetText()
}
if op, found := operators.Find(opText); found {
- lhs := p.Visit(ctx.Relation(0)).(*exprpb.Expr)
+ lhs := p.Visit(ctx.Relation(0)).(ast.Expr)
opID := p.helper.id(ctx.GetOp())
- rhs := p.Visit(ctx.Relation(1)).(*exprpb.Expr)
+ rhs := p.Visit(ctx.Relation(1)).(ast.Expr)
return p.globalCallOrMacro(opID, op, lhs, rhs)
}
return p.reportError(ctx, "operator not found")
@@ -531,9 +500,9 @@ func (p *parser) VisitCalc(ctx *gen.CalcContext) any {
opText = ctx.GetOp().GetText()
}
if op, found := operators.Find(opText); found {
- lhs := p.Visit(ctx.Calc(0)).(*exprpb.Expr)
+ lhs := p.Visit(ctx.Calc(0)).(ast.Expr)
opID := p.helper.id(ctx.GetOp())
- rhs := p.Visit(ctx.Calc(1)).(*exprpb.Expr)
+ rhs := p.Visit(ctx.Calc(1)).(ast.Expr)
return p.globalCallOrMacro(opID, op, lhs, rhs)
}
return p.reportError(ctx, "operator not found")
@@ -549,7 +518,7 @@ func (p *parser) VisitLogicalNot(ctx *gen.LogicalNotContext) any {
return p.Visit(ctx.Member())
}
opID := p.helper.id(ctx.GetOps()[0])
- target := p.Visit(ctx.Member()).(*exprpb.Expr)
+ target := p.Visit(ctx.Member()).(ast.Expr)
return p.globalCallOrMacro(opID, operators.LogicalNot, target)
}
@@ -558,13 +527,13 @@ func (p *parser) VisitNegate(ctx *gen.NegateContext) any {
return p.Visit(ctx.Member())
}
opID := p.helper.id(ctx.GetOps()[0])
- target := p.Visit(ctx.Member()).(*exprpb.Expr)
+ target := p.Visit(ctx.Member()).(ast.Expr)
return p.globalCallOrMacro(opID, operators.Negate, target)
}
// VisitSelect visits a parse tree produced by CELParser#Select.
func (p *parser) VisitSelect(ctx *gen.SelectContext) any {
- operand := p.Visit(ctx.Member()).(*exprpb.Expr)
+ operand := p.Visit(ctx.Member()).(ast.Expr)
// Handle the error case where no valid identifier is specified.
if ctx.GetId() == nil || ctx.GetOp() == nil {
return p.helper.newExpr(ctx)
@@ -585,7 +554,7 @@ func (p *parser) VisitSelect(ctx *gen.SelectContext) any {
// VisitMemberCall visits a parse tree produced by CELParser#MemberCall.
func (p *parser) VisitMemberCall(ctx *gen.MemberCallContext) any {
- operand := p.Visit(ctx.Member()).(*exprpb.Expr)
+ operand := p.Visit(ctx.Member()).(ast.Expr)
// Handle the error case where no valid identifier is specified.
if ctx.GetId() == nil {
return p.helper.newExpr(ctx)
@@ -597,13 +566,13 @@ func (p *parser) VisitMemberCall(ctx *gen.MemberCallContext) any {
// Visit a parse tree produced by CELParser#Index.
func (p *parser) VisitIndex(ctx *gen.IndexContext) any {
- target := p.Visit(ctx.Member()).(*exprpb.Expr)
+ target := p.Visit(ctx.Member()).(ast.Expr)
// Handle the error case where no valid identifier is specified.
if ctx.GetOp() == nil {
return p.helper.newExpr(ctx)
}
opID := p.helper.id(ctx.GetOp())
- index := p.Visit(ctx.GetIndex()).(*exprpb.Expr)
+ index := p.Visit(ctx.GetIndex()).(ast.Expr)
operator := operators.Index
if ctx.GetOpt() != nil {
if !p.enableOptionalSyntax {
@@ -627,7 +596,7 @@ func (p *parser) VisitCreateMessage(ctx *gen.CreateMessageContext) any {
messageName = "." + messageName
}
objID := p.helper.id(ctx.GetOp())
- entries := p.VisitIFieldInitializerList(ctx.GetEntries()).([]*exprpb.Expr_CreateStruct_Entry)
+ entries := p.VisitIFieldInitializerList(ctx.GetEntries()).([]ast.EntryExpr)
return p.helper.newObject(objID, messageName, entries...)
}
@@ -635,16 +604,16 @@ func (p *parser) VisitCreateMessage(ctx *gen.CreateMessageContext) any {
func (p *parser) VisitIFieldInitializerList(ctx gen.IFieldInitializerListContext) any {
if ctx == nil || ctx.GetFields() == nil {
// This is the result of a syntax error handled elswhere, return empty.
- return []*exprpb.Expr_CreateStruct_Entry{}
+ return []ast.EntryExpr{}
}
- result := make([]*exprpb.Expr_CreateStruct_Entry, len(ctx.GetFields()))
+ result := make([]ast.EntryExpr, len(ctx.GetFields()))
cols := ctx.GetCols()
vals := ctx.GetValues()
for i, f := range ctx.GetFields() {
if i >= len(cols) || i >= len(vals) {
// This is the result of a syntax error detected elsewhere.
- return []*exprpb.Expr_CreateStruct_Entry{}
+ return []ast.EntryExpr{}
}
initID := p.helper.id(cols[i])
optField := f.(*gen.OptFieldContext)
@@ -656,10 +625,10 @@ func (p *parser) VisitIFieldInitializerList(ctx gen.IFieldInitializerListContext
// The field may be empty due to a prior error.
id := optField.IDENTIFIER()
if id == nil {
- return []*exprpb.Expr_CreateStruct_Entry{}
+ return []ast.EntryExpr{}
}
fieldName := id.GetText()
- value := p.Visit(vals[i]).(*exprpb.Expr)
+ value := p.Visit(vals[i]).(ast.Expr)
field := p.helper.newObjectField(initID, fieldName, value, optional)
result[i] = field
}
@@ -699,9 +668,9 @@ func (p *parser) VisitCreateList(ctx *gen.CreateListContext) any {
// Visit a parse tree produced by CELParser#CreateStruct.
func (p *parser) VisitCreateStruct(ctx *gen.CreateStructContext) any {
structID := p.helper.id(ctx.GetOp())
- entries := []*exprpb.Expr_CreateStruct_Entry{}
+ entries := []ast.EntryExpr{}
if ctx.GetEntries() != nil {
- entries = p.Visit(ctx.GetEntries()).([]*exprpb.Expr_CreateStruct_Entry)
+ entries = p.Visit(ctx.GetEntries()).([]ast.EntryExpr)
}
return p.helper.newMap(structID, entries...)
}
@@ -710,17 +679,17 @@ func (p *parser) VisitCreateStruct(ctx *gen.CreateStructContext) any {
func (p *parser) VisitMapInitializerList(ctx *gen.MapInitializerListContext) any {
if ctx == nil || ctx.GetKeys() == nil {
// This is the result of a syntax error handled elswhere, return empty.
- return []*exprpb.Expr_CreateStruct_Entry{}
+ return []ast.EntryExpr{}
}
- result := make([]*exprpb.Expr_CreateStruct_Entry, len(ctx.GetCols()))
+ result := make([]ast.EntryExpr, len(ctx.GetCols()))
keys := ctx.GetKeys()
vals := ctx.GetValues()
for i, col := range ctx.GetCols() {
colID := p.helper.id(col)
if i >= len(keys) || i >= len(vals) {
// This is the result of a syntax error detected elsewhere.
- return []*exprpb.Expr_CreateStruct_Entry{}
+ return []ast.EntryExpr{}
}
optKey := keys[i]
optional := optKey.GetOpt() != nil
@@ -728,8 +697,8 @@ func (p *parser) VisitMapInitializerList(ctx *gen.MapInitializerListContext) any
p.reportError(optKey, "unsupported syntax '?'")
continue
}
- key := p.Visit(optKey.GetE()).(*exprpb.Expr)
- value := p.Visit(vals[i]).(*exprpb.Expr)
+ key := p.Visit(optKey.GetE()).(ast.Expr)
+ value := p.Visit(vals[i]).(ast.Expr)
entry := p.helper.newMapEntry(colID, key, value, optional)
result[i] = entry
}
@@ -809,30 +778,27 @@ func (p *parser) VisitBoolFalse(ctx *gen.BoolFalseContext) any {
// Visit a parse tree produced by CELParser#Null.
func (p *parser) VisitNull(ctx *gen.NullContext) any {
- return p.helper.newLiteral(ctx,
- &exprpb.Constant{
- ConstantKind: &exprpb.Constant_NullValue{
- NullValue: structpb.NullValue_NULL_VALUE}})
+ return p.helper.exprFactory.NewLiteral(p.helper.newID(ctx), types.NullValue)
}
-func (p *parser) visitExprList(ctx gen.IExprListContext) []*exprpb.Expr {
+func (p *parser) visitExprList(ctx gen.IExprListContext) []ast.Expr {
if ctx == nil {
- return []*exprpb.Expr{}
+ return []ast.Expr{}
}
return p.visitSlice(ctx.GetE())
}
-func (p *parser) visitListInit(ctx gen.IListInitContext) ([]*exprpb.Expr, []int32) {
+func (p *parser) visitListInit(ctx gen.IListInitContext) ([]ast.Expr, []int32) {
if ctx == nil {
- return []*exprpb.Expr{}, []int32{}
+ return []ast.Expr{}, []int32{}
}
elements := ctx.GetElems()
- result := make([]*exprpb.Expr, len(elements))
+ result := make([]ast.Expr, len(elements))
optionals := []int32{}
for i, e := range elements {
- ex := p.Visit(e.GetE()).(*exprpb.Expr)
+ ex := p.Visit(e.GetE()).(ast.Expr)
if ex == nil {
- return []*exprpb.Expr{}, []int32{}
+ return []ast.Expr{}, []int32{}
}
result[i] = ex
if e.GetOpt() != nil {
@@ -846,13 +812,13 @@ func (p *parser) visitListInit(ctx gen.IListInitContext) ([]*exprpb.Expr, []int3
return result, optionals
}
-func (p *parser) visitSlice(expressions []gen.IExprContext) []*exprpb.Expr {
+func (p *parser) visitSlice(expressions []gen.IExprContext) []ast.Expr {
if expressions == nil {
- return []*exprpb.Expr{}
+ return []ast.Expr{}
}
- result := make([]*exprpb.Expr, len(expressions))
+ result := make([]ast.Expr, len(expressions))
for i, e := range expressions {
- ex := p.Visit(e).(*exprpb.Expr)
+ ex := p.Visit(e).(ast.Expr)
result[i] = ex
}
return result
@@ -867,24 +833,31 @@ func (p *parser) unquote(ctx any, value string, isBytes bool) string {
return text
}
-func (p *parser) reportError(ctx any, format string, args ...any) *exprpb.Expr {
+func (p *parser) newLogicManager(function string, term ast.Expr) *logicManager {
+ if p.enableVariadicOperatorASTs {
+ return newVariadicLogicManager(p.exprFactory, function, term)
+ }
+ return newBalancingLogicManager(p.exprFactory, function, term)
+}
+
+func (p *parser) reportError(ctx any, format string, args ...any) ast.Expr {
var location common.Location
- switch ctx.(type) {
+ err := p.helper.newExpr(ctx)
+ switch c := ctx.(type) {
case common.Location:
- location = ctx.(common.Location)
+ location = c
case antlr.Token, antlr.ParserRuleContext:
- err := p.helper.newExpr(ctx)
- location = p.helper.getLocation(err.GetId())
+ location = p.helper.getLocation(err.ID())
}
- err := p.helper.newExpr(ctx)
// Provide arguments to the report error.
- p.errors.ReportError(location, format, args...)
+ p.errors.reportErrorAtID(err.ID(), location, format, args...)
return err
}
// ANTLR Parse listener implementations
func (p *parser) SyntaxError(recognizer antlr.Recognizer, offendingSymbol any, line, column int, msg string, e antlr.RecognitionException) {
- l := p.helper.source.NewLocation(line, column)
+ offset := p.helper.sourceInfo.ComputeOffset(int32(line), int32(column))
+ l := p.helper.getLocationByOffset(offset)
// Hack to keep existing error messages consistent with previous versions of CEL when a reserved word
// is used as an identifier. This behavior needs to be overhauled to provide consistent, normalized error
// messages out of ANTLR to prevent future breaking changes related to error message content.
@@ -903,33 +876,33 @@ func (p *parser) SyntaxError(recognizer antlr.Recognizer, offendingSymbol any, l
}
}
-func (p *parser) ReportAmbiguity(recognizer antlr.Parser, dfa *antlr.DFA, startIndex, stopIndex int, exact bool, ambigAlts *antlr.BitSet, configs antlr.ATNConfigSet) {
+func (p *parser) ReportAmbiguity(recognizer antlr.Parser, dfa *antlr.DFA, startIndex, stopIndex int, exact bool, ambigAlts *antlr.BitSet, configs *antlr.ATNConfigSet) {
// Intentional
}
-func (p *parser) ReportAttemptingFullContext(recognizer antlr.Parser, dfa *antlr.DFA, startIndex, stopIndex int, conflictingAlts *antlr.BitSet, configs antlr.ATNConfigSet) {
+func (p *parser) ReportAttemptingFullContext(recognizer antlr.Parser, dfa *antlr.DFA, startIndex, stopIndex int, conflictingAlts *antlr.BitSet, configs *antlr.ATNConfigSet) {
// Intentional
}
-func (p *parser) ReportContextSensitivity(recognizer antlr.Parser, dfa *antlr.DFA, startIndex, stopIndex, prediction int, configs antlr.ATNConfigSet) {
+func (p *parser) ReportContextSensitivity(recognizer antlr.Parser, dfa *antlr.DFA, startIndex, stopIndex, prediction int, configs *antlr.ATNConfigSet) {
// Intentional
}
-func (p *parser) globalCallOrMacro(exprID int64, function string, args ...*exprpb.Expr) *exprpb.Expr {
+func (p *parser) globalCallOrMacro(exprID int64, function string, args ...ast.Expr) ast.Expr {
if expr, found := p.expandMacro(exprID, function, nil, args...); found {
return expr
}
return p.helper.newGlobalCall(exprID, function, args...)
}
-func (p *parser) receiverCallOrMacro(exprID int64, function string, target *exprpb.Expr, args ...*exprpb.Expr) *exprpb.Expr {
+func (p *parser) receiverCallOrMacro(exprID int64, function string, target ast.Expr, args ...ast.Expr) ast.Expr {
if expr, found := p.expandMacro(exprID, function, target, args...); found {
return expr
}
return p.helper.newReceiverCall(exprID, function, target, args...)
}
-func (p *parser) expandMacro(exprID int64, function string, target *exprpb.Expr, args ...*exprpb.Expr) (*exprpb.Expr, bool) {
+func (p *parser) expandMacro(exprID int64, function string, target ast.Expr, args ...ast.Expr) (ast.Expr, bool) {
macro, found := p.macros[makeMacroKey(function, len(args), target != nil)]
if !found {
macro, found = p.macros[makeVarArgMacroKey(function, target != nil)]
@@ -944,10 +917,12 @@ func (p *parser) expandMacro(exprID int64, function string, target *exprpb.Expr,
expr, err := macro.Expander()(eh, target, args)
// An error indicates that the macro was matched, but the arguments were not well-formed.
if err != nil {
- if err.Location != nil {
- return p.reportError(err.Location, err.Message), true
+ loc := err.Location
+ if loc == nil {
+ loc = p.helper.getLocation(exprID)
}
- return p.reportError(p.helper.getLocation(exprID), err.Message), true
+ p.helper.deleteID(exprID)
+ return p.reportError(loc, err.Message), true
}
// A nil value from the macro indicates that the macro implementation decided that
// an expansion should not be performed.
@@ -955,8 +930,9 @@ func (p *parser) expandMacro(exprID int64, function string, target *exprpb.Expr,
return nil, false
}
if p.populateMacroCalls {
- p.helper.addMacroCall(expr.GetId(), function, target, args...)
+ p.helper.addMacroCall(expr.ID(), function, target, args...)
}
+ p.helper.deleteID(exprID)
return expr, true
}