refact(session): Replace state start_time/end_time with time range (#5134)

This replaces the start_time, end_time, and previous_end_time columns on
the session_state table with a time range column that records when the
state was active. This allowed for dropping several constraints on the
session_state table that were used to ensure that states for a session
did not overlap. Instead this can be ensured with a single exclusion
constraint. By reducing the number of constraints, it improves the
performance for a number of queries involved with session state. Most
notably when resolving delete cascades of resources related to sessions,
such as targets, where as part of the delete transaction large numbers
of sessions get canceled. Similarly, when sessions get deleted, and the
corresponding session_state rows are deleted, these constraints all must
be checked during the transaction.

The domain layer continues to expose StartTime and EndTime fields on
session.State. However it drops PreviousEndTime, which was never used by
anything calling the domain layer.

See:
    https://www.postgresql.org/docs/current/ddl-constraints.html#DDL-CONSTRAINTS-EXCLUSION
    https://www.postgresql.org/docs/current/rangetypes.html
pull/5140/head
Timothy Messier 1 year ago committed by GitHub
parent f6bdbd116e
commit 37bf59243a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -395,6 +395,7 @@ begin;
references session_state (session_id, end_time)
);
-- Replaced in 91/06_session_state_tstzrange.up.sql
create trigger immutable_columns before update on session_state
for each row execute procedure immutable_columns('session_id', 'state', 'start_time', 'previous_end_time');

@ -397,6 +397,7 @@ begin;
end;
$$ language plpgsql;
-- Replaced in 91/06_session_state_tstzrange.up.sql
create trigger wh_insert_session_state after insert on session_state
for each row execute function wh_insert_session_state();

@ -77,6 +77,7 @@ begin
end; $$;
-- Replaces trigger from 0/50_session.up.sql
-- Replaced in 91/06_session_state_tstzrange.up.sql
-- Update insert session state transition trigger
drop trigger insert_session_state on session_state;
drop function insert_session_state();

@ -4,6 +4,7 @@
begin;
-- Replaces the view created in 69/02_session_worker_protocol.up.sql
-- Replaced in 91/06_session_state_tstzrange.up.sql
drop view session_list;
create view session_list as
select s.public_id,

@ -69,6 +69,7 @@ begin;
'for the user that corresponds to the provided auth_token_id.';
-- Replaces function from 60/03_wh_sessions.up.sql
-- Replaced in 91/06_session_state_tstzrange.up.sql
create function wh_insert_session() returns trigger
as $$
declare

