// Copyright (C) MongoDB, Inc. 2017-present. // // Licensed under the Apache License, Version 2.0 (the "License"); you may // not use this file except in compliance with the License. You may obtain // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 package bson import ( "bytes" "errors" "reflect" "testing" "go.mongodb.org/mongo-driver/bson/bsoncodec" "go.mongodb.org/mongo-driver/bson/bsonrw" "go.mongodb.org/mongo-driver/bson/bsonrw/bsonrwtest" ) func TestBasicEncode(t *testing.T) { for _, tc := range marshalingTestCases { t.Run(tc.name, func(t *testing.T) { got := make(bsonrw.SliceWriter, 0, 1024) vw, err := bsonrw.NewBSONValueWriter(&got) noerr(t, err) reg := DefaultRegistry encoder, err := reg.LookupEncoder(reflect.TypeOf(tc.val)) noerr(t, err) err = encoder.EncodeValue(bsoncodec.EncodeContext{Registry: reg}, vw, reflect.ValueOf(tc.val)) noerr(t, err) if !bytes.Equal(got, tc.want) { t.Errorf("Bytes are not equal. got %v; want %v", got, tc.want) t.Errorf("Bytes:\n%v\n%v", got, tc.want) } }) } } func TestEncoderEncode(t *testing.T) { for _, tc := range marshalingTestCases { t.Run(tc.name, func(t *testing.T) { got := make(bsonrw.SliceWriter, 0, 1024) vw, err := bsonrw.NewBSONValueWriter(&got) noerr(t, err) enc, err := NewEncoder(vw) noerr(t, err) err = enc.Encode(tc.val) noerr(t, err) if !bytes.Equal(got, tc.want) { t.Errorf("Bytes are not equal. got %v; want %v", got, tc.want) t.Errorf("Bytes:\n%v\n%v", got, tc.want) } }) } t.Run("Marshaler", func(t *testing.T) { testCases := []struct { name string buf []byte err error wanterr error vw bsonrw.ValueWriter }{ { "error", nil, errors.New("Marshaler error"), errors.New("Marshaler error"), &bsonrwtest.ValueReaderWriter{}, }, { "copy error", []byte{0x05, 0x00, 0x00, 0x00, 0x00}, nil, errors.New("copy error"), &bsonrwtest.ValueReaderWriter{Err: errors.New("copy error"), ErrAfter: bsonrwtest.WriteDocument}, }, { "success", []byte{0x07, 0x00, 0x00, 0x00, 0x0A, 0x00, 0x00}, nil, nil, nil, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { marshaler := testMarshaler{buf: tc.buf, err: tc.err} var vw bsonrw.ValueWriter var err error b := make(bsonrw.SliceWriter, 0, 100) compareVW := false if tc.vw != nil { vw = tc.vw } else { compareVW = true vw, err = bsonrw.NewBSONValueWriter(&b) noerr(t, err) } enc, err := NewEncoder(vw) noerr(t, err) got := enc.Encode(marshaler) want := tc.wanterr if !compareErrors(got, want) { t.Errorf("Did not receive expected error. got %v; want %v", got, want) } if compareVW { buf := b if !bytes.Equal(buf, tc.buf) { t.Errorf("Copied bytes do not match. got %v; want %v", buf, tc.buf) } } }) } }) } type testMarshaler struct { buf []byte err error } func (tm testMarshaler) MarshalBSON() ([]byte, error) { return tm.buf, tm.err } func docToBytes(d interface{}) []byte { b, err := Marshal(d) if err != nil { panic(err) } return b } type byteMarshaler []byte func (bm byteMarshaler) MarshalBSON() ([]byte, error) { return bm, nil } type _Interface interface { method() } type _impl struct { Foo string } func (_impl) method() {}