diff --git a/drivers/vdpa/vdpa_user/vduse_dev.c b/drivers/vdpa/vdpa_user/vduse_dev.c index 73c89701fc9d..7ae99691efdf 100644 --- a/drivers/vdpa/vdpa_user/vduse_dev.c +++ b/drivers/vdpa/vdpa_user/vduse_dev.c @@ -8,6 +8,7 @@ * */ +#include "linux/virtio_net.h" #include #include #include @@ -28,6 +29,7 @@ #include #include #include +#include #include #include "iova_domain.h" @@ -141,6 +143,7 @@ static struct workqueue_struct *vduse_irq_bound_wq; static u32 allowed_device_id[] = { VIRTIO_ID_BLOCK, + VIRTIO_ID_NET, }; static inline struct vduse_dev *vdpa_to_vduse(struct vdpa_device *vdpa) @@ -1705,13 +1708,21 @@ static bool device_is_allowed(u32 device_id) return false; } -static bool features_is_valid(u64 features) +static bool features_is_valid(struct vduse_dev_config *config) { - if (!(features & (1ULL << VIRTIO_F_ACCESS_PLATFORM))) + if (!(config->features & BIT_ULL(VIRTIO_F_ACCESS_PLATFORM))) return false; /* Now we only support read-only configuration space */ - if (features & (1ULL << VIRTIO_BLK_F_CONFIG_WCE)) + if ((config->device_id == VIRTIO_ID_BLOCK) && + (config->features & BIT_ULL(VIRTIO_BLK_F_CONFIG_WCE))) + return false; + else if ((config->device_id == VIRTIO_ID_NET) && + (config->features & BIT_ULL(VIRTIO_NET_F_CTRL_VQ))) + return false; + + if ((config->device_id == VIRTIO_ID_NET) && + !(config->features & BIT_ULL(VIRTIO_F_VERSION_1))) return false; return true; @@ -1738,7 +1749,7 @@ static bool vduse_validate_config(struct vduse_dev_config *config) if (!device_is_allowed(config->device_id)) return false; - if (!features_is_valid(config->features)) + if (!features_is_valid(config)) return false; return true; @@ -1821,6 +1832,10 @@ static int vduse_create_dev(struct vduse_dev_config *config, int ret; struct vduse_dev *dev; + ret = -EPERM; + if ((config->device_id == VIRTIO_ID_NET) && !capable(CAP_NET_ADMIN)) + goto err; + ret = -EEXIST; if (vduse_find_dev(config->name)) goto err; @@ -2064,6 +2079,7 @@ static const struct vdpa_mgmtdev_ops vdpa_dev_mgmtdev_ops = { static struct virtio_device_id id_table[] = { { VIRTIO_ID_BLOCK, VIRTIO_DEV_ANY_ID }, + { VIRTIO_ID_NET, VIRTIO_DEV_ANY_ID }, { 0 }, };