diff --git a/internal/db/migrations/postgres.gen.go b/internal/db/migrations/postgres.gen.go index 32b16dbfe5..ab8efe2e75 100644 --- a/internal/db/migrations/postgres.gen.go +++ b/internal/db/migrations/postgres.gen.go @@ -3758,10 +3758,10 @@ begin; on update cascade, -- the client_tcp_address is the network address of the client which initiated -- the connection to a worker - client_tcp_address inet not null, + client_tcp_address inet, -- maybe null on insert -- the client_tcp_port is the network port at the address of the client the -- worker proxied a connection for the user - client_tcp_port integer not null + client_tcp_port integer -- maybe null on insert check( client_tcp_port > 0 and @@ -3769,10 +3769,10 @@ begin; ), -- the endpoint_tcp_address is the network address of the endpoint which the -- worker initiated the connection to, for the user - endpoint_tcp_address inet not null, + endpoint_tcp_address inet, -- maybe be null on insert -- the endpoint_tcp_port is the network port at the address of the endpoint the -- worker proxied a connection to, for the user - endpoint_tcp_port integer not null + endpoint_tcp_port integer -- maybe null on insert check( endpoint_tcp_port > 0 and @@ -3807,7 +3807,7 @@ begin; immutable_columns before update on session_connection - for each row execute procedure immutable_columns('public_id', 'session_id', 'client_tcp_address', 'client_tcp_port', 'endpoint_tcp_address', 'endpoint_tcp_port', 'create_time'); + for each row execute procedure immutable_columns('public_id', 'session_id', 'create_time'); create trigger update_version_column @@ -3826,7 +3826,7 @@ begin; for each row execute procedure default_create_time(); -- insert_new_connection_state() is used in an after insert trigger on the - -- session_connection table. it will insert a state of "connected" in + -- session_connection table. it will insert a state of "authorized" in -- session_connection_state for the new session connection. create or replace function insert_new_connection_state() @@ -3835,7 +3835,7 @@ begin; begin insert into session_connection_state (connection_id, state) values - (new.public_id, 'connected'); + (new.public_id, 'authorized'); return new; end; $$ language plpgsql; @@ -3879,12 +3879,13 @@ begin; create table session_connection_state_enm ( name text primary key check ( - name in ('connected', 'closed') + name in ('authorized', 'connected', 'closed') ) ); insert into session_connection_state_enm (name) values + ('authorized'), ('connected'), ('closed'); diff --git a/internal/db/migrations/postgres/51_connection.up.sql b/internal/db/migrations/postgres/51_connection.up.sql index 5d355752da..287a706a0b 100644 --- a/internal/db/migrations/postgres/51_connection.up.sql +++ b/internal/db/migrations/postgres/51_connection.up.sql @@ -86,10 +86,10 @@ begin; on update cascade, -- the client_tcp_address is the network address of the client which initiated -- the connection to a worker - client_tcp_address inet not null, + client_tcp_address inet, -- maybe null on insert -- the client_tcp_port is the network port at the address of the client the -- worker proxied a connection for the user - client_tcp_port integer not null + client_tcp_port integer -- maybe null on insert check( client_tcp_port > 0 and @@ -97,10 +97,10 @@ begin; ), -- the endpoint_tcp_address is the network address of the endpoint which the -- worker initiated the connection to, for the user - endpoint_tcp_address inet not null, + endpoint_tcp_address inet, -- maybe be null on insert -- the endpoint_tcp_port is the network port at the address of the endpoint the -- worker proxied a connection to, for the user - endpoint_tcp_port integer not null + endpoint_tcp_port integer -- maybe null on insert check( endpoint_tcp_port > 0 and @@ -135,7 +135,7 @@ begin; immutable_columns before update on session_connection - for each row execute procedure immutable_columns('public_id', 'session_id', 'client_tcp_address', 'client_tcp_port', 'endpoint_tcp_address', 'endpoint_tcp_port', 'create_time'); + for each row execute procedure immutable_columns('public_id', 'session_id', 'create_time'); create trigger update_version_column @@ -154,7 +154,7 @@ begin; for each row execute procedure default_create_time(); -- insert_new_connection_state() is used in an after insert trigger on the - -- session_connection table. it will insert a state of "connected" in + -- session_connection table. it will insert a state of "authorized" in -- session_connection_state for the new session connection. create or replace function insert_new_connection_state() @@ -163,7 +163,7 @@ begin; begin insert into session_connection_state (connection_id, state) values - (new.public_id, 'connected'); + (new.public_id, 'authorized'); return new; end; $$ language plpgsql; @@ -207,12 +207,13 @@ begin; create table session_connection_state_enm ( name text primary key check ( - name in ('connected', 'closed') + name in ('authorized', 'connected', 'closed') ) ); insert into session_connection_state_enm (name) values + ('authorized'), ('connected'), ('closed'); diff --git a/internal/session/connection.go b/internal/session/connection.go index 9a4df98dd5..effe65220f 100644 --- a/internal/session/connection.go +++ b/internal/session/connection.go @@ -122,14 +122,6 @@ func (c *Connection) VetForWrite(ctx context.Context, r db.Reader, opType db.OpT return fmt.Errorf("connection vet for write: public id is immutable: %w", db.ErrInvalidParameter) case contains(opts.WithFieldMaskPaths, "SessionId"): return fmt.Errorf("connection vet for write: session id is immutable: %w", db.ErrInvalidParameter) - case contains(opts.WithFieldMaskPaths, "ClientTcpAddress"): - return fmt.Errorf("connection vet for write: client address is immutable: %w", db.ErrInvalidParameter) - case contains(opts.WithFieldMaskPaths, "ClientTcpPort"): - return fmt.Errorf("connection vet for write: client port is immutable: %w", db.ErrInvalidParameter) - case contains(opts.WithFieldMaskPaths, "EndpointTcpAddress"): - return fmt.Errorf("connection vet for write: endpoint address is immutable: %w", db.ErrInvalidParameter) - case contains(opts.WithFieldMaskPaths, "EndpointTcpPort"): - return fmt.Errorf("connection vet for write: endpoint port is immutable: %w", db.ErrInvalidParameter) case contains(opts.WithFieldMaskPaths, "CreateTime"): return fmt.Errorf("connection vet for write: create time is immutable: %w", db.ErrInvalidParameter) case contains(opts.WithFieldMaskPaths, "UpdateTime"): diff --git a/internal/session/connection_state.go b/internal/session/connection_state.go index c34a03a941..c360bf1604 100644 --- a/internal/session/connection_state.go +++ b/internal/session/connection_state.go @@ -17,8 +17,9 @@ const ( type ConnectionStatus string const ( - StatusConnected ConnectionStatus = "connected" - StatusClosed ConnectionStatus = "closed" + StatusAuthorized ConnectionStatus = "authorized" + StatusConnected ConnectionStatus = "connected" + StatusClosed ConnectionStatus = "closed" ) // String representation of the state's status diff --git a/internal/session/immutable_fields_test.go b/internal/session/immutable_fields_test.go index 31f62ba5e4..5a0d3d1071 100644 --- a/internal/session/immutable_fields_test.go +++ b/internal/session/immutable_fields_test.go @@ -194,42 +194,6 @@ func TestConnection_ImmutableFields(t *testing.T) { }(), fieldMask: []string{"SessionId"}, }, - { - name: "client_tcp_address", - update: func() *Connection { - c := new.Clone().(*Connection) - c.ClientTcpAddress = "0.0.0.0" - return c - }(), - fieldMask: []string{"ClientTcpAddress"}, - }, - { - name: "client_tcp_port", - update: func() *Connection { - c := new.Clone().(*Connection) - c.ClientTcpPort = 443 - return c - }(), - fieldMask: []string{"ClientTcpPort"}, - }, - { - name: "endpoint_tcp_address", - update: func() *Connection { - c := new.Clone().(*Connection) - c.EndpointTcpAddress = "0.0.0.0" - return c - }(), - fieldMask: []string{"EndpointTcpAddress"}, - }, - { - name: "endpoint_tcp_port", - update: func() *Connection { - c := new.Clone().(*Connection) - c.EndpointTcpPort = 443 - return c - }(), - fieldMask: []string{"EndpointTcpPort"}, - }, { name: "create time", update: func() *Connection { diff --git a/internal/session/query.go b/internal/session/query.go index 670e374028..a2ecca0c4e 100644 --- a/internal/session/query.go +++ b/internal/session/query.go @@ -51,26 +51,22 @@ with terminated as ( ) select * from terminated; ` - createConnectionCte = ` + authorizeConnectionCte = ` insert into session_connection ( session_id, - public_id, - client_tcp_address, - client_tcp_port, - endpoint_tcp_address, - endpoint_tcp_port + public_id ) with active_session as ( select $1 as session_id, - $2 as public_id, - $3::inet as client_tcp_address, - $4::int as client_tcp_port, - $5::inet as endpoint_tcp_address, - $6::int as endpoint_tcp_port + $2 as public_id from session s where + -- check that the session hasn't expired. + s.expiration_time > now() and + -- check that there are still connections available. connection_limit of 0 equals unlimited connections + (s.connection_limit = 0 or s.connection_limit > (select count(*) from session_connection sc where sc.session_id = $1)) and -- check that there's a state of active s.public_id in ( select @@ -90,7 +86,7 @@ with active_session as ( where ss.session_id = $1 and ss.state in('cancelling', 'terminated') - ) + ) ) select * from active_session; ` diff --git a/internal/session/repository_session.go b/internal/session/repository_session.go index b729be984c..3c6b44eb32 100644 --- a/internal/session/repository_session.go +++ b/internal/session/repository_session.go @@ -279,17 +279,20 @@ func (r *Repository) TerminateSession(ctx context.Context, sessionId string, ses return &updatedSession, returnedStates, nil } -// ConnectSession creates a connection in the repo with a state of "connected". -// Returns an ErrCancelledOrTerminatedSession error if a connection cannot be made -// because the session was cancelled or terminated. -func (r *Repository) ConnectSession(ctx context.Context, c ConnectWith) (*Connection, []*ConnectionState, error) { - // ConnectWith.validate will check all the fields... - if err := c.validate(); err != nil { - return nil, nil, fmt.Errorf("connect session: %w", err) +// AuthorizeConnection will check to see if a connection is allowed. Currently, +// that authorization checks: +// * the hasn't expired based on the session.Expiration +// * number of connections already created is less than session.ConnectionLimit +// If authorization is success, it creates/stores a new connection in the repo +// and returns it, along with it's states. If the authorization fails, it +// an error of ErrInvalidStateForOperation. +func (r *Repository) AuthorizeConnection(ctx context.Context, sessionId string) (*Connection, []*ConnectionState, error) { + if sessionId == "" { + return nil, nil, fmt.Errorf("authorize connection: missing session id: %w", db.ErrInvalidParameter) } connectionId, err := newConnectionId() if err != nil { - return nil, nil, fmt.Errorf("connect session: %w", err) + return nil, nil, fmt.Errorf("authorize connection: %w", err) } connection := AllocConnection() @@ -300,15 +303,15 @@ func (r *Repository) ConnectSession(ctx context.Context, c ConnectWith) (*Connec db.StdRetryCnt, db.ExpBackoff{}, func(reader db.Reader, w db.Writer) error { - rowsAffected, err := w.Exec(createConnectionCte, []interface{}{c.SessionId, connectionId, c.ClientTcpAddress, c.ClientTcpPort, c.EndpointTcpAddress, c.EndpointTcpPort}) + rowsAffected, err := w.Exec(authorizeConnectionCte, []interface{}{sessionId, connectionId}) if err != nil { - return fmt.Errorf("unable to connect session %s: %w", c.SessionId, err) + return fmt.Errorf("unable to authorize connection %s: %w", sessionId, err) } if rowsAffected == 0 { - return fmt.Errorf("session %s is not active: %w", c.SessionId, ErrInvalidStateForOperation) + return fmt.Errorf("session %s is not authorized (not active, expired or connection limit reached): %w", sessionId, ErrInvalidStateForOperation) } if err := reader.LookupById(ctx, &connection); err != nil { - return fmt.Errorf("lookup session: failed %w for %s", err, c.SessionId) + return fmt.Errorf("lookup connection: failed %w for %s", err, sessionId) } connectionStates, err = fetchConnectionStates(ctx, reader, connectionId, db.WithOrder("start_time desc")) if err != nil { @@ -317,6 +320,59 @@ func (r *Repository) ConnectSession(ctx context.Context, c ConnectWith) (*Connec return nil }, ) + if err != nil { + return nil, nil, fmt.Errorf("authorize connection: %w", err) + } + return &connection, connectionStates, nil +} + +// ConnectSession updates a connection in the repo with a state of "connected". +func (r *Repository) ConnectSession(ctx context.Context, c ConnectWith) (*Connection, []*ConnectionState, error) { + // ConnectWith.validate will check all the fields... + if err := c.validate(); err != nil { + return nil, nil, fmt.Errorf("connect session: %w", err) + } + var connection Connection + var connectionStates []*ConnectionState + _, err := r.writer.DoTx( + ctx, + db.StdRetryCnt, + db.ExpBackoff{}, + func(reader db.Reader, w db.Writer) error { + connection = AllocConnection() + connection.PublicId = c.ConnectionId + connection.ClientTcpAddress = c.ClientTcpAddress + connection.ClientTcpPort = c.ClientTcpPort + connection.EndpointTcpAddress = c.EndpointTcpAddress + connection.EndpointTcpPort = c.EndpointTcpPort + fieldMask := []string{ + "ClientTcpAddress", + "ClientTcpPort", + "EndpointTcpAddress", + "EndpointTcpPort", + } + rowsUpdated, err := w.Update(ctx, &connection, fieldMask, nil) + if err != nil { + return err + } + if err == nil && rowsUpdated > 1 { + // return err, which will result in a rollback of the update + return errors.New("error more than 1 connection would have been updated ") + } + newState, err := NewConnectionState(connection.PublicId, StatusConnected) + if err != nil { + return err + } + if err := w.Create(ctx, newState); err != nil { + return err + } + connectionStates, err = fetchConnectionStates(ctx, reader, c.ConnectionId, db.WithOrder("start_time desc")) + if err != nil { + return err + } + return nil + }, + ) if err != nil { return nil, nil, fmt.Errorf("connect session: %w", err) } diff --git a/internal/session/repository_session_test.go b/internal/session/repository_session_test.go index 7fbb153bd2..ac9d3e6b7a 100644 --- a/internal/session/repository_session_test.go +++ b/internal/session/repository_session_test.go @@ -8,6 +8,7 @@ import ( "github.com/golang/protobuf/ptypes" "github.com/hashicorp/boundary/internal/db" + "github.com/hashicorp/boundary/internal/db/timestamp" "github.com/hashicorp/boundary/internal/iam" "github.com/hashicorp/boundary/internal/kms" "github.com/hashicorp/boundary/internal/oplog" @@ -398,6 +399,81 @@ func TestRepository_UpdateState(t *testing.T) { } } +func TestRepository_AuthorizeConnect(t *testing.T) { + t.Parallel() + conn, _ := db.TestSetup(t, "postgres") + rw := db.New(conn) + wrapper := db.TestWrapper(t) + iamRepo := iam.TestRepo(t, conn, wrapper) + kms := kms.TestKms(t, conn, wrapper) + repo, err := NewRepository(rw, rw, kms) + require.NoError(t, err) + + setupFn := func(exp *timestamp.Timestamp) string { + composedOf := TestSessionParams(t, conn, wrapper, iamRepo) + if exp != nil { + composedOf.ExpirationTime = exp + } + s := TestSession(t, conn, wrapper, composedOf) + srv := TestWorker(t, conn, wrapper) + tofu := TestTofu(t) + _, _, err := repo.ActivateSession(context.Background(), s.PublicId, s.Version, srv.PrivateId, srv.Type, tofu) + require.NoError(t, err) + return s.PublicId + } + tests := []struct { + name string + sessionId string + wantErr bool + wantIsError error + }{ + { + name: "valid", + sessionId: setupFn(nil), + }, + { + name: "empty-sessionId", + sessionId: "", + wantErr: true, + wantIsError: db.ErrInvalidParameter, + }, + { + name: "exceeded-connection-limit", + sessionId: func() string { + sessionId := setupFn(nil) + _ = TestConnection(t, conn, sessionId, "127.0.0.1", 22, "127.0.0.1", 2222) + return sessionId + }(), + wantErr: true, + wantIsError: ErrInvalidStateForOperation, + }, + { + name: "expired-session", + sessionId: setupFn(×tamp.Timestamp{Timestamp: ptypes.TimestampNow()}), + wantErr: true, + wantIsError: ErrInvalidStateForOperation, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + + c, cs, err := repo.AuthorizeConnection(context.Background(), tt.sessionId) + if tt.wantErr { + require.Error(err) + if tt.wantIsError != nil { + assert.Truef(errors.Is(err, tt.wantIsError), "unexpected error %s", err.Error()) + } + return + } + require.NoError(err) + require.NotNil(c) + require.NotNil(cs) + assert.Equal(StatusAuthorized.String(), cs[0].Status) + }) + } +} + func TestRepository_ConnectSession(t *testing.T) { t.Parallel() conn, _ := db.TestSetup(t, "postgres") @@ -414,9 +490,9 @@ func TestRepository_ConnectSession(t *testing.T) { tofu := TestTofu(t) _, _, err := repo.ActivateSession(context.Background(), s.PublicId, s.Version, srv.PrivateId, srv.Type, tofu) require.NoError(t, err) - _ = TestConnection(t, conn, s.PublicId, "127.0.0.1", 22, "127.0.0.1", 2222) + c := TestConnection(t, conn, s.PublicId, "127.0.0.1", 22, "127.0.0.1", 2222) return ConnectWith{ - SessionId: s.PublicId, + ConnectionId: c.PublicId, ClientTcpAddress: "127.0.0.1", ClientTcpPort: 22, EndpointTcpAddress: "127.0.0.1", @@ -437,7 +513,7 @@ func TestRepository_ConnectSession(t *testing.T) { name: "empty-SessionId", connectWith: func() ConnectWith { cw := setupFn() - cw.SessionId = "" + cw.ConnectionId = "" return cw }(), wantErr: true, diff --git a/internal/session/session_connect_with.go b/internal/session/session_connect_with.go index db2da0e296..584b94b9ce 100644 --- a/internal/session/session_connect_with.go +++ b/internal/session/session_connect_with.go @@ -9,7 +9,7 @@ import ( // ConnectWith defines the boundary data that is saved in the repo when the // worker has established a connection between the client and the endpoint. type ConnectWith struct { - SessionId string + ConnectionId string ClientTcpAddress string ClientTcpPort uint32 EndpointTcpAddress string @@ -17,7 +17,7 @@ type ConnectWith struct { } func (c ConnectWith) validate() error { - if c.SessionId == "" { + if c.ConnectionId == "" { return fmt.Errorf("missing session id: %w", db.ErrInvalidParameter) } if c.ClientTcpAddress == "" { diff --git a/internal/session/session_connect_with_test.go b/internal/session/session_connect_with_test.go index a101fa741d..8b2d8022c4 100644 --- a/internal/session/session_connect_with_test.go +++ b/internal/session/session_connect_with_test.go @@ -86,7 +86,7 @@ func TestConnectWith_validate(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := ConnectWith{ - SessionId: tt.fields.SessionId, + ConnectionId: tt.fields.SessionId, ClientTcpAddress: tt.fields.ClientTcpAddress, ClientTcpPort: tt.fields.ClientTcpPort, EndpointTcpAddress: tt.fields.EndpointTcpAddress, diff --git a/internal/session/testing.go b/internal/session/testing.go index 529c844a1f..fefcde304b 100644 --- a/internal/session/testing.go +++ b/internal/session/testing.go @@ -36,6 +36,11 @@ func TestConnection(t *testing.T, conn *gorm.DB, sessionId, clientTcpAddr string c.PublicId = id err = rw.Create(context.Background(), c) require.NoError(err) + + connectedState, err := NewConnectionState(c.PublicId, StatusConnected) + require.NoError(err) + err = rw.Create(context.Background(), connectedState) + require.NoError(err) return c }