From 030b5fd4f07f73d1d8420de9021f2aae3a0ddf83 Mon Sep 17 00:00:00 2001 From: Matthew Hooker Date: Fri, 19 Jan 2018 19:44:01 -0800 Subject: [PATCH] WIP add context to state bag --- helper/multistep/statebag.go | 44 +++++++++++++++++++++++++++++++----- 1 file changed, 38 insertions(+), 6 deletions(-) diff --git a/helper/multistep/statebag.go b/helper/multistep/statebag.go index dab712316..c2f1e8a84 100644 --- a/helper/multistep/statebag.go +++ b/helper/multistep/statebag.go @@ -1,15 +1,20 @@ package multistep import ( + "context" "sync" ) +// Add context to state bag to prevent changing step signature + // StateBag holds the state that is used by the Runner and Steps. The // StateBag implementation must be safe for concurrent access. type StateBag interface { Get(string) interface{} GetOk(string) (interface{}, bool) Put(string, interface{}) + Context() context.Context + WithContext(context.Context) StateBag } // BasicStateBag implements StateBag by using a normal map underneath @@ -17,7 +22,13 @@ type StateBag interface { type BasicStateBag struct { data map[string]interface{} l sync.RWMutex - once sync.Once + ctx context.Context +} + +func NewBasicStateBag() *BasicStateBag { + b := new(BasicStateBag) + b.data = make(map[string]interface{}) + return b } func (b *BasicStateBag) Get(k string) interface{} { @@ -37,11 +48,32 @@ func (b *BasicStateBag) Put(k string, v interface{}) { b.l.Lock() defer b.l.Unlock() - // Make sure the map is initialized one time, on write - b.once.Do(func() { - b.data = make(map[string]interface{}) - }) - // Write the data b.data[k] = v } + +func (b *BasicStateBag) Context() context.Context { + if b.ctx != nil { + return b.ctx + } + return context.Background() +} + +// WithContext returns a copy of BasicStateBag with the provided context +// We copy the state bag +func (b *BasicStateBag) WithContext(ctx context.Context) *BasicStateBag { + if ctx == nil { + panic("nil context") + } + // read lock because copying is a read operation + b.l.RLock() + defer b.l.RUnlock() + + b2 := NewBasicStateBag() + + for k, v := range b.data { + b2.data[k] = v + } + b2.ctx = ctx + return b2 +}