From 33b0021547ce0e5dd51016b127c20feff4f6e543 Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Mon, 21 Sep 2020 02:22:59 -0400 Subject: [PATCH] Add Sessions CLI command and add session cleanup to worker (#388) --- internal/cmd/commands.go | 25 +++ .../cmd/commands/authtokens/authtokens.go | 14 +- internal/cmd/commands/authtokens/funcs.go | 37 ++-- internal/cmd/commands/proxy/funcs.go | 1 + internal/cmd/commands/proxy/proxy.go | 10 +- internal/cmd/commands/sessions/funcs.go | 96 ++++++++++ internal/cmd/commands/sessions/session.go | 177 ++++++++++++++++++ internal/gen/controller.swagger.json | 20 ++ .../api/services/session_service.pb.go | 25 +-- .../api/services/session_service.pb.gw.go | 34 ++-- internal/perms/grants_test.go | 2 +- .../api/services/v1/session_service.proto | 29 +-- internal/servers/worker/handler.go | 4 +- internal/servers/worker/session.go | 41 ++-- internal/servers/worker/status.go | 60 ++++-- internal/types/resource/resource.go | 2 + internal/types/resource/resource_test.go | 4 + sdk/strutil/strutil.go | 4 +- 18 files changed, 491 insertions(+), 94 deletions(-) create mode 100644 internal/cmd/commands/sessions/funcs.go create mode 100644 internal/cmd/commands/sessions/session.go diff --git a/internal/cmd/commands.go b/internal/cmd/commands.go index 63756d3db7..2981021c6c 100644 --- a/internal/cmd/commands.go +++ b/internal/cmd/commands.go @@ -20,6 +20,7 @@ import ( "github.com/hashicorp/boundary/internal/cmd/commands/proxy" "github.com/hashicorp/boundary/internal/cmd/commands/roles" "github.com/hashicorp/boundary/internal/cmd/commands/scopes" + "github.com/hashicorp/boundary/internal/cmd/commands/sessions" "github.com/hashicorp/boundary/internal/cmd/commands/targets" "github.com/hashicorp/boundary/internal/cmd/commands/users" "github.com/hashicorp/boundary/internal/cmd/commands/worker" @@ -560,6 +561,30 @@ func initCommands(ui, serverCmdUi cli.Ui, runOpts *RunOptions) { }, nil }, + "sessions": func() (cli.Command, error) { + return &sessions.Command{ + Command: base.NewCommand(ui), + }, nil + }, + "sessions read": func() (cli.Command, error) { + return &sessions.Command{ + Command: base.NewCommand(ui), + Func: "read", + }, nil + }, + "sessions list": func() (cli.Command, error) { + return &sessions.Command{ + Command: base.NewCommand(ui), + Func: "list", + }, nil + }, + "sessions cancel": func() (cli.Command, error) { + return &sessions.Command{ + Command: base.NewCommand(ui), + Func: "cancel", + }, nil + }, + "targets": func() (cli.Command, error) { return &targets.Command{ Command: base.NewCommand(ui), diff --git a/internal/cmd/commands/authtokens/authtokens.go b/internal/cmd/commands/authtokens/authtokens.go index d26f905ccb..f3cc41a7d5 100644 --- a/internal/cmd/commands/authtokens/authtokens.go +++ b/internal/cmd/commands/authtokens/authtokens.go @@ -169,13 +169,13 @@ func (c *Command) Run(args []string) int { output = append(output, "") } output = append(output, - fmt.Sprintf(" ID: %s", t.Id), - fmt.Sprintf(" Approximate Last Used Time: %s", t.ApproximateLastUsedTime.Local().Format(time.RFC3339)), - fmt.Sprintf(" Auth Method ID: %s", t.AuthMethodId), - fmt.Sprintf(" Created Time: %s", t.CreatedTime.Local().Format(time.RFC3339)), - fmt.Sprintf(" Expiration Time: %s", t.ExpirationTime.Local().Format(time.RFC3339)), - fmt.Sprintf(" Updated Time: %s", t.UpdatedTime.Local().Format(time.RFC3339)), - fmt.Sprintf(" User ID: %s", t.UserId), + fmt.Sprintf(" ID: %s", t.Id), + fmt.Sprintf(" Approximate Last Used Time: %s", t.ApproximateLastUsedTime.Local().Format(time.RFC3339)), + fmt.Sprintf(" Auth Method ID: %s", t.AuthMethodId), + fmt.Sprintf(" Created Time: %s", t.CreatedTime.Local().Format(time.RFC3339)), + fmt.Sprintf(" Expiration Time: %s", t.ExpirationTime.Local().Format(time.RFC3339)), + fmt.Sprintf(" Updated Time: %s", t.UpdatedTime.Local().Format(time.RFC3339)), + fmt.Sprintf(" User ID: %s", t.UserId), ) } c.UI.Output(base.WrapForHelpText(output)) diff --git a/internal/cmd/commands/authtokens/funcs.go b/internal/cmd/commands/authtokens/funcs.go index 81b8493386..2c9dc3e21a 100644 --- a/internal/cmd/commands/authtokens/funcs.go +++ b/internal/cmd/commands/authtokens/funcs.go @@ -1,7 +1,6 @@ package authtokens import ( - "fmt" "time" "github.com/hashicorp/boundary/api/authtokens" @@ -9,19 +8,29 @@ import ( ) func generateAuthTokenTableOutput(in *authtokens.AuthToken) string { - var ret []string - ret = append(ret, []string{ - "", - "Auth Token information:", - fmt.Sprintf(" Approximate Last Used Time: %s", in.ApproximateLastUsedTime.Local().Format(time.RFC3339)), - fmt.Sprintf(" Auth Method ID: %s", in.AuthMethodId), - fmt.Sprintf(" Created Time: %s", in.CreatedTime.Local().Format(time.RFC3339)), - fmt.Sprintf(" Expiration Time: %s", in.ExpirationTime.Local().Format(time.RFC3339)), - fmt.Sprintf(" ID: %s", in.Id), - fmt.Sprintf(" Scope ID: %s", in.Scope.Id), - fmt.Sprintf(" Updated Time: %s", in.UpdatedTime.Local().Format(time.RFC3339)), - fmt.Sprintf(" User ID: %s", in.UserId), - }..., + nonAttributeMap := map[string]interface{}{ + "ID": in.Id, + "Scope ID": in.Scope.Id, + "Auth Method ID": in.AuthMethodId, + "User ID": in.UserId, + "Created Time": in.CreatedTime.Local().Format(time.RFC3339), + "Updated Time": in.UpdatedTime.Local().Format(time.RFC3339), + "Expiration Time": in.ExpirationTime.Local().Format(time.RFC3339), + "Approximate Last Used Time": in.ApproximateLastUsedTime.Local().Format(time.RFC3339), + } + + maxLength := 0 + for k := range nonAttributeMap { + if len(k) > maxLength { + maxLength = len(k) + } + } + + ret := []string{"", "Auth Token information:"} + + ret = append(ret, + // We do +2 because there is another +2 offset for host sets below + base.WrapMap(2, maxLength+2, nonAttributeMap), ) return base.WrapForHelpText(ret) diff --git a/internal/cmd/commands/proxy/funcs.go b/internal/cmd/commands/proxy/funcs.go index e2570f4415..6a33de3469 100644 --- a/internal/cmd/commands/proxy/funcs.go +++ b/internal/cmd/commands/proxy/funcs.go @@ -10,6 +10,7 @@ func generateSessionInfoTableOutput(in SessionInfo) string { var ret []string nonAttributeMap := map[string]interface{}{ + "Session ID": in.SessionId, "Protocol": in.Protocol, "Address": in.Address, "Port": in.Port, diff --git a/internal/cmd/commands/proxy/proxy.go b/internal/cmd/commands/proxy/proxy.go index aaf9093349..8a653bce5a 100644 --- a/internal/cmd/commands/proxy/proxy.go +++ b/internal/cmd/commands/proxy/proxy.go @@ -40,6 +40,7 @@ type SessionInfo struct { Protocol string `json:"protocol"` Expiration time.Time `json:"expiration"` ConnectionLimit int32 `json:"connection_limit"` + SessionId string `json:"session_id"` } type ConnectionInfo struct { @@ -348,6 +349,7 @@ func (c *Command) Run(args []string) (retCode int) { Port: listenerAddr.Port, Expiration: c.expiration, ConnectionLimit: data.GetConnectionLimit(), + SessionId: data.GetSessionId(), } switch base.Format(c.UI) { @@ -365,7 +367,6 @@ func (c *Command) Run(args []string) (retCode int) { c.connWg = new(sync.WaitGroup) c.connWg.Add(1) - go func() { defer c.connWg.Done() for { @@ -483,6 +484,13 @@ func (c *Command) handleConnection( } var handshakeResult proxy.HandshakeResult if err := wspb.Read(c.Context, conn, &handshakeResult); err != nil { + switch { + case strings.Contains(err.Error(), "unable to authorize connection"): + // There's no reason to think we'd be able to authorize any more + // connections after the first has failed + c.connsLeftCh <- 0 + return errors.New("Unable to authorize connection") + } return fmt.Errorf("error reading handshake result: %w", err) } diff --git a/internal/cmd/commands/sessions/funcs.go b/internal/cmd/commands/sessions/funcs.go new file mode 100644 index 0000000000..e532e34423 --- /dev/null +++ b/internal/cmd/commands/sessions/funcs.go @@ -0,0 +1,96 @@ +package sessions + +import ( + "fmt" + "time" + + "github.com/hashicorp/boundary/api/sessions" + "github.com/hashicorp/boundary/internal/cmd/base" +) + +func generateSessionTableOutput(in *sessions.Session) string { + nonAttributeMap := map[string]interface{}{ + "ID": in.Id, + "Target ID": in.TargetId, + "Scope ID": in.Scope.Id, + "Created Time": in.CreatedTime.Local().Format(time.RFC3339), + "Updated Time": in.UpdatedTime.Local().Format(time.RFC3339), + "Expiration Time": in.ExpirationTime.Local().Format(time.RFC3339), + "Version": in.Version, + "Type": in.Type, + "Auth Token ID": in.AuthTokenId, + "User ID": in.UserId, + "Host Set ID": in.HostSetId, + "Host ID": in.HostId, + "Endpoint": in.Endpoint, + "Status": in.Status, + } + + maxLength := 0 + for k := range nonAttributeMap { + if len(k) > maxLength { + maxLength = len(k) + } + } + + var statesMaps []map[string]interface{} + if len(in.States) > 0 { + for _, state := range in.States { + m := map[string]interface{}{ + "Status": state.Status, + "Start Time": state.StartTime.Local().Format(time.RFC3339), + "End Time": state.EndTime.Local().Format(time.RFC3339), + } + statesMaps = append(statesMaps, m) + } + if l := len("Start Time"); l > maxLength { + maxLength = l + } + } + + var workerInfoMaps []map[string]interface{} + if len(in.WorkerInfo) > 0 { + for _, wi := range in.WorkerInfo { + m := map[string]interface{}{ + "Address": wi.Address, + } + workerInfoMaps = append(workerInfoMaps, m) + } + if l := len("Address"); l > maxLength { + maxLength = l + } + } + + ret := []string{"", "Session information:"} + + ret = append(ret, + // We do +2 because there is another +2 offset for host sets below + base.WrapMap(2, maxLength+2, nonAttributeMap), + ) + + if len(in.States) > 0 { + ret = append(ret, + fmt.Sprintf(" States: %s", ""), + ) + for _, m := range statesMaps { + ret = append(ret, + base.WrapMap(4, maxLength, m), + "", + ) + } + } + + if len(in.WorkerInfo) > 0 { + ret = append(ret, + fmt.Sprintf(" Worker Info: %s", ""), + ) + for _, m := range workerInfoMaps { + ret = append(ret, + base.WrapMap(4, maxLength, m), + "", + ) + } + } + + return base.WrapForHelpText(ret) +} diff --git a/internal/cmd/commands/sessions/session.go b/internal/cmd/commands/sessions/session.go new file mode 100644 index 0000000000..41e7eb6616 --- /dev/null +++ b/internal/cmd/commands/sessions/session.go @@ -0,0 +1,177 @@ +package sessions + +import ( + "fmt" + "time" + + "github.com/hashicorp/boundary/api" + "github.com/hashicorp/boundary/api/sessions" + "github.com/hashicorp/boundary/internal/cmd/base" + "github.com/hashicorp/boundary/internal/cmd/common" + "github.com/hashicorp/boundary/internal/types/resource" + "github.com/hashicorp/boundary/sdk/strutil" + "github.com/kr/pretty" + "github.com/mitchellh/cli" + "github.com/posener/complete" +) + +var _ cli.Command = (*Command)(nil) +var _ cli.CommandAutocomplete = (*Command)(nil) + +type Command struct { + *base.Command + + Func string +} + +func (c *Command) Synopsis() string { + return common.SynopsisFunc(c.Func, "session") +} + +var flagsMap = map[string][]string{ + "read": {"id"}, + "cancel": {"id"}, + "list": {"scope-id"}, +} + +func (c *Command) Help() string { + helpMap := common.HelpMap("session") + if c.Func == "" { + return helpMap["base"]() + } + return helpMap[c.Func]() + c.Flags().Help() +} + +func (c *Command) Flags() *base.FlagSets { + set := c.FlagSet(base.FlagSetHTTP | base.FlagSetClient | base.FlagSetOutputFormat) + + if len(flagsMap[c.Func]) > 0 { + f := set.NewFlagSet("Command Options") + common.PopulateCommonFlags(c.Command, f, resource.Session.String(), flagsMap[c.Func]) + } + + return set +} + +func (c *Command) AutocompleteArgs() complete.Predictor { + return complete.PredictAnything +} + +func (c *Command) AutocompleteFlags() complete.Flags { + return c.Flags().Completions() +} + +func (c *Command) Run(args []string) int { + if c.Func == "" { + return cli.RunResultHelp + } + + f := c.Flags() + + if err := f.Parse(args); err != nil { + c.UI.Error(err.Error()) + return 1 + } + + if strutil.StrListContains(flagsMap[c.Func], "id") && c.FlagId == "" { + c.UI.Error("ID is required but not passed in via -id") + return 1 + } + if strutil.StrListContains(flagsMap[c.Func], "scope-id") && c.FlagScopeId == "" { + c.UI.Error("Scope ID must be passed in via -scope-id") + return 1 + } + + client, err := c.Client() + if err != nil { + c.UI.Error(fmt.Sprintf("Error creating API client: %s", err.Error())) + return 2 + } + + sessionClient := sessions.NewClient(client) + + var result api.GenericResult + var listResult api.GenericListResult + var apiErr *api.Error + + switch c.Func { + case "read": + result, apiErr, err = sessionClient.Read(c.Context, c.FlagId) + case "cancel": + result, apiErr, err = sessionClient.Cancel(c.Context, c.FlagId, 0, sessions.WithAutomaticVersioning(true)) + case "list": + listResult, apiErr, err = sessionClient.List(c.Context, c.FlagScopeId) + } + + plural := "session" + if c.Func == "list" { + plural = "sessions" + } + if err != nil { + c.UI.Error(fmt.Sprintf("Error trying to %s %s: %s", c.Func, plural, err.Error())) + return 2 + } + if apiErr != nil { + c.UI.Error(fmt.Sprintf("Error from controller when performing %s on %s: %s", c.Func, plural, pretty.Sprint(apiErr))) + return 1 + } + + switch c.Func { + case "list": + listedSessions := listResult.GetItems().([]*sessions.Session) + switch base.Format(c.UI) { + case "json": + if len(listedSessions) == 0 { + c.UI.Output("null") + return 0 + } + b, err := base.JsonFormatter{}.Format(listedSessions) + if err != nil { + c.UI.Error(fmt.Errorf("Error formatting as JSON: %w", err).Error()) + return 1 + } + c.UI.Output(string(b)) + + case "table": + if len(listedSessions) == 0 { + c.UI.Output("No auth tokens found") + return 0 + } + var output []string + output = []string{ + "", + "Session information:", + } + for i, t := range listedSessions { + if i > 0 { + output = append(output, "") + } + output = append(output, + fmt.Sprintf(" ID: %s", t.Id), + fmt.Sprintf(" Created Time: %s", t.CreatedTime.Local().Format(time.RFC3339)), + fmt.Sprintf(" Expiration Time: %s", t.ExpirationTime.Local().Format(time.RFC3339)), + fmt.Sprintf(" Updated Time: %s", t.UpdatedTime.Local().Format(time.RFC3339)), + fmt.Sprintf(" User ID: %s", t.UserId), + fmt.Sprintf(" Target ID: %s", t.UserId), + ) + } + c.UI.Output(base.WrapForHelpText(output)) + } + return 0 + } + + sess := result.GetItem().(*sessions.Session) + switch base.Format(c.UI) { + case "table": + c.UI.Output(generateSessionTableOutput(sess)) + case "json": + b, err := base.JsonFormatter{}.Format(sess) + if err != nil { + c.UI.Error(fmt.Errorf("Error formatting as JSON: %w", err).Error()) + return 1 + } + c.UI.Output(string(b)) + } + + return 0 +} diff --git a/internal/gen/controller.swagger.json b/internal/gen/controller.swagger.json index f9958f0328..75c45d03c7 100644 --- a/internal/gen/controller.swagger.json +++ b/internal/gen/controller.swagger.json @@ -1895,6 +1895,14 @@ "in": "path", "required": true, "type": "string" + }, + { + "name": "body", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/controller.api.services.v1.CancelSessionRequest" + } } ], "tags": [ @@ -3465,6 +3473,18 @@ } } }, + "controller.api.services.v1.CancelSessionRequest": { + "type": "object", + "properties": { + "id": { + "type": "string" + }, + "version": { + "type": "integer", + "format": "int64" + } + } + }, "controller.api.services.v1.CancelSessionResponse": { "type": "object", "properties": { diff --git a/internal/gen/controller/api/services/session_service.pb.go b/internal/gen/controller/api/services/session_service.pb.go index 95f06f2b48..c398382efb 100644 --- a/internal/gen/controller/api/services/session_service.pb.go +++ b/internal/gen/controller/api/services/session_service.pb.go @@ -386,7 +386,7 @@ var file_controller_api_services_v1_session_service_proto_rawDesc = []byte{ 0x32, 0x2d, 0x2e, 0x63, 0x6f, 0x6e, 0x74, 0x72, 0x6f, 0x6c, 0x6c, 0x65, 0x72, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x72, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x73, 0x2e, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x73, 0x2e, 0x76, 0x31, 0x2e, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x52, - 0x04, 0x69, 0x74, 0x65, 0x6d, 0x32, 0x8f, 0x04, 0x0a, 0x0e, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, + 0x04, 0x69, 0x74, 0x65, 0x6d, 0x32, 0x92, 0x04, 0x0a, 0x0e, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0xa6, 0x01, 0x0a, 0x0a, 0x47, 0x65, 0x74, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x2d, 0x2e, 0x63, 0x6f, 0x6e, 0x74, 0x72, 0x6f, 0x6c, 0x6c, 0x65, 0x72, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, @@ -408,23 +408,24 @@ var file_controller_api_services_v1_session_service_proto_rawDesc = []byte{ 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x2b, 0x82, 0xd3, 0xe4, 0x93, 0x02, 0x0e, 0x12, 0x0c, 0x2f, 0x76, 0x31, 0x2f, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x73, 0x92, 0x41, 0x14, 0x12, 0x12, 0x4c, 0x69, 0x73, 0x74, 0x73, 0x20, 0x61, 0x6c, 0x6c, 0x20, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, - 0x6e, 0x73, 0x12, 0xb2, 0x01, 0x0a, 0x0d, 0x43, 0x61, 0x6e, 0x63, 0x65, 0x6c, 0x53, 0x65, 0x73, + 0x6e, 0x73, 0x12, 0xb5, 0x01, 0x0a, 0x0d, 0x43, 0x61, 0x6e, 0x63, 0x65, 0x6c, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x30, 0x2e, 0x63, 0x6f, 0x6e, 0x74, 0x72, 0x6f, 0x6c, 0x6c, 0x65, 0x72, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x73, 0x2e, 0x76, 0x31, 0x2e, 0x43, 0x61, 0x6e, 0x63, 0x65, 0x6c, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x31, 0x2e, 0x63, 0x6f, 0x6e, 0x74, 0x72, 0x6f, 0x6c, 0x6c, 0x65, 0x72, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x73, 0x2e, 0x76, 0x31, 0x2e, 0x43, 0x61, 0x6e, 0x63, 0x65, 0x6c, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, - 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x3c, 0x82, 0xd3, 0xe4, 0x93, 0x02, - 0x20, 0x22, 0x18, 0x2f, 0x76, 0x31, 0x2f, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x73, 0x2f, - 0x7b, 0x69, 0x64, 0x7d, 0x3a, 0x63, 0x61, 0x6e, 0x63, 0x65, 0x6c, 0x62, 0x04, 0x69, 0x74, 0x65, - 0x6d, 0x92, 0x41, 0x13, 0x12, 0x11, 0x43, 0x61, 0x6e, 0x63, 0x65, 0x6c, 0x73, 0x20, 0x61, 0x20, - 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x42, 0x4d, 0x5a, 0x4b, 0x67, 0x69, 0x74, 0x68, 0x75, - 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x68, 0x61, 0x73, 0x68, 0x69, 0x63, 0x6f, 0x72, 0x70, 0x2f, - 0x62, 0x6f, 0x75, 0x6e, 0x64, 0x61, 0x72, 0x79, 0x2f, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x61, - 0x6c, 0x2f, 0x67, 0x65, 0x6e, 0x2f, 0x63, 0x6f, 0x6e, 0x74, 0x72, 0x6f, 0x6c, 0x6c, 0x65, 0x72, - 0x2f, 0x61, 0x70, 0x69, 0x2f, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x73, 0x3b, 0x73, 0x65, - 0x72, 0x76, 0x69, 0x63, 0x65, 0x73, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x3f, 0x82, 0xd3, 0xe4, 0x93, 0x02, + 0x23, 0x22, 0x18, 0x2f, 0x76, 0x31, 0x2f, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x73, 0x2f, + 0x7b, 0x69, 0x64, 0x7d, 0x3a, 0x63, 0x61, 0x6e, 0x63, 0x65, 0x6c, 0x3a, 0x01, 0x2a, 0x62, 0x04, + 0x69, 0x74, 0x65, 0x6d, 0x92, 0x41, 0x13, 0x12, 0x11, 0x43, 0x61, 0x6e, 0x63, 0x65, 0x6c, 0x73, + 0x20, 0x61, 0x20, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x42, 0x4d, 0x5a, 0x4b, 0x67, 0x69, + 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x68, 0x61, 0x73, 0x68, 0x69, 0x63, 0x6f, + 0x72, 0x70, 0x2f, 0x62, 0x6f, 0x75, 0x6e, 0x64, 0x61, 0x72, 0x79, 0x2f, 0x69, 0x6e, 0x74, 0x65, + 0x72, 0x6e, 0x61, 0x6c, 0x2f, 0x67, 0x65, 0x6e, 0x2f, 0x63, 0x6f, 0x6e, 0x74, 0x72, 0x6f, 0x6c, + 0x6c, 0x65, 0x72, 0x2f, 0x61, 0x70, 0x69, 0x2f, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x73, + 0x3b, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x73, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x33, } var ( diff --git a/internal/gen/controller/api/services/session_service.pb.gw.go b/internal/gen/controller/api/services/session_service.pb.gw.go index 949e9def9b..97007da5e9 100644 --- a/internal/gen/controller/api/services/session_service.pb.gw.go +++ b/internal/gen/controller/api/services/session_service.pb.gw.go @@ -135,14 +135,18 @@ func local_request_SessionService_ListSessions_0(ctx context.Context, marshaler } -var ( - filter_SessionService_CancelSession_0 = &utilities.DoubleArray{Encoding: map[string]int{"id": 0}, Base: []int{1, 1, 0}, Check: []int{0, 1, 2}} -) - func request_SessionService_CancelSession_0(ctx context.Context, marshaler runtime.Marshaler, client SessionServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { var protoReq CancelSessionRequest var metadata runtime.ServerMetadata + newReader, berr := utilities.IOReaderFactory(req.Body) + if berr != nil { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", berr) + } + if err := marshaler.NewDecoder(newReader()).Decode(&protoReq); err != nil && err != io.EOF { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + var ( val string ok bool @@ -160,13 +164,6 @@ func request_SessionService_CancelSession_0(ctx context.Context, marshaler runti return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", "id", err) } - if err := req.ParseForm(); err != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) - } - if err := runtime.PopulateQueryParameters(&protoReq, req.Form, filter_SessionService_CancelSession_0); err != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) - } - msg, err := client.CancelSession(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) return msg, metadata, err @@ -176,6 +173,14 @@ func local_request_SessionService_CancelSession_0(ctx context.Context, marshaler var protoReq CancelSessionRequest var metadata runtime.ServerMetadata + newReader, berr := utilities.IOReaderFactory(req.Body) + if berr != nil { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", berr) + } + if err := marshaler.NewDecoder(newReader()).Decode(&protoReq); err != nil && err != io.EOF { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + var ( val string ok bool @@ -193,13 +198,6 @@ func local_request_SessionService_CancelSession_0(ctx context.Context, marshaler return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", "id", err) } - if err := req.ParseForm(); err != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) - } - if err := runtime.PopulateQueryParameters(&protoReq, req.Form, filter_SessionService_CancelSession_0); err != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) - } - msg, err := server.CancelSession(ctx, &protoReq) return msg, metadata, err diff --git a/internal/perms/grants_test.go b/internal/perms/grants_test.go index dc7014ca3f..8e48ef4e87 100644 --- a/internal/perms/grants_test.go +++ b/internal/perms/grants_test.go @@ -125,7 +125,7 @@ func Test_ValidateType(t *testing.T) { } } -func Test_MarshallingAndCloning(t *testing.T) { +func Test_MarshalingAndCloning(t *testing.T) { t.Parallel() type input struct { diff --git a/internal/proto/local/controller/api/services/v1/session_service.proto b/internal/proto/local/controller/api/services/v1/session_service.proto index f6eb2a407b..c908cd4875 100644 --- a/internal/proto/local/controller/api/services/v1/session_service.proto +++ b/internal/proto/local/controller/api/services/v1/session_service.proto @@ -16,12 +16,12 @@ service SessionService { // resource an error is returned. rpc GetSession(GetSessionRequest) returns (GetSessionResponse) { option (google.api.http) = { - get: "/v1/sessions/{id}" - response_body: "item" - }; + get: "/v1/sessions/{id}" + response_body: "item" + }; option (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_operation) = { - summary: "Gets a single Session" - }; + summary: "Gets a single Session" + }; } // ListSessions returns a list of stored sessions which exist inside the project @@ -30,11 +30,11 @@ service SessionService { // reference a non existing scope, an error is returned. rpc ListSessions(ListSessionsRequest) returns (ListSessionsResponse) { option (google.api.http) = { - get: "/v1/sessions" - }; + get: "/v1/sessions" + }; option (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_operation) = { - summary: "Lists all Sessions" - }; + summary: "Lists all Sessions" + }; } // CancelSession cancels an existing session in boundary. An error @@ -42,12 +42,13 @@ service SessionService { // not exist. rpc CancelSession(CancelSessionRequest) returns (CancelSessionResponse) { option (google.api.http) = { - post: "/v1/sessions/{id}:cancel" - response_body: "item" - }; + post: "/v1/sessions/{id}:cancel" + body: "*" + response_body: "item" + }; option (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_operation) = { - summary: "Cancels a Session" - }; + summary: "Cancels a Session" + }; } } diff --git a/internal/servers/worker/handler.go b/internal/servers/worker/handler.go index 23e58d55da..65a9a241b8 100644 --- a/internal/servers/worker/handler.go +++ b/internal/servers/worker/handler.go @@ -139,7 +139,9 @@ func (w *Worker) handleProxy() http.HandlerFunc { defer func() { connectionId := ci.id - if err := w.closeConnections(r.Context(), si, []string{connectionId}); err != nil { + if err := w.closeConnections(r.Context(), map[string]string{ + connectionId: si.id, + }); err != nil { w.logger.Error("error marking connection closed", "error", err, "connection_id", connectionId) } }() diff --git a/internal/servers/worker/session.go b/internal/servers/worker/session.go index be60ed27a0..b9c54291dc 100644 --- a/internal/servers/worker/session.go +++ b/internal/servers/worker/session.go @@ -24,10 +24,12 @@ type connInfo struct { connCtx context.Context connCancel context.CancelFunc status pbs.CONNECTIONSTATUS + closeTime time.Time } type sessionInfo struct { sync.RWMutex + id string sessionTls *tls.Config status pbs.SESSIONSTATUS lookupSessionResponse *pbs.LookupSessionResponse @@ -102,6 +104,7 @@ func (w *Worker) getSessionTls(hello *tls.ClientHelloInfo) (*tls.Config, error) } si := &sessionInfo{ + id: resp.GetAuthorization().GetSessionId(), sessionTls: tlsConf, lookupSessionResponse: resp, status: resp.GetStatus(), @@ -224,13 +227,13 @@ func (w *Worker) closeConnection(ctx context.Context, req *pbs.CloseConnectionRe return resp, nil } -func (w *Worker) closeConnections(ctx context.Context, si *sessionInfo, connectionIds []string) error { - w.logger.Trace("marking connection as closed", "connection_ids", connectionIds) +func (w *Worker) closeConnections(ctx context.Context, closeMap map[string]string) error { + w.logger.Trace("marking connections as closed", "session_and_connection_ids", fmt.Sprint("%#v", closeMap)) - closeData := make([]*pbs.CloseConnectionRequestData, 0, len(connectionIds)) - for _, v := range connectionIds { + closeData := make([]*pbs.CloseConnectionRequestData, 0, len(closeMap)) + for connId := range closeMap { closeData = append(closeData, &pbs.CloseConnectionRequestData{ - ConnectionId: v, + ConnectionId: connId, Reason: session.UnknownReason.String(), }) } @@ -243,17 +246,29 @@ func (w *Worker) closeConnections(ctx context.Context, si *sessionInfo, connecti return err } closedIds := make([]string, 0, len(connStatus.GetCloseResponseData())) - si.Lock() + + // Here we build a reverse map from closeMap, that is, session ID to + // connection IDs, for more efficient locking + revMap := make(map[string][]*pbs.CloseConnectionResponseData) for _, v := range connStatus.GetCloseResponseData() { - closedIds = append(closedIds, v.GetConnectionId()) - if v.GetStatus() == pbs.CONNECTIONSTATUS_CONNECTIONSTATUS_CLOSED { - delete(si.connInfoMap, v.GetConnectionId()) - } else { - ci := si.connInfoMap[v.GetConnectionId()] - ci.status = v.GetStatus() + revMap[closeMap[v.GetConnectionId()]] = append(revMap[closeMap[v.GetConnectionId()]], v) + } + for k, v := range revMap { + siRaw, ok := w.sessionInfoMap.Load(k) + if !ok { + w.logger.Warn("could not find session ID in info map after closing connections", "session_id", k) + } + si := siRaw.(*sessionInfo) + si.Lock() + for _, connResult := range v { + ci := si.connInfoMap[connResult.GetConnectionId()] + ci.status = connResult.GetStatus() + if ci.status == pbs.CONNECTIONSTATUS_CONNECTIONSTATUS_CLOSED { + ci.closeTime = time.Now() + } } + si.Unlock() } - si.Unlock() w.logger.Trace("connections successfully marked closed", "connection_ids", closedIds) return nil } diff --git a/internal/servers/worker/status.go b/internal/servers/worker/status.go index 065eaf6266..3faee341a6 100644 --- a/internal/servers/worker/status.go +++ b/internal/servers/worker/status.go @@ -45,6 +45,8 @@ func (w *Worker) startStatusTicking(cancelCtx context.Context) { return case <-timer.C: + // First send info as-is. We'll perform cleanup duties after we + // get cancel/job change info back. var activeJobs []*pbs.JobStatus w.sessionInfoMap.Range(func(key, value interface{}) bool { var jobInfo pbs.SessionJobInfo @@ -119,24 +121,60 @@ func (w *Worker) startStatusTicking(cancelCtx context.Context) { continue } si := siRaw.(*sessionInfo) - closeIds := make([]string, 0, len(result.GetJobsRequests())) si.Lock() si.status = sessInfo.GetStatus() - if request.GetRequestType() == pbs.CHANGETYPE_CHANGETYPE_CANCEL { - for k, v := range si.connInfoMap { - v.connCancel() - w.logger.Info("terminated connection", "session_id", sessionId, "connection_id", k) - closeIds = append(closeIds, k) - } - } si.Unlock() - if err := w.closeConnections(cancelCtx, si, closeIds); err != nil { - w.logger.Error("error marking connections closed", "error", err, "connection_ids", closeIds) - } } } } } + + // Cleanup: Run through current jobs. Cancel connections for any + // canceling session or any session that is expired. Clear out + // sessions that are canceled or expired with all connections + // marked as closed. Close any that aren't marked as such. + closeInfo := make(map[string]string) + cleanSessionIds := make([]string, 0) + w.sessionInfoMap.Range(func(key, value interface{}) bool { + si := value.(*sessionInfo) + si.Lock() + if time.Until(si.lookupSessionResponse.Expiration.AsTime()) < 0 || + si.status == pbs.SESSIONSTATUS_SESSIONSTATUS_CANCELLING { + var toClose int + for k, v := range si.connInfoMap { + if v.closeTime.IsZero() { + toClose++ + v.connCancel() + w.logger.Info("terminated connection due to cancelation or expiration", "session_id", si.id, "connection_id", k) + closeInfo[k] = si.id + } + } + // closeTime is marked by closeConnections iff the + // status is returned for that connection as closed. If + // the session is no longer valid and all connections + // are marked closed, clean up the session. + if toClose == 0 { + cleanSessionIds = append(cleanSessionIds, si.id) + } + } + si.Unlock() + return true + }) + + // Note that we won't clean these from the info map until the + // next time we run this function + if len(closeInfo) > 0 { + if err := w.closeConnections(cancelCtx, closeInfo); err != nil { + w.logger.Error("error marking connections closed", "error", err) + } + } + + // Forget sessions where the session is expired/canceled and all + // connections are canceled and marked closed + for _, v := range cleanSessionIds { + w.sessionInfoMap.Delete(v) + } + timer.Reset(getRandomInterval()) } } diff --git a/internal/types/resource/resource.go b/internal/types/resource/resource.go index 9cde3413f2..8902d843f7 100644 --- a/internal/types/resource/resource.go +++ b/internal/types/resource/resource.go @@ -39,6 +39,7 @@ func (r Type) String() string { "target", "controller", "worker", + "session", }[r] } @@ -58,4 +59,5 @@ var Map = map[string]Type{ Target.String(): Target, Controller.String(): Controller, Worker.String(): Worker, + Session.String(): Session, } diff --git a/internal/types/resource/resource_test.go b/internal/types/resource/resource_test.go index 2e61a68c5d..bf5643e271 100644 --- a/internal/types/resource/resource_test.go +++ b/internal/types/resource/resource_test.go @@ -72,6 +72,10 @@ func Test_Resource(t *testing.T) { typeString: "worker", want: Worker, }, + { + typeString: "session", + want: Session, + }, } for _, tt := range tests { t.Run(tt.typeString, func(t *testing.T) { diff --git a/sdk/strutil/strutil.go b/sdk/strutil/strutil.go index 0069a1377b..6a3be0451b 100644 --- a/sdk/strutil/strutil.go +++ b/sdk/strutil/strutil.go @@ -128,7 +128,7 @@ func ParseArbitraryKeyValues(input string, out map[string]string, sep string) er // metadata was supplied as JSON input. err = json.Unmarshal([]byte(input), &out) if err != nil { - // If JSON unmarshalling fails, consider that the input was + // If JSON unmarshaling fails, consider that the input was // supplied as a comma separated string of 'key=value' pairs. if err = ParseKeyValues(input, out, sep); err != nil { return errwrap.Wrapf("failed to parse the input: {{err}}", err) @@ -194,7 +194,7 @@ func ParseArbitraryStringSlice(input string, sep string) []string { // metadata was supplied as JSON input. err = json.Unmarshal([]byte(input), &ret) if err != nil { - // If JSON unmarshalling fails, consider that the input was + // If JSON unmarshaling fails, consider that the input was // supplied as a separated string of values. return ParseStringSlice(input, sep) }