return connection authz info from session.AuthorizeConnection (#380)

pull/381/head
Jim 6 years ago committed by GitHub
parent 66400c9cff
commit 223591d835
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -66,7 +66,11 @@ with active_session as (
-- check that the session hasn't expired.
s.expiration_time > now() and
-- check that there are still connections available. connection_limit of 0 equals unlimited connections
(s.connection_limit = 0 or s.connection_limit > (select count(*) from session_connection sc where sc.session_id = $1)) and
(
s.connection_limit = 0
or
s.connection_limit > (select count(*) from session_connection sc where sc.session_id = $1)
) and
-- check that there's a state of active
s.public_id in (
select
@ -75,19 +79,34 @@ with active_session as (
session_state ss
where
ss.session_id = $1 and
ss.state = 'active'
) and
-- check that there are no cancelling or terminated states
s.public_id not in(
select
ss.session_id
from
session_state ss
where
ss.session_id = $1 and
ss.state in('cancelling', 'terminated')
ss.state = 'active' and
-- if there's no end_time, then this is the current state.
ss.end_time is null
)
)
select * from active_session;
`
remainingConnectionsCte = `
with
session_connection_count(current_connection_count) as (
select count(*)
from
session_connection sc
where
sc.session_id = $1
),
session_connection_limit(expiration_time, connection_limit) as (
select
s.expiration_time,
s.connection_limit
from
session s
where
s.public_id = $1
)
select expiration_time, connection_limit, current_connection_count
from
session_connection_limit, session_connection_count;
`
)

@ -9,6 +9,7 @@ import (
"strings"
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/db/timestamp"
"github.com/hashicorp/boundary/internal/kms"
wrapping "github.com/hashicorp/go-kms-wrapping"
)
@ -281,18 +282,18 @@ func (r *Repository) TerminateSession(ctx context.Context, sessionId string, ses
// AuthorizeConnection will check to see if a connection is allowed. Currently,
// that authorization checks:
// * the hasn't expired based on the session.Expiration
// * number of connections already created is less than session.ConnectionLimit
// * the hasn't expired based on the session.Expiration
// * number of connections already created is less than session.ConnectionLimit
// If authorization is success, it creates/stores a new connection in the repo
// and returns it, along with it's states. If the authorization fails, it
// an error of ErrInvalidStateForOperation.
func (r *Repository) AuthorizeConnection(ctx context.Context, sessionId string) (*Connection, []*ConnectionState, error) {
func (r *Repository) AuthorizeConnection(ctx context.Context, sessionId string) (*Connection, []*ConnectionState, *ConnectionAuthzSummary, error) {
if sessionId == "" {
return nil, nil, fmt.Errorf("authorize connection: missing session id: %w", db.ErrInvalidParameter)
return nil, nil, nil, fmt.Errorf("authorize connection: missing session id: %w", db.ErrInvalidParameter)
}
connectionId, err := newConnectionId()
if err != nil {
return nil, nil, fmt.Errorf("authorize connection: %w", err)
return nil, nil, nil, fmt.Errorf("authorize connection: %w", err)
}
connection := AllocConnection()
@ -321,9 +322,43 @@ func (r *Repository) AuthorizeConnection(ctx context.Context, sessionId string)
},
)
if err != nil {
return nil, nil, fmt.Errorf("authorize connection: %w", err)
return nil, nil, nil, fmt.Errorf("authorize connection: %w", err)
}
return &connection, connectionStates, nil
authzSummary, err := r.sessionAuthzSummary(ctx, connection.SessionId)
if err != nil {
return nil, nil, nil, fmt.Errorf("authorize connection: %w", err)
}
return &connection, connectionStates, authzSummary, nil
}
type ConnectionAuthzSummary struct {
ExpirationTime *timestamp.Timestamp
ConnectionLimit uint32
CurrentConnectionCount uint32
}
func (r *Repository) sessionAuthzSummary(ctx context.Context, sessionId string) (*ConnectionAuthzSummary, error) {
tx, err := r.reader.DB()
if err != nil {
return nil, fmt.Errorf("session summary: unable to get DB: %w", err)
}
rows, err := tx.QueryContext(ctx, remainingConnectionsCte, sessionId)
if err != nil {
return nil, fmt.Errorf("session summary: query failed: %w", err)
}
defer rows.Close()
var info *ConnectionAuthzSummary
for rows.Next() {
if info != nil {
return nil, fmt.Errorf("session summary: query returned more than one row")
}
info = &ConnectionAuthzSummary{}
if err := r.reader.ScanRows(rows, info); err != nil {
return nil, fmt.Errorf("session summary: scan row failed: %w", err)
}
}
return info, nil
}
// ConnectSession updates a connection in the repo with a state of "connected".

@ -409,7 +409,7 @@ func TestRepository_AuthorizeConnect(t *testing.T) {
repo, err := NewRepository(rw, rw, kms)
require.NoError(t, err)
setupFn := func(exp *timestamp.Timestamp) string {
setupFn := func(exp *timestamp.Timestamp) *Session {
composedOf := TestSessionParams(t, conn, wrapper, iamRepo)
if exp != nil {
composedOf.ExpirationTime = exp
@ -419,37 +419,48 @@ func TestRepository_AuthorizeConnect(t *testing.T) {
tofu := TestTofu(t)
_, _, err := repo.ActivateSession(context.Background(), s.PublicId, s.Version, srv.PrivateId, srv.Type, tofu)
require.NoError(t, err)
return s.PublicId
return s
}
testSession := setupFn(nil)
tests := []struct {
name string
sessionId string
wantErr bool
wantIsError error
name string
session *Session
wantErr bool
wantIsError error
wantAuthzInfo ConnectionAuthzSummary
}{
{
name: "valid",
sessionId: setupFn(nil),
name: "valid",
session: testSession,
wantAuthzInfo: ConnectionAuthzSummary{
ConnectionLimit: 1,
CurrentConnectionCount: 1,
ExpirationTime: testSession.ExpirationTime,
},
},
{
name: "empty-sessionId",
sessionId: "",
name: "empty-sessionId",
session: func() *Session {
s := AllocSession()
return &s
}(),
wantErr: true,
wantIsError: db.ErrInvalidParameter,
},
{
name: "exceeded-connection-limit",
sessionId: func() string {
sessionId := setupFn(nil)
_ = TestConnection(t, conn, sessionId, "127.0.0.1", 22, "127.0.0.1", 2222)
return sessionId
session: func() *Session {
session := setupFn(nil)
_ = TestConnection(t, conn, session.PublicId, "127.0.0.1", 22, "127.0.0.1", 2222)
return session
}(),
wantErr: true,
wantIsError: ErrInvalidStateForOperation,
},
{
name: "expired-session",
sessionId: setupFn(&timestamp.Timestamp{Timestamp: ptypes.TimestampNow()}),
session: setupFn(&timestamp.Timestamp{Timestamp: ptypes.TimestampNow()}),
wantErr: true,
wantIsError: ErrInvalidStateForOperation,
},
@ -458,7 +469,7 @@ func TestRepository_AuthorizeConnect(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
c, cs, err := repo.AuthorizeConnection(context.Background(), tt.sessionId)
c, cs, authzInfo, err := repo.AuthorizeConnection(context.Background(), tt.session.PublicId)
if tt.wantErr {
require.Error(err)
if tt.wantIsError != nil {
@ -470,6 +481,9 @@ func TestRepository_AuthorizeConnect(t *testing.T) {
require.NotNil(c)
require.NotNil(cs)
assert.Equal(StatusAuthorized.String(), cs[0].Status)
assert.Equal(tt.wantAuthzInfo.ExpirationTime, authzInfo.ExpirationTime)
assert.Equal(tt.wantAuthzInfo.ConnectionLimit, authzInfo.ConnectionLimit)
assert.Equal(tt.wantAuthzInfo.CurrentConnectionCount, authzInfo.CurrentConnectionCount)
})
}
}

Loading…
Cancel
Save