WIP add context to state bag

pull/5810/head
Matthew Hooker 8 years ago
parent 07a5af66f8
commit 030b5fd4f0
No known key found for this signature in database
GPG Key ID: 7B5F933D9CE8C6A1

@ -1,15 +1,20 @@
package multistep package multistep
import ( import (
"context"
"sync" "sync"
) )
// Add context to state bag to prevent changing step signature
// StateBag holds the state that is used by the Runner and Steps. The // StateBag holds the state that is used by the Runner and Steps. The
// StateBag implementation must be safe for concurrent access. // StateBag implementation must be safe for concurrent access.
type StateBag interface { type StateBag interface {
Get(string) interface{} Get(string) interface{}
GetOk(string) (interface{}, bool) GetOk(string) (interface{}, bool)
Put(string, interface{}) Put(string, interface{})
Context() context.Context
WithContext(context.Context) StateBag
} }
// BasicStateBag implements StateBag by using a normal map underneath // BasicStateBag implements StateBag by using a normal map underneath
@ -17,7 +22,13 @@ type StateBag interface {
type BasicStateBag struct { type BasicStateBag struct {
data map[string]interface{} data map[string]interface{}
l sync.RWMutex l sync.RWMutex
once sync.Once ctx context.Context
}
func NewBasicStateBag() *BasicStateBag {
b := new(BasicStateBag)
b.data = make(map[string]interface{})
return b
} }
func (b *BasicStateBag) Get(k string) interface{} { func (b *BasicStateBag) Get(k string) interface{} {
@ -37,11 +48,32 @@ func (b *BasicStateBag) Put(k string, v interface{}) {
b.l.Lock() b.l.Lock()
defer b.l.Unlock() defer b.l.Unlock()
// Make sure the map is initialized one time, on write
b.once.Do(func() {
b.data = make(map[string]interface{})
})
// Write the data // Write the data
b.data[k] = v b.data[k] = v
} }
func (b *BasicStateBag) Context() context.Context {
if b.ctx != nil {
return b.ctx
}
return context.Background()
}
// WithContext returns a copy of BasicStateBag with the provided context
// We copy the state bag
func (b *BasicStateBag) WithContext(ctx context.Context) *BasicStateBag {
if ctx == nil {
panic("nil context")
}
// read lock because copying is a read operation
b.l.RLock()
defer b.l.RUnlock()
b2 := NewBasicStateBag()
for k, v := range b.data {
b2.data[k] = v
}
b2.ctx = ctx
return b2
}

Loading…
Cancel
Save