diff --git a/Gopkg.lock b/Gopkg.lock index 6b92f8dae6..fab79b9e68 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -397,7 +397,7 @@ revision = "8cba5a8e5f2816f26f9dc34b8ea968279a5a76eb" [[projects]] - digest = "1:54e0385cece7064d8afd967e0987e52cd8b261c8c7e22d84e8b25719d7379e74" + digest = "1:903bfb87f41dc18a533d327b5b03fd2f562f34deafcbd0db93d4be3fa8d6099c" name = "github.com/kata-containers/agent" packages = [ "pkg/types", @@ -405,7 +405,7 @@ "protocols/grpc", ] pruneopts = "NUT" - revision = "3ffb7ca1067565a45ee9fbfcb109eb85e7e899af" + revision = "32c87e75c2e4c014961f104c3c59b87f2aee3384" [[projects]] branch = "master" diff --git a/Gopkg.toml b/Gopkg.toml index 8c2aa840a5..ebf76340a5 100644 --- a/Gopkg.toml +++ b/Gopkg.toml @@ -52,7 +52,7 @@ [[constraint]] name = "github.com/kata-containers/agent" - revision = "3ffb7ca1067565a45ee9fbfcb109eb85e7e899af" + revision = "32c87e75c2e4c014961f104c3c59b87f2aee3384" [[constraint]] name = "github.com/containerd/cri-containerd" diff --git a/vendor/github.com/kata-containers/agent/protocols/client/client.go b/vendor/github.com/kata-containers/agent/protocols/client/client.go index a4b0ac3db9..911c252ef8 100644 --- a/vendor/github.com/kata-containers/agent/protocols/client/client.go +++ b/vendor/github.com/kata-containers/agent/protocols/client/client.go @@ -27,9 +27,9 @@ import ( ) const ( - unixSocketScheme = "unix" - vsockSocketScheme = "vsock" - hybridVSockScheme = "hvsock" + UnixSocketScheme = "unix" + VSockSocketScheme = "vsock" + HybridVSockScheme = "hvsock" ) var defaultDialTimeout = 15 * time.Second @@ -142,7 +142,7 @@ func parse(sock string) (string, *url.URL, error) { var grpcAddr string // validate more switch addr.Scheme { - case vsockSocketScheme: + case VSockSocketScheme: if addr.Hostname() == "" || addr.Port() == "" || addr.Path != "" { return "", nil, grpcStatus.Errorf(codes.InvalidArgument, "Invalid vsock scheme: %s", sock) } @@ -152,19 +152,19 @@ func parse(sock string) (string, *url.URL, error) { if _, err := strconv.ParseUint(addr.Port(), 10, 32); err != nil { return "", nil, grpcStatus.Errorf(codes.InvalidArgument, "Invalid vsock port: %s", sock) } - grpcAddr = vsockSocketScheme + ":" + addr.Host - case unixSocketScheme: + grpcAddr = VSockSocketScheme + ":" + addr.Host + case UnixSocketScheme: fallthrough case "": if (addr.Host == "" && addr.Path == "") || addr.Port() != "" { return "", nil, grpcStatus.Errorf(codes.InvalidArgument, "Invalid unix scheme: %s", sock) } if addr.Host == "" { - grpcAddr = unixSocketScheme + ":///" + addr.Path + grpcAddr = UnixSocketScheme + ":///" + addr.Path } else { - grpcAddr = unixSocketScheme + ":///" + addr.Host + "/" + addr.Path + grpcAddr = UnixSocketScheme + ":///" + addr.Host + "/" + addr.Path } - case hybridVSockScheme: + case HybridVSockScheme: if addr.Path == "" { return "", nil, grpcStatus.Errorf(codes.InvalidArgument, "Invalid hybrid vsock scheme: %s", sock) } @@ -178,7 +178,7 @@ func parse(sock string) (string, *url.URL, error) { return "", nil, grpcStatus.Errorf(codes.InvalidArgument, "Invalid hybrid vsock port %s: %v", sock, err) } hybridVSockPort = uint32(port) - grpcAddr = hybridVSockScheme + ":" + hvsocket[0] + grpcAddr = HybridVSockScheme + ":" + hvsocket[0] default: return "", nil, grpcStatus.Errorf(codes.InvalidArgument, "Invalid scheme: %s", sock) } @@ -209,11 +209,11 @@ func heartBeat(session *yamux.Session) { func agentDialer(addr *url.URL, enableYamux bool) dialer { var d dialer switch addr.Scheme { - case vsockSocketScheme: + case VSockSocketScheme: d = vsockDialer - case hybridVSockScheme: - d = hybridVSockDialer - case unixSocketScheme: + case HybridVSockScheme: + d = HybridVSockDialer + case UnixSocketScheme: fallthrough default: d = unixDialer @@ -281,7 +281,7 @@ func parseGrpcVsockAddr(sock string) (uint32, uint32, error) { if len(sp) != 3 { return 0, 0, grpcStatus.Errorf(codes.InvalidArgument, "Invalid vsock address: %s", sock) } - if sp[0] != vsockSocketScheme { + if sp[0] != VSockSocketScheme { return 0, 0, grpcStatus.Errorf(codes.InvalidArgument, "Invalid vsock URL scheme: %s", sp[0]) } @@ -297,16 +297,26 @@ func parseGrpcVsockAddr(sock string) (uint32, uint32, error) { return uint32(cid), uint32(port), nil } -func parseGrpcHybridVSockAddr(sock string) (string, error) { +func parseGrpcHybridVSockAddr(sock string) (string, uint32, error) { sp := strings.Split(sock, ":") - if len(sp) != 2 { - return "", grpcStatus.Errorf(codes.InvalidArgument, "Invalid hybrid vsock address: %s", sock) + // scheme and host are required + if len(sp) < 2 { + return "", 0, grpcStatus.Errorf(codes.InvalidArgument, "Invalid hybrid vsock address: %s", sock) } - if sp[0] != hybridVSockScheme { - return "", grpcStatus.Errorf(codes.InvalidArgument, "Invalid hybrid vsock URL scheme: %s", sock) + if sp[0] != HybridVSockScheme { + return "", 0, grpcStatus.Errorf(codes.InvalidArgument, "Invalid hybrid vsock URL scheme: %s", sock) } - return sp[1], nil + port := uint32(0) + // the third is the port + if len(sp) == 3 { + p, err := strconv.ParseUint(sp[2], 10, 32) + if err == nil { + port = uint32(p) + } + } + + return sp[1], port, nil } // This would bypass the grpc dialer backoff strategy and handle dial timeout @@ -371,8 +381,9 @@ func vsockDialer(sock string, timeout time.Duration) (net.Conn, error) { return commonDialer(timeout, dialFunc, timeoutErr) } -func hybridVSockDialer(sock string, timeout time.Duration) (net.Conn, error) { - udsPath, err := parseGrpcHybridVSockAddr(sock) +// HybridVSockDialer dials to a hybrid virtio socket +func HybridVSockDialer(sock string, timeout time.Duration) (net.Conn, error) { + udsPath, port, err := parseGrpcHybridVSockAddr(sock) if err != nil { return nil, err } @@ -382,10 +393,16 @@ func hybridVSockDialer(sock string, timeout time.Duration) (net.Conn, error) { if err != nil { return nil, err } + + if port == 0 { + // use the port read at parse() + port = hybridVSockPort + } + // Once the connection is opened, the following command MUST BE sent, // the hypervisor needs to know the port number where the agent is listening in order to // create the connection - if _, err = conn.Write([]byte(fmt.Sprintf("CONNECT %d\n", hybridVSockPort))); err != nil { + if _, err = conn.Write([]byte(fmt.Sprintf("CONNECT %d\n", port))); err != nil { conn.Close() return nil, err }