Skip to content

Commit

Permalink
Internal refactoring to avoid unneeded netlink parsing
Browse files Browse the repository at this point in the history
The parsing of some netlink structures is performed only for those vdpa objects
that need to be returned to the user.

Signed-off-by: Leonardo Milleri <[email protected]>
  • Loading branch information
Leonardo Milleri authored and lmilleri committed Sep 26, 2023
1 parent ff4e4ec commit 3335d73
Show file tree
Hide file tree
Showing 5 changed files with 196 additions and 52 deletions.
24 changes: 8 additions & 16 deletions cmd/kvdpa-cli/kvdpa-cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package main
import (
"fmt"
"os"
"strings"
"text/template"

vdpa "github.com/k8snetworkplumbingwg/govdpa/pkg/kvdpa"
Expand All @@ -26,26 +25,19 @@ const deviceTemplate = ` - Name: {{ .Name }}
func listAction(c *cli.Context) error {
var devs []vdpa.VdpaDevice
var err error
if c.String("mgmtdev") != "" {
var bus, name string
nameParts := strings.Split(c.String("mgmtdev"), "/")
if len(nameParts) == 1 {
name = nameParts[0]
} else if len(nameParts) == 2 {
bus = nameParts[0]
name = nameParts[1]
} else {
return fmt.Errorf("Invalid management device name %s", c.String("mgmtdev"))
}
devs, err = vdpa.GetVdpaDevicesByMgmtDev(bus, name)
var mgmtDev = c.String("mgmtdev")
if mgmtDev != "" {
var busName, devName string
busName, devName, err = vdpa.ExtractBusAndMgmtDevice(mgmtDev)
if err != nil {
return err
}
devs, err = vdpa.GetVdpaDevicesByMgmtDev(busName, devName)
} else {
devs, err = vdpa.ListVdpaDevices()
if err != nil {
fmt.Println(err)
}
}
if err != nil {
return err
}
tmpl := template.Must(template.New("device").Parse(deviceTemplate))

Expand Down
55 changes: 20 additions & 35 deletions pkg/kvdpa/device.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
package kvdpa

import (
"errors"
"os"
"path/filepath"
"strings"
"syscall"

"github.com/vishvananda/netlink/nl"
Expand Down Expand Up @@ -185,7 +183,8 @@ func GetVdpaDevice(name string) (VdpaDevice, error) {
return nil, err
}

vdpaDevs, err := parseDevLinkVdpaDevList(msgs)
// No filters, expecting to parse attributes for the device with the given name
vdpaDevs, err := parseDevLinkVdpaDevList("", "", msgs)
if err != nil {
return nil, err
}
Expand All @@ -197,50 +196,27 @@ GetVdpaDevicesByMgmtDev returns the VdpaDevice objects whose MgmtDev
has the given bus and device names.
*/
func GetVdpaDevicesByMgmtDev(busName, devName string) ([]VdpaDevice, error) {
result := []VdpaDevice{}
devices, err := ListVdpaDevices()
if err != nil {
return nil, err
}
for _, device := range devices {
if device.MgmtDev() != nil &&
device.MgmtDev().BusName() == busName &&
device.MgmtDev().DevName() == devName {
result = append(result, device)
}
}
if len(result) == 0 {
return nil, syscall.ENODEV
}
return result, nil
return listVdpaDevicesWithBusDevName(busName, devName)
}

/*ListVdpaDevices returns a list of all available vdpa devices */
func ListVdpaDevices() ([]VdpaDevice, error) {
return listVdpaDevicesWithBusDevName("", "")
}

func listVdpaDevicesWithBusDevName(busName, devName string) ([]VdpaDevice, error) {
msgs, err := GetNetlinkOps().RunVdpaNetlinkCmd(VdpaCmdDevGet, syscall.NLM_F_DUMP, nil)
if err != nil {
return nil, err
}

vdpaDevs, err := parseDevLinkVdpaDevList(msgs)
vdpaDevs, err := parseDevLinkVdpaDevList(busName, devName, msgs)
if err != nil {
return nil, err
}
return vdpaDevs, nil
}

func extractBusNameAndMgmtDeviceName(fullMgmtDeviceName string) (busName string, mgmtDeviceName string, err error) {
numSlashes := strings.Count(fullMgmtDeviceName, "/")
if numSlashes > 1 {
return "", "", errors.New("expected mgmtDeviceName to be either in the format <mgmtBusName>/<mgmtDeviceName> or <mgmtDeviceName>")
} else if numSlashes == 0 {
return "", fullMgmtDeviceName, nil
} else {
values := strings.Split(fullMgmtDeviceName, "/")
return values[0], values[1], nil
}
}

/*
GetVdpaDevicesByPciAddress returns the VdpaDevice objects for the given pciAddress
Expand All @@ -249,7 +225,7 @@ GetVdpaDevicesByPciAddress returns the VdpaDevice objects for the given pciAddre
- MgmtDevName
*/
func GetVdpaDevicesByPciAddress(pciAddress string) ([]VdpaDevice, error) {
busName, mgmtDeviceName, err := extractBusNameAndMgmtDeviceName(pciAddress)
busName, mgmtDeviceName, err := ExtractBusAndMgmtDevice(pciAddress)
if err != nil {
return nil, unix.EINVAL
}
Expand All @@ -263,7 +239,7 @@ func AddVdpaDevice(mgmtDeviceName string, vdpaDeviceName string) error {
return unix.EINVAL
}

busName, mgmtDeviceName, err := extractBusNameAndMgmtDeviceName(mgmtDeviceName)
busName, mgmtDeviceName, err := ExtractBusAndMgmtDevice(mgmtDeviceName)
if err != nil {
return unix.EINVAL
}
Expand Down Expand Up @@ -317,7 +293,7 @@ func DeleteVdpaDevice(vdpaDeviceName string) error {
return nil
}

func parseDevLinkVdpaDevList(msgs [][]byte) ([]VdpaDevice, error) {
func parseDevLinkVdpaDevList(busName string, mgmtDeviceName string, msgs [][]byte) ([]VdpaDevice, error) {
devices := make([]VdpaDevice, 0, len(msgs))

for _, m := range msgs {
Expand All @@ -329,6 +305,15 @@ func parseDevLinkVdpaDevList(msgs [][]byte) ([]VdpaDevice, error) {
if err = dev.parseAttributes(attrs); err != nil {
return nil, err
}

if busName != "" && busName != dev.mgmtDev.busName {
continue
}

if mgmtDeviceName != "" && mgmtDeviceName != dev.mgmtDev.devName {
continue
}

if err = dev.getBusInfo(); err != nil {
return nil, err
}
Expand Down
94 changes: 93 additions & 1 deletion pkg/kvdpa/device_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,98 @@ func TestVdpaDevList(t *testing.T) {
}
}

func TestVdpaDevListWithFilter(t *testing.T) {
tests := []struct {
name string
err bool
response []VdpaDevice
}{
{
name: "Multiple SR-IOV and SF devices",
err: false,
response: []VdpaDevice{
&vdpaDev{
name: "vdpa0",
mgmtDev: &mgmtDev{
devName: "0000:01:01",
},
},
&vdpaDev{
name: "vdpa1",
mgmtDev: &mgmtDev{
busName: "pci",
devName: "0000:01:02",
},
},
&vdpaDev{
name: "vdpa2",
mgmtDev: &mgmtDev{
busName: "pci",
devName: "0000:01:02",
},
},
&vdpaDev{
name: "vdpa3",
mgmtDev: &mgmtDev{
busName: "pci",
devName: "0000:01:03",
},
},
},
},
}

for _, tt := range tests {
t.Run(fmt.Sprintf("%s_%s", "TestVdpaDevListWithFilter", tt.name), func(t *testing.T) {
netLinkMock := &mocks.NetlinkOps{}
SetNetlinkOps(netLinkMock)
netLinkMock.On("RunVdpaNetlinkCmd",
VdpaCmdDevGet,
mock.MatchedBy(func(flags int) bool {
return (flags|syscall.NLM_F_DUMP != 0)
}),
mock.AnythingOfType("[]*nl.RtAttr")).
Return(vdpaDevToNlMessage(t, tt.response...), nil)
// no filters, all devices are returned
devs, err := ListVdpaDevices()
if tt.err {
assert.NotNil(t, err)
} else {
assert.Nil(t, err)
assert.Equal(t, tt.response, devs)
}
// mgmtdev: 0000:01:01
devs, err = GetVdpaDevicesByPciAddress("0000:01:01")
if tt.err {
assert.NotNil(t, err)
} else {
assert.Nil(t, err)
assert.Equal(t, len(devs), 1)
assert.Equal(t, tt.response[0], devs[0])
}
// mgmtdev: pci/0000:01:02
devs, err = GetVdpaDevicesByPciAddress("pci/0000:01:02")
if tt.err {
assert.NotNil(t, err)
} else {
assert.Nil(t, err)
assert.Equal(t, len(devs), 2)
assert.Equal(t, tt.response[1], devs[0])
assert.Equal(t, tt.response[2], devs[1])
}
// mgmtdev: pci/0000:01:03
devs, err = GetVdpaDevicesByPciAddress("pci/0000:01:03")
if tt.err {
assert.NotNil(t, err)
} else {
assert.Nil(t, err)
assert.Equal(t, len(devs), 1)
assert.Equal(t, tt.response[3], devs[0])
}
})
}
}

func TestVdpaDevGet(t *testing.T) {
tests := []struct {
name string
Expand Down Expand Up @@ -304,7 +396,7 @@ func TestVdpaDevGetByMgmt(t *testing.T) {
},
{
name: "Wrong",
err: syscall.ENODEV,
response: []VdpaDevice{},
mgmtDevName: "noDev",
mgmtBusName: "wrongBus",
},
Expand Down
22 changes: 22 additions & 0 deletions pkg/kvdpa/util.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package kvdpa

import (
"errors"
"strings"
)

// ExtractBusAndMgmtDevice extracts the busName and deviceName from a full device address (e.g. pci)
// example 1: pci/65:0000.1 -> "pci", "65:0000.1", nil
// example 2: vdpa_sim -> "", "vdpa_sim", nil
// example 3: pci/65:0000.1/1 -> "", "", err
func ExtractBusAndMgmtDevice(fullMgmtDeviceName string) (busName string, mgmtDeviceName string, err error) {
numSlashes := strings.Count(fullMgmtDeviceName, "/")
if numSlashes > 1 {
return "", "", errors.New("expected mgmtDeviceName to be either in the format <mgmtBusName>/<mgmtDeviceName> or <mgmtDeviceName>")
} else if numSlashes == 0 {
return "", fullMgmtDeviceName, nil
} else {
values := strings.Split(fullMgmtDeviceName, "/")
return values[0], values[1], nil
}
}
53 changes: 53 additions & 0 deletions pkg/kvdpa/util_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package kvdpa

import (
"fmt"
"testing"

"github.com/stretchr/testify/assert"
)

func TestExtractBusAndMgmtDevice(t *testing.T) {
tests := []struct {
testName string
deviceAddress string
busName string
devName string
err bool
}{
{
testName: "regular PCI address",
deviceAddress: "pci/0000:65:00.1",
busName: "pci",
devName: "0000:65:00.1",
err: false,
},
{
testName: "no bus",
deviceAddress: "vdpa_sim",
busName: "",
devName: "vdpa_sim",
err: false,
},
{
testName: "wrong address",
deviceAddress: "pci/0000:65:00.1/0",
busName: "",
devName: "",
err: true,
},
}

for _, tt := range tests {
t.Run(fmt.Sprintf("%s_%s", "TestExtractBusAndMgmtDevice", tt.testName), func(t *testing.T) {
busName, devName, err := ExtractBusAndMgmtDevice(tt.deviceAddress)
if tt.err {
assert.NotNil(t, err)
} else {
assert.Nil(t, err)
assert.Equal(t, tt.busName, busName)
assert.Equal(t, tt.devName, devName)
}
})
}
}

0 comments on commit 3335d73

Please sign in to comment.