// Copyright (C) MongoDB, Inc. 2017-present. // // Licensed under the Apache License, Version 2.0 (the "License"); you may // not use this file except in compliance with the License. You may obtain // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 package bsonrw import ( "bytes" "errors" "fmt" "io/ioutil" "math" "reflect" "strings" "testing" "go.mongodb.org/mongo-driver/bson/bsontype" "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" ) func TestNewBSONValueWriter(t *testing.T) { _, got := NewBSONValueWriter(nil) want := errNilWriter if !compareErrors(got, want) { t.Errorf("Returned error did not match what was expected. got %v; want %v", got, want) } vw, got := NewBSONValueWriter(errWriter{}) want = nil if !compareErrors(got, want) { t.Errorf("Returned error did not match what was expected. got %v; want %v", got, want) } if vw == nil { t.Errorf("Expected non-nil ValueWriter to be returned from NewBSONValueWriter") } } func TestValueWriter(t *testing.T) { header := []byte{0x00, 0x00, 0x00, 0x00} oid := primitive.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C} testCases := []struct { name string fn interface{} params []interface{} want []byte }{ { "WriteBinary", (*valueWriter).WriteBinary, []interface{}{[]byte{0x01, 0x02, 0x03}}, bsoncore.AppendBinaryElement(header, "foo", 0x00, []byte{0x01, 0x02, 0x03}), }, { "WriteBinaryWithSubtype (not 0x02)", (*valueWriter).WriteBinaryWithSubtype, []interface{}{[]byte{0x01, 0x02, 0x03}, byte(0xFF)}, bsoncore.AppendBinaryElement(header, "foo", 0xFF, []byte{0x01, 0x02, 0x03}), }, { "WriteBinaryWithSubtype (0x02)", (*valueWriter).WriteBinaryWithSubtype, []interface{}{[]byte{0x01, 0x02, 0x03}, byte(0x02)}, bsoncore.AppendBinaryElement(header, "foo", 0x02, []byte{0x01, 0x02, 0x03}), }, { "WriteBoolean", (*valueWriter).WriteBoolean, []interface{}{true}, bsoncore.AppendBooleanElement(header, "foo", true), }, { "WriteDBPointer", (*valueWriter).WriteDBPointer, []interface{}{"bar", oid}, bsoncore.AppendDBPointerElement(header, "foo", "bar", oid), }, { "WriteDateTime", (*valueWriter).WriteDateTime, []interface{}{int64(12345678)}, bsoncore.AppendDateTimeElement(header, "foo", 12345678), }, { "WriteDecimal128", (*valueWriter).WriteDecimal128, []interface{}{primitive.NewDecimal128(10, 20)}, bsoncore.AppendDecimal128Element(header, "foo", primitive.NewDecimal128(10, 20)), }, { "WriteDouble", (*valueWriter).WriteDouble, []interface{}{float64(3.14159)}, bsoncore.AppendDoubleElement(header, "foo", 3.14159), }, { "WriteInt32", (*valueWriter).WriteInt32, []interface{}{int32(123456)}, bsoncore.AppendInt32Element(header, "foo", 123456), }, { "WriteInt64", (*valueWriter).WriteInt64, []interface{}{int64(1234567890)}, bsoncore.AppendInt64Element(header, "foo", 1234567890), }, { "WriteJavascript", (*valueWriter).WriteJavascript, []interface{}{"var foo = 'bar';"}, bsoncore.AppendJavaScriptElement(header, "foo", "var foo = 'bar';"), }, { "WriteMaxKey", (*valueWriter).WriteMaxKey, []interface{}{}, bsoncore.AppendMaxKeyElement(header, "foo"), }, { "WriteMinKey", (*valueWriter).WriteMinKey, []interface{}{}, bsoncore.AppendMinKeyElement(header, "foo"), }, { "WriteNull", (*valueWriter).WriteNull, []interface{}{}, bsoncore.AppendNullElement(header, "foo"), }, { "WriteObjectID", (*valueWriter).WriteObjectID, []interface{}{oid}, bsoncore.AppendObjectIDElement(header, "foo", oid), }, { "WriteRegex", (*valueWriter).WriteRegex, []interface{}{"bar", "baz"}, bsoncore.AppendRegexElement(header, "foo", "bar", "abz"), }, { "WriteString", (*valueWriter).WriteString, []interface{}{"hello, world!"}, bsoncore.AppendStringElement(header, "foo", "hello, world!"), }, { "WriteSymbol", (*valueWriter).WriteSymbol, []interface{}{"symbollolz"}, bsoncore.AppendSymbolElement(header, "foo", "symbollolz"), }, { "WriteTimestamp", (*valueWriter).WriteTimestamp, []interface{}{uint32(10), uint32(20)}, bsoncore.AppendTimestampElement(header, "foo", 10, 20), }, { "WriteUndefined", (*valueWriter).WriteUndefined, []interface{}{}, bsoncore.AppendUndefinedElement(header, "foo"), }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { fn := reflect.ValueOf(tc.fn) if fn.Kind() != reflect.Func { t.Fatalf("fn must be of kind Func but it is a %v", fn.Kind()) } if fn.Type().NumIn() != len(tc.params)+1 || fn.Type().In(0) != reflect.TypeOf((*valueWriter)(nil)) { t.Fatalf("fn must have at least one parameter and the first parameter must be a *valueWriter") } if fn.Type().NumOut() != 1 || fn.Type().Out(0) != reflect.TypeOf((*error)(nil)).Elem() { t.Fatalf("fn must have one return value and it must be an error.") } params := make([]reflect.Value, 1, len(tc.params)+1) vw := newValueWriter(ioutil.Discard) params[0] = reflect.ValueOf(vw) for _, param := range tc.params { params = append(params, reflect.ValueOf(param)) } _, err := vw.WriteDocument() noerr(t, err) _, err = vw.WriteDocumentElement("foo") noerr(t, err) results := fn.Call(params) if !results[0].IsValid() { err = results[0].Interface().(error) } else { err = nil } noerr(t, err) got := vw.buf want := tc.want if !bytes.Equal(got, want) { t.Errorf("Bytes are not equal.\n\tgot %v\n\twant %v", got, want) } t.Run("incorrect transition", func(t *testing.T) { vw = newValueWriter(ioutil.Discard) results := fn.Call(params) got := results[0].Interface().(error) fnName := tc.name if strings.Contains(fnName, "WriteBinary") { fnName = "WriteBinaryWithSubtype" } want := TransitionError{current: mTopLevel, name: fnName, modes: []mode{mElement, mValue}, action: "write"} if !compareErrors(got, want) { t.Errorf("Errors do not match. got %v; want %v", got, want) } }) }) } t.Run("WriteArray", func(t *testing.T) { vw := newValueWriter(ioutil.Discard) vw.push(mArray) want := TransitionError{current: mArray, destination: mArray, parent: mTopLevel, name: "WriteArray", modes: []mode{mElement, mValue}, action: "write"} _, got := vw.WriteArray() if !compareErrors(got, want) { t.Errorf("Did not get expected error. got %v; want %v", got, want) } }) t.Run("WriteCodeWithScope", func(t *testing.T) { vw := newValueWriter(ioutil.Discard) vw.push(mArray) want := TransitionError{current: mArray, destination: mCodeWithScope, parent: mTopLevel, name: "WriteCodeWithScope", modes: []mode{mElement, mValue}, action: "write"} _, got := vw.WriteCodeWithScope("") if !compareErrors(got, want) { t.Errorf("Did not get expected error. got %v; want %v", got, want) } }) t.Run("WriteDocument", func(t *testing.T) { vw := newValueWriter(ioutil.Discard) vw.push(mArray) want := TransitionError{current: mArray, destination: mDocument, parent: mTopLevel, name: "WriteDocument", modes: []mode{mElement, mValue, mTopLevel}, action: "write"} _, got := vw.WriteDocument() if !compareErrors(got, want) { t.Errorf("Did not get expected error. got %v; want %v", got, want) } }) t.Run("WriteDocumentElement", func(t *testing.T) { vw := newValueWriter(ioutil.Discard) vw.push(mElement) want := TransitionError{current: mElement, destination: mElement, parent: mTopLevel, name: "WriteDocumentElement", modes: []mode{mTopLevel, mDocument}, action: "write"} _, got := vw.WriteDocumentElement("") if !compareErrors(got, want) { t.Errorf("Did not get expected error. got %v; want %v", got, want) } }) t.Run("WriteDocumentEnd", func(t *testing.T) { vw := newValueWriter(ioutil.Discard) vw.push(mElement) want := fmt.Errorf("incorrect mode to end document: %s", mElement) got := vw.WriteDocumentEnd() if !compareErrors(got, want) { t.Errorf("Did not get expected error. got %v; want %v", got, want) } vw.pop() vw.buf = append(vw.buf, make([]byte, 1023)...) maxSize = 512 want = errMaxDocumentSizeExceeded{size: 1024} got = vw.WriteDocumentEnd() if !compareErrors(got, want) { t.Errorf("Did not get expected error. got %v; want %v", got, want) } maxSize = math.MaxInt32 want = errors.New("what a nice fake error we have here") vw.w = errWriter{err: want} got = vw.WriteDocumentEnd() if !compareErrors(got, want) { t.Errorf("Did not get expected error. got %v; want %v", got, want) } }) t.Run("WriteArrayElement", func(t *testing.T) { vw := newValueWriter(ioutil.Discard) vw.push(mElement) want := TransitionError{current: mElement, destination: mValue, parent: mTopLevel, name: "WriteArrayElement", modes: []mode{mArray}, action: "write"} _, got := vw.WriteArrayElement() if !compareErrors(got, want) { t.Errorf("Did not get expected error. got %v; want %v", got, want) } }) t.Run("WriteArrayEnd", func(t *testing.T) { vw := newValueWriter(ioutil.Discard) vw.push(mElement) want := fmt.Errorf("incorrect mode to end array: %s", mElement) got := vw.WriteArrayEnd() if !compareErrors(got, want) { t.Errorf("Did not get expected error. got %v; want %v", got, want) } vw.push(mArray) vw.buf = append(vw.buf, make([]byte, 1019)...) maxSize = 512 want = errMaxDocumentSizeExceeded{size: 1024} got = vw.WriteArrayEnd() if !compareErrors(got, want) { t.Errorf("Did not get expected error. got %v; want %v", got, want) } maxSize = math.MaxInt32 }) t.Run("WriteBytes", func(t *testing.T) { t.Run("writeElementHeader error", func(t *testing.T) { vw := newValueWriterFromSlice(nil) want := TransitionError{current: mTopLevel, destination: mode(0), name: "WriteValueBytes", modes: []mode{mElement, mValue}, action: "write"} got := vw.WriteValueBytes(bsontype.EmbeddedDocument, nil) if !compareErrors(got, want) { t.Errorf("Did not received expected error. got %v; want %v", got, want) } }) t.Run("success", func(t *testing.T) { index, doc := bsoncore.ReserveLength(nil) doc = bsoncore.AppendStringElement(doc, "hello", "world") doc = append(doc, 0x00) doc = bsoncore.UpdateLength(doc, index, int32(len(doc))) index, want := bsoncore.ReserveLength(nil) want = bsoncore.AppendDocumentElement(want, "foo", doc) want = append(want, 0x00) want = bsoncore.UpdateLength(want, index, int32(len(want))) vw := newValueWriterFromSlice(make([]byte, 0, 512)) _, err := vw.WriteDocument() noerr(t, err) _, err = vw.WriteDocumentElement("foo") noerr(t, err) err = vw.WriteValueBytes(bsontype.EmbeddedDocument, doc) noerr(t, err) err = vw.WriteDocumentEnd() noerr(t, err) got := vw.buf if !bytes.Equal(got, want) { t.Errorf("Bytes are not equal. got %v; want %v", got, want) } }) }) } type errWriter struct { err error } func (ew errWriter) Write([]byte) (int, error) { return 0, ew.err }