diff --git a/internal/host/options.go b/internal/host/options.go index 6e2afe4d67..87d27c1555 100644 --- a/internal/host/options.go +++ b/internal/host/options.go @@ -19,6 +19,7 @@ type options struct { WithLimit int WithOrderByCreateTime bool Ascending bool + WithSetMembers bool } func getDefaultOptions() options { @@ -44,3 +45,13 @@ func WithOrderByCreateTime(ascending bool) Option { return nil } } + +// WithSetMembers controls whether to include set members in a lookup. This will +// always be true for static and may be true for plugin when we have caching, +// but for now this skips API calls on authentication. +func WithSetMembers(with bool) Option { + return func(o *options) error { + o.WithSetMembers = true + return nil + } +} diff --git a/internal/host/plugin/loopback.go b/internal/host/plugin/loopback.go index 7202493a49..cfd3c06fdb 100644 --- a/internal/host/plugin/loopback.go +++ b/internal/host/plugin/loopback.go @@ -2,6 +2,7 @@ package plugin import ( "context" + "fmt" "github.com/hashicorp/boundary/internal/errors" plgpb "github.com/hashicorp/boundary/sdk/pbs/plugin" @@ -20,7 +21,7 @@ var _ plgpb.HostPluginServiceServer = (*loopbackPlugin)(nil) type loopbackPlugin struct { *TestPluginServer - hostMap map[string]*loopbackPluginHostInfo + hostMap map[string][]*loopbackPluginHostInfo } type loopbackPluginHostInfo struct { @@ -33,7 +34,7 @@ type loopbackPluginHostInfo struct { func NewLoopbackPlugin() plgpb.HostPluginServiceServer { ret := &loopbackPlugin{ TestPluginServer: new(TestPluginServer), - hostMap: make(map[string]*loopbackPluginHostInfo), + hostMap: make(map[string][]*loopbackPluginHostInfo), } ret.OnCreateCatalogFn = ret.onCreateCatalog ret.OnCreateSetFn = ret.onCreateSet @@ -71,11 +72,24 @@ func (l *loopbackPlugin) onCreateSet(ctx context.Context, req *plgpb.OnCreateSet if attrs := set.GetAttributes(); attrs != nil { attrsMap := attrs.AsMap() if field := attrsMap[loopbackPluginHostInfoAttrField]; field != nil { - hostInfo := new(loopbackPluginHostInfo) - if err := mapstructure.Decode(field, hostInfo); err != nil { - return nil, errors.Wrap(ctx, err, op) + switch t := field.(type) { + case []interface{}: + for _, h := range t { + hostInfo := new(loopbackPluginHostInfo) + if err := mapstructure.Decode(h, hostInfo); err != nil { + return nil, errors.Wrap(ctx, err, op) + } + l.hostMap[set.GetId()] = append(l.hostMap[set.GetId()], hostInfo) + } + case map[string]interface{}: + hostInfo := new(loopbackPluginHostInfo) + if err := mapstructure.Decode(t, hostInfo); err != nil { + return nil, errors.Wrap(ctx, err, op) + } + l.hostMap[set.GetId()] = append(l.hostMap[set.GetId()], hostInfo) + default: + return nil, errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("unknown host info type %T", t)) } - l.hostMap[set.GetId()] = hostInfo } } return nil, nil @@ -101,16 +115,18 @@ func (l *loopbackPlugin) listHosts(ctx context.Context, req *plgpb.ListHostsRequ } resp := new(plgpb.ListHostsResponse) for _, set := range req.GetSets() { - hostInfo := l.hostMap[set.GetId()] - if hostInfo == nil { + hostInfos := l.hostMap[set.GetId()] + if len(hostInfos) == 0 { continue } - resp.Hosts = append(resp.Hosts, &plgpb.ListHostsResponseHost{ - SetIds: []string{set.GetId()}, - ExternalId: hostInfo.ExternalId, - IpAddresses: hostInfo.IpAddresses, - DnsNames: hostInfo.DnsNames, - }) + for _, host := range hostInfos { + resp.Hosts = append(resp.Hosts, &plgpb.ListHostsResponseHost{ + SetIds: []string{set.GetId()}, + ExternalId: host.ExternalId, + IpAddresses: host.IpAddresses, + DnsNames: host.DnsNames, + }) + } } return resp, nil } diff --git a/internal/host/plugin/loopback_test.go b/internal/host/plugin/loopback_test.go index ec73cd731e..6bd0dfcf1b 100644 --- a/internal/host/plugin/loopback_test.go +++ b/internal/host/plugin/loopback_test.go @@ -197,3 +197,142 @@ func TestLoopbackPlugin(t *testing.T) { }) } } + +func TestLoopbackPluginArrays(t *testing.T) { + require := tr.New(t) + ctx := context.Background() + + plg := NewLoopbackPlugin() + + // Add data to some sets + hostInfo1 := map[string]interface{}{ + loopbackPluginHostInfoAttrField: []interface{}{ + map[string]interface{}{ + "set_ids": []interface{}{"set1"}, + "external_id": "host1a", + "ip_addresses": []interface{}{"1.2.3.4", "2.3.4.5"}, + "dns_names": []interface{}{"foo.com"}, + }, + map[string]interface{}{ + "set_ids": []interface{}{"set1"}, + "external_id": "host1b", + "ip_addresses": []interface{}{"3.4.5.6", "4.5.6.7"}, + "dns_names": []interface{}{"bar.com"}, + }, + }, + } + attrs, err := structpb.NewStruct(hostInfo1) + require.NoError(err) + _, err = plg.OnCreateSet(ctx, &plgpb.OnCreateSetRequest{ + Set: &hostsets.HostSet{ + Id: "set1", + Attributes: attrs, + }, + }) + require.NoError(err) + hostInfo2 := map[string]interface{}{ + loopbackPluginHostInfoAttrField: []interface{}{ + map[string]interface{}{ + "set_ids": []interface{}{"set2"}, + "external_id": "host2a", + "ip_addresses": []interface{}{"10.20.30.40", "20.30.40.50"}, + "dns_names": []interface{}{"foz.com"}, + }, + map[string]interface{}{ + "set_ids": []interface{}{"set2"}, + "external_id": "host2b", + "ip_addresses": []interface{}{"30.40.50.60", "40.50.60.70"}, + "dns_names": []interface{}{"baz.com"}, + }, + }, + } + attrs, err = structpb.NewStruct(hostInfo2) + require.NoError(err) + _, err = plg.OnCreateSet(ctx, &plgpb.OnCreateSetRequest{ + Set: &hostsets.HostSet{ + Id: "set2", + Attributes: attrs, + }, + }) + require.NoError(err) + + // Define test struct and validation function + type testInfo struct { + name string + sets []string + found []interface{} + } + validateSets := func(t *testing.T, tt testInfo) { + require, assert := tr.New(t), ta.New(t) + var hostSets []*hostsets.HostSet + for _, set := range tt.sets { + hostSets = append(hostSets, &hostsets.HostSet{Id: set}) + } + resp, err := plg.ListHosts(ctx, &plgpb.ListHostsRequest{ + Sets: hostSets, + }) + require.NoError(err) + if len(tt.found) == 0 { + assert.Len(resp.GetHosts(), 0) + return + } + + require.Greater(len(resp.GetHosts()), 0) + + var found []interface{} + for _, host := range resp.GetHosts() { + hostMap := map[string]interface{}{ + "external_id": host.GetExternalId(), + } + var sets []interface{} + for _, set := range host.SetIds { + sets = append(sets, set) + } + var ips []interface{} + for _, ip := range host.GetIpAddresses() { + ips = append(ips, ip) + } + var names []interface{} + for _, name := range host.GetDnsNames() { + names = append(names, name) + } + if len(sets) > 0 { + hostMap["set_ids"] = sets + } + if len(ips) > 0 { + hostMap["ip_addresses"] = ips + } + if len(names) > 0 { + hostMap["dns_names"] = names + } + found = append(found, hostMap) + } + assert.ElementsMatch(tt.found, found) + } + + // First set of tests: check that we can look up sets individually and + // together + setTests := []testInfo{ + { + name: "set 1", + sets: []string{"set1"}, + found: hostInfo1[loopbackPluginHostInfoAttrField].([]interface{}), + }, + { + name: "set 2", + sets: []string{"set2"}, + found: hostInfo2[loopbackPluginHostInfoAttrField].([]interface{}), + }, + { + name: "sets 1 and 2", + sets: []string{"set1", "set2"}, + found: append(hostInfo1[loopbackPluginHostInfoAttrField].([]interface{}), + hostInfo2[loopbackPluginHostInfoAttrField].([]interface{})...), + }, + } + for _, tt := range setTests { + t.Run(tt.name, func(t *testing.T) { + validateSets(t, tt) + }) + } +} diff --git a/internal/host/plugin/repository_host_set.go b/internal/host/plugin/repository_host_set.go index 3abda7fc7b..61e2337f71 100644 --- a/internal/host/plugin/repository_host_set.go +++ b/internal/host/plugin/repository_host_set.go @@ -16,6 +16,7 @@ import ( "github.com/hashicorp/boundary/internal/oplog" hostplugin "github.com/hashicorp/boundary/internal/plugin/host" hcpb "github.com/hashicorp/boundary/sdk/pbs/controller/api/resources/hostcatalogs" + hspb "github.com/hashicorp/boundary/sdk/pbs/controller/api/resources/hostsets" pb "github.com/hashicorp/boundary/sdk/pbs/controller/api/resources/hostsets" plgpb "github.com/hashicorp/boundary/sdk/pbs/plugin" "google.golang.org/grpc/codes" @@ -153,28 +154,69 @@ func (r *Repository) CreateSet(ctx context.Context, scopeId string, s *HostSet, return returnedHostSet, plg, nil } -// LookupSet will look up a host set in the repository and return the host -// set. If the host set is not found, it will return nil, nil. -// All options are ignored. -func (r *Repository) LookupSet(ctx context.Context, publicId string, opt ...host.Option) (*HostSet, *hostplugin.Plugin, error) { +// LookupSet will look up a host set in the repository and return the host set, +// as well as host IDs that match. If the host set is not found, it will return +// nil, nil, nil, nil. Supported options: WithSetMembers, which requests that +// host IDs contained within the set are looked up and returned. (In the future +// we may make it automatic to return this if it's coming from the database.) +func (r *Repository) LookupSet(ctx context.Context, publicId string, opt ...host.Option) (*HostSet, []string, *hostplugin.Plugin, error) { const op = "plugin.(Repository).LookupSet" if publicId == "" { - return nil, nil, errors.New(ctx, errors.InvalidParameter, op, "no public id") + return nil, nil, nil, errors.New(ctx, errors.InvalidParameter, op, "no public id") + } + + opts, err := host.GetOpts(opt...) + if err != nil { + return nil, nil, nil, errors.Wrap(ctx, err, op) } sets, plg, err := r.getSets(ctx, publicId, "", opt...) if err != nil { - return nil, nil, errors.Wrap(ctx, err, op) + return nil, nil, nil, errors.Wrap(ctx, err, op) } switch { case len(sets) == 0: - return nil, nil, nil // not an error to return no rows for a "lookup" + return nil, nil, nil, nil // not an error to return no rows for a "lookup" case len(sets) > 1: - return nil, nil, errors.New(ctx, errors.NotSpecificIntegrity, op, fmt.Sprintf("%s matched more than 1 ", publicId)) - default: - return sets[0], plg, nil + return nil, nil, nil, errors.New(ctx, errors.NotSpecificIntegrity, op, fmt.Sprintf("%s matched more than 1 ", publicId)) } + + setToReturn := sets[0] + var hostIdsToReturn []string + + if plg != nil && opts.WithSetMembers { + plgSet, err := toPluginSet(ctx, setToReturn) + if err != nil { + return nil, nil, nil, errors.Wrap(ctx, err, op) + } + plgClient, ok := r.plugins[plg.GetPublicId()] + if !ok { + return nil, nil, nil, errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("no plugin found for plugin id %s", plg.GetPublicId())) + } + resp, err := plgClient.ListHosts(ctx, &plgpb.ListHostsRequest{ + Sets: []*hspb.HostSet{plgSet}, + }) + switch { + case err != nil: + // If it's just not implemented, e.g. for tests, don't error out, return what we have + if status.Code(err) != codes.Unimplemented { + return nil, nil, nil, errors.Wrap(ctx, err, op) + } + case resp != nil: + for _, respHost := range resp.GetHosts() { + hostId, err := newHostId(ctx, setToReturn.GetCatalogId(), respHost.GetExternalId()) + if err != nil { + return nil, nil, nil, errors.Wrap(ctx, err, op) + } + hostIdsToReturn = append(hostIdsToReturn, hostId) + } + } + } + + sort.Strings(hostIdsToReturn) + + return setToReturn, hostIdsToReturn, plg, nil } // ListSets returns a slice of HostSets for the catalogId. WithLimit is the @@ -337,7 +379,7 @@ func (agg *hostSetAgg) TableName() string { return "host_plugin_host_set_with_va // toPluginSet returns a host set in the format expected by the host plugin system. func toPluginSet(ctx context.Context, in *HostSet) (*pb.HostSet, error) { - const op = "plugin.toPluginCatalog" + const op = "plugin.toPluginSet" if in == nil { return nil, errors.New(ctx, errors.InvalidParameter, op, "nil storage plugin") } @@ -363,6 +405,8 @@ func (r *Repository) Endpoints(ctx context.Context, setIds []string) ([]*host.En if len(setIds) == 0 { return nil, errors.New(ctx, errors.InvalidParameter, op, "no set ids") } + + // Fist, look up the sets corresponding to the set IDs var setAggs []*hostSetAgg if err := r.reader.SearchWhere(ctx, &setAggs, "public_id in (?)", []interface{}{setIds}); err != nil { return nil, errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("can't retrieve sets %v", setIds))) @@ -377,13 +421,15 @@ func (r *Repository) Endpoints(ctx context.Context, setIds []string) ([]*host.En } type catalogInfo struct { - publicId string - plg plgpb.HostPluginServiceServer - setInfos map[string]*setInfo - plgCat *hcpb.HostCatalog - persisted *plgpb.HostCatalogPersisted + publicId string // ID of the catalog + plg plgpb.HostPluginServiceServer // plugin client for the catalog + setInfos map[string]*setInfo // map of set IDs to set information + plgCat *hcpb.HostCatalog // storage host catalog + persisted *plgpb.HostCatalogPersisted // host catalog persisted (secret) data } + // Next, look up the distinct catalog info and assign set infos to it. + // Notably, this does not include persisted info. catalogInfos := make(map[string]*catalogInfo) for _, ag := range setAggs { ci, ok := catalogInfos[ag.CatalogId] @@ -416,6 +462,7 @@ func (r *Repository) Endpoints(ctx context.Context, setIds []string) ([]*host.En catalogInfos[ag.CatalogId] = ci } + // Now, look up the catalog persisted (secret) information catIds := make([]string, 0, len(catalogInfos)) for k := range catalogInfos { catIds = append(catIds, k) @@ -438,7 +485,7 @@ func (r *Repository) Endpoints(ctx context.Context, setIds []string) ([]*host.En } ci.plgCat = plgCat - // TODO: Do these looksups from the DB in bulk instead of individually. + // TODO: Do these lookups from the DB in bulk instead of individually. per, err := r.getPersistedDataForCatalog(ctx, c) if err != nil { return nil, errors.Wrap(ctx, err, op, errors.WithMsg("persisted catalog lookup failed")) @@ -451,6 +498,8 @@ func (r *Repository) Endpoints(ctx context.Context, setIds []string) ([]*host.En var hosts []interface{} hostIds := map[string]bool{} var es []*host.Endpoint + + // For each distinct catalog, list all sets at once for _, ci := range catalogInfos { var sets []*pb.HostSet @@ -459,9 +508,9 @@ func (r *Repository) Endpoints(ctx context.Context, setIds []string) ([]*host.En } resp, err := ci.plg.ListHosts(ctx, &plgpb.ListHostsRequest{ - Catalog: ci.plgCat, - Sets: sets, - // Persisted: ci.persisted, + Catalog: ci.plgCat, + Sets: sets, + Persisted: ci.persisted, }) if err != nil { return nil, errors.Wrap(ctx, err, op) diff --git a/internal/host/plugin/repository_host_set_test.go b/internal/host/plugin/repository_host_set_test.go index 3327ab13b1..ffd6dba73b 100644 --- a/internal/host/plugin/repository_host_set_test.go +++ b/internal/host/plugin/repository_host_set_test.go @@ -352,7 +352,7 @@ func TestRepository_LookupSet(t *testing.T) { repo, err := NewRepository(rw, rw, kms, plgm) assert.NoError(err) require.NotNil(repo) - got, _, err := repo.LookupSet(ctx, tt.in) + got, _, _, err := repo.LookupSet(ctx, tt.in) if tt.wantIsErr != 0 { assert.Truef(errors.Match(errors.T(tt.wantIsErr), err), "want err: %q got: %q", tt.wantIsErr, err) assert.Nil(got) diff --git a/internal/servers/controller/handlers/host_sets/host_set_service.go b/internal/servers/controller/handlers/host_sets/host_set_service.go index 8cae153bbd..899d6b5cb5 100644 --- a/internal/servers/controller/handlers/host_sets/host_set_service.go +++ b/internal/servers/controller/handlers/host_sets/host_set_service.go @@ -9,6 +9,7 @@ import ( pbs "github.com/hashicorp/boundary/internal/gen/controller/api/services" "github.com/hashicorp/boundary/internal/host" "github.com/hashicorp/boundary/internal/host/plugin" + plugstore "github.com/hashicorp/boundary/internal/host/plugin/store" "github.com/hashicorp/boundary/internal/host/static" "github.com/hashicorp/boundary/internal/host/static/store" "github.com/hashicorp/boundary/internal/libs/endpoint" @@ -422,7 +423,7 @@ func (s Service) getFromRepo(ctx context.Context, id string) (host.Set, []host.H if err != nil { return nil, nil, nil, err } - hset, hsplg, err := repo.LookupSet(ctx, id) + hset, hostIds, hsplg, err := repo.LookupSet(ctx, id, host.WithSetMembers(true)) if err != nil { return nil, nil, nil, err } @@ -431,6 +432,14 @@ func (s Service) getFromRepo(ctx context.Context, id string) (host.Set, []host.H } hs = hset plg = toPluginInfo(hsplg) + for _, h := range hostIds { + hl = append(hl, &plugin.Host{ + Host: &plugstore.Host{ + PublicId: h, + CatalogId: hset.CatalogId, + }, + }) + } } return hs, hl, plg, nil } @@ -668,7 +677,7 @@ func (s Service) parentAndAuthResult(ctx context.Context, id string, a action.Ty } set = ss case plugin.Subtype: - ps, _, err := pluginRepo.LookupSet(ctx, id) + ps, _, _, err := pluginRepo.LookupSet(ctx, id) if err != nil { res.Error = err return nil, res diff --git a/internal/servers/controller/handlers/host_sets/host_set_service_test.go b/internal/servers/controller/handlers/host_sets/host_set_service_test.go index a772e29cc9..d8b6eebd93 100644 --- a/internal/servers/controller/handlers/host_sets/host_set_service_test.go +++ b/internal/servers/controller/handlers/host_sets/host_set_service_test.go @@ -148,17 +148,19 @@ func TestGet_Plugin(t *testing.T) { repoFn := func() (*static.Repository, error) { return static.NewRepository(rw, rw, kms) } - pluginRepoFn := func() (*plugin.Repository, error) { - return plugin.NewRepository(rw, rw, kms, map[string]plgpb.HostPluginServiceServer{}) - } name := "test" prefEndpoints := []string{"cidr:1.2.3.4", "cidr:2.3.4.5/24"} plg := hostplugin.TestPlugin(t, conn, name) - hc := plugin.TestCatalog(t, conn, proj.GetPublicId(), plg.GetPublicId()) - hs := plugin.TestSet(t, conn, kms, hc, map[string]plgpb.HostPluginServiceServer{ + plgm := map[string]plgpb.HostPluginServiceServer{ plg.GetPublicId(): &plugin.TestPluginServer{}, - }, plugin.WithPreferredEndpoints(prefEndpoints)) + } + pluginRepoFn := func() (*plugin.Repository, error) { + return plugin.NewRepository(rw, rw, kms, plgm) + } + + hc := plugin.TestCatalog(t, conn, proj.GetPublicId(), plg.GetPublicId()) + hs := plugin.TestSet(t, conn, kms, hc, plgm, plugin.WithPreferredEndpoints(prefEndpoints)) toMerge := &pbs.GetHostSetRequest{}