From 208611ec3e367080ba817fa2dbce2fda20151e1e Mon Sep 17 00:00:00 2001 From: Damian Debkowski Date: Thu, 14 Dec 2023 09:25:05 -0800 Subject: [PATCH] feat(bsr): invoke writeAndClose() in encode --- internal/bsr/convert/convert_test.go | 9 +++------ internal/bsr/encode.go | 14 ++++---------- internal/bsr/encode_test.go | 2 -- 3 files changed, 7 insertions(+), 18 deletions(-) diff --git a/internal/bsr/convert/convert_test.go b/internal/bsr/convert/convert_test.go index ce014096c2..4b2c45b015 100644 --- a/internal/bsr/convert/convert_test.go +++ b/internal/bsr/convert/convert_test.go @@ -236,12 +236,9 @@ func TestConvert_ToAsciicast_SessionProgram(t *testing.T) { err = writeToChannels(ctx, outW, messageOutboundBsrChunks...) require.NoError(t, err) - outWC := outW.(io.Closer) - outWC.Close() - - ch.Close(ctx) - conn.Close(ctx) - sesh.Close(ctx) + require.NoError(t, ch.Close(ctx)) + require.NoError(t, conn.Close(ctx)) + require.NoError(t, sesh.Close(ctx)) opSesh, err := bsr.OpenSession(ctx, srm.Id, fs, keyFn) require.NoError(t, err) diff --git a/internal/bsr/encode.go b/internal/bsr/encode.go index a15e8a4453..3ed4c7e1af 100644 --- a/internal/bsr/encode.go +++ b/internal/bsr/encode.go @@ -129,15 +129,9 @@ func (e ChunkEncoder) Encode(ctx context.Context, c Chunk) (int, error) { copy(encodedChunk[chunkBaseSize:], encode.compress.Bytes()) binary.BigEndian.PutUint32(encodedChunk[chunkBaseSize+length:], sum) - return e.w.Write(encodedChunk) -} - -// Close closes the encoder. -func (e *ChunkEncoder) Close() error { - var i interface{} = e.w - v, ok := i.(io.WriteCloser) - if ok { - return v.Close() + if c.GetType() == ChunkEnd { + return e.w.WriteAndClose(encodedChunk) } - return nil + + return e.w.Write(encodedChunk) } diff --git a/internal/bsr/encode_test.go b/internal/bsr/encode_test.go index 46e4b7dbd5..4633c06f18 100644 --- a/internal/bsr/encode_test.go +++ b/internal/bsr/encode_test.go @@ -280,8 +280,6 @@ func TestChunkEncoder(t *testing.T) { require.NoError(t, err) wrote += w } - err = enc.Close() - require.NoError(t, err) got := buf.Bytes() assert.Equal(t, len(tc.want), wrote)