@ -0,0 +1,178 @@
-- Copyright (c) HashiCorp, Inc.
-- SPDX-License-Identifier: BUSL-1.1
begin;
-- Add new active_time_range column that will replace two start_time, end_time columns.
-- Also drop a number of constraints on the start_time, end_time columns. This will allow
-- from dropping these columns after the new column has been set with the correct data.
alter table session_state
add column active_time_range tstzrange not null default tstzrange(now(), null, '[]'),
drop constraint end_times_in_sequence,
drop constraint previous_end_time_and_start_time_in_sequence,
drop constraint start_and_end_times_in_sequence,
drop constraint session_state_session_id_previous_end_time_fkey;
-- Set the new active_time_range column for any existing rows using start_time and end_time.
update session_state
set active_time_range = tstzrange(start_time, end_time, '[)');
-- Replaces view from 72/03/session_list_perf_fix.up.sql
-- Switch view to tuse the new column. This also eliminates the previous_end_time column
-- from the view, since it also will be dropped.
drop view session_list;
create view session_list as
select s.public_id,
s.user_id,
shsh.host_id,
shsh.host_set_id,
s.target_id,
s.auth_token_id,
s.project_id,
s.certificate,
s.expiration_time,
s.termination_reason,
s.create_time,
s.update_time,
s.version,
s.endpoint,
s.connection_limit,
ss.state,
lower(ss.active_time_range) as start_time,
upper(ss.active_time_range) as end_time
from session s
join session_state ss on s.public_id = ss.session_id
left join session_host_set_host shsh on s.public_id = shsh.session_id;
-- Now we can finally drop the old columns and add a constraint on the new column
-- that ensures there are no overlaps on the active_time_range for a given session.
alter table session_state
drop column start_time,
drop column end_time,
drop column previous_end_time,
add constraint session_state_active_time_range_excl
exclude using gist (session_id with =,
active_time_range with &&),
add constraint active_time_range_not_empty
check (not isempty(active_time_range));
-- There are still a number of functions that reference the old columns.
-- These all need to be updated to use the new column instead.
-- Replaces trigger from 0/50_session.up.sql
drop trigger immutable_columns on session_state;
create trigger immutable_columns before update on session_state
for each row execute procedure immutable_columns('session_id', 'state');
-- Replaces function from 28/02_prior_session_trigger.up.sql
drop trigger insert_session_state on session_state;
drop function insert_session_state();
create function insert_session_state() returns trigger
as $$
declare
old_col_state text;
begin
update session_state
set active_time_range = tstzrange(lower(active_time_range), now(), '[)')
where session_id = new.session_id
and upper(active_time_range) is null
returning state
into old_col_state;
if not found then
new.prior_state = 'pending';
else
new.prior_state = old_col_state;
end if;
new.active_time_range = tstzrange(now(), null, '[]');
return new;
end;
$$ language plpgsql;
create trigger insert_session_state before insert on session_state
for each row execute procedure insert_session_state();
-- Replaces function from 84/02_wh_upsert_user_refact.up.sql
drop trigger wh_insert_session on session;
drop function wh_insert_session;
create function wh_insert_session() returns trigger
as $$
declare
new_row wh_session_accumulating_fact%rowtype;
begin
with
pending_timestamp (date_dim_key, time_dim_key, ts) as (
select wh_date_key(lower(active_time_range)), wh_time_key(lower(active_time_range)), lower(active_time_range)
from session_state
where session_id = new.public_id
and state = 'pending'
)
insert into wh_session_accumulating_fact (
session_id,
auth_token_id,
host_key,
user_key,
credential_group_key,
session_pending_date_key,
session_pending_time_key,
session_pending_time
)
select new.public_id,
new.auth_token_id,
'no host source', -- will be updated by wh_upsert_host
wh_upsert_user(new.auth_token_id),
'no credentials', -- will be updated by wh_upsert_credential_group
pending_timestamp.date_dim_key,
pending_timestamp.time_dim_key,
pending_timestamp.ts
from pending_timestamp
returning * into strict new_row;
return null;
end;
$$ language plpgsql;
create trigger wh_insert_session after insert on session
for each row execute procedure wh_insert_session();
-- Replaces function from 15/01_wh_rename_key_columns.up.sql
drop trigger wh_insert_session_state on session_state;
drop function wh_insert_session_state;
create function wh_insert_session_state() returns trigger
as $$
declare
date_col text;
time_col text;
ts_col text;
q text;
session_row wh_session_accumulating_fact%rowtype;
begin
if new.state = 'pending' then
-- The pending state is the first state which is handled by the
-- wh_insert_session trigger. The update statement in this trigger will
-- fail for the pending state because the row for the session has not yet
-- been inserted into the wh_session_accumulating_fact table.
return null;
end if;
date_col = 'session_' || new.state || '_date_key';
time_col = 'session_' || new.state || '_time_key';
ts_col = 'session_' || new.state || '_time';
q = format(' update wh_session_accumulating_fact
set (%I, %I, %I) = (select wh_date_key(%L), wh_time_key(%L), %L::timestamptz)
where session_id = %L
returning *',
date_col, time_col, ts_col,
lower(new.active_time_range), lower(new.active_time_range), lower(new.active_time_range),
new.session_id);
execute q into strict session_row;
return null;
end;
$$ language plpgsql;
create trigger wh_insert_session_state after insert on session_state
for each row execute function wh_insert_session_state();
commit;

