diff --git a/virtcontainers/network.go b/virtcontainers/network.go index 3778471982..a6e574d510 100644 --- a/virtcontainers/network.go +++ b/virtcontainers/network.go @@ -6,6 +6,7 @@ package virtcontainers import ( + "bufio" "encoding/hex" "encoding/json" "fmt" @@ -591,6 +592,107 @@ func newNetwork(networkType NetworkModel) network { } } +const procMountInfoFile = "/proc/self/mountinfo" + +// getNetNsFromBindMount returns the network namespace for the bind-mounted path +func getNetNsFromBindMount(nsPath string, procMountFile string) (string, error) { + netNsMountType := "nsfs" + + // Resolve all symlinks in the path as the mountinfo file contains + // resolved paths. + nsPath, err := filepath.EvalSymlinks(nsPath) + if err != nil { + return "", err + } + + f, err := os.Open(procMountFile) + if err != nil { + return "", err + } + defer f.Close() + + scanner := bufio.NewScanner(f) + for scanner.Scan() { + text := scanner.Text() + + // Scan the mountinfo file to search for the network namespace path + // This file contains mounts in the eg format: + // "711 26 0:3 net:[4026532009] /run/docker/netns/default rw shared:535 - nsfs nsfs rw" + // + // Reference: https://www.kernel.org/doc/Documentation/filesystems/proc.txt + + // We are interested in the first 9 fields of this file, + // to check for the correct mount type. + fields := strings.Split(text, " ") + if len(fields) < 9 { + continue + } + + // We check here if the mount type is a network namespace mount type, namely "nsfs" + mountTypeFieldIdx := 8 + if fields[mountTypeFieldIdx] != netNsMountType { + continue + } + + // This is the mount point/destination for the mount + mntDestIdx := 4 + if fields[mntDestIdx] != nsPath { + continue + } + + // This is the root/source of the mount + return fields[3], nil + } + + return "", nil +} + +// hostNetworkingRequested checks if the network namespace requested is the +// same as the current process. +func hostNetworkingRequested(configNetNs string) (bool, error) { + var evalNS, nsPath, currentNsPath string + var err error + + // Net namespace provided as "/proc/pid/ns/net" or "/proc//task//ns/net" + if strings.HasPrefix(configNetNs, "/proc") && strings.HasSuffix(configNetNs, "/ns/net") { + if _, err := os.Stat(configNetNs); err != nil { + return false, err + } + + // Here we are trying to resolve the path but it fails because + // namespaces links don't really exist. For this reason, the + // call to EvalSymlinks will fail when it will try to stat the + // resolved path found. As we only care about the path, we can + // retrieve it from the PathError structure. + if _, err = filepath.EvalSymlinks(configNetNs); err != nil { + nsPath = err.(*os.PathError).Path + } else { + return false, fmt.Errorf("Net namespace path %s is not a symlink", configNetNs) + } + + _, evalNS = filepath.Split(nsPath) + + } else { + // Bind-mounted path provided + evalNS, _ = getNetNsFromBindMount(configNetNs, procMountInfoFile) + } + + currentNS := fmt.Sprintf("/proc/%d/task/%d/ns/net", os.Getpid(), unix.Gettid()) + if _, err = filepath.EvalSymlinks(currentNS); err != nil { + currentNsPath = err.(*os.PathError).Path + } else { + return false, fmt.Errorf("Unexpected: Current network namespace path is not a symlink") + } + + _, evalCurrentNS := filepath.Split(currentNsPath) + + if evalNS == evalCurrentNS { + return true, nil + } + + return false, nil +} + func initNetworkCommon(config NetworkConfig) (string, bool, error) { if !config.InterworkingModel.IsValid() || config.InterworkingModel == NetXConnectDefaultModel { config.InterworkingModel = DefaultNetInterworkingModel @@ -605,6 +707,15 @@ func initNetworkCommon(config NetworkConfig) (string, bool, error) { return path, true, nil } + isHostNs, err := hostNetworkingRequested(config.NetNSPath) + if err != nil { + return "", false, err + } + + if isHostNs { + return "", false, fmt.Errorf("Host networking requested, not supported by runtime") + } + return config.NetNSPath, false, nil } diff --git a/virtcontainers/network_test.go b/virtcontainers/network_test.go index f829c3aaab..c2442f4f33 100644 --- a/virtcontainers/network_test.go +++ b/virtcontainers/network_test.go @@ -7,12 +7,16 @@ package virtcontainers import ( "fmt" + "io/ioutil" "net" "os" + "path/filepath" "reflect" + "syscall" "testing" "github.com/containernetworking/plugins/pkg/ns" + "github.com/stretchr/testify/assert" "github.com/vishvananda/netlink" "github.com/vishvananda/netns" ) @@ -488,3 +492,84 @@ func TestVhostUserEndpointAttach(t *testing.T) { t.Fatal(err) } } + +func TestGetNetNsFromBindMount(t *testing.T) { + assert := assert.New(t) + + tmpdir, err := ioutil.TempDir("", "") + assert.NoError(err) + defer os.RemoveAll(tmpdir) + + mountFile := filepath.Join(tmpdir, "mountInfo") + nsPath := filepath.Join(tmpdir, "ns123") + + // Non-existent namespace path + _, err = getNetNsFromBindMount(nsPath, mountFile) + assert.NotNil(err) + + tmpNSPath := filepath.Join(tmpdir, "testNetNs") + f, err := os.Create(tmpNSPath) + assert.NoError(err) + defer f.Close() + + type testData struct { + contents string + expectedResult string + } + + data := []testData{ + {fmt.Sprintf("711 26 0:3 net:[4026532008] %s rw shared:535 - nsfs nsfs rw", tmpNSPath), "net:[4026532008]"}, + {"711 26 0:3 net:[4026532008] /run/netns/ns123 rw shared:535 - tmpfs tmpfs rw", ""}, + {"a a a a a a a - b c d", ""}, + {"", ""}, + } + + for i, d := range data { + err := ioutil.WriteFile(mountFile, []byte(d.contents), 0640) + assert.NoError(err) + + path, err := getNetNsFromBindMount(tmpNSPath, mountFile) + assert.NoError(err, fmt.Sprintf("got %q, test data: %+v", path, d)) + + assert.Equal(d.expectedResult, path, "Test %d, expected %s, got %s", i, d.expectedResult, path) + } +} + +func TestHostNetworkingRequested(t *testing.T) { + if os.Geteuid() != 0 { + t.Skip(testDisabledAsNonRoot) + } + + assert := assert.New(t) + + // Network namespace same as the host + selfNsPath := "/proc/self/ns/net" + isHostNs, err := hostNetworkingRequested(selfNsPath) + assert.NoError(err) + assert.True(isHostNs) + + // Non-existent netns path + nsPath := "/proc/123/ns/net" + _, err = hostNetworkingRequested(nsPath) + assert.Error(err) + + // Bind-mounted Netns + tmpdir, err := ioutil.TempDir("", "") + assert.NoError(err) + defer os.RemoveAll(tmpdir) + + // Create a bind mount to the current network namespace. + tmpFile := filepath.Join(tmpdir, "testNetNs") + f, err := os.Create(tmpFile) + assert.NoError(err) + defer f.Close() + + err = syscall.Mount(selfNsPath, tmpFile, "bind", syscall.MS_BIND, "") + assert.Nil(err) + + isHostNs, err = hostNetworkingRequested(tmpFile) + assert.NoError(err) + assert.True(isHostNs) + + syscall.Unmount(tmpFile, 0) +}