From f9b93d13a2956cf2035be7e232b4c2b921bc980b Mon Sep 17 00:00:00 2001 From: James Bardin Date: Fri, 8 Dec 2023 15:10:41 -0500 Subject: [PATCH] update grpcwrap to handle provider functions --- internal/grpcwrap/provider.go | 58 ++++++++++++++++++++++++++++++++-- internal/grpcwrap/provider6.go | 58 ++++++++++++++++++++++++++++++++-- 2 files changed, 111 insertions(+), 5 deletions(-) diff --git a/internal/grpcwrap/provider.go b/internal/grpcwrap/provider.go index ac6775f7ba..4e4906d7df 100644 --- a/internal/grpcwrap/provider.go +++ b/internal/grpcwrap/provider.go @@ -5,6 +5,7 @@ package grpcwrap import ( "context" + "fmt" "github.com/hashicorp/terraform/internal/plugin/convert" "github.com/hashicorp/terraform/internal/providers" @@ -67,6 +68,12 @@ func (p *provider) GetSchema(_ context.Context, req *tfplugin5.GetProviderSchema Block: convert.ConfigSchemaToProto(dat.Block), } } + if decls, err := convert.FunctionDeclsToProto(p.schema.Functions); err == nil { + resp.Functions = decls + } else { + resp.Diagnostics = convert.AppendProtoDiag(resp.Diagnostics, err) + return resp, nil + } resp.ServerCapabilities = &tfplugin5.ServerCapabilities{ GetProviderSchemaOptional: p.schema.ServerCapabilities.GetProviderSchemaOptional, @@ -399,10 +406,55 @@ func (p *provider) GetFunctions(context.Context, *tfplugin5.GetFunctions_Request } func (p *provider) CallFunction(_ context.Context, req *tfplugin5.CallFunction_Request) (*tfplugin5.CallFunction_Response, error) { - panic("unimplemented") - return nil, nil -} + var err error + resp := &tfplugin5.CallFunction_Response{} + + funcSchema := p.schema.Functions[req.Name] + + var args []cty.Value + if len(req.Arguments) != 0 { + args = make([]cty.Value, len(req.Arguments)) + for i, rawArg := range req.Arguments { + + var argTy cty.Type + if i < len(funcSchema.Parameters) { + argTy = funcSchema.Parameters[i].Type + } else { + if funcSchema.VariadicParameter == nil { + resp.Diagnostics = convert.AppendProtoDiag( + resp.Diagnostics, fmt.Errorf("too many arguments for non-variadic function"), + ) + return resp, nil + } + argTy = funcSchema.VariadicParameter.Type + } + + argVal, err := decodeDynamicValue(rawArg, argTy) + if err != nil { + resp.Diagnostics = convert.AppendProtoDiag(resp.Diagnostics, err) + return resp, nil + } + args[i] = argVal + } + } + + callResp := p.provider.CallFunction(providers.CallFunctionRequest{ + FunctionName: req.Name, + Arguments: args, + }) + resp.Diagnostics = convert.AppendProtoDiag(resp.Diagnostics, callResp.Diagnostics) + if callResp.Diagnostics.HasErrors() { + return resp, nil + } + + resp.Result, err = encodeDynamicValue(callResp.Result, funcSchema.ReturnType) + if err != nil { + resp.Diagnostics = convert.AppendProtoDiag(resp.Diagnostics, err) + return resp, nil + } + return resp, nil +} func (p *provider) Stop(context.Context, *tfplugin5.Stop_Request) (*tfplugin5.Stop_Response, error) { resp := &tfplugin5.Stop_Response{} err := p.provider.Stop() diff --git a/internal/grpcwrap/provider6.go b/internal/grpcwrap/provider6.go index 92984e92c0..16b98322d9 100644 --- a/internal/grpcwrap/provider6.go +++ b/internal/grpcwrap/provider6.go @@ -5,6 +5,7 @@ package grpcwrap import ( "context" + "fmt" "github.com/hashicorp/terraform/internal/plugin6/convert" "github.com/hashicorp/terraform/internal/providers" @@ -39,6 +40,7 @@ func (p *provider6) GetProviderSchema(_ context.Context, req *tfplugin6.GetProvi resp := &tfplugin6.GetProviderSchema_Response{ ResourceSchemas: make(map[string]*tfplugin6.Schema), DataSourceSchemas: make(map[string]*tfplugin6.Schema), + Functions: make(map[string]*tfplugin6.Function), } resp.Provider = &tfplugin6.Schema{ @@ -67,6 +69,12 @@ func (p *provider6) GetProviderSchema(_ context.Context, req *tfplugin6.GetProvi Block: convert.ConfigSchemaToProto(dat.Block), } } + if decls, err := convert.FunctionDeclsToProto(p.schema.Functions); err == nil { + resp.Functions = decls + } else { + resp.Diagnostics = convert.AppendProtoDiag(resp.Diagnostics, err) + return resp, nil + } resp.ServerCapabilities = &tfplugin6.ServerCapabilities{ GetProviderSchemaOptional: p.schema.ServerCapabilities.GetProviderSchemaOptional, @@ -399,8 +407,54 @@ func (p *provider6) GetFunctions(context.Context, *tfplugin6.GetFunctions_Reques } func (p *provider6) CallFunction(_ context.Context, req *tfplugin6.CallFunction_Request) (*tfplugin6.CallFunction_Response, error) { - panic("unimplemented") - return nil, nil + var err error + resp := &tfplugin6.CallFunction_Response{} + + funcSchema := p.schema.Functions[req.Name] + + var args []cty.Value + if len(req.Arguments) != 0 { + args = make([]cty.Value, len(req.Arguments)) + for i, rawArg := range req.Arguments { + + var argTy cty.Type + if i < len(funcSchema.Parameters) { + argTy = funcSchema.Parameters[i].Type + } else { + if funcSchema.VariadicParameter == nil { + resp.Diagnostics = convert.AppendProtoDiag( + resp.Diagnostics, fmt.Errorf("too many arguments for non-variadic function"), + ) + return resp, nil + } + argTy = funcSchema.VariadicParameter.Type + } + + argVal, err := decodeDynamicValue6(rawArg, argTy) + if err != nil { + resp.Diagnostics = convert.AppendProtoDiag(resp.Diagnostics, err) + return resp, nil + } + args[i] = argVal + } + } + + callResp := p.provider.CallFunction(providers.CallFunctionRequest{ + FunctionName: req.Name, + Arguments: args, + }) + resp.Diagnostics = convert.AppendProtoDiag(resp.Diagnostics, callResp.Diagnostics) + if callResp.Diagnostics.HasErrors() { + return resp, nil + } + + resp.Result, err = encodeDynamicValue6(callResp.Result, funcSchema.ReturnType) + if err != nil { + resp.Diagnostics = convert.AppendProtoDiag(resp.Diagnostics, err) + return resp, nil + } + + return resp, nil } func (p *provider6) StopProvider(context.Context, *tfplugin6.StopProvider_Request) (*tfplugin6.StopProvider_Response, error) {