// Copyright 2017 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. //go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris || windows || zos // +build aix darwin dragonfly freebsd linux netbsd openbsd solaris windows zos package socket_test import ( "bytes" "fmt" "io/ioutil" "net" "os" "os/exec" "path/filepath" "runtime" "strings" "syscall" "testing" "golang.org/x/net/internal/socket" "golang.org/x/net/nettest" ) func TestSocket(t *testing.T) { t.Run("Option", func(t *testing.T) { testSocketOption(t, &socket.Option{Level: syscall.SOL_SOCKET, Name: syscall.SO_RCVBUF, Len: 4}) }) } func testSocketOption(t *testing.T, so *socket.Option) { c, err := nettest.NewLocalPacketListener("udp") if err != nil { t.Skipf("not supported on %s/%s: %v", runtime.GOOS, runtime.GOARCH, err) } defer c.Close() cc, err := socket.NewConn(c.(net.Conn)) if err != nil { t.Fatal(err) } const N = 2048 if err := so.SetInt(cc, N); err != nil { t.Fatal(err) } n, err := so.GetInt(cc) if err != nil { t.Fatal(err) } if n < N { t.Fatalf("got %d; want greater than or equal to %d", n, N) } } type mockControl struct { Level int Type int Data []byte } func TestControlMessage(t *testing.T) { switch runtime.GOOS { case "windows": t.Skipf("not supported on %s", runtime.GOOS) } for _, tt := range []struct { cs []mockControl }{ { []mockControl{ {Level: 1, Type: 1}, }, }, { []mockControl{ {Level: 2, Type: 2, Data: []byte{0xfe}}, }, }, { []mockControl{ {Level: 3, Type: 3, Data: []byte{0xfe, 0xff, 0xff, 0xfe}}, }, }, { []mockControl{ {Level: 4, Type: 4, Data: []byte{0xfe, 0xff, 0xff, 0xfe, 0xfe, 0xff, 0xff, 0xfe}}, }, }, { []mockControl{ {Level: 4, Type: 4, Data: []byte{0xfe, 0xff, 0xff, 0xfe, 0xfe, 0xff, 0xff, 0xfe}}, {Level: 2, Type: 2, Data: []byte{0xfe}}, }, }, } { var w []byte var tailPadLen int mm := socket.NewControlMessage([]int{0}) for i, c := range tt.cs { m := socket.NewControlMessage([]int{len(c.Data)}) l := len(m) - len(mm) if i == len(tt.cs)-1 && l > len(c.Data) { tailPadLen = l - len(c.Data) } w = append(w, m...) } var err error ww := make([]byte, len(w)) copy(ww, w) m := socket.ControlMessage(ww) for _, c := range tt.cs { if err = m.MarshalHeader(c.Level, c.Type, len(c.Data)); err != nil { t.Fatalf("(%v).MarshalHeader() = %v", tt.cs, err) } copy(m.Data(len(c.Data)), c.Data) m = m.Next(len(c.Data)) } m = socket.ControlMessage(w) for _, c := range tt.cs { m, err = m.Marshal(c.Level, c.Type, c.Data) if err != nil { t.Fatalf("(%v).Marshal() = %v", tt.cs, err) } } if !bytes.Equal(ww, w) { t.Fatalf("got %#v; want %#v", ww, w) } ws := [][]byte{w} if tailPadLen > 0 { // Test a message with no tail padding. nopad := w[:len(w)-tailPadLen] ws = append(ws, [][]byte{nopad}...) } for _, w := range ws { ms, err := socket.ControlMessage(w).Parse() if err != nil { t.Fatalf("(%v).Parse() = %v", tt.cs, err) } for i, m := range ms { lvl, typ, dataLen, err := m.ParseHeader() if err != nil { t.Fatalf("(%v).ParseHeader() = %v", tt.cs, err) } if lvl != tt.cs[i].Level || typ != tt.cs[i].Type || dataLen != len(tt.cs[i].Data) { t.Fatalf("%v: got %d, %d, %d; want %d, %d, %d", tt.cs[i], lvl, typ, dataLen, tt.cs[i].Level, tt.cs[i].Type, len(tt.cs[i].Data)) } } } } } func TestUDP(t *testing.T) { switch runtime.GOOS { case "windows": t.Skipf("not supported on %s", runtime.GOOS) } c, err := nettest.NewLocalPacketListener("udp") if err != nil { t.Skipf("not supported on %s/%s: %v", runtime.GOOS, runtime.GOARCH, err) } defer c.Close() // test that wrapped connections work with NewConn too type wrappedConn struct{ *net.UDPConn } cc, err := socket.NewConn(&wrappedConn{c.(*net.UDPConn)}) if err != nil { t.Fatal(err) } // create a dialed connection talking (only) to c/cc cDialed, err := net.Dial("udp", c.LocalAddr().String()) if err != nil { t.Fatal(err) } ccDialed, err := socket.NewConn(cDialed) if err != nil { t.Fatal(err) } const data = "HELLO-R-U-THERE" messageTests := []struct { name string conn *socket.Conn dest net.Addr }{ { name: "Message", conn: cc, dest: c.LocalAddr(), }, { name: "Message-dialed", conn: ccDialed, dest: nil, }, } for _, tt := range messageTests { t.Run(tt.name, func(t *testing.T) { wm := socket.Message{ Buffers: bytes.SplitAfter([]byte(data), []byte("-")), Addr: tt.dest, } if err := tt.conn.SendMsg(&wm, 0); err != nil { t.Fatal(err) } b := make([]byte, 32) rm := socket.Message{ Buffers: [][]byte{b[:1], b[1:3], b[3:7], b[7:11], b[11:]}, } if err := cc.RecvMsg(&rm, 0); err != nil { t.Fatal(err) } received := string(b[:rm.N]) if received != data { t.Fatalf("Roundtrip SendMsg/RecvMsg got %q; want %q", received, data) } }) } switch runtime.GOOS { case "android", "linux": messagesTests := []struct { name string conn *socket.Conn dest net.Addr }{ { name: "Messages", conn: cc, dest: c.LocalAddr(), }, { name: "Messages-dialed", conn: ccDialed, dest: nil, }, } for _, tt := range messagesTests { t.Run(tt.name, func(t *testing.T) { wmbs := bytes.SplitAfter([]byte(data), []byte("-")) wms := []socket.Message{ {Buffers: wmbs[:1], Addr: tt.dest}, {Buffers: wmbs[1:], Addr: tt.dest}, } n, err := tt.conn.SendMsgs(wms, 0) if err != nil { t.Fatal(err) } if n != len(wms) { t.Fatalf("SendMsgs(%#v) != %d; want %d", wms, n, len(wms)) } rmbs := [][]byte{make([]byte, 32), make([]byte, 32)} rms := []socket.Message{ {Buffers: [][]byte{rmbs[0]}}, {Buffers: [][]byte{rmbs[1][:1], rmbs[1][1:3], rmbs[1][3:7], rmbs[1][7:11], rmbs[1][11:]}}, } nrecv := 0 for nrecv < len(rms) { n, err := cc.RecvMsgs(rms[nrecv:], 0) if err != nil { t.Fatal(err) } nrecv += n } received0, received1 := string(rmbs[0][:rms[0].N]), string(rmbs[1][:rms[1].N]) assembled := received0 + received1 assembledReordered := received1 + received0 if assembled != data && assembledReordered != data { t.Fatalf("Roundtrip SendMsgs/RecvMsgs got %q / %q; want %q", assembled, assembledReordered, data) } }) } t.Run("Messages-undialed-no-dst", func(t *testing.T) { // sending without destination address should fail. // This checks that the internally recycled buffers are reset correctly. data := []byte("HELLO-R-U-THERE") wmbs := bytes.SplitAfter(data, []byte("-")) wms := []socket.Message{ {Buffers: wmbs[:1], Addr: nil}, {Buffers: wmbs[1:], Addr: nil}, } n, err := cc.SendMsgs(wms, 0) if n != 0 && err == nil { t.Fatal("expected error, destination address required") } }) } // The behavior of transmission for zero byte paylaod depends // on each platform implementation. Some may transmit only // protocol header and options, other may transmit nothing. // We test only that SendMsg and SendMsgs will not crash with // empty buffers. wm := socket.Message{ Buffers: [][]byte{{}}, Addr: c.LocalAddr(), } cc.SendMsg(&wm, 0) wms := []socket.Message{ {Buffers: [][]byte{{}}, Addr: c.LocalAddr()}, } cc.SendMsgs(wms, 0) } func BenchmarkUDP(b *testing.B) { c, err := nettest.NewLocalPacketListener("udp") if err != nil { b.Skipf("not supported on %s/%s: %v", runtime.GOOS, runtime.GOARCH, err) } defer c.Close() cc, err := socket.NewConn(c.(net.Conn)) if err != nil { b.Fatal(err) } data := []byte("HELLO-R-U-THERE") wm := socket.Message{ Buffers: [][]byte{data}, Addr: c.LocalAddr(), } rm := socket.Message{ Buffers: [][]byte{make([]byte, 128)}, OOB: make([]byte, 128), } for M := 1; M <= 1<<9; M = M << 1 { b.Run(fmt.Sprintf("Iter-%d", M), func(b *testing.B) { for i := 0; i < b.N; i++ { for j := 0; j < M; j++ { if err := cc.SendMsg(&wm, 0); err != nil { b.Fatal(err) } if err := cc.RecvMsg(&rm, 0); err != nil { b.Fatal(err) } } } }) switch runtime.GOOS { case "android", "linux": wms := make([]socket.Message, M) for i := range wms { wms[i].Buffers = [][]byte{data} wms[i].Addr = c.LocalAddr() } rms := make([]socket.Message, M) for i := range rms { rms[i].Buffers = [][]byte{make([]byte, 128)} rms[i].OOB = make([]byte, 128) } b.Run(fmt.Sprintf("Batch-%d", M), func(b *testing.B) { for i := 0; i < b.N; i++ { if _, err := cc.SendMsgs(wms, 0); err != nil { b.Fatal(err) } if _, err := cc.RecvMsgs(rms, 0); err != nil { b.Fatal(err) } } }) } } } func TestRace(t *testing.T) { tests := []string{ ` package main import ( "log" "net" "golang.org/x/net/ipv4" ) var g byte func main() { c, err := net.ListenPacket("udp", "127.0.0.1:0") if err != nil { log.Fatalf("ListenPacket: %v", err) } cc := ipv4.NewPacketConn(c) sync := make(chan bool) src := make([]byte, 100) dst := make([]byte, 100) go func() { if _, err := cc.WriteTo(src, nil, c.LocalAddr()); err != nil { log.Fatalf("WriteTo: %v", err) } }() go func() { if _, _, _, err := cc.ReadFrom(dst); err != nil { log.Fatalf("ReadFrom: %v", err) } sync <- true }() g = dst[0] <-sync } `, ` package main import ( "log" "net" "golang.org/x/net/ipv4" ) func main() { c, err := net.ListenPacket("udp", "127.0.0.1:0") if err != nil { log.Fatalf("ListenPacket: %v", err) } cc := ipv4.NewPacketConn(c) sync := make(chan bool) src := make([]byte, 100) dst := make([]byte, 100) go func() { if _, err := cc.WriteTo(src, nil, c.LocalAddr()); err != nil { log.Fatalf("WriteTo: %v", err) } sync <- true }() src[0] = 0 go func() { if _, _, _, err := cc.ReadFrom(dst); err != nil { log.Fatalf("ReadFrom: %v", err) } }() <-sync } `, } platforms := map[string]bool{ "linux/amd64": true, "linux/ppc64le": true, "linux/arm64": true, } if !platforms[runtime.GOOS+"/"+runtime.GOARCH] { t.Skip("skipping test on non-race-enabled host.") } if runtime.Compiler == "gccgo" { t.Skip("skipping race test when built with gccgo") } dir, err := ioutil.TempDir("", "testrace") if err != nil { t.Fatalf("failed to create temp directory: %v", err) } defer os.RemoveAll(dir) goBinary := filepath.Join(runtime.GOROOT(), "bin", "go") t.Logf("%s version", goBinary) got, err := exec.Command(goBinary, "version").CombinedOutput() if len(got) > 0 { t.Logf("%s", got) } if err != nil { t.Fatalf("go version failed: %v", err) } for i, test := range tests { t.Run(fmt.Sprintf("test %d", i), func(t *testing.T) { src := filepath.Join(dir, fmt.Sprintf("test%d.go", i)) if err := ioutil.WriteFile(src, []byte(test), 0644); err != nil { t.Fatalf("failed to write file: %v", err) } t.Logf("%s run -race %s", goBinary, src) got, err := exec.Command(goBinary, "run", "-race", src).CombinedOutput() if len(got) > 0 { t.Logf("%s", got) } if strings.Contains(string(got), "-race requires cgo") { t.Log("CGO is not enabled so can't use -race") } else if !strings.Contains(string(got), "WARNING: DATA RACE") { t.Errorf("race not detected for test %d: err:%v", i, err) } }) } }