diff --git a/pkg/nazabits/bits.go b/pkg/nazabits/bits.go index a37f222..9021703 100644 --- a/pkg/nazabits/bits.go +++ b/pkg/nazabits/bits.go @@ -55,8 +55,9 @@ func NewBitWriter(b []byte) BitWriter { } } +// @param b: 当b不为0和1时,取b的最低位 func (bw *BitWriter) WriteBit(b uint8) { - bw.core[bw.index] |= b << (7 - bw.pos) + bw.core[bw.index] |= (b & 0x1) << (7 - bw.pos) bw.pos++ if bw.pos == 8 { bw.pos = 0 diff --git a/pkg/nazabits/bits_test.go b/pkg/nazabits/bits_test.go index 129e8da..7ecb681 100644 --- a/pkg/nazabits/bits_test.go +++ b/pkg/nazabits/bits_test.go @@ -121,6 +121,15 @@ func TestBitWriter_WriteBit(t *testing.T) { } assert.Equal(t, uint8(48), v[0]) assert.Equal(t, uint8(57), v[1]) + + v = make([]byte, 2) + bw = nazabits.NewBitWriter(v) + bs = []uint8{2, 4, 3, 5, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1} + for _, b := range bs { + bw.WriteBit(b) + } + assert.Equal(t, uint8(48), v[0]) + assert.Equal(t, uint8(57), v[1]) } func TestBitWriter_WriteBits(t *testing.T) { @@ -134,6 +143,11 @@ func TestBitWriter_WriteBits(t *testing.T) { bw.WriteBits(1, 1) assert.Equal(t, uint8(48), v[0]) assert.Equal(t, uint8(57), v[1]) + + v = make([]byte, 1) + bw = nazabits.NewBitWriter(v) + bw.WriteBits(3, 1+8+32+128) + assert.Equal(t, uint8(1<<5), v[0]) } func BenchmarkGetBits16(b *testing.B) {