diff --git a/device.go b/device.go index 3a3a157ac..994aa3970 100644 --- a/device.go +++ b/device.go @@ -42,9 +42,7 @@ const ( ) var ( - sysBusPrefix = sysfsDir + "/bus/pci/devices" pciBusRescanFile = sysfsDir + "/bus/pci/rescan" - pciBusPathFormat = "%s/%s/pci_bus/" systemDevPath = "/dev" getSCSIDevPath = getSCSIDevPathImpl getPmemDevPath = getPmemDevPathImpl @@ -82,10 +80,11 @@ type devIndex map[string]devIndexEntry // Guest-side PCI path, identifies a PCI device by where it sits in // the PCI topology. // -// Has the format "bridgeAddr/deviceAddr" where bridgeAddr is the -// address at which the brige is attached on the root bus, while -// deviceAddr is the address at which the device is attached on the -// bridge +// Has the format "xx/.../yy/zz" Here, zz is the slot of the device on +// its PCI bridge, yy is the slot of the bridge on its parent bridge +// and so forth until xx is the slot of the "most upstream" bridge on +// the root bus. If a device is connected directly to the root bus, +// its pciPath is just "zz" type PciPath struct { path string } @@ -108,42 +107,38 @@ func rescanPciBus() error { // the syfs path for the PCI host bridge, based on the PCI path // provided. func pciPathToSysfsImpl(pciPath PciPath) (string, error) { - tokens := strings.Split(pciPath.path, "/") - - if len(tokens) != 2 { - return "", fmt.Errorf("PCI path for device should be of format [bridgeAddr/deviceAddr], got %q", pciPath) - } + var relPath string + bus := "0000:00" - bridgeID := tokens[0] - deviceID := tokens[1] + tokens := strings.Split(pciPath.path, "/") - // Deduce the complete bridge address based on the bridge address identifier passed - // and the fact that bridges are attached on the main bus with function 0. - pciBridgeAddr := fmt.Sprintf("0000:00:%s.0", bridgeID) + for i, slot := range tokens { + // Full PCI address of this device along the path + bdf := fmt.Sprintf("%s:%s.0", bus, slot) - // Find out the bus exposed by bridge - bridgeBusPath := fmt.Sprintf(pciBusPathFormat, sysBusPrefix, pciBridgeAddr) + relPath = filepath.Join(relPath, bdf) - files, err := ioutil.ReadDir(bridgeBusPath) - if err != nil { - return "", fmt.Errorf("Error with getting bridge pci bus : %s", err) - } + if i == len(tokens)-1 { + // Final device need not be a bridge + break + } - busNum := len(files) - if busNum != 1 { - return "", fmt.Errorf("Expected an entry for bus in %s, got %d entries instead", bridgeBusPath, busNum) - } + // Find out the bus exposed by bridge + bridgeBusPath := filepath.Join(sysfsDir, rootBusPath, relPath, "pci_bus") - bus := files[0].Name() + files, err := ioutil.ReadDir(bridgeBusPath) + if err != nil { + return "", fmt.Errorf("Error reading %s : %s", bridgeBusPath, err) + } - // Device address is based on the bus of the bridge to which it is attached. - // We do not pass devices as multifunction, hence the trailing 0 in the address. - pciDeviceAddr := fmt.Sprintf("%s:%s.0", bus, deviceID) + if len(files) != 1 { + return "", fmt.Errorf("Expected exactly one PCI bus in %s, got %d instead", bridgeBusPath, len(files)) + } - sysfsRelPath := fmt.Sprintf("%s/%s", pciBridgeAddr, pciDeviceAddr) - agentLog.WithField("sysfsRelPath", sysfsRelPath).Info("Fetched sysfs relative path for PCI device") + bus = files[0].Name() + } - return sysfsRelPath, nil + return relPath, nil } func getDeviceName(s *sandbox, devID string) (string, error) { diff --git a/device_test.go b/device_test.go index 0b60e1c0a..34f5622a0 100644 --- a/device_test.go +++ b/device_test.go @@ -99,11 +99,15 @@ func TestPciPathToSysfs(t *testing.T) { } defer os.RemoveAll(testDir) - // Set sysBusPrefix to test directory for unit tests. - sysBusPrefix = testDir + // Set sysfsDir to test directory for unit tests. + sysfsDir = testDir + rootBus := filepath.Join(sysfsDir, rootBusPath) + err = os.MkdirAll(rootBus, mountPerm) + assert.NoError(t, err) - _, err = pciPathToSysfs(PciPath{"02"}) - assert.Error(t, err) + sysRelPath, err := pciPathToSysfs(PciPath{"02"}) + assert.NoError(t, err) + assert.Equal(t, sysRelPath, "0000:00:02.0") _, err = pciPathToSysfs(PciPath{"02/03"}) assert.Error(t, err) @@ -112,13 +116,14 @@ func TestPciPathToSysfs(t *testing.T) { assert.Error(t, err) // Create mock sysfs files for the device at 0000:00:02.0 - bridge2Path := fmt.Sprintf(pciBusPathFormat, sysBusPrefix, "0000:00:02.0") + bridge2Path := filepath.Join(rootBus, "0000:00:02.0") err = os.MkdirAll(bridge2Path, mountPerm) assert.NoError(t, err) - _, err = pciPathToSysfs(PciPath{"02"}) - assert.Error(t, err) + sysRelPath, err = pciPathToSysfs(PciPath{"02"}) + assert.NoError(t, err) + assert.Equal(t, sysRelPath, "0000:00:02.0") _, err = pciPathToSysfs(PciPath{"02/03"}) assert.Error(t, err) @@ -128,18 +133,40 @@ func TestPciPathToSysfs(t *testing.T) { // Create mock sysfs files to indicate that 0000:00:02.0 is a bridge to bus 01 bridge2Bus := "0000:01" - err = os.MkdirAll(filepath.Join(bridge2Path, bridge2Bus), mountPerm) + err = os.MkdirAll(filepath.Join(bridge2Path, "pci_bus", bridge2Bus), mountPerm) assert.NoError(t, err) - _, err = pciPathToSysfs(PciPath{"02"}) - assert.Error(t, err) + sysRelPath, err = pciPathToSysfs(PciPath{"02"}) + assert.NoError(t, err) + assert.Equal(t, sysRelPath, "0000:00:02.0") - sysRelPath, err := pciPathToSysfs(PciPath{"02/03"}) + sysRelPath, err = pciPathToSysfs(PciPath{"02/03"}) assert.NoError(t, err) assert.Equal(t, sysRelPath, "0000:00:02.0/0000:01:03.0") _, err = pciPathToSysfs(PciPath{"02/03/04"}) assert.Error(t, err) + + // Create mock sysfs files for a bridge at 0000:01:03.0 to bus 02 + bridge3Path := filepath.Join(bridge2Path, "0000:01:03.0") + bridge3Bus := "0000:02" + err = os.MkdirAll(filepath.Join(bridge3Path, "pci_bus", bridge3Bus), mountPerm) + assert.NoError(t, err) + + err = os.MkdirAll(bridge3Path, mountPerm) + assert.NoError(t, err) + + sysRelPath, err = pciPathToSysfs(PciPath{"02"}) + assert.NoError(t, err) + assert.Equal(t, sysRelPath, "0000:00:02.0") + + sysRelPath, err = pciPathToSysfs(PciPath{"02/03"}) + assert.NoError(t, err) + assert.Equal(t, sysRelPath, "0000:00:02.0/0000:01:03.0") + + sysRelPath, err = pciPathToSysfs(PciPath{"02/03/04"}) + assert.NoError(t, err) + assert.Equal(t, sysRelPath, "0000:00:02.0/0000:01:03.0/0000:02:04.0") } func TestScanSCSIBus(t *testing.T) { diff --git a/mount_test.go b/mount_test.go index eff66284a..da00a4654 100644 --- a/mount_test.go +++ b/mount_test.go @@ -223,9 +223,9 @@ func TestVirtioBlkStorageHandlerSuccessful(t *testing.T) { pciPath := fmt.Sprintf("%s/%s", bridgeID, deviceID) - sysBusPrefix = testDir - bridgeBusPath := fmt.Sprintf(pciBusPathFormat, sysBusPrefix, "0000:00:02.0") - + // Set sysfsDir to test directory for unit tests. + sysfsDir = testDir + bridgeBusPath := filepath.Join(sysfsDir, rootBusPath, "0000:00:02.0", "pci_bus") err = os.MkdirAll(filepath.Join(bridgeBusPath, pciBus), mountPerm) assert.Nil(t, err)