diff --git a/drivers/block/nbd.c b/drivers/block/nbd.c index d61a04155d99..438f4dc549db 100644 --- a/drivers/block/nbd.c +++ b/drivers/block/nbd.c @@ -60,7 +60,8 @@ struct nbd_device { bool disconnect; /* a disconnect has been requested by user */ struct timer_list timeout_timer; - spinlock_t tasks_lock; + /* protects initialization and shutdown of the socket */ + spinlock_t sock_lock; struct task_struct *task_recv; struct task_struct *task_send; @@ -129,13 +130,20 @@ static void nbd_end_request(struct nbd_device *nbd, struct request *req) */ static void sock_shutdown(struct nbd_device *nbd) { - if (!nbd->sock) + spin_lock_irq(&nbd->sock_lock); + + if (!nbd->sock) { + spin_unlock_irq(&nbd->sock_lock); return; + } dev_warn(disk_to_dev(nbd->disk), "shutting down socket\n"); kernel_sock_shutdown(nbd->sock, SHUT_RDWR); + sockfd_put(nbd->sock); nbd->sock = NULL; - del_timer_sync(&nbd->timeout_timer); + spin_unlock_irq(&nbd->sock_lock); + + del_timer(&nbd->timeout_timer); } static void nbd_xmit_timeout(unsigned long arg) @@ -148,17 +156,15 @@ static void nbd_xmit_timeout(unsigned long arg) nbd->disconnect = true; - spin_lock_irqsave(&nbd->tasks_lock, flags); + spin_lock_irqsave(&nbd->sock_lock, flags); - if (nbd->task_recv) - force_sig(SIGKILL, nbd->task_recv); - if (nbd->task_send) - force_sig(SIGKILL, nbd->task_send); + if (nbd->sock) + kernel_sock_shutdown(nbd->sock, SHUT_RDWR); - spin_unlock_irqrestore(&nbd->tasks_lock, flags); + spin_unlock_irqrestore(&nbd->sock_lock, flags); - dev_err(nbd_to_dev(nbd), "Connection timed out, killed receiver and sender, shutting down connection\n"); + dev_err(nbd_to_dev(nbd), "Connection timed out, shutting down connection\n"); } /* @@ -171,7 +177,6 @@ static int sock_xmit(struct nbd_device *nbd, int send, void *buf, int size, int result; struct msghdr msg; struct kvec iov; - sigset_t blocked, oldset; unsigned long pflags = current->flags; if (unlikely(!sock)) { @@ -181,11 +186,6 @@ static int sock_xmit(struct nbd_device *nbd, int send, void *buf, int size, return -EINVAL; } - /* Allow interception of SIGKILL only - * Don't allow other signals to interrupt the transmission */ - siginitsetinv(&blocked, sigmask(SIGKILL)); - sigprocmask(SIG_SETMASK, &blocked, &oldset); - current->flags |= PF_MEMALLOC; do { sock->sk->sk_allocation = GFP_NOIO | __GFP_MEMALLOC; @@ -212,7 +212,6 @@ static int sock_xmit(struct nbd_device *nbd, int send, void *buf, int size, buf += result; } while (size > 0); - sigprocmask(SIG_SETMASK, &oldset, NULL); tsk_restore_flags(current, pflags, PF_MEMALLOC); if (!send && nbd->xmit_timeout) @@ -406,23 +405,18 @@ static int nbd_thread_recv(struct nbd_device *nbd) { struct request *req; int ret; - unsigned long flags; BUG_ON(nbd->magic != NBD_MAGIC); sk_set_memalloc(nbd->sock->sk); - spin_lock_irqsave(&nbd->tasks_lock, flags); nbd->task_recv = current; - spin_unlock_irqrestore(&nbd->tasks_lock, flags); ret = device_create_file(disk_to_dev(nbd->disk), &pid_attr); if (ret) { dev_err(disk_to_dev(nbd->disk), "device_create_file failed!\n"); - spin_lock_irqsave(&nbd->tasks_lock, flags); nbd->task_recv = NULL; - spin_unlock_irqrestore(&nbd->tasks_lock, flags); return ret; } @@ -439,19 +433,7 @@ static int nbd_thread_recv(struct nbd_device *nbd) device_remove_file(disk_to_dev(nbd->disk), &pid_attr); - spin_lock_irqsave(&nbd->tasks_lock, flags); nbd->task_recv = NULL; - spin_unlock_irqrestore(&nbd->tasks_lock, flags); - - if (signal_pending(current)) { - ret = kernel_dequeue_signal(NULL); - dev_warn(nbd_to_dev(nbd), "pid %d, %s, got signal %d\n", - task_pid_nr(current), current->comm, ret); - mutex_lock(&nbd->tx_lock); - sock_shutdown(nbd); - mutex_unlock(&nbd->tx_lock); - ret = -ETIMEDOUT; - } return ret; } @@ -544,11 +526,8 @@ static int nbd_thread_send(void *data) { struct nbd_device *nbd = data; struct request *req; - unsigned long flags; - spin_lock_irqsave(&nbd->tasks_lock, flags); nbd->task_send = current; - spin_unlock_irqrestore(&nbd->tasks_lock, flags); set_user_nice(current, MIN_NICE); while (!kthread_should_stop() || !list_empty(&nbd->waiting_queue)) { @@ -557,17 +536,6 @@ static int nbd_thread_send(void *data) kthread_should_stop() || !list_empty(&nbd->waiting_queue)); - if (signal_pending(current)) { - int ret = kernel_dequeue_signal(NULL); - - dev_warn(nbd_to_dev(nbd), "pid %d, %s, got signal %d\n", - task_pid_nr(current), current->comm, ret); - mutex_lock(&nbd->tx_lock); - sock_shutdown(nbd); - mutex_unlock(&nbd->tx_lock); - break; - } - /* extract request */ if (list_empty(&nbd->waiting_queue)) continue; @@ -582,13 +550,7 @@ static int nbd_thread_send(void *data) nbd_handle_req(nbd, req); } - spin_lock_irqsave(&nbd->tasks_lock, flags); nbd->task_send = NULL; - spin_unlock_irqrestore(&nbd->tasks_lock, flags); - - /* Clear maybe pending signals */ - if (signal_pending(current)) - kernel_dequeue_signal(NULL); return 0; } @@ -636,6 +598,25 @@ static void nbd_request_handler(struct request_queue *q) } } +static int nbd_set_socket(struct nbd_device *nbd, struct socket *sock) +{ + int ret = 0; + + spin_lock_irq(&nbd->sock_lock); + + if (nbd->sock) { + ret = -EBUSY; + goto out; + } + + nbd->sock = sock; + +out: + spin_unlock_irq(&nbd->sock_lock); + + return ret; +} + static int nbd_dev_dbg_init(struct nbd_device *nbd); static void nbd_dev_dbg_close(struct nbd_device *nbd); @@ -668,32 +649,26 @@ static int __nbd_ioctl(struct block_device *bdev, struct nbd_device *nbd, return 0; } - case NBD_CLEAR_SOCK: { - struct socket *sock = nbd->sock; - nbd->sock = NULL; + case NBD_CLEAR_SOCK: + sock_shutdown(nbd); nbd_clear_que(nbd); BUG_ON(!list_empty(&nbd->queue_head)); BUG_ON(!list_empty(&nbd->waiting_queue)); kill_bdev(bdev); - if (sock) - sockfd_put(sock); return 0; - } case NBD_SET_SOCK: { - struct socket *sock; int err; - if (nbd->sock) - return -EBUSY; - sock = sockfd_lookup(arg, &err); - if (sock) { - nbd->sock = sock; - if (max_part > 0) - bdev->bd_invalidated = 1; - nbd->disconnect = false; /* we're connected now */ - return 0; - } - return -EINVAL; + struct socket *sock = sockfd_lookup(arg, &err); + + if (!sock) + return err; + + err = nbd_set_socket(nbd, sock); + if (!err && max_part) + bdev->bd_invalidated = 1; + + return err; } case NBD_SET_BLKSIZE: @@ -734,7 +709,6 @@ static int __nbd_ioctl(struct block_device *bdev, struct nbd_device *nbd, case NBD_DO_IT: { struct task_struct *thread; - struct socket *sock; int error; if (nbd->task_recv) @@ -769,14 +743,10 @@ static int __nbd_ioctl(struct block_device *bdev, struct nbd_device *nbd, mutex_lock(&nbd->tx_lock); sock_shutdown(nbd); - sock = nbd->sock; - nbd->sock = NULL; nbd_clear_que(nbd); kill_bdev(bdev); queue_flag_clear_unlocked(QUEUE_FLAG_DISCARD, nbd->disk->queue); set_device_ro(bdev, false); - if (sock) - sockfd_put(sock); nbd->flags = 0; nbd->bytesize = 0; bdev->bd_inode->i_size = 0; @@ -1042,7 +1012,7 @@ static int __init nbd_init(void) nbd_dev[i].magic = NBD_MAGIC; INIT_LIST_HEAD(&nbd_dev[i].waiting_queue); spin_lock_init(&nbd_dev[i].queue_lock); - spin_lock_init(&nbd_dev[i].tasks_lock); + spin_lock_init(&nbd_dev[i].sock_lock); INIT_LIST_HEAD(&nbd_dev[i].queue_head); mutex_init(&nbd_dev[i].tx_lock); init_timer(&nbd_dev[i].timeout_timer);