diff --git a/src/virtio-blk.c b/src/virtio-blk.c index d50dbba..d7ec168 100644 --- a/src/virtio-blk.c +++ b/src/virtio-blk.c @@ -103,7 +103,7 @@ static void virtio_blk_complete_request(struct virtq *vq) struct vring_packed_desc *used_desc = desc; ssize_t io_bytes = 0; - void *hdr = vm_guest_to_host(v, desc->addr); + void *hdr = vm_guest_buf(v, desc->addr, hdr_sz); if (!hdr || desc->len < hdr_sz) return; memcpy(&req, hdr, hdr_sz); @@ -112,33 +112,48 @@ static void virtio_blk_complete_request(struct virtq *vq) return; desc = virtq_get_avail(vq); req.data_size = desc->len; - req.data = vm_guest_to_host(v, desc->addr); - if (!req.data) - return; - - if (req.type == VIRTIO_BLK_T_IN) - io_bytes = virtio_blk_read(dev, req.data, req.sector << 9, - req.data_size); - else - io_bytes = virtio_blk_write(dev, req.data, req.sector << 9, - req.data_size); - - status = io_bytes < 0 ? VIRTIO_BLK_S_IOERR : VIRTIO_BLK_S_OK; + req.data = vm_guest_buf(v, desc->addr, req.data_size); + + /* Validate that the request fits in the backing store. Both the + * shift (sector*512) and the addition (offset+data_size) must not + * overflow, and the end must be within diskimg->size. Any failure + * yields VIRTIO_BLK_S_IOERR with no data transferred. */ + uint64_t off, end; + bool io_ok = false; + if (req.data && !__builtin_mul_overflow(req.sector, 512, &off) && + !__builtin_add_overflow(off, req.data_size, &end) && + end <= (uint64_t) dev->diskimg->size) { + if (req.type == VIRTIO_BLK_T_IN) + io_bytes = virtio_blk_read(dev, req.data, (off_t) off, + req.data_size); + else + io_bytes = virtio_blk_write(dev, req.data, (off_t) off, + req.data_size); + /* A short read/write leaves part of the guest buffer stale, + * so treat anything less than the full request as IOERR. */ + io_ok = io_bytes >= 0 && (size_t) io_bytes == req.data_size; + } + status = io_ok ? VIRTIO_BLK_S_OK : VIRTIO_BLK_S_IOERR; } else { status = VIRTIO_BLK_S_UNSUPP; } if (!virtq_check_next(desc)) return; desc = virtq_get_avail(vq); - req.status = vm_guest_to_host(v, desc->addr); + /* The status descriptor must advertise at least one device-writable + * byte; otherwise we'd clobber memory the guest did not offer. */ + if (desc->len < 1) + return; + req.status = vm_guest_buf(v, desc->addr, 1); if (!req.status) return; *req.status = status; /* used.len is total bytes the device wrote into device-writable * buffers across the chain: the 1-byte status is always written, plus - * io_bytes of data on a successful IN. */ + * the data buffer on a successful IN. On any error we report only the + * status byte so the guest does not consume stale data. */ size_t written = 1; - if (req.type == VIRTIO_BLK_T_IN && io_bytes > 0) + if (status == VIRTIO_BLK_S_OK && req.type == VIRTIO_BLK_T_IN) written += (size_t) io_bytes; used_desc->len = (uint32_t) written; used_desc->flags ^= (1ULL << VRING_PACKED_DESC_F_USED); diff --git a/src/virtio-net.c b/src/virtio-net.c index af90404..1fafc23 100644 --- a/src/virtio-net.c +++ b/src/virtio-net.c @@ -153,16 +153,24 @@ void virtio_net_complete_request_rx(struct virtq *vq) struct vring_packed_desc *desc; while ((desc = virtq_get_avail(vq)) != NULL) { - uint8_t *data = vm_guest_to_host(v, desc->addr); + size_t virtio_header_len = sizeof(struct virtio_net_hdr_v1); + /* desc lives in guest-writable memory; snapshot the length we'll + * validate and use so a concurrent guest write cannot widen the + * access past the bounds check. */ + uint32_t buf_len = desc->len; + uint8_t *data = vm_guest_buf(v, desc->addr, buf_len); + if (!data || buf_len < virtio_header_len) { + vq->guest_event->flags = VRING_PACKED_EVENT_FLAG_DISABLE; + return; + } struct virtio_net_hdr_v1 *virtio_hdr = (struct virtio_net_hdr_v1 *) data; memset(virtio_hdr, 0, sizeof(struct virtio_net_hdr_v1)); virtio_hdr->num_buffers = 1; - size_t virtio_header_len = sizeof(struct virtio_net_hdr_v1); ssize_t read_bytes = read(dev->tapfd, data + virtio_header_len, - desc->len - virtio_header_len); + buf_len - virtio_header_len); if (read_bytes < 0) { vq->guest_event->flags = VRING_PACKED_EVENT_FLAG_DISABLE; return; @@ -183,16 +191,18 @@ void virtio_net_complete_request_tx(struct virtq *vq) vm_t *v = container_of(dev, vm_t, virtio_net_dev); struct vring_packed_desc *desc; while ((desc = virtq_get_avail(vq)) != NULL) { - uint8_t *data = vm_guest_to_host(v, desc->addr); size_t virtio_header_len = sizeof(struct virtio_net_hdr_v1); + /* See rx path: snapshot len before bounds check to defeat TOCTOU. */ + uint32_t buf_len = desc->len; + uint8_t *data = vm_guest_buf(v, desc->addr, buf_len); - if (desc->len < virtio_header_len) { + if (!data || buf_len < virtio_header_len) { vq->guest_event->flags = VRING_PACKED_EVENT_FLAG_DISABLE; return; } uint8_t *actual_data = data + virtio_header_len; - size_t actual_data_len = desc->len - virtio_header_len; + size_t actual_data_len = buf_len - virtio_header_len; struct iovec iov[1]; iov[0].iov_base = actual_data; diff --git a/src/vm.c b/src/vm.c index 3876364..344670f 100644 --- a/src/vm.c +++ b/src/vm.c @@ -176,6 +176,15 @@ void *vm_guest_to_host(vm_t *v, uint64_t guest) return (void *) ((uintptr_t) v->mem + guest - RAM_BASE); } +void *vm_guest_buf(vm_t *v, uint64_t guest, size_t len) +{ + uint64_t end; + if (guest < RAM_BASE || __builtin_add_overflow(guest, len, &end) || + end > RAM_BASE + RAM_SIZE) + return NULL; + return (void *) ((uintptr_t) v->mem + guest - RAM_BASE); +} + void vm_irqfd_register(vm_t *v, int fd, int gsi, int flags) { struct kvm_irqfd irqfd = { diff --git a/src/vm.h b/src/vm.h index 5ebd9d3..da87e40 100644 --- a/src/vm.h +++ b/src/vm.h @@ -35,6 +35,7 @@ int vm_enable_net(vm_t *v); int vm_run(vm_t *v); int vm_irq_line(vm_t *v, int irq, int level); void *vm_guest_to_host(vm_t *v, uint64_t guest); +void *vm_guest_buf(vm_t *v, uint64_t guest, size_t len); void vm_irqfd_register(vm_t *v, int fd, int gsi, int flags); void vm_ioeventfd_register(vm_t *v, int fd,