From 37bf59243a7a67fdb19ce942cde425301ed8eb12 Mon Sep 17 00:00:00 2001 From: Timothy Messier Date: Mon, 30 Sep 2024 14:39:19 -0400 Subject: [PATCH] refact(session): Replace state start_time/end_time with time range (#5134) This replaces the start_time, end_time, and previous_end_time columns on the session_state table with a time range column that records when the state was active. This allowed for dropping several constraints on the session_state table that were used to ensure that states for a session did not overlap. Instead this can be ensured with a single exclusion constraint. By reducing the number of constraints, it improves the performance for a number of queries involved with session state. Most notably when resolving delete cascades of resources related to sessions, such as targets, where as part of the delete transaction large numbers of sessions get canceled. Similarly, when sessions get deleted, and the corresponding session_state rows are deleted, these constraints all must be checked during the transaction. The domain layer continues to expose StartTime and EndTime fields on session.State. However it drops PreviousEndTime, which was never used by anything calling the domain layer. See: https://www.postgresql.org/docs/current/ddl-constraints.html#DDL-CONSTRAINTS-EXCLUSION https://www.postgresql.org/docs/current/rangetypes.html --- .../oss/postgres/0/50_session.up.sql | 1 + .../15/01_wh_rename_key_columns.up.sql | 1 + .../28/02_prior_session_trigger.up.sql | 1 + .../72/03_session_list_perf_fix.up.sql | 1 + .../84/02_wh_upsert_user_refact.up.sql | 1 + .../91/06_session_state_tstzrange.up.sql | 178 ++++++++++++++++++ internal/session/immutable_fields_test.go | 62 +++--- internal/session/query.go | 18 +- internal/session/repository.go | 21 +-- internal/session/repository_session.go | 19 +- internal/session/session.go | 45 ++--- internal/session/state.go | 14 -- internal/session/state_test.go | 146 -------------- internal/session/testing.go | 14 +- 14 files changed, 283 insertions(+), 239 deletions(-) create mode 100644 internal/db/schema/migrations/oss/postgres/91/06_session_state_tstzrange.up.sql diff --git a/internal/db/schema/migrations/oss/postgres/0/50_session.up.sql b/internal/db/schema/migrations/oss/postgres/0/50_session.up.sql index 62db5118df..1c2b2f5664 100644 --- a/internal/db/schema/migrations/oss/postgres/0/50_session.up.sql +++ b/internal/db/schema/migrations/oss/postgres/0/50_session.up.sql @@ -395,6 +395,7 @@ begin; references session_state (session_id, end_time) ); + -- Replaced in 91/06_session_state_tstzrange.up.sql create trigger immutable_columns before update on session_state for each row execute procedure immutable_columns('session_id', 'state', 'start_time', 'previous_end_time'); diff --git a/internal/db/schema/migrations/oss/postgres/15/01_wh_rename_key_columns.up.sql b/internal/db/schema/migrations/oss/postgres/15/01_wh_rename_key_columns.up.sql index 16b9a8e67f..7ecc4a146e 100644 --- a/internal/db/schema/migrations/oss/postgres/15/01_wh_rename_key_columns.up.sql +++ b/internal/db/schema/migrations/oss/postgres/15/01_wh_rename_key_columns.up.sql @@ -397,6 +397,7 @@ begin; end; $$ language plpgsql; + -- Replaced in 91/06_session_state_tstzrange.up.sql create trigger wh_insert_session_state after insert on session_state for each row execute function wh_insert_session_state(); diff --git a/internal/db/schema/migrations/oss/postgres/28/02_prior_session_trigger.up.sql b/internal/db/schema/migrations/oss/postgres/28/02_prior_session_trigger.up.sql index 831728b3eb..72f82fac8e 100644 --- a/internal/db/schema/migrations/oss/postgres/28/02_prior_session_trigger.up.sql +++ b/internal/db/schema/migrations/oss/postgres/28/02_prior_session_trigger.up.sql @@ -77,6 +77,7 @@ begin end; $$; -- Replaces trigger from 0/50_session.up.sql +-- Replaced in 91/06_session_state_tstzrange.up.sql -- Update insert session state transition trigger drop trigger insert_session_state on session_state; drop function insert_session_state(); diff --git a/internal/db/schema/migrations/oss/postgres/72/03_session_list_perf_fix.up.sql b/internal/db/schema/migrations/oss/postgres/72/03_session_list_perf_fix.up.sql index fcd4a2d1c4..809fe54a11 100644 --- a/internal/db/schema/migrations/oss/postgres/72/03_session_list_perf_fix.up.sql +++ b/internal/db/schema/migrations/oss/postgres/72/03_session_list_perf_fix.up.sql @@ -4,6 +4,7 @@ begin; -- Replaces the view created in 69/02_session_worker_protocol.up.sql + -- Replaced in 91/06_session_state_tstzrange.up.sql drop view session_list; create view session_list as select s.public_id, diff --git a/internal/db/schema/migrations/oss/postgres/84/02_wh_upsert_user_refact.up.sql b/internal/db/schema/migrations/oss/postgres/84/02_wh_upsert_user_refact.up.sql index 4f52d710ba..a55e0aeed9 100644 --- a/internal/db/schema/migrations/oss/postgres/84/02_wh_upsert_user_refact.up.sql +++ b/internal/db/schema/migrations/oss/postgres/84/02_wh_upsert_user_refact.up.sql @@ -69,6 +69,7 @@ begin; 'for the user that corresponds to the provided auth_token_id.'; -- Replaces function from 60/03_wh_sessions.up.sql + -- Replaced in 91/06_session_state_tstzrange.up.sql create function wh_insert_session() returns trigger as $$ declare diff --git a/internal/db/schema/migrations/oss/postgres/91/06_session_state_tstzrange.up.sql b/internal/db/schema/migrations/oss/postgres/91/06_session_state_tstzrange.up.sql new file mode 100644 index 0000000000..a0bf087153 --- /dev/null +++ b/internal/db/schema/migrations/oss/postgres/91/06_session_state_tstzrange.up.sql @@ -0,0 +1,178 @@ +-- Copyright (c) HashiCorp, Inc. +-- SPDX-License-Identifier: BUSL-1.1 + +begin; + -- Add new active_time_range column that will replace two start_time, end_time columns. + -- Also drop a number of constraints on the start_time, end_time columns. This will allow + -- from dropping these columns after the new column has been set with the correct data. + alter table session_state + add column active_time_range tstzrange not null default tstzrange(now(), null, '[]'), + drop constraint end_times_in_sequence, + drop constraint previous_end_time_and_start_time_in_sequence, + drop constraint start_and_end_times_in_sequence, + drop constraint session_state_session_id_previous_end_time_fkey; + + -- Set the new active_time_range column for any existing rows using start_time and end_time. + update session_state + set active_time_range = tstzrange(start_time, end_time, '[)'); + + -- Replaces view from 72/03/session_list_perf_fix.up.sql + -- Switch view to tuse the new column. This also eliminates the previous_end_time column + -- from the view, since it also will be dropped. + drop view session_list; + create view session_list as + select s.public_id, + s.user_id, + shsh.host_id, + shsh.host_set_id, + s.target_id, + s.auth_token_id, + s.project_id, + s.certificate, + s.expiration_time, + s.termination_reason, + s.create_time, + s.update_time, + s.version, + s.endpoint, + s.connection_limit, + ss.state, + lower(ss.active_time_range) as start_time, + upper(ss.active_time_range) as end_time + from session s + join session_state ss on s.public_id = ss.session_id + left join session_host_set_host shsh on s.public_id = shsh.session_id; + + -- Now we can finally drop the old columns and add a constraint on the new column + -- that ensures there are no overlaps on the active_time_range for a given session. + alter table session_state + drop column start_time, + drop column end_time, + drop column previous_end_time, + add constraint session_state_active_time_range_excl + exclude using gist (session_id with =, + active_time_range with &&), + add constraint active_time_range_not_empty + check (not isempty(active_time_range)); + + -- There are still a number of functions that reference the old columns. + -- These all need to be updated to use the new column instead. + + -- Replaces trigger from 0/50_session.up.sql + drop trigger immutable_columns on session_state; + create trigger immutable_columns before update on session_state + for each row execute procedure immutable_columns('session_id', 'state'); + + -- Replaces function from 28/02_prior_session_trigger.up.sql + drop trigger insert_session_state on session_state; + drop function insert_session_state(); + create function insert_session_state() returns trigger + as $$ + declare + old_col_state text; + begin + update session_state + set active_time_range = tstzrange(lower(active_time_range), now(), '[)') + where session_id = new.session_id + and upper(active_time_range) is null + returning state + into old_col_state; + + if not found then + new.prior_state = 'pending'; + else + new.prior_state = old_col_state; + end if; + + new.active_time_range = tstzrange(now(), null, '[]'); + + return new; + end; + $$ language plpgsql; + + create trigger insert_session_state before insert on session_state + for each row execute procedure insert_session_state(); + + -- Replaces function from 84/02_wh_upsert_user_refact.up.sql + drop trigger wh_insert_session on session; + drop function wh_insert_session; + create function wh_insert_session() returns trigger + as $$ + declare + new_row wh_session_accumulating_fact%rowtype; + begin + with + pending_timestamp (date_dim_key, time_dim_key, ts) as ( + select wh_date_key(lower(active_time_range)), wh_time_key(lower(active_time_range)), lower(active_time_range) + from session_state + where session_id = new.public_id + and state = 'pending' + ) + insert into wh_session_accumulating_fact ( + session_id, + auth_token_id, + host_key, + user_key, + credential_group_key, + session_pending_date_key, + session_pending_time_key, + session_pending_time + ) + select new.public_id, + new.auth_token_id, + 'no host source', -- will be updated by wh_upsert_host + wh_upsert_user(new.auth_token_id), + 'no credentials', -- will be updated by wh_upsert_credential_group + pending_timestamp.date_dim_key, + pending_timestamp.time_dim_key, + pending_timestamp.ts + from pending_timestamp + returning * into strict new_row; + return null; + end; + $$ language plpgsql; + + create trigger wh_insert_session after insert on session + for each row execute procedure wh_insert_session(); + + -- Replaces function from 15/01_wh_rename_key_columns.up.sql + drop trigger wh_insert_session_state on session_state; + drop function wh_insert_session_state; + + create function wh_insert_session_state() returns trigger + as $$ + declare + date_col text; + time_col text; + ts_col text; + q text; + session_row wh_session_accumulating_fact%rowtype; + begin + if new.state = 'pending' then + -- The pending state is the first state which is handled by the + -- wh_insert_session trigger. The update statement in this trigger will + -- fail for the pending state because the row for the session has not yet + -- been inserted into the wh_session_accumulating_fact table. + return null; + end if; + + date_col = 'session_' || new.state || '_date_key'; + time_col = 'session_' || new.state || '_time_key'; + ts_col = 'session_' || new.state || '_time'; + + q = format(' update wh_session_accumulating_fact + set (%I, %I, %I) = (select wh_date_key(%L), wh_time_key(%L), %L::timestamptz) + where session_id = %L + returning *', + date_col, time_col, ts_col, + lower(new.active_time_range), lower(new.active_time_range), lower(new.active_time_range), + new.session_id); + execute q into strict session_row; + + return null; + end; + $$ language plpgsql; + + create trigger wh_insert_session_state after insert on session_state + for each row execute function wh_insert_session_state(); +commit; diff --git a/internal/session/immutable_fields_test.go b/internal/session/immutable_fields_test.go index d5efdbe497..4082f5c238 100644 --- a/internal/session/immutable_fields_test.go +++ b/internal/session/immutable_fields_test.go @@ -5,6 +5,7 @@ package session import ( "context" + "fmt" "testing" "github.com/hashicorp/boundary/internal/db" @@ -97,15 +98,35 @@ func TestState_ImmutableFields(t *testing.T) { wrapper := db.TestWrapper(t) iamRepo := iam.TestRepo(t, conn, wrapper) - ts := timestamp.Timestamp{Timestamp: ×tamppb.Timestamp{Seconds: 0, Nanos: 0}} - _, _ = iam.TestScopes(t, iam.TestRepo(t, conn, wrapper)) session := TestDefaultSession(t, conn, wrapper, iamRepo) state := TestState(t, conn, session.PublicId, StatusActive) - var new State - err := rw.LookupWhere(context.Background(), &new, "session_id = ? and state = ?", []any{state.SessionId, state.Status}) - require.NoError(t, err) + fetchSession := func(ctx context.Context, rw *db.Db, sessionId string, startTime *timestamp.Timestamp) (*State, error) { + const selectQuery = ` +select session_id, + state, + lower(active_time_range) as start_time, + upper(active_time_range) as end_time + from session_state + where session_id = ? + and lower(active_time_range) = ?;` + var states []*State + rows, err := rw.Query(ctx, selectQuery, []any{sessionId, startTime}) + if err != nil { + return nil, err + } + defer rows.Close() + for rows.Next() { + if err := rw.ScanRows(ctx, rows, &states); err != nil { + return nil, err + } + } + if len(states) != 1 { + return nil, fmt.Errorf("found %d states, expected 1", len(states)) + } + return states[0], nil + } tests := []struct { name string @@ -115,7 +136,7 @@ func TestState_ImmutableFields(t *testing.T) { { name: "session_id", update: func() *State { - s := new.Clone().(*State) + s := state.Clone().(*State) s.SessionId = "s_thisIsNotAValidId" return s }(), @@ -124,47 +145,28 @@ func TestState_ImmutableFields(t *testing.T) { { name: "status", update: func() *State { - s := new.Clone().(*State) + s := state.Clone().(*State) s.Status = "canceling" return s }(), fieldMask: []string{"Status"}, }, - { - name: "start time", - update: func() *State { - s := new.Clone().(*State) - s.StartTime = &ts - return s - }(), - fieldMask: []string{"StartTime"}, - }, - { - name: "previous_end_time", - update: func() *State { - s := new.Clone().(*State) - s.PreviousEndTime = &ts - return s - }(), - fieldMask: []string{"PreviousEndTime"}, - }, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() assert, require := assert.New(t), require.New(t) - orig := new.Clone() - err := rw.LookupWhere(context.Background(), orig, "session_id = ? and start_time = ?", []any{new.SessionId, new.StartTime}) + orig, err := fetchSession(ctx, rw, state.SessionId, state.StartTime) require.NoError(err) rowsUpdated, err := rw.Update(context.Background(), tt.update, tt.fieldMask, nil, db.WithSkipVetForWrite(true)) require.Error(err) assert.Equal(0, rowsUpdated) - after := new.Clone() - err = rw.LookupWhere(context.Background(), after, "session_id = ? and start_time = ?", []any{new.SessionId, new.StartTime}) + after, err := fetchSession(ctx, rw, state.SessionId, state.StartTime) require.NoError(err) - assert.Equal(orig.(*State), after) + assert.Equal(orig, after) }) } } diff --git a/internal/session/query.go b/internal/session/query.go index da18eef887..c743092d98 100644 --- a/internal/session/query.go +++ b/internal/session/query.go @@ -90,7 +90,7 @@ active_session as ( where ss.session_id in (select * from unexpired_session) and ss.state = 'active' and - ss.end_time is null + upper(ss.active_time_range) is null ) insert into session_connection ( session_id, @@ -150,7 +150,7 @@ from where ss.session_id = @public_id and ss.state = 'canceling' and - ss.end_time is null + upper(ss.active_time_range) is null ) update session us set version = version +1, @@ -226,7 +226,7 @@ with canceling_session(session_id) as session_state ss where ss.state = 'canceling' and - ss.end_time is null + upper(ss.active_time_range) is null ) update session us set termination_reason = @@ -371,7 +371,7 @@ where and session_state.state = 'terminated' and - session_state.start_time < wt_sub_seconds_from_now(@threshold_seconds) + lower(session_state.active_time_range) < wt_sub_seconds_from_now(@threshold_seconds) ; ` sessionCredentialRewrapQuery = ` @@ -451,6 +451,16 @@ order by update_time desc, public_id desc; ` estimateCountSessions = ` select reltuples::bigint as estimate from pg_class where oid in ('session'::regclass) +` + + selectStates = ` + select session_id, + state, + lower(active_time_range) as start_time, + upper(active_time_range) as end_time + from session_state + where session_id = ? +order by active_time_range desc; ` ) diff --git a/internal/session/repository.go b/internal/session/repository.go index b7e98b56ef..79a266eb2c 100644 --- a/internal/session/repository.go +++ b/internal/session/repository.go @@ -144,33 +144,32 @@ func (r *Repository) convertToSessions(ctx context.Context, sessionList []*sessi PublicId: sv.PublicId, UserId: sv.UserId, HostId: sv.HostId, - TargetId: sv.TargetId, HostSetId: sv.HostSetId, + TargetId: sv.TargetId, AuthTokenId: sv.AuthTokenId, ProjectId: sv.ProjectId, Certificate: sv.Certificate, - CtCertificatePrivateKey: nil, // CtCertificatePrivateKey should not be returned in lists - CertificatePrivateKey: nil, // CertificatePrivateKey should not be returned in lists ExpirationTime: sv.ExpirationTime, - CtTofuToken: nil, // CtTofuToken should not be returned in lists - TofuToken: nil, // TofuToken should not be returned in lists TerminationReason: sv.TerminationReason, CreateTime: sv.CreateTime, UpdateTime: sv.UpdateTime, Version: sv.Version, Endpoint: sv.Endpoint, ConnectionLimit: sv.ConnectionLimit, - KeyId: "", // KeyId should not be returned in lists + CtCertificatePrivateKey: nil, // CtCertificatePrivateKey should not be returned in lists + CertificatePrivateKey: nil, // CertificatePrivateKey should not be returned in lists + CtTofuToken: nil, // CtTofuToken should not be returned in lists + TofuToken: nil, // TofuToken should not be returned in lists + KeyId: "", // KeyId should not be returned in lists } } if _, ok := states[sv.EndTime]; !ok { states[sv.EndTime] = &State{ - SessionId: sv.PublicId, - Status: Status(sv.Status), - PreviousEndTime: sv.PreviousEndTime, - StartTime: sv.StartTime, - EndTime: sv.EndTime, + SessionId: sv.PublicId, + Status: Status(sv.Status), + StartTime: sv.StartTime, + EndTime: sv.EndTime, } } diff --git a/internal/session/repository_session.go b/internal/session/repository_session.go index 5cff87003b..864db5c26d 100644 --- a/internal/session/repository_session.go +++ b/internal/session/repository_session.go @@ -212,7 +212,7 @@ func (r *Repository) LookupSession(ctx context.Context, sessionId string, opt .. if err := read.LookupById(ctx, &session); err != nil { return errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("failed for %s", sessionId))) } - states, err := fetchStates(ctx, read, sessionId, db.WithOrder("start_time desc")) + states, err := fetchStates(ctx, read, sessionId) if err != nil { return errors.Wrap(ctx, err, op) } @@ -618,7 +618,7 @@ func (r *Repository) fetchActivatedSessionStatesTx(ctx context.Context, reader d var txErr error var returnedStates []*State - returnedStates, txErr = fetchStates(ctx, reader, sessionId, db.WithOrder("start_time desc")) + returnedStates, txErr = fetchStates(ctx, reader, sessionId) if txErr != nil { return nil, errors.Wrap(ctx, txErr, op) } @@ -823,7 +823,7 @@ func (r *Repository) updateState(ctx context.Context, sessionId string, sessionV if rowsAffected != 0 && rowsAffected != 1 { return errors.New(ctx, errors.MultipleRecords, op, fmt.Sprintf("updated session %s to state %s and %d rows inserted (should be 0 or 1)", sessionId, s.String(), rowsAffected)) } - returnedStates, err = fetchStates(ctx, reader, sessionId, db.WithOrder("start_time desc")) + returnedStates, err = fetchStates(ctx, reader, sessionId) if err != nil { return errors.Wrap(ctx, err, op) } @@ -854,7 +854,7 @@ func (r *Repository) updateState(ctx context.Context, sessionId string, sessionV // non-active state, i.e. "canceling" or "terminated" It returns a *StateReport // object for each session that is not active, with its current status. func (r *Repository) CheckIfNotActive(ctx context.Context, reportedSessions []string) ([]*StateReport, error) { - const op = "session.(Repository).listSessionIdAndState" + const op = "session.(Repository).CheckIfNotActive" notActive := make([]*StateReport, 0, len(reportedSessions)) if len(reportedSessions) <= 0 { @@ -872,7 +872,7 @@ func (r *Repository) CheckIfNotActive(ctx context.Context, reportedSessions []st db.ExpBackoff{}, func(reader db.Reader, _ db.Writer) error { var states []*State - err := reader.SearchWhere(ctx, &states, "end_time is null and session_id in (?)", []any{reportedSessions}) + err := reader.SearchWhere(ctx, &states, "upper(active_time_range) is null and session_id in (?)", []any{reportedSessions}) if err != nil { return errors.Wrap(ctx, err, op) } @@ -926,9 +926,16 @@ func (r *Repository) deleteSessionsTerminatedBefore(ctx context.Context, thresho func fetchStates(ctx context.Context, r db.Reader, sessionId string, opt ...db.Option) ([]*State, error) { const op = "session.fetchStates" var states []*State - if err := r.SearchWhere(ctx, &states, "session_id = ?", []any{sessionId}, opt...); err != nil { + rows, err := r.Query(ctx, selectStates, []any{sessionId}, opt...) + if err != nil { return nil, errors.Wrap(ctx, err, op) } + defer rows.Close() + for rows.Next() { + if err := r.ScanRows(ctx, rows, &states); err != nil { + return nil, errors.Wrap(ctx, err, op) + } + } if len(states) == 0 { return nil, nil } diff --git a/internal/session/session.go b/internal/session/session.go index 063b460eac..6cbf75fb60 100644 --- a/internal/session/session.go +++ b/internal/session/session.go @@ -575,34 +575,27 @@ func (s *Session) decrypt(ctx context.Context, cipher wrapping.Wrapper) error { } type sessionListView struct { - // Session fields - PublicId string `json:"public_id,omitempty" gorm:"primary_key"` - UserId string `json:"user_id,omitempty" gorm:"default:null"` - HostId string `json:"host_id,omitempty" gorm:"default:null"` - TargetId string `json:"target_id,omitempty" gorm:"default:null"` - HostSetId string `json:"host_set_id,omitempty" gorm:"default:null"` - AuthTokenId string `json:"auth_token_id,omitempty" gorm:"default:null"` - ProjectId string `json:"project_id,omitempty" gorm:"default:null"` - Certificate []byte `json:"certificate,omitempty" gorm:"default:null"` - CtCertificatePrivateKey []byte `json:"ct_certificate_private_key,omitempty" gorm:"column:certificate_private_key;default:null" wrapping:"ct,certificate_private_key"` - CertificatePrivateKey []byte `json:"certificate_private_key,omitempty" gorm:"-" wrapping:"pt,certificate_private_key"` - ExpirationTime *timestamp.Timestamp `json:"expiration_time,omitempty" gorm:"default:null"` - CtTofuToken []byte `json:"ct_tofu_token,omitempty" gorm:"column:tofu_token;default:null" wrapping:"ct,tofu_token"` - TofuToken []byte `json:"tofu_token,omitempty" gorm:"-" wrapping:"pt,tofu_token"` - TerminationReason string `json:"termination_reason,omitempty" gorm:"default:null"` - CreateTime *timestamp.Timestamp `json:"create_time,omitempty" gorm:"default:current_timestamp"` - UpdateTime *timestamp.Timestamp `json:"update_time,omitempty" gorm:"default:current_timestamp"` - Version uint32 `json:"version,omitempty" gorm:"default:null"` - Endpoint string `json:"-" gorm:"default:null"` - ConnectionLimit int32 `json:"connection_limit,omitempty" gorm:"default:null"` - KeyId string `json:"key_id,omitempty" gorm:"default:null"` - ProtocolWorkerId string `json:"protocol_worker_id,omitempty" gorm:"default:null"` + // Session fields, we omit some fields that are not included when listing sessions. + PublicId string `gorm:"primary_key"` + UserId string `gorm:"default:null"` + HostId string `gorm:"default:null"` + HostSetId string `gorm:"default:null"` + TargetId string `gorm:"default:null"` + AuthTokenId string `gorm:"default:null"` + ProjectId string `gorm:"default:null"` + Certificate []byte `gorm:"default:null"` + ExpirationTime *timestamp.Timestamp `gorm:"default:null"` + TerminationReason string `gorm:"default:null"` + CreateTime *timestamp.Timestamp `gorm:"default:current_timestamp"` + UpdateTime *timestamp.Timestamp `gorm:"default:current_timestamp"` + Version uint32 `gorm:"default:null"` + Endpoint string `gorm:"default:null"` + ConnectionLimit int32 `gorm:"default:null"` // State fields - Status string `json:"state,omitempty" gorm:"column:state"` - PreviousEndTime *timestamp.Timestamp `json:"previous_end_time,omitempty" gorm:"default:current_timestamp"` - StartTime *timestamp.Timestamp `json:"start_time,omitempty" gorm:"default:current_timestamp;primary_key"` - EndTime *timestamp.Timestamp `json:"end_time,omitempty" gorm:"default:current_timestamp"` + Status string `gorm:"column:state"` + StartTime *timestamp.Timestamp `gorm:"column:start_time"` + EndTime *timestamp.Timestamp `gorm:"column:end_time"` } // TableName returns the tablename to override the default gorm table name diff --git a/internal/session/state.go b/internal/session/state.go index 445994206c..a7b0cfac6e 100644 --- a/internal/session/state.go +++ b/internal/session/state.go @@ -54,8 +54,6 @@ type State struct { SessionId string `json:"session_id,omitempty" gorm:"primary_key"` // status of the session Status Status `json:"status,omitempty" gorm:"column:state"` - // PreviousEndTime from the RDBMS - PreviousEndTime *timestamp.Timestamp `json:"previous_end_time,omitempty" gorm:"default:current_timestamp"` // StartTime from the RDBMS StartTime *timestamp.Timestamp `json:"start_time,omitempty" gorm:"default:current_timestamp;primary_key"` // EndTime from the RDBMS @@ -95,15 +93,6 @@ func (s *State) Clone() any { SessionId: s.SessionId, Status: s.Status, } - if s.PreviousEndTime != nil { - clone.PreviousEndTime = ×tamp.Timestamp{ - Timestamp: ×tamppb.Timestamp{ - Seconds: s.PreviousEndTime.Timestamp.Seconds, - Nanos: s.PreviousEndTime.Timestamp.Nanos, - }, - } - } - if s.StartTime != nil { clone.StartTime = ×tamp.Timestamp{ Timestamp: ×tamppb.Timestamp{ @@ -163,8 +152,5 @@ func (s *State) validate(ctx context.Context) error { if s.EndTime != nil { return errors.New(ctx, errors.InvalidParameter, op, "end time is not settable") } - if s.PreviousEndTime != nil { - return errors.New(ctx, errors.InvalidParameter, op, "previous end time is not settable") - } return nil } diff --git a/internal/session/state_test.go b/internal/session/state_test.go index 25e73f1df4..29a9a86b45 100644 --- a/internal/session/state_test.go +++ b/internal/session/state_test.go @@ -4,160 +4,14 @@ package session import ( - "context" "testing" "github.com/hashicorp/boundary/internal/db" - "github.com/hashicorp/boundary/internal/errors" "github.com/hashicorp/boundary/internal/iam" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func TestState_Create(t *testing.T) { - t.Parallel() - conn, _ := db.TestSetup(t, "postgres") - wrapper := db.TestWrapper(t) - iamRepo := iam.TestRepo(t, conn, wrapper) - session := TestDefaultSession(t, conn, wrapper, iamRepo) - - type args struct { - sessionId string - status Status - } - - tests := []struct { - name string - args args - want *State - wantErr bool - wantIsErr errors.Code - create bool - wantCreateErr bool - }{ - { - name: "valid", - args: args{ - sessionId: session.PublicId, - status: StatusActive, - }, - want: &State{ - SessionId: session.PublicId, - Status: StatusActive, - }, - create: true, - }, - { - name: "empty-sessionId", - args: args{ - status: StatusPending, - }, - wantErr: true, - wantIsErr: errors.InvalidParameter, - }, - { - name: "empty-status", - args: args{ - sessionId: session.PublicId, - }, - wantErr: true, - wantIsErr: errors.InvalidParameter, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - assert, require := assert.New(t), require.New(t) - got, err := NewState(context.Background(), tt.args.sessionId, tt.args.status) - if tt.wantErr { - require.Error(err) - assert.True(errors.Match(errors.T(tt.wantIsErr), err)) - return - } - require.NoError(err) - assert.Equal(tt.want, got) - if tt.create { - err = db.New(conn).Create(context.Background(), got) - if tt.wantCreateErr { - assert.Error(err) - return - } else { - assert.NoError(err) - } - } - }) - } -} - -func TestState_Delete(t *testing.T) { - t.Parallel() - ctx := context.Background() - conn, _ := db.TestSetup(t, "postgres") - rw := db.New(conn) - wrapper := db.TestWrapper(t) - iamRepo := iam.TestRepo(t, conn, wrapper) - session := TestDefaultSession(t, conn, wrapper, iamRepo) - session2 := TestDefaultSession(t, conn, wrapper, iamRepo) - - tests := []struct { - name string - state *State - deleteStateId string - wantRowsDeleted int - wantErr bool - wantErrMsg string - }{ - { - name: "valid", - state: TestState(t, conn, session.PublicId, StatusTerminated), - wantErr: false, - wantRowsDeleted: 1, - }, - { - name: "bad-id", - state: TestState(t, conn, session2.PublicId, StatusTerminated), - deleteStateId: func() string { - id, err := db.NewPublicId(ctx, StatePrefix) - require.NoError(t, err) - return id - }(), - wantErr: false, - wantRowsDeleted: 0, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - assert, require := assert.New(t), require.New(t) - - var initialState State - err := rw.LookupWhere(ctx, &initialState, "session_id = ? and state = ?", []any{tt.state.SessionId, tt.state.Status}) - require.NoError(err) - - deleteState := allocState() - if tt.deleteStateId != "" { - deleteState.SessionId = tt.deleteStateId - } else { - deleteState.SessionId = tt.state.SessionId - } - deleteState.StartTime = initialState.StartTime - deletedRows, err := rw.Delete(ctx, &deleteState) - if tt.wantErr { - require.Error(err) - return - } - require.NoError(err) - if tt.wantRowsDeleted == 0 { - assert.Equal(tt.wantRowsDeleted, deletedRows) - return - } - assert.Equal(tt.wantRowsDeleted, deletedRows) - foundState := allocState() - err = rw.LookupWhere(ctx, &foundState, "session_id = ? and start_time = ?", []any{tt.state.SessionId, initialState.StartTime}) - require.Error(err) - assert.True(errors.IsNotFoundError(err)) - }) - } -} - func TestState_Clone(t *testing.T) { t.Parallel() conn, _ := db.TestSetup(t, "postgres") diff --git a/internal/session/testing.go b/internal/session/testing.go index 9bd1d60c17..e9237e9785 100644 --- a/internal/session/testing.go +++ b/internal/session/testing.go @@ -46,13 +46,23 @@ func TestConnection(t testing.TB, conn *db.DB, sessionId, clientTcpAddr string, // TestState creates a test state for the sessionId in the repository. func TestState(t testing.TB, conn *db.DB, sessionId string, state Status) *State { + const insertSessionState = ` +insert into session_state (session_id, state, active_time_range) + values ($1, $2, tstzrange($3, null, '[]')) + returning lower(active_time_range) as start_time +;` t.Helper() require := require.New(t) rw := db.New(conn) s, err := NewState(context.Background(), sessionId, state) require.NoError(err) - err = rw.Create(context.Background(), s) + rows, err := rw.Query(context.Background(), insertSessionState, []any{s.SessionId, s.Status, s.StartTime}) require.NoError(err) + defer rows.Close() + for rows.Next() { + err := rows.Scan(&s.StartTime) + require.NoError(err) + } return s } @@ -142,7 +152,7 @@ func TestSession(t testing.TB, conn *db.DB, rootWrapper wrapping.Wrapper, c Comp require.NoError(err) } - ss, err := fetchStates(ctx, rw, s.PublicId, append(opts.withDbOpts, db.WithOrder("start_time desc"))...) + ss, err := fetchStates(ctx, rw, s.PublicId, opts.withDbOpts...) require.NoError(err) s.States = ss