/* * * Copyright 2018 gRPC authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * */ package conn import ( "bytes" "encoding/binary" "fmt" "io" "math" "net" "reflect" "testing" core "google.golang.org/grpc/credentials/alts/internal" ) var ( nextProtocols = []string{"ALTSRP_GCM_AES128"} altsRecordFuncs = map[string]ALTSRecordFunc{ // ALTS handshaker protocols. "ALTSRP_GCM_AES128": func(s core.Side, keyData []byte) (ALTSRecordCrypto, error) { return NewAES128GCM(s, keyData) }, } ) func init() { for protocol, f := range altsRecordFuncs { if err := RegisterProtocol(protocol, f); err != nil { panic(err) } } } // testConn mimics a net.Conn to the peer. type testConn struct { net.Conn in *bytes.Buffer out *bytes.Buffer } func (c *testConn) Read(b []byte) (n int, err error) { return c.in.Read(b) } func (c *testConn) Write(b []byte) (n int, err error) { return c.out.Write(b) } func (c *testConn) Close() error { return nil } func newTestALTSRecordConn(in, out *bytes.Buffer, side core.Side, np string) *conn { key := []byte{ // 16 arbitrary bytes. 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0xd2, 0x4c, 0xce, 0x4f, 0x49} tc := testConn{ in: in, out: out, } c, err := NewConn(&tc, side, np, key, nil) if err != nil { panic(fmt.Sprintf("Unexpected error creating test ALTS record connection: %v", err)) } return c.(*conn) } func newConnPair(np string) (client, server *conn) { clientBuf := new(bytes.Buffer) serverBuf := new(bytes.Buffer) clientConn := newTestALTSRecordConn(clientBuf, serverBuf, core.ClientSide, np) serverConn := newTestALTSRecordConn(serverBuf, clientBuf, core.ServerSide, np) return clientConn, serverConn } func testPingPong(t *testing.T, np string) { clientConn, serverConn := newConnPair(np) clientMsg := []byte("Client Message") if n, err := clientConn.Write(clientMsg); n != len(clientMsg) || err != nil { t.Fatalf("Client Write() = %v, %v; want %v, ", n, err, len(clientMsg)) } rcvClientMsg := make([]byte, len(clientMsg)) if n, err := serverConn.Read(rcvClientMsg); n != len(rcvClientMsg) || err != nil { t.Fatalf("Server Read() = %v, %v; want %v, ", n, err, len(rcvClientMsg)) } if !reflect.DeepEqual(clientMsg, rcvClientMsg) { t.Fatalf("Client Write()/Server Read() = %v, want %v", rcvClientMsg, clientMsg) } serverMsg := []byte("Server Message") if n, err := serverConn.Write(serverMsg); n != len(serverMsg) || err != nil { t.Fatalf("Server Write() = %v, %v; want %v, ", n, err, len(serverMsg)) } rcvServerMsg := make([]byte, len(serverMsg)) if n, err := clientConn.Read(rcvServerMsg); n != len(rcvServerMsg) || err != nil { t.Fatalf("Client Read() = %v, %v; want %v, ", n, err, len(rcvServerMsg)) } if !reflect.DeepEqual(serverMsg, rcvServerMsg) { t.Fatalf("Server Write()/Client Read() = %v, want %v", rcvServerMsg, serverMsg) } } func TestPingPong(t *testing.T) { for _, np := range nextProtocols { testPingPong(t, np) } } func testSmallReadBuffer(t *testing.T, np string) { clientConn, serverConn := newConnPair(np) msg := []byte("Very Important Message") if n, err := clientConn.Write(msg); err != nil { t.Fatalf("Write() = %v, %v; want %v, ", n, err, len(msg)) } rcvMsg := make([]byte, len(msg)) n := 2 // Arbitrary index to break rcvMsg in two. rcvMsg1 := rcvMsg[:n] rcvMsg2 := rcvMsg[n:] if n, err := serverConn.Read(rcvMsg1); n != len(rcvMsg1) || err != nil { t.Fatalf("Read() = %v, %v; want %v, ", n, err, len(rcvMsg1)) } if n, err := serverConn.Read(rcvMsg2); n != len(rcvMsg2) || err != nil { t.Fatalf("Read() = %v, %v; want %v, ", n, err, len(rcvMsg2)) } if !reflect.DeepEqual(msg, rcvMsg) { t.Fatalf("Write()/Read() = %v, want %v", rcvMsg, msg) } } func TestSmallReadBuffer(t *testing.T) { for _, np := range nextProtocols { testSmallReadBuffer(t, np) } } func testLargeMsg(t *testing.T, np string) { clientConn, serverConn := newConnPair(np) // msgLen is such that the length in the framing is larger than the // default size of one frame. msgLen := altsRecordDefaultLength - msgTypeFieldSize - clientConn.crypto.EncryptionOverhead() + 1 msg := make([]byte, msgLen) if n, err := clientConn.Write(msg); n != len(msg) || err != nil { t.Fatalf("Write() = %v, %v; want %v, ", n, err, len(msg)) } rcvMsg := make([]byte, len(msg)) if n, err := io.ReadFull(serverConn, rcvMsg); n != len(rcvMsg) || err != nil { t.Fatalf("Read() = %v, %v; want %v, ", n, err, len(rcvMsg)) } if !reflect.DeepEqual(msg, rcvMsg) { t.Fatalf("Write()/Server Read() = %v, want %v", rcvMsg, msg) } } func TestLargeMsg(t *testing.T) { for _, np := range nextProtocols { testLargeMsg(t, np) } } func testIncorrectMsgType(t *testing.T, np string) { // framedMsg is an empty ciphertext with correct framing but wrong // message type. framedMsg := make([]byte, MsgLenFieldSize+msgTypeFieldSize) binary.LittleEndian.PutUint32(framedMsg[:MsgLenFieldSize], msgTypeFieldSize) wrongMsgType := uint32(0x22) binary.LittleEndian.PutUint32(framedMsg[MsgLenFieldSize:], wrongMsgType) in := bytes.NewBuffer(framedMsg) c := newTestALTSRecordConn(in, nil, core.ClientSide, np) b := make([]byte, 1) if n, err := c.Read(b); n != 0 || err == nil { t.Fatalf("Read() = , want %v", fmt.Errorf("received frame with incorrect message type %v", wrongMsgType)) } } func TestIncorrectMsgType(t *testing.T) { for _, np := range nextProtocols { testIncorrectMsgType(t, np) } } func testFrameTooLarge(t *testing.T, np string) { buf := new(bytes.Buffer) clientConn := newTestALTSRecordConn(nil, buf, core.ClientSide, np) serverConn := newTestALTSRecordConn(buf, nil, core.ServerSide, np) // payloadLen is such that the length in the framing is larger than // allowed in one frame. payloadLen := altsRecordLengthLimit - msgTypeFieldSize - clientConn.crypto.EncryptionOverhead() + 1 payload := make([]byte, payloadLen) c, err := clientConn.crypto.Encrypt(nil, payload) if err != nil { t.Fatalf(fmt.Sprintf("Error encrypting message: %v", err)) } msgLen := msgTypeFieldSize + len(c) framedMsg := make([]byte, MsgLenFieldSize+msgLen) binary.LittleEndian.PutUint32(framedMsg[:MsgLenFieldSize], uint32(msgTypeFieldSize+len(c))) msg := framedMsg[MsgLenFieldSize:] binary.LittleEndian.PutUint32(msg[:msgTypeFieldSize], altsRecordMsgType) copy(msg[msgTypeFieldSize:], c) if _, err = buf.Write(framedMsg); err != nil { t.Fatal(fmt.Sprintf("Unexpected error writing to buffer: %v", err)) } b := make([]byte, 1) if n, err := serverConn.Read(b); n != 0 || err == nil { t.Fatalf("Read() = , want %v", fmt.Errorf("received the frame length %d larger than the limit %d", altsRecordLengthLimit+1, altsRecordLengthLimit)) } } func TestFrameTooLarge(t *testing.T) { for _, np := range nextProtocols { testFrameTooLarge(t, np) } } func testWriteLargeData(t *testing.T, np string) { // Test sending and receiving messages larger than the maximum write // buffer size. clientConn, serverConn := newConnPair(np) // Message size is intentionally chosen to not be multiple of // payloadLengthLimtit. msgSize := altsWriteBufferMaxSize + (100 * 1024) clientMsg := make([]byte, msgSize) for i := 0; i < msgSize; i++ { clientMsg[i] = 0xAA } if n, err := clientConn.Write(clientMsg); n != len(clientMsg) || err != nil { t.Fatalf("Client Write() = %v, %v; want %v, ", n, err, len(clientMsg)) } // We need to keep reading until the entire message is received. The // reason we set all bytes of the message to a value other than zero is // to avoid ambiguous zero-init value of rcvClientMsg buffer and the // actual received data. rcvClientMsg := make([]byte, 0, msgSize) numberOfExpectedFrames := int(math.Ceil(float64(msgSize) / float64(serverConn.payloadLengthLimit))) for i := 0; i < numberOfExpectedFrames; i++ { expectedRcvSize := serverConn.payloadLengthLimit if i == numberOfExpectedFrames-1 { // Last frame might be smaller. expectedRcvSize = msgSize % serverConn.payloadLengthLimit } tmpBuf := make([]byte, expectedRcvSize) if n, err := serverConn.Read(tmpBuf); n != len(tmpBuf) || err != nil { t.Fatalf("Server Read() = %v, %v; want %v, ", n, err, len(tmpBuf)) } rcvClientMsg = append(rcvClientMsg, tmpBuf...) } if !reflect.DeepEqual(clientMsg, rcvClientMsg) { t.Fatalf("Client Write()/Server Read() = %v, want %v", rcvClientMsg, clientMsg) } } func TestWriteLargeData(t *testing.T) { for _, np := range nextProtocols { testWriteLargeData(t, np) } }