diff --git a/io_uring/msg_ring.c b/io_uring/msg_ring.c index 47a754e83b49..c2171495098b 100644 --- a/io_uring/msg_ring.c +++ b/io_uring/msg_ring.c @@ -86,16 +86,21 @@ static void io_msg_tw_complete(struct io_kiocb *req, struct io_tw_state *ts) percpu_ref_put(&ctx->refs); } -static void io_msg_remote_post(struct io_ring_ctx *ctx, struct io_kiocb *req, - int res, u32 cflags, u64 user_data) +static int io_msg_remote_post(struct io_ring_ctx *ctx, struct io_kiocb *req, + int res, u32 cflags, u64 user_data) { + req->task = READ_ONCE(ctx->submitter_task); + if (!req->task) { + kmem_cache_free(req_cachep, req); + return -EOWNERDEAD; + } req->cqe.user_data = user_data; io_req_set_res(req, res, cflags); percpu_ref_get(&ctx->refs); req->ctx = ctx; - req->task = READ_ONCE(ctx->submitter_task); req->io_task_work.func = io_msg_tw_complete; io_req_task_work_add_remote(req, ctx, IOU_F_TWQ_LAZY_WAKE); + return 0; } static struct io_kiocb *io_msg_get_kiocb(struct io_ring_ctx *ctx) @@ -125,8 +130,8 @@ static int io_msg_data_remote(struct io_kiocb *req) if (msg->flags & IORING_MSG_RING_FLAGS_PASS) flags = msg->cqe_flags; - io_msg_remote_post(target_ctx, target, msg->len, flags, msg->user_data); - return 0; + return io_msg_remote_post(target_ctx, target, msg->len, flags, + msg->user_data); } static int io_msg_ring_data(struct io_kiocb *req, unsigned int issue_flags)