Merge branch '10GbE' of git://git.kernel.org/pub/scm/linux/kernel/git/jkirsher/net...
[muen/linux.git] / drivers / vhost / vhost.c
1 /* Copyright (C) 2009 Red Hat, Inc.
2  * Copyright (C) 2006 Rusty Russell IBM Corporation
3  *
4  * Author: Michael S. Tsirkin <mst@redhat.com>
5  *
6  * Inspiration, some code, and most witty comments come from
7  * Documentation/virtual/lguest/lguest.c, by Rusty Russell
8  *
9  * This work is licensed under the terms of the GNU GPL, version 2.
10  *
11  * Generic code for virtio server in host kernel.
12  */
13
14 #include <linux/eventfd.h>
15 #include <linux/vhost.h>
16 #include <linux/uio.h>
17 #include <linux/mm.h>
18 #include <linux/mmu_context.h>
19 #include <linux/miscdevice.h>
20 #include <linux/mutex.h>
21 #include <linux/poll.h>
22 #include <linux/file.h>
23 #include <linux/highmem.h>
24 #include <linux/slab.h>
25 #include <linux/vmalloc.h>
26 #include <linux/kthread.h>
27 #include <linux/cgroup.h>
28 #include <linux/module.h>
29 #include <linux/sort.h>
30 #include <linux/sched/mm.h>
31 #include <linux/sched/signal.h>
32 #include <linux/interval_tree_generic.h>
33 #include <linux/nospec.h>
34
35 #include "vhost.h"
36
37 static ushort max_mem_regions = 64;
38 module_param(max_mem_regions, ushort, 0444);
39 MODULE_PARM_DESC(max_mem_regions,
40         "Maximum number of memory regions in memory map. (default: 64)");
41 static int max_iotlb_entries = 2048;
42 module_param(max_iotlb_entries, int, 0444);
43 MODULE_PARM_DESC(max_iotlb_entries,
44         "Maximum number of iotlb entries. (default: 2048)");
45
46 enum {
47         VHOST_MEMORY_F_LOG = 0x1,
48 };
49
50 #define vhost_used_event(vq) ((__virtio16 __user *)&vq->avail->ring[vq->num])
51 #define vhost_avail_event(vq) ((__virtio16 __user *)&vq->used->ring[vq->num])
52
53 INTERVAL_TREE_DEFINE(struct vhost_umem_node,
54                      rb, __u64, __subtree_last,
55                      START, LAST, static inline, vhost_umem_interval_tree);
56
57 #ifdef CONFIG_VHOST_CROSS_ENDIAN_LEGACY
58 static void vhost_disable_cross_endian(struct vhost_virtqueue *vq)
59 {
60         vq->user_be = !virtio_legacy_is_little_endian();
61 }
62
63 static void vhost_enable_cross_endian_big(struct vhost_virtqueue *vq)
64 {
65         vq->user_be = true;
66 }
67
68 static void vhost_enable_cross_endian_little(struct vhost_virtqueue *vq)
69 {
70         vq->user_be = false;
71 }
72
73 static long vhost_set_vring_endian(struct vhost_virtqueue *vq, int __user *argp)
74 {
75         struct vhost_vring_state s;
76
77         if (vq->private_data)
78                 return -EBUSY;
79
80         if (copy_from_user(&s, argp, sizeof(s)))
81                 return -EFAULT;
82
83         if (s.num != VHOST_VRING_LITTLE_ENDIAN &&
84             s.num != VHOST_VRING_BIG_ENDIAN)
85                 return -EINVAL;
86
87         if (s.num == VHOST_VRING_BIG_ENDIAN)
88                 vhost_enable_cross_endian_big(vq);
89         else
90                 vhost_enable_cross_endian_little(vq);
91
92         return 0;
93 }
94
95 static long vhost_get_vring_endian(struct vhost_virtqueue *vq, u32 idx,
96                                    int __user *argp)
97 {
98         struct vhost_vring_state s = {
99                 .index = idx,
100                 .num = vq->user_be
101         };
102
103         if (copy_to_user(argp, &s, sizeof(s)))
104                 return -EFAULT;
105
106         return 0;
107 }
108
109 static void vhost_init_is_le(struct vhost_virtqueue *vq)
110 {
111         /* Note for legacy virtio: user_be is initialized at reset time
112          * according to the host endianness. If userspace does not set an
113          * explicit endianness, the default behavior is native endian, as
114          * expected by legacy virtio.
115          */
116         vq->is_le = vhost_has_feature(vq, VIRTIO_F_VERSION_1) || !vq->user_be;
117 }
118 #else
119 static void vhost_disable_cross_endian(struct vhost_virtqueue *vq)
120 {
121 }
122
123 static long vhost_set_vring_endian(struct vhost_virtqueue *vq, int __user *argp)
124 {
125         return -ENOIOCTLCMD;
126 }
127
128 static long vhost_get_vring_endian(struct vhost_virtqueue *vq, u32 idx,
129                                    int __user *argp)
130 {
131         return -ENOIOCTLCMD;
132 }
133
134 static void vhost_init_is_le(struct vhost_virtqueue *vq)
135 {
136         vq->is_le = vhost_has_feature(vq, VIRTIO_F_VERSION_1)
137                 || virtio_legacy_is_little_endian();
138 }
139 #endif /* CONFIG_VHOST_CROSS_ENDIAN_LEGACY */
140
141 static void vhost_reset_is_le(struct vhost_virtqueue *vq)
142 {
143         vhost_init_is_le(vq);
144 }
145
146 struct vhost_flush_struct {
147         struct vhost_work work;
148         struct completion wait_event;
149 };
150
151 static void vhost_flush_work(struct vhost_work *work)
152 {
153         struct vhost_flush_struct *s;
154
155         s = container_of(work, struct vhost_flush_struct, work);
156         complete(&s->wait_event);
157 }
158
159 static void vhost_poll_func(struct file *file, wait_queue_head_t *wqh,
160                             poll_table *pt)
161 {
162         struct vhost_poll *poll;
163
164         poll = container_of(pt, struct vhost_poll, table);
165         poll->wqh = wqh;
166         add_wait_queue(wqh, &poll->wait);
167 }
168
169 static int vhost_poll_wakeup(wait_queue_entry_t *wait, unsigned mode, int sync,
170                              void *key)
171 {
172         struct vhost_poll *poll = container_of(wait, struct vhost_poll, wait);
173
174         if (!(key_to_poll(key) & poll->mask))
175                 return 0;
176
177         vhost_poll_queue(poll);
178         return 0;
179 }
180
181 void vhost_work_init(struct vhost_work *work, vhost_work_fn_t fn)
182 {
183         clear_bit(VHOST_WORK_QUEUED, &work->flags);
184         work->fn = fn;
185 }
186 EXPORT_SYMBOL_GPL(vhost_work_init);
187
188 /* Init poll structure */
189 void vhost_poll_init(struct vhost_poll *poll, vhost_work_fn_t fn,
190                      __poll_t mask, struct vhost_dev *dev)
191 {
192         init_waitqueue_func_entry(&poll->wait, vhost_poll_wakeup);
193         init_poll_funcptr(&poll->table, vhost_poll_func);
194         poll->mask = mask;
195         poll->dev = dev;
196         poll->wqh = NULL;
197
198         vhost_work_init(&poll->work, fn);
199 }
200 EXPORT_SYMBOL_GPL(vhost_poll_init);
201
202 /* Start polling a file. We add ourselves to file's wait queue. The caller must
203  * keep a reference to a file until after vhost_poll_stop is called. */
204 int vhost_poll_start(struct vhost_poll *poll, struct file *file)
205 {
206         __poll_t mask;
207         int ret = 0;
208
209         if (poll->wqh)
210                 return 0;
211
212         mask = vfs_poll(file, &poll->table);
213         if (mask)
214                 vhost_poll_wakeup(&poll->wait, 0, 0, poll_to_key(mask));
215         if (mask & EPOLLERR) {
216                 vhost_poll_stop(poll);
217                 ret = -EINVAL;
218         }
219
220         return ret;
221 }
222 EXPORT_SYMBOL_GPL(vhost_poll_start);
223
224 /* Stop polling a file. After this function returns, it becomes safe to drop the
225  * file reference. You must also flush afterwards. */
226 void vhost_poll_stop(struct vhost_poll *poll)
227 {
228         if (poll->wqh) {
229                 remove_wait_queue(poll->wqh, &poll->wait);
230                 poll->wqh = NULL;
231         }
232 }
233 EXPORT_SYMBOL_GPL(vhost_poll_stop);
234
235 void vhost_work_flush(struct vhost_dev *dev, struct vhost_work *work)
236 {
237         struct vhost_flush_struct flush;
238
239         if (dev->worker) {
240                 init_completion(&flush.wait_event);
241                 vhost_work_init(&flush.work, vhost_flush_work);
242
243                 vhost_work_queue(dev, &flush.work);
244                 wait_for_completion(&flush.wait_event);
245         }
246 }
247 EXPORT_SYMBOL_GPL(vhost_work_flush);
248
249 /* Flush any work that has been scheduled. When calling this, don't hold any
250  * locks that are also used by the callback. */
251 void vhost_poll_flush(struct vhost_poll *poll)
252 {
253         vhost_work_flush(poll->dev, &poll->work);
254 }
255 EXPORT_SYMBOL_GPL(vhost_poll_flush);
256
257 void vhost_work_queue(struct vhost_dev *dev, struct vhost_work *work)
258 {
259         if (!dev->worker)
260                 return;
261
262         if (!test_and_set_bit(VHOST_WORK_QUEUED, &work->flags)) {
263                 /* We can only add the work to the list after we're
264                  * sure it was not in the list.
265                  * test_and_set_bit() implies a memory barrier.
266                  */
267                 llist_add(&work->node, &dev->work_list);
268                 wake_up_process(dev->worker);
269         }
270 }
271 EXPORT_SYMBOL_GPL(vhost_work_queue);
272
273 /* A lockless hint for busy polling code to exit the loop */
274 bool vhost_has_work(struct vhost_dev *dev)
275 {
276         return !llist_empty(&dev->work_list);
277 }
278 EXPORT_SYMBOL_GPL(vhost_has_work);
279
280 void vhost_poll_queue(struct vhost_poll *poll)
281 {
282         vhost_work_queue(poll->dev, &poll->work);
283 }
284 EXPORT_SYMBOL_GPL(vhost_poll_queue);
285
286 static void __vhost_vq_meta_reset(struct vhost_virtqueue *vq)
287 {
288         int j;
289
290         for (j = 0; j < VHOST_NUM_ADDRS; j++)
291                 vq->meta_iotlb[j] = NULL;
292 }
293
294 static void vhost_vq_meta_reset(struct vhost_dev *d)
295 {
296         int i;
297
298         for (i = 0; i < d->nvqs; ++i) {
299                 mutex_lock(&d->vqs[i]->mutex);
300                 __vhost_vq_meta_reset(d->vqs[i]);
301                 mutex_unlock(&d->vqs[i]->mutex);
302         }
303 }
304
305 static void vhost_vq_reset(struct vhost_dev *dev,
306                            struct vhost_virtqueue *vq)
307 {
308         vq->num = 1;
309         vq->desc = NULL;
310         vq->avail = NULL;
311         vq->used = NULL;
312         vq->last_avail_idx = 0;
313         vq->avail_idx = 0;
314         vq->last_used_idx = 0;
315         vq->signalled_used = 0;
316         vq->signalled_used_valid = false;
317         vq->used_flags = 0;
318         vq->log_used = false;
319         vq->log_addr = -1ull;
320         vq->private_data = NULL;
321         vq->acked_features = 0;
322         vq->acked_backend_features = 0;
323         vq->log_base = NULL;
324         vq->error_ctx = NULL;
325         vq->kick = NULL;
326         vq->call_ctx = NULL;
327         vq->log_ctx = NULL;
328         vhost_reset_is_le(vq);
329         vhost_disable_cross_endian(vq);
330         vq->busyloop_timeout = 0;
331         vq->umem = NULL;
332         vq->iotlb = NULL;
333         __vhost_vq_meta_reset(vq);
334 }
335
336 static int vhost_worker(void *data)
337 {
338         struct vhost_dev *dev = data;
339         struct vhost_work *work, *work_next;
340         struct llist_node *node;
341         mm_segment_t oldfs = get_fs();
342
343         set_fs(USER_DS);
344         use_mm(dev->mm);
345
346         for (;;) {
347                 /* mb paired w/ kthread_stop */
348                 set_current_state(TASK_INTERRUPTIBLE);
349
350                 if (kthread_should_stop()) {
351                         __set_current_state(TASK_RUNNING);
352                         break;
353                 }
354
355                 node = llist_del_all(&dev->work_list);
356                 if (!node)
357                         schedule();
358
359                 node = llist_reverse_order(node);
360                 /* make sure flag is seen after deletion */
361                 smp_wmb();
362                 llist_for_each_entry_safe(work, work_next, node, node) {
363                         clear_bit(VHOST_WORK_QUEUED, &work->flags);
364                         __set_current_state(TASK_RUNNING);
365                         work->fn(work);
366                         if (need_resched())
367                                 schedule();
368                 }
369         }
370         unuse_mm(dev->mm);
371         set_fs(oldfs);
372         return 0;
373 }
374
375 static void vhost_vq_free_iovecs(struct vhost_virtqueue *vq)
376 {
377         kfree(vq->indirect);
378         vq->indirect = NULL;
379         kfree(vq->log);
380         vq->log = NULL;
381         kfree(vq->heads);
382         vq->heads = NULL;
383 }
384
385 /* Helper to allocate iovec buffers for all vqs. */
386 static long vhost_dev_alloc_iovecs(struct vhost_dev *dev)
387 {
388         struct vhost_virtqueue *vq;
389         int i;
390
391         for (i = 0; i < dev->nvqs; ++i) {
392                 vq = dev->vqs[i];
393                 vq->indirect = kmalloc_array(UIO_MAXIOV,
394                                              sizeof(*vq->indirect),
395                                              GFP_KERNEL);
396                 vq->log = kmalloc_array(UIO_MAXIOV, sizeof(*vq->log),
397                                         GFP_KERNEL);
398                 vq->heads = kmalloc_array(UIO_MAXIOV, sizeof(*vq->heads),
399                                           GFP_KERNEL);
400                 if (!vq->indirect || !vq->log || !vq->heads)
401                         goto err_nomem;
402         }
403         return 0;
404
405 err_nomem:
406         for (; i >= 0; --i)
407                 vhost_vq_free_iovecs(dev->vqs[i]);
408         return -ENOMEM;
409 }
410
411 static void vhost_dev_free_iovecs(struct vhost_dev *dev)
412 {
413         int i;
414
415         for (i = 0; i < dev->nvqs; ++i)
416                 vhost_vq_free_iovecs(dev->vqs[i]);
417 }
418
419 void vhost_dev_init(struct vhost_dev *dev,
420                     struct vhost_virtqueue **vqs, int nvqs)
421 {
422         struct vhost_virtqueue *vq;
423         int i;
424
425         dev->vqs = vqs;
426         dev->nvqs = nvqs;
427         mutex_init(&dev->mutex);
428         dev->log_ctx = NULL;
429         dev->umem = NULL;
430         dev->iotlb = NULL;
431         dev->mm = NULL;
432         dev->worker = NULL;
433         init_llist_head(&dev->work_list);
434         init_waitqueue_head(&dev->wait);
435         INIT_LIST_HEAD(&dev->read_list);
436         INIT_LIST_HEAD(&dev->pending_list);
437         spin_lock_init(&dev->iotlb_lock);
438
439
440         for (i = 0; i < dev->nvqs; ++i) {
441                 vq = dev->vqs[i];
442                 vq->log = NULL;
443                 vq->indirect = NULL;
444                 vq->heads = NULL;
445                 vq->dev = dev;
446                 mutex_init(&vq->mutex);
447                 vhost_vq_reset(dev, vq);
448                 if (vq->handle_kick)
449                         vhost_poll_init(&vq->poll, vq->handle_kick,
450                                         EPOLLIN, dev);
451         }
452 }
453 EXPORT_SYMBOL_GPL(vhost_dev_init);
454
455 /* Caller should have device mutex */
456 long vhost_dev_check_owner(struct vhost_dev *dev)
457 {
458         /* Are you the owner? If not, I don't think you mean to do that */
459         return dev->mm == current->mm ? 0 : -EPERM;
460 }
461 EXPORT_SYMBOL_GPL(vhost_dev_check_owner);
462
463 struct vhost_attach_cgroups_struct {
464         struct vhost_work work;
465         struct task_struct *owner;
466         int ret;
467 };
468
469 static void vhost_attach_cgroups_work(struct vhost_work *work)
470 {
471         struct vhost_attach_cgroups_struct *s;
472
473         s = container_of(work, struct vhost_attach_cgroups_struct, work);
474         s->ret = cgroup_attach_task_all(s->owner, current);
475 }
476
477 static int vhost_attach_cgroups(struct vhost_dev *dev)
478 {
479         struct vhost_attach_cgroups_struct attach;
480
481         attach.owner = current;
482         vhost_work_init(&attach.work, vhost_attach_cgroups_work);
483         vhost_work_queue(dev, &attach.work);
484         vhost_work_flush(dev, &attach.work);
485         return attach.ret;
486 }
487
488 /* Caller should have device mutex */
489 bool vhost_dev_has_owner(struct vhost_dev *dev)
490 {
491         return dev->mm;
492 }
493 EXPORT_SYMBOL_GPL(vhost_dev_has_owner);
494
495 /* Caller should have device mutex */
496 long vhost_dev_set_owner(struct vhost_dev *dev)
497 {
498         struct task_struct *worker;
499         int err;
500
501         /* Is there an owner already? */
502         if (vhost_dev_has_owner(dev)) {
503                 err = -EBUSY;
504                 goto err_mm;
505         }
506
507         /* No owner, become one */
508         dev->mm = get_task_mm(current);
509         worker = kthread_create(vhost_worker, dev, "vhost-%d", current->pid);
510         if (IS_ERR(worker)) {
511                 err = PTR_ERR(worker);
512                 goto err_worker;
513         }
514
515         dev->worker = worker;
516         wake_up_process(worker);        /* avoid contributing to loadavg */
517
518         err = vhost_attach_cgroups(dev);
519         if (err)
520                 goto err_cgroup;
521
522         err = vhost_dev_alloc_iovecs(dev);
523         if (err)
524                 goto err_cgroup;
525
526         return 0;
527 err_cgroup:
528         kthread_stop(worker);
529         dev->worker = NULL;
530 err_worker:
531         if (dev->mm)
532                 mmput(dev->mm);
533         dev->mm = NULL;
534 err_mm:
535         return err;
536 }
537 EXPORT_SYMBOL_GPL(vhost_dev_set_owner);
538
539 struct vhost_umem *vhost_dev_reset_owner_prepare(void)
540 {
541         return kvzalloc(sizeof(struct vhost_umem), GFP_KERNEL);
542 }
543 EXPORT_SYMBOL_GPL(vhost_dev_reset_owner_prepare);
544
545 /* Caller should have device mutex */
546 void vhost_dev_reset_owner(struct vhost_dev *dev, struct vhost_umem *umem)
547 {
548         int i;
549
550         vhost_dev_cleanup(dev);
551
552         /* Restore memory to default empty mapping. */
553         INIT_LIST_HEAD(&umem->umem_list);
554         dev->umem = umem;
555         /* We don't need VQ locks below since vhost_dev_cleanup makes sure
556          * VQs aren't running.
557          */
558         for (i = 0; i < dev->nvqs; ++i)
559                 dev->vqs[i]->umem = umem;
560 }
561 EXPORT_SYMBOL_GPL(vhost_dev_reset_owner);
562
563 void vhost_dev_stop(struct vhost_dev *dev)
564 {
565         int i;
566
567         for (i = 0; i < dev->nvqs; ++i) {
568                 if (dev->vqs[i]->kick && dev->vqs[i]->handle_kick) {
569                         vhost_poll_stop(&dev->vqs[i]->poll);
570                         vhost_poll_flush(&dev->vqs[i]->poll);
571                 }
572         }
573 }
574 EXPORT_SYMBOL_GPL(vhost_dev_stop);
575
576 static void vhost_umem_free(struct vhost_umem *umem,
577                             struct vhost_umem_node *node)
578 {
579         vhost_umem_interval_tree_remove(node, &umem->umem_tree);
580         list_del(&node->link);
581         kfree(node);
582         umem->numem--;
583 }
584
585 static void vhost_umem_clean(struct vhost_umem *umem)
586 {
587         struct vhost_umem_node *node, *tmp;
588
589         if (!umem)
590                 return;
591
592         list_for_each_entry_safe(node, tmp, &umem->umem_list, link)
593                 vhost_umem_free(umem, node);
594
595         kvfree(umem);
596 }
597
598 static void vhost_clear_msg(struct vhost_dev *dev)
599 {
600         struct vhost_msg_node *node, *n;
601
602         spin_lock(&dev->iotlb_lock);
603
604         list_for_each_entry_safe(node, n, &dev->read_list, node) {
605                 list_del(&node->node);
606                 kfree(node);
607         }
608
609         list_for_each_entry_safe(node, n, &dev->pending_list, node) {
610                 list_del(&node->node);
611                 kfree(node);
612         }
613
614         spin_unlock(&dev->iotlb_lock);
615 }
616
617 void vhost_dev_cleanup(struct vhost_dev *dev)
618 {
619         int i;
620
621         for (i = 0; i < dev->nvqs; ++i) {
622                 if (dev->vqs[i]->error_ctx)
623                         eventfd_ctx_put(dev->vqs[i]->error_ctx);
624                 if (dev->vqs[i]->kick)
625                         fput(dev->vqs[i]->kick);
626                 if (dev->vqs[i]->call_ctx)
627                         eventfd_ctx_put(dev->vqs[i]->call_ctx);
628                 vhost_vq_reset(dev, dev->vqs[i]);
629         }
630         vhost_dev_free_iovecs(dev);
631         if (dev->log_ctx)
632                 eventfd_ctx_put(dev->log_ctx);
633         dev->log_ctx = NULL;
634         /* No one will access memory at this point */
635         vhost_umem_clean(dev->umem);
636         dev->umem = NULL;
637         vhost_umem_clean(dev->iotlb);
638         dev->iotlb = NULL;
639         vhost_clear_msg(dev);
640         wake_up_interruptible_poll(&dev->wait, EPOLLIN | EPOLLRDNORM);
641         WARN_ON(!llist_empty(&dev->work_list));
642         if (dev->worker) {
643                 kthread_stop(dev->worker);
644                 dev->worker = NULL;
645         }
646         if (dev->mm)
647                 mmput(dev->mm);
648         dev->mm = NULL;
649 }
650 EXPORT_SYMBOL_GPL(vhost_dev_cleanup);
651
652 static bool log_access_ok(void __user *log_base, u64 addr, unsigned long sz)
653 {
654         u64 a = addr / VHOST_PAGE_SIZE / 8;
655
656         /* Make sure 64 bit math will not overflow. */
657         if (a > ULONG_MAX - (unsigned long)log_base ||
658             a + (unsigned long)log_base > ULONG_MAX)
659                 return false;
660
661         return access_ok(VERIFY_WRITE, log_base + a,
662                          (sz + VHOST_PAGE_SIZE * 8 - 1) / VHOST_PAGE_SIZE / 8);
663 }
664
665 static bool vhost_overflow(u64 uaddr, u64 size)
666 {
667         /* Make sure 64 bit math will not overflow. */
668         return uaddr > ULONG_MAX || size > ULONG_MAX || uaddr > ULONG_MAX - size;
669 }
670
671 /* Caller should have vq mutex and device mutex. */
672 static bool vq_memory_access_ok(void __user *log_base, struct vhost_umem *umem,
673                                 int log_all)
674 {
675         struct vhost_umem_node *node;
676
677         if (!umem)
678                 return false;
679
680         list_for_each_entry(node, &umem->umem_list, link) {
681                 unsigned long a = node->userspace_addr;
682
683                 if (vhost_overflow(node->userspace_addr, node->size))
684                         return false;
685
686
687                 if (!access_ok(VERIFY_WRITE, (void __user *)a,
688                                     node->size))
689                         return false;
690                 else if (log_all && !log_access_ok(log_base,
691                                                    node->start,
692                                                    node->size))
693                         return false;
694         }
695         return true;
696 }
697
698 static inline void __user *vhost_vq_meta_fetch(struct vhost_virtqueue *vq,
699                                                u64 addr, unsigned int size,
700                                                int type)
701 {
702         const struct vhost_umem_node *node = vq->meta_iotlb[type];
703
704         if (!node)
705                 return NULL;
706
707         return (void *)(uintptr_t)(node->userspace_addr + addr - node->start);
708 }
709
710 /* Can we switch to this memory table? */
711 /* Caller should have device mutex but not vq mutex */
712 static bool memory_access_ok(struct vhost_dev *d, struct vhost_umem *umem,
713                              int log_all)
714 {
715         int i;
716
717         for (i = 0; i < d->nvqs; ++i) {
718                 bool ok;
719                 bool log;
720
721                 mutex_lock(&d->vqs[i]->mutex);
722                 log = log_all || vhost_has_feature(d->vqs[i], VHOST_F_LOG_ALL);
723                 /* If ring is inactive, will check when it's enabled. */
724                 if (d->vqs[i]->private_data)
725                         ok = vq_memory_access_ok(d->vqs[i]->log_base,
726                                                  umem, log);
727                 else
728                         ok = true;
729                 mutex_unlock(&d->vqs[i]->mutex);
730                 if (!ok)
731                         return false;
732         }
733         return true;
734 }
735
736 static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len,
737                           struct iovec iov[], int iov_size, int access);
738
739 static int vhost_copy_to_user(struct vhost_virtqueue *vq, void __user *to,
740                               const void *from, unsigned size)
741 {
742         int ret;
743
744         if (!vq->iotlb)
745                 return __copy_to_user(to, from, size);
746         else {
747                 /* This function should be called after iotlb
748                  * prefetch, which means we're sure that all vq
749                  * could be access through iotlb. So -EAGAIN should
750                  * not happen in this case.
751                  */
752                 struct iov_iter t;
753                 void __user *uaddr = vhost_vq_meta_fetch(vq,
754                                      (u64)(uintptr_t)to, size,
755                                      VHOST_ADDR_USED);
756
757                 if (uaddr)
758                         return __copy_to_user(uaddr, from, size);
759
760                 ret = translate_desc(vq, (u64)(uintptr_t)to, size, vq->iotlb_iov,
761                                      ARRAY_SIZE(vq->iotlb_iov),
762                                      VHOST_ACCESS_WO);
763                 if (ret < 0)
764                         goto out;
765                 iov_iter_init(&t, WRITE, vq->iotlb_iov, ret, size);
766                 ret = copy_to_iter(from, size, &t);
767                 if (ret == size)
768                         ret = 0;
769         }
770 out:
771         return ret;
772 }
773
774 static int vhost_copy_from_user(struct vhost_virtqueue *vq, void *to,
775                                 void __user *from, unsigned size)
776 {
777         int ret;
778
779         if (!vq->iotlb)
780                 return __copy_from_user(to, from, size);
781         else {
782                 /* This function should be called after iotlb
783                  * prefetch, which means we're sure that vq
784                  * could be access through iotlb. So -EAGAIN should
785                  * not happen in this case.
786                  */
787                 void __user *uaddr = vhost_vq_meta_fetch(vq,
788                                      (u64)(uintptr_t)from, size,
789                                      VHOST_ADDR_DESC);
790                 struct iov_iter f;
791
792                 if (uaddr)
793                         return __copy_from_user(to, uaddr, size);
794
795                 ret = translate_desc(vq, (u64)(uintptr_t)from, size, vq->iotlb_iov,
796                                      ARRAY_SIZE(vq->iotlb_iov),
797                                      VHOST_ACCESS_RO);
798                 if (ret < 0) {
799                         vq_err(vq, "IOTLB translation failure: uaddr "
800                                "%p size 0x%llx\n", from,
801                                (unsigned long long) size);
802                         goto out;
803                 }
804                 iov_iter_init(&f, READ, vq->iotlb_iov, ret, size);
805                 ret = copy_from_iter(to, size, &f);
806                 if (ret == size)
807                         ret = 0;
808         }
809
810 out:
811         return ret;
812 }
813
814 static void __user *__vhost_get_user_slow(struct vhost_virtqueue *vq,
815                                           void __user *addr, unsigned int size,
816                                           int type)
817 {
818         int ret;
819
820         ret = translate_desc(vq, (u64)(uintptr_t)addr, size, vq->iotlb_iov,
821                              ARRAY_SIZE(vq->iotlb_iov),
822                              VHOST_ACCESS_RO);
823         if (ret < 0) {
824                 vq_err(vq, "IOTLB translation failure: uaddr "
825                         "%p size 0x%llx\n", addr,
826                         (unsigned long long) size);
827                 return NULL;
828         }
829
830         if (ret != 1 || vq->iotlb_iov[0].iov_len != size) {
831                 vq_err(vq, "Non atomic userspace memory access: uaddr "
832                         "%p size 0x%llx\n", addr,
833                         (unsigned long long) size);
834                 return NULL;
835         }
836
837         return vq->iotlb_iov[0].iov_base;
838 }
839
840 /* This function should be called after iotlb
841  * prefetch, which means we're sure that vq
842  * could be access through iotlb. So -EAGAIN should
843  * not happen in this case.
844  */
845 static inline void __user *__vhost_get_user(struct vhost_virtqueue *vq,
846                                             void *addr, unsigned int size,
847                                             int type)
848 {
849         void __user *uaddr = vhost_vq_meta_fetch(vq,
850                              (u64)(uintptr_t)addr, size, type);
851         if (uaddr)
852                 return uaddr;
853
854         return __vhost_get_user_slow(vq, addr, size, type);
855 }
856
857 #define vhost_put_user(vq, x, ptr)              \
858 ({ \
859         int ret = -EFAULT; \
860         if (!vq->iotlb) { \
861                 ret = __put_user(x, ptr); \
862         } else { \
863                 __typeof__(ptr) to = \
864                         (__typeof__(ptr)) __vhost_get_user(vq, ptr,     \
865                                           sizeof(*ptr), VHOST_ADDR_USED); \
866                 if (to != NULL) \
867                         ret = __put_user(x, to); \
868                 else \
869                         ret = -EFAULT;  \
870         } \
871         ret; \
872 })
873
874 #define vhost_get_user(vq, x, ptr, type)                \
875 ({ \
876         int ret; \
877         if (!vq->iotlb) { \
878                 ret = __get_user(x, ptr); \
879         } else { \
880                 __typeof__(ptr) from = \
881                         (__typeof__(ptr)) __vhost_get_user(vq, ptr, \
882                                                            sizeof(*ptr), \
883                                                            type); \
884                 if (from != NULL) \
885                         ret = __get_user(x, from); \
886                 else \
887                         ret = -EFAULT; \
888         } \
889         ret; \
890 })
891
892 #define vhost_get_avail(vq, x, ptr) \
893         vhost_get_user(vq, x, ptr, VHOST_ADDR_AVAIL)
894
895 #define vhost_get_used(vq, x, ptr) \
896         vhost_get_user(vq, x, ptr, VHOST_ADDR_USED)
897
898 static int vhost_new_umem_range(struct vhost_umem *umem,
899                                 u64 start, u64 size, u64 end,
900                                 u64 userspace_addr, int perm)
901 {
902         struct vhost_umem_node *tmp, *node = kmalloc(sizeof(*node), GFP_ATOMIC);
903
904         if (!node)
905                 return -ENOMEM;
906
907         if (umem->numem == max_iotlb_entries) {
908                 tmp = list_first_entry(&umem->umem_list, typeof(*tmp), link);
909                 vhost_umem_free(umem, tmp);
910         }
911
912         node->start = start;
913         node->size = size;
914         node->last = end;
915         node->userspace_addr = userspace_addr;
916         node->perm = perm;
917         INIT_LIST_HEAD(&node->link);
918         list_add_tail(&node->link, &umem->umem_list);
919         vhost_umem_interval_tree_insert(node, &umem->umem_tree);
920         umem->numem++;
921
922         return 0;
923 }
924
925 static void vhost_del_umem_range(struct vhost_umem *umem,
926                                  u64 start, u64 end)
927 {
928         struct vhost_umem_node *node;
929
930         while ((node = vhost_umem_interval_tree_iter_first(&umem->umem_tree,
931                                                            start, end)))
932                 vhost_umem_free(umem, node);
933 }
934
935 static void vhost_iotlb_notify_vq(struct vhost_dev *d,
936                                   struct vhost_iotlb_msg *msg)
937 {
938         struct vhost_msg_node *node, *n;
939
940         spin_lock(&d->iotlb_lock);
941
942         list_for_each_entry_safe(node, n, &d->pending_list, node) {
943                 struct vhost_iotlb_msg *vq_msg = &node->msg.iotlb;
944                 if (msg->iova <= vq_msg->iova &&
945                     msg->iova + msg->size - 1 >= vq_msg->iova &&
946                     vq_msg->type == VHOST_IOTLB_MISS) {
947                         mutex_lock(&node->vq->mutex);
948                         vhost_poll_queue(&node->vq->poll);
949                         mutex_unlock(&node->vq->mutex);
950
951                         list_del(&node->node);
952                         kfree(node);
953                 }
954         }
955
956         spin_unlock(&d->iotlb_lock);
957 }
958
959 static bool umem_access_ok(u64 uaddr, u64 size, int access)
960 {
961         unsigned long a = uaddr;
962
963         /* Make sure 64 bit math will not overflow. */
964         if (vhost_overflow(uaddr, size))
965                 return false;
966
967         if ((access & VHOST_ACCESS_RO) &&
968             !access_ok(VERIFY_READ, (void __user *)a, size))
969                 return false;
970         if ((access & VHOST_ACCESS_WO) &&
971             !access_ok(VERIFY_WRITE, (void __user *)a, size))
972                 return false;
973         return true;
974 }
975
976 static int vhost_process_iotlb_msg(struct vhost_dev *dev,
977                                    struct vhost_iotlb_msg *msg)
978 {
979         int ret = 0;
980
981         mutex_lock(&dev->mutex);
982         switch (msg->type) {
983         case VHOST_IOTLB_UPDATE:
984                 if (!dev->iotlb) {
985                         ret = -EFAULT;
986                         break;
987                 }
988                 if (!umem_access_ok(msg->uaddr, msg->size, msg->perm)) {
989                         ret = -EFAULT;
990                         break;
991                 }
992                 vhost_vq_meta_reset(dev);
993                 if (vhost_new_umem_range(dev->iotlb, msg->iova, msg->size,
994                                          msg->iova + msg->size - 1,
995                                          msg->uaddr, msg->perm)) {
996                         ret = -ENOMEM;
997                         break;
998                 }
999                 vhost_iotlb_notify_vq(dev, msg);
1000                 break;
1001         case VHOST_IOTLB_INVALIDATE:
1002                 if (!dev->iotlb) {
1003                         ret = -EFAULT;
1004                         break;
1005                 }
1006                 vhost_vq_meta_reset(dev);
1007                 vhost_del_umem_range(dev->iotlb, msg->iova,
1008                                      msg->iova + msg->size - 1);
1009                 break;
1010         default:
1011                 ret = -EINVAL;
1012                 break;
1013         }
1014
1015         mutex_unlock(&dev->mutex);
1016
1017         return ret;
1018 }
1019 ssize_t vhost_chr_write_iter(struct vhost_dev *dev,
1020                              struct iov_iter *from)
1021 {
1022         struct vhost_iotlb_msg msg;
1023         size_t offset;
1024         int type, ret;
1025
1026         ret = copy_from_iter(&type, sizeof(type), from);
1027         if (ret != sizeof(type))
1028                 goto done;
1029
1030         switch (type) {
1031         case VHOST_IOTLB_MSG:
1032                 /* There maybe a hole after type for V1 message type,
1033                  * so skip it here.
1034                  */
1035                 offset = offsetof(struct vhost_msg, iotlb) - sizeof(int);
1036                 break;
1037         case VHOST_IOTLB_MSG_V2:
1038                 offset = sizeof(__u32);
1039                 break;
1040         default:
1041                 ret = -EINVAL;
1042                 goto done;
1043         }
1044
1045         iov_iter_advance(from, offset);
1046         ret = copy_from_iter(&msg, sizeof(msg), from);
1047         if (ret != sizeof(msg))
1048                 goto done;
1049         if (vhost_process_iotlb_msg(dev, &msg)) {
1050                 ret = -EFAULT;
1051                 goto done;
1052         }
1053
1054         ret = (type == VHOST_IOTLB_MSG) ? sizeof(struct vhost_msg) :
1055               sizeof(struct vhost_msg_v2);
1056 done:
1057         return ret;
1058 }
1059 EXPORT_SYMBOL(vhost_chr_write_iter);
1060
1061 __poll_t vhost_chr_poll(struct file *file, struct vhost_dev *dev,
1062                             poll_table *wait)
1063 {
1064         __poll_t mask = 0;
1065
1066         poll_wait(file, &dev->wait, wait);
1067
1068         if (!list_empty(&dev->read_list))
1069                 mask |= EPOLLIN | EPOLLRDNORM;
1070
1071         return mask;
1072 }
1073 EXPORT_SYMBOL(vhost_chr_poll);
1074
1075 ssize_t vhost_chr_read_iter(struct vhost_dev *dev, struct iov_iter *to,
1076                             int noblock)
1077 {
1078         DEFINE_WAIT(wait);
1079         struct vhost_msg_node *node;
1080         ssize_t ret = 0;
1081         unsigned size = sizeof(struct vhost_msg);
1082
1083         if (iov_iter_count(to) < size)
1084                 return 0;
1085
1086         while (1) {
1087                 if (!noblock)
1088                         prepare_to_wait(&dev->wait, &wait,
1089                                         TASK_INTERRUPTIBLE);
1090
1091                 node = vhost_dequeue_msg(dev, &dev->read_list);
1092                 if (node)
1093                         break;
1094                 if (noblock) {
1095                         ret = -EAGAIN;
1096                         break;
1097                 }
1098                 if (signal_pending(current)) {
1099                         ret = -ERESTARTSYS;
1100                         break;
1101                 }
1102                 if (!dev->iotlb) {
1103                         ret = -EBADFD;
1104                         break;
1105                 }
1106
1107                 schedule();
1108         }
1109
1110         if (!noblock)
1111                 finish_wait(&dev->wait, &wait);
1112
1113         if (node) {
1114                 struct vhost_iotlb_msg *msg;
1115                 void *start = &node->msg;
1116
1117                 switch (node->msg.type) {
1118                 case VHOST_IOTLB_MSG:
1119                         size = sizeof(node->msg);
1120                         msg = &node->msg.iotlb;
1121                         break;
1122                 case VHOST_IOTLB_MSG_V2:
1123                         size = sizeof(node->msg_v2);
1124                         msg = &node->msg_v2.iotlb;
1125                         break;
1126                 default:
1127                         BUG();
1128                         break;
1129                 }
1130
1131                 ret = copy_to_iter(start, size, to);
1132                 if (ret != size || msg->type != VHOST_IOTLB_MISS) {
1133                         kfree(node);
1134                         return ret;
1135                 }
1136                 vhost_enqueue_msg(dev, &dev->pending_list, node);
1137         }
1138
1139         return ret;
1140 }
1141 EXPORT_SYMBOL_GPL(vhost_chr_read_iter);
1142
1143 static int vhost_iotlb_miss(struct vhost_virtqueue *vq, u64 iova, int access)
1144 {
1145         struct vhost_dev *dev = vq->dev;
1146         struct vhost_msg_node *node;
1147         struct vhost_iotlb_msg *msg;
1148         bool v2 = vhost_backend_has_feature(vq, VHOST_BACKEND_F_IOTLB_MSG_V2);
1149
1150         node = vhost_new_msg(vq, v2 ? VHOST_IOTLB_MSG_V2 : VHOST_IOTLB_MSG);
1151         if (!node)
1152                 return -ENOMEM;
1153
1154         if (v2) {
1155                 node->msg_v2.type = VHOST_IOTLB_MSG_V2;
1156                 msg = &node->msg_v2.iotlb;
1157         } else {
1158                 msg = &node->msg.iotlb;
1159         }
1160
1161         msg->type = VHOST_IOTLB_MISS;
1162         msg->iova = iova;
1163         msg->perm = access;
1164
1165         vhost_enqueue_msg(dev, &dev->read_list, node);
1166
1167         return 0;
1168 }
1169
1170 static bool vq_access_ok(struct vhost_virtqueue *vq, unsigned int num,
1171                          struct vring_desc __user *desc,
1172                          struct vring_avail __user *avail,
1173                          struct vring_used __user *used)
1174
1175 {
1176         size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
1177
1178         return access_ok(VERIFY_READ, desc, num * sizeof *desc) &&
1179                access_ok(VERIFY_READ, avail,
1180                          sizeof *avail + num * sizeof *avail->ring + s) &&
1181                access_ok(VERIFY_WRITE, used,
1182                         sizeof *used + num * sizeof *used->ring + s);
1183 }
1184
1185 static void vhost_vq_meta_update(struct vhost_virtqueue *vq,
1186                                  const struct vhost_umem_node *node,
1187                                  int type)
1188 {
1189         int access = (type == VHOST_ADDR_USED) ?
1190                      VHOST_ACCESS_WO : VHOST_ACCESS_RO;
1191
1192         if (likely(node->perm & access))
1193                 vq->meta_iotlb[type] = node;
1194 }
1195
1196 static bool iotlb_access_ok(struct vhost_virtqueue *vq,
1197                             int access, u64 addr, u64 len, int type)
1198 {
1199         const struct vhost_umem_node *node;
1200         struct vhost_umem *umem = vq->iotlb;
1201         u64 s = 0, size, orig_addr = addr, last = addr + len - 1;
1202
1203         if (vhost_vq_meta_fetch(vq, addr, len, type))
1204                 return true;
1205
1206         while (len > s) {
1207                 node = vhost_umem_interval_tree_iter_first(&umem->umem_tree,
1208                                                            addr,
1209                                                            last);
1210                 if (node == NULL || node->start > addr) {
1211                         vhost_iotlb_miss(vq, addr, access);
1212                         return false;
1213                 } else if (!(node->perm & access)) {
1214                         /* Report the possible access violation by
1215                          * request another translation from userspace.
1216                          */
1217                         return false;
1218                 }
1219
1220                 size = node->size - addr + node->start;
1221
1222                 if (orig_addr == addr && size >= len)
1223                         vhost_vq_meta_update(vq, node, type);
1224
1225                 s += size;
1226                 addr += size;
1227         }
1228
1229         return true;
1230 }
1231
1232 int vq_iotlb_prefetch(struct vhost_virtqueue *vq)
1233 {
1234         size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
1235         unsigned int num = vq->num;
1236
1237         if (!vq->iotlb)
1238                 return 1;
1239
1240         return iotlb_access_ok(vq, VHOST_ACCESS_RO, (u64)(uintptr_t)vq->desc,
1241                                num * sizeof(*vq->desc), VHOST_ADDR_DESC) &&
1242                iotlb_access_ok(vq, VHOST_ACCESS_RO, (u64)(uintptr_t)vq->avail,
1243                                sizeof *vq->avail +
1244                                num * sizeof(*vq->avail->ring) + s,
1245                                VHOST_ADDR_AVAIL) &&
1246                iotlb_access_ok(vq, VHOST_ACCESS_WO, (u64)(uintptr_t)vq->used,
1247                                sizeof *vq->used +
1248                                num * sizeof(*vq->used->ring) + s,
1249                                VHOST_ADDR_USED);
1250 }
1251 EXPORT_SYMBOL_GPL(vq_iotlb_prefetch);
1252
1253 /* Can we log writes? */
1254 /* Caller should have device mutex but not vq mutex */
1255 bool vhost_log_access_ok(struct vhost_dev *dev)
1256 {
1257         return memory_access_ok(dev, dev->umem, 1);
1258 }
1259 EXPORT_SYMBOL_GPL(vhost_log_access_ok);
1260
1261 /* Verify access for write logging. */
1262 /* Caller should have vq mutex and device mutex */
1263 static bool vq_log_access_ok(struct vhost_virtqueue *vq,
1264                              void __user *log_base)
1265 {
1266         size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
1267
1268         return vq_memory_access_ok(log_base, vq->umem,
1269                                    vhost_has_feature(vq, VHOST_F_LOG_ALL)) &&
1270                 (!vq->log_used || log_access_ok(log_base, vq->log_addr,
1271                                         sizeof *vq->used +
1272                                         vq->num * sizeof *vq->used->ring + s));
1273 }
1274
1275 /* Can we start vq? */
1276 /* Caller should have vq mutex and device mutex */
1277 bool vhost_vq_access_ok(struct vhost_virtqueue *vq)
1278 {
1279         if (!vq_log_access_ok(vq, vq->log_base))
1280                 return false;
1281
1282         /* Access validation occurs at prefetch time with IOTLB */
1283         if (vq->iotlb)
1284                 return true;
1285
1286         return vq_access_ok(vq, vq->num, vq->desc, vq->avail, vq->used);
1287 }
1288 EXPORT_SYMBOL_GPL(vhost_vq_access_ok);
1289
1290 static struct vhost_umem *vhost_umem_alloc(void)
1291 {
1292         struct vhost_umem *umem = kvzalloc(sizeof(*umem), GFP_KERNEL);
1293
1294         if (!umem)
1295                 return NULL;
1296
1297         umem->umem_tree = RB_ROOT_CACHED;
1298         umem->numem = 0;
1299         INIT_LIST_HEAD(&umem->umem_list);
1300
1301         return umem;
1302 }
1303
1304 static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m)
1305 {
1306         struct vhost_memory mem, *newmem;
1307         struct vhost_memory_region *region;
1308         struct vhost_umem *newumem, *oldumem;
1309         unsigned long size = offsetof(struct vhost_memory, regions);
1310         int i;
1311
1312         if (copy_from_user(&mem, m, size))
1313                 return -EFAULT;
1314         if (mem.padding)
1315                 return -EOPNOTSUPP;
1316         if (mem.nregions > max_mem_regions)
1317                 return -E2BIG;
1318         newmem = kvzalloc(struct_size(newmem, regions, mem.nregions),
1319                         GFP_KERNEL);
1320         if (!newmem)
1321                 return -ENOMEM;
1322
1323         memcpy(newmem, &mem, size);
1324         if (copy_from_user(newmem->regions, m->regions,
1325                            mem.nregions * sizeof *m->regions)) {
1326                 kvfree(newmem);
1327                 return -EFAULT;
1328         }
1329
1330         newumem = vhost_umem_alloc();
1331         if (!newumem) {
1332                 kvfree(newmem);
1333                 return -ENOMEM;
1334         }
1335
1336         for (region = newmem->regions;
1337              region < newmem->regions + mem.nregions;
1338              region++) {
1339                 if (vhost_new_umem_range(newumem,
1340                                          region->guest_phys_addr,
1341                                          region->memory_size,
1342                                          region->guest_phys_addr +
1343                                          region->memory_size - 1,
1344                                          region->userspace_addr,
1345                                          VHOST_ACCESS_RW))
1346                         goto err;
1347         }
1348
1349         if (!memory_access_ok(d, newumem, 0))
1350                 goto err;
1351
1352         oldumem = d->umem;
1353         d->umem = newumem;
1354
1355         /* All memory accesses are done under some VQ mutex. */
1356         for (i = 0; i < d->nvqs; ++i) {
1357                 mutex_lock(&d->vqs[i]->mutex);
1358                 d->vqs[i]->umem = newumem;
1359                 mutex_unlock(&d->vqs[i]->mutex);
1360         }
1361
1362         kvfree(newmem);
1363         vhost_umem_clean(oldumem);
1364         return 0;
1365
1366 err:
1367         vhost_umem_clean(newumem);
1368         kvfree(newmem);
1369         return -EFAULT;
1370 }
1371
1372 long vhost_vring_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *argp)
1373 {
1374         struct file *eventfp, *filep = NULL;
1375         bool pollstart = false, pollstop = false;
1376         struct eventfd_ctx *ctx = NULL;
1377         u32 __user *idxp = argp;
1378         struct vhost_virtqueue *vq;
1379         struct vhost_vring_state s;
1380         struct vhost_vring_file f;
1381         struct vhost_vring_addr a;
1382         u32 idx;
1383         long r;
1384
1385         r = get_user(idx, idxp);
1386         if (r < 0)
1387                 return r;
1388         if (idx >= d->nvqs)
1389                 return -ENOBUFS;
1390
1391         idx = array_index_nospec(idx, d->nvqs);
1392         vq = d->vqs[idx];
1393
1394         mutex_lock(&vq->mutex);
1395
1396         switch (ioctl) {
1397         case VHOST_SET_VRING_NUM:
1398                 /* Resizing ring with an active backend?
1399                  * You don't want to do that. */
1400                 if (vq->private_data) {
1401                         r = -EBUSY;
1402                         break;
1403                 }
1404                 if (copy_from_user(&s, argp, sizeof s)) {
1405                         r = -EFAULT;
1406                         break;
1407                 }
1408                 if (!s.num || s.num > 0xffff || (s.num & (s.num - 1))) {
1409                         r = -EINVAL;
1410                         break;
1411                 }
1412                 vq->num = s.num;
1413                 break;
1414         case VHOST_SET_VRING_BASE:
1415                 /* Moving base with an active backend?
1416                  * You don't want to do that. */
1417                 if (vq->private_data) {
1418                         r = -EBUSY;
1419                         break;
1420                 }
1421                 if (copy_from_user(&s, argp, sizeof s)) {
1422                         r = -EFAULT;
1423                         break;
1424                 }
1425                 if (s.num > 0xffff) {
1426                         r = -EINVAL;
1427                         break;
1428                 }
1429                 vq->last_avail_idx = s.num;
1430                 /* Forget the cached index value. */
1431                 vq->avail_idx = vq->last_avail_idx;
1432                 break;
1433         case VHOST_GET_VRING_BASE:
1434                 s.index = idx;
1435                 s.num = vq->last_avail_idx;
1436                 if (copy_to_user(argp, &s, sizeof s))
1437                         r = -EFAULT;
1438                 break;
1439         case VHOST_SET_VRING_ADDR:
1440                 if (copy_from_user(&a, argp, sizeof a)) {
1441                         r = -EFAULT;
1442                         break;
1443                 }
1444                 if (a.flags & ~(0x1 << VHOST_VRING_F_LOG)) {
1445                         r = -EOPNOTSUPP;
1446                         break;
1447                 }
1448                 /* For 32bit, verify that the top 32bits of the user
1449                    data are set to zero. */
1450                 if ((u64)(unsigned long)a.desc_user_addr != a.desc_user_addr ||
1451                     (u64)(unsigned long)a.used_user_addr != a.used_user_addr ||
1452                     (u64)(unsigned long)a.avail_user_addr != a.avail_user_addr) {
1453                         r = -EFAULT;
1454                         break;
1455                 }
1456
1457                 /* Make sure it's safe to cast pointers to vring types. */
1458                 BUILD_BUG_ON(__alignof__ *vq->avail > VRING_AVAIL_ALIGN_SIZE);
1459                 BUILD_BUG_ON(__alignof__ *vq->used > VRING_USED_ALIGN_SIZE);
1460                 if ((a.avail_user_addr & (VRING_AVAIL_ALIGN_SIZE - 1)) ||
1461                     (a.used_user_addr & (VRING_USED_ALIGN_SIZE - 1)) ||
1462                     (a.log_guest_addr & (VRING_USED_ALIGN_SIZE - 1))) {
1463                         r = -EINVAL;
1464                         break;
1465                 }
1466
1467                 /* We only verify access here if backend is configured.
1468                  * If it is not, we don't as size might not have been setup.
1469                  * We will verify when backend is configured. */
1470                 if (vq->private_data) {
1471                         if (!vq_access_ok(vq, vq->num,
1472                                 (void __user *)(unsigned long)a.desc_user_addr,
1473                                 (void __user *)(unsigned long)a.avail_user_addr,
1474                                 (void __user *)(unsigned long)a.used_user_addr)) {
1475                                 r = -EINVAL;
1476                                 break;
1477                         }
1478
1479                         /* Also validate log access for used ring if enabled. */
1480                         if ((a.flags & (0x1 << VHOST_VRING_F_LOG)) &&
1481                             !log_access_ok(vq->log_base, a.log_guest_addr,
1482                                            sizeof *vq->used +
1483                                            vq->num * sizeof *vq->used->ring)) {
1484                                 r = -EINVAL;
1485                                 break;
1486                         }
1487                 }
1488
1489                 vq->log_used = !!(a.flags & (0x1 << VHOST_VRING_F_LOG));
1490                 vq->desc = (void __user *)(unsigned long)a.desc_user_addr;
1491                 vq->avail = (void __user *)(unsigned long)a.avail_user_addr;
1492                 vq->log_addr = a.log_guest_addr;
1493                 vq->used = (void __user *)(unsigned long)a.used_user_addr;
1494                 break;
1495         case VHOST_SET_VRING_KICK:
1496                 if (copy_from_user(&f, argp, sizeof f)) {
1497                         r = -EFAULT;
1498                         break;
1499                 }
1500                 eventfp = f.fd == -1 ? NULL : eventfd_fget(f.fd);
1501                 if (IS_ERR(eventfp)) {
1502                         r = PTR_ERR(eventfp);
1503                         break;
1504                 }
1505                 if (eventfp != vq->kick) {
1506                         pollstop = (filep = vq->kick) != NULL;
1507                         pollstart = (vq->kick = eventfp) != NULL;
1508                 } else
1509                         filep = eventfp;
1510                 break;
1511         case VHOST_SET_VRING_CALL:
1512                 if (copy_from_user(&f, argp, sizeof f)) {
1513                         r = -EFAULT;
1514                         break;
1515                 }
1516                 ctx = f.fd == -1 ? NULL : eventfd_ctx_fdget(f.fd);
1517                 if (IS_ERR(ctx)) {
1518                         r = PTR_ERR(ctx);
1519                         break;
1520                 }
1521                 swap(ctx, vq->call_ctx);
1522                 break;
1523         case VHOST_SET_VRING_ERR:
1524                 if (copy_from_user(&f, argp, sizeof f)) {
1525                         r = -EFAULT;
1526                         break;
1527                 }
1528                 ctx = f.fd == -1 ? NULL : eventfd_ctx_fdget(f.fd);
1529                 if (IS_ERR(ctx)) {
1530                         r = PTR_ERR(ctx);
1531                         break;
1532                 }
1533                 swap(ctx, vq->error_ctx);
1534                 break;
1535         case VHOST_SET_VRING_ENDIAN:
1536                 r = vhost_set_vring_endian(vq, argp);
1537                 break;
1538         case VHOST_GET_VRING_ENDIAN:
1539                 r = vhost_get_vring_endian(vq, idx, argp);
1540                 break;
1541         case VHOST_SET_VRING_BUSYLOOP_TIMEOUT:
1542                 if (copy_from_user(&s, argp, sizeof(s))) {
1543                         r = -EFAULT;
1544                         break;
1545                 }
1546                 vq->busyloop_timeout = s.num;
1547                 break;
1548         case VHOST_GET_VRING_BUSYLOOP_TIMEOUT:
1549                 s.index = idx;
1550                 s.num = vq->busyloop_timeout;
1551                 if (copy_to_user(argp, &s, sizeof(s)))
1552                         r = -EFAULT;
1553                 break;
1554         default:
1555                 r = -ENOIOCTLCMD;
1556         }
1557
1558         if (pollstop && vq->handle_kick)
1559                 vhost_poll_stop(&vq->poll);
1560
1561         if (!IS_ERR_OR_NULL(ctx))
1562                 eventfd_ctx_put(ctx);
1563         if (filep)
1564                 fput(filep);
1565
1566         if (pollstart && vq->handle_kick)
1567                 r = vhost_poll_start(&vq->poll, vq->kick);
1568
1569         mutex_unlock(&vq->mutex);
1570
1571         if (pollstop && vq->handle_kick)
1572                 vhost_poll_flush(&vq->poll);
1573         return r;
1574 }
1575 EXPORT_SYMBOL_GPL(vhost_vring_ioctl);
1576
1577 int vhost_init_device_iotlb(struct vhost_dev *d, bool enabled)
1578 {
1579         struct vhost_umem *niotlb, *oiotlb;
1580         int i;
1581
1582         niotlb = vhost_umem_alloc();
1583         if (!niotlb)
1584                 return -ENOMEM;
1585
1586         oiotlb = d->iotlb;
1587         d->iotlb = niotlb;
1588
1589         for (i = 0; i < d->nvqs; ++i) {
1590                 struct vhost_virtqueue *vq = d->vqs[i];
1591
1592                 mutex_lock(&vq->mutex);
1593                 vq->iotlb = niotlb;
1594                 __vhost_vq_meta_reset(vq);
1595                 mutex_unlock(&vq->mutex);
1596         }
1597
1598         vhost_umem_clean(oiotlb);
1599
1600         return 0;
1601 }
1602 EXPORT_SYMBOL_GPL(vhost_init_device_iotlb);
1603
1604 /* Caller must have device mutex */
1605 long vhost_dev_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *argp)
1606 {
1607         struct eventfd_ctx *ctx;
1608         u64 p;
1609         long r;
1610         int i, fd;
1611
1612         /* If you are not the owner, you can become one */
1613         if (ioctl == VHOST_SET_OWNER) {
1614                 r = vhost_dev_set_owner(d);
1615                 goto done;
1616         }
1617
1618         /* You must be the owner to do anything else */
1619         r = vhost_dev_check_owner(d);
1620         if (r)
1621                 goto done;
1622
1623         switch (ioctl) {
1624         case VHOST_SET_MEM_TABLE:
1625                 r = vhost_set_memory(d, argp);
1626                 break;
1627         case VHOST_SET_LOG_BASE:
1628                 if (copy_from_user(&p, argp, sizeof p)) {
1629                         r = -EFAULT;
1630                         break;
1631                 }
1632                 if ((u64)(unsigned long)p != p) {
1633                         r = -EFAULT;
1634                         break;
1635                 }
1636                 for (i = 0; i < d->nvqs; ++i) {
1637                         struct vhost_virtqueue *vq;
1638                         void __user *base = (void __user *)(unsigned long)p;
1639                         vq = d->vqs[i];
1640                         mutex_lock(&vq->mutex);
1641                         /* If ring is inactive, will check when it's enabled. */
1642                         if (vq->private_data && !vq_log_access_ok(vq, base))
1643                                 r = -EFAULT;
1644                         else
1645                                 vq->log_base = base;
1646                         mutex_unlock(&vq->mutex);
1647                 }
1648                 break;
1649         case VHOST_SET_LOG_FD:
1650                 r = get_user(fd, (int __user *)argp);
1651                 if (r < 0)
1652                         break;
1653                 ctx = fd == -1 ? NULL : eventfd_ctx_fdget(fd);
1654                 if (IS_ERR(ctx)) {
1655                         r = PTR_ERR(ctx);
1656                         break;
1657                 }
1658                 swap(ctx, d->log_ctx);
1659                 for (i = 0; i < d->nvqs; ++i) {
1660                         mutex_lock(&d->vqs[i]->mutex);
1661                         d->vqs[i]->log_ctx = d->log_ctx;
1662                         mutex_unlock(&d->vqs[i]->mutex);
1663                 }
1664                 if (ctx)
1665                         eventfd_ctx_put(ctx);
1666                 break;
1667         default:
1668                 r = -ENOIOCTLCMD;
1669                 break;
1670         }
1671 done:
1672         return r;
1673 }
1674 EXPORT_SYMBOL_GPL(vhost_dev_ioctl);
1675
1676 /* TODO: This is really inefficient.  We need something like get_user()
1677  * (instruction directly accesses the data, with an exception table entry
1678  * returning -EFAULT). See Documentation/x86/exception-tables.txt.
1679  */
1680 static int set_bit_to_user(int nr, void __user *addr)
1681 {
1682         unsigned long log = (unsigned long)addr;
1683         struct page *page;
1684         void *base;
1685         int bit = nr + (log % PAGE_SIZE) * 8;
1686         int r;
1687
1688         r = get_user_pages_fast(log, 1, 1, &page);
1689         if (r < 0)
1690                 return r;
1691         BUG_ON(r != 1);
1692         base = kmap_atomic(page);
1693         set_bit(bit, base);
1694         kunmap_atomic(base);
1695         set_page_dirty_lock(page);
1696         put_page(page);
1697         return 0;
1698 }
1699
1700 static int log_write(void __user *log_base,
1701                      u64 write_address, u64 write_length)
1702 {
1703         u64 write_page = write_address / VHOST_PAGE_SIZE;
1704         int r;
1705
1706         if (!write_length)
1707                 return 0;
1708         write_length += write_address % VHOST_PAGE_SIZE;
1709         for (;;) {
1710                 u64 base = (u64)(unsigned long)log_base;
1711                 u64 log = base + write_page / 8;
1712                 int bit = write_page % 8;
1713                 if ((u64)(unsigned long)log != log)
1714                         return -EFAULT;
1715                 r = set_bit_to_user(bit, (void __user *)(unsigned long)log);
1716                 if (r < 0)
1717                         return r;
1718                 if (write_length <= VHOST_PAGE_SIZE)
1719                         break;
1720                 write_length -= VHOST_PAGE_SIZE;
1721                 write_page += 1;
1722         }
1723         return r;
1724 }
1725
1726 int vhost_log_write(struct vhost_virtqueue *vq, struct vhost_log *log,
1727                     unsigned int log_num, u64 len)
1728 {
1729         int i, r;
1730
1731         /* Make sure data written is seen before log. */
1732         smp_wmb();
1733         for (i = 0; i < log_num; ++i) {
1734                 u64 l = min(log[i].len, len);
1735                 r = log_write(vq->log_base, log[i].addr, l);
1736                 if (r < 0)
1737                         return r;
1738                 len -= l;
1739                 if (!len) {
1740                         if (vq->log_ctx)
1741                                 eventfd_signal(vq->log_ctx, 1);
1742                         return 0;
1743                 }
1744         }
1745         /* Length written exceeds what we have stored. This is a bug. */
1746         BUG();
1747         return 0;
1748 }
1749 EXPORT_SYMBOL_GPL(vhost_log_write);
1750
1751 static int vhost_update_used_flags(struct vhost_virtqueue *vq)
1752 {
1753         void __user *used;
1754         if (vhost_put_user(vq, cpu_to_vhost16(vq, vq->used_flags),
1755                            &vq->used->flags) < 0)
1756                 return -EFAULT;
1757         if (unlikely(vq->log_used)) {
1758                 /* Make sure the flag is seen before log. */
1759                 smp_wmb();
1760                 /* Log used flag write. */
1761                 used = &vq->used->flags;
1762                 log_write(vq->log_base, vq->log_addr +
1763                           (used - (void __user *)vq->used),
1764                           sizeof vq->used->flags);
1765                 if (vq->log_ctx)
1766                         eventfd_signal(vq->log_ctx, 1);
1767         }
1768         return 0;
1769 }
1770
1771 static int vhost_update_avail_event(struct vhost_virtqueue *vq, u16 avail_event)
1772 {
1773         if (vhost_put_user(vq, cpu_to_vhost16(vq, vq->avail_idx),
1774                            vhost_avail_event(vq)))
1775                 return -EFAULT;
1776         if (unlikely(vq->log_used)) {
1777                 void __user *used;
1778                 /* Make sure the event is seen before log. */
1779                 smp_wmb();
1780                 /* Log avail event write */
1781                 used = vhost_avail_event(vq);
1782                 log_write(vq->log_base, vq->log_addr +
1783                           (used - (void __user *)vq->used),
1784                           sizeof *vhost_avail_event(vq));
1785                 if (vq->log_ctx)
1786                         eventfd_signal(vq->log_ctx, 1);
1787         }
1788         return 0;
1789 }
1790
1791 int vhost_vq_init_access(struct vhost_virtqueue *vq)
1792 {
1793         __virtio16 last_used_idx;
1794         int r;
1795         bool is_le = vq->is_le;
1796
1797         if (!vq->private_data)
1798                 return 0;
1799
1800         vhost_init_is_le(vq);
1801
1802         r = vhost_update_used_flags(vq);
1803         if (r)
1804                 goto err;
1805         vq->signalled_used_valid = false;
1806         if (!vq->iotlb &&
1807             !access_ok(VERIFY_READ, &vq->used->idx, sizeof vq->used->idx)) {
1808                 r = -EFAULT;
1809                 goto err;
1810         }
1811         r = vhost_get_used(vq, last_used_idx, &vq->used->idx);
1812         if (r) {
1813                 vq_err(vq, "Can't access used idx at %p\n",
1814                        &vq->used->idx);
1815                 goto err;
1816         }
1817         vq->last_used_idx = vhost16_to_cpu(vq, last_used_idx);
1818         return 0;
1819
1820 err:
1821         vq->is_le = is_le;
1822         return r;
1823 }
1824 EXPORT_SYMBOL_GPL(vhost_vq_init_access);
1825
1826 static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len,
1827                           struct iovec iov[], int iov_size, int access)
1828 {
1829         const struct vhost_umem_node *node;
1830         struct vhost_dev *dev = vq->dev;
1831         struct vhost_umem *umem = dev->iotlb ? dev->iotlb : dev->umem;
1832         struct iovec *_iov;
1833         u64 s = 0;
1834         int ret = 0;
1835
1836         while ((u64)len > s) {
1837                 u64 size;
1838                 if (unlikely(ret >= iov_size)) {
1839                         ret = -ENOBUFS;
1840                         break;
1841                 }
1842
1843                 node = vhost_umem_interval_tree_iter_first(&umem->umem_tree,
1844                                                         addr, addr + len - 1);
1845                 if (node == NULL || node->start > addr) {
1846                         if (umem != dev->iotlb) {
1847                                 ret = -EFAULT;
1848                                 break;
1849                         }
1850                         ret = -EAGAIN;
1851                         break;
1852                 } else if (!(node->perm & access)) {
1853                         ret = -EPERM;
1854                         break;
1855                 }
1856
1857                 _iov = iov + ret;
1858                 size = node->size - addr + node->start;
1859                 _iov->iov_len = min((u64)len - s, size);
1860                 _iov->iov_base = (void __user *)(unsigned long)
1861                         (node->userspace_addr + addr - node->start);
1862                 s += size;
1863                 addr += size;
1864                 ++ret;
1865         }
1866
1867         if (ret == -EAGAIN)
1868                 vhost_iotlb_miss(vq, addr, access);
1869         return ret;
1870 }
1871
1872 /* Each buffer in the virtqueues is actually a chain of descriptors.  This
1873  * function returns the next descriptor in the chain,
1874  * or -1U if we're at the end. */
1875 static unsigned next_desc(struct vhost_virtqueue *vq, struct vring_desc *desc)
1876 {
1877         unsigned int next;
1878
1879         /* If this descriptor says it doesn't chain, we're done. */
1880         if (!(desc->flags & cpu_to_vhost16(vq, VRING_DESC_F_NEXT)))
1881                 return -1U;
1882
1883         /* Check they're not leading us off end of descriptors. */
1884         next = vhost16_to_cpu(vq, READ_ONCE(desc->next));
1885         return next;
1886 }
1887
1888 static int get_indirect(struct vhost_virtqueue *vq,
1889                         struct iovec iov[], unsigned int iov_size,
1890                         unsigned int *out_num, unsigned int *in_num,
1891                         struct vhost_log *log, unsigned int *log_num,
1892                         struct vring_desc *indirect)
1893 {
1894         struct vring_desc desc;
1895         unsigned int i = 0, count, found = 0;
1896         u32 len = vhost32_to_cpu(vq, indirect->len);
1897         struct iov_iter from;
1898         int ret, access;
1899
1900         /* Sanity check */
1901         if (unlikely(len % sizeof desc)) {
1902                 vq_err(vq, "Invalid length in indirect descriptor: "
1903                        "len 0x%llx not multiple of 0x%zx\n",
1904                        (unsigned long long)len,
1905                        sizeof desc);
1906                 return -EINVAL;
1907         }
1908
1909         ret = translate_desc(vq, vhost64_to_cpu(vq, indirect->addr), len, vq->indirect,
1910                              UIO_MAXIOV, VHOST_ACCESS_RO);
1911         if (unlikely(ret < 0)) {
1912                 if (ret != -EAGAIN)
1913                         vq_err(vq, "Translation failure %d in indirect.\n", ret);
1914                 return ret;
1915         }
1916         iov_iter_init(&from, READ, vq->indirect, ret, len);
1917
1918         /* We will use the result as an address to read from, so most
1919          * architectures only need a compiler barrier here. */
1920         read_barrier_depends();
1921
1922         count = len / sizeof desc;
1923         /* Buffers are chained via a 16 bit next field, so
1924          * we can have at most 2^16 of these. */
1925         if (unlikely(count > USHRT_MAX + 1)) {
1926                 vq_err(vq, "Indirect buffer length too big: %d\n",
1927                        indirect->len);
1928                 return -E2BIG;
1929         }
1930
1931         do {
1932                 unsigned iov_count = *in_num + *out_num;
1933                 if (unlikely(++found > count)) {
1934                         vq_err(vq, "Loop detected: last one at %u "
1935                                "indirect size %u\n",
1936                                i, count);
1937                         return -EINVAL;
1938                 }
1939                 if (unlikely(!copy_from_iter_full(&desc, sizeof(desc), &from))) {
1940                         vq_err(vq, "Failed indirect descriptor: idx %d, %zx\n",
1941                                i, (size_t)vhost64_to_cpu(vq, indirect->addr) + i * sizeof desc);
1942                         return -EINVAL;
1943                 }
1944                 if (unlikely(desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_INDIRECT))) {
1945                         vq_err(vq, "Nested indirect descriptor: idx %d, %zx\n",
1946                                i, (size_t)vhost64_to_cpu(vq, indirect->addr) + i * sizeof desc);
1947                         return -EINVAL;
1948                 }
1949
1950                 if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_WRITE))
1951                         access = VHOST_ACCESS_WO;
1952                 else
1953                         access = VHOST_ACCESS_RO;
1954
1955                 ret = translate_desc(vq, vhost64_to_cpu(vq, desc.addr),
1956                                      vhost32_to_cpu(vq, desc.len), iov + iov_count,
1957                                      iov_size - iov_count, access);
1958                 if (unlikely(ret < 0)) {
1959                         if (ret != -EAGAIN)
1960                                 vq_err(vq, "Translation failure %d indirect idx %d\n",
1961                                         ret, i);
1962                         return ret;
1963                 }
1964                 /* If this is an input descriptor, increment that count. */
1965                 if (access == VHOST_ACCESS_WO) {
1966                         *in_num += ret;
1967                         if (unlikely(log)) {
1968                                 log[*log_num].addr = vhost64_to_cpu(vq, desc.addr);
1969                                 log[*log_num].len = vhost32_to_cpu(vq, desc.len);
1970                                 ++*log_num;
1971                         }
1972                 } else {
1973                         /* If it's an output descriptor, they're all supposed
1974                          * to come before any input descriptors. */
1975                         if (unlikely(*in_num)) {
1976                                 vq_err(vq, "Indirect descriptor "
1977                                        "has out after in: idx %d\n", i);
1978                                 return -EINVAL;
1979                         }
1980                         *out_num += ret;
1981                 }
1982         } while ((i = next_desc(vq, &desc)) != -1);
1983         return 0;
1984 }
1985
1986 /* This looks in the virtqueue and for the first available buffer, and converts
1987  * it to an iovec for convenient access.  Since descriptors consist of some
1988  * number of output then some number of input descriptors, it's actually two
1989  * iovecs, but we pack them into one and note how many of each there were.
1990  *
1991  * This function returns the descriptor number found, or vq->num (which is
1992  * never a valid descriptor number) if none was found.  A negative code is
1993  * returned on error. */
1994 int vhost_get_vq_desc(struct vhost_virtqueue *vq,
1995                       struct iovec iov[], unsigned int iov_size,
1996                       unsigned int *out_num, unsigned int *in_num,
1997                       struct vhost_log *log, unsigned int *log_num)
1998 {
1999         struct vring_desc desc;
2000         unsigned int i, head, found = 0;
2001         u16 last_avail_idx;
2002         __virtio16 avail_idx;
2003         __virtio16 ring_head;
2004         int ret, access;
2005
2006         /* Check it isn't doing very strange things with descriptor numbers. */
2007         last_avail_idx = vq->last_avail_idx;
2008
2009         if (vq->avail_idx == vq->last_avail_idx) {
2010                 if (unlikely(vhost_get_avail(vq, avail_idx, &vq->avail->idx))) {
2011                         vq_err(vq, "Failed to access avail idx at %p\n",
2012                                 &vq->avail->idx);
2013                         return -EFAULT;
2014                 }
2015                 vq->avail_idx = vhost16_to_cpu(vq, avail_idx);
2016
2017                 if (unlikely((u16)(vq->avail_idx - last_avail_idx) > vq->num)) {
2018                         vq_err(vq, "Guest moved used index from %u to %u",
2019                                 last_avail_idx, vq->avail_idx);
2020                         return -EFAULT;
2021                 }
2022
2023                 /* If there's nothing new since last we looked, return
2024                  * invalid.
2025                  */
2026                 if (vq->avail_idx == last_avail_idx)
2027                         return vq->num;
2028
2029                 /* Only get avail ring entries after they have been
2030                  * exposed by guest.
2031                  */
2032                 smp_rmb();
2033         }
2034
2035         /* Grab the next descriptor number they're advertising, and increment
2036          * the index we've seen. */
2037         if (unlikely(vhost_get_avail(vq, ring_head,
2038                      &vq->avail->ring[last_avail_idx & (vq->num - 1)]))) {
2039                 vq_err(vq, "Failed to read head: idx %d address %p\n",
2040                        last_avail_idx,
2041                        &vq->avail->ring[last_avail_idx % vq->num]);
2042                 return -EFAULT;
2043         }
2044
2045         head = vhost16_to_cpu(vq, ring_head);
2046
2047         /* If their number is silly, that's an error. */
2048         if (unlikely(head >= vq->num)) {
2049                 vq_err(vq, "Guest says index %u > %u is available",
2050                        head, vq->num);
2051                 return -EINVAL;
2052         }
2053
2054         /* When we start there are none of either input nor output. */
2055         *out_num = *in_num = 0;
2056         if (unlikely(log))
2057                 *log_num = 0;
2058
2059         i = head;
2060         do {
2061                 unsigned iov_count = *in_num + *out_num;
2062                 if (unlikely(i >= vq->num)) {
2063                         vq_err(vq, "Desc index is %u > %u, head = %u",
2064                                i, vq->num, head);
2065                         return -EINVAL;
2066                 }
2067                 if (unlikely(++found > vq->num)) {
2068                         vq_err(vq, "Loop detected: last one at %u "
2069                                "vq size %u head %u\n",
2070                                i, vq->num, head);
2071                         return -EINVAL;
2072                 }
2073                 ret = vhost_copy_from_user(vq, &desc, vq->desc + i,
2074                                            sizeof desc);
2075                 if (unlikely(ret)) {
2076                         vq_err(vq, "Failed to get descriptor: idx %d addr %p\n",
2077                                i, vq->desc + i);
2078                         return -EFAULT;
2079                 }
2080                 if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_INDIRECT)) {
2081                         ret = get_indirect(vq, iov, iov_size,
2082                                            out_num, in_num,
2083                                            log, log_num, &desc);
2084                         if (unlikely(ret < 0)) {
2085                                 if (ret != -EAGAIN)
2086                                         vq_err(vq, "Failure detected "
2087                                                 "in indirect descriptor at idx %d\n", i);
2088                                 return ret;
2089                         }
2090                         continue;
2091                 }
2092
2093                 if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_WRITE))
2094                         access = VHOST_ACCESS_WO;
2095                 else
2096                         access = VHOST_ACCESS_RO;
2097                 ret = translate_desc(vq, vhost64_to_cpu(vq, desc.addr),
2098                                      vhost32_to_cpu(vq, desc.len), iov + iov_count,
2099                                      iov_size - iov_count, access);
2100                 if (unlikely(ret < 0)) {
2101                         if (ret != -EAGAIN)
2102                                 vq_err(vq, "Translation failure %d descriptor idx %d\n",
2103                                         ret, i);
2104                         return ret;
2105                 }
2106                 if (access == VHOST_ACCESS_WO) {
2107                         /* If this is an input descriptor,
2108                          * increment that count. */
2109                         *in_num += ret;
2110                         if (unlikely(log)) {
2111                                 log[*log_num].addr = vhost64_to_cpu(vq, desc.addr);
2112                                 log[*log_num].len = vhost32_to_cpu(vq, desc.len);
2113                                 ++*log_num;
2114                         }
2115                 } else {
2116                         /* If it's an output descriptor, they're all supposed
2117                          * to come before any input descriptors. */
2118                         if (unlikely(*in_num)) {
2119                                 vq_err(vq, "Descriptor has out after in: "
2120                                        "idx %d\n", i);
2121                                 return -EINVAL;
2122                         }
2123                         *out_num += ret;
2124                 }
2125         } while ((i = next_desc(vq, &desc)) != -1);
2126
2127         /* On success, increment avail index. */
2128         vq->last_avail_idx++;
2129
2130         /* Assume notifications from guest are disabled at this point,
2131          * if they aren't we would need to update avail_event index. */
2132         BUG_ON(!(vq->used_flags & VRING_USED_F_NO_NOTIFY));
2133         return head;
2134 }
2135 EXPORT_SYMBOL_GPL(vhost_get_vq_desc);
2136
2137 /* Reverse the effect of vhost_get_vq_desc. Useful for error handling. */
2138 void vhost_discard_vq_desc(struct vhost_virtqueue *vq, int n)
2139 {
2140         vq->last_avail_idx -= n;
2141 }
2142 EXPORT_SYMBOL_GPL(vhost_discard_vq_desc);
2143
2144 /* After we've used one of their buffers, we tell them about it.  We'll then
2145  * want to notify the guest, using eventfd. */
2146 int vhost_add_used(struct vhost_virtqueue *vq, unsigned int head, int len)
2147 {
2148         struct vring_used_elem heads = {
2149                 cpu_to_vhost32(vq, head),
2150                 cpu_to_vhost32(vq, len)
2151         };
2152
2153         return vhost_add_used_n(vq, &heads, 1);
2154 }
2155 EXPORT_SYMBOL_GPL(vhost_add_used);
2156
2157 static int __vhost_add_used_n(struct vhost_virtqueue *vq,
2158                             struct vring_used_elem *heads,
2159                             unsigned count)
2160 {
2161         struct vring_used_elem __user *used;
2162         u16 old, new;
2163         int start;
2164
2165         start = vq->last_used_idx & (vq->num - 1);
2166         used = vq->used->ring + start;
2167         if (count == 1) {
2168                 if (vhost_put_user(vq, heads[0].id, &used->id)) {
2169                         vq_err(vq, "Failed to write used id");
2170                         return -EFAULT;
2171                 }
2172                 if (vhost_put_user(vq, heads[0].len, &used->len)) {
2173                         vq_err(vq, "Failed to write used len");
2174                         return -EFAULT;
2175                 }
2176         } else if (vhost_copy_to_user(vq, used, heads, count * sizeof *used)) {
2177                 vq_err(vq, "Failed to write used");
2178                 return -EFAULT;
2179         }
2180         if (unlikely(vq->log_used)) {
2181                 /* Make sure data is seen before log. */
2182                 smp_wmb();
2183                 /* Log used ring entry write. */
2184                 log_write(vq->log_base,
2185                           vq->log_addr +
2186                            ((void __user *)used - (void __user *)vq->used),
2187                           count * sizeof *used);
2188         }
2189         old = vq->last_used_idx;
2190         new = (vq->last_used_idx += count);
2191         /* If the driver never bothers to signal in a very long while,
2192          * used index might wrap around. If that happens, invalidate
2193          * signalled_used index we stored. TODO: make sure driver
2194          * signals at least once in 2^16 and remove this. */
2195         if (unlikely((u16)(new - vq->signalled_used) < (u16)(new - old)))
2196                 vq->signalled_used_valid = false;
2197         return 0;
2198 }
2199
2200 /* After we've used one of their buffers, we tell them about it.  We'll then
2201  * want to notify the guest, using eventfd. */
2202 int vhost_add_used_n(struct vhost_virtqueue *vq, struct vring_used_elem *heads,
2203                      unsigned count)
2204 {
2205         int start, n, r;
2206
2207         start = vq->last_used_idx & (vq->num - 1);
2208         n = vq->num - start;
2209         if (n < count) {
2210                 r = __vhost_add_used_n(vq, heads, n);
2211                 if (r < 0)
2212                         return r;
2213                 heads += n;
2214                 count -= n;
2215         }
2216         r = __vhost_add_used_n(vq, heads, count);
2217
2218         /* Make sure buffer is written before we update index. */
2219         smp_wmb();
2220         if (vhost_put_user(vq, cpu_to_vhost16(vq, vq->last_used_idx),
2221                            &vq->used->idx)) {
2222                 vq_err(vq, "Failed to increment used idx");
2223                 return -EFAULT;
2224         }
2225         if (unlikely(vq->log_used)) {
2226                 /* Log used index update. */
2227                 log_write(vq->log_base,
2228                           vq->log_addr + offsetof(struct vring_used, idx),
2229                           sizeof vq->used->idx);
2230                 if (vq->log_ctx)
2231                         eventfd_signal(vq->log_ctx, 1);
2232         }
2233         return r;
2234 }
2235 EXPORT_SYMBOL_GPL(vhost_add_used_n);
2236
2237 static bool vhost_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)
2238 {
2239         __u16 old, new;
2240         __virtio16 event;
2241         bool v;
2242         /* Flush out used index updates. This is paired
2243          * with the barrier that the Guest executes when enabling
2244          * interrupts. */
2245         smp_mb();
2246
2247         if (vhost_has_feature(vq, VIRTIO_F_NOTIFY_ON_EMPTY) &&
2248             unlikely(vq->avail_idx == vq->last_avail_idx))
2249                 return true;
2250
2251         if (!vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX)) {
2252                 __virtio16 flags;
2253                 if (vhost_get_avail(vq, flags, &vq->avail->flags)) {
2254                         vq_err(vq, "Failed to get flags");
2255                         return true;
2256                 }
2257                 return !(flags & cpu_to_vhost16(vq, VRING_AVAIL_F_NO_INTERRUPT));
2258         }
2259         old = vq->signalled_used;
2260         v = vq->signalled_used_valid;
2261         new = vq->signalled_used = vq->last_used_idx;
2262         vq->signalled_used_valid = true;
2263
2264         if (unlikely(!v))
2265                 return true;
2266
2267         if (vhost_get_avail(vq, event, vhost_used_event(vq))) {
2268                 vq_err(vq, "Failed to get used event idx");
2269                 return true;
2270         }
2271         return vring_need_event(vhost16_to_cpu(vq, event), new, old);
2272 }
2273
2274 /* This actually signals the guest, using eventfd. */
2275 void vhost_signal(struct vhost_dev *dev, struct vhost_virtqueue *vq)
2276 {
2277         /* Signal the Guest tell them we used something up. */
2278         if (vq->call_ctx && vhost_notify(dev, vq))
2279                 eventfd_signal(vq->call_ctx, 1);
2280 }
2281 EXPORT_SYMBOL_GPL(vhost_signal);
2282
2283 /* And here's the combo meal deal.  Supersize me! */
2284 void vhost_add_used_and_signal(struct vhost_dev *dev,
2285                                struct vhost_virtqueue *vq,
2286                                unsigned int head, int len)
2287 {
2288         vhost_add_used(vq, head, len);
2289         vhost_signal(dev, vq);
2290 }
2291 EXPORT_SYMBOL_GPL(vhost_add_used_and_signal);
2292
2293 /* multi-buffer version of vhost_add_used_and_signal */
2294 void vhost_add_used_and_signal_n(struct vhost_dev *dev,
2295                                  struct vhost_virtqueue *vq,
2296                                  struct vring_used_elem *heads, unsigned count)
2297 {
2298         vhost_add_used_n(vq, heads, count);
2299         vhost_signal(dev, vq);
2300 }
2301 EXPORT_SYMBOL_GPL(vhost_add_used_and_signal_n);
2302
2303 /* return true if we're sure that avaiable ring is empty */
2304 bool vhost_vq_avail_empty(struct vhost_dev *dev, struct vhost_virtqueue *vq)
2305 {
2306         __virtio16 avail_idx;
2307         int r;
2308
2309         if (vq->avail_idx != vq->last_avail_idx)
2310                 return false;
2311
2312         r = vhost_get_avail(vq, avail_idx, &vq->avail->idx);
2313         if (unlikely(r))
2314                 return false;
2315         vq->avail_idx = vhost16_to_cpu(vq, avail_idx);
2316
2317         return vq->avail_idx == vq->last_avail_idx;
2318 }
2319 EXPORT_SYMBOL_GPL(vhost_vq_avail_empty);
2320
2321 /* OK, now we need to know about added descriptors. */
2322 bool vhost_enable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)
2323 {
2324         __virtio16 avail_idx;
2325         int r;
2326
2327         if (!(vq->used_flags & VRING_USED_F_NO_NOTIFY))
2328                 return false;
2329         vq->used_flags &= ~VRING_USED_F_NO_NOTIFY;
2330         if (!vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX)) {
2331                 r = vhost_update_used_flags(vq);
2332                 if (r) {
2333                         vq_err(vq, "Failed to enable notification at %p: %d\n",
2334                                &vq->used->flags, r);
2335                         return false;
2336                 }
2337         } else {
2338                 r = vhost_update_avail_event(vq, vq->avail_idx);
2339                 if (r) {
2340                         vq_err(vq, "Failed to update avail event index at %p: %d\n",
2341                                vhost_avail_event(vq), r);
2342                         return false;
2343                 }
2344         }
2345         /* They could have slipped one in as we were doing that: make
2346          * sure it's written, then check again. */
2347         smp_mb();
2348         r = vhost_get_avail(vq, avail_idx, &vq->avail->idx);
2349         if (r) {
2350                 vq_err(vq, "Failed to check avail idx at %p: %d\n",
2351                        &vq->avail->idx, r);
2352                 return false;
2353         }
2354
2355         return vhost16_to_cpu(vq, avail_idx) != vq->avail_idx;
2356 }
2357 EXPORT_SYMBOL_GPL(vhost_enable_notify);
2358
2359 /* We don't need to be notified again. */
2360 void vhost_disable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)
2361 {
2362         int r;
2363
2364         if (vq->used_flags & VRING_USED_F_NO_NOTIFY)
2365                 return;
2366         vq->used_flags |= VRING_USED_F_NO_NOTIFY;
2367         if (!vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX)) {
2368                 r = vhost_update_used_flags(vq);
2369                 if (r)
2370                         vq_err(vq, "Failed to enable notification at %p: %d\n",
2371                                &vq->used->flags, r);
2372         }
2373 }
2374 EXPORT_SYMBOL_GPL(vhost_disable_notify);
2375
2376 /* Create a new message. */
2377 struct vhost_msg_node *vhost_new_msg(struct vhost_virtqueue *vq, int type)
2378 {
2379         struct vhost_msg_node *node = kmalloc(sizeof *node, GFP_KERNEL);
2380         if (!node)
2381                 return NULL;
2382
2383         /* Make sure all padding within the structure is initialized. */
2384         memset(&node->msg, 0, sizeof node->msg);
2385         node->vq = vq;
2386         node->msg.type = type;
2387         return node;
2388 }
2389 EXPORT_SYMBOL_GPL(vhost_new_msg);
2390
2391 void vhost_enqueue_msg(struct vhost_dev *dev, struct list_head *head,
2392                        struct vhost_msg_node *node)
2393 {
2394         spin_lock(&dev->iotlb_lock);
2395         list_add_tail(&node->node, head);
2396         spin_unlock(&dev->iotlb_lock);
2397
2398         wake_up_interruptible_poll(&dev->wait, EPOLLIN | EPOLLRDNORM);
2399 }
2400 EXPORT_SYMBOL_GPL(vhost_enqueue_msg);
2401
2402 struct vhost_msg_node *vhost_dequeue_msg(struct vhost_dev *dev,
2403                                          struct list_head *head)
2404 {
2405         struct vhost_msg_node *node = NULL;
2406
2407         spin_lock(&dev->iotlb_lock);
2408         if (!list_empty(head)) {
2409                 node = list_first_entry(head, struct vhost_msg_node,
2410                                         node);
2411                 list_del(&node->node);
2412         }
2413         spin_unlock(&dev->iotlb_lock);
2414
2415         return node;
2416 }
2417 EXPORT_SYMBOL_GPL(vhost_dequeue_msg);
2418
2419
2420 static int __init vhost_init(void)
2421 {
2422         return 0;
2423 }
2424
2425 static void __exit vhost_exit(void)
2426 {
2427 }
2428
2429 module_init(vhost_init);
2430 module_exit(vhost_exit);
2431
2432 MODULE_VERSION("0.0.1");
2433 MODULE_LICENSE("GPL v2");
2434 MODULE_AUTHOR("Michael S. Tsirkin");
2435 MODULE_DESCRIPTION("Host kernel accelerator for virtio");