From 223591d835ad644f3a7d7c4e282690b47a848de3 Mon Sep 17 00:00:00 2001 From: Jim Date: Sat, 19 Sep 2020 21:05:56 -0400 Subject: [PATCH] return connection authz info from session.AuthorizeConnection (#380) --- internal/session/query.go | 43 +++++++++++++----- internal/session/repository_session.go | 49 ++++++++++++++++++--- internal/session/repository_session_test.go | 46 ++++++++++++------- 3 files changed, 103 insertions(+), 35 deletions(-) diff --git a/internal/session/query.go b/internal/session/query.go index a2ecca0c4e..898a749fc7 100644 --- a/internal/session/query.go +++ b/internal/session/query.go @@ -66,7 +66,11 @@ with active_session as ( -- check that the session hasn't expired. s.expiration_time > now() and -- check that there are still connections available. connection_limit of 0 equals unlimited connections - (s.connection_limit = 0 or s.connection_limit > (select count(*) from session_connection sc where sc.session_id = $1)) and + ( + s.connection_limit = 0 + or + s.connection_limit > (select count(*) from session_connection sc where sc.session_id = $1) + ) and -- check that there's a state of active s.public_id in ( select @@ -75,19 +79,34 @@ with active_session as ( session_state ss where ss.session_id = $1 and - ss.state = 'active' - ) and - -- check that there are no cancelling or terminated states - s.public_id not in( - select - ss.session_id - from - session_state ss - where - ss.session_id = $1 and - ss.state in('cancelling', 'terminated') + ss.state = 'active' and + -- if there's no end_time, then this is the current state. + ss.end_time is null ) ) select * from active_session; +` + + remainingConnectionsCte = ` +with +session_connection_count(current_connection_count) as ( + select count(*) + from + session_connection sc + where + sc.session_id = $1 +), +session_connection_limit(expiration_time, connection_limit) as ( + select + s.expiration_time, + s.connection_limit + from + session s + where + s.public_id = $1 +) +select expiration_time, connection_limit, current_connection_count +from + session_connection_limit, session_connection_count; ` ) diff --git a/internal/session/repository_session.go b/internal/session/repository_session.go index 3c6b44eb32..9d1ce7c44d 100644 --- a/internal/session/repository_session.go +++ b/internal/session/repository_session.go @@ -9,6 +9,7 @@ import ( "strings" "github.com/hashicorp/boundary/internal/db" + "github.com/hashicorp/boundary/internal/db/timestamp" "github.com/hashicorp/boundary/internal/kms" wrapping "github.com/hashicorp/go-kms-wrapping" ) @@ -281,18 +282,18 @@ func (r *Repository) TerminateSession(ctx context.Context, sessionId string, ses // AuthorizeConnection will check to see if a connection is allowed. Currently, // that authorization checks: -// * the hasn't expired based on the session.Expiration -// * number of connections already created is less than session.ConnectionLimit +// * the hasn't expired based on the session.Expiration +// * number of connections already created is less than session.ConnectionLimit // If authorization is success, it creates/stores a new connection in the repo // and returns it, along with it's states. If the authorization fails, it // an error of ErrInvalidStateForOperation. -func (r *Repository) AuthorizeConnection(ctx context.Context, sessionId string) (*Connection, []*ConnectionState, error) { +func (r *Repository) AuthorizeConnection(ctx context.Context, sessionId string) (*Connection, []*ConnectionState, *ConnectionAuthzSummary, error) { if sessionId == "" { - return nil, nil, fmt.Errorf("authorize connection: missing session id: %w", db.ErrInvalidParameter) + return nil, nil, nil, fmt.Errorf("authorize connection: missing session id: %w", db.ErrInvalidParameter) } connectionId, err := newConnectionId() if err != nil { - return nil, nil, fmt.Errorf("authorize connection: %w", err) + return nil, nil, nil, fmt.Errorf("authorize connection: %w", err) } connection := AllocConnection() @@ -321,9 +322,43 @@ func (r *Repository) AuthorizeConnection(ctx context.Context, sessionId string) }, ) if err != nil { - return nil, nil, fmt.Errorf("authorize connection: %w", err) + return nil, nil, nil, fmt.Errorf("authorize connection: %w", err) } - return &connection, connectionStates, nil + authzSummary, err := r.sessionAuthzSummary(ctx, connection.SessionId) + if err != nil { + return nil, nil, nil, fmt.Errorf("authorize connection: %w", err) + } + return &connection, connectionStates, authzSummary, nil +} + +type ConnectionAuthzSummary struct { + ExpirationTime *timestamp.Timestamp + ConnectionLimit uint32 + CurrentConnectionCount uint32 +} + +func (r *Repository) sessionAuthzSummary(ctx context.Context, sessionId string) (*ConnectionAuthzSummary, error) { + tx, err := r.reader.DB() + if err != nil { + return nil, fmt.Errorf("session summary: unable to get DB: %w", err) + } + rows, err := tx.QueryContext(ctx, remainingConnectionsCte, sessionId) + if err != nil { + return nil, fmt.Errorf("session summary: query failed: %w", err) + } + defer rows.Close() + + var info *ConnectionAuthzSummary + for rows.Next() { + if info != nil { + return nil, fmt.Errorf("session summary: query returned more than one row") + } + info = &ConnectionAuthzSummary{} + if err := r.reader.ScanRows(rows, info); err != nil { + return nil, fmt.Errorf("session summary: scan row failed: %w", err) + } + } + return info, nil } // ConnectSession updates a connection in the repo with a state of "connected". diff --git a/internal/session/repository_session_test.go b/internal/session/repository_session_test.go index ac9d3e6b7a..d1a320c076 100644 --- a/internal/session/repository_session_test.go +++ b/internal/session/repository_session_test.go @@ -409,7 +409,7 @@ func TestRepository_AuthorizeConnect(t *testing.T) { repo, err := NewRepository(rw, rw, kms) require.NoError(t, err) - setupFn := func(exp *timestamp.Timestamp) string { + setupFn := func(exp *timestamp.Timestamp) *Session { composedOf := TestSessionParams(t, conn, wrapper, iamRepo) if exp != nil { composedOf.ExpirationTime = exp @@ -419,37 +419,48 @@ func TestRepository_AuthorizeConnect(t *testing.T) { tofu := TestTofu(t) _, _, err := repo.ActivateSession(context.Background(), s.PublicId, s.Version, srv.PrivateId, srv.Type, tofu) require.NoError(t, err) - return s.PublicId + return s } + testSession := setupFn(nil) + tests := []struct { - name string - sessionId string - wantErr bool - wantIsError error + name string + session *Session + wantErr bool + wantIsError error + wantAuthzInfo ConnectionAuthzSummary }{ { - name: "valid", - sessionId: setupFn(nil), + name: "valid", + session: testSession, + wantAuthzInfo: ConnectionAuthzSummary{ + ConnectionLimit: 1, + CurrentConnectionCount: 1, + ExpirationTime: testSession.ExpirationTime, + }, }, { - name: "empty-sessionId", - sessionId: "", + name: "empty-sessionId", + session: func() *Session { + s := AllocSession() + return &s + }(), wantErr: true, wantIsError: db.ErrInvalidParameter, }, { name: "exceeded-connection-limit", - sessionId: func() string { - sessionId := setupFn(nil) - _ = TestConnection(t, conn, sessionId, "127.0.0.1", 22, "127.0.0.1", 2222) - return sessionId + session: func() *Session { + session := setupFn(nil) + _ = TestConnection(t, conn, session.PublicId, "127.0.0.1", 22, "127.0.0.1", 2222) + return session }(), wantErr: true, wantIsError: ErrInvalidStateForOperation, }, { name: "expired-session", - sessionId: setupFn(×tamp.Timestamp{Timestamp: ptypes.TimestampNow()}), + session: setupFn(×tamp.Timestamp{Timestamp: ptypes.TimestampNow()}), wantErr: true, wantIsError: ErrInvalidStateForOperation, }, @@ -458,7 +469,7 @@ func TestRepository_AuthorizeConnect(t *testing.T) { t.Run(tt.name, func(t *testing.T) { assert, require := assert.New(t), require.New(t) - c, cs, err := repo.AuthorizeConnection(context.Background(), tt.sessionId) + c, cs, authzInfo, err := repo.AuthorizeConnection(context.Background(), tt.session.PublicId) if tt.wantErr { require.Error(err) if tt.wantIsError != nil { @@ -470,6 +481,9 @@ func TestRepository_AuthorizeConnect(t *testing.T) { require.NotNil(c) require.NotNil(cs) assert.Equal(StatusAuthorized.String(), cs[0].Status) + assert.Equal(tt.wantAuthzInfo.ExpirationTime, authzInfo.ExpirationTime) + assert.Equal(tt.wantAuthzInfo.ConnectionLimit, authzInfo.ConnectionLimit) + assert.Equal(tt.wantAuthzInfo.CurrentConnectionCount, authzInfo.CurrentConnectionCount) }) } }