@ -5,6 +5,7 @@ package session
import (
"context"
"fmt"
"testing"
"github.com/hashicorp/boundary/internal/db"
@ -97,15 +98,35 @@ func TestState_ImmutableFields(t *testing.T) {
wrapper := db.TestWrapper(t)
iamRepo := iam.TestRepo(t, conn, wrapper)
ts := timestamp.Timestamp{Timestamp: &timestamppb.Timestamp{Seconds: 0, Nanos: 0}}
_, _ = iam.TestScopes(t, iam.TestRepo(t, conn, wrapper))
session := TestDefaultSession(t, conn, wrapper, iamRepo)
state := TestState(t, conn, session.PublicId, StatusActive)
var new State
err := rw.LookupWhere(context.Background(), &new, "session_id = ? and state = ?", []any{state.SessionId, state.Status})
require.NoError(t, err)
fetchSession := func(ctx context.Context, rw *db.Db, sessionId string, startTime *timestamp.Timestamp) (*State, error) {
const selectQuery = `
select session_id,
state,
lower(active_time_range) as start_time,
upper(active_time_range) as end_time
from session_state
where session_id = ?
and lower(active_time_range) = ?;`
var states []*State
rows, err := rw.Query(ctx, selectQuery, []any{sessionId, startTime})
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
if err := rw.ScanRows(ctx, rows, &states); err != nil {
return nil, err
}
}
if len(states) != 1 {
return nil, fmt.Errorf("found %d states, expected 1", len(states))
}
return states[0], nil
}
tests := []struct {
name string
@ -115,7 +136,7 @@ func TestState_ImmutableFields(t *testing.T) {
{
name: "session_id",
update: func() *State {
s := new.Clone().(*State)
s := state.Clone().(*State)
s.SessionId = "s_thisIsNotAValidId"
return s
}(),
@ -124,47 +145,28 @@ func TestState_ImmutableFields(t *testing.T) {
{
name: "status",
update: func() *State {
s := new.Clone().(*State)
s := state.Clone().(*State)
s.Status = "canceling"
return s
}(),
fieldMask: []string{"Status"},
},
{
name: "start time",
update: func() *State {
s := new.Clone().(*State)
s.StartTime = &ts
return s
}(),
fieldMask: []string{"StartTime"},
},
{
name: "previous_end_time",
update: func() *State {
s := new.Clone().(*State)
s.PreviousEndTime = &ts
return s
}(),
fieldMask: []string{"PreviousEndTime"},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
assert, require := assert.New(t), require.New(t)
orig := new.Clone()
err := rw.LookupWhere(context.Background(), orig, "session_id = ? and start_time = ?", []any{new.SessionId, new.StartTime})
orig, err := fetchSession(ctx, rw, state.SessionId, state.StartTime)
require.NoError(err)
rowsUpdated, err := rw.Update(context.Background(), tt.update, tt.fieldMask, nil, db.WithSkipVetForWrite(true))
require.Error(err)
assert.Equal(0, rowsUpdated)
after := new.Clone()
err = rw.LookupWhere(context.Background(), after, "session_id = ? and start_time = ?", []any{new.SessionId, new.StartTime})
after, err := fetchSession(ctx, rw, state.SessionId, state.StartTime)
require.NoError(err)
assert.Equal(orig.(*State), after)
assert.Equal(orig, after)
})
}
}

@ -90,7 +90,7 @@ active_session as (
where
ss.session_id in (select * from unexpired_session) and
ss.state = 'active' and
ss.end_time is null
upper(ss.active_time_range) is null
)
insert into session_connection (
session_id,
@ -150,7 +150,7 @@ from
where
ss.session_id = @public_id and
ss.state = 'canceling' and
ss.end_time is null
upper(ss.active_time_range) is null
)
update session us
set version = version +1,
@ -226,7 +226,7 @@ with canceling_session(session_id) as
session_state ss
where
ss.state = 'canceling' and
ss.end_time is null
upper(ss.active_time_range) is null
)
update session us
set termination_reason =
@ -371,7 +371,7 @@ where
and
session_state.state = 'terminated'
and
session_state.start_time < wt_sub_seconds_from_now(@threshold_seconds)
lower(session_state.active_time_range) < wt_sub_seconds_from_now(@threshold_seconds)
;
`
sessionCredentialRewrapQuery = `
@ -451,6 +451,16 @@ order by update_time desc, public_id desc;
`
estimateCountSessions = `
select reltuples::bigint as estimate from pg_class where oid in ('session'::regclass)
`
selectStates = `
select session_id,
state,
lower(active_time_range) as start_time,
upper(active_time_range) as end_time
from session_state
where session_id = ?
order by active_time_range desc;
`
)

@ -144,33 +144,32 @@ func (r *Repository) convertToSessions(ctx context.Context, sessionList []*sessi
PublicId: sv.PublicId,
UserId: sv.UserId,
HostId: sv.HostId,
TargetId: sv.TargetId,
HostSetId: sv.HostSetId,
TargetId: sv.TargetId,
AuthTokenId: sv.AuthTokenId,
ProjectId: sv.ProjectId,
Certificate: sv.Certificate,
CtCertificatePrivateKey: nil, // CtCertificatePrivateKey should not be returned in lists
CertificatePrivateKey: nil, // CertificatePrivateKey should not be returned in lists
ExpirationTime: sv.ExpirationTime,
CtTofuToken: nil, // CtTofuToken should not be returned in lists
TofuToken: nil, // TofuToken should not be returned in lists
TerminationReason: sv.TerminationReason,
CreateTime: sv.CreateTime,
UpdateTime: sv.UpdateTime,
Version: sv.Version,
Endpoint: sv.Endpoint,
ConnectionLimit: sv.ConnectionLimit,
KeyId: "", // KeyId should not be returned in lists
CtCertificatePrivateKey: nil, // CtCertificatePrivateKey should not be returned in lists
CertificatePrivateKey: nil, // CertificatePrivateKey should not be returned in lists
CtTofuToken: nil, // CtTofuToken should not be returned in lists
TofuToken: nil, // TofuToken should not be returned in lists
KeyId: "", // KeyId should not be returned in lists
}
}
if _, ok := states[sv.EndTime]; !ok {
states[sv.EndTime] = &State{
SessionId: sv.PublicId,
Status: Status(sv.Status),
PreviousEndTime: sv.PreviousEndTime,
StartTime: sv.StartTime,
EndTime: sv.EndTime,
SessionId: sv.PublicId,
Status: Status(sv.Status),
StartTime: sv.StartTime,
EndTime: sv.EndTime,
}
}

@ -212,7 +212,7 @@ func (r *Repository) LookupSession(ctx context.Context, sessionId string, opt ..
if err := read.LookupById(ctx, &session); err != nil {
return errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("failed for %s", sessionId)))
}
states, err := fetchStates(ctx, read, sessionId, db.WithOrder("start_time desc"))
states, err := fetchStates(ctx, read, sessionId)
if err != nil {
return errors.Wrap(ctx, err, op)
}
@ -618,7 +618,7 @@ func (r *Repository) fetchActivatedSessionStatesTx(ctx context.Context, reader d
var txErr error
var returnedStates []*State
returnedStates, txErr = fetchStates(ctx, reader, sessionId, db.WithOrder("start_time desc"))
returnedStates, txErr = fetchStates(ctx, reader, sessionId)
if txErr != nil {
return nil, errors.Wrap(ctx, txErr, op)
}
@ -823,7 +823,7 @@ func (r *Repository) updateState(ctx context.Context, sessionId string, sessionV
if rowsAffected != 0 && rowsAffected != 1 {
return errors.New(ctx, errors.MultipleRecords, op, fmt.Sprintf("updated session %s to state %s and %d rows inserted (should be 0 or 1)", sessionId, s.String(), rowsAffected))
}
returnedStates, err = fetchStates(ctx, reader, sessionId, db.WithOrder("start_time desc"))
returnedStates, err = fetchStates(ctx, reader, sessionId)
if err != nil {
return errors.Wrap(ctx, err, op)
}
@ -854,7 +854,7 @@ func (r *Repository) updateState(ctx context.Context, sessionId string, sessionV
// non-active state, i.e. "canceling" or "terminated" It returns a *StateReport
// object for each session that is not active, with its current status.
func (r *Repository) CheckIfNotActive(ctx context.Context, reportedSessions []string) ([]*StateReport, error) {
const op = "session.(Repository).listSessionIdAndState"
const op = "session.(Repository).CheckIfNotActive"
notActive := make([]*StateReport, 0, len(reportedSessions))
if len(reportedSessions) <= 0 {
@ -872,7 +872,7 @@ func (r *Repository) CheckIfNotActive(ctx context.Context, reportedSessions []st
db.ExpBackoff{},
func(reader db.Reader, _ db.Writer) error {
var states []*State
err := reader.SearchWhere(ctx, &states, "end_time is null and session_id in (?)", []any{reportedSessions})
err := reader.SearchWhere(ctx, &states, "upper(active_time_range) is null and session_id in (?)", []any{reportedSessions})
if err != nil {
return errors.Wrap(ctx, err, op)
}
@ -926,9 +926,16 @@ func (r *Repository) deleteSessionsTerminatedBefore(ctx context.Context, thresho
func fetchStates(ctx context.Context, r db.Reader, sessionId string, opt ...db.Option) ([]*State, error) {
const op = "session.fetchStates"
var states []*State
if err := r.SearchWhere(ctx, &states, "session_id = ?", []any{sessionId}, opt...); err != nil {
rows, err := r.Query(ctx, selectStates, []any{sessionId}, opt...)
if err != nil {
return nil, errors.Wrap(ctx, err, op)
}
defer rows.Close()
for rows.Next() {
if err := r.ScanRows(ctx, rows, &states); err != nil {
return nil, errors.Wrap(ctx, err, op)
}
}
if len(states) == 0 {
return nil, nil
}

@ -575,34 +575,27 @@ func (s *Session) decrypt(ctx context.Context, cipher wrapping.Wrapper) error {
}
type sessionListView struct {
// Session fields
PublicId string `json:"public_id,omitempty" gorm:"primary_key"`
UserId string `json:"user_id,omitempty" gorm:"default:null"`
HostId string `json:"host_id,omitempty" gorm:"default:null"`
TargetId string `json:"target_id,omitempty" gorm:"default:null"`
HostSetId string `json:"host_set_id,omitempty" gorm:"default:null"`
AuthTokenId string `json:"auth_token_id,omitempty" gorm:"default:null"`
ProjectId string `json:"project_id,omitempty" gorm:"default:null"`
Certificate []byte `json:"certificate,omitempty" gorm:"default:null"`
CtCertificatePrivateKey []byte `json:"ct_certificate_private_key,omitempty" gorm:"column:certificate_private_key;default:null" wrapping:"ct,certificate_private_key"`
CertificatePrivateKey []byte `json:"certificate_private_key,omitempty" gorm:"-" wrapping:"pt,certificate_private_key"`
ExpirationTime *timestamp.Timestamp `json:"expiration_time,omitempty" gorm:"default:null"`
CtTofuToken []byte `json:"ct_tofu_token,omitempty" gorm:"column:tofu_token;default:null" wrapping:"ct,tofu_token"`
TofuToken []byte `json:"tofu_token,omitempty" gorm:"-" wrapping:"pt,tofu_token"`
TerminationReason string `json:"termination_reason,omitempty" gorm:"default:null"`
CreateTime *timestamp.Timestamp `json:"create_time,omitempty" gorm:"default:current_timestamp"`
UpdateTime *timestamp.Timestamp `json:"update_time,omitempty" gorm:"default:current_timestamp"`
Version uint32 `json:"version,omitempty" gorm:"default:null"`
Endpoint string `json:"-" gorm:"default:null"`
ConnectionLimit int32 `json:"connection_limit,omitempty" gorm:"default:null"`
KeyId string `json:"key_id,omitempty" gorm:"default:null"`
ProtocolWorkerId string `json:"protocol_worker_id,omitempty" gorm:"default:null"`
// Session fields, we omit some fields that are not included when listing sessions.
PublicId string `gorm:"primary_key"`
UserId string `gorm:"default:null"`
HostId string `gorm:"default:null"`
HostSetId string `gorm:"default:null"`
TargetId string `gorm:"default:null"`
AuthTokenId string `gorm:"default:null"`
ProjectId string `gorm:"default:null"`
Certificate []byte `gorm:"default:null"`
ExpirationTime *timestamp.Timestamp `gorm:"default:null"`
TerminationReason string `gorm:"default:null"`
CreateTime *timestamp.Timestamp `gorm:"default:current_timestamp"`
UpdateTime *timestamp.Timestamp `gorm:"default:current_timestamp"`
Version uint32 `gorm:"default:null"`
Endpoint string `gorm:"default:null"`
ConnectionLimit int32 `gorm:"default:null"`
// State fields
Status string `json:"state,omitempty" gorm:"column:state"`
PreviousEndTime *timestamp.Timestamp `json:"previous_end_time,omitempty" gorm:"default:current_timestamp"`
StartTime *timestamp.Timestamp `json:"start_time,omitempty" gorm:"default:current_timestamp;primary_key"`
EndTime *timestamp.Timestamp `json:"end_time,omitempty" gorm:"default:current_timestamp"`
Status string `gorm:"column:state"`
StartTime *timestamp.Timestamp `gorm:"column:start_time"`
EndTime *timestamp.Timestamp `gorm:"column:end_time"`
}
// TableName returns the tablename to override the default gorm table name

@ -54,8 +54,6 @@ type State struct {
SessionId string `json:"session_id,omitempty" gorm:"primary_key"`
// status of the session
Status Status `json:"status,omitempty" gorm:"column:state"`
// PreviousEndTime from the RDBMS
PreviousEndTime *timestamp.Timestamp `json:"previous_end_time,omitempty" gorm:"default:current_timestamp"`
// StartTime from the RDBMS
StartTime *timestamp.Timestamp `json:"start_time,omitempty" gorm:"default:current_timestamp;primary_key"`
// EndTime from the RDBMS
@ -95,15 +93,6 @@ func (s *State) Clone() any {
SessionId: s.SessionId,
Status: s.Status,
}
if s.PreviousEndTime != nil {
clone.PreviousEndTime = &timestamp.Timestamp{
Timestamp: &timestamppb.Timestamp{
Seconds: s.PreviousEndTime.Timestamp.Seconds,
Nanos: s.PreviousEndTime.Timestamp.Nanos,
},
}
}
if s.StartTime != nil {
clone.StartTime = &timestamp.Timestamp{
Timestamp: &timestamppb.Timestamp{
@ -163,8 +152,5 @@ func (s *State) validate(ctx context.Context) error {
if s.EndTime != nil {
return errors.New(ctx, errors.InvalidParameter, op, "end time is not settable")
}
if s.PreviousEndTime != nil {
return errors.New(ctx, errors.InvalidParameter, op, "previous end time is not settable")
}
return nil
}

@ -4,160 +4,14 @@
package session
import (
"context"
"testing"
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/errors"
"github.com/hashicorp/boundary/internal/iam"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestState_Create(t *testing.T) {
t.Parallel()
conn, _ := db.TestSetup(t, "postgres")
wrapper := db.TestWrapper(t)
iamRepo := iam.TestRepo(t, conn, wrapper)
session := TestDefaultSession(t, conn, wrapper, iamRepo)
type args struct {
sessionId string
status Status
}
tests := []struct {
name string
args args
want *State
wantErr bool
wantIsErr errors.Code
create bool
wantCreateErr bool
}{
{
name: "valid",
args: args{
sessionId: session.PublicId,
status: StatusActive,
},
want: &State{
SessionId: session.PublicId,
Status: StatusActive,
},
create: true,
},
{
name: "empty-sessionId",
args: args{
status: StatusPending,
},
wantErr: true,
wantIsErr: errors.InvalidParameter,
},
{
name: "empty-status",
args: args{
sessionId: session.PublicId,
},
wantErr: true,
wantIsErr: errors.InvalidParameter,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
got, err := NewState(context.Background(), tt.args.sessionId, tt.args.status)
if tt.wantErr {
require.Error(err)
assert.True(errors.Match(errors.T(tt.wantIsErr), err))
return
}
require.NoError(err)
assert.Equal(tt.want, got)
if tt.create {
err = db.New(conn).Create(context.Background(), got)
if tt.wantCreateErr {
assert.Error(err)
return
} else {
assert.NoError(err)
}
}
})
}
}
func TestState_Delete(t *testing.T) {
t.Parallel()
ctx := context.Background()
conn, _ := db.TestSetup(t, "postgres")
rw := db.New(conn)
wrapper := db.TestWrapper(t)
iamRepo := iam.TestRepo(t, conn, wrapper)
session := TestDefaultSession(t, conn, wrapper, iamRepo)
session2 := TestDefaultSession(t, conn, wrapper, iamRepo)
tests := []struct {
name string
state *State
deleteStateId string
wantRowsDeleted int
wantErr bool
wantErrMsg string
}{
{
name: "valid",
state: TestState(t, conn, session.PublicId, StatusTerminated),
wantErr: false,
wantRowsDeleted: 1,
},
{
name: "bad-id",
state: TestState(t, conn, session2.PublicId, StatusTerminated),
deleteStateId: func() string {
id, err := db.NewPublicId(ctx, StatePrefix)
require.NoError(t, err)
return id
}(),
wantErr: false,
wantRowsDeleted: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
var initialState State
err := rw.LookupWhere(ctx, &initialState, "session_id = ? and state = ?", []any{tt.state.SessionId, tt.state.Status})
require.NoError(err)
deleteState := allocState()
if tt.deleteStateId != "" {
deleteState.SessionId = tt.deleteStateId
} else {
deleteState.SessionId = tt.state.SessionId
}
deleteState.StartTime = initialState.StartTime
deletedRows, err := rw.Delete(ctx, &deleteState)
if tt.wantErr {
require.Error(err)
return
}
require.NoError(err)
if tt.wantRowsDeleted == 0 {
assert.Equal(tt.wantRowsDeleted, deletedRows)
return
}
assert.Equal(tt.wantRowsDeleted, deletedRows)
foundState := allocState()
err = rw.LookupWhere(ctx, &foundState, "session_id = ? and start_time = ?", []any{tt.state.SessionId, initialState.StartTime})
require.Error(err)
assert.True(errors.IsNotFoundError(err))
})
}
}
func TestState_Clone(t *testing.T) {
t.Parallel()
conn, _ := db.TestSetup(t, "postgres")

@ -46,13 +46,23 @@ func TestConnection(t testing.TB, conn *db.DB, sessionId, clientTcpAddr string,
// TestState creates a test state for the sessionId in the repository.
func TestState(t testing.TB, conn *db.DB, sessionId string, state Status) *State {
const insertSessionState = `
insert into session_state (session_id, state, active_time_range)
values ($1, $2, tstzrange($3, null, '[]'))
returning lower(active_time_range) as start_time
;`
t.Helper()
require := require.New(t)
rw := db.New(conn)
s, err := NewState(context.Background(), sessionId, state)
require.NoError(err)
err = rw.Create(context.Background(), s)
rows, err := rw.Query(context.Background(), insertSessionState, []any{s.SessionId, s.Status, s.StartTime})
require.NoError(err)
defer rows.Close()
for rows.Next() {
err := rows.Scan(&s.StartTime)
require.NoError(err)
}
return s
}
@ -142,7 +152,7 @@ func TestSession(t testing.TB, conn *db.DB, rootWrapper wrapping.Wrapper, c Comp
require.NoError(err)
}
ss, err := fetchStates(ctx, rw, s.PublicId, append(opts.withDbOpts, db.WithOrder("start_time desc"))...)
ss, err := fetchStates(ctx, rw, s.PublicId, opts.withDbOpts...)
require.NoError(err)
s.States = ss

Loading…
Cancel
Save