diff --git a/pkg/expansion/expand_match.go b/pkg/expansion/expand_match.go index 1aa2c253..b9d1c96e 100644 --- a/pkg/expansion/expand_match.go +++ b/pkg/expansion/expand_match.go @@ -12,39 +12,60 @@ type ExpandRegexMatch struct { Only []string } -var DefaultRefRegexp = regexp.MustCompile(`((secret)?ref)\+([^\+:]*:\/\/[^\+\n ]+[^\+\n ",])\+?`) +var DefaultRefRegexp = regexp.MustCompile(`((secret)?ref)\+([^:]*:\/\/[^ \n",]*[^ \n",+])\+?`) func (e *ExpandRegexMatch) InString(s string) (string, error) { - var sb strings.Builder - for { - ixs := e.Target.FindStringSubmatchIndex(s) - if ixs == nil { - sb.WriteString(s) - return sb.String(), nil - } - kind := s[ixs[2]:ixs[3]] - if len(e.Only) > 0 { - var shouldExpand bool - for _, k := range e.Only { - if k == kind { - shouldExpand = true + // Keep expanding until no more expressions are found (for nested expressions) + maxIterations := 10 // Prevent infinite loops + iteration := 0 + + for iteration < maxIterations { + originalString := s + var sb strings.Builder + hasChanges := false + + for { + ixs := e.Target.FindStringSubmatchIndex(s) + if ixs == nil { + sb.WriteString(s) + break + } + kind := s[ixs[2]:ixs[3]] + if len(e.Only) > 0 { + var shouldExpand bool + for _, k := range e.Only { + if k == kind { + shouldExpand = true + break + } + } + if !shouldExpand { + sb.WriteString(s) break } } - if !shouldExpand { - sb.WriteString(s) - return sb.String(), nil + ref := s[ixs[6]:ixs[7]] + val, err := e.Lookup(ref) + if err != nil { + return "", fmt.Errorf("expand %s: %v", ref, err) } + sb.WriteString(s[:ixs[0]]) + sb.WriteString(val) + s = s[ixs[1]:] + hasChanges = true } - ref := s[ixs[6]:ixs[7]] - val, err := e.Lookup(ref) - if err != nil { - return "", fmt.Errorf("expand %s: %v", ref, err) + + s = sb.String() + + // If no changes were made in this iteration, we're done + if !hasChanges || s == originalString { + return s, nil } - sb.WriteString(s[:ixs[0]]) - sb.WriteString(val) - s = s[ixs[1]:] + + iteration++ } + + return "", fmt.Errorf("maximum iterations (%d) reached while expanding nested expressions", maxIterations) } func (e *ExpandRegexMatch) InMap(target map[string]interface{}) (map[string]interface{}, error) { diff --git a/vals_test.go b/vals_test.go index 9fdfe9de..a3931b3a 100644 --- a/vals_test.go +++ b/vals_test.go @@ -187,3 +187,179 @@ datetime_offset: "2025-01-01T12:34:56+01:00" require.Equal(t, expected, buf.String()) } + +func TestNestedExpressions(t *testing.T) { + // Set up test environment variables + os.Setenv("TEST_VAR", "hello-world") + os.Setenv("NESTED_VAR", "nested-value") + defer func() { + os.Unsetenv("TEST_VAR") + os.Unsetenv("NESTED_VAR") + }() + + tests := []struct { + input map[string]interface{} + expected map[string]interface{} + name string + }{ + { + name: "echo with envsubst nested", + input: map[string]interface{}{ + "test": "ref+echo://ref+envsubst://$TEST_VAR/foo", + }, + expected: map[string]interface{}{ + "test": "hello-world/foo", + }, + }, + { + name: "envsubst with echo nested", + input: map[string]interface{}{ + "test": "ref+envsubst://prefix-ref+echo://$NESTED_VAR-suffix", + }, + expected: map[string]interface{}{ + "test": "prefix-nested-value-suffix", + }, + }, + { + name: "multiple nested expressions", + input: map[string]interface{}{ + "test1": "ref+echo://ref+envsubst://$TEST_VAR/path", + "test2": "ref+envsubst://ref+echo://$NESTED_VAR", + }, + expected: map[string]interface{}{ + "test1": "hello-world/path", + "test2": "nested-value", + }, + }, + { + name: "deeply nested expressions", + input: map[string]interface{}{ + "test": "ref+echo://prefix/ref+envsubst://ref+echo://$NESTED_VAR/suffix", + }, + expected: map[string]interface{}{ + "test": "prefix/nested-value/suffix", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := Eval(tt.input) + require.NoError(t, err) + require.Equal(t, tt.expected, result) + }) + } +} + +func TestNestedExpressionsWithGet(t *testing.T) { + // Set up test environment variables + os.Setenv("TEST_VAR", "hello-world") + defer os.Unsetenv("TEST_VAR") + + runtime, err := New(Options{}) + require.NoError(t, err) + + tests := []struct { + name string + input string + expected string + }{ + { + name: "simple nested expression", + input: "ref+echo://ref+envsubst://$TEST_VAR/foo", + expected: "hello-world/foo", + }, + { + name: "envsubst with echo nested", + input: "ref+envsubst://prefix-ref+echo://suffix", + expected: "prefix-suffix", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := runtime.Get(tt.input) + require.NoError(t, err) + require.Equal(t, tt.expected, result) + }) + } +} + +func TestNestedExpressionsBackwardCompatibility(t *testing.T) { + // Ensure that existing non-nested expressions still work + tests := []struct { + input map[string]interface{} + expected map[string]interface{} + name string + }{ + { + name: "simple echo", + input: map[string]interface{}{ + "test": "ref+echo://hello-world", + }, + expected: map[string]interface{}{ + "test": "hello-world", + }, + }, + { + name: "echo with fragment", + input: map[string]interface{}{ + "test": "ref+echo://foo/bar/baz#/foo/bar", + }, + expected: map[string]interface{}{ + "test": "baz", + }, + }, + { + name: "file provider", + input: map[string]interface{}{ + "test": "ref+file://./myjson.json#/baz/mykey", + }, + expected: map[string]interface{}{ + "test": "myvalue", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := Eval(tt.input) + require.NoError(t, err) + require.Equal(t, tt.expected, result) + }) + } +} + +func TestNestedExpressionsEdgeCases(t *testing.T) { + // Set up test environment variables + os.Setenv("EDGE_VAR", "edge-value") + defer os.Unsetenv("EDGE_VAR") + + tests := []struct { + name string + input string + expected string + expectErr bool + }{ + { + name: "nested expression with special characters", + input: "ref+echo://ref+envsubst://$EDGE_VAR-test_123", + expected: "edge-value-test_123", + }, + } + + runtime, err := New(Options{}) + require.NoError(t, err) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := runtime.Get(tt.input) + if tt.expectErr { + require.Error(t, err) + } else { + require.NoError(t, err) + require.Equal(t, tt.expected, result) + } + }) + } +}