changes needed for sessions.AuthorizeConnection (#377)

pull/380/head
Jim 6 years ago committed by GitHub
parent ee7cdde7de
commit 66400c9cff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -3758,10 +3758,10 @@ begin;
on update cascade,
-- the client_tcp_address is the network address of the client which initiated
-- the connection to a worker
client_tcp_address inet not null,
client_tcp_address inet, -- maybe null on insert
-- the client_tcp_port is the network port at the address of the client the
-- worker proxied a connection for the user
client_tcp_port integer not null
client_tcp_port integer -- maybe null on insert
check(
client_tcp_port > 0
and
@ -3769,10 +3769,10 @@ begin;
),
-- the endpoint_tcp_address is the network address of the endpoint which the
-- worker initiated the connection to, for the user
endpoint_tcp_address inet not null,
endpoint_tcp_address inet, -- maybe be null on insert
-- the endpoint_tcp_port is the network port at the address of the endpoint the
-- worker proxied a connection to, for the user
endpoint_tcp_port integer not null
endpoint_tcp_port integer -- maybe null on insert
check(
endpoint_tcp_port > 0
and
@ -3807,7 +3807,7 @@ begin;
immutable_columns
before
update on session_connection
for each row execute procedure immutable_columns('public_id', 'session_id', 'client_tcp_address', 'client_tcp_port', 'endpoint_tcp_address', 'endpoint_tcp_port', 'create_time');
for each row execute procedure immutable_columns('public_id', 'session_id', 'create_time');
create trigger
update_version_column
@ -3826,7 +3826,7 @@ begin;
for each row execute procedure default_create_time();
-- insert_new_connection_state() is used in an after insert trigger on the
-- session_connection table. it will insert a state of "connected" in
-- session_connection table. it will insert a state of "authorized" in
-- session_connection_state for the new session connection.
create or replace function
insert_new_connection_state()
@ -3835,7 +3835,7 @@ begin;
begin
insert into session_connection_state (connection_id, state)
values
(new.public_id, 'connected');
(new.public_id, 'authorized');
return new;
end;
$$ language plpgsql;
@ -3879,12 +3879,13 @@ begin;
create table session_connection_state_enm (
name text primary key
check (
name in ('connected', 'closed')
name in ('authorized', 'connected', 'closed')
)
);
insert into session_connection_state_enm (name)
values
('authorized'),
('connected'),
('closed');

@ -86,10 +86,10 @@ begin;
on update cascade,
-- the client_tcp_address is the network address of the client which initiated
-- the connection to a worker
client_tcp_address inet not null,
client_tcp_address inet, -- maybe null on insert
-- the client_tcp_port is the network port at the address of the client the
-- worker proxied a connection for the user
client_tcp_port integer not null
client_tcp_port integer -- maybe null on insert
check(
client_tcp_port > 0
and
@ -97,10 +97,10 @@ begin;
),
-- the endpoint_tcp_address is the network address of the endpoint which the
-- worker initiated the connection to, for the user
endpoint_tcp_address inet not null,
endpoint_tcp_address inet, -- maybe be null on insert
-- the endpoint_tcp_port is the network port at the address of the endpoint the
-- worker proxied a connection to, for the user
endpoint_tcp_port integer not null
endpoint_tcp_port integer -- maybe null on insert
check(
endpoint_tcp_port > 0
and
@ -135,7 +135,7 @@ begin;
immutable_columns
before
update on session_connection
for each row execute procedure immutable_columns('public_id', 'session_id', 'client_tcp_address', 'client_tcp_port', 'endpoint_tcp_address', 'endpoint_tcp_port', 'create_time');
for each row execute procedure immutable_columns('public_id', 'session_id', 'create_time');
create trigger
update_version_column
@ -154,7 +154,7 @@ begin;
for each row execute procedure default_create_time();
-- insert_new_connection_state() is used in an after insert trigger on the
-- session_connection table. it will insert a state of "connected" in
-- session_connection table. it will insert a state of "authorized" in
-- session_connection_state for the new session connection.
create or replace function
insert_new_connection_state()
@ -163,7 +163,7 @@ begin;
begin
insert into session_connection_state (connection_id, state)
values
(new.public_id, 'connected');
(new.public_id, 'authorized');
return new;
end;
$$ language plpgsql;
@ -207,12 +207,13 @@ begin;
create table session_connection_state_enm (
name text primary key
check (
name in ('connected', 'closed')
name in ('authorized', 'connected', 'closed')
)
);
insert into session_connection_state_enm (name)
values
('authorized'),
('connected'),
('closed');

@ -122,14 +122,6 @@ func (c *Connection) VetForWrite(ctx context.Context, r db.Reader, opType db.OpT
return fmt.Errorf("connection vet for write: public id is immutable: %w", db.ErrInvalidParameter)
case contains(opts.WithFieldMaskPaths, "SessionId"):
return fmt.Errorf("connection vet for write: session id is immutable: %w", db.ErrInvalidParameter)
case contains(opts.WithFieldMaskPaths, "ClientTcpAddress"):
return fmt.Errorf("connection vet for write: client address is immutable: %w", db.ErrInvalidParameter)
case contains(opts.WithFieldMaskPaths, "ClientTcpPort"):
return fmt.Errorf("connection vet for write: client port is immutable: %w", db.ErrInvalidParameter)
case contains(opts.WithFieldMaskPaths, "EndpointTcpAddress"):
return fmt.Errorf("connection vet for write: endpoint address is immutable: %w", db.ErrInvalidParameter)
case contains(opts.WithFieldMaskPaths, "EndpointTcpPort"):
return fmt.Errorf("connection vet for write: endpoint port is immutable: %w", db.ErrInvalidParameter)
case contains(opts.WithFieldMaskPaths, "CreateTime"):
return fmt.Errorf("connection vet for write: create time is immutable: %w", db.ErrInvalidParameter)
case contains(opts.WithFieldMaskPaths, "UpdateTime"):

@ -17,8 +17,9 @@ const (
type ConnectionStatus string
const (
StatusConnected ConnectionStatus = "connected"
StatusClosed ConnectionStatus = "closed"
StatusAuthorized ConnectionStatus = "authorized"
StatusConnected ConnectionStatus = "connected"
StatusClosed ConnectionStatus = "closed"
)
// String representation of the state's status

@ -194,42 +194,6 @@ func TestConnection_ImmutableFields(t *testing.T) {
}(),
fieldMask: []string{"SessionId"},
},
{
name: "client_tcp_address",
update: func() *Connection {
c := new.Clone().(*Connection)
c.ClientTcpAddress = "0.0.0.0"
return c
}(),
fieldMask: []string{"ClientTcpAddress"},
},
{
name: "client_tcp_port",
update: func() *Connection {
c := new.Clone().(*Connection)
c.ClientTcpPort = 443
return c
}(),
fieldMask: []string{"ClientTcpPort"},
},
{
name: "endpoint_tcp_address",
update: func() *Connection {
c := new.Clone().(*Connection)
c.EndpointTcpAddress = "0.0.0.0"
return c
}(),
fieldMask: []string{"EndpointTcpAddress"},
},
{
name: "endpoint_tcp_port",
update: func() *Connection {
c := new.Clone().(*Connection)
c.EndpointTcpPort = 443
return c
}(),
fieldMask: []string{"EndpointTcpPort"},
},
{
name: "create time",
update: func() *Connection {

@ -51,26 +51,22 @@ with terminated as (
)
select * from terminated;
`
createConnectionCte = `
authorizeConnectionCte = `
insert into session_connection (
session_id,
public_id,
client_tcp_address,
client_tcp_port,
endpoint_tcp_address,
endpoint_tcp_port
public_id
)
with active_session as (
select
$1 as session_id,
$2 as public_id,
$3::inet as client_tcp_address,
$4::int as client_tcp_port,
$5::inet as endpoint_tcp_address,
$6::int as endpoint_tcp_port
$2 as public_id
from
session s
where
-- check that the session hasn't expired.
s.expiration_time > now() and
-- check that there are still connections available. connection_limit of 0 equals unlimited connections
(s.connection_limit = 0 or s.connection_limit > (select count(*) from session_connection sc where sc.session_id = $1)) and
-- check that there's a state of active
s.public_id in (
select
@ -90,7 +86,7 @@ with active_session as (
where
ss.session_id = $1 and
ss.state in('cancelling', 'terminated')
)
)
)
select * from active_session;
`

@ -279,17 +279,20 @@ func (r *Repository) TerminateSession(ctx context.Context, sessionId string, ses
return &updatedSession, returnedStates, nil
}
// ConnectSession creates a connection in the repo with a state of "connected".
// Returns an ErrCancelledOrTerminatedSession error if a connection cannot be made
// because the session was cancelled or terminated.
func (r *Repository) ConnectSession(ctx context.Context, c ConnectWith) (*Connection, []*ConnectionState, error) {
// ConnectWith.validate will check all the fields...
if err := c.validate(); err != nil {
return nil, nil, fmt.Errorf("connect session: %w", err)
// AuthorizeConnection will check to see if a connection is allowed. Currently,
// that authorization checks:
// * the hasn't expired based on the session.Expiration
// * number of connections already created is less than session.ConnectionLimit
// If authorization is success, it creates/stores a new connection in the repo
// and returns it, along with it's states. If the authorization fails, it
// an error of ErrInvalidStateForOperation.
func (r *Repository) AuthorizeConnection(ctx context.Context, sessionId string) (*Connection, []*ConnectionState, error) {
if sessionId == "" {
return nil, nil, fmt.Errorf("authorize connection: missing session id: %w", db.ErrInvalidParameter)
}
connectionId, err := newConnectionId()
if err != nil {
return nil, nil, fmt.Errorf("connect session: %w", err)
return nil, nil, fmt.Errorf("authorize connection: %w", err)
}
connection := AllocConnection()
@ -300,15 +303,15 @@ func (r *Repository) ConnectSession(ctx context.Context, c ConnectWith) (*Connec
db.StdRetryCnt,
db.ExpBackoff{},
func(reader db.Reader, w db.Writer) error {
rowsAffected, err := w.Exec(createConnectionCte, []interface{}{c.SessionId, connectionId, c.ClientTcpAddress, c.ClientTcpPort, c.EndpointTcpAddress, c.EndpointTcpPort})
rowsAffected, err := w.Exec(authorizeConnectionCte, []interface{}{sessionId, connectionId})
if err != nil {
return fmt.Errorf("unable to connect session %s: %w", c.SessionId, err)
return fmt.Errorf("unable to authorize connection %s: %w", sessionId, err)
}
if rowsAffected == 0 {
return fmt.Errorf("session %s is not active: %w", c.SessionId, ErrInvalidStateForOperation)
return fmt.Errorf("session %s is not authorized (not active, expired or connection limit reached): %w", sessionId, ErrInvalidStateForOperation)
}
if err := reader.LookupById(ctx, &connection); err != nil {
return fmt.Errorf("lookup session: failed %w for %s", err, c.SessionId)
return fmt.Errorf("lookup connection: failed %w for %s", err, sessionId)
}
connectionStates, err = fetchConnectionStates(ctx, reader, connectionId, db.WithOrder("start_time desc"))
if err != nil {
@ -317,6 +320,59 @@ func (r *Repository) ConnectSession(ctx context.Context, c ConnectWith) (*Connec
return nil
},
)
if err != nil {
return nil, nil, fmt.Errorf("authorize connection: %w", err)
}
return &connection, connectionStates, nil
}
// ConnectSession updates a connection in the repo with a state of "connected".
func (r *Repository) ConnectSession(ctx context.Context, c ConnectWith) (*Connection, []*ConnectionState, error) {
// ConnectWith.validate will check all the fields...
if err := c.validate(); err != nil {
return nil, nil, fmt.Errorf("connect session: %w", err)
}
var connection Connection
var connectionStates []*ConnectionState
_, err := r.writer.DoTx(
ctx,
db.StdRetryCnt,
db.ExpBackoff{},
func(reader db.Reader, w db.Writer) error {
connection = AllocConnection()
connection.PublicId = c.ConnectionId
connection.ClientTcpAddress = c.ClientTcpAddress
connection.ClientTcpPort = c.ClientTcpPort
connection.EndpointTcpAddress = c.EndpointTcpAddress
connection.EndpointTcpPort = c.EndpointTcpPort
fieldMask := []string{
"ClientTcpAddress",
"ClientTcpPort",
"EndpointTcpAddress",
"EndpointTcpPort",
}
rowsUpdated, err := w.Update(ctx, &connection, fieldMask, nil)
if err != nil {
return err
}
if err == nil && rowsUpdated > 1 {
// return err, which will result in a rollback of the update
return errors.New("error more than 1 connection would have been updated ")
}
newState, err := NewConnectionState(connection.PublicId, StatusConnected)
if err != nil {
return err
}
if err := w.Create(ctx, newState); err != nil {
return err
}
connectionStates, err = fetchConnectionStates(ctx, reader, c.ConnectionId, db.WithOrder("start_time desc"))
if err != nil {
return err
}
return nil
},
)
if err != nil {
return nil, nil, fmt.Errorf("connect session: %w", err)
}

@ -8,6 +8,7 @@ import (
"github.com/golang/protobuf/ptypes"
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/db/timestamp"
"github.com/hashicorp/boundary/internal/iam"
"github.com/hashicorp/boundary/internal/kms"
"github.com/hashicorp/boundary/internal/oplog"
@ -398,6 +399,81 @@ func TestRepository_UpdateState(t *testing.T) {
}
}
func TestRepository_AuthorizeConnect(t *testing.T) {
t.Parallel()
conn, _ := db.TestSetup(t, "postgres")
rw := db.New(conn)
wrapper := db.TestWrapper(t)
iamRepo := iam.TestRepo(t, conn, wrapper)
kms := kms.TestKms(t, conn, wrapper)
repo, err := NewRepository(rw, rw, kms)
require.NoError(t, err)
setupFn := func(exp *timestamp.Timestamp) string {
composedOf := TestSessionParams(t, conn, wrapper, iamRepo)
if exp != nil {
composedOf.ExpirationTime = exp
}
s := TestSession(t, conn, wrapper, composedOf)
srv := TestWorker(t, conn, wrapper)
tofu := TestTofu(t)
_, _, err := repo.ActivateSession(context.Background(), s.PublicId, s.Version, srv.PrivateId, srv.Type, tofu)
require.NoError(t, err)
return s.PublicId
}
tests := []struct {
name string
sessionId string
wantErr bool
wantIsError error
}{
{
name: "valid",
sessionId: setupFn(nil),
},
{
name: "empty-sessionId",
sessionId: "",
wantErr: true,
wantIsError: db.ErrInvalidParameter,
},
{
name: "exceeded-connection-limit",
sessionId: func() string {
sessionId := setupFn(nil)
_ = TestConnection(t, conn, sessionId, "127.0.0.1", 22, "127.0.0.1", 2222)
return sessionId
}(),
wantErr: true,
wantIsError: ErrInvalidStateForOperation,
},
{
name: "expired-session",
sessionId: setupFn(&timestamp.Timestamp{Timestamp: ptypes.TimestampNow()}),
wantErr: true,
wantIsError: ErrInvalidStateForOperation,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
c, cs, err := repo.AuthorizeConnection(context.Background(), tt.sessionId)
if tt.wantErr {
require.Error(err)
if tt.wantIsError != nil {
assert.Truef(errors.Is(err, tt.wantIsError), "unexpected error %s", err.Error())
}
return
}
require.NoError(err)
require.NotNil(c)
require.NotNil(cs)
assert.Equal(StatusAuthorized.String(), cs[0].Status)
})
}
}
func TestRepository_ConnectSession(t *testing.T) {
t.Parallel()
conn, _ := db.TestSetup(t, "postgres")
@ -414,9 +490,9 @@ func TestRepository_ConnectSession(t *testing.T) {
tofu := TestTofu(t)
_, _, err := repo.ActivateSession(context.Background(), s.PublicId, s.Version, srv.PrivateId, srv.Type, tofu)
require.NoError(t, err)
_ = TestConnection(t, conn, s.PublicId, "127.0.0.1", 22, "127.0.0.1", 2222)
c := TestConnection(t, conn, s.PublicId, "127.0.0.1", 22, "127.0.0.1", 2222)
return ConnectWith{
SessionId: s.PublicId,
ConnectionId: c.PublicId,
ClientTcpAddress: "127.0.0.1",
ClientTcpPort: 22,
EndpointTcpAddress: "127.0.0.1",
@ -437,7 +513,7 @@ func TestRepository_ConnectSession(t *testing.T) {
name: "empty-SessionId",
connectWith: func() ConnectWith {
cw := setupFn()
cw.SessionId = ""
cw.ConnectionId = ""
return cw
}(),
wantErr: true,

@ -9,7 +9,7 @@ import (
// ConnectWith defines the boundary data that is saved in the repo when the
// worker has established a connection between the client and the endpoint.
type ConnectWith struct {
SessionId string
ConnectionId string
ClientTcpAddress string
ClientTcpPort uint32
EndpointTcpAddress string
@ -17,7 +17,7 @@ type ConnectWith struct {
}
func (c ConnectWith) validate() error {
if c.SessionId == "" {
if c.ConnectionId == "" {
return fmt.Errorf("missing session id: %w", db.ErrInvalidParameter)
}
if c.ClientTcpAddress == "" {

@ -86,7 +86,7 @@ func TestConnectWith_validate(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := ConnectWith{
SessionId: tt.fields.SessionId,
ConnectionId: tt.fields.SessionId,
ClientTcpAddress: tt.fields.ClientTcpAddress,
ClientTcpPort: tt.fields.ClientTcpPort,
EndpointTcpAddress: tt.fields.EndpointTcpAddress,

@ -36,6 +36,11 @@ func TestConnection(t *testing.T, conn *gorm.DB, sessionId, clientTcpAddr string
c.PublicId = id
err = rw.Create(context.Background(), c)
require.NoError(err)
connectedState, err := NewConnectionState(c.PublicId, StatusConnected)
require.NoError(err)
err = rw.Create(context.Background(), connectedState)
require.NoError(err)
return c
}

Loading…
Cancel
Save