diff --git a/protocols/client/client.go b/protocols/client/client.go index 6e386fe129..4da19ece98 100644 --- a/protocols/client/client.go +++ b/protocols/client/client.go @@ -8,6 +8,7 @@ package client import ( "context" + "fmt" "net" "net/url" "strconv" @@ -31,6 +32,7 @@ const ( ) var defaultDialTimeout = 15 * time.Second +var defaultCloseTimeout = 5 * time.Second // AgentClient is an agent gRPC client connection wrapper for agentgrpc.AgentServiceClient type AgentClient struct { @@ -45,7 +47,26 @@ type yamuxSessionStream struct { } func (y *yamuxSessionStream) Close() error { - return y.session.Close() + waitCh := y.session.CloseChan() + timeout := time.NewTimer(defaultCloseTimeout) + + if err := y.Conn.Close(); err != nil { + return err + } + + if err := y.session.Close(); err != nil { + return err + } + + // block until session is really closed + select { + case <-waitCh: + timeout.Stop() + case <-timeout.C: + return fmt.Errorf("timeout waiting for session close") + } + + return nil } type dialer func(string, time.Duration) (net.Conn, error)