diff --git a/oci.go b/oci.go index 4380c190..3071d3e5 100644 --- a/oci.go +++ b/oci.go @@ -15,11 +15,13 @@ package main import ( + "bufio" "errors" "fmt" "net" "os" "path/filepath" + "strings" "syscall" vc "github.com/containers/virtcontainers" @@ -44,7 +46,9 @@ const ( var errNeedLinuxResource = errors.New("Linux resource cannot be empty") -var cgroupsDirPath = "/sys/fs/cgroup" +var cgroupsDirPath string + +var procMountInfo = "/proc/self/mountinfo" // getContainerInfo returns the container status and its pod ID. // It internally expands the container ID from the prefix provided. @@ -191,6 +195,12 @@ func processCgroupsPathForResource(ociSpec oci.CompatOCISpec, resource string, i return "", errNeedLinuxResource } + var err error + cgroupsDirPath, err = getCgroupsDirPath(procMountInfo) + if err != nil { + return "", fmt.Errorf("get CgroupsDirPath error: %s", err) + } + // Relative cgroups path provided. if filepath.IsAbs(ociSpec.Linux.CgroupsPath) == false { return filepath.Join(cgroupsDirPath, resource, ociSpec.Linux.CgroupsPath), nil @@ -302,3 +312,41 @@ func noNeedForOutput(detach bool, tty bool) bool { return true } + +func getCgroupsDirPath(mountInfoFile string) (string, error) { + if cgroupsDirPath != "" { + return cgroupsDirPath, nil + } + + f, err := os.Open(mountInfoFile) + if err != nil { + return "", err + } + defer f.Close() + + var cgroupRootPath string + scanner := bufio.NewScanner(f) + for scanner.Scan() { + text := scanner.Text() + index := strings.Index(text, " - ") + if index < 0 { + continue + } + fields := strings.Split(text, " ") + postSeparatorFields := strings.Fields(text[index+3:]) + numPostFields := len(postSeparatorFields) + + if len(fields) < 5 || postSeparatorFields[0] != "cgroup" || numPostFields < 3 { + continue + } + + cgroupRootPath = filepath.Dir(fields[4]) + break + } + + if _, err = os.Stat(cgroupRootPath); err != nil { + return "", err + } + + return cgroupRootPath, nil +} diff --git a/oci_test.go b/oci_test.go index 7c3fbc19..4d0f9a9e 100644 --- a/oci_test.go +++ b/oci_test.go @@ -530,7 +530,12 @@ func TestIsCgroupMounted(t *testing.T) { assert.False(isCgroupMounted(os.TempDir()), "%s is not a cgroup", os.TempDir()) - memoryCgroupPath := "/sys/fs/cgroup/memory" + cgroupsDirPath = "" + cgroupRootPath, err := getCgroupsDirPath(procMountInfo) + if err != nil { + assert.NoError(err) + } + memoryCgroupPath := filepath.Join(cgroupRootPath, "memory") if _, err := os.Stat(memoryCgroupPath); os.IsNotExist(err) { t.Skipf("memory cgroup does not exist: %s", memoryCgroupPath) } @@ -562,3 +567,56 @@ func TestProcessCgroupsPathForResource(t *testing.T) { assert.False(vcMock.IsMockError(err)) } } + +func TestGetCgroupsDirPath(t *testing.T) { + assert := assert.New(t) + + type testData struct { + contents string + expectedResult string + expectError bool + } + + dir, err := ioutil.TempDir("", "") + if err != nil { + assert.NoError(err) + } + defer os.RemoveAll(dir) + + // make sure tested cgroupsDirPath is existed + testedCgroupDir := filepath.Join(dir, "weirdCgroup") + err = os.Mkdir(testedCgroupDir, testDirMode) + assert.NoError(err) + + weirdCgroupPath := filepath.Join(testedCgroupDir, "memory") + + data := []testData{ + {fmt.Sprintf("num1 num2 num3 / %s num6 num7 - cgroup cgroup rw,memory", weirdCgroupPath), testedCgroupDir, false}, + // cgroup mount is not properly formated, if fields post - less than 3 + {fmt.Sprintf("num1 num2 num3 / %s num6 num7 - cgroup cgroup ", weirdCgroupPath), "", true}, + {"a a a a a a a - b c d", "", true}, + {"a \na b \na b c\na b c d", "", true}, + {"", "", true}, + } + + file := filepath.Join(dir, "mountinfo") + + //file does not exist, should error here + _, err = getCgroupsDirPath(file) + assert.Error(err) + + for _, d := range data { + err := ioutil.WriteFile(file, []byte(d.contents), testFileMode) + assert.NoError(err) + + cgroupsDirPath = "" + path, err := getCgroupsDirPath(file) + if d.expectError { + assert.Error(err, fmt.Sprintf("got %q, test data: %+v", path, d)) + } else { + assert.NoError(err, fmt.Sprintf("got %q, test data: %+v", path, d)) + } + + assert.Equal(d.expectedResult, path) + } +}