tests: add more netlink tests for neighbors/routes

Differential Revision: https://reviews.freebsd.org/D38912
MFC after:	2 weeks
This commit is contained in:
Alexander V. Chernikov 2023-03-07 17:30:35 +00:00
parent a9a38dea37
commit c57dfd92c8
4 changed files with 183 additions and 15 deletions

View file

@ -29,6 +29,12 @@ def align4(val: int) -> int:
return roundup2(val, 4)
def enum_or_int(val) -> int:
if isinstance(val, Enum):
return val.value
return val
class SockaddrNl(Structure):
_fields_ = [
("nl_len", c_ubyte),
@ -125,8 +131,8 @@ class NlRtMsgType(Enum):
RTM_DELROUTE = 25
RTM_GETROUTE = 26
RTM_NEWNEIGH = 28
RTM_DELNEIGH = 27
RTM_GETNEIGH = 28
RTM_DELNEIGH = 29
RTM_GETNEIGH = 30
RTM_NEWRULE = 32
RTM_DELRULE = 33
RTM_GETRULE = 34
@ -491,6 +497,39 @@ class IfattrType(Enum):
IFA_TARGET_NETNSID = auto()
class NdMsg(Structure):
_fields_ = [
("ndm_family", c_ubyte),
("ndm_pad1", c_ubyte),
("ndm_pad2", c_ubyte),
("ndm_ifindex", c_uint),
("ndm_state", c_ushort),
("ndm_flags", c_ubyte),
("ndm_type", c_ubyte),
]
class NdAttrType(Enum):
NDA_UNSPEC = 0
NDA_DST = 1
NDA_LLADDR = 2
NDA_CACHEINFO = 3
NDA_PROBES = 4
NDA_VLAN = 5
NDA_PORT = 6
NDA_VNI = 7
NDA_IFINDEX = 8
NDA_MASTER = 9
NDA_LINK_NETNSID = 10
NDA_SRC_VNI = 11
NDA_PROTOCOL = 12
NDA_NH_ID = 13
NDA_FDB_EXT_ATTRS = 14
NDA_FLAGS_EXT = 15
NDA_NDM_STATE_MASK = 16
NDA_NDM_FLAGS_MASK = 17
class GenlMsgHdr(Structure):
_fields_ = [
("cmd", c_ubyte),
@ -702,7 +741,7 @@ def __bytes__(self):
class NlAttrU32(NlAttr):
def __init__(self, nla_type, val):
self.u32 = val
self.u32 = enum_or_int(val)
super().__init__(nla_type, b"")
@property
@ -729,7 +768,7 @@ def __bytes__(self):
class NlAttrU16(NlAttr):
def __init__(self, nla_type, val):
self.u16 = val
self.u16 = enum_or_int(val)
super().__init__(nla_type, b"")
@property
@ -756,7 +795,7 @@ def __bytes__(self):
class NlAttrU8(NlAttr):
def __init__(self, nla_type, val):
self.u8 = val
self.u8 = enum_or_int(val)
super().__init__(nla_type, b"")
@property
@ -842,6 +881,11 @@ def _print_attr_value(self):
return " iface=if#{}".format(self.u32)
class NlAttrMac(NlAttr):
def _print_attr_value(self):
return ["{:02}".format(int(d)) for d in data[4:]].join(":")
class NlAttrTable(NlAttrU32):
def _print_attr_value(self):
return " rtable={}".format(self.u32)
@ -1067,26 +1111,44 @@ def prepare_attrs_map(attrs: List[AttrDescr]) -> Dict[str, Dict]:
)
rtnl_nd_attrs = prepare_attrs_map(
[
AttrDescr(NdAttrType.NDA_DST, NlAttrIp),
AttrDescr(NdAttrType.NDA_IFINDEX, NlAttrIfindex),
AttrDescr(NdAttrType.NDA_FLAGS_EXT, NlAttrU32),
AttrDescr(NdAttrType.NDA_LLADDR, NlAttrMac),
]
)
class BaseNetlinkMessage(object):
def __init__(self, helper, nlmsg_type):
self.nlmsg_type = nlmsg_type
self.nlmsg_type = enum_or_int(nlmsg_type)
self.ut = unittest.TestCase()
self.nla_list = []
self._orig_data = None
self.helper = helper
self.nl_hdr = Nlmsghdr(
nlmsg_type=nlmsg_type, nlmsg_seq=helper.get_seq(), nlmsg_pid=helper.pid
nlmsg_type=self.nlmsg_type, nlmsg_seq=helper.get_seq(), nlmsg_pid=helper.pid
)
self.base_hdr = None
def set_request(self, need_ack=True):
self.add_nlflags([NlmBaseFlags.NLM_F_REQUEST])
if need_ack:
self.add_nlflags([NlmBaseFlags.NLM_F_ACK])
def add_nlflags(self, flags: List):
int_flags = 0
for flag in flags:
int_flags |= enum_or_int(flag)
self.nl_hdr.nlmsg_flags |= int_flags
def add_nla(self, nla):
self.nla_list.append(nla)
def _get_nla(self, nla_list, nla_type):
if isinstance(nla_type, Enum):
nla_type_raw = nla_type.value
else:
nla_type_raw = nla_type
nla_type_raw = enum_or_int(nla_type)
for nla in nla_list:
if nla.nla_type == nla_type_raw:
return nla
@ -1102,10 +1164,7 @@ def parse_nl_header(data: bytes):
return Nlmsghdr.from_buffer_copy(data), sizeof(Nlmsghdr)
def is_type(self, nlmsg_type):
if isinstance(nlmsg_type, Enum):
nlmsg_type_raw = nlmsg_type.value
else:
nlmsg_type_raw = nlmsg_type
nlmsg_type_raw = enum_or_int(nlmsg_type)
return nlmsg_type_raw == self.nl_hdr.nlmsg_type
def is_reply(self, hdr):
@ -1422,6 +1481,37 @@ def print_base_header(self, hdr, prepend=""):
)
class NetlinkNdMessage(BaseNetlinkRtMessage):
messages = [
NlRtMsgType.RTM_NEWNEIGH.value,
NlRtMsgType.RTM_DELNEIGH.value,
NlRtMsgType.RTM_GETNEIGH.value,
]
nl_attrs_map = rtnl_nd_attrs
def __init__(self, helper, nlm_type):
super().__init__(helper, nlm_type)
self.base_hdr = NdMsg()
def parse_base_header(self, data):
if len(data) < sizeof(NdMsg):
raise ValueError("length less than NdMsg header")
nd_hdr = NdMsg.from_buffer_copy(data)
return (nd_hdr, sizeof(NdMsg))
def print_base_header(self, hdr, prepend=""):
family = self.helper.get_af_name(hdr.ndm_family)
print(
"{}family={}, ndm_ifindex={}, ndm_state={}, ndm_flags={}".format( # noqa: E501
prepend,
family,
hdr.ndm_ifindex,
hdr.ndm_state,
hdr.ndm_flags,
)
)
class Nlsock:
def __init__(self, family, helper):
self.helper = helper
@ -1435,6 +1525,7 @@ def build_msgmap(self):
NetlinkRtMessage,
NetlinkIflaMessage,
NetlinkIfaMessage,
NetlinkNdMessage,
NetlinkDoneMessage,
NetlinkErrorMessage,
]

