// Copyright 2018 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package h2c import ( "context" "crypto/tls" "fmt" "io" "io/ioutil" "log" "net" "net/http" "net/http/httptest" "strings" "testing" "golang.org/x/net/http2" ) func ExampleNewHandler() { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "Hello world") }) h2s := &http2.Server{ // ... } h1s := &http.Server{ Addr: ":8080", Handler: NewHandler(handler, h2s), } log.Fatal(h1s.ListenAndServe()) } func TestContext(t *testing.T) { baseCtx := context.WithValue(context.Background(), "testkey", "testvalue") handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.ProtoMajor != 2 { t.Errorf("Request wasn't handled by h2c. Got ProtoMajor=%v", r.ProtoMajor) } if r.Context().Value("testkey") != "testvalue" { t.Errorf("Request doesn't have expected base context: %v", r.Context()) } fmt.Fprint(w, "Hello world") }) h2s := &http2.Server{} h1s := httptest.NewUnstartedServer(NewHandler(handler, h2s)) h1s.Config.BaseContext = func(_ net.Listener) context.Context { return baseCtx } h1s.Start() defer h1s.Close() client := &http.Client{ Transport: &http2.Transport{ AllowHTTP: true, DialTLS: func(network, addr string, _ *tls.Config) (net.Conn, error) { return net.Dial(network, addr) }, }, } resp, err := client.Get(h1s.URL) if err != nil { t.Fatal(err) } _, err = ioutil.ReadAll(resp.Body) if err != nil { t.Fatal(err) } if err := resp.Body.Close(); err != nil { t.Fatal(err) } } func TestPropagation(t *testing.T) { var ( server *http.Server // double the limit because http2 will compress header headerSize = 1 << 11 headerLimit = 1 << 10 ) handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.ProtoMajor != 2 { t.Errorf("Request wasn't handled by h2c. Got ProtoMajor=%v", r.ProtoMajor) } if r.Context().Value(http.ServerContextKey).(*http.Server) != server { t.Errorf("Request doesn't have expected http server: %v", r.Context()) } if len(r.Header.Get("Long-Header")) != headerSize { t.Errorf("Request doesn't have expected http header length: %v", len(r.Header.Get("Long-Header"))) } fmt.Fprint(w, "Hello world") }) h2s := &http2.Server{} h1s := httptest.NewUnstartedServer(NewHandler(handler, h2s)) server = h1s.Config server.MaxHeaderBytes = headerLimit server.ConnState = func(conn net.Conn, state http.ConnState) { t.Logf("server conn state: conn %s -> %s, status changed to %s", conn.RemoteAddr(), conn.LocalAddr(), state) } h1s.Start() defer h1s.Close() client := &http.Client{ Transport: &http2.Transport{ AllowHTTP: true, DialTLS: func(network, addr string, _ *tls.Config) (net.Conn, error) { conn, err := net.Dial(network, addr) if conn != nil { t.Logf("client dial tls: %s -> %s", conn.RemoteAddr(), conn.LocalAddr()) } return conn, err }, }, } req, err := http.NewRequest("GET", h1s.URL, nil) if err != nil { t.Fatal(err) } req.Header.Set("Long-Header", strings.Repeat("A", headerSize)) _, err = client.Do(req) if err == nil { t.Fatal("expected server err, got nil") } } func TestMaxBytesHandler(t *testing.T) { const bodyLimit = 10 handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { t.Errorf("got request, expected to be blocked by body limit") }) h2s := &http2.Server{} h1s := httptest.NewUnstartedServer(http.MaxBytesHandler(NewHandler(handler, h2s), bodyLimit)) h1s.Start() defer h1s.Close() // Wrap the body in a struct{io.Reader} to prevent it being rewound and resent. body := "0123456789abcdef" req, err := http.NewRequest("POST", h1s.URL, struct{ io.Reader }{strings.NewReader(body)}) if err != nil { t.Fatal(err) } req.Header.Set("Http2-Settings", "") req.Header.Set("Upgrade", "h2c") req.Header.Set("Connection", "Upgrade, HTTP2-Settings") resp, err := h1s.Client().Do(req) if err != nil { t.Fatal(err) } defer resp.Body.Close() _, err = ioutil.ReadAll(resp.Body) if err != nil { t.Fatal(err) } if got, want := resp.StatusCode, http.StatusInternalServerError; got != want { t.Errorf("resp.StatusCode = %v, want %v", got, want) } }