diff options
Diffstat (limited to 'vendor/github.com/google/cel-go/ext')
| -rw-r--r-- | vendor/github.com/google/cel-go/ext/BUILD.bazel | 17 | ||||
| -rw-r--r-- | vendor/github.com/google/cel-go/ext/README.md | 30 | ||||
| -rw-r--r-- | vendor/github.com/google/cel-go/ext/extension_option_factory.go | 75 | ||||
| -rw-r--r-- | vendor/github.com/google/cel-go/ext/formatting.go | 23 | ||||
| -rw-r--r-- | vendor/github.com/google/cel-go/ext/formatting_v2.go | 788 | ||||
| -rw-r--r-- | vendor/github.com/google/cel-go/ext/lists.go | 265 | ||||
| -rw-r--r-- | vendor/github.com/google/cel-go/ext/math.go | 47 | ||||
| -rw-r--r-- | vendor/github.com/google/cel-go/ext/native.go | 5 | ||||
| -rw-r--r-- | vendor/github.com/google/cel-go/ext/regex.go | 332 | ||||
| -rw-r--r-- | vendor/github.com/google/cel-go/ext/sets.go | 14 | ||||
| -rw-r--r-- | vendor/github.com/google/cel-go/ext/strings.go | 63 |
11 files changed, 1601 insertions, 58 deletions
diff --git a/vendor/github.com/google/cel-go/ext/BUILD.bazel b/vendor/github.com/google/cel-go/ext/BUILD.bazel index b764fa1..ef4f4ec 100644 --- a/vendor/github.com/google/cel-go/ext/BUILD.bazel +++ b/vendor/github.com/google/cel-go/ext/BUILD.bazel @@ -10,12 +10,15 @@ go_library( "bindings.go", "comprehensions.go", "encoders.go", + "extension_option_factory.go", "formatting.go", + "formatting_v2.go", "guards.go", "lists.go", "math.go", "native.go", "protos.go", + "regex.go", "sets.go", "strings.go", ], @@ -24,10 +27,12 @@ go_library( deps = [ "//cel:go_default_library", "//checker:go_default_library", + "//common:go_default_library", "//common/ast:go_default_library", "//common/decls:go_default_library", - "//common/overloads:go_default_library", + "//common/env:go_default_library", "//common/operators:go_default_library", + "//common/overloads:go_default_library", "//common/types:go_default_library", "//common/types/pb:go_default_library", "//common/types/ref:go_default_library", @@ -48,11 +53,15 @@ go_test( srcs = [ "bindings_test.go", "comprehensions_test.go", - "encoders_test.go", + "encoders_test.go", + "extension_option_factory_test.go", + "formatting_test.go", + "formatting_v2_test.go", "lists_test.go", "math_test.go", "native_test.go", "protos_test.go", + "regex_test.go", "sets_test.go", "strings_test.go", ], @@ -62,14 +71,16 @@ go_test( deps = [ "//cel:go_default_library", "//checker:go_default_library", + "//common:go_default_library", + "//common/env:go_default_library", "//common/types:go_default_library", "//common/types/ref:go_default_library", "//common/types/traits:go_default_library", "//test:go_default_library", "//test/proto2pb:go_default_library", "//test/proto3pb:go_default_library", + "@org_golang_google_protobuf//encoding/protojson:go_default_library", "@org_golang_google_protobuf//proto:go_default_library", "@org_golang_google_protobuf//types/known/wrapperspb:go_default_library", - "@org_golang_google_protobuf//encoding/protojson:go_default_library", ], ) diff --git a/vendor/github.com/google/cel-go/ext/README.md b/vendor/github.com/google/cel-go/ext/README.md index 4620204..41ae6a3 100644 --- a/vendor/github.com/google/cel-go/ext/README.md +++ b/vendor/github.com/google/cel-go/ext/README.md @@ -356,6 +356,23 @@ Examples: math.isFinite(0.0/0.0) // returns false math.isFinite(1.2) // returns true +### Math.Sqrt + +Introduced at version: 2 + +Returns the square root of the given input as double +Throws error for negative or non-numeric inputs + + math.sqrt(<double>) -> <double> + math.sqrt(<int>) -> <double> + math.sqrt(<uint>) -> <double> + +Examples: + + math.sqrt(81) // returns 9.0 + math.sqrt(985.25) // returns 31.388692231439016 + math.sqrt(-15) // returns NaN + ## Protos Protos configure extended macros and functions for proto manipulation. @@ -395,7 +412,7 @@ zero-based. ### Distinct -**Introduced in version 2** +**Introduced in version 2 (cost support in version 3)** Returns the distinct elements of a list. @@ -409,7 +426,7 @@ Examples: ### Flatten -**Introduced in version 1** +**Introduced in version 1 (cost support in version 3)** Flattens a list recursively. If an optional depth is provided, the list is flattened to a the specificied level. @@ -428,7 +445,7 @@ Examples: ### Range -**Introduced in version 2** +**Introduced in version 2 (cost support in version 3)** Returns a list of integers from 0 to n-1. @@ -441,7 +458,7 @@ Examples: ### Reverse -**Introduced in version 2** +**Introduced in version 2 (cost support in version 3)** Returns the elements of a list in reverse order. @@ -454,6 +471,7 @@ Examples: ### Slice +**Introduced in version 0 (cost support in version 3)** Returns a new sub-list using the indexes provided. @@ -466,7 +484,7 @@ Examples: ### Sort -**Introduced in version 2** +**Introduced in version 2 (cost support in version 3)** Sorts a list with comparable elements. If the element type is not comparable or the element types are not the same, the function will produce an error. @@ -483,7 +501,7 @@ Examples: ### SortBy -**Introduced in version 2** +**Introduced in version 2 (cost support in version 3)** Sorts a list by a key value, i.e., the order is determined by the result of an expression applied to each element of the list. diff --git a/vendor/github.com/google/cel-go/ext/extension_option_factory.go b/vendor/github.com/google/cel-go/ext/extension_option_factory.go new file mode 100644 index 0000000..cebf0d7 --- /dev/null +++ b/vendor/github.com/google/cel-go/ext/extension_option_factory.go @@ -0,0 +1,75 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ext + +import ( + "fmt" + + "github.com/google/cel-go/cel" + "github.com/google/cel-go/common/env" +) + +// ExtensionOptionFactory converts an ExtensionConfig value to a CEL environment option. +func ExtensionOptionFactory(configElement any) (cel.EnvOption, bool) { + ext, isExtension := configElement.(*env.Extension) + if !isExtension { + return nil, false + } + fac, found := extFactories[ext.Name] + if !found { + return nil, false + } + // If the version is 'latest', set the version value to the max uint. + ver, err := ext.VersionNumber() + if err != nil { + return func(*cel.Env) (*cel.Env, error) { + return nil, fmt.Errorf("invalid extension version: %s - %s", ext.Name, ext.Version) + }, true + } + return fac(ver), true +} + +// extensionFactory accepts a version and produces a CEL environment associated with the versioned extension. +type extensionFactory func(uint32) cel.EnvOption + +var extFactories = map[string]extensionFactory{ + "bindings": func(version uint32) cel.EnvOption { + return Bindings(BindingsVersion(version)) + }, + "encoders": func(version uint32) cel.EnvOption { + return Encoders(EncodersVersion(version)) + }, + "lists": func(version uint32) cel.EnvOption { + return Lists(ListsVersion(version)) + }, + "math": func(version uint32) cel.EnvOption { + return Math(MathVersion(version)) + }, + "protos": func(version uint32) cel.EnvOption { + return Protos(ProtosVersion(version)) + }, + "sets": func(version uint32) cel.EnvOption { + return Sets(SetsVersion(version)) + }, + "strings": func(version uint32) cel.EnvOption { + return Strings(StringsVersion(version)) + }, + "two-var-comprehensions": func(version uint32) cel.EnvOption { + return TwoVarComprehensions(TwoVarComprehensionsVersion(version)) + }, + "regex": func(version uint32) cel.EnvOption { + return Regex(RegexVersion(version)) + }, +} diff --git a/vendor/github.com/google/cel-go/ext/formatting.go b/vendor/github.com/google/cel-go/ext/formatting.go index aa334cc..111184b 100644 --- a/vendor/github.com/google/cel-go/ext/formatting.go +++ b/vendor/github.com/google/cel-go/ext/formatting.go @@ -268,14 +268,17 @@ func makeMatcher(locale string) (language.Matcher, error) { type stringFormatter struct{} +// String implements formatStringInterpolator.String. func (c *stringFormatter) String(arg ref.Val, locale string) (string, error) { return FormatString(arg, locale) } +// Decimal implements formatStringInterpolator.Decimal. func (c *stringFormatter) Decimal(arg ref.Val, locale string) (string, error) { return formatDecimal(arg, locale) } +// Fixed implements formatStringInterpolator.Fixed. func (c *stringFormatter) Fixed(precision *int) func(ref.Val, string) (string, error) { if precision == nil { precision = new(int) @@ -307,6 +310,7 @@ func (c *stringFormatter) Fixed(precision *int) func(ref.Val, string) (string, e } } +// Scientific implements formatStringInterpolator.Scientific. func (c *stringFormatter) Scientific(precision *int) func(ref.Val, string) (string, error) { if precision == nil { precision = new(int) @@ -337,6 +341,7 @@ func (c *stringFormatter) Scientific(precision *int) func(ref.Val, string) (stri } } +// Binary implements formatStringInterpolator.Binary. func (c *stringFormatter) Binary(arg ref.Val, locale string) (string, error) { switch arg.Type() { case types.IntType: @@ -358,6 +363,7 @@ func (c *stringFormatter) Binary(arg ref.Val, locale string) (string, error) { } } +// Hex implements formatStringInterpolator.Hex. func (c *stringFormatter) Hex(useUpper bool) func(ref.Val, string) (string, error) { return func(arg ref.Val, locale string) (string, error) { fmtStr := "%x" @@ -388,6 +394,7 @@ func (c *stringFormatter) Hex(useUpper bool) func(ref.Val, string) (string, erro } } +// Octal implements formatStringInterpolator.Octal. func (c *stringFormatter) Octal(arg ref.Val, locale string) (string, error) { switch arg.Type() { case types.IntType: @@ -504,6 +511,7 @@ type stringFormatChecker struct { ast *ast.AST } +// String implements formatStringInterpolator.String. func (c *stringFormatChecker) String(arg ref.Val, locale string) (string, error) { formatArg := c.args[c.currArgIndex] valid, badID := c.verifyString(formatArg) @@ -513,6 +521,7 @@ func (c *stringFormatChecker) String(arg ref.Val, locale string) (string, error) return "", nil } +// Decimal implements formatStringInterpolator.Decimal. func (c *stringFormatChecker) Decimal(arg ref.Val, locale string) (string, error) { id := c.args[c.currArgIndex].ID() valid := c.verifyTypeOneOf(id, types.IntType, types.UintType) @@ -522,6 +531,7 @@ func (c *stringFormatChecker) Decimal(arg ref.Val, locale string) (string, error return "", nil } +// Fixed implements formatStringInterpolator.Fixed. func (c *stringFormatChecker) Fixed(precision *int) func(ref.Val, string) (string, error) { return func(arg ref.Val, locale string) (string, error) { id := c.args[c.currArgIndex].ID() @@ -534,6 +544,7 @@ func (c *stringFormatChecker) Fixed(precision *int) func(ref.Val, string) (strin } } +// Scientific implements formatStringInterpolator.Scientific. func (c *stringFormatChecker) Scientific(precision *int) func(ref.Val, string) (string, error) { return func(arg ref.Val, locale string) (string, error) { id := c.args[c.currArgIndex].ID() @@ -545,6 +556,7 @@ func (c *stringFormatChecker) Scientific(precision *int) func(ref.Val, string) ( } } +// Binary implements formatStringInterpolator.Binary. func (c *stringFormatChecker) Binary(arg ref.Val, locale string) (string, error) { id := c.args[c.currArgIndex].ID() valid := c.verifyTypeOneOf(id, types.IntType, types.UintType, types.BoolType) @@ -554,6 +566,7 @@ func (c *stringFormatChecker) Binary(arg ref.Val, locale string) (string, error) return "", nil } +// Hex implements formatStringInterpolator.Hex. func (c *stringFormatChecker) Hex(useUpper bool) func(ref.Val, string) (string, error) { return func(arg ref.Val, locale string) (string, error) { id := c.args[c.currArgIndex].ID() @@ -565,6 +578,7 @@ func (c *stringFormatChecker) Hex(useUpper bool) func(ref.Val, string) (string, } } +// Octal implements formatStringInterpolator.Octal. func (c *stringFormatChecker) Octal(arg ref.Val, locale string) (string, error) { id := c.args[c.currArgIndex].ID() valid := c.verifyTypeOneOf(id, types.IntType, types.UintType) @@ -574,6 +588,7 @@ func (c *stringFormatChecker) Octal(arg ref.Val, locale string) (string, error) return "", nil } +// Arg implements formatListArgs.Arg. func (c *stringFormatChecker) Arg(index int64) (ref.Val, error) { c.argsRequested++ c.currArgIndex = index @@ -582,6 +597,7 @@ func (c *stringFormatChecker) Arg(index int64) (ref.Val, error) { return types.Int(0), nil } +// Size implements formatListArgs.Size. func (c *stringFormatChecker) Size() int64 { return int64(len(c.args)) } @@ -686,10 +702,12 @@ func newFormatError(id int64, msg string, args ...any) error { } } +// Error implements error. func (e formatError) Error() string { return e.msg } +// Is implements errors.Is. func (e formatError) Is(target error) bool { return e.msg == target.Error() } @@ -699,6 +717,7 @@ type stringArgList struct { args traits.Lister } +// Arg implements formatListArgs.Arg. func (c *stringArgList) Arg(index int64) (ref.Val, error) { if index >= c.args.Size().Value().(int64) { return nil, fmt.Errorf("index %d out of range", index) @@ -706,6 +725,7 @@ func (c *stringArgList) Arg(index int64) (ref.Val, error) { return c.args.Get(types.Int(index)), nil } +// Size implements formatListArgs.Size. func (c *stringArgList) Size() int64 { return c.args.Size().Value().(int64) } @@ -887,14 +907,17 @@ func newParseFormatError(msg string, wrapped error) error { return parseFormatError{msg: msg, wrapped: wrapped} } +// Error implements error. func (e parseFormatError) Error() string { return fmt.Sprintf("%s: %s", e.msg, e.wrapped.Error()) } +// Is implements errors.Is. func (e parseFormatError) Is(target error) bool { return e.Error() == target.Error() } +// Is implements errors.Unwrap. func (e parseFormatError) Unwrap() error { return e.wrapped } diff --git a/vendor/github.com/google/cel-go/ext/formatting_v2.go b/vendor/github.com/google/cel-go/ext/formatting_v2.go new file mode 100644 index 0000000..ca8efbc --- /dev/null +++ b/vendor/github.com/google/cel-go/ext/formatting_v2.go @@ -0,0 +1,788 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ext + +import ( + "errors" + "fmt" + "math" + "sort" + "strconv" + "strings" + "time" + "unicode" + + "github.com/google/cel-go/cel" + "github.com/google/cel-go/common/ast" + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" + "github.com/google/cel-go/common/types/traits" +) + +type clauseImplV2 func(ref.Val) (string, error) + +type appendingFormatterV2 struct { + buf []byte +} + +type formattedMapEntryV2 struct { + key string + val string +} + +func (af *appendingFormatterV2) format(arg ref.Val) error { + switch arg.Type() { + case types.BoolType: + argBool, ok := arg.Value().(bool) + if !ok { + return fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.BoolType) + } + af.buf = strconv.AppendBool(af.buf, argBool) + return nil + case types.IntType: + argInt, ok := arg.Value().(int64) + if !ok { + return fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.IntType) + } + af.buf = strconv.AppendInt(af.buf, argInt, 10) + return nil + case types.UintType: + argUint, ok := arg.Value().(uint64) + if !ok { + return fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.UintType) + } + af.buf = strconv.AppendUint(af.buf, argUint, 10) + return nil + case types.DoubleType: + argDbl, ok := arg.Value().(float64) + if !ok { + return fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.DoubleType) + } + if math.IsNaN(argDbl) { + af.buf = append(af.buf, "NaN"...) + return nil + } + if math.IsInf(argDbl, -1) { + af.buf = append(af.buf, "-Infinity"...) + return nil + } + if math.IsInf(argDbl, 1) { + af.buf = append(af.buf, "Infinity"...) + return nil + } + af.buf = strconv.AppendFloat(af.buf, argDbl, 'f', -1, 64) + return nil + case types.BytesType: + argBytes, ok := arg.Value().([]byte) + if !ok { + return fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.BytesType) + } + af.buf = append(af.buf, argBytes...) + return nil + case types.StringType: + argStr, ok := arg.Value().(string) + if !ok { + return fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.StringType) + } + af.buf = append(af.buf, argStr...) + return nil + case types.DurationType: + argDur, ok := arg.Value().(time.Duration) + if !ok { + return fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.DurationType) + } + af.buf = strconv.AppendFloat(af.buf, argDur.Seconds(), 'f', -1, 64) + af.buf = append(af.buf, "s"...) + return nil + case types.TimestampType: + argTime, ok := arg.Value().(time.Time) + if !ok { + return fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.TimestampType) + } + af.buf = argTime.UTC().AppendFormat(af.buf, time.RFC3339Nano) + return nil + case types.NullType: + af.buf = append(af.buf, "null"...) + return nil + case types.TypeType: + argType, ok := arg.Value().(string) + if !ok { + return fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.TypeType) + } + af.buf = append(af.buf, argType...) + return nil + case types.ListType: + argList, ok := arg.(traits.Lister) + if !ok { + return fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.ListType) + } + argIter := argList.Iterator() + af.buf = append(af.buf, "["...) + if argIter.HasNext() == types.True { + if err := af.format(argIter.Next()); err != nil { + return err + } + for argIter.HasNext() == types.True { + af.buf = append(af.buf, ", "...) + if err := af.format(argIter.Next()); err != nil { + return err + } + } + } + af.buf = append(af.buf, "]"...) + return nil + case types.MapType: + argMap, ok := arg.(traits.Mapper) + if !ok { + return fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.MapType) + } + argIter := argMap.Iterator() + ents := []formattedMapEntryV2{} + for argIter.HasNext() == types.True { + key := argIter.Next() + val, ok := argMap.Find(key) + if !ok { + return fmt.Errorf("key missing from map: '%s'", key) + } + keyStr, err := formatStringV2(key) + if err != nil { + return err + } + valStr, err := formatStringV2(val) + if err != nil { + return err + } + ents = append(ents, formattedMapEntryV2{keyStr, valStr}) + } + sort.SliceStable(ents, func(x, y int) bool { + return ents[x].key < ents[y].key + }) + af.buf = append(af.buf, "{"...) + for i, e := range ents { + if i > 0 { + af.buf = append(af.buf, ", "...) + } + af.buf = append(af.buf, e.key...) + af.buf = append(af.buf, ": "...) + af.buf = append(af.buf, e.val...) + } + af.buf = append(af.buf, "}"...) + return nil + default: + return stringFormatErrorV2(runtimeID, arg.Type().TypeName()) + } +} + +func formatStringV2(arg ref.Val) (string, error) { + var fmter appendingFormatterV2 + if err := fmter.format(arg); err != nil { + return "", err + } + return string(fmter.buf), nil +} + +type stringFormatterV2 struct{} + +// String implements formatStringInterpolatorV2.String. +func (c *stringFormatterV2) String(arg ref.Val) (string, error) { + return formatStringV2(arg) +} + +// Decimal implements formatStringInterpolatorV2.Decimal. +func (c *stringFormatterV2) Decimal(arg ref.Val) (string, error) { + switch arg.Type() { + case types.IntType: + argInt, ok := arg.Value().(int64) + if !ok { + return "", fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.IntType) + } + return strconv.FormatInt(argInt, 10), nil + case types.UintType: + argUint, ok := arg.Value().(uint64) + if !ok { + return "", fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.UintType) + } + return strconv.FormatUint(argUint, 10), nil + case types.DoubleType: + argDbl, ok := arg.Value().(float64) + if !ok { + return "", fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.DoubleType) + } + if math.IsNaN(argDbl) { + return "NaN", nil + } + if math.IsInf(argDbl, -1) { + return "-Infinity", nil + } + if math.IsInf(argDbl, 1) { + return "Infinity", nil + } + return strconv.FormatFloat(argDbl, 'f', -1, 64), nil + default: + return "", decimalFormatErrorV2(runtimeID, arg.Type().TypeName()) + } +} + +// Fixed implements formatStringInterpolatorV2.Fixed. +func (c *stringFormatterV2) Fixed(precision int) func(ref.Val) (string, error) { + return func(arg ref.Val) (string, error) { + fmtStr := fmt.Sprintf("%%.%df", precision) + switch arg.Type() { + case types.IntType: + argInt, ok := arg.Value().(int64) + if !ok { + return "", fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.IntType) + } + return fmt.Sprintf(fmtStr, argInt), nil + case types.UintType: + argUint, ok := arg.Value().(uint64) + if !ok { + return "", fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.UintType) + } + return fmt.Sprintf(fmtStr, argUint), nil + case types.DoubleType: + argDbl, ok := arg.Value().(float64) + if !ok { + return "", fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.DoubleType) + } + if math.IsNaN(argDbl) { + return "NaN", nil + } + if math.IsInf(argDbl, -1) { + return "-Infinity", nil + } + if math.IsInf(argDbl, 1) { + return "Infinity", nil + } + return fmt.Sprintf(fmtStr, argDbl), nil + default: + return "", fixedPointFormatErrorV2(runtimeID, arg.Type().TypeName()) + } + } +} + +// Scientific implements formatStringInterpolatorV2.Scientific. +func (c *stringFormatterV2) Scientific(precision int) func(ref.Val) (string, error) { + return func(arg ref.Val) (string, error) { + fmtStr := fmt.Sprintf("%%1.%de", precision) + switch arg.Type() { + case types.IntType: + argInt, ok := arg.Value().(int64) + if !ok { + return "", fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.IntType) + } + return fmt.Sprintf(fmtStr, argInt), nil + case types.UintType: + argUint, ok := arg.Value().(uint64) + if !ok { + return "", fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.UintType) + } + return fmt.Sprintf(fmtStr, argUint), nil + case types.DoubleType: + argDbl, ok := arg.Value().(float64) + if !ok { + return "", fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.DoubleType) + } + if math.IsNaN(argDbl) { + return "NaN", nil + } + if math.IsInf(argDbl, -1) { + return "-Infinity", nil + } + if math.IsInf(argDbl, 1) { + return "Infinity", nil + } + return fmt.Sprintf(fmtStr, argDbl), nil + default: + return "", scientificFormatErrorV2(runtimeID, arg.Type().TypeName()) + } + } +} + +// Binary implements formatStringInterpolatorV2.Binary. +func (c *stringFormatterV2) Binary(arg ref.Val) (string, error) { + switch arg.Type() { + case types.BoolType: + argBool, ok := arg.Value().(bool) + if !ok { + return "", fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.BoolType) + } + if argBool { + return "1", nil + } + return "0", nil + case types.IntType: + argInt, ok := arg.Value().(int64) + if !ok { + return "", fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.IntType) + } + return strconv.FormatInt(argInt, 2), nil + case types.UintType: + argUint, ok := arg.Value().(uint64) + if !ok { + return "", fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.UintType) + } + return strconv.FormatUint(argUint, 2), nil + default: + return "", binaryFormatErrorV2(runtimeID, arg.Type().TypeName()) + } +} + +// Hex implements formatStringInterpolatorV2.Hex. +func (c *stringFormatterV2) Hex(useUpper bool) func(ref.Val) (string, error) { + return func(arg ref.Val) (string, error) { + var fmtStr string + if useUpper { + fmtStr = "%X" + } else { + fmtStr = "%x" + } + switch arg.Type() { + case types.IntType: + argInt, ok := arg.Value().(int64) + if !ok { + return "", fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.IntType) + } + return fmt.Sprintf(fmtStr, argInt), nil + case types.UintType: + argUint, ok := arg.Value().(uint64) + if !ok { + return "", fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.UintType) + } + return fmt.Sprintf(fmtStr, argUint), nil + case types.StringType: + argStr, ok := arg.Value().(string) + if !ok { + return "", fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.StringType) + } + return fmt.Sprintf(fmtStr, argStr), nil + case types.BytesType: + argBytes, ok := arg.Value().([]byte) + if !ok { + return "", fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.BytesType) + } + return fmt.Sprintf(fmtStr, argBytes), nil + default: + return "", hexFormatErrorV2(runtimeID, arg.Type().TypeName()) + } + } +} + +// Octal implements formatStringInterpolatorV2.Octal. +func (c *stringFormatterV2) Octal(arg ref.Val) (string, error) { + switch arg.Type() { + case types.IntType: + argInt, ok := arg.Value().(int64) + if !ok { + return "", fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.IntType) + } + return strconv.FormatInt(argInt, 8), nil + case types.UintType: + argUint, ok := arg.Value().(uint64) + if !ok { + return "", fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.UintType) + } + return strconv.FormatUint(argUint, 8), nil + default: + return "", octalFormatErrorV2(runtimeID, arg.Type().TypeName()) + } +} + +// stringFormatValidatorV2 implements the cel.ASTValidator interface allowing for static validation +// of string.format calls. +type stringFormatValidatorV2 struct{} + +// Name returns the name of the validator. +func (stringFormatValidatorV2) Name() string { + return "cel.validator.string_format" +} + +// Configure implements the ASTValidatorConfigurer interface and augments the list of functions to skip +// during homogeneous aggregate literal type-checks. +func (stringFormatValidatorV2) Configure(config cel.MutableValidatorConfig) error { + functions := config.GetOrDefault(cel.HomogeneousAggregateLiteralExemptFunctions, []string{}).([]string) + functions = append(functions, "format") + return config.Set(cel.HomogeneousAggregateLiteralExemptFunctions, functions) +} + +// Validate parses all literal format strings and type checks the format clause against the argument +// at the corresponding ordinal within the list literal argument to the function, if one is specified. +func (stringFormatValidatorV2) Validate(env *cel.Env, _ cel.ValidatorConfig, a *ast.AST, iss *cel.Issues) { + root := ast.NavigateAST(a) + formatCallExprs := ast.MatchDescendants(root, matchConstantFormatStringWithListLiteralArgs(a)) + for _, e := range formatCallExprs { + call := e.AsCall() + formatStr := call.Target().AsLiteral().Value().(string) + args := call.Args()[0].AsList().Elements() + formatCheck := &stringFormatCheckerV2{ + args: args, + ast: a, + } + // use a placeholder locale, since locale doesn't affect syntax + _, err := parseFormatStringV2(formatStr, formatCheck, formatCheck) + if err != nil { + iss.ReportErrorAtID(getErrorExprID(e.ID(), err), "%v", err) + continue + } + seenArgs := formatCheck.argsRequested + if len(args) > seenArgs { + iss.ReportErrorAtID(e.ID(), + "too many arguments supplied to string.format (expected %d, got %d)", seenArgs, len(args)) + } + } +} + +// stringFormatCheckerV2 implements the formatStringInterpolater interface +type stringFormatCheckerV2 struct { + args []ast.Expr + argsRequested int + currArgIndex int64 + ast *ast.AST +} + +// String implements formatStringInterpolatorV2.String. +func (c *stringFormatCheckerV2) String(arg ref.Val) (string, error) { + formatArg := c.args[c.currArgIndex] + valid, badID := c.verifyString(formatArg) + if !valid { + return "", stringFormatErrorV2(badID, c.typeOf(badID).TypeName()) + } + return "", nil +} + +// Decimal implements formatStringInterpolatorV2.Decimal. +func (c *stringFormatCheckerV2) Decimal(arg ref.Val) (string, error) { + id := c.args[c.currArgIndex].ID() + valid := c.verifyTypeOneOf(id, types.IntType, types.UintType, types.DoubleType) + if !valid { + return "", decimalFormatErrorV2(id, c.typeOf(id).TypeName()) + } + return "", nil +} + +// Fixed implements formatStringInterpolatorV2.Fixed. +func (c *stringFormatCheckerV2) Fixed(precision int) func(ref.Val) (string, error) { + return func(arg ref.Val) (string, error) { + id := c.args[c.currArgIndex].ID() + valid := c.verifyTypeOneOf(id, types.IntType, types.UintType, types.DoubleType) + if !valid { + return "", fixedPointFormatErrorV2(id, c.typeOf(id).TypeName()) + } + return "", nil + } +} + +// Scientific implements formatStringInterpolatorV2.Scientific. +func (c *stringFormatCheckerV2) Scientific(precision int) func(ref.Val) (string, error) { + return func(arg ref.Val) (string, error) { + id := c.args[c.currArgIndex].ID() + valid := c.verifyTypeOneOf(id, types.IntType, types.UintType, types.DoubleType) + if !valid { + return "", scientificFormatErrorV2(id, c.typeOf(id).TypeName()) + } + return "", nil + } +} + +// Binary implements formatStringInterpolatorV2.Binary. +func (c *stringFormatCheckerV2) Binary(arg ref.Val) (string, error) { + id := c.args[c.currArgIndex].ID() + valid := c.verifyTypeOneOf(id, types.BoolType, types.IntType, types.UintType) + if !valid { + return "", binaryFormatErrorV2(id, c.typeOf(id).TypeName()) + } + return "", nil +} + +// Hex implements formatStringInterpolatorV2.Hex. +func (c *stringFormatCheckerV2) Hex(useUpper bool) func(ref.Val) (string, error) { + return func(arg ref.Val) (string, error) { + id := c.args[c.currArgIndex].ID() + valid := c.verifyTypeOneOf(id, types.IntType, types.UintType, types.StringType, types.BytesType) + if !valid { + return "", hexFormatErrorV2(id, c.typeOf(id).TypeName()) + } + return "", nil + } +} + +// Octal implements formatStringInterpolatorV2.Octal. +func (c *stringFormatCheckerV2) Octal(arg ref.Val) (string, error) { + id := c.args[c.currArgIndex].ID() + valid := c.verifyTypeOneOf(id, types.IntType, types.UintType) + if !valid { + return "", octalFormatErrorV2(id, c.typeOf(id).TypeName()) + } + return "", nil +} + +// Arg implements formatListArgs.Arg. +func (c *stringFormatCheckerV2) Arg(index int64) (ref.Val, error) { + c.argsRequested++ + c.currArgIndex = index + // return a dummy value - this is immediately passed to back to us + // through one of the FormatCallback functions, so anything will do + return types.Int(0), nil +} + +// Size implements formatListArgs.Size. +func (c *stringFormatCheckerV2) Size() int64 { + return int64(len(c.args)) +} + +func (c *stringFormatCheckerV2) typeOf(id int64) *cel.Type { + return c.ast.GetType(id) +} + +func (c *stringFormatCheckerV2) verifyTypeOneOf(id int64, validTypes ...*cel.Type) bool { + t := c.typeOf(id) + if t == cel.DynType { + return true + } + for _, vt := range validTypes { + // Only check runtime type compatibility without delving deeper into parameterized types + if t.Kind() == vt.Kind() { + return true + } + } + return false +} + +func (c *stringFormatCheckerV2) verifyString(sub ast.Expr) (bool, int64) { + paramA := cel.TypeParamType("A") + paramB := cel.TypeParamType("B") + subVerified := c.verifyTypeOneOf(sub.ID(), + cel.ListType(paramA), cel.MapType(paramA, paramB), + cel.IntType, cel.UintType, cel.DoubleType, cel.BoolType, cel.StringType, + cel.TimestampType, cel.BytesType, cel.DurationType, cel.TypeType, cel.NullType) + if !subVerified { + return false, sub.ID() + } + switch sub.Kind() { + case ast.ListKind: + for _, e := range sub.AsList().Elements() { + // recursively verify if we're dealing with a list/map + verified, id := c.verifyString(e) + if !verified { + return false, id + } + } + return true, sub.ID() + case ast.MapKind: + for _, e := range sub.AsMap().Entries() { + // recursively verify if we're dealing with a list/map + entry := e.AsMapEntry() + verified, id := c.verifyString(entry.Key()) + if !verified { + return false, id + } + verified, id = c.verifyString(entry.Value()) + if !verified { + return false, id + } + } + return true, sub.ID() + default: + return true, sub.ID() + } +} + +// helper routines for reporting common errors during string formatting static validation and +// runtime execution. + +func binaryFormatErrorV2(id int64, badType string) error { + return newFormatError(id, "only ints, uints, and bools can be formatted as binary, was given %s", badType) +} + +func decimalFormatErrorV2(id int64, badType string) error { + return newFormatError(id, "decimal clause can only be used on ints, uints, and doubles, was given %s", badType) +} + +func fixedPointFormatErrorV2(id int64, badType string) error { + return newFormatError(id, "fixed-point clause can only be used on ints, uints, and doubles, was given %s", badType) +} + +func hexFormatErrorV2(id int64, badType string) error { + return newFormatError(id, "only ints, uints, bytes, and strings can be formatted as hex, was given %s", badType) +} + +func octalFormatErrorV2(id int64, badType string) error { + return newFormatError(id, "octal clause can only be used on ints and uints, was given %s", badType) +} + +func scientificFormatErrorV2(id int64, badType string) error { + return newFormatError(id, "scientific clause can only be used on ints, uints, and doubles, was given %s", badType) +} + +func stringFormatErrorV2(id int64, badType string) error { + return newFormatError(id, "string clause can only be used on strings, bools, bytes, ints, doubles, maps, lists, types, durations, and timestamps, was given %s", badType) +} + +// formatStringInterpolatorV2 is an interface that allows user-defined behavior +// for formatting clause implementations, as well as argument retrieval. +// Each function is expected to support the appropriate types as laid out in +// the string.format documentation, and to return an error if given an inappropriate type. +type formatStringInterpolatorV2 interface { + // String takes a ref.Val and a string representing the current locale identifier + // and returns the Val formatted as a string, or an error if one occurred. + String(ref.Val) (string, error) + + // Decimal takes a ref.Val and a string representing the current locale identifier + // and returns the Val formatted as a decimal integer, or an error if one occurred. + Decimal(ref.Val) (string, error) + + // Fixed takes an int pointer representing precision (or nil if none was given) and + // returns a function operating in a similar manner to String and Decimal, taking a + // ref.Val and locale and returning the appropriate string. A closure is returned + // so precision can be set without needing an additional function call/configuration. + Fixed(int) func(ref.Val) (string, error) + + // Scientific functions identically to Fixed, except the string returned from the closure + // is expected to be in scientific notation. + Scientific(int) func(ref.Val) (string, error) + + // Binary takes a ref.Val and a string representing the current locale identifier + // and returns the Val formatted as a binary integer, or an error if one occurred. + Binary(ref.Val) (string, error) + + // Hex takes a boolean that, if true, indicates the hex string output by the returned + // closure should use uppercase letters for A-F. + Hex(bool) func(ref.Val) (string, error) + + // Octal takes a ref.Val and a string representing the current locale identifier and + // returns the Val formatted in octal, or an error if one occurred. + Octal(ref.Val) (string, error) +} + +// parseFormatString formats a string according to the string.format syntax, taking the clause implementations +// from the provided FormatCallback and the args from the given FormatList. +func parseFormatStringV2(formatStr string, callback formatStringInterpolatorV2, list formatListArgs) (string, error) { + i := 0 + argIndex := 0 + var builtStr strings.Builder + for i < len(formatStr) { + if formatStr[i] == '%' { + if i+1 < len(formatStr) && formatStr[i+1] == '%' { + err := builtStr.WriteByte('%') + if err != nil { + return "", fmt.Errorf("error writing format string: %w", err) + } + i += 2 + continue + } else { + argAny, err := list.Arg(int64(argIndex)) + if err != nil { + return "", err + } + if i+1 >= len(formatStr) { + return "", errors.New("unexpected end of string") + } + if int64(argIndex) >= list.Size() { + return "", fmt.Errorf("index %d out of range", argIndex) + } + numRead, val, refErr := parseAndFormatClauseV2(formatStr[i:], argAny, callback, list) + if refErr != nil { + return "", refErr + } + _, err = builtStr.WriteString(val) + if err != nil { + return "", fmt.Errorf("error writing format string: %w", err) + } + i += numRead + argIndex++ + } + } else { + err := builtStr.WriteByte(formatStr[i]) + if err != nil { + return "", fmt.Errorf("error writing format string: %w", err) + } + i++ + } + } + return builtStr.String(), nil +} + +// parseAndFormatClause parses the format clause at the start of the given string with val, and returns +// how many characters were consumed and the substituted string form of val, or an error if one occurred. +func parseAndFormatClauseV2(formatStr string, val ref.Val, callback formatStringInterpolatorV2, list formatListArgs) (int, string, error) { + i := 1 + read, formatter, err := parseFormattingClauseV2(formatStr[i:], callback) + i += read + if err != nil { + return -1, "", newParseFormatError("could not parse formatting clause", err) + } + + valStr, err := formatter(val) + if err != nil { + return -1, "", newParseFormatError("error during formatting", err) + } + return i, valStr, nil +} + +func parseFormattingClauseV2(formatStr string, callback formatStringInterpolatorV2) (int, clauseImplV2, error) { + i := 0 + read, precision, err := parsePrecisionV2(formatStr[i:]) + i += read + if err != nil { + return -1, nil, fmt.Errorf("error while parsing precision: %w", err) + } + r := rune(formatStr[i]) + i++ + switch r { + case 's': + return i, callback.String, nil + case 'd': + return i, callback.Decimal, nil + case 'f': + return i, callback.Fixed(precision), nil + case 'e': + return i, callback.Scientific(precision), nil + case 'b': + return i, callback.Binary, nil + case 'x', 'X': + return i, callback.Hex(unicode.IsUpper(r)), nil + case 'o': + return i, callback.Octal, nil + default: + return -1, nil, fmt.Errorf("unrecognized formatting clause \"%c\"", r) + } +} + +func parsePrecisionV2(formatStr string) (int, int, error) { + i := 0 + if formatStr[i] != '.' { + return i, defaultPrecision, nil + } + i++ + var buffer strings.Builder + for { + if i >= len(formatStr) { + return -1, -1, errors.New("could not find end of precision specifier") + } + if !isASCIIDigit(rune(formatStr[i])) { + break + } + buffer.WriteByte(formatStr[i]) + i++ + } + precision, err := strconv.Atoi(buffer.String()) + if err != nil { + return -1, -1, fmt.Errorf("error while converting precision to integer: %w", err) + } + if precision < 0 { + return -1, -1, fmt.Errorf("negative precision: %d", precision) + } + return i, precision, nil +} diff --git a/vendor/github.com/google/cel-go/ext/lists.go b/vendor/github.com/google/cel-go/ext/lists.go index 9a3cce3..b27ddf2 100644 --- a/vendor/github.com/google/cel-go/ext/lists.go +++ b/vendor/github.com/google/cel-go/ext/lists.go @@ -20,11 +20,14 @@ import ( "sort" "github.com/google/cel-go/cel" + "github.com/google/cel-go/checker" + "github.com/google/cel-go/common" "github.com/google/cel-go/common/ast" "github.com/google/cel-go/common/decls" "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" "github.com/google/cel-go/common/types/traits" + "github.com/google/cel-go/interpreter" "github.com/google/cel-go/parser" ) @@ -44,7 +47,7 @@ var comparableTypes = []*cel.Type{ // // # Distinct // -// Introduced in version: 2 +// Introduced in version: 2 (cost support in version 3) // // Returns the distinct elements of a list. // @@ -58,7 +61,7 @@ var comparableTypes = []*cel.Type{ // // # Range // -// Introduced in version: 2 +// Introduced in version: 2 (cost support in version 3) // // Returns a list of integers from 0 to n-1. // @@ -70,7 +73,7 @@ var comparableTypes = []*cel.Type{ // // # Reverse // -// Introduced in version: 2 +// Introduced in version: 2 (cost support in version 3) // // Returns the elements of a list in reverse order. // @@ -82,6 +85,8 @@ var comparableTypes = []*cel.Type{ // // # Slice // +// Introduced in version: 0 (cost support in version 3) +// // Returns a new sub-list using the indexes provided. // // <list>.slice(<int>, <int>) -> <list> @@ -93,12 +98,14 @@ var comparableTypes = []*cel.Type{ // // # Flatten // +// Introduced in version: 1 (cost support in version 3) +// // Flattens a list recursively. -// If an optional depth is provided, the list is flattened to a the specificied level. +// If an optional depth is provided, the list is flattened to a the specified level. // A negative depth value will result in an error. // -// <list>.flatten(<list>) -> <list> -// <list>.flatten(<list>, <int>) -> <list> +// <list>.flatten() -> <list> +// <list>.flatten(<int>) -> <list> // // Examples: // @@ -110,7 +117,7 @@ var comparableTypes = []*cel.Type{ // // # Sort // -// Introduced in version: 2 +// Introduced in version: 2 (cost support in version 3) // // Sorts a list with comparable elements. If the element type is not comparable // or the element types are not the same, the function will produce an error. @@ -127,6 +134,8 @@ var comparableTypes = []*cel.Type{ // // # SortBy // +// Introduced in version: 2 (cost support in version 3) +// // Sorts a list by a key value, i.e., the order is determined by the result of // an expression applied to each element of the list. // The output of the key expression must be a comparable type, otherwise the @@ -303,9 +312,8 @@ func (lib listsLib) CompileOptions() []cel.EnvOption { opts = append(opts, cel.Function("lists.range", cel.Overload("lists_range", []*cel.Type{cel.IntType}, cel.ListType(cel.IntType), - cel.FunctionBinding(func(args ...ref.Val) ref.Val { - n := args[0].(types.Int) - result, err := genRange(n) + cel.UnaryBinding(func(n ref.Val) ref.Val { + result, err := genRange(n.(types.Int)) if err != nil { return types.WrapErr(err) } @@ -316,9 +324,8 @@ func (lib listsLib) CompileOptions() []cel.EnvOption { opts = append(opts, cel.Function("reverse", cel.MemberOverload("list_reverse", []*cel.Type{listType}, listType, - cel.FunctionBinding(func(args ...ref.Val) ref.Val { - list := args[0].(traits.Lister) - result, err := reverseList(list) + cel.UnaryBinding(func(list ref.Val) ref.Val { + result, err := reverseList(list.(traits.Lister)) if err != nil { return types.WrapErr(err) } @@ -339,13 +346,61 @@ func (lib listsLib) CompileOptions() []cel.EnvOption { ), )) } + if lib.version >= 3 { + estimators := []checker.CostOption{ + checker.OverloadCostEstimate("list_slice", estimateListSlice), + checker.OverloadCostEstimate("list_flatten", estimateListFlatten), + checker.OverloadCostEstimate("list_flatten_int", estimateListFlatten), + checker.OverloadCostEstimate("lists_range", estimateListsRange), + checker.OverloadCostEstimate("list_reverse", estimateListReverse), + checker.OverloadCostEstimate("list_distinct", estimateListDistinct), + } + for _, t := range comparableTypes { + estimators = append(estimators, + checker.OverloadCostEstimate( + fmt.Sprintf("list_%s_sort", t.TypeName()), + estimateListSort(t), + ), + checker.OverloadCostEstimate( + fmt.Sprintf("list_%s_sortByAssociatedKeys", t.TypeName()), + estimateListSortBy(t), + ), + ) + } + opts = append(opts, cel.CostEstimatorOptions(estimators...)) + } return opts } // ProgramOptions implements the Library interface method. -func (listsLib) ProgramOptions() []cel.ProgramOption { - return []cel.ProgramOption{} +func (lib *listsLib) ProgramOptions() []cel.ProgramOption { + var opts []cel.ProgramOption + if lib.version >= 3 { + // TODO: Add cost trackers for list operations + trackers := []interpreter.CostTrackerOption{ + interpreter.OverloadCostTracker("list_slice", trackListOutputSize), + interpreter.OverloadCostTracker("list_flatten", trackListFlatten), + interpreter.OverloadCostTracker("list_flatten_int", trackListFlatten), + interpreter.OverloadCostTracker("lists_range", trackListOutputSize), + interpreter.OverloadCostTracker("list_reverse", trackListOutputSize), + interpreter.OverloadCostTracker("list_distinct", trackListDistinct), + } + for _, t := range comparableTypes { + trackers = append(trackers, + interpreter.OverloadCostTracker( + fmt.Sprintf("list_%s_sort", t.TypeName()), + trackListSort, + ), + interpreter.OverloadCostTracker( + fmt.Sprintf("list_%s_sortByAssociatedKeys", t.TypeName()), + trackListSortBy, + ), + ) + } + opts = append(opts, cel.CostTrackerOptions(trackers...)) + } + return opts } func genRange(n types.Int) (ref.Val, error) { @@ -450,20 +505,24 @@ func sortListByAssociatedKeys(list, keys traits.Lister) (ref.Val, error) { sortedIndices := make([]ref.Val, 0, listLength) for i := types.IntZero; i < listLength; i++ { - if keys.Get(i).Type() != elem.Type() { - return nil, fmt.Errorf("list elements must have the same type") - } sortedIndices = append(sortedIndices, i) } + var err error sort.Slice(sortedIndices, func(i, j int) bool { iKey := keys.Get(sortedIndices[i]) jKey := keys.Get(sortedIndices[j]) + if iKey.Type() != elem.Type() || jKey.Type() != elem.Type() { + err = fmt.Errorf("list elements must have the same type") + return false + } return iKey.(traits.Comparer).Compare(jKey) == types.IntNegOne }) + if err != nil { + return nil, err + } sorted := make([]ref.Val, 0, listLength) - for _, sortedIdx := range sortedIndices { sorted = append(sorted, list.Get(sortedIdx)) } @@ -550,3 +609,171 @@ func templatedOverloads(types []*cel.Type, template func(t *cel.Type) cel.Functi } return overloads } + +// estimateListSlice computes an O(n) slice operation with a cost factor of 1. +func estimateListSlice(estimator checker.CostEstimator, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate { + if target == nil || len(args) != 2 { + return nil + } + sz := estimateSize(estimator, *target) + start := nodeAsIntValue(args[0], 0) + end := nodeAsIntValue(args[1], sz.Max) + return estimateAllocatingListCall(1, checker.FixedSizeEstimate(end-start)) +} + +// estimateListsRange computes an O(n) range operation with a cost factor of 1. +func estimateListsRange(estimator checker.CostEstimator, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate { + if target != nil || len(args) != 1 { + return nil + } + return estimateAllocatingListCall(1, checker.FixedSizeEstimate(nodeAsIntValue(args[0], math.MaxUint))) +} + +// estimateListReverse computes an O(n) reverse operation with a cost factor of 1. +func estimateListReverse(estimator checker.CostEstimator, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate { + if target == nil || len(args) != 0 { + return nil + } + return estimateAllocatingListCall(1, estimateSize(estimator, *target)) +} + +// estimateListFlatten computes an O(n) flatten operation with a cost factor proportional to the flatten depth. +func estimateListFlatten(estimator checker.CostEstimator, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate { + if target == nil || len(args) > 1 { + return nil + } + depth := uint64(1) + if len(args) == 1 { + depth = nodeAsIntValue(args[0], math.MaxUint) + } + return estimateAllocatingListCall(float64(depth), estimateSize(estimator, *target)) +} + +// Compute an O(n^2) with a cost factor of 2, equivalent to sets.contains with a result list +// which can vary in size from 1 element to the original list size. +func estimateListDistinct(estimator checker.CostEstimator, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate { + if target == nil || len(args) != 0 { + return nil + } + sz := estimateSize(estimator, *target) + costFactor := 2.0 + return estimateAllocatingListCall(costFactor, sz.Multiply(sz)) +} + +// estimateListSort computes an O(n^2) sort operation with a cost factor of 2 for the equality +// operations against the elements in the list against themselves which occur during the sort computation. +func estimateListSort(t *types.Type) checker.FunctionEstimator { + return func(estimator checker.CostEstimator, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate { + if target == nil || len(args) != 0 { + return nil + } + return estimateListSortCost(estimator, *target, t) + } +} + +// estimateListSortBy computes an O(n^2) sort operation with a cost factor of 2 for the equality +// operations against the sort index list which occur during the sort computation. +func estimateListSortBy(u *types.Type) checker.FunctionEstimator { + return func(estimator checker.CostEstimator, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate { + if target == nil || len(args) != 1 { + return nil + } + // Estimate the size of the list used as the sort index + return estimateListSortCost(estimator, args[0], u) + } +} + +// estimateListSortCost estimates an O(n^2) sort operation with a cost factor of 2 for the equality +// operations which occur during the sort computation. +func estimateListSortCost(estimator checker.CostEstimator, node checker.AstNode, elemType *types.Type) *checker.CallEstimate { + sz := estimateSize(estimator, node) + costFactor := 2.0 + switch elemType { + case types.StringType, types.BytesType: + costFactor += common.StringTraversalCostFactor + } + return estimateAllocatingListCall(costFactor, sz.Multiply(sz)) +} + +// estimateAllocatingListCall computes cost as a function of the size of the result list with a +// baseline cost for the call dispatch and the associated list allocation. +func estimateAllocatingListCall(costFactor float64, listSize checker.SizeEstimate) *checker.CallEstimate { + return estimateListCall(costFactor, listSize, true) +} + +// estimateListCall computes cost as a function of the size of the target list and whether the +// call allocates memory. +func estimateListCall(costFactor float64, listSize checker.SizeEstimate, allocates bool) *checker.CallEstimate { + cost := listSize.MultiplyByCostFactor(costFactor).Add(callCostEstimate) + if allocates { + cost = cost.Add(checker.FixedCostEstimate(common.ListCreateBaseCost)) + } + return &checker.CallEstimate{CostEstimate: cost, ResultSize: &listSize} +} + +// trackListOutputSize computes cost as a function of the size of the result list. +func trackListOutputSize(_ []ref.Val, result ref.Val) *uint64 { + return trackAllocatingListCall(1, actualSize(result)) +} + +// trackListFlatten computes cost as a function of the size of the result list and the depth of +// the flatten operation. +func trackListFlatten(args []ref.Val, _ ref.Val) *uint64 { + depth := 1.0 + if len(args) == 2 { + depth = float64(args[1].(types.Int)) + } + inputSize := actualSize(args[0]) + return trackAllocatingListCall(depth, inputSize) +} + +// trackListDistinct computes costs as a worst-case O(n^2) operation over the input list. +func trackListDistinct(args []ref.Val, _ ref.Val) *uint64 { + return trackListSelfCompare(args[0].(traits.Lister)) +} + +// trackListSort computes costs as a worst-case O(n^2) operation over the input list. +func trackListSort(args []ref.Val, result ref.Val) *uint64 { + return trackListSelfCompare(args[0].(traits.Lister)) +} + +// trackListSortBy computes costs as a worst-case O(n^2) operation over the sort index list. +func trackListSortBy(args []ref.Val, result ref.Val) *uint64 { + return trackListSelfCompare(args[1].(traits.Lister)) +} + +// trackListSelfCompare computes costs as a worst-case O(n^2) operation over the input list. +func trackListSelfCompare(l traits.Lister) *uint64 { + sz := actualSize(l) + costFactor := 2.0 + if sz == 0 { + return trackAllocatingListCall(costFactor, 0) + } + elem := l.Get(types.IntZero) + if elem.Type() == types.StringType || elem.Type() == types.BytesType { + costFactor += common.StringTraversalCostFactor + } + return trackAllocatingListCall(costFactor, sz*sz) +} + +// trackAllocatingListCall computes costs as a function of the size of the result list with a baseline cost +// for the call dispatch and the associated list allocation. +func trackAllocatingListCall(costFactor float64, size uint64) *uint64 { + cost := uint64(float64(size)*costFactor) + callCost + common.ListCreateBaseCost + return &cost +} + +func nodeAsIntValue(node checker.AstNode, defaultVal uint64) uint64 { + if node.Expr().Kind() != ast.LiteralKind { + return defaultVal + } + lit := node.Expr().AsLiteral() + if lit.Type() != types.IntType { + return defaultVal + } + val := lit.(types.Int) + if val < types.IntZero { + return 0 + } + return uint64(lit.(types.Int)) +} diff --git a/vendor/github.com/google/cel-go/ext/math.go b/vendor/github.com/google/cel-go/ext/math.go index 250246d..6df8e37 100644 --- a/vendor/github.com/google/cel-go/ext/math.go +++ b/vendor/github.com/google/cel-go/ext/math.go @@ -325,6 +325,23 @@ import ( // // math.isFinite(0.0/0.0) // returns false // math.isFinite(1.2) // returns true +// +// # Math.Sqrt +// +// Introduced at version: 2 +// +// Returns the square root of the given input as double +// Throws error for negative or non-numeric inputs +// +// math.sqrt(<double>) -> <double> +// math.sqrt(<int>) -> <double> +// math.sqrt(<uint>) -> <double> +// +// Examples: +// +// math.sqrt(81) // returns 9.0 +// math.sqrt(985.25) // returns 31.388692231439016 +// math.sqrt(-15) // returns NaN func Math(options ...MathOption) cel.EnvOption { m := &mathLib{version: math.MaxUint32} for _, o := range options { @@ -357,6 +374,9 @@ const ( absFunc = "math.abs" signFunc = "math.sign" + // SquareRoot function + sqrtFunc = "math.sqrt" + // Bitwise functions bitAndFunc = "math.bitAnd" bitOrFunc = "math.bitOr" @@ -548,6 +568,18 @@ func (lib *mathLib) CompileOptions() []cel.EnvOption { ), ) } + if lib.version >= 2 { + opts = append(opts, + cel.Function(sqrtFunc, + cel.Overload("math_sqrt_double", []*cel.Type{cel.DoubleType}, cel.DoubleType, + cel.UnaryBinding(sqrt)), + cel.Overload("math_sqrt_int", []*cel.Type{cel.IntType}, cel.DoubleType, + cel.UnaryBinding(sqrt)), + cel.Overload("math_sqrt_uint", []*cel.Type{cel.UintType}, cel.DoubleType, + cel.UnaryBinding(sqrt)), + ), + ) + } return opts } @@ -691,6 +723,21 @@ func sign(val ref.Val) ref.Val { } } + +func sqrt(val ref.Val) ref.Val { + switch v := val.(type) { + case types.Double: + return types.Double(math.Sqrt(float64(v))) + case types.Int: + return types.Double(math.Sqrt(float64(v))) + case types.Uint: + return types.Double(math.Sqrt(float64(v))) + default: + return types.NewErr("no such overload: sqrt") + } +} + + func bitAndPairInt(first, second ref.Val) ref.Val { l := first.(types.Int) r := second.(types.Int) diff --git a/vendor/github.com/google/cel-go/ext/native.go b/vendor/github.com/google/cel-go/ext/native.go index 1c33def..ceaa274 100644 --- a/vendor/github.com/google/cel-go/ext/native.go +++ b/vendor/github.com/google/cel-go/ext/native.go @@ -81,7 +81,7 @@ var ( // the time that it is invoked. // // There is also the possibility to rename the fields of native structs by setting the `cel` tag -// for fields you want to override. In order to enable this feature, pass in the `EnableStructTag` +// for fields you want to override. In order to enable this feature, pass in the `ParseStructTags(true)` // option. Here is an example to see it in action: // // ```go @@ -609,7 +609,8 @@ func newNativeTypes(fieldNameHandler NativeTypesFieldNameHandler, rawType reflec var iterateStructMembers func(reflect.Type) iterateStructMembers = func(t reflect.Type) { if k := t.Kind(); k == reflect.Pointer || k == reflect.Slice || k == reflect.Array || k == reflect.Map { - t = t.Elem() + iterateStructMembers(t.Elem()) + return } if t.Kind() != reflect.Struct { return diff --git a/vendor/github.com/google/cel-go/ext/regex.go b/vendor/github.com/google/cel-go/ext/regex.go new file mode 100644 index 0000000..1a66f65 --- /dev/null +++ b/vendor/github.com/google/cel-go/ext/regex.go @@ -0,0 +1,332 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ext + +import ( + "errors" + "fmt" + "math" + "regexp" + "strconv" + "strings" + + "github.com/google/cel-go/cel" + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" +) + +const ( + regexReplace = "regex.replace" + regexExtract = "regex.extract" + regexExtractAll = "regex.extractAll" +) + +// Regex returns a cel.EnvOption to configure extended functions for regular +// expression operations. +// +// Note: all functions use the 'regex' namespace. If you are +// currently using a variable named 'regex', the functions will likely work as +// intended, however there is some chance for collision. +// +// This library depends on the CEL optional type. Please ensure that the +// cel.OptionalTypes() is enabled when using regex extensions. +// +// # Replace +// +// The `regex.replace` function replaces all non-overlapping substring of a regex +// pattern in the target string with a replacement string. Optionally, you can +// limit the number of replacements by providing a count argument. When the count +// is a negative number, the function acts as replace all. Only numeric (\N) +// capture group references are supported in the replacement string, with +// validation for correctness. Backslashed-escaped digits (\1 to \9) within the +// replacement argument can be used to insert text matching the corresponding +// parenthesized group in the regexp pattern. An error will be thrown for invalid +// regex or replace string. +// +// regex.replace(target: string, pattern: string, replacement: string) -> string +// regex.replace(target: string, pattern: string, replacement: string, count: int) -> string +// +// Examples: +// +// regex.replace('hello world hello', 'hello', 'hi') == 'hi world hi' +// regex.replace('banana', 'a', 'x', 0) == 'banana' +// regex.replace('banana', 'a', 'x', 1) == 'bxnana' +// regex.replace('banana', 'a', 'x', 2) == 'bxnxna' +// regex.replace('banana', 'a', 'x', -12) == 'bxnxnx' +// regex.replace('foo bar', '(fo)o (ba)r', r'\2 \1') == 'ba fo' +// regex.replace('test', '(.)', r'\2') \\ Runtime Error invalid replace string +// regex.replace('foo bar', '(', '$2 $1') \\ Runtime Error invalid regex string +// regex.replace('id=123', r'id=(?P<value>\d+)', r'value: \values') \\ Runtime Error invalid replace string +// +// # Extract +// +// The `regex.extract` function returns the first match of a regex pattern in a +// string. If no match is found, it returns an optional none value. An error will +// be thrown for invalid regex or for multiple capture groups. +// +// regex.extract(target: string, pattern: string) -> optional<string> +// +// Examples: +// +// regex.extract('hello world', 'hello(.*)') == optional.of(' world') +// regex.extract('item-A, item-B', 'item-(\\w+)') == optional.of('A') +// regex.extract('HELLO', 'hello') == optional.empty() +// regex.extract('testuser@testdomain', '(.*)@([^.]*)') // Runtime Error multiple capture group +// +// # Extract All +// +// The `regex.extractAll` function returns a list of all matches of a regex +// pattern in a target string. If no matches are found, it returns an empty list. An error will +// be thrown for invalid regex or for multiple capture groups. +// +// regex.extractAll(target: string, pattern: string) -> list<string> +// +// Examples: +// +// regex.extractAll('id:123, id:456', 'id:\\d+') == ['id:123', 'id:456'] +// regex.extractAll('id:123, id:456', 'assa') == [] +// regex.extractAll('testuser@testdomain', '(.*)@([^.]*)') // Runtime Error multiple capture group +func Regex(options ...RegexOptions) cel.EnvOption { + s := ®exLib{ + version: math.MaxUint32, + } + for _, o := range options { + s = o(s) + } + return cel.Lib(s) +} + +// RegexOptions declares a functional operator for configuring regex extension. +type RegexOptions func(*regexLib) *regexLib + +// RegexVersion configures the version of the Regex library definitions to use. See [Regex] for supported values. +func RegexVersion(version uint32) RegexOptions { + return func(lib *regexLib) *regexLib { + lib.version = version + return lib + } +} + +type regexLib struct { + version uint32 +} + +// LibraryName implements that SingletonLibrary interface method. +func (r *regexLib) LibraryName() string { + return "cel.lib.ext.regex" +} + +// CompileOptions implements the cel.Library interface method. +func (r *regexLib) CompileOptions() []cel.EnvOption { + optionalTypesEnabled := func(env *cel.Env) (*cel.Env, error) { + if !env.HasLibrary("cel.lib.optional") { + return nil, errors.New("regex library requires the optional library") + } + return env, nil + } + opts := []cel.EnvOption{ + cel.Function(regexExtract, + cel.Overload("regex_extract_string_string", []*cel.Type{cel.StringType, cel.StringType}, cel.OptionalType(cel.StringType), + cel.BinaryBinding(extract))), + + cel.Function(regexExtractAll, + cel.Overload("regex_extractAll_string_string", []*cel.Type{cel.StringType, cel.StringType}, cel.ListType(cel.StringType), + cel.BinaryBinding(extractAll))), + + cel.Function(regexReplace, + cel.Overload("regex_replace_string_string_string", []*cel.Type{cel.StringType, cel.StringType, cel.StringType}, cel.StringType, + cel.FunctionBinding(regReplace)), + cel.Overload("regex_replace_string_string_string_int", []*cel.Type{cel.StringType, cel.StringType, cel.StringType, cel.IntType}, cel.StringType, + cel.FunctionBinding((regReplaceN))), + ), + cel.EnvOption(optionalTypesEnabled), + } + return opts +} + +// ProgramOptions implements the cel.Library interface method +func (r *regexLib) ProgramOptions() []cel.ProgramOption { + return []cel.ProgramOption{} +} + +func compileRegex(regexStr string) (*regexp.Regexp, error) { + re, err := regexp.Compile(regexStr) + if err != nil { + return nil, fmt.Errorf("given regex is invalid: %w", err) + } + return re, nil +} + +func regReplace(args ...ref.Val) ref.Val { + target := args[0].(types.String) + regexStr := args[1].(types.String) + replaceStr := args[2].(types.String) + + return regReplaceN(target, regexStr, replaceStr, types.Int(-1)) +} + +func regReplaceN(args ...ref.Val) ref.Val { + target := string(args[0].(types.String)) + regexStr := string(args[1].(types.String)) + replaceStr := string(args[2].(types.String)) + replaceCount := int64(args[3].(types.Int)) + + if replaceCount == 0 { + return types.String(target) + } + + if replaceCount > math.MaxInt32 { + return types.NewErr("integer overflow") + } + + // If replaceCount is negative, just do a replaceAll. + if replaceCount < 0 { + replaceCount = -1 + } + + re, err := regexp.Compile(regexStr) + if err != nil { + return types.WrapErr(err) + } + + var resultBuilder strings.Builder + var lastIndex int + counter := int64(0) + + matches := re.FindAllStringSubmatchIndex(target, -1) + + for _, match := range matches { + if replaceCount != -1 && counter >= replaceCount { + break + } + + processedReplacement, err := replaceStrValidator(target, re, match, replaceStr) + if err != nil { + return types.WrapErr(err) + } + + resultBuilder.WriteString(target[lastIndex:match[0]]) + resultBuilder.WriteString(processedReplacement) + lastIndex = match[1] + counter++ + } + + resultBuilder.WriteString(target[lastIndex:]) + return types.String(resultBuilder.String()) +} + +func replaceStrValidator(target string, re *regexp.Regexp, match []int, replacement string) (string, error) { + groupCount := re.NumSubexp() + var sb strings.Builder + runes := []rune(replacement) + + for i := 0; i < len(runes); i++ { + c := runes[i] + + if c != '\\' { + sb.WriteRune(c) + continue + } + + if i+1 >= len(runes) { + return "", fmt.Errorf("invalid replacement string: '%s' \\ not allowed at end", replacement) + } + + i++ + nextChar := runes[i] + + if nextChar == '\\' { + sb.WriteRune('\\') + continue + } + + groupNum, err := strconv.Atoi(string(nextChar)) + if err != nil { + return "", fmt.Errorf("invalid replacement string: '%s' \\ must be followed by a digit or \\", replacement) + } + + if groupNum > groupCount { + return "", fmt.Errorf("replacement string references group %d but regex has only %d group(s)", groupNum, groupCount) + } + + if match[2*groupNum] != -1 { + sb.WriteString(target[match[2*groupNum]:match[2*groupNum+1]]) + } + } + return sb.String(), nil +} + +func extract(target, regexStr ref.Val) ref.Val { + t := string(target.(types.String)) + r := string(regexStr.(types.String)) + re, err := compileRegex(r) + if err != nil { + return types.WrapErr(err) + } + + if len(re.SubexpNames())-1 > 1 { + return types.WrapErr(fmt.Errorf("regular expression has more than one capturing group: %q", r)) + } + + matches := re.FindStringSubmatch(t) + if len(matches) == 0 { + return types.OptionalNone + } + + // If there is a capturing group, return the first match; otherwise, return the whole match. + if len(matches) > 1 { + capturedGroup := matches[1] + // If optional group is empty, return OptionalNone. + if capturedGroup == "" { + return types.OptionalNone + } + return types.OptionalOf(types.String(capturedGroup)) + } + return types.OptionalOf(types.String(matches[0])) +} + +func extractAll(target, regexStr ref.Val) ref.Val { + t := string(target.(types.String)) + r := string(regexStr.(types.String)) + re, err := compileRegex(r) + if err != nil { + return types.WrapErr(err) + } + + groupCount := len(re.SubexpNames()) - 1 + if groupCount > 1 { + return types.WrapErr(fmt.Errorf("regular expression has more than one capturing group: %q", r)) + } + + matches := re.FindAllStringSubmatch(t, -1) + result := make([]string, 0, len(matches)) + if len(matches) == 0 { + return types.NewStringList(types.DefaultTypeAdapter, result) + } + + if groupCount != 1 { + for _, match := range matches { + result = append(result, match[0]) + } + return types.NewStringList(types.DefaultTypeAdapter, result) + } + + for _, match := range matches { + if match[1] != "" { + result = append(result, match[1]) + } + } + return types.NewStringList(types.DefaultTypeAdapter, result) +} diff --git a/vendor/github.com/google/cel-go/ext/sets.go b/vendor/github.com/google/cel-go/ext/sets.go index 9a9ef6e..ecac4bf 100644 --- a/vendor/github.com/google/cel-go/ext/sets.go +++ b/vendor/github.com/google/cel-go/ext/sets.go @@ -236,13 +236,13 @@ func setsEquivalent(listA, listB ref.Val) ref.Val { func estimateSetsCost(costFactor float64) checker.FunctionEstimator { return func(estimator checker.CostEstimator, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate { - if len(args) == 2 { - arg0Size := estimateSize(estimator, args[0]) - arg1Size := estimateSize(estimator, args[1]) - costEstimate := arg0Size.Multiply(arg1Size).MultiplyByCostFactor(costFactor).Add(callCostEstimate) - return &checker.CallEstimate{CostEstimate: costEstimate} + if len(args) != 2 { + return nil } - return nil + arg0Size := estimateSize(estimator, args[0]) + arg1Size := estimateSize(estimator, args[1]) + costEstimate := arg0Size.Multiply(arg1Size).MultiplyByCostFactor(costFactor).Add(callCostEstimate) + return &checker.CallEstimate{CostEstimate: costEstimate} } } @@ -273,6 +273,6 @@ func actualSize(value ref.Val) uint64 { } var ( - callCostEstimate = checker.CostEstimate{Min: 1, Max: 1} + callCostEstimate = checker.FixedCostEstimate(1) callCost = uint64(1) ) diff --git a/vendor/github.com/google/cel-go/ext/strings.go b/vendor/github.com/google/cel-go/ext/strings.go index 2e590a4..de65421 100644 --- a/vendor/github.com/google/cel-go/ext/strings.go +++ b/vendor/github.com/google/cel-go/ext/strings.go @@ -286,10 +286,15 @@ const ( // // 'gums'.reverse() // returns 'smug' // 'John Smith'.reverse() // returns 'htimS nhoJ' +// +// Introduced at version: 4 +// +// Formatting updated to adhere to https://github.com/google/cel-spec/blob/master/doc/extensions/strings.md. +// +// <string>.format(<list>) -> <string> func Strings(options ...StringsOption) cel.EnvOption { s := &stringLib{ - version: math.MaxUint32, - validateFormat: true, + version: math.MaxUint32, } for _, o := range options { s = o(s) @@ -298,9 +303,8 @@ func Strings(options ...StringsOption) cel.EnvOption { } type stringLib struct { - locale string - version uint32 - validateFormat bool + locale string + version uint32 } // LibraryName implements the SingletonLibrary interface method. @@ -314,6 +318,8 @@ type StringsOption func(*stringLib) *stringLib // StringsLocale configures the library with the given locale. The locale tag will // be checked for validity at the time that EnvOptions are configured. If this option // is not passed, string.format will behave as if en_US was passed as the locale. +// +// If StringsVersion is greater than or equal to 4, this option is ignored. func StringsLocale(locale string) StringsOption { return func(sl *stringLib) *stringLib { sl.locale = locale @@ -340,10 +346,9 @@ func StringsVersion(version uint32) StringsOption { // StringsValidateFormatCalls validates type-checked ASTs to ensure that string.format() calls have // valid formatting clauses and valid argument types for each clause. // -// Enabled by default. +// Deprecated func StringsValidateFormatCalls(value bool) StringsOption { return func(s *stringLib) *stringLib { - s.validateFormat = value return s } } @@ -351,7 +356,7 @@ func StringsValidateFormatCalls(value bool) StringsOption { // CompileOptions implements the Library interface method. func (lib *stringLib) CompileOptions() []cel.EnvOption { formatLocale := "en_US" - if lib.locale != "" { + if lib.version < 4 && lib.locale != "" { // ensure locale is properly-formed if set _, err := language.Parse(lib.locale) if err != nil { @@ -466,21 +471,29 @@ func (lib *stringLib) CompileOptions() []cel.EnvOption { }))), } if lib.version >= 1 { - opts = append(opts, cel.Function("format", - cel.MemberOverload("string_format", []*cel.Type{cel.StringType, cel.ListType(cel.DynType)}, cel.StringType, - cel.FunctionBinding(func(args ...ref.Val) ref.Val { - s := string(args[0].(types.String)) - formatArgs := args[1].(traits.Lister) - return stringOrError(parseFormatString(s, &stringFormatter{}, &stringArgList{formatArgs}, formatLocale)) - }))), + if lib.version >= 4 { + opts = append(opts, cel.Function("format", + cel.MemberOverload("string_format", []*cel.Type{cel.StringType, cel.ListType(cel.DynType)}, cel.StringType, + cel.FunctionBinding(func(args ...ref.Val) ref.Val { + s := string(args[0].(types.String)) + formatArgs := args[1].(traits.Lister) + return stringOrError(parseFormatStringV2(s, &stringFormatterV2{}, &stringArgList{formatArgs})) + })))) + } else { + opts = append(opts, cel.Function("format", + cel.MemberOverload("string_format", []*cel.Type{cel.StringType, cel.ListType(cel.DynType)}, cel.StringType, + cel.FunctionBinding(func(args ...ref.Val) ref.Val { + s := string(args[0].(types.String)) + formatArgs := args[1].(traits.Lister) + return stringOrError(parseFormatString(s, &stringFormatter{}, &stringArgList{formatArgs}, formatLocale)) + })))) + } + opts = append(opts, cel.Function("strings.quote", cel.Overload("strings_quote", []*cel.Type{cel.StringType}, cel.StringType, cel.UnaryBinding(func(str ref.Val) ref.Val { s := str.(types.String) return stringOrError(quote(string(s))) - }))), - - cel.ASTValidators(stringFormatValidator{})) - + })))) } if lib.version >= 2 { opts = append(opts, @@ -529,8 +542,12 @@ func (lib *stringLib) CompileOptions() []cel.EnvOption { }))), ) } - if lib.validateFormat { - opts = append(opts, cel.ASTValidators(stringFormatValidator{})) + if lib.version >= 1 { + if lib.version >= 4 { + opts = append(opts, cel.ASTValidators(stringFormatValidatorV2{})) + } else { + opts = append(opts, cel.ASTValidators(stringFormatValidator{})) + } } return opts } @@ -590,6 +607,10 @@ func lastIndexOf(str, substr string) (int64, error) { if substr == "" { return int64(len(runes)), nil } + + if len(str) < len(substr) { + return -1, nil + } return lastIndexOfOffset(str, substr, int64(len(runes)-1)) } |
