/* Copyright 2015 The Kubernetes Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package wsstream import ( "encoding/base64" "io" "io/ioutil" "net/http" "net/http/httptest" "reflect" "sync" "testing" "golang.org/x/net/websocket" ) func newServer(handler http.Handler) (*httptest.Server, string) { server := httptest.NewServer(handler) serverAddr := server.Listener.Addr().String() return server, serverAddr } func TestRawConn(t *testing.T) { channels := []ChannelType{ReadWriteChannel, ReadWriteChannel, IgnoreChannel, ReadChannel, WriteChannel} conn := NewConn(NewDefaultChannelProtocols(channels)) s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { conn.Open(w, req) })) defer s.Close() client, err := websocket.Dial("ws://"+addr, "", "http://localhost/") if err != nil { t.Fatal(err) } defer client.Close() <-conn.ready wg := sync.WaitGroup{} // verify we can read a client write wg.Add(1) go func() { defer wg.Done() data, err := ioutil.ReadAll(conn.channels[0]) if err != nil { t.Fatal(err) } if !reflect.DeepEqual(data, []byte("client")) { t.Errorf("unexpected server read: %v", data) } }() if n, err := client.Write(append([]byte{0}, []byte("client")...)); err != nil || n != 7 { t.Fatalf("%d: %v", n, err) } // verify we can read a server write wg.Add(1) go func() { defer wg.Done() if n, err := conn.channels[1].Write([]byte("server")); err != nil && n != 6 { t.Fatalf("%d: %v", n, err) } }() data := make([]byte, 1024) if n, err := io.ReadAtLeast(client, data, 6); n != 7 || err != nil { t.Fatalf("%d: %v", n, err) } if !reflect.DeepEqual(data[:7], append([]byte{1}, []byte("server")...)) { t.Errorf("unexpected client read: %v", data[:7]) } // verify that an ignore channel is empty in both directions. if n, err := conn.channels[2].Write([]byte("test")); n != 4 || err != nil { t.Errorf("writes should be ignored") } data = make([]byte, 1024) if n, err := conn.channels[2].Read(data); n != 0 || err != io.EOF { t.Errorf("reads should be ignored") } // verify that a write to a Read channel doesn't block if n, err := conn.channels[3].Write([]byte("test")); n != 4 || err != nil { t.Errorf("writes should be ignored") } // verify that a read from a Write channel doesn't block data = make([]byte, 1024) if n, err := conn.channels[4].Read(data); n != 0 || err != io.EOF { t.Errorf("reads should be ignored") } // verify that a client write to a Write channel doesn't block (is dropped) if n, err := client.Write(append([]byte{4}, []byte("ignored")...)); err != nil || n != 8 { t.Fatalf("%d: %v", n, err) } client.Close() wg.Wait() } func TestBase64Conn(t *testing.T) { conn := NewConn(NewDefaultChannelProtocols([]ChannelType{ReadWriteChannel, ReadWriteChannel})) s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { conn.Open(w, req) })) defer s.Close() config, err := websocket.NewConfig("ws://"+addr, "http://localhost/") if err != nil { t.Fatal(err) } config.Protocol = []string{"base64.channel.k8s.io"} client, err := websocket.DialConfig(config) if err != nil { t.Fatal(err) } defer client.Close() <-conn.ready wg := sync.WaitGroup{} wg.Add(1) go func() { defer wg.Done() data, err := ioutil.ReadAll(conn.channels[0]) if err != nil { t.Fatal(err) } if !reflect.DeepEqual(data, []byte("client")) { t.Errorf("unexpected server read: %s", string(data)) } }() clientData := base64.StdEncoding.EncodeToString([]byte("client")) if n, err := client.Write(append([]byte{'0'}, clientData...)); err != nil || n != len(clientData)+1 { t.Fatalf("%d: %v", n, err) } wg.Add(1) go func() { defer wg.Done() if n, err := conn.channels[1].Write([]byte("server")); err != nil && n != 6 { t.Fatalf("%d: %v", n, err) } }() data := make([]byte, 1024) if n, err := io.ReadAtLeast(client, data, 9); n != 9 || err != nil { t.Fatalf("%d: %v", n, err) } expect := []byte(base64.StdEncoding.EncodeToString([]byte("server"))) if !reflect.DeepEqual(data[:9], append([]byte{'1'}, expect...)) { t.Errorf("unexpected client read: %v", data[:9]) } client.Close() wg.Wait() } type versionTest struct { supported map[string]bool // protocol -> binary requested []string error bool expected string } func versionTests() []versionTest { const ( binary = true base64 = false ) return []versionTest{ { supported: nil, requested: []string{"raw"}, error: true, }, { supported: map[string]bool{"": binary, "raw": binary, "base64": base64}, requested: nil, expected: "", }, { supported: map[string]bool{"": binary, "raw": binary, "base64": base64}, requested: []string{"v1.raw"}, error: true, }, { supported: map[string]bool{"": binary, "raw": binary, "base64": base64}, requested: []string{"v1.raw", "v1.base64"}, error: true, }, { supported: map[string]bool{"": binary, "raw": binary, "base64": base64}, requested: []string{"v1.raw", "raw"}, expected: "raw", }, { supported: map[string]bool{"": binary, "v1.raw": binary, "v1.base64": base64, "v2.raw": binary, "v2.base64": base64}, requested: []string{"v1.raw"}, expected: "v1.raw", }, { supported: map[string]bool{"": binary, "v1.raw": binary, "v1.base64": base64, "v2.raw": binary, "v2.base64": base64}, requested: []string{"v2.base64"}, expected: "v2.base64", }, } } func TestVersionedConn(t *testing.T) { for i, test := range versionTests() { func() { supportedProtocols := map[string]ChannelProtocolConfig{} for p, binary := range test.supported { supportedProtocols[p] = ChannelProtocolConfig{ Binary: binary, Channels: []ChannelType{ReadWriteChannel}, } } conn := NewConn(supportedProtocols) // note that it's not enough to wait for conn.ready to avoid a race here. Hence, // we use a channel. selectedProtocol := make(chan string) s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { p, _, _ := conn.Open(w, req) selectedProtocol <- p })) defer s.Close() config, err := websocket.NewConfig("ws://"+addr, "http://localhost/") if err != nil { t.Fatal(err) } config.Protocol = test.requested client, err := websocket.DialConfig(config) if err != nil { if !test.error { t.Fatalf("test %d: didn't expect error: %v", i, err) } else { return } } defer client.Close() if test.error && err == nil { t.Fatalf("test %d: expected an error", i) } <-conn.ready if got, expected := <-selectedProtocol, test.expected; got != expected { t.Fatalf("test %d: unexpected protocol version: got=%s expected=%s", i, got, expected) } }() } }