View file

@ -9,6 +9,7 @@ ATF_TESTS_C += test_snl
ATF_TESTS_PYTEST += test_nl_core.py
ATF_TESTS_PYTEST += test_rtnl_iface.py
ATF_TESTS_PYTEST += test_rtnl_ifaddr.py
ATF_TESTS_PYTEST += test_rtnl_neigh.py
ATF_TESTS_PYTEST += test_rtnl_route.py
CFLAGS+= -I${.CURDIR:H:H:H}

View file

@ -0,0 +1,53 @@
import socket
import pytest
from atf_python.sys.net.netlink import NdAttrType
from atf_python.sys.net.netlink import NetlinkNdMessage
from atf_python.sys.net.netlink import NetlinkTestTemplate
from atf_python.sys.net.netlink import NlConst
from atf_python.sys.net.netlink import NlRtMsgType
from atf_python.sys.net.vnet import SingleVnetTestTemplate
class TestRtNlNeigh(NetlinkTestTemplate, SingleVnetTestTemplate):
def setup_method(self, method):
method_name = method.__name__
if "4" in method_name:
self.IPV4_PREFIXES = ["192.0.2.1/24"]
if "6" in method_name:
self.IPV6_PREFIXES = ["2001:db8::1/64"]
super().setup_method(method)
self.setup_netlink(NlConst.NETLINK_ROUTE)
def filter_iface(self, family, num_items):
epair_ifname = self.vnet.iface_alias_map["if1"].name
epair_ifindex = socket.if_nametoindex(epair_ifname)
msg = NetlinkNdMessage(self.helper, NlRtMsgType.RTM_GETNEIGH)
msg.set_request()
msg.base_hdr.ndm_family = family
msg.base_hdr.ndm_ifindex = epair_ifindex
self.write_message(msg)
ret = []
for rx_msg in self.read_msg_list(
msg.nl_hdr.nlmsg_seq, NlRtMsgType.RTM_NEWNEIGH
):
ifname = socket.if_indextoname(rx_msg.base_hdr.ndm_ifindex)
family = rx_msg.base_hdr.ndm_family
assert ifname == epair_ifname
assert family == family
assert rx_msg.get_nla(NdAttrType.NDA_DST) is not None
assert rx_msg.get_nla(NdAttrType.NDA_LLADDR) is not None
ret.append(rx_msg)
assert len(ret) == num_items
@pytest.mark.timeout(5)
def test_6_filter_iface(self):
"""Tests that listing outputs all nd6 records"""
return self.filter_iface(socket.AF_INET6, 2)
@pytest.mark.timeout(5)
def test_4_filter_iface(self):
"""Tests that listing outputs all arp records"""
return self.filter_iface(socket.AF_INET, 1)

View file

@ -2,9 +2,11 @@
import socket
import pytest
from atf_python.sys.net.tools import ToolsHelper
from atf_python.sys.net.netlink import NetlinkRtMessage
from atf_python.sys.net.netlink import NetlinkTestTemplate
from atf_python.sys.net.netlink import NlAttrIp
from atf_python.sys.net.netlink import NlAttrU32
from atf_python.sys.net.netlink import NlConst
from atf_python.sys.net.netlink import NlmBaseFlags
from atf_python.sys.net.netlink import NlmGetFlags
@ -22,6 +24,27 @@ def setup_method(self, method):
super().setup_method(method)
self.setup_netlink(NlConst.NETLINK_ROUTE)
@pytest.mark.timeout(5)
def test_add_route6_ll_gw(self):
epair_ifname = self.vnet.iface_alias_map["if1"].name
epair_ifindex = socket.if_nametoindex(epair_ifname)
msg = NetlinkRtMessage(self.helper, NlRtMsgType.RTM_NEWROUTE)
msg.set_request()
msg.add_nlflags([NlmNewFlags.NLM_F_CREATE])
msg.base_hdr.rtm_family = socket.AF_INET6
msg.base_hdr.rtm_dst_len = 64
msg.add_nla(NlAttrIp(RtattrType.RTA_DST, "2001:db8:2::"))
msg.add_nla(NlAttrIp(RtattrType.RTA_GATEWAY, "fe80::1"))
msg.add_nla(NlAttrU32(RtattrType.RTA_OIF, epair_ifindex))
rx_msg = self.get_reply(msg)
assert rx_msg.is_type(NlMsgType.NLMSG_ERROR)
assert rx_msg.error_code == 0
ToolsHelper.print_net_debug()
ToolsHelper.print_output("netstat -6onW")
@pytest.mark.timeout(20)
def test_buffer_override(self):
msg_flags = (