/* Copyright 2014-2021 Docker Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package spdystream import ( "errors" "fmt" "io" "net" "net/http" "sync" "time" "github.com/moby/spdystream/spdy" ) var ( ErrInvalidStreamId = errors.New("Invalid stream id") ErrTimeout = errors.New("Timeout occurred") ErrReset = errors.New("Stream reset") ErrWriteClosedStream = errors.New("Write on closed stream") ) const ( FRAME_WORKERS = 5 QUEUE_SIZE = 50 ) type StreamHandler func(stream *Stream) type AuthHandler func(header http.Header, slot uint8, parent uint32) bool type idleAwareFramer struct { f *spdy.Framer conn *Connection writeLock sync.Mutex resetChan chan struct{} setTimeoutLock sync.Mutex setTimeoutChan chan time.Duration timeout time.Duration } func newIdleAwareFramer(framer *spdy.Framer) *idleAwareFramer { iaf := &idleAwareFramer{ f: framer, resetChan: make(chan struct{}, 2), // setTimeoutChan needs to be buffered to avoid deadlocks when calling setIdleTimeout at about // the same time the connection is being closed setTimeoutChan: make(chan time.Duration, 1), } return iaf } func (i *idleAwareFramer) monitor() { var ( timer *time.Timer expired <-chan time.Time resetChan = i.resetChan setTimeoutChan = i.setTimeoutChan ) Loop: for { select { case timeout := <-i.setTimeoutChan: i.timeout = timeout if timeout == 0 { if timer != nil { timer.Stop() } } else { if timer == nil { timer = time.NewTimer(timeout) expired = timer.C } else { timer.Reset(timeout) } } case <-resetChan: if timer != nil && i.timeout > 0 { timer.Reset(i.timeout) } case <-expired: i.conn.streamCond.L.Lock() streams := i.conn.streams i.conn.streams = make(map[spdy.StreamId]*Stream) i.conn.streamCond.Broadcast() i.conn.streamCond.L.Unlock() go func() { for _, stream := range streams { stream.resetStream() } i.conn.Close() }() case <-i.conn.closeChan: if timer != nil { timer.Stop() } // Start a goroutine to drain resetChan. This is needed because we've seen // some unit tests with large numbers of goroutines get into a situation // where resetChan fills up, at least 1 call to Write() is still trying to // send to resetChan, the connection gets closed, and this case statement // attempts to grab the write lock that Write() already has, causing a // deadlock. // // See https://github.com/moby/spdystream/issues/49 for more details. go func() { for range resetChan { } }() go func() { for range setTimeoutChan { } }() i.writeLock.Lock() close(resetChan) i.resetChan = nil i.writeLock.Unlock() i.setTimeoutLock.Lock() close(i.setTimeoutChan) i.setTimeoutChan = nil i.setTimeoutLock.Unlock() break Loop } } // Drain resetChan for range resetChan { } } func (i *idleAwareFramer) WriteFrame(frame spdy.Frame) error { i.writeLock.Lock() defer i.writeLock.Unlock() if i.resetChan == nil { return io.EOF } err := i.f.WriteFrame(frame) if err != nil { return err } i.resetChan <- struct{}{} return nil } func (i *idleAwareFramer) ReadFrame() (spdy.Frame, error) { frame, err := i.f.ReadFrame() if err != nil { return nil, err } // resetChan should never be closed since it is only closed // when the connection has closed its closeChan. This closure // only occurs after all Reads have finished // TODO (dmcgowan): refactor relationship into connection i.resetChan <- struct{}{} return frame, nil } func (i *idleAwareFramer) setIdleTimeout(timeout time.Duration) { i.setTimeoutLock.Lock() defer i.setTimeoutLock.Unlock() if i.setTimeoutChan == nil { return } i.setTimeoutChan <- timeout } type Connection struct { conn net.Conn framer *idleAwareFramer closeChan chan bool goneAway bool lastStreamChan chan<- *Stream goAwayTimeout time.Duration closeTimeout time.Duration streamLock *sync.RWMutex streamCond *sync.Cond streams map[spdy.StreamId]*Stream nextIdLock sync.Mutex receiveIdLock sync.Mutex nextStreamId spdy.StreamId receivedStreamId spdy.StreamId pingIdLock sync.Mutex pingId uint32 pingChans map[uint32]chan error shutdownLock sync.Mutex shutdownChan chan error hasShutdown bool // for testing https://github.com/moby/spdystream/pull/56 dataFrameHandler func(*spdy.DataFrame) error } // NewConnection creates a new spdy connection from an existing // network connection. func NewConnection(conn net.Conn, server bool) (*Connection, error) { framer, framerErr := spdy.NewFramer(conn, conn) if framerErr != nil { return nil, framerErr } idleAwareFramer := newIdleAwareFramer(framer) var sid spdy.StreamId var rid spdy.StreamId var pid uint32 if server { sid = 2 rid = 1 pid = 2 } else { sid = 1 rid = 2 pid = 1 } streamLock := new(sync.RWMutex) streamCond := sync.NewCond(streamLock) session := &Connection{ conn: conn, framer: idleAwareFramer, closeChan: make(chan bool), goAwayTimeout: time.Duration(0), closeTimeout: time.Duration(0), streamLock: streamLock, streamCond: streamCond, streams: make(map[spdy.StreamId]*Stream), nextStreamId: sid, receivedStreamId: rid, pingId: pid, pingChans: make(map[uint32]chan error), shutdownChan: make(chan error), } session.dataFrameHandler = session.handleDataFrame idleAwareFramer.conn = session go idleAwareFramer.monitor() return session, nil } // Ping sends a ping frame across the connection and // returns the response time func (s *Connection) Ping() (time.Duration, error) { pid := s.pingId s.pingIdLock.Lock() if s.pingId > 0x7ffffffe { s.pingId = s.pingId - 0x7ffffffe } else { s.pingId = s.pingId + 2 } s.pingIdLock.Unlock() pingChan := make(chan error) s.pingChans[pid] = pingChan defer delete(s.pingChans, pid) frame := &spdy.PingFrame{Id: pid} startTime := time.Now() writeErr := s.framer.WriteFrame(frame) if writeErr != nil { return time.Duration(0), writeErr } select { case <-s.closeChan: return time.Duration(0), errors.New("connection closed") case err, ok := <-pingChan: if ok && err != nil { return time.Duration(0), err } break } return time.Since(startTime), nil } // Serve handles frames sent from the server, including reply frames // which are needed to fully initiate connections. Both clients and servers // should call Serve in a separate goroutine before creating streams. func (s *Connection) Serve(newHandler StreamHandler) { // use a WaitGroup to wait for all frames to be drained after receiving // go-away. var wg sync.WaitGroup // Parition queues to ensure stream frames are handled // by the same worker, ensuring order is maintained frameQueues := make([]*PriorityFrameQueue, FRAME_WORKERS) for i := 0; i < FRAME_WORKERS; i++ { frameQueues[i] = NewPriorityFrameQueue(QUEUE_SIZE) // Ensure frame queue is drained when connection is closed go func(frameQueue *PriorityFrameQueue) { <-s.closeChan frameQueue.Drain() }(frameQueues[i]) wg.Add(1) go func(frameQueue *PriorityFrameQueue) { // let the WaitGroup know this worker is done defer wg.Done() s.frameHandler(frameQueue, newHandler) }(frameQueues[i]) } var ( partitionRoundRobin int goAwayFrame *spdy.GoAwayFrame ) Loop: for { readFrame, err := s.framer.ReadFrame() if err != nil { if err != io.EOF { debugMessage("frame read error: %s", err) } else { debugMessage("(%p) EOF received", s) } break } var priority uint8 var partition int switch frame := readFrame.(type) { case *spdy.SynStreamFrame: if s.checkStreamFrame(frame) { priority = frame.Priority partition = int(frame.StreamId % FRAME_WORKERS) debugMessage("(%p) Add stream frame: %d ", s, frame.StreamId) s.addStreamFrame(frame) } else { debugMessage("(%p) Rejected stream frame: %d ", s, frame.StreamId) continue } case *spdy.SynReplyFrame: priority = s.getStreamPriority(frame.StreamId) partition = int(frame.StreamId % FRAME_WORKERS) case *spdy.DataFrame: priority = s.getStreamPriority(frame.StreamId) partition = int(frame.StreamId % FRAME_WORKERS) case *spdy.RstStreamFrame: priority = s.getStreamPriority(frame.StreamId) partition = int(frame.StreamId % FRAME_WORKERS) case *spdy.HeadersFrame: priority = s.getStreamPriority(frame.StreamId) partition = int(frame.StreamId % FRAME_WORKERS) case *spdy.PingFrame: priority = 0 partition = partitionRoundRobin partitionRoundRobin = (partitionRoundRobin + 1) % FRAME_WORKERS case *spdy.GoAwayFrame: // hold on to the go away frame and exit the loop goAwayFrame = frame break Loop default: priority = 7 partition = partitionRoundRobin partitionRoundRobin = (partitionRoundRobin + 1) % FRAME_WORKERS } frameQueues[partition].Push(readFrame, priority) } close(s.closeChan) // wait for all frame handler workers to indicate they've drained their queues // before handling the go away frame wg.Wait() if goAwayFrame != nil { s.handleGoAwayFrame(goAwayFrame) } // now it's safe to close remote channels and empty s.streams s.streamCond.L.Lock() // notify streams that they're now closed, which will // unblock any stream Read() calls for _, stream := range s.streams { stream.closeRemoteChannels() } s.streams = make(map[spdy.StreamId]*Stream) s.streamCond.Broadcast() s.streamCond.L.Unlock() } func (s *Connection) frameHandler(frameQueue *PriorityFrameQueue, newHandler StreamHandler) { for { popFrame := frameQueue.Pop() if popFrame == nil { return } var frameErr error switch frame := popFrame.(type) { case *spdy.SynStreamFrame: frameErr = s.handleStreamFrame(frame, newHandler) case *spdy.SynReplyFrame: frameErr = s.handleReplyFrame(frame) case *spdy.DataFrame: frameErr = s.dataFrameHandler(frame) case *spdy.RstStreamFrame: frameErr = s.handleResetFrame(frame) case *spdy.HeadersFrame: frameErr = s.handleHeaderFrame(frame) case *spdy.PingFrame: frameErr = s.handlePingFrame(frame) case *spdy.GoAwayFrame: frameErr = s.handleGoAwayFrame(frame) default: frameErr = fmt.Errorf("unhandled frame type: %T", frame) } if frameErr != nil { debugMessage("frame handling error: %s", frameErr) } } } func (s *Connection) getStreamPriority(streamId spdy.StreamId) uint8 { stream, streamOk := s.getStream(streamId) if !streamOk { return 7 } return stream.priority } func (s *Connection) addStreamFrame(frame *spdy.SynStreamFrame) { var parent *Stream if frame.AssociatedToStreamId != spdy.StreamId(0) { parent, _ = s.getStream(frame.AssociatedToStreamId) } stream := &Stream{ streamId: frame.StreamId, parent: parent, conn: s, startChan: make(chan error), headers: frame.Headers, finished: (frame.CFHeader.Flags & spdy.ControlFlagUnidirectional) != 0x00, replyCond: sync.NewCond(new(sync.Mutex)), dataChan: make(chan []byte), headerChan: make(chan http.Header), closeChan: make(chan bool), priority: frame.Priority, } if frame.CFHeader.Flags&spdy.ControlFlagFin != 0x00 { stream.closeRemoteChannels() } s.addStream(stream) } // checkStreamFrame checks to see if a stream frame is allowed. // If the stream is invalid, then a reset frame with protocol error // will be returned. func (s *Connection) checkStreamFrame(frame *spdy.SynStreamFrame) bool { s.receiveIdLock.Lock() defer s.receiveIdLock.Unlock() if s.goneAway { return false } validationErr := s.validateStreamId(frame.StreamId) if validationErr != nil { go func() { resetErr := s.sendResetFrame(spdy.ProtocolError, frame.StreamId) if resetErr != nil { debugMessage("reset error: %s", resetErr) } }() return false } return true } func (s *Connection) handleStreamFrame(frame *spdy.SynStreamFrame, newHandler StreamHandler) error { stream, ok := s.getStream(frame.StreamId) if !ok { return fmt.Errorf("Missing stream: %d", frame.StreamId) } newHandler(stream) return nil } func (s *Connection) handleReplyFrame(frame *spdy.SynReplyFrame) error { debugMessage("(%p) Reply frame received for %d", s, frame.StreamId) stream, streamOk := s.getStream(frame.StreamId) if !streamOk { debugMessage("Reply frame gone away for %d", frame.StreamId) // Stream has already gone away return nil } if stream.replied { // Stream has already received reply return nil } stream.replied = true // TODO Check for error if (frame.CFHeader.Flags & spdy.ControlFlagFin) != 0x00 { s.remoteStreamFinish(stream) } close(stream.startChan) return nil } func (s *Connection) handleResetFrame(frame *spdy.RstStreamFrame) error { stream, streamOk := s.getStream(frame.StreamId) if !streamOk { // Stream has already been removed return nil } s.removeStream(stream) stream.closeRemoteChannels() if !stream.replied { stream.replied = true stream.startChan <- ErrReset close(stream.startChan) } stream.finishLock.Lock() stream.finished = true stream.finishLock.Unlock() return nil } func (s *Connection) handleHeaderFrame(frame *spdy.HeadersFrame) error { stream, streamOk := s.getStream(frame.StreamId) if !streamOk { // Stream has already gone away return nil } if !stream.replied { // No reply received...Protocol error? return nil } // TODO limit headers while not blocking (use buffered chan or goroutine?) select { case <-stream.closeChan: return nil case stream.headerChan <- frame.Headers: } if (frame.CFHeader.Flags & spdy.ControlFlagFin) != 0x00 { s.remoteStreamFinish(stream) } return nil } func (s *Connection) handleDataFrame(frame *spdy.DataFrame) error { debugMessage("(%p) Data frame received for %d", s, frame.StreamId) stream, streamOk := s.getStream(frame.StreamId) if !streamOk { debugMessage("(%p) Data frame gone away for %d", s, frame.StreamId) // Stream has already gone away return nil } if !stream.replied { debugMessage("(%p) Data frame not replied %d", s, frame.StreamId) // No reply received...Protocol error? return nil } debugMessage("(%p) (%d) Data frame handling", stream, stream.streamId) if len(frame.Data) > 0 { stream.dataLock.RLock() select { case <-stream.closeChan: debugMessage("(%p) (%d) Data frame not sent (stream shut down)", stream, stream.streamId) case stream.dataChan <- frame.Data: debugMessage("(%p) (%d) Data frame sent", stream, stream.streamId) } stream.dataLock.RUnlock() } if (frame.Flags & spdy.DataFlagFin) != 0x00 { s.remoteStreamFinish(stream) } return nil } func (s *Connection) handlePingFrame(frame *spdy.PingFrame) error { if s.pingId&0x01 != frame.Id&0x01 { return s.framer.WriteFrame(frame) } pingChan, pingOk := s.pingChans[frame.Id] if pingOk { close(pingChan) } return nil } func (s *Connection) handleGoAwayFrame(frame *spdy.GoAwayFrame) error { debugMessage("(%p) Go away received", s) s.receiveIdLock.Lock() if s.goneAway { s.receiveIdLock.Unlock() return nil } s.goneAway = true s.receiveIdLock.Unlock() if s.lastStreamChan != nil { stream, _ := s.getStream(frame.LastGoodStreamId) go func() { s.lastStreamChan <- stream }() } // Do not block frame handler waiting for closure go s.shutdown(s.goAwayTimeout) return nil } func (s *Connection) remoteStreamFinish(stream *Stream) { stream.closeRemoteChannels() stream.finishLock.Lock() if stream.finished { // Stream is fully closed, cleanup s.removeStream(stream) } stream.finishLock.Unlock() } // CreateStream creates a new spdy stream using the parameters for // creating the stream frame. The stream frame will be sent upon // calling this function, however this function does not wait for // the reply frame. If waiting for the reply is desired, use // the stream Wait or WaitTimeout function on the stream returned // by this function. func (s *Connection) CreateStream(headers http.Header, parent *Stream, fin bool) (*Stream, error) { // MUST synchronize stream creation (all the way to writing the frame) // as stream IDs **MUST** increase monotonically. s.nextIdLock.Lock() defer s.nextIdLock.Unlock() streamId := s.getNextStreamId() if streamId == 0 { return nil, fmt.Errorf("Unable to get new stream id") } stream := &Stream{ streamId: streamId, parent: parent, conn: s, startChan: make(chan error), headers: headers, dataChan: make(chan []byte), headerChan: make(chan http.Header), closeChan: make(chan bool), } debugMessage("(%p) (%p) Create stream", s, stream) s.addStream(stream) return stream, s.sendStream(stream, fin) } func (s *Connection) shutdown(closeTimeout time.Duration) { // TODO Ensure this isn't called multiple times s.shutdownLock.Lock() if s.hasShutdown { s.shutdownLock.Unlock() return } s.hasShutdown = true s.shutdownLock.Unlock() var timeout <-chan time.Time if closeTimeout > time.Duration(0) { timeout = time.After(closeTimeout) } streamsClosed := make(chan bool) go func() { s.streamCond.L.Lock() for len(s.streams) > 0 { debugMessage("Streams opened: %d, %#v", len(s.streams), s.streams) s.streamCond.Wait() } s.streamCond.L.Unlock() close(streamsClosed) }() var err error select { case <-streamsClosed: // No active streams, close should be safe err = s.conn.Close() case <-timeout: // Force ungraceful close err = s.conn.Close() // Wait for cleanup to clear active streams <-streamsClosed } if err != nil { duration := 10 * time.Minute time.AfterFunc(duration, func() { select { case err, ok := <-s.shutdownChan: if ok { debugMessage("Unhandled close error after %s: %s", duration, err) } default: } }) s.shutdownChan <- err } close(s.shutdownChan) } // Closes spdy connection by sending GoAway frame and initiating shutdown func (s *Connection) Close() error { s.receiveIdLock.Lock() if s.goneAway { s.receiveIdLock.Unlock() return nil } s.goneAway = true s.receiveIdLock.Unlock() var lastStreamId spdy.StreamId if s.receivedStreamId > 2 { lastStreamId = s.receivedStreamId - 2 } goAwayFrame := &spdy.GoAwayFrame{ LastGoodStreamId: lastStreamId, Status: spdy.GoAwayOK, } err := s.framer.WriteFrame(goAwayFrame) go s.shutdown(s.closeTimeout) if err != nil { return err } return nil } // CloseWait closes the connection and waits for shutdown // to finish. Note the underlying network Connection // is not closed until the end of shutdown. func (s *Connection) CloseWait() error { closeErr := s.Close() if closeErr != nil { return closeErr } shutdownErr, ok := <-s.shutdownChan if ok { return shutdownErr } return nil } // Wait waits for the connection to finish shutdown or for // the wait timeout duration to expire. This needs to be // called either after Close has been called or the GOAWAYFRAME // has been received. If the wait timeout is 0, this function // will block until shutdown finishes. If wait is never called // and a shutdown error occurs, that error will be logged as an // unhandled error. func (s *Connection) Wait(waitTimeout time.Duration) error { var timeout <-chan time.Time if waitTimeout > time.Duration(0) { timeout = time.After(waitTimeout) } select { case err, ok := <-s.shutdownChan: if ok { return err } case <-timeout: return ErrTimeout } return nil } // NotifyClose registers a channel to be called when the remote // peer inidicates connection closure. The last stream to be // received by the remote will be sent on the channel. The notify // timeout will determine the duration between go away received // and the connection being closed. func (s *Connection) NotifyClose(c chan<- *Stream, timeout time.Duration) { s.goAwayTimeout = timeout s.lastStreamChan = c } // SetCloseTimeout sets the amount of time close will wait for // streams to finish before terminating the underlying network // connection. Setting the timeout to 0 will cause close to // wait forever, which is the default. func (s *Connection) SetCloseTimeout(timeout time.Duration) { s.closeTimeout = timeout } // SetIdleTimeout sets the amount of time the connection may sit idle before // it is forcefully terminated. func (s *Connection) SetIdleTimeout(timeout time.Duration) { s.framer.setIdleTimeout(timeout) } func (s *Connection) sendHeaders(headers http.Header, stream *Stream, fin bool) error { var flags spdy.ControlFlags if fin { flags = spdy.ControlFlagFin } headerFrame := &spdy.HeadersFrame{ StreamId: stream.streamId, Headers: headers, CFHeader: spdy.ControlFrameHeader{Flags: flags}, } return s.framer.WriteFrame(headerFrame) } func (s *Connection) sendReply(headers http.Header, stream *Stream, fin bool) error { var flags spdy.ControlFlags if fin { flags = spdy.ControlFlagFin } replyFrame := &spdy.SynReplyFrame{ StreamId: stream.streamId, Headers: headers, CFHeader: spdy.ControlFrameHeader{Flags: flags}, } return s.framer.WriteFrame(replyFrame) } func (s *Connection) sendResetFrame(status spdy.RstStreamStatus, streamId spdy.StreamId) error { resetFrame := &spdy.RstStreamFrame{ StreamId: streamId, Status: status, } return s.framer.WriteFrame(resetFrame) } func (s *Connection) sendReset(status spdy.RstStreamStatus, stream *Stream) error { return s.sendResetFrame(status, stream.streamId) } func (s *Connection) sendStream(stream *Stream, fin bool) error { var flags spdy.ControlFlags if fin { flags = spdy.ControlFlagFin stream.finished = true } var parentId spdy.StreamId if stream.parent != nil { parentId = stream.parent.streamId } streamFrame := &spdy.SynStreamFrame{ StreamId: spdy.StreamId(stream.streamId), AssociatedToStreamId: spdy.StreamId(parentId), Headers: stream.headers, CFHeader: spdy.ControlFrameHeader{Flags: flags}, } return s.framer.WriteFrame(streamFrame) } // getNextStreamId returns the next sequential id // every call should produce a unique value or an error func (s *Connection) getNextStreamId() spdy.StreamId { sid := s.nextStreamId if sid > 0x7fffffff { return 0 } s.nextStreamId = s.nextStreamId + 2 return sid } // PeekNextStreamId returns the next sequential id and keeps the next id untouched func (s *Connection) PeekNextStreamId() spdy.StreamId { sid := s.nextStreamId return sid } func (s *Connection) validateStreamId(rid spdy.StreamId) error { if rid > 0x7fffffff || rid < s.receivedStreamId { return ErrInvalidStreamId } s.receivedStreamId = rid + 2 return nil } func (s *Connection) addStream(stream *Stream) { s.streamCond.L.Lock() s.streams[stream.streamId] = stream debugMessage("(%p) (%p) Stream added, broadcasting: %d", s, stream, stream.streamId) s.streamCond.Broadcast() s.streamCond.L.Unlock() } func (s *Connection) removeStream(stream *Stream) { s.streamCond.L.Lock() delete(s.streams, stream.streamId) debugMessage("(%p) (%p) Stream removed, broadcasting: %d", s, stream, stream.streamId) s.streamCond.Broadcast() s.streamCond.L.Unlock() } func (s *Connection) getStream(streamId spdy.StreamId) (stream *Stream, ok bool) { s.streamLock.RLock() stream, ok = s.streams[streamId] s.streamLock.RUnlock() return } // FindStream looks up the given stream id and either waits for the // stream to be found or returns nil if the stream id is no longer // valid. func (s *Connection) FindStream(streamId uint32) *Stream { var stream *Stream var ok bool s.streamCond.L.Lock() stream, ok = s.streams[spdy.StreamId(streamId)] debugMessage("(%p) Found stream %d? %t", s, spdy.StreamId(streamId), ok) for !ok && streamId >= uint32(s.receivedStreamId) { s.streamCond.Wait() stream, ok = s.streams[spdy.StreamId(streamId)] } s.streamCond.L.Unlock() return stream } func (s *Connection) CloseChan() <-chan bool { return s.closeChan }