From cae97fcd2e2019c05e09aec24c7068e6fbbfc5d1 Mon Sep 17 00:00:00 2001 From: Elim Tsiagbey Date: Mon, 8 Jan 2024 14:36:21 -0500 Subject: [PATCH] feat: Add function that writes to a file without expanding the buffer and closes the file. --- internal/bsr/bsr.go | 6 +- internal/bsr/convert/convert_test.go | 3 +- internal/bsr/convert/ssh_test.go | 6 +- internal/bsr/encode.go | 6 +- internal/bsr/encode_bench_test.go | 17 +++-- internal/bsr/encode_test.go | 36 ++++++++--- internal/bsr/internal/checksum/checksum.go | 40 +++++++++++- internal/bsr/internal/fstest/fs.go | 72 +++++++++++++++++++--- internal/bsr/internal/fstest/fs_test.go | 25 ++++++++ internal/bsr/internal/fstest/local.go | 34 +++++++++- internal/bsr/internal/journal/journal.go | 32 ++++++++++ internal/storage/storage.go | 10 ++- 12 files changed, 253 insertions(+), 34 deletions(-) diff --git a/internal/bsr/bsr.go b/internal/bsr/bsr.go index 102f10618b..613f24124d 100644 --- a/internal/bsr/bsr.go +++ b/internal/bsr/bsr.go @@ -657,7 +657,7 @@ func (c *Connection) NewMessagesWriter(ctx context.Context, dir Direction) (io.W } // NewRequestsWriter creates a writer for recording connection requests. -func (c *Connection) NewRequestsWriter(ctx context.Context, dir Direction) (io.Writer, error) { +func (c *Connection) NewRequestsWriter(ctx context.Context, dir Direction) (storage.Writer, error) { const op = "bsr.(Connection).NewRequestsWriter" switch { @@ -704,7 +704,7 @@ func (c *Channel) Close(ctx context.Context) error { } // NewMessagesWriter creates a writer for recording channel messages. -func (c *Channel) NewMessagesWriter(ctx context.Context, dir Direction) (io.Writer, error) { +func (c *Channel) NewMessagesWriter(ctx context.Context, dir Direction) (storage.Writer, error) { const op = "bsr.(Channel).NewMessagesWriter" switch { @@ -726,7 +726,7 @@ func (c *Channel) NewMessagesWriter(ctx context.Context, dir Direction) (io.Writ } // NewRequestsWriter creates a writer for recording channel requests. -func (c *Channel) NewRequestsWriter(ctx context.Context, dir Direction) (io.Writer, error) { +func (c *Channel) NewRequestsWriter(ctx context.Context, dir Direction) (storage.Writer, error) { const op = "bsr.(Channel).NewRequestsWriter" switch { diff --git a/internal/bsr/convert/convert_test.go b/internal/bsr/convert/convert_test.go index cad6bbf954..ce014096c2 100644 --- a/internal/bsr/convert/convert_test.go +++ b/internal/bsr/convert/convert_test.go @@ -16,6 +16,7 @@ import ( "github.com/hashicorp/boundary/internal/bsr/internal/fstest" "github.com/hashicorp/boundary/internal/bsr/kms" "github.com/hashicorp/boundary/internal/bsr/ssh" + "github.com/hashicorp/boundary/internal/storage" "github.com/stretchr/testify/require" ) @@ -45,7 +46,7 @@ func testChunks(s string, d bsr.Direction, p bsr.Protocol) []bsr.Chunk { } } -func writeToChannels(ctx context.Context, w io.Writer, chunks ...bsr.Chunk) error { +func writeToChannels(ctx context.Context, w storage.Writer, chunks ...bsr.Chunk) error { w.Write(bsr.Magic.Bytes()) enc, err := bsr.NewChunkEncoder(ctx, w, bsr.NoCompression, bsr.NoEncryption) if err != nil { diff --git a/internal/bsr/convert/ssh_test.go b/internal/bsr/convert/ssh_test.go index 02810a95da..6522d04f83 100644 --- a/internal/bsr/convert/ssh_test.go +++ b/internal/bsr/convert/ssh_test.go @@ -14,6 +14,7 @@ import ( "github.com/hashicorp/boundary/internal/bsr" sshv1 "github.com/hashicorp/boundary/internal/bsr/gen/ssh/v1" + "github.com/hashicorp/boundary/internal/bsr/internal/fstest" "github.com/hashicorp/boundary/internal/bsr/ssh" "github.com/stretchr/testify/require" ) @@ -31,9 +32,10 @@ func Test_sshChannelToAsciicast(t *testing.T) { return f } newScanner := func(chunks ...bsr.Chunk) *bsr.ChunkScanner { - var buf bytes.Buffer + buf, err := fstest.NewTempBuffer() + require.NoError(t, err) buf.Write(bsr.Magic.Bytes()) - enc, err := bsr.NewChunkEncoder(ctx, &buf, bsr.NoCompression, bsr.NoEncryption) + enc, err := bsr.NewChunkEncoder(ctx, buf, bsr.NoCompression, bsr.NoEncryption) require.NoError(t, err) for _, c := range chunks { diff --git a/internal/bsr/encode.go b/internal/bsr/encode.go index 8a5b63cd67..a15e8a4453 100644 --- a/internal/bsr/encode.go +++ b/internal/bsr/encode.go @@ -13,6 +13,8 @@ import ( "hash/crc32" "io" "sync" + + "github.com/hashicorp/boundary/internal/storage" ) type encodeCache struct { @@ -43,13 +45,13 @@ var encodeCachePool = &sync.Pool{ // ChunkEncoder will encode a chunk and write it to the writer. // It will compress the chunk data based on the compression. type ChunkEncoder struct { - w io.Writer + w storage.Writer compression Compression encryption Encryption } // NewChunkEncoder creates a ChunkEncoder. -func NewChunkEncoder(ctx context.Context, w io.Writer, c Compression, e Encryption) (*ChunkEncoder, error) { +func NewChunkEncoder(ctx context.Context, w storage.Writer, c Compression, e Encryption) (*ChunkEncoder, error) { const op = "bsr.NewChunkEncoder" if w == nil { diff --git a/internal/bsr/encode_bench_test.go b/internal/bsr/encode_bench_test.go index 885cbc8968..438a2d2631 100644 --- a/internal/bsr/encode_bench_test.go +++ b/internal/bsr/encode_bench_test.go @@ -4,11 +4,12 @@ package bsr import ( - "bytes" "context" "fmt" "testing" "time" + + "github.com/hashicorp/boundary/internal/bsr/internal/fstest" ) type testChunk struct { @@ -48,8 +49,11 @@ func BenchmarkEncodeParallel(b *testing.B) { b.StartTimer() b.RunParallel(func(pb *testing.PB) { for pb.Next() { - var buf bytes.Buffer - enc, _ := NewChunkEncoder(ctx, &buf, NoCompression, NoEncryption) + buf, err := fstest.NewTempBuffer() + if err != nil { + panic("could not create buffer") + } + enc, _ := NewChunkEncoder(ctx, buf, NoCompression, NoEncryption) for _, c := range chunks { if _, err := enc.Encode(ctx, c); err != nil { b.Fatal("Encode:", err) @@ -75,8 +79,11 @@ func BenchmarkEncodeSequential(b *testing.B) { b.StartTimer() for i := 0; i < b.N; i++ { - var buf bytes.Buffer - enc, _ := NewChunkEncoder(ctx, &buf, NoCompression, NoEncryption) + buf, err := fstest.NewTempBuffer() + if err != nil { + panic("could not create buffer") + } + enc, _ := NewChunkEncoder(ctx, buf, NoCompression, NoEncryption) for _, c := range chunks { if _, err := enc.Encode(ctx, c); err != nil { b.Fatal("Encode:", err) diff --git a/internal/bsr/encode_test.go b/internal/bsr/encode_test.go index f4334aa5eb..46e4b7dbd5 100644 --- a/internal/bsr/encode_test.go +++ b/internal/bsr/encode_test.go @@ -9,11 +9,12 @@ import ( "context" "errors" "fmt" - "io" "testing" "time" "github.com/hashicorp/boundary/internal/bsr" + "github.com/hashicorp/boundary/internal/bsr/internal/fstest" + "github.com/hashicorp/boundary/internal/storage" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -268,8 +269,9 @@ func TestChunkEncoder(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - var buf bytes.Buffer - enc, err := bsr.NewChunkEncoder(ctx, &buf, tc.c, tc.e) + buf, err := fstest.NewTempBuffer() + require.NoError(t, err) + enc, err := bsr.NewChunkEncoder(ctx, buf, tc.c, tc.e) require.NoError(t, err) var wrote int @@ -294,6 +296,10 @@ func (e errorWriter) Write(_ []byte) (int, error) { return 0, fmt.Errorf("write error") } +func (e errorWriter) WriteAndClose(_ []byte) (int, error) { + return 0, fmt.Errorf("write error") +} + func TestChunkEncoderEncodeError(t *testing.T) { ctx := context.Background() @@ -301,7 +307,7 @@ func TestChunkEncoderEncodeError(t *testing.T) { cases := []struct { name string - w io.Writer + w storage.Writer c bsr.Compression e bsr.Encryption chunk bsr.Chunk @@ -309,7 +315,11 @@ func TestChunkEncoderEncodeError(t *testing.T) { }{ { "chunk-marshal-error", - func() io.Writer { var buf bytes.Buffer; return &buf }(), + func() storage.Writer { + buf, err := fstest.NewTempBuffer() + require.NoError(t, err) + return buf + }(), bsr.NoCompression, bsr.NoEncryption, &testChunk{ @@ -325,7 +335,7 @@ func TestChunkEncoderEncodeError(t *testing.T) { }, { "writer-error", - func() io.Writer { return errorWriter{} }(), + func() storage.Writer { return errorWriter{} }(), bsr.NoCompression, bsr.NoEncryption, &testChunk{ @@ -357,21 +367,29 @@ func TestChunkEncoderErrors(t *testing.T) { cases := []struct { name string - w io.Writer + w storage.Writer c bsr.Compression e bsr.Encryption want error }{ { "invalid-compression", - func() io.Writer { var buf bytes.Buffer; return &buf }(), + func() storage.Writer { + buf, err := fstest.NewTempBuffer() + require.NoError(t, err) + return buf + }(), bsr.Compression(255), bsr.NoEncryption, errors.New("bsr.NewChunkEncoder: invalid compression: invalid parameter"), }, { "invalid-encryption", - func() io.Writer { var buf bytes.Buffer; return &buf }(), + func() storage.Writer { + buf, err := fstest.NewTempBuffer() + require.NoError(t, err) + return buf + }(), bsr.NoCompression, bsr.Encryption(255), errors.New("bsr.NewChunkEncoder: invalid encryption: invalid parameter"), diff --git a/internal/bsr/internal/checksum/checksum.go b/internal/bsr/internal/checksum/checksum.go index 62c7f0c352..ab101b357d 100644 --- a/internal/bsr/internal/checksum/checksum.go +++ b/internal/bsr/internal/checksum/checksum.go @@ -29,6 +29,12 @@ type writerFile interface { fs.File io.WriteCloser io.StringWriter + WriteAndClose([]byte) (int, error) +} + +type underlyingFile interface { + fs.File + WriteAndClose([]byte) (int, error) } // File is a writable file that will compute its checksum while it is written @@ -45,7 +51,7 @@ type File struct { checksumWriter io.Writer - underlying fs.File + underlying underlyingFile } // NewFile wraps the provided writerFile in a checksumed File. When the file is @@ -113,4 +119,36 @@ func (f *File) Close() error { return closeErrors } +// WriteAndClose writes to the underlying file, closes the Sha256SumWriter +// and computes the checksum and writes it to f.checksumWriter +func (f *File) WriteAndClose(b []byte) (int, error) { + const op = "checksum.(File).WriteAndClose" + + var closeErrors error + + // Call stat before closure; calling it after results in an err + s, err := f.Stat() + if err != nil { + closeErrors = errors.Join(closeErrors, fmt.Errorf("%s: %w", op, err)) + return 0, closeErrors + } + + n, err := f.underlying.WriteAndClose(b) + if err != nil { + return 0, fmt.Errorf("%s: %w", op, err) + } + + sum, err := f.Sha256SumWriter.Sum(f.ctx, crypto.WithHexEncoding(true)) + if err != nil { + closeErrors = errors.Join(closeErrors, fmt.Errorf("%s: %w", op, err)) + return 0, closeErrors + } + + if _, err := f.checksumWriter.Write([]byte(fmt.Sprintf(checksumLine, sum, s.Name()))); err != nil { + closeErrors = errors.Join(closeErrors, fmt.Errorf("%s: %w", op, err)) + } + + return n, closeErrors +} + var _ writerFile = (*File)(nil) diff --git a/internal/bsr/internal/fstest/fs.go b/internal/bsr/internal/fstest/fs.go index d728c2a8ce..61824c6f1f 100644 --- a/internal/bsr/internal/fstest/fs.go +++ b/internal/bsr/internal/fstest/fs.go @@ -268,6 +268,11 @@ type MemFile struct { syncMode storage.SyncMode accessMode storage.AccessMode + bufferSize uint64 + minimumBufferSize uint64 + bufferOffset int64 + currentOffset int64 + sync.RWMutex } @@ -279,14 +284,18 @@ func NewMemFile(n string, mode sfs.FileMode, options ...Option) *MemFile { storageOpts := storage.GetOpts(opts.withStorageOptions...) return &MemFile{ - name: n, - Buf: bytes.NewBuffer([]byte{}), - src: []byte{}, - mode: mode, - accessMode: storageOpts.WithFileAccessMode, - syncMode: storageOpts.WithCloseSyncMode, - statFunc: opts.withStatFunc, - closeFunc: opts.withCloseFunc, + name: n, + Buf: bytes.NewBuffer([]byte{}), + src: []byte{}, + mode: mode, + accessMode: storageOpts.WithFileAccessMode, + syncMode: storageOpts.WithCloseSyncMode, + statFunc: opts.withStatFunc, + closeFunc: opts.withCloseFunc, + bufferSize: storageOpts.WithBuffer, + minimumBufferSize: storageOpts.WithMinimumBuffer, + currentOffset: 0, + bufferOffset: 0, } } @@ -347,6 +356,10 @@ func (m *MemFile) Close() error { m.Lock() defer m.Unlock() + return m.close() +} + +func (m *MemFile) close() error { if m.Closed { return fmt.Errorf("close on closed file") } @@ -367,6 +380,10 @@ func (m *MemFile) Write(p []byte) (n int, err error) { m.Lock() defer m.Unlock() + return m.write(p) +} + +func (m *MemFile) write(p []byte) (n int, err error) { if m.Closed { return 0, fmt.Errorf("write on closed file") } @@ -385,6 +402,24 @@ func (m *MemFile) Write(p []byte) (n int, err error) { return m.Buf.Write(p) } +// WriteAndClose writes and closes the file. +func (m *MemFile) WriteAndClose(p []byte) (int, error) { + m.Lock() + defer m.Unlock() + + n, err := m.write(p) + if err != nil { + return n, fmt.Errorf("write failed: %w", err) + } + + err = m.close() + if err != nil { + return n, fmt.Errorf("close failed: %w", err) + } + + return n, nil +} + // TempFile implements storage.TempFile type TempFile struct { *os.File @@ -412,3 +447,24 @@ func (t *TempFile) Close() error { } return os.Remove(fname.Name()) } + +func (t *TempFile) WriteAndClose(b []byte) (int, error) { + panic("not implemented") +} + +type TempBuffer struct { + bytes.Buffer +} + +// NewTempBuffer creates a TempBuffer. +func NewTempBuffer() (*TempBuffer, error) { + var testBuffer bytes.Buffer + return &TempBuffer{ + testBuffer, + }, nil +} + +// WriteAndClose writes and closes the file. +func (t *TempBuffer) WriteAndClose(b []byte) (int, error) { + return t.Write(b) +} diff --git a/internal/bsr/internal/fstest/fs_test.go b/internal/bsr/internal/fstest/fs_test.go index 94df7012e3..df8976c75a 100644 --- a/internal/bsr/internal/fstest/fs_test.go +++ b/internal/bsr/internal/fstest/fs_test.go @@ -628,3 +628,28 @@ func TestOutOfSpace(t *testing.T) { require.ErrorIs(t, err, fstest.ErrOutOfSpace) assert.Nil(t, r2) } + +func TestMemFile_WriteAndClose(t *testing.T) { + mf := fstest.NewMemFile( + "test-write-and-close-file", + 0o644, + fstest.WithStorageOptions([]storage.Option{ + storage.WithFileAccessMode(storage.ReadWrite), + })) + require.False(t, mf.Closed, "MemFile is closed") + + str1 := "Input 1" + n, err := mf.Write([]byte(str1)) + require.NoError(t, err) + assert.NotNil(t, n) + + str2 := "Input 2" + n, err = mf.WriteAndClose([]byte(str2)) + require.NoError(t, err) + assert.NotNil(t, n) + + expectedString := "Input 1Input 2" + assert.Equal(t, expectedString, mf.Buf.String()) + + require.True(t, mf.Closed, "MemFile is not closed") +} diff --git a/internal/bsr/internal/fstest/local.go b/internal/bsr/internal/fstest/local.go index 7fbb842cee..95626a2778 100644 --- a/internal/bsr/internal/fstest/local.go +++ b/internal/bsr/internal/fstest/local.go @@ -314,9 +314,15 @@ func (f *LocalFile) Read(b []byte) (int, error) { // Close closes the file preventing reads or writes. func (f *LocalFile) Close() error { - const op = "fstest.(LocalFile).Stat" f.Lock() defer f.Unlock() + + return f.close() +} + +func (f *LocalFile) close() error { + const op = "fstest.(LocalFile).close" + if f.closed { return nil } @@ -330,9 +336,15 @@ func (f *LocalFile) Close() error { } func (f *LocalFile) Write(b []byte) (int, error) { - const op = "fstest.(localFile).Write" f.Lock() defer f.Unlock() + + return f.write(b) +} + +func (f *LocalFile) write(b []byte) (int, error) { + const op = "fstest.(localFile).write" + if f.closed { return 0, fmt.Errorf("%s: file is closed", op) } @@ -352,3 +364,21 @@ func (f *LocalFile) WriteString(s string) (int, error) { const op = "storage.(localFile).WriteString" return f.Write([]byte(s)) } + +// WriteAndClose writes and closes the file. +func (f *LocalFile) WriteAndClose(b []byte) (int, error) { + f.Lock() + defer f.Unlock() + + n, err := f.write(b) + if err != nil { + return n, fmt.Errorf("write failed: %w", err) + } + + err = f.close() + if err != nil { + return n, fmt.Errorf("close failed: %w", err) + } + + return n, nil +} diff --git a/internal/bsr/internal/journal/journal.go b/internal/bsr/internal/journal/journal.go index b0130719ba..e461a34f9a 100644 --- a/internal/bsr/internal/journal/journal.go +++ b/internal/bsr/internal/journal/journal.go @@ -20,6 +20,7 @@ type writerFile interface { fs.File io.WriteCloser io.StringWriter + WriteAndClose([]byte) (int, error) } // Journal is used to record meta data about the operations that will be and @@ -61,6 +62,11 @@ func (j *Journal) Record(op, f string) error { return err } +// WriteAndClose writes and closes the journal +func (j *Journal) WriteAndClose(b []byte) (int, error) { + panic("not implemented") +} + // File is a writable file that will update a Journal as it closed. type File struct { j *Journal @@ -102,4 +108,30 @@ func (f *File) Close() error { return f.j.Record("CLOSED", s.Name()) } +// WriteAndClose writes to the underlying file and closes the underlying file, +// writing to the journal prior to and after +func (f *File) WriteAndClose(b []byte) (int, error) { + const op = "journal.(File).Close" + + s, err := f.Stat() + if err != nil { + return 0, fmt.Errorf("%s: %w", op, err) + } + + if err := f.j.Record("CLOSING", s.Name()); err != nil { + return 0, fmt.Errorf("%s: %w", op, err) + } + + n, err := f.writerFile.WriteAndClose(b) + if err != nil { + return n, fmt.Errorf("%s: %w", op, err) + } + + if err := f.j.Record("CLOSED", s.Name()); err != nil { + return n, fmt.Errorf("%s: %w", op, err) + } + + return n, err +} + var _ writerFile = (*File)(nil) diff --git a/internal/storage/storage.go b/internal/storage/storage.go index e12ab4eb97..48bde0a2f3 100644 --- a/internal/storage/storage.go +++ b/internal/storage/storage.go @@ -55,8 +55,8 @@ type Container interface { // File represents a storage File. type File interface { fs.File - io.Writer io.StringWriter + Writer } // TempFile is a temporary File. It will get removed when Closed. @@ -64,3 +64,11 @@ type TempFile interface { File io.Seeker } + +// Writer is an interface that extends the io.Writer interface with an additional +// WriteAndClose method. WriteAndClose writes a byte slice and closes the file in +// a single call. +type Writer interface { + io.Writer + WriteAndClose([]byte) (int, error) +}