// Copyright 2011 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package ssh import ( "bytes" "crypto" "crypto/rand" "encoding/binary" "io" "testing" "golang.org/x/crypto/chacha20" "golang.org/x/crypto/internal/poly1305" ) func TestDefaultCiphersExist(t *testing.T) { for _, cipherAlgo := range supportedCiphers { if _, ok := cipherModes[cipherAlgo]; !ok { t.Errorf("supported cipher %q is unknown", cipherAlgo) } } for _, cipherAlgo := range preferredCiphers { if _, ok := cipherModes[cipherAlgo]; !ok { t.Errorf("preferred cipher %q is unknown", cipherAlgo) } } } func TestPacketCiphers(t *testing.T) { defaultMac := "hmac-sha2-256" defaultCipher := "aes128-ctr" for cipher := range cipherModes { t.Run("cipher="+cipher, func(t *testing.T) { testPacketCipher(t, cipher, defaultMac) }) } for mac := range macModes { t.Run("mac="+mac, func(t *testing.T) { testPacketCipher(t, defaultCipher, mac) }) } } func testPacketCipher(t *testing.T, cipher, mac string) { kr := &kexResult{Hash: crypto.SHA1} algs := directionAlgorithms{ Cipher: cipher, MAC: mac, Compression: "none", } client, err := newPacketCipher(clientKeys, algs, kr) if err != nil { t.Fatalf("newPacketCipher(client, %q, %q): %v", cipher, mac, err) } server, err := newPacketCipher(clientKeys, algs, kr) if err != nil { t.Fatalf("newPacketCipher(client, %q, %q): %v", cipher, mac, err) } want := "bla bla" input := []byte(want) buf := &bytes.Buffer{} if err := client.writeCipherPacket(0, buf, rand.Reader, input); err != nil { t.Fatalf("writeCipherPacket(%q, %q): %v", cipher, mac, err) } packet, err := server.readCipherPacket(0, buf) if err != nil { t.Fatalf("readCipherPacket(%q, %q): %v", cipher, mac, err) } if string(packet) != want { t.Errorf("roundtrip(%q, %q): got %q, want %q", cipher, mac, packet, want) } } func TestCBCOracleCounterMeasure(t *testing.T) { kr := &kexResult{Hash: crypto.SHA1} algs := directionAlgorithms{ Cipher: aes128cbcID, MAC: "hmac-sha1", Compression: "none", } client, err := newPacketCipher(clientKeys, algs, kr) if err != nil { t.Fatalf("newPacketCipher(client): %v", err) } want := "bla bla" input := []byte(want) buf := &bytes.Buffer{} if err := client.writeCipherPacket(0, buf, rand.Reader, input); err != nil { t.Errorf("writeCipherPacket: %v", err) } packetSize := buf.Len() buf.Write(make([]byte, 2*maxPacket)) // We corrupt each byte, but this usually will only test the // 'packet too large' or 'MAC failure' cases. lastRead := -1 for i := 0; i < packetSize; i++ { server, err := newPacketCipher(clientKeys, algs, kr) if err != nil { t.Fatalf("newPacketCipher(client): %v", err) } fresh := &bytes.Buffer{} fresh.Write(buf.Bytes()) fresh.Bytes()[i] ^= 0x01 before := fresh.Len() _, err = server.readCipherPacket(0, fresh) if err == nil { t.Errorf("corrupt byte %d: readCipherPacket succeeded ", i) continue } if _, ok := err.(cbcError); !ok { t.Errorf("corrupt byte %d: got %v (%T), want cbcError", i, err, err) continue } after := fresh.Len() bytesRead := before - after if bytesRead < maxPacket { t.Errorf("corrupt byte %d: read %d bytes, want more than %d", i, bytesRead, maxPacket) continue } if i > 0 && bytesRead != lastRead { t.Errorf("corrupt byte %d: read %d bytes, want %d bytes read", i, bytesRead, lastRead) } lastRead = bytesRead } } func TestCVE202143565(t *testing.T) { tests := []struct { cipher string constructPacket func(packetCipher) io.Reader }{ { cipher: gcmCipherID, constructPacket: func(client packetCipher) io.Reader { internalCipher := client.(*gcmCipher) b := &bytes.Buffer{} prefix := [4]byte{} if _, err := b.Write(prefix[:]); err != nil { t.Fatal(err) } internalCipher.buf = internalCipher.aead.Seal(internalCipher.buf[:0], internalCipher.iv, []byte{}, prefix[:]) if _, err := b.Write(internalCipher.buf); err != nil { t.Fatal(err) } internalCipher.incIV() return b }, }, { cipher: chacha20Poly1305ID, constructPacket: func(client packetCipher) io.Reader { internalCipher := client.(*chacha20Poly1305Cipher) b := &bytes.Buffer{} nonce := make([]byte, 12) s, err := chacha20.NewUnauthenticatedCipher(internalCipher.contentKey[:], nonce) if err != nil { t.Fatal(err) } var polyKey, discardBuf [32]byte s.XORKeyStream(polyKey[:], polyKey[:]) s.XORKeyStream(discardBuf[:], discardBuf[:]) // skip the next 32 bytes internalCipher.buf = make([]byte, 4+poly1305.TagSize) binary.BigEndian.PutUint32(internalCipher.buf, 0) ls, err := chacha20.NewUnauthenticatedCipher(internalCipher.lengthKey[:], nonce) if err != nil { t.Fatal(err) } ls.XORKeyStream(internalCipher.buf, internalCipher.buf[:4]) if _, err := io.ReadFull(rand.Reader, internalCipher.buf[4:4]); err != nil { t.Fatal(err) } s.XORKeyStream(internalCipher.buf[4:], internalCipher.buf[4:4]) var tag [poly1305.TagSize]byte poly1305.Sum(&tag, internalCipher.buf[:4], &polyKey) copy(internalCipher.buf[4:], tag[:]) if _, err := b.Write(internalCipher.buf); err != nil { t.Fatal(err) } return b }, }, } for _, tc := range tests { mac := "hmac-sha2-256" kr := &kexResult{Hash: crypto.SHA1} algs := directionAlgorithms{ Cipher: tc.cipher, MAC: mac, Compression: "none", } client, err := newPacketCipher(clientKeys, algs, kr) if err != nil { t.Fatalf("newPacketCipher(client, %q, %q): %v", tc.cipher, mac, err) } server, err := newPacketCipher(clientKeys, algs, kr) if err != nil { t.Fatalf("newPacketCipher(client, %q, %q): %v", tc.cipher, mac, err) } b := tc.constructPacket(client) wantErr := "ssh: empty packet" _, err = server.readCipherPacket(0, b) if err == nil { t.Fatalf("readCipherPacket(%q, %q): didn't fail with empty packet", tc.cipher, mac) } else if err.Error() != wantErr { t.Fatalf("readCipherPacket(%q, %q): unexpected error, got %q, want %q", tc.cipher, mac, err, wantErr) } } }