diff options
-rw-r--r-- | contrib/libvhost-user/libvhost-user.c | 24 |
1 files changed, 16 insertions, 8 deletions
diff --git a/contrib/libvhost-user/libvhost-user.c b/contrib/libvhost-user/libvhost-user.c index 533d55d82a..3abc9689e5 100644 --- a/contrib/libvhost-user/libvhost-user.c +++ b/contrib/libvhost-user/libvhost-user.c @@ -948,6 +948,7 @@ static bool vu_check_queue_msg_file(VuDev *dev, VhostUserMsg *vmsg) { int index = vmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK; + bool nofd = vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK; if (index >= dev->max_queues) { vmsg_close_fds(vmsg); @@ -955,8 +956,12 @@ vu_check_queue_msg_file(VuDev *dev, VhostUserMsg *vmsg) return false; } - if (vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK || - vmsg->fd_num != 1) { + if (nofd) { + vmsg_close_fds(vmsg); + return true; + } + + if (vmsg->fd_num != 1) { vmsg_close_fds(vmsg); vu_panic(dev, "Invalid fds in request: %d", vmsg->request); return false; @@ -1053,6 +1058,7 @@ static bool vu_set_vring_kick_exec(VuDev *dev, VhostUserMsg *vmsg) { int index = vmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK; + bool nofd = vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK; DPRINT("u64: 0x%016"PRIx64"\n", vmsg->payload.u64); @@ -1066,8 +1072,8 @@ vu_set_vring_kick_exec(VuDev *dev, VhostUserMsg *vmsg) dev->vq[index].kick_fd = -1; } - dev->vq[index].kick_fd = vmsg->fds[0]; - DPRINT("Got kick_fd: %d for vq: %d\n", vmsg->fds[0], index); + dev->vq[index].kick_fd = nofd ? -1 : vmsg->fds[0]; + DPRINT("Got kick_fd: %d for vq: %d\n", dev->vq[index].kick_fd, index); dev->vq[index].started = true; if (dev->iface->queue_set_started) { @@ -1147,6 +1153,7 @@ static bool vu_set_vring_call_exec(VuDev *dev, VhostUserMsg *vmsg) { int index = vmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK; + bool nofd = vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK; DPRINT("u64: 0x%016"PRIx64"\n", vmsg->payload.u64); @@ -1159,14 +1166,14 @@ vu_set_vring_call_exec(VuDev *dev, VhostUserMsg *vmsg) dev->vq[index].call_fd = -1; } - dev->vq[index].call_fd = vmsg->fds[0]; + dev->vq[index].call_fd = nofd ? -1 : vmsg->fds[0]; /* in case of I/O hang after reconnecting */ - if (eventfd_write(vmsg->fds[0], 1)) { + if (dev->vq[index].call_fd != -1 && eventfd_write(vmsg->fds[0], 1)) { return -1; } - DPRINT("Got call_fd: %d for vq: %d\n", vmsg->fds[0], index); + DPRINT("Got call_fd: %d for vq: %d\n", dev->vq[index].call_fd, index); return false; } @@ -1175,6 +1182,7 @@ static bool vu_set_vring_err_exec(VuDev *dev, VhostUserMsg *vmsg) { int index = vmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK; + bool nofd = vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK; DPRINT("u64: 0x%016"PRIx64"\n", vmsg->payload.u64); @@ -1187,7 +1195,7 @@ vu_set_vring_err_exec(VuDev *dev, VhostUserMsg *vmsg) dev->vq[index].err_fd = -1; } - dev->vq[index].err_fd = vmsg->fds[0]; + dev->vq[index].err_fd = nofd ? -1 : vmsg->fds[0]; return false; } |