// Copyright (c) 2021 VMware, Inc. or its affiliates. All Rights Reserved. // Copyright (c) 2012-2021, Sean Treadway, SoundCloud Ltd. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package amqp091 import ( "bytes" "io" "reflect" "testing" "time" ) type server struct { *testing.T r reader // framer <- client w writer // framer -> client S io.ReadWriteCloser // Server IO C io.ReadWriteCloser // Client IO // captured client frames start connectionStartOk tune connectionTuneOk } var defaultLogin = "guest" var defaultPassword = "guest" var defaultPlainAuth = &PlainAuth{defaultLogin, defaultPassword} var defaultAMQPlainAuth = &AMQPlainAuth{defaultLogin, defaultPassword} func defaultConfigWithAuth(auth Authentication) Config { return Config{ SASL: []Authentication{auth}, Vhost: "/", Locale: defaultLocale, } } func defaultConfig() Config { return defaultConfigWithAuth(defaultPlainAuth) } func amqplainConfig() Config { return defaultConfigWithAuth(defaultAMQPlainAuth) } func newServer(t *testing.T, serverIO, clientIO io.ReadWriteCloser) *server { return &server{ T: t, r: reader{serverIO}, w: writer{serverIO}, S: serverIO, C: clientIO, } } func newSession(t *testing.T) (io.ReadWriteCloser, *server) { rs, wc := io.Pipe() rc, ws := io.Pipe() rws := &logIO{t, "server", pipe{rs, ws}} rwc := &logIO{t, "client", pipe{rc, wc}} return rwc, newServer(t, rws, rwc) } func (t *server) expectBytes(b []byte) { in := make([]byte, len(b)) if _, err := io.ReadFull(t.S, in); err != nil { t.Fatalf("io error expecting bytes: %v", err) } if !bytes.Equal(b, in) { t.Fatalf("failed bytes: expected: %s got: %s", string(b), string(in)) } } func (t *server) send(channel int, m message) { defer time.AfterFunc(time.Second, func() { t.Fatalf("send deadlock") }).Stop() if msg, ok := m.(messageWithContent); ok { props, body := msg.getContent() class, _ := msg.id() if err := t.w.WriteFrame(&methodFrame{ ChannelId: uint16(channel), Method: msg, }); err != nil { t.Fatalf("WriteFrame error: %v", err) } if err := t.w.WriteFrame(&headerFrame{ ChannelId: uint16(channel), ClassId: class, Size: uint64(len(body)), Properties: props, }); err != nil { t.Fatalf("WriteFrame error: %v", err) } if err := t.w.WriteFrame(&bodyFrame{ ChannelId: uint16(channel), Body: body, }); err != nil { t.Fatalf("WriteFrame error: %v", err) } } else { if err := t.w.WriteFrame(&methodFrame{ ChannelId: uint16(channel), Method: m, }); err != nil { t.Fatalf("WriteFrame error: %v", err) } } } // drops all but method frames expected on the given channel func (t *server) recv(channel int, m message) message { defer time.AfterFunc(time.Second, func() { t.Fatalf("recv deadlock") }).Stop() var remaining int var header *headerFrame var body []byte for { frame, err := t.r.ReadFrame() if err != nil { t.Fatalf("frame err, read: %s", err) } if frame.channel() != uint16(channel) { t.Fatalf("expected frame on channel %d, got channel %d", channel, frame.channel()) } switch f := frame.(type) { case *heartbeatFrame: // drop case *headerFrame: // start content state header = f remaining = int(header.Size) if remaining == 0 { m.(messageWithContent).setContent(header.Properties, nil) return m } case *bodyFrame: // continue until terminated body = append(body, f.Body...) remaining -= len(f.Body) if remaining <= 0 { m.(messageWithContent).setContent(header.Properties, body) return m } case *methodFrame: if reflect.TypeOf(m) == reflect.TypeOf(f.Method) { wantv := reflect.ValueOf(m).Elem() havev := reflect.ValueOf(f.Method).Elem() wantv.Set(havev) if _, ok := m.(messageWithContent); !ok { return m } } else { t.Fatalf("expected method type: %T, got: %T", m, f.Method) } default: t.Fatalf("unexpected frame: %+v", f) } } } func (t *server) expectAMQP() { t.expectBytes([]byte{'A', 'M', 'Q', 'P', 0, 0, 9, 1}) } func (t *server) connectionStartWithMechanisms(mechs string, recv bool) { t.send(0, &connectionStart{ VersionMajor: 0, VersionMinor: 9, Mechanisms: mechs, Locales: defaultLocale, }) if recv { t.recv(0, &t.start) } } func (t *server) connectionStart() { t.connectionStartWithMechanisms("PLAIN", true) } func (t *server) connectionTune() { t.send(0, &connectionTune{ ChannelMax: 11, FrameMax: 20000, Heartbeat: 10, }) t.recv(0, &t.tune) } func (t *server) connectionOpen() { t.expectAMQP() t.connectionStart() t.connectionTune() t.recv(0, &connectionOpen{}) t.send(0, &connectionOpenOk{}) } func (t *server) connectionClose() { t.recv(0, &connectionClose{}) t.send(0, &connectionCloseOk{}) } func (t *server) channelOpen(id int) { t.recv(id, &channelOpen{}) t.send(id, &channelOpenOk{}) } func TestDefaultClientProperties(t *testing.T) { rwc, srv := newSession(t) t.Cleanup(func() { rwc.Close() }) go func() { srv.connectionOpen() }() if c, err := Open(rwc, defaultConfig()); err != nil { t.Fatalf("could not create connection: %v (%s)", c, err) } if want, got := defaultProduct, srv.start.ClientProperties["product"]; want != got { t.Errorf("expected product %s got: %s", want, got) } if want, got := buildVersion, srv.start.ClientProperties["version"]; want != got { t.Errorf("expected version %s got: %s", want, got) } if want, got := defaultLocale, srv.start.Locale; want != got { t.Errorf("expected locale %s got: %s", want, got) } } func TestCustomClientProperties(t *testing.T) { rwc, srv := newSession(t) t.Cleanup(func() { rwc.Close() }) config := defaultConfig() config.Properties = Table{ "product": "foo", "version": "1.0", } go func() { srv.connectionOpen() }() if c, err := Open(rwc, config); err != nil { t.Fatalf("could not create connection: %v (%s)", c, err) } if want, got := config.Properties["product"], srv.start.ClientProperties["product"]; want != got { t.Errorf("expected product %s got: %s", want, got) } if want, got := config.Properties["version"], srv.start.ClientProperties["version"]; want != got { t.Errorf("expected version %s got: %s", want, got) } } func TestOpen(t *testing.T) { rwc, srv := newSession(t) t.Cleanup(func() { rwc.Close() }) go func() { srv.connectionOpen() }() if c, err := Open(rwc, defaultConfig()); err != nil { t.Fatalf("could not create connection: %v (%s)", c, err) } } func TestChannelOpen(t *testing.T) { rwc, srv := newSession(t) t.Cleanup(func() { rwc.Close() }) go func() { srv.connectionOpen() srv.channelOpen(1) }() c, err := Open(rwc, defaultConfig()) if err != nil { t.Fatalf("could not create connection: %v (%s)", c, err) } ch, err := c.Channel() if err != nil { t.Fatalf("could not open channel: %v (%s)", ch, err) } } func TestOpenFailedSASLUnsupportedMechanisms(t *testing.T) { rwc, srv := newSession(t) t.Cleanup(func() { rwc.Close() }) go func() { srv.expectAMQP() srv.connectionStartWithMechanisms("KERBEROS NTLM", false) }() c, err := Open(rwc, defaultConfig()) if err != ErrSASL { t.Fatalf("expected ErrSASL got: %+v on %+v", err, c) } } func TestOpenAMQPlainAuth(t *testing.T) { auth := make(chan Table) rwc, srv := newSession(t) t.Cleanup(func() { rwc.Close() }) go func() { srv.expectAMQP() srv.connectionStartWithMechanisms("AMQPLAIN", true) var authresp bytes.Buffer _ = writeLongstr(&authresp, srv.start.Response) table, _ := readTable(&authresp) srv.connectionTune() srv.recv(0, &connectionOpen{}) srv.send(0, &connectionOpenOk{}) auth <- table }() if c, err := Open(rwc, amqplainConfig()); err != nil { t.Fatalf("could not create connection: %v (%s)", c, err) } table := <-auth if table["LOGIN"] != defaultLogin { t.Fatalf("unexpected login: want: %s, got: %s", defaultLogin, table["LOGIN"]) } if table["PASSWORD"] != defaultPassword { t.Fatalf("unexpected password: want: %s, got: %s", defaultPassword, table["PASSWORD"]) } } func TestOpenFailedCredentials(t *testing.T) { rwc, srv := newSession(t) go func() { srv.expectAMQP() srv.connectionStart() // Now kill/timeout the connection indicating bad auth rwc.Close() }() c, err := Open(rwc, defaultConfig()) if err != ErrCredentials { t.Fatalf("expected ErrCredentials got: %+v on %+v", err, c) } } func TestOpenFailedVhost(t *testing.T) { rwc, srv := newSession(t) go func() { srv.expectAMQP() srv.connectionStart() srv.connectionTune() srv.recv(0, &connectionOpen{}) // Now kill/timeout the connection on bad Vhost rwc.Close() }() c, err := Open(rwc, defaultConfig()) if err != ErrVhost { t.Fatalf("expected ErrVhost got: %+v on %+v", err, c) } } func TestConfirmMultipleOrdersDeliveryTags(t *testing.T) { rwc, srv := newSession(t) defer rwc.Close() go func() { srv.connectionOpen() srv.channelOpen(1) srv.recv(1, &confirmSelect{}) srv.send(1, &confirmSelectOk{}) srv.recv(1, &basicPublish{}) srv.recv(1, &basicPublish{}) srv.recv(1, &basicPublish{}) srv.recv(1, &basicPublish{}) // Single tag, plus multiple, should produce // 2, 1, 3, 4 srv.send(1, &basicAck{DeliveryTag: 2}) srv.send(1, &basicAck{DeliveryTag: 1}) srv.send(1, &basicAck{DeliveryTag: 4, Multiple: true}) srv.recv(1, &basicPublish{}) srv.recv(1, &basicPublish{}) srv.recv(1, &basicPublish{}) srv.recv(1, &basicPublish{}) // And some more, but in reverse order, multiple then one // 5, 6, 7, 8 srv.send(1, &basicAck{DeliveryTag: 6, Multiple: true}) srv.send(1, &basicAck{DeliveryTag: 8}) srv.send(1, &basicAck{DeliveryTag: 7}) }() c, err := Open(rwc, defaultConfig()) if err != nil { t.Fatalf("could not create connection: %v (%s)", c, err) } ch, err := c.Channel() if err != nil { t.Fatalf("could not open channel: %v (%s)", ch, err) } confirm := ch.NotifyPublish(make(chan Confirmation)) err = ch.Confirm(false) if err != nil { t.Fatalf("channel error setting confirm mode: %v (%s)", ch, err) } go func() { var e error if e = ch.Publish("", "q", false, false, Publishing{Body: []byte("pub 1")}); e != nil { t.Errorf("publish error: %v", err) } if e = ch.Publish("", "q", false, false, Publishing{Body: []byte("pub 2")}); e != nil { t.Errorf("publish error: %v", err) } if e = ch.Publish("", "q", false, false, Publishing{Body: []byte("pub 3")}); e != nil { t.Errorf("publish error: %v", err) } if e = ch.Publish("", "q", false, false, Publishing{Body: []byte("pub 4")}); e != nil { t.Errorf("publish error: %v", err) } }() // received out of order, consumed in order for i, tag := range []uint64{1, 2, 3, 4} { if ack := <-confirm; tag != ack.DeliveryTag { t.Fatalf("failed ack, expected ack#%d to be %d, got %d", i, tag, ack.DeliveryTag) } } go func() { var e error if e = ch.Publish("", "q", false, false, Publishing{Body: []byte("pub 5")}); e != nil { t.Errorf("publish error: %v", err) } if e = ch.Publish("", "q", false, false, Publishing{Body: []byte("pub 6")}); e != nil { t.Errorf("publish error: %v", err) } if e = ch.Publish("", "q", false, false, Publishing{Body: []byte("pub 7")}); e != nil { t.Errorf("publish error: %v", err) } if e = ch.Publish("", "q", false, false, Publishing{Body: []byte("pub 8")}); e != nil { t.Errorf("publish error: %v", err) } }() for i, tag := range []uint64{5, 6, 7, 8} { if ack := <-confirm; tag != ack.DeliveryTag { t.Fatalf("failed ack, expected ack#%d to be %d, got %d", i, tag, ack.DeliveryTag) } } } func TestDeferredConfirmations(t *testing.T) { rwc, srv := newSession(t) defer rwc.Close() go func() { srv.connectionOpen() srv.channelOpen(1) srv.recv(1, &confirmSelect{}) srv.send(1, &confirmSelectOk{}) srv.recv(1, &basicPublish{}) srv.recv(1, &basicPublish{}) srv.recv(1, &basicPublish{}) srv.recv(1, &basicPublish{}) }() c, err := Open(rwc, defaultConfig()) if err != nil { t.Fatalf("could not create connection: %v (%s)", c, err) } ch, err := c.Channel() if err != nil { t.Fatalf("could not open channel: %v (%s)", ch, err) } err = ch.Confirm(false) if err != nil { t.Fatalf("channel error setting confirm mode: %v (%s)", ch, err) } var results []*DeferredConfirmation for i := 1; i < 5; i++ { dc, err := ch.PublishWithDeferredConfirm("", "q", false, false, Publishing{Body: []byte("pub")}) if err != nil { t.Fatalf("failed to PublishWithDeferredConfirm: %v", err) } results = append(results, dc) } acks := make(chan Confirmation, 4) for _, result := range results { go func(r *DeferredConfirmation) { acks <- Confirmation{Ack: r.Wait(), DeliveryTag: r.DeliveryTag} }(result) } // received out of order, consumed out of order assertReceive := func(ack Confirmation, tags ...uint64) { for _, tag := range tags { if tag == ack.DeliveryTag { return } } t.Fatalf("failed ack, expected ack to be in set %v, got %d", tags, ack.DeliveryTag) } srv.send(1, &basicAck{DeliveryTag: 2}) assertReceive(<-acks, 2) srv.send(1, &basicAck{DeliveryTag: 1}) assertReceive(<-acks, 1) srv.send(1, &basicAck{DeliveryTag: 4, Multiple: true}) assertReceive(<-acks, 3, 4) // 3 and 4 are non-determistic due to map ordering assertReceive(<-acks, 3, 4) } func TestNotifyClosesReusedPublisherConfirmChan(t *testing.T) { rwc, srv := newSession(t) go func() { srv.connectionOpen() srv.channelOpen(1) srv.recv(1, &confirmSelect{}) srv.send(1, &confirmSelectOk{}) srv.recv(0, &connectionClose{}) srv.send(0, &connectionCloseOk{}) }() c, err := Open(rwc, defaultConfig()) if err != nil { t.Fatalf("could not create connection: %v (%s)", c, err) } ch, err := c.Channel() if err != nil { t.Fatalf("could not open channel: %v (%s)", ch, err) } ackAndNack := make(chan uint64) ch.NotifyConfirm(ackAndNack, ackAndNack) if err := ch.Confirm(false); err != nil { t.Fatalf("expected to enter confirm mode: %v", err) } if err := c.Close(); err != nil { t.Fatalf("could not close connection: %v (%s)", c, err) } } func TestNotifyClosesAllChansAfterConnectionClose(t *testing.T) { rwc, srv := newSession(t) go func() { srv.connectionOpen() srv.channelOpen(1) srv.recv(0, &connectionClose{}) srv.send(0, &connectionCloseOk{}) }() c, err := Open(rwc, defaultConfig()) if err != nil { t.Fatalf("could not create connection: %v (%s)", c, err) } ch, err := c.Channel() if err != nil { t.Fatalf("could not open channel: %v (%s)", ch, err) } if err := c.Close(); err != nil { t.Fatalf("could not close connection: %v (%s)", c, err) } select { case <-c.NotifyClose(make(chan *Error)): case <-time.After(time.Millisecond): t.Errorf("expected to close NotifyClose chan after Connection.Close") } select { case <-ch.NotifyClose(make(chan *Error)): case <-time.After(time.Millisecond): t.Errorf("expected to close Connection.NotifyClose chan after Connection.Close") } select { case <-ch.NotifyFlow(make(chan bool)): case <-time.After(time.Millisecond): t.Errorf("expected to close Channel.NotifyFlow chan after Connection.Close") } select { case <-ch.NotifyCancel(make(chan string)): case <-time.After(time.Millisecond): t.Errorf("expected to close Channel.NofityCancel chan after Connection.Close") } select { case <-ch.NotifyReturn(make(chan Return)): case <-time.After(time.Millisecond): t.Errorf("expected to close Channel.NotifyReturn chan after Connection.Close") } confirms := ch.NotifyPublish(make(chan Confirmation)) select { case <-confirms: case <-time.After(time.Millisecond): t.Errorf("expected to close confirms on Channel.NotifyPublish chan after Connection.Close") } } // Should not panic when sending bodies split at different boundaries func TestPublishBodySliceIssue74(t *testing.T) { rwc, srv := newSession(t) defer rwc.Close() const frameSize = 100 const publishings = frameSize * 3 done := make(chan bool) base := make([]byte, publishings) go func() { srv.connectionOpen() srv.channelOpen(1) for i := 0; i < publishings; i++ { srv.recv(1, &basicPublish{}) } done <- true }() cfg := defaultConfig() cfg.FrameSize = frameSize c, err := Open(rwc, cfg) if err != nil { t.Fatalf("could not create connection: %v (%s)", c, err) } ch, err := c.Channel() if err != nil { t.Fatalf("could not open channel: %v (%s)", ch, err) } for i := 0; i < publishings; i++ { go func(ii int) { if err := ch.Publish("", "q", false, false, Publishing{Body: base[0:ii]}); err != nil { t.Errorf("publish error: %v", err) } }(i) } <-done } // Should not panic when server and client have frame_size of 0 func TestPublishZeroFrameSizeIssue161(t *testing.T) { rwc, srv := newSession(t) defer rwc.Close() const frameSize = 0 const publishings = 1 done := make(chan bool) go func() { srv.connectionOpen() srv.channelOpen(1) for i := 0; i < publishings; i++ { srv.recv(1, &basicPublish{}) } done <- true }() cfg := defaultConfig() cfg.FrameSize = frameSize c, err := Open(rwc, cfg) // override the tuned framesize with a hard 0, as would happen when rabbit is configured with 0 c.Config.FrameSize = frameSize if err != nil { t.Fatalf("could not create connection: %v (%s)", c, err) } ch, err := c.Channel() if err != nil { t.Fatalf("could not open channel: %v (%s)", ch, err) } for i := 0; i < publishings; i++ { go func() { if err := ch.Publish("", "q", false, false, Publishing{Body: []byte("anything")}); err != nil { t.Errorf("publish error: %v", err) } }() } <-done } func TestPublishAndShutdownDeadlockIssue84(t *testing.T) { rwc, srv := newSession(t) defer rwc.Close() go func() { srv.connectionOpen() srv.channelOpen(1) srv.recv(1, &basicPublish{}) // Mimic a broken io pipe so that Publish catches the error and goes into shutdown srv.S.Close() }() c, err := Open(rwc, defaultConfig()) if err != nil { t.Fatalf("couldn't create connection: %v (%s)", c, err) } ch, err := c.Channel() if err != nil { t.Fatalf("couldn't open channel: %v (%s)", ch, err) } defer time.AfterFunc(500*time.Millisecond, func() { t.Fatalf("Publish deadlock") }).Stop() for { if err := ch.Publish("exchange", "q", false, false, Publishing{Body: []byte("test")}); err != nil { t.Log("successfully caught disconnect error", err) return } } } // TestChannelReturnsCloseRace ensures that receiving a basicReturn frame and // sending the notification to the bound channel does not race with // channel.shutdown() which closes all registered notification channels - checks // for a "send on closed channel" panic func TestChannelReturnsCloseRace(t *testing.T) { defer time.AfterFunc(5*time.Second, func() { t.Fatalf("Shutdown deadlock") }).Stop() ch := newChannel(&Connection{}, 1) // Register a channel to close in channel.shutdown() notify := make(chan Return, 1) ch.NotifyReturn(notify) go func() { for range notify { // Drain notifications } }() // Simulate receiving a load of returns (triggering a write to the above // channel) while we call shutdown concurrently go func() { for i := 0; i < 100; i++ { ch.dispatch(&basicReturn{}) } }() ch.shutdown(nil) } // TestLeakClosedConsumersIssue264 ensures that closing a consumer with // prefetched messages does not leak the buffering goroutine. func TestLeakClosedConsumersIssue264(t *testing.T) { const tag = "consumer-tag" rwc, srv := newSession(t) defer rwc.Close() go func() { srv.connectionOpen() srv.channelOpen(1) srv.recv(1, &basicQos{}) srv.send(1, &basicQosOk{}) srv.recv(1, &basicConsume{}) srv.send(1, &basicConsumeOk{ConsumerTag: tag}) // This delivery is intended to be consumed srv.send(1, &basicDeliver{ConsumerTag: tag, DeliveryTag: 1}) // This delivery is intended to be dropped srv.send(1, &basicDeliver{ConsumerTag: tag, DeliveryTag: 2}) srv.recv(0, &connectionClose{}) srv.send(0, &connectionCloseOk{}) srv.C.Close() }() c, err := Open(rwc, defaultConfig()) if err != nil { t.Fatalf("could not create connection: %v (%s)", c, err) } ch, err := c.Channel() if err != nil { t.Fatalf("could not open channel: %v (%s)", ch, err) } err = ch.Qos(2, 0, false) if err != nil { t.Fatalf("channel Qos error: %v (%s)", ch, err) } consumer, err := ch.Consume("queue", tag, false, false, false, false, nil) if err != nil { t.Fatalf("unexpected error during consumer: %v", err) } first := <-consumer if want, got := uint64(1), first.DeliveryTag; want != got { t.Fatalf("unexpected delivery tag: want: %d, got: %d", want, got) } if err := c.Close(); err != nil { t.Fatalf("unexpected error during connection close: %v", err) } if _, open := <-consumer; open { t.Fatalf("expected deliveries channel to be closed immediately when the connection is closed so not to leak the bufferDeliveries goroutine") } }