diff --git a/template/interpolate/funcs.go b/template/interpolate/funcs.go new file mode 100644 index 000000000..f3c17d8b7 --- /dev/null +++ b/template/interpolate/funcs.go @@ -0,0 +1,46 @@ +package interpolate + +import ( + "errors" + "os" + "text/template" +) + +// Funcs are the interpolation funcs that are available within interpolations. +var FuncGens = map[string]FuncGenerator{ + "env": funcGenEnv, + "user": funcGenUser, +} + +// FuncGenerator is a function that given a context generates a template +// function for the template. +type FuncGenerator func(*Context) interface{} + +// Funcs returns the functions that can be used for interpolation given +// a context. +func Funcs(ctx *Context) template.FuncMap { + result := make(map[string]interface{}) + for k, v := range FuncGens { + result[k] = v(ctx) + } + + return template.FuncMap(result) +} + +func funcGenEnv(ctx *Context) interface{} { + return func(k string) (string, error) { + if ctx.DisableEnv { + // The error message doesn't have to be that detailed since + // semantic checks should catch this. + return "", errors.New("env vars are not allowed here") + } + + return os.Getenv(k), nil + } +} + +func funcGenUser(ctx *Context) interface{} { + return func() string { + return "" + } +} diff --git a/template/interpolate/funcs_test.go b/template/interpolate/funcs_test.go new file mode 100644 index 000000000..2bc70b0bf --- /dev/null +++ b/template/interpolate/funcs_test.go @@ -0,0 +1,66 @@ +package interpolate + +import ( + "os" + "testing" +) + +func TestFuncEnv(t *testing.T) { + cases := []struct { + Input string + Output string + }{ + { + `{{env "PACKER_TEST_ENV"}}`, + `foo`, + }, + + { + `{{env "PACKER_TEST_ENV_NOPE"}}`, + ``, + }, + } + + os.Setenv("PACKER_TEST_ENV", "foo") + defer os.Setenv("PACKER_TEST_ENV", "") + + ctx := &Context{} + for _, tc := range cases { + i := &I{Value: tc.Input} + result, err := i.Render(ctx) + if err != nil { + t.Fatalf("Input: %s\n\nerr: %s", tc.Input, err) + } + + if result != tc.Output { + t.Fatalf("Input: %s\n\nGot: %s", tc.Input, result) + } + } +} + +func TestFuncEnv_disable(t *testing.T) { + cases := []struct { + Input string + Output string + Error bool + }{ + { + `{{env "PACKER_TEST_ENV"}}`, + "", + true, + }, + } + + ctx := &Context{DisableEnv: true} + for _, tc := range cases { + i := &I{Value: tc.Input} + result, err := i.Render(ctx) + if (err != nil) != tc.Error { + t.Fatalf("Input: %s\n\nerr: %s", tc.Input, err) + } + + if result != tc.Output { + t.Fatalf("Input: %s\n\nGot: %s", tc.Input, result) + } + } +} diff --git a/template/interpolate/i.go b/template/interpolate/i.go new file mode 100644 index 000000000..68095a03f --- /dev/null +++ b/template/interpolate/i.go @@ -0,0 +1,38 @@ +package interpolate + +import ( + "bytes" + "text/template" +) + +// Context is the context that an interpolation is done in. This defines +// things such as available variables. +type Context struct { + DisableEnv bool +} + +// I stands for "interpolation" and is the main interpolation struct +// in order to render values. +type I struct { + Value string +} + +// Render renders the interpolation with the given context. +func (i *I) Render(ctx *Context) (string, error) { + tpl, err := i.template(ctx) + if err != nil { + return "", err + } + + var result bytes.Buffer + data := map[string]interface{}{} + if err := tpl.Execute(&result, data); err != nil { + return "", err + } + + return result.String(), nil +} + +func (i *I) template(ctx *Context) (*template.Template, error) { + return template.New("root").Funcs(Funcs(ctx)).Parse(i.Value) +} diff --git a/template/interpolate/i_test.go b/template/interpolate/i_test.go new file mode 100644 index 000000000..a678afbc4 --- /dev/null +++ b/template/interpolate/i_test.go @@ -0,0 +1,32 @@ +package interpolate + +import ( + "testing" +) + +func TestIRender(t *testing.T) { + cases := map[string]struct { + Ctx *Context + Value string + Result string + }{ + "basic": { + nil, + "foo", + "foo", + }, + } + + for k, tc := range cases { + i := &I{Value: tc.Value} + result, err := i.Render(tc.Ctx) + if err != nil { + t.Fatalf("%s\n\ninput: %s\n\nerr: %s", k, tc.Value, err) + } + if result != tc.Result { + t.Fatalf( + "%s\n\ninput: %s\n\nexpected: %s\n\ngot: %s", + k, tc.Value, tc.Result, result) + } + } +} diff --git a/template/interpolate/parse.go b/template/interpolate/parse.go new file mode 100644 index 000000000..b18079510 --- /dev/null +++ b/template/interpolate/parse.go @@ -0,0 +1,42 @@ +package interpolate + +import ( + "fmt" + "text/template" + "text/template/parse" +) + +// functionsCalled returns a map (to be used as a set) of the functions +// that are called from the given text template. +func functionsCalled(t *template.Template) map[string]struct{} { + result := make(map[string]struct{}) + functionsCalledWalk(t.Tree.Root, result) + return result +} + +func functionsCalledWalk(raw parse.Node, r map[string]struct{}) { + switch node := raw.(type) { + case *parse.ActionNode: + functionsCalledWalk(node.Pipe, r) + case *parse.CommandNode: + if in, ok := node.Args[0].(*parse.IdentifierNode); ok { + r[in.Ident] = struct{}{} + } + + for _, n := range node.Args[1:] { + functionsCalledWalk(n, r) + } + case *parse.ListNode: + for _, n := range node.Nodes { + functionsCalledWalk(n, r) + } + case *parse.PipeNode: + for _, n := range node.Cmds { + functionsCalledWalk(n, r) + } + case *parse.StringNode, *parse.TextNode: + // Ignore + default: + panic(fmt.Sprintf("unknown type: %T", node)) + } +} diff --git a/template/interpolate/parse_test.go b/template/interpolate/parse_test.go new file mode 100644 index 000000000..3398ddbf1 --- /dev/null +++ b/template/interpolate/parse_test.go @@ -0,0 +1,39 @@ +package interpolate + +import ( + "reflect" + "testing" + "text/template" +) + +func TestFunctionsCalled(t *testing.T) { + cases := []struct { + Input string + Result map[string]struct{} + }{ + { + "foo", + map[string]struct{}{}, + }, + + { + "foo {{user `bar`}}", + map[string]struct{}{ + "user": struct{}{}, + }, + }, + } + + funcs := Funcs(&Context{}) + for _, tc := range cases { + tpl, err := template.New("root").Funcs(funcs).Parse(tc.Input) + if err != nil { + t.Fatalf("err parsing: %v\n\n%s", tc.Input, err) + } + + actual := functionsCalled(tpl) + if !reflect.DeepEqual(actual, tc.Result) { + t.Fatalf("bad: %v\n\ngot: %#v", tc.Input, actual) + } + } +}