diff --git a/lang/funcs/collection.go b/lang/funcs/collection.go index 9ac5e06307..b3f0f09278 100644 --- a/lang/funcs/collection.go +++ b/lang/funcs/collection.go @@ -411,12 +411,17 @@ var KeysFunc = function.New(&function.Spec{ Params: []function.Parameter{ { Name: "inputMap", - Type: cty.Map(cty.DynamicPseudoType), + Type: cty.DynamicPseudoType, }, }, Type: function.StaticReturnType(cty.List(cty.String)), Impl: func(args []cty.Value, retType cty.Type) (ret cty.Value, err error) { var keys []cty.Value + ty := args[0].Type() + + if !ty.IsObjectType() && !ty.IsMapType() { + return cty.NilVal, fmt.Errorf("keys() requires a map") + } for it := args[0].ElementIterator(); it.Next(); { k, _ := it.Element() @@ -861,11 +866,25 @@ var ValuesFunc = function.New(&function.Spec{ Params: []function.Parameter{ { Name: "values", - Type: cty.Map(cty.DynamicPseudoType), + Type: cty.DynamicPseudoType, }, }, Type: func(args []cty.Value) (ret cty.Type, err error) { - return cty.List(args[0].Type().ElementType()), nil + ty := args[0].Type() + if ty.IsMapType() { + return cty.List(ty.ElementType()), nil + } else if ty.IsObjectType() { + var tys []cty.Type + for _, v := range ty.AttributeTypes() { + tys = append(tys, v) + } + retType, _ := convert.UnifyUnsafe(tys) + if retType == cty.NilType { + return cty.NilType, fmt.Errorf("all arguments must have the same type") + } + return cty.List(retType), nil + } + return cty.NilType, fmt.Errorf("values() requires a map as the first argument") }, Impl: func(args []cty.Value, retType cty.Type) (ret cty.Value, err error) { mapVar := args[0] @@ -886,9 +905,17 @@ var ValuesFunc = function.New(&function.Spec{ var values []cty.Value for it := keys.ElementIterator(); it.Next(); { - _, k := it.Element() - value := mapVar.Index(cty.StringVal(k.AsString())) - values = append(values, value) + _, key := it.Element() + k := key.AsString() + if mapVar.Type().IsObjectType() { + if mapVar.Type().HasAttribute(k) { + value := mapVar.GetAttr(k) + values = append(values, value) + } + } else { + value := mapVar.Index(cty.StringVal(k)) + values = append(values, value) + } } if len(values) == 0 { diff --git a/lang/funcs/collection_test.go b/lang/funcs/collection_test.go index 26a50af2f6..5078631e2d 100644 --- a/lang/funcs/collection_test.go +++ b/lang/funcs/collection_test.go @@ -997,6 +997,17 @@ func TestKeys(t *testing.T) { }), false, }, + { // same as above, but an object type + cty.ObjectVal(map[string]cty.Value{ + "hello": cty.NumberIntVal(1), + "goodbye": cty.StringVal("adieu"), + }), + cty.ListVal([]cty.Value{ + cty.StringVal("goodbye"), + cty.StringVal("hello"), + }), + false, + }, { // Not a map cty.StringVal("foo"), cty.NilVal, @@ -1967,6 +1978,17 @@ func TestValues(t *testing.T) { }), false, }, + { + cty.ObjectVal(map[string]cty.Value{ + "hello": cty.StringVal("world"), + "what's": cty.StringVal("up"), + }), + cty.ListVal([]cty.Value{ + cty.StringVal("world"), + cty.StringVal("up"), + }), + false, + }, { // note ordering: keys are sorted first cty.MapVal(map[string]cty.Value{ "hello": cty.NumberIntVal(1),