diff --git a/internal/builtin/providers/terraform/provider.go b/internal/builtin/providers/terraform/provider.go index 52f776d47b..816914a38c 100644 --- a/internal/builtin/providers/terraform/provider.go +++ b/internal/builtin/providers/terraform/provider.go @@ -135,6 +135,12 @@ func (p *Provider) ValidateResourceConfig(req providers.ValidateResourceConfigRe return validateDataStoreResourceConfig(req) } +// CallFunction would call a function contributed by this provider, but this +// provider has no functions and so this function just panics. +func (p *Provider) CallFunction(providers.CallFunctionRequest) providers.CallFunctionResponse { + panic("unimplemented - terraform.io/builtin/terraform provider has no functions") +} + // Close is a noop for this provider, since it's run in-process. func (p *Provider) Close() error { return nil diff --git a/internal/plugin/convert/functions.go b/internal/plugin/convert/functions.go new file mode 100644 index 0000000000..4f3598ceb0 --- /dev/null +++ b/internal/plugin/convert/functions.go @@ -0,0 +1,135 @@ +package convert + +import ( + "encoding/json" + "fmt" + + "github.com/hashicorp/terraform/internal/providers" + "github.com/hashicorp/terraform/internal/tfplugin5" +) + +func FunctionDeclsFromProto(protoFuncs map[string]*tfplugin5.Function) (map[string]providers.FunctionDecl, error) { + if len(protoFuncs) == 0 { + return nil, nil + } + + ret := make(map[string]providers.FunctionDecl, len(protoFuncs)) + for name, protoFunc := range protoFuncs { + decl, err := FunctionDeclFromProto(protoFunc) + if err != nil { + return nil, fmt.Errorf("invalid declaration for function %q: %s", name, err) + } + ret[name] = decl + } + return ret, nil +} + +func FunctionDeclFromProto(protoFunc *tfplugin5.Function) (providers.FunctionDecl, error) { + var ret providers.FunctionDecl + + ret.Description = protoFunc.Description + ret.DescriptionKind = schemaStringKind(protoFunc.DescriptionKind) + + if err := json.Unmarshal(protoFunc.Return.Type, &ret.ReturnType); err != nil { + return ret, fmt.Errorf("invalid return type constraint: %s", err) + } + + if len(protoFunc.Parameters) != 0 { + ret.Parameters = make([]providers.FunctionParam, len(protoFunc.Parameters)) + for i, protoParam := range protoFunc.Parameters { + param, err := functionParamFromProto(protoParam) + if err != nil { + return ret, fmt.Errorf("invalid parameter %d (%q): %s", i, protoParam.Name, err) + } + ret.Parameters[i] = param + } + } + if protoFunc.VariadicParameter != nil { + param, err := functionParamFromProto(protoFunc.VariadicParameter) + if err != nil { + return ret, fmt.Errorf("invalid variadic parameter (%q): %s", protoFunc.VariadicParameter.Name, err) + } + ret.VariadicParameter = ¶m + } + + return ret, nil +} + +func functionParamFromProto(protoParam *tfplugin5.Function_Parameter) (providers.FunctionParam, error) { + var ret providers.FunctionParam + ret.Name = protoParam.Name + ret.Description = protoParam.Description + ret.DescriptionKind = schemaStringKind(protoParam.DescriptionKind) + ret.AllowNullValue = protoParam.AllowNullValue + ret.AllowUnknownValues = protoParam.AllowUnknownValues + if err := json.Unmarshal(protoParam.Type, &ret.Type); err != nil { + return ret, fmt.Errorf("invalid type constraint: %s", err) + } + return ret, nil +} + +func FunctionDeclsToProto(fns map[string]providers.FunctionDecl) (map[string]*tfplugin5.Function, error) { + if len(fns) == 0 { + return nil, nil + } + + ret := make(map[string]*tfplugin5.Function, len(fns)) + for name, fn := range fns { + decl, err := FunctionDeclToProto(fn) + if err != nil { + return nil, fmt.Errorf("invalid declaration for function %q: %s", name, err) + } + ret[name] = decl + } + return ret, nil +} + +func FunctionDeclToProto(fn providers.FunctionDecl) (*tfplugin5.Function, error) { + ret := &tfplugin5.Function{ + Return: &tfplugin5.Function_Return{}, + } + + ret.Description = fn.Description + ret.DescriptionKind = protoStringKind(fn.DescriptionKind) + + retTy, err := json.Marshal(fn.ReturnType) + if err != nil { + return ret, fmt.Errorf("invalid return type constraint: %s", err) + } + ret.Return.Type = retTy + + if len(fn.Parameters) != 0 { + ret.Parameters = make([]*tfplugin5.Function_Parameter, len(fn.Parameters)) + for i, fnParam := range fn.Parameters { + protoParam, err := functionParamToProto(fnParam) + if err != nil { + return ret, fmt.Errorf("invalid parameter %d (%q): %s", i, fnParam.Name, err) + } + ret.Parameters[i] = protoParam + } + } + if fn.VariadicParameter != nil { + param, err := functionParamToProto(*fn.VariadicParameter) + if err != nil { + return ret, fmt.Errorf("invalid variadic parameter (%q): %s", fn.VariadicParameter.Name, err) + } + ret.VariadicParameter = param + } + + return ret, nil +} + +func functionParamToProto(param providers.FunctionParam) (*tfplugin5.Function_Parameter, error) { + ret := &tfplugin5.Function_Parameter{} + ret.Name = param.Name + ret.Description = param.Description + ret.DescriptionKind = protoStringKind(param.DescriptionKind) + ret.AllowNullValue = param.AllowNullValue + ret.AllowUnknownValues = param.AllowUnknownValues + ty, err := json.Marshal(param.Type) + if err != nil { + return ret, fmt.Errorf("invalid type constraint: %s", err) + } + ret.Type = ty + return ret, nil +} diff --git a/internal/plugin/grpc_provider.go b/internal/plugin/grpc_provider.go index 89b66f48df..c430c01f89 100644 --- a/internal/plugin/grpc_provider.go +++ b/internal/plugin/grpc_provider.go @@ -141,6 +141,13 @@ func (p *GRPCProvider) GetProviderSchema() providers.GetProviderSchemaResponse { resp.DataSources[name] = convert.ProtoToProviderSchema(data) } + if decls, err := convert.FunctionDeclsFromProto(protoResp.Functions); err == nil { + resp.Functions = decls + } else { + resp.Diagnostics = resp.Diagnostics.Append(err) + return resp + } + if protoResp.ServerCapabilities != nil { resp.ServerCapabilities.PlanDestroy = protoResp.ServerCapabilities.PlanDestroy resp.ServerCapabilities.GetProviderSchemaOptional = protoResp.ServerCapabilities.GetProviderSchemaOptional @@ -678,6 +685,78 @@ func (p *GRPCProvider) ReadDataSource(r providers.ReadDataSourceRequest) (resp p return resp } +func (p *GRPCProvider) CallFunction(r providers.CallFunctionRequest) (resp providers.CallFunctionResponse) { + logger.Trace("GRPCProvider", "CallFunction", r.FunctionName) + + schema := p.GetProviderSchema() + if schema.Diagnostics.HasErrors() { + resp.Diagnostics = schema.Diagnostics + return resp + } + + funcDecl, ok := schema.Functions[r.FunctionName] + // We check for various problems with the request below in the interests + // of robustness, just to avoid crashing while trying to encode/decode, but + // if we reach any of these errors then that suggests a bug in the caller, + // because we should catch function calls that don't match the schema at an + // earlier point than this. + if !ok { + // Should only get here if the caller has a bug, because we should + // have detected earlier any attempt to call a function that the + // provider didn't declare. + resp.Diagnostics = resp.Diagnostics.Append(fmt.Errorf("provider has no function named %q", r.FunctionName)) + return resp + } + if len(r.Arguments) < len(funcDecl.Parameters) { + resp.Diagnostics = resp.Diagnostics.Append(fmt.Errorf("not enough arguments for function %q", r.FunctionName)) + return resp + } + if funcDecl.VariadicParameter == nil && len(r.Arguments) > len(funcDecl.Parameters) { + resp.Diagnostics = resp.Diagnostics.Append(fmt.Errorf("too many arguments for function %q", r.FunctionName)) + return resp + } + args := make([]*proto.DynamicValue, len(r.Arguments)) + for i, argVal := range r.Arguments { + var paramDecl providers.FunctionParam + if i < len(funcDecl.Parameters) { + paramDecl = funcDecl.Parameters[i] + } else { + paramDecl = *funcDecl.VariadicParameter + } + + argValRaw, err := msgpack.Marshal(argVal, paramDecl.Type) + if err != nil { + resp.Diagnostics = resp.Diagnostics.Append(err) + return resp + } + args[i] = &proto.DynamicValue{ + Msgpack: argValRaw, + } + } + + protoResp, err := p.client.CallFunction(p.ctx, &proto.CallFunction_Request{ + Name: r.FunctionName, + Arguments: args, + }) + if err != nil { + resp.Diagnostics = resp.Diagnostics.Append(grpcErr(err)) + return resp + } + resp.Diagnostics = resp.Diagnostics.Append(convert.ProtoToDiagnostics(protoResp.Diagnostics)) + if resp.Diagnostics.HasErrors() { + return resp + } + + resultVal, err := decodeDynamicValue(protoResp.Result, funcDecl.ReturnType) + if err != nil { + resp.Diagnostics = resp.Diagnostics.Append(err) + return resp + } + + resp.Result = resultVal + return resp +} + // closing the grpc connection is final, and terraform will call it at the end of every phase. func (p *GRPCProvider) Close() error { logger.Trace("GRPCProvider: Close") diff --git a/internal/plugin6/convert/diagnostics.go b/internal/plugin6/convert/diagnostics.go index 2e6ccbe633..a1cd57a234 100644 --- a/internal/plugin6/convert/diagnostics.go +++ b/internal/plugin6/convert/diagnostics.go @@ -44,6 +44,14 @@ func AppendProtoDiag(diags []*proto.Diagnostic, d interface{}) []*proto.Diagnost Severity: proto.Diagnostic_WARNING, Summary: d, }) + case tfdiags.Diagnostic: + diags = append(diags, DiagnosticToProto(d)) + + case tfdiags.Diagnostics: + for _, diag := range d { + diags = append(diags, DiagnosticToProto(diag)) + } + case *proto.Diagnostic: diags = append(diags, d) case []*proto.Diagnostic: @@ -52,6 +60,21 @@ func AppendProtoDiag(diags []*proto.Diagnostic, d interface{}) []*proto.Diagnost return diags } +func DiagnosticToProto(diag tfdiags.Diagnostic) *proto.Diagnostic { + ret := &proto.Diagnostic{} + switch diag.Severity() { + case tfdiags.Error: + ret.Severity = proto.Diagnostic_ERROR + case tfdiags.Warning: + ret.Severity = proto.Diagnostic_WARNING + } + + desc := diag.Description() + ret.Summary = desc.Summary + ret.Detail = desc.Detail + return ret +} + // ProtoToDiagnostics converts a list of proto.Diagnostics to a tf.Diagnostics. func ProtoToDiagnostics(ds []*proto.Diagnostic) tfdiags.Diagnostics { var diags tfdiags.Diagnostics diff --git a/internal/plugin6/convert/functions.go b/internal/plugin6/convert/functions.go new file mode 100644 index 0000000000..54eb50a14c --- /dev/null +++ b/internal/plugin6/convert/functions.go @@ -0,0 +1,135 @@ +package convert + +import ( + "encoding/json" + "fmt" + + "github.com/hashicorp/terraform/internal/providers" + "github.com/hashicorp/terraform/internal/tfplugin6" +) + +func FunctionDeclsFromProto(protoFuncs map[string]*tfplugin6.Function) (map[string]providers.FunctionDecl, error) { + if len(protoFuncs) == 0 { + return nil, nil + } + + ret := make(map[string]providers.FunctionDecl, len(protoFuncs)) + for name, protoFunc := range protoFuncs { + decl, err := FunctionDeclFromProto(protoFunc) + if err != nil { + return nil, fmt.Errorf("invalid declaration for function %q: %s", name, err) + } + ret[name] = decl + } + return ret, nil +} + +func FunctionDeclFromProto(protoFunc *tfplugin6.Function) (providers.FunctionDecl, error) { + var ret providers.FunctionDecl + + ret.Description = protoFunc.Description + ret.DescriptionKind = schemaStringKind(protoFunc.DescriptionKind) + + if err := json.Unmarshal(protoFunc.Return.Type, &ret.ReturnType); err != nil { + return ret, fmt.Errorf("invalid return type constraint: %s", err) + } + + if len(protoFunc.Parameters) != 0 { + ret.Parameters = make([]providers.FunctionParam, len(protoFunc.Parameters)) + for i, protoParam := range protoFunc.Parameters { + param, err := functionParamFromProto(protoParam) + if err != nil { + return ret, fmt.Errorf("invalid parameter %d (%q): %s", i, protoParam.Name, err) + } + ret.Parameters[i] = param + } + } + if protoFunc.VariadicParameter != nil { + param, err := functionParamFromProto(protoFunc.VariadicParameter) + if err != nil { + return ret, fmt.Errorf("invalid variadic parameter (%q): %s", protoFunc.VariadicParameter.Name, err) + } + ret.VariadicParameter = ¶m + } + + return ret, nil +} + +func functionParamFromProto(protoParam *tfplugin6.Function_Parameter) (providers.FunctionParam, error) { + var ret providers.FunctionParam + ret.Name = protoParam.Name + ret.Description = protoParam.Description + ret.DescriptionKind = schemaStringKind(protoParam.DescriptionKind) + ret.AllowNullValue = protoParam.AllowNullValue + ret.AllowUnknownValues = protoParam.AllowUnknownValues + if err := json.Unmarshal(protoParam.Type, &ret.Type); err != nil { + return ret, fmt.Errorf("invalid type constraint: %s", err) + } + return ret, nil +} + +func FunctionDeclsToProto(fns map[string]providers.FunctionDecl) (map[string]*tfplugin6.Function, error) { + if len(fns) == 0 { + return nil, nil + } + + ret := make(map[string]*tfplugin6.Function, len(fns)) + for name, fn := range fns { + decl, err := FunctionDeclToProto(fn) + if err != nil { + return nil, fmt.Errorf("invalid declaration for function %q: %s", name, err) + } + ret[name] = decl + } + return ret, nil +} + +func FunctionDeclToProto(fn providers.FunctionDecl) (*tfplugin6.Function, error) { + ret := &tfplugin6.Function{ + Return: &tfplugin6.Function_Return{}, + } + + ret.Description = fn.Description + ret.DescriptionKind = protoStringKind(fn.DescriptionKind) + + retTy, err := json.Marshal(fn.ReturnType) + if err != nil { + return ret, fmt.Errorf("invalid return type constraint: %s", err) + } + ret.Return.Type = retTy + + if len(fn.Parameters) != 0 { + ret.Parameters = make([]*tfplugin6.Function_Parameter, len(fn.Parameters)) + for i, fnParam := range fn.Parameters { + protoParam, err := functionParamToProto(fnParam) + if err != nil { + return ret, fmt.Errorf("invalid parameter %d (%q): %s", i, fnParam.Name, err) + } + ret.Parameters[i] = protoParam + } + } + if fn.VariadicParameter != nil { + param, err := functionParamToProto(*fn.VariadicParameter) + if err != nil { + return ret, fmt.Errorf("invalid variadic parameter (%q): %s", fn.VariadicParameter.Name, err) + } + ret.VariadicParameter = param + } + + return ret, nil +} + +func functionParamToProto(param providers.FunctionParam) (*tfplugin6.Function_Parameter, error) { + ret := &tfplugin6.Function_Parameter{} + ret.Name = param.Name + ret.Description = param.Description + ret.DescriptionKind = protoStringKind(param.DescriptionKind) + ret.AllowNullValue = param.AllowNullValue + ret.AllowUnknownValues = param.AllowUnknownValues + ty, err := json.Marshal(param.Type) + if err != nil { + return ret, fmt.Errorf("invalid type constraint: %s", err) + } + ret.Type = ty + return ret, nil +} diff --git a/internal/plugin6/grpc_provider.go b/internal/plugin6/grpc_provider.go index 70e65cd7f9..0a4728404e 100644 --- a/internal/plugin6/grpc_provider.go +++ b/internal/plugin6/grpc_provider.go @@ -74,7 +74,6 @@ type GRPCProvider struct { } func (p *GRPCProvider) GetProviderSchema() providers.GetProviderSchemaResponse { - logger.Trace("GRPCProvider.v6: GetProviderSchema") p.mu.Lock() defer p.mu.Unlock() @@ -88,6 +87,7 @@ func (p *GRPCProvider) GetProviderSchema() providers.GetProviderSchemaResponse { return resp } } + logger.Trace("GRPCProvider.v6: GetProviderSchema") // If the local cache is non-zero, we know this instance has called // GetProviderSchema at least once and we can return early. @@ -141,6 +141,13 @@ func (p *GRPCProvider) GetProviderSchema() providers.GetProviderSchemaResponse { resp.DataSources[name] = convert.ProtoToProviderSchema(data) } + if decls, err := convert.FunctionDeclsFromProto(protoResp.Functions); err == nil { + resp.Functions = decls + } else { + resp.Diagnostics = resp.Diagnostics.Append(err) + return resp + } + if protoResp.ServerCapabilities != nil { resp.ServerCapabilities.PlanDestroy = protoResp.ServerCapabilities.PlanDestroy resp.ServerCapabilities.GetProviderSchemaOptional = protoResp.ServerCapabilities.GetProviderSchemaOptional @@ -667,6 +674,78 @@ func (p *GRPCProvider) ReadDataSource(r providers.ReadDataSourceRequest) (resp p return resp } +func (p *GRPCProvider) CallFunction(r providers.CallFunctionRequest) (resp providers.CallFunctionResponse) { + logger.Trace("GRPCProvider.v6", "CallFunction", r.FunctionName) + + schema := p.GetProviderSchema() + if schema.Diagnostics.HasErrors() { + resp.Diagnostics = schema.Diagnostics + return resp + } + + funcDecl, ok := schema.Functions[r.FunctionName] + // We check for various problems with the request below in the interests + // of robustness, just to avoid crashing while trying to encode/decode, but + // if we reach any of these errors then that suggests a bug in the caller, + // because we should catch function calls that don't match the schema at an + // earlier point than this. + if !ok { + // Should only get here if the caller has a bug, because we should + // have detected earlier any attempt to call a function that the + // provider didn't declare. + resp.Diagnostics = resp.Diagnostics.Append(fmt.Errorf("provider has no function named %q", r.FunctionName)) + return resp + } + if len(r.Arguments) < len(funcDecl.Parameters) { + resp.Diagnostics = resp.Diagnostics.Append(fmt.Errorf("not enough arguments for function %q", r.FunctionName)) + return resp + } + if funcDecl.VariadicParameter == nil && len(r.Arguments) > len(funcDecl.Parameters) { + resp.Diagnostics = resp.Diagnostics.Append(fmt.Errorf("too many arguments for function %q", r.FunctionName)) + return resp + } + args := make([]*proto6.DynamicValue, len(r.Arguments)) + for i, argVal := range r.Arguments { + var paramDecl providers.FunctionParam + if i < len(funcDecl.Parameters) { + paramDecl = funcDecl.Parameters[i] + } else { + paramDecl = *funcDecl.VariadicParameter + } + + argValRaw, err := msgpack.Marshal(argVal, paramDecl.Type) + if err != nil { + resp.Diagnostics = resp.Diagnostics.Append(err) + return resp + } + args[i] = &proto6.DynamicValue{ + Msgpack: argValRaw, + } + } + + protoResp, err := p.client.CallFunction(p.ctx, &proto6.CallFunction_Request{ + Name: r.FunctionName, + Arguments: args, + }) + if err != nil { + resp.Diagnostics = resp.Diagnostics.Append(grpcErr(err)) + return resp + } + resp.Diagnostics = resp.Diagnostics.Append(convert.ProtoToDiagnostics(protoResp.Diagnostics)) + if resp.Diagnostics.HasErrors() { + return resp + } + + resultVal, err := decodeDynamicValue(protoResp.Result, funcDecl.ReturnType) + if err != nil { + resp.Diagnostics = resp.Diagnostics.Append(err) + return resp + } + + resp.Result = resultVal + return resp +} + // closing the grpc connection is final, and terraform will call it at the end of every phase. func (p *GRPCProvider) Close() error { logger.Trace("GRPCProvider.v6: Close") diff --git a/internal/provider-simple-v6/provider.go b/internal/provider-simple-v6/provider.go index 86dffcb7d6..d9f3461ae9 100644 --- a/internal/provider-simple-v6/provider.go +++ b/internal/provider-simple-v6/provider.go @@ -145,6 +145,12 @@ func (s simple) ReadDataSource(req providers.ReadDataSourceRequest) (resp provid return resp } +func (s simple) CallFunction(req providers.CallFunctionRequest) (resp providers.CallFunctionResponse) { + // Our schema doesn't include any functions, so it should be impossible + // to get in here. + panic("CallFunction on provider that didn't declare any functions") +} + func (s simple) Close() error { return nil } diff --git a/internal/provider-simple/provider.go b/internal/provider-simple/provider.go index 0bf874aa6e..205e6f7453 100644 --- a/internal/provider-simple/provider.go +++ b/internal/provider-simple/provider.go @@ -136,6 +136,12 @@ func (s simple) ReadDataSource(req providers.ReadDataSourceRequest) (resp provid return resp } +func (s simple) CallFunction(req providers.CallFunctionRequest) (resp providers.CallFunctionResponse) { + // Our schema doesn't include any functions, so it should be impossible + // to get in here. + panic("CallFunction on provider that didn't declare any functions") +} + func (s simple) Close() error { return nil } diff --git a/internal/providers/functions.go b/internal/providers/functions.go index 55f6ccc03a..de626d2984 100644 --- a/internal/providers/functions.go +++ b/internal/providers/functions.go @@ -1,10 +1,10 @@ -// Copyright (c) HashiCorp, Inc. -// SPDX-License-Identifier: BUSL-1.1 - package providers import ( + "fmt" + "github.com/zclconf/go-cty/cty" + "github.com/zclconf/go-cty/cty/function" "github.com/hashicorp/terraform/internal/configs/configschema" ) @@ -22,9 +22,118 @@ type FunctionParam struct { Name string // Only for documentation and UI, because arguments are positional Type cty.Type - Nullable bool + AllowNullValue bool AllowUnknownValues bool Description string DescriptionKind configschema.StringKind } + +// BuildFunction takes a factory function which will return an unconfigured +// instance of the provider this declaration belongs to and returns a +// cty function that is ready to be called against that provider. +// +// The given name must be the name under which the provider originally +// registered this declaration, or the returned function will try to use an +// invalid name, leading to errors or undefined behavior. +// +// If the given factory returns an instance of any provider other than the +// one the declaration belongs to, or returns a _configured_ instance of +// the provider rather than an unconfigured one, the behavior of the returned +// function is undefined. +// +// Although not functionally required, callers should ideally pass a factory +// function that either retrieves already-running plugins or memoizes the +// plugins it returns so that many calls to functions in the same provider +// will not incur a repeated startup cost. +func (d FunctionDecl) BuildFunction(name string, factory func() (Interface, error)) function.Function { + + var params []function.Parameter + var varParam *function.Parameter + if len(d.Parameters) > 0 { + params = make([]function.Parameter, len(d.Parameters)) + for i, paramDecl := range d.Parameters { + params[i] = paramDecl.ctyParameter() + } + } + if d.VariadicParameter != nil { + cp := d.VariadicParameter.ctyParameter() + varParam = &cp + } + + return function.New(&function.Spec{ + Type: function.StaticReturnType(d.ReturnType), + Params: params, + VarParam: varParam, + Impl: func(args []cty.Value, retType cty.Type) (cty.Value, error) { + for i, arg := range args { + var param function.Parameter + if i < len(params) { + param = params[i] + } else { + param = *varParam + } + + // We promise provider developers that we won't pass them even + // _nested_ unknown values unless they opt in to dealing with + // them. + if !param.AllowUnknown { + if !arg.IsWhollyKnown() { + return cty.UnknownVal(retType), nil + } + } + + // We also ensure that null values are never passed where they + // are not expected. + if !param.AllowNull { + if arg.IsNull() { + return cty.UnknownVal(retType), fmt.Errorf("argument %q cannot be null", param.Name) + } + } + } + + provider, err := factory() + if err != nil { + return cty.UnknownVal(retType), fmt.Errorf("failed to launch provider plugin: %s", err) + } + + resp := provider.CallFunction(CallFunctionRequest{ + FunctionName: name, + Arguments: args, + }) + // NOTE: We don't actually have any way to surface warnings + // from the function here, because functions just return normal + // Go errors rather than diagnostics. + if resp.Diagnostics.HasErrors() { + return cty.UnknownVal(retType), resp.Diagnostics.Err() + } + + if resp.Result == cty.NilVal { + return cty.UnknownVal(retType), fmt.Errorf("provider returned no result and no errors") + } + + return resp.Result, nil + }, + }) +} + +func (p *FunctionParam) ctyParameter() function.Parameter { + return function.Parameter{ + Name: p.Name, + Type: p.Type, + AllowNull: p.AllowNullValue, + + // While the function may not allow DynamicVal, a `null` literal is + // also dynamically typed. If the parameter is dynamically typed, then + // we must allow this for `null` to pass through. + AllowDynamicType: p.Type == cty.DynamicPseudoType, + + // NOTE: Setting this is not a sufficient implementation of + // FunctionParam.AllowUnknownValues, because cty's function + // system only blocks passing in a top-level unknown, but + // our provider-contributed functions API promises to only + // pass wholly-known values unless AllowUnknownValues is true. + // The function implementation itself must also check this. + AllowUnknown: p.AllowUnknownValues, + } +} diff --git a/internal/providers/mock.go b/internal/providers/mock.go index cce51e0d72..e05c20d0a6 100644 --- a/internal/providers/mock.go +++ b/internal/providers/mock.go @@ -263,6 +263,10 @@ func (m *Mock) ReadDataSource(request ReadDataSourceRequest) ReadDataSourceRespo return response } +func (m *Mock) CallFunction(request CallFunctionRequest) CallFunctionResponse { + return m.Provider.CallFunction(request) +} + func (m *Mock) Close() error { return m.Provider.Close() } diff --git a/internal/terraform/context_plugins.go b/internal/terraform/context_plugins.go index 112996f71c..33df850204 100644 --- a/internal/terraform/context_plugins.go +++ b/internal/terraform/context_plugins.go @@ -7,6 +7,7 @@ import ( "fmt" "log" + "github.com/hashicorp/hcl/v2/hclsyntax" "github.com/hashicorp/terraform/internal/addrs" "github.com/hashicorp/terraform/internal/configs/configschema" "github.com/hashicorp/terraform/internal/providers" @@ -139,6 +140,33 @@ func (cp *contextPlugins) ProviderSchema(addr addrs.Provider) (providers.Provide } } + for n, f := range resp.Functions { + if !hclsyntax.ValidIdentifier(n) { + return resp, fmt.Errorf("provider %s declares function with invalid name %q", addr, n) + } + // We'll also do some enforcement of parameter names, even though they + // are only for docs/UI for now, to leave room for us to potentially + // use them for other purposes later. + seenParams := make(map[string]int, len(f.Parameters)) + for i, p := range f.Parameters { + if !hclsyntax.ValidIdentifier(p.Name) { + return resp, fmt.Errorf("provider %s function %q declares invalid name %q for parameter %d", addr, n, p.Name, i) + } + if prevIdx, exists := seenParams[p.Name]; exists { + return resp, fmt.Errorf("provider %s function %q reuses name %q for both parameters %d and %d", addr, n, p.Name, prevIdx, i) + } + seenParams[p.Name] = i + } + if p := f.VariadicParameter; p != nil { + if !hclsyntax.ValidIdentifier(p.Name) { + return resp, fmt.Errorf("provider %s function %q declares invalid name %q for its variadic parameter", addr, n, p.Name) + } + if prevIdx, exists := seenParams[p.Name]; exists { + return resp, fmt.Errorf("provider %s function %q reuses name %q for both parameter %d and its variadic parameter", addr, n, p.Name, prevIdx) + } + } + } + return resp, nil } @@ -197,3 +225,20 @@ func (cp *contextPlugins) ProvisionerSchema(typ string) (*configschema.Block, er return resp.Provisioner, nil } + +// ProviderFunctionDecls is a helper wrapper around ProviderSchema which first +// reads the schema of the given provider and then returns all of the +// functions it declares, if any. +// +// ProviderFunctionDecl will return an error if the provider schema lookup +// fails, but will return an empty set of functions if a successful response +// returns no functions, or if the provider is using an older protocol version +// which has no support for provider-contributed functions. +func (cp *contextPlugins) ProviderFunctionDecls(providerAddr addrs.Provider) (map[string]providers.FunctionDecl, error) { + providerSchema, err := cp.ProviderSchema(providerAddr) + if err != nil { + return nil, err + } + + return providerSchema.Functions, nil +} diff --git a/internal/terraform/provider_mock.go b/internal/terraform/provider_mock.go index d9cafd5d21..e46e37bc06 100644 --- a/internal/terraform/provider_mock.go +++ b/internal/terraform/provider_mock.go @@ -85,6 +85,11 @@ type MockProvider struct { ReadDataSourceRequest providers.ReadDataSourceRequest ReadDataSourceFn func(providers.ReadDataSourceRequest) providers.ReadDataSourceResponse + CallFunctionCalled bool + CallFunctionResponse providers.CallFunctionResponse + CallFunctionRequest providers.CallFunctionRequest + CallFunctionFn func(providers.CallFunctionRequest) providers.CallFunctionResponse + CloseCalled bool CloseError error } @@ -517,6 +522,20 @@ func (p *MockProvider) ReadDataSource(r providers.ReadDataSourceRequest) (resp p return resp } +func (p *MockProvider) CallFunction(r providers.CallFunctionRequest) providers.CallFunctionResponse { + p.Lock() + defer p.Unlock() + + p.CallFunctionCalled = true + p.CallFunctionRequest = r + + if p.ReadDataSourceFn != nil { + return p.CallFunctionFn(r) + } + + return p.CallFunctionResponse +} + func (p *MockProvider) Close() error { p.Lock() defer p.Unlock()