aboutsummaryrefslogtreecommitdiff
path: root/nbd/server.c
diff options
context:
space:
mode:
Diffstat (limited to 'nbd/server.c')
-rw-r--r--nbd/server.c343
1 files changed, 119 insertions, 224 deletions
diff --git a/nbd/server.c b/nbd/server.c
index 49b55f6ede..8a70c054a6 100644
--- a/nbd/server.c
+++ b/nbd/server.c
@@ -81,7 +81,7 @@ static QTAILQ_HEAD(, NBDExport) exports = QTAILQ_HEAD_INITIALIZER(exports);
struct NBDClient {
int refcount;
- void (*close)(NBDClient *client);
+ void (*close_fn)(NBDClient *client, bool negotiated);
bool no_zeroes;
NBDExport *exp;
@@ -104,69 +104,6 @@ struct NBDClient {
static void nbd_client_receive_next_request(NBDClient *client);
-static gboolean nbd_negotiate_continue(QIOChannel *ioc,
- GIOCondition condition,
- void *opaque)
-{
- qemu_coroutine_enter(opaque);
- return TRUE;
-}
-
-static int nbd_negotiate_read(QIOChannel *ioc, void *buffer, size_t size)
-{
- ssize_t ret;
- guint watch;
-
- assert(qemu_in_coroutine());
- /* Negotiation are always in main loop. */
- watch = qio_channel_add_watch(ioc,
- G_IO_IN,
- nbd_negotiate_continue,
- qemu_coroutine_self(),
- NULL);
- ret = read_sync(ioc, buffer, size, NULL);
- g_source_remove(watch);
- return ret;
-
-}
-
-static int nbd_negotiate_write(QIOChannel *ioc, const void *buffer, size_t size)
-{
- ssize_t ret;
- guint watch;
-
- assert(qemu_in_coroutine());
- /* Negotiation are always in main loop. */
- watch = qio_channel_add_watch(ioc,
- G_IO_OUT,
- nbd_negotiate_continue,
- qemu_coroutine_self(),
- NULL);
- ret = write_sync(ioc, buffer, size, NULL);
- g_source_remove(watch);
- return ret;
-}
-
-static int nbd_negotiate_drop_sync(QIOChannel *ioc, size_t size)
-{
- ssize_t ret;
- uint8_t *buffer = g_malloc(MIN(65536, size));
-
- while (size > 0) {
- size_t count = MIN(65536, size);
- ret = nbd_negotiate_read(ioc, buffer, count);
- if (ret < 0) {
- g_free(buffer);
- return ret;
- }
-
- size -= count;
- }
-
- g_free(buffer);
- return 0;
-}
-
/* Basic flow for negotiation
Server Client
@@ -205,22 +142,22 @@ static int nbd_negotiate_send_rep_len(QIOChannel *ioc, uint32_t type,
type, opt, len);
magic = cpu_to_be64(NBD_REP_MAGIC);
- if (nbd_negotiate_write(ioc, &magic, sizeof(magic)) < 0) {
+ if (nbd_write(ioc, &magic, sizeof(magic), NULL) < 0) {
LOG("write failed (rep magic)");
return -EINVAL;
}
opt = cpu_to_be32(opt);
- if (nbd_negotiate_write(ioc, &opt, sizeof(opt)) < 0) {
+ if (nbd_write(ioc, &opt, sizeof(opt), NULL) < 0) {
LOG("write failed (rep opt)");
return -EINVAL;
}
type = cpu_to_be32(type);
- if (nbd_negotiate_write(ioc, &type, sizeof(type)) < 0) {
+ if (nbd_write(ioc, &type, sizeof(type), NULL) < 0) {
LOG("write failed (rep type)");
return -EINVAL;
}
len = cpu_to_be32(len);
- if (nbd_negotiate_write(ioc, &len, sizeof(len)) < 0) {
+ if (nbd_write(ioc, &len, sizeof(len), NULL) < 0) {
LOG("write failed (rep data length)");
return -EINVAL;
}
@@ -255,7 +192,7 @@ nbd_negotiate_send_rep_err(QIOChannel *ioc, uint32_t type,
if (ret < 0) {
goto out;
}
- if (nbd_negotiate_write(ioc, msg, len) < 0) {
+ if (nbd_write(ioc, msg, len, NULL) < 0) {
LOG("write failed (error message)");
ret = -EIO;
} else {
@@ -274,27 +211,27 @@ static int nbd_negotiate_send_rep_list(QIOChannel *ioc, NBDExport *exp)
uint32_t len;
const char *name = exp->name ? exp->name : "";
const char *desc = exp->description ? exp->description : "";
- int rc;
+ int ret;
TRACE("Advertising export name '%s' description '%s'", name, desc);
name_len = strlen(name);
desc_len = strlen(desc);
len = name_len + desc_len + sizeof(len);
- rc = nbd_negotiate_send_rep_len(ioc, NBD_REP_SERVER, NBD_OPT_LIST, len);
- if (rc < 0) {
- return rc;
+ ret = nbd_negotiate_send_rep_len(ioc, NBD_REP_SERVER, NBD_OPT_LIST, len);
+ if (ret < 0) {
+ return ret;
}
len = cpu_to_be32(name_len);
- if (nbd_negotiate_write(ioc, &len, sizeof(len)) < 0) {
+ if (nbd_write(ioc, &len, sizeof(len), NULL) < 0) {
LOG("write failed (name length)");
return -EINVAL;
}
- if (nbd_negotiate_write(ioc, name, name_len) < 0) {
+ if (nbd_write(ioc, name, name_len, NULL) < 0) {
LOG("write failed (name buffer)");
return -EINVAL;
}
- if (nbd_negotiate_write(ioc, desc, desc_len) < 0) {
+ if (nbd_write(ioc, desc, desc_len, NULL) < 0) {
LOG("write failed (description buffer)");
return -EINVAL;
}
@@ -308,7 +245,7 @@ static int nbd_negotiate_handle_list(NBDClient *client, uint32_t length)
NBDExport *exp;
if (length) {
- if (nbd_negotiate_drop_sync(client->ioc, length) < 0) {
+ if (nbd_drop(client->ioc, length, NULL) < 0) {
return -EIO;
}
return nbd_negotiate_send_rep_err(client->ioc,
@@ -328,7 +265,6 @@ static int nbd_negotiate_handle_list(NBDClient *client, uint32_t length)
static int nbd_negotiate_handle_export_name(NBDClient *client, uint32_t length)
{
- int rc = -EINVAL;
char name[NBD_MAX_NAME_SIZE + 1];
/* Client sends:
@@ -337,11 +273,11 @@ static int nbd_negotiate_handle_export_name(NBDClient *client, uint32_t length)
TRACE("Checking length");
if (length >= sizeof(name)) {
LOG("Bad length received");
- goto fail;
+ return -EINVAL;
}
- if (nbd_negotiate_read(client->ioc, name, length) < 0) {
+ if (nbd_read(client->ioc, name, length, NULL) < 0) {
LOG("read failed");
- goto fail;
+ return -EINVAL;
}
name[length] = '\0';
@@ -350,14 +286,13 @@ static int nbd_negotiate_handle_export_name(NBDClient *client, uint32_t length)
client->exp = nbd_export_find(name);
if (!client->exp) {
LOG("export not found");
- goto fail;
+ return -EINVAL;
}
QTAILQ_INSERT_TAIL(&client->exp->clients, client, next);
nbd_export_get(client->exp);
- rc = 0;
-fail:
- return rc;
+
+ return 0;
}
/* Handle NBD_OPT_STARTTLS. Return NULL to drop connection, or else the
@@ -372,7 +307,7 @@ static QIOChannel *nbd_negotiate_handle_starttls(NBDClient *client,
TRACE("Setting up TLS");
ioc = client->ioc;
if (length) {
- if (nbd_negotiate_drop_sync(ioc, length) < 0) {
+ if (nbd_drop(ioc, length, NULL) < 0) {
return NULL;
}
nbd_negotiate_send_rep_err(ioc, NBD_REP_ERR_INVALID, NBD_OPT_STARTTLS,
@@ -436,7 +371,7 @@ static int nbd_negotiate_options(NBDClient *client)
... Rest of request
*/
- if (nbd_negotiate_read(client->ioc, &flags, sizeof(flags)) < 0) {
+ if (nbd_read(client->ioc, &flags, sizeof(flags), NULL) < 0) {
LOG("read failed");
return -EIO;
}
@@ -462,7 +397,7 @@ static int nbd_negotiate_options(NBDClient *client)
uint32_t clientflags, length;
uint64_t magic;
- if (nbd_negotiate_read(client->ioc, &magic, sizeof(magic)) < 0) {
+ if (nbd_read(client->ioc, &magic, sizeof(magic), NULL) < 0) {
LOG("read failed");
return -EINVAL;
}
@@ -472,15 +407,15 @@ static int nbd_negotiate_options(NBDClient *client)
return -EINVAL;
}
- if (nbd_negotiate_read(client->ioc, &clientflags,
- sizeof(clientflags)) < 0)
+ if (nbd_read(client->ioc, &clientflags,
+ sizeof(clientflags), NULL) < 0)
{
LOG("read failed");
return -EINVAL;
}
clientflags = be32_to_cpu(clientflags);
- if (nbd_negotiate_read(client->ioc, &length, sizeof(length)) < 0) {
+ if (nbd_read(client->ioc, &length, sizeof(length), NULL) < 0) {
LOG("read failed");
return -EINVAL;
}
@@ -510,7 +445,7 @@ static int nbd_negotiate_options(NBDClient *client)
return -EINVAL;
default:
- if (nbd_negotiate_drop_sync(client->ioc, length) < 0) {
+ if (nbd_drop(client->ioc, length, NULL) < 0) {
return -EIO;
}
ret = nbd_negotiate_send_rep_err(client->ioc,
@@ -548,7 +483,7 @@ static int nbd_negotiate_options(NBDClient *client)
return nbd_negotiate_handle_export_name(client, length);
case NBD_OPT_STARTTLS:
- if (nbd_negotiate_drop_sync(client->ioc, length) < 0) {
+ if (nbd_drop(client->ioc, length, NULL) < 0) {
return -EIO;
}
if (client->tlscreds) {
@@ -567,7 +502,7 @@ static int nbd_negotiate_options(NBDClient *client)
}
break;
default:
- if (nbd_negotiate_drop_sync(client->ioc, length) < 0) {
+ if (nbd_drop(client->ioc, length, NULL) < 0) {
return -EIO;
}
ret = nbd_negotiate_send_rep_err(client->ioc,
@@ -598,16 +533,10 @@ static int nbd_negotiate_options(NBDClient *client)
}
}
-typedef struct {
- NBDClient *client;
- Coroutine *co;
-} NBDClientNewData;
-
-static coroutine_fn int nbd_negotiate(NBDClientNewData *data)
+static coroutine_fn int nbd_negotiate(NBDClient *client)
{
- NBDClient *client = data->client;
char buf[8 + 8 + 8 + 128];
- int rc;
+ int ret;
const uint16_t myflags = (NBD_FLAG_HAS_FLAGS | NBD_FLAG_SEND_TRIM |
NBD_FLAG_SEND_FLUSH | NBD_FLAG_SEND_FUA |
NBD_FLAG_SEND_WRITE_ZEROES);
@@ -633,7 +562,6 @@ static coroutine_fn int nbd_negotiate(NBDClientNewData *data)
*/
qio_channel_set_blocking(client->ioc, false, NULL);
- rc = -EINVAL;
TRACE("Beginning negotiation.");
memset(buf, 0, sizeof(buf));
@@ -654,21 +582,21 @@ static coroutine_fn int nbd_negotiate(NBDClientNewData *data)
if (oldStyle) {
if (client->tlscreds) {
TRACE("TLS cannot be enabled with oldstyle protocol");
- goto fail;
+ return -EINVAL;
}
- if (nbd_negotiate_write(client->ioc, buf, sizeof(buf)) < 0) {
+ if (nbd_write(client->ioc, buf, sizeof(buf), NULL) < 0) {
LOG("write failed");
- goto fail;
+ return -EINVAL;
}
} else {
- if (nbd_negotiate_write(client->ioc, buf, 18) < 0) {
+ if (nbd_write(client->ioc, buf, 18, NULL) < 0) {
LOG("write failed");
- goto fail;
+ return -EINVAL;
}
- rc = nbd_negotiate_options(client);
- if (rc != 0) {
+ ret = nbd_negotiate_options(client);
+ if (ret != 0) {
LOG("option negotiation failed");
- goto fail;
+ return ret;
}
TRACE("advertising size %" PRIu64 " and flags %x",
@@ -676,25 +604,25 @@ static coroutine_fn int nbd_negotiate(NBDClientNewData *data)
stq_be_p(buf + 18, client->exp->size);
stw_be_p(buf + 26, client->exp->nbdflags | myflags);
len = client->no_zeroes ? 10 : sizeof(buf) - 18;
- if (nbd_negotiate_write(client->ioc, buf + 18, len) < 0) {
+ ret = nbd_write(client->ioc, buf + 18, len, NULL);
+ if (ret < 0) {
LOG("write failed");
- goto fail;
+ return ret;
}
}
TRACE("Negotiation succeeded.");
- rc = 0;
-fail:
- return rc;
+
+ return 0;
}
-static ssize_t nbd_receive_request(QIOChannel *ioc, NBDRequest *request)
+static int nbd_receive_request(QIOChannel *ioc, NBDRequest *request)
{
uint8_t buf[NBD_REQUEST_SIZE];
uint32_t magic;
- ssize_t ret;
+ int ret;
- ret = read_sync(ioc, buf, sizeof(buf), NULL);
+ ret = nbd_read(ioc, buf, sizeof(buf), NULL);
if (ret < 0) {
return ret;
}
@@ -726,7 +654,7 @@ static ssize_t nbd_receive_request(QIOChannel *ioc, NBDRequest *request)
return 0;
}
-static ssize_t nbd_send_reply(QIOChannel *ioc, NBDReply *reply)
+static int nbd_send_reply(QIOChannel *ioc, NBDReply *reply)
{
uint8_t buf[NBD_REPLY_SIZE];
@@ -745,7 +673,7 @@ static ssize_t nbd_send_reply(QIOChannel *ioc, NBDReply *reply)
stl_be_p(buf + 4, reply->error);
stq_be_p(buf + 8, reply->handle);
- return write_sync(ioc, buf, sizeof(buf), NULL);
+ return nbd_write(ioc, buf, sizeof(buf), NULL);
}
#define MAX_NBD_REQUESTS 16
@@ -778,7 +706,7 @@ void nbd_client_put(NBDClient *client)
}
}
-static void client_close(NBDClient *client)
+static void client_close(NBDClient *client, bool negotiated)
{
if (client->closing) {
return;
@@ -793,8 +721,8 @@ static void client_close(NBDClient *client)
NULL);
/* Also tell the client, so that they release their reference. */
- if (client->close) {
- client->close(client);
+ if (client->close_fn) {
+ client->close_fn(client, negotiated);
}
}
@@ -975,7 +903,7 @@ void nbd_export_close(NBDExport *exp)
nbd_export_get(exp);
QTAILQ_FOREACH_SAFE(client, &exp->clients, next, next) {
- client_close(client);
+ client_close(client, true);
}
nbd_export_set_name(exp, NULL);
nbd_export_set_description(exp, NULL);
@@ -1032,25 +960,24 @@ void nbd_export_close_all(void)
}
}
-static ssize_t nbd_co_send_reply(NBDRequestData *req, NBDReply *reply,
- int len)
+static int nbd_co_send_reply(NBDRequestData *req, NBDReply *reply, int len)
{
NBDClient *client = req->client;
- ssize_t rc, ret;
+ int ret;
g_assert(qemu_in_coroutine());
qemu_co_mutex_lock(&client->send_lock);
client->send_coroutine = qemu_coroutine_self();
if (!len) {
- rc = nbd_send_reply(client->ioc, reply);
+ ret = nbd_send_reply(client->ioc, reply);
} else {
qio_channel_set_cork(client->ioc, true);
- rc = nbd_send_reply(client->ioc, reply);
- if (rc >= 0) {
- ret = write_sync(client->ioc, req->data, len, NULL);
+ ret = nbd_send_reply(client->ioc, reply);
+ if (ret == 0) {
+ ret = nbd_write(client->ioc, req->data, len, NULL);
if (ret < 0) {
- rc = -EIO;
+ ret = -EIO;
}
}
qio_channel_set_cork(client->ioc, false);
@@ -1058,28 +985,23 @@ static ssize_t nbd_co_send_reply(NBDRequestData *req, NBDReply *reply,
client->send_coroutine = NULL;
qemu_co_mutex_unlock(&client->send_lock);
- return rc;
+ return ret;
}
-/* Collect a client request. Return 0 if request looks valid, -EAGAIN
- * to keep trying the collection, -EIO to drop connection right away,
- * and any other negative value to report an error to the client
- * (although the caller may still need to disconnect after reporting
- * the error). */
-static ssize_t nbd_co_receive_request(NBDRequestData *req,
- NBDRequest *request)
+/* nbd_co_receive_request
+ * Collect a client request. Return 0 if request looks valid, -EIO to drop
+ * connection right away, and any other negative value to report an error to
+ * the client (although the caller may still need to disconnect after reporting
+ * the error).
+ */
+static int nbd_co_receive_request(NBDRequestData *req, NBDRequest *request)
{
NBDClient *client = req->client;
- ssize_t rc;
g_assert(qemu_in_coroutine());
assert(client->recv_coroutine == qemu_coroutine_self());
- rc = nbd_receive_request(client->ioc, request);
- if (rc < 0) {
- if (rc != -EAGAIN) {
- rc = -EIO;
- }
- goto out;
+ if (nbd_receive_request(client->ioc, request) < 0) {
+ return -EIO;
}
TRACE("Decoding type");
@@ -1093,8 +1015,7 @@ static ssize_t nbd_co_receive_request(NBDRequestData *req,
/* Special case: we're going to disconnect without a reply,
* whether or not flags, from, or len are bogus */
TRACE("Request type is DISCONNECT");
- rc = -EIO;
- goto out;
+ return -EIO;
}
/* Check for sanity in the parameters, part 1. Defer as many
@@ -1102,31 +1023,27 @@ static ssize_t nbd_co_receive_request(NBDRequestData *req,
* payload, so we can try and keep the connection alive. */
if ((request->from + request->len) < request->from) {
LOG("integer overflow detected, you're probably being attacked");
- rc = -EINVAL;
- goto out;
+ return -EINVAL;
}
if (request->type == NBD_CMD_READ || request->type == NBD_CMD_WRITE) {
if (request->len > NBD_MAX_BUFFER_SIZE) {
LOG("len (%" PRIu32" ) is larger than max len (%u)",
request->len, NBD_MAX_BUFFER_SIZE);
- rc = -EINVAL;
- goto out;
+ return -EINVAL;
}
req->data = blk_try_blockalign(client->exp->blk, request->len);
if (req->data == NULL) {
- rc = -ENOMEM;
- goto out;
+ return -ENOMEM;
}
}
if (request->type == NBD_CMD_WRITE) {
TRACE("Reading %" PRIu32 " byte(s)", request->len);
- if (read_sync(client->ioc, req->data, request->len, NULL) < 0) {
+ if (nbd_read(client->ioc, req->data, request->len, NULL) < 0) {
LOG("reading from socket failed");
- rc = -EIO;
- goto out;
+ return -EIO;
}
req->complete = true;
}
@@ -1136,28 +1053,19 @@ static ssize_t nbd_co_receive_request(NBDRequestData *req,
LOG("operation past EOF; From: %" PRIu64 ", Len: %" PRIu32
", Size: %" PRIu64, request->from, request->len,
(uint64_t)client->exp->size);
- rc = request->type == NBD_CMD_WRITE ? -ENOSPC : -EINVAL;
- goto out;
+ return request->type == NBD_CMD_WRITE ? -ENOSPC : -EINVAL;
}
if (request->flags & ~(NBD_CMD_FLAG_FUA | NBD_CMD_FLAG_NO_HOLE)) {
LOG("unsupported flags (got 0x%x)", request->flags);
- rc = -EINVAL;
- goto out;
+ return -EINVAL;
}
if (request->type != NBD_CMD_WRITE_ZEROES &&
(request->flags & NBD_CMD_FLAG_NO_HOLE)) {
LOG("unexpected flags (got 0x%x)", request->flags);
- rc = -EINVAL;
- goto out;
+ return -EINVAL;
}
- rc = 0;
-
-out:
- client->recv_coroutine = NULL;
- nbd_client_receive_next_request(client);
-
- return rc;
+ return 0;
}
/* Owns a reference to the NBDClient passed as opaque. */
@@ -1168,8 +1076,9 @@ static coroutine_fn void nbd_trip(void *opaque)
NBDRequestData *req;
NBDRequest request = { 0 }; /* GCC thinks it can be used uninitialized */
NBDReply reply;
- ssize_t ret;
+ int ret;
int flags;
+ int reply_data_len = 0;
TRACE("Reading request.");
if (client->closing) {
@@ -1179,11 +1088,10 @@ static coroutine_fn void nbd_trip(void *opaque)
req = nbd_request_get(client);
ret = nbd_co_receive_request(req, &request);
- if (ret == -EAGAIN) {
- goto done;
- }
+ client->recv_coroutine = NULL;
+ nbd_client_receive_next_request(client);
if (ret == -EIO) {
- goto out;
+ goto disconnect;
}
reply.handle = request.handle;
@@ -1191,7 +1099,7 @@ static coroutine_fn void nbd_trip(void *opaque)
if (ret < 0) {
reply.error = -ret;
- goto error_reply;
+ goto reply;
}
if (client->closing) {
@@ -1212,7 +1120,7 @@ static coroutine_fn void nbd_trip(void *opaque)
if (ret < 0) {
LOG("flush failed");
reply.error = -ret;
- goto error_reply;
+ break;
}
}
@@ -1221,12 +1129,12 @@ static coroutine_fn void nbd_trip(void *opaque)
if (ret < 0) {
LOG("reading from file failed");
reply.error = -ret;
- goto error_reply;
+ break;
}
+ reply_data_len = request.len;
TRACE("Read %" PRIu32" byte(s)", request.len);
- if (nbd_co_send_reply(req, &reply, request.len) < 0)
- goto out;
+
break;
case NBD_CMD_WRITE:
TRACE("Request type is WRITE");
@@ -1234,7 +1142,7 @@ static coroutine_fn void nbd_trip(void *opaque)
if (exp->nbdflags & NBD_FLAG_READ_ONLY) {
TRACE("Server is read-only, return error");
reply.error = EROFS;
- goto error_reply;
+ break;
}
TRACE("Writing to device");
@@ -1248,21 +1156,16 @@ static coroutine_fn void nbd_trip(void *opaque)
if (ret < 0) {
LOG("writing to file failed");
reply.error = -ret;
- goto error_reply;
}
- if (nbd_co_send_reply(req, &reply, 0) < 0) {
- goto out;
- }
break;
-
case NBD_CMD_WRITE_ZEROES:
TRACE("Request type is WRITE_ZEROES");
if (exp->nbdflags & NBD_FLAG_READ_ONLY) {
TRACE("Server is read-only, return error");
reply.error = EROFS;
- goto error_reply;
+ break;
}
TRACE("Writing to device");
@@ -1279,14 +1182,9 @@ static coroutine_fn void nbd_trip(void *opaque)
if (ret < 0) {
LOG("writing to file failed");
reply.error = -ret;
- goto error_reply;
}
- if (nbd_co_send_reply(req, &reply, 0) < 0) {
- goto out;
- }
break;
-
case NBD_CMD_DISC:
/* unreachable, thanks to special case in nbd_co_receive_request() */
abort();
@@ -1299,9 +1197,7 @@ static coroutine_fn void nbd_trip(void *opaque)
LOG("flush failed");
reply.error = -ret;
}
- if (nbd_co_send_reply(req, &reply, 0) < 0) {
- goto out;
- }
+
break;
case NBD_CMD_TRIM:
TRACE("Request type is TRIM");
@@ -1311,21 +1207,19 @@ static coroutine_fn void nbd_trip(void *opaque)
LOG("discard failed");
reply.error = -ret;
}
- if (nbd_co_send_reply(req, &reply, 0) < 0) {
- goto out;
- }
+
break;
default:
LOG("invalid request type (%" PRIu32 ") received", request.type);
reply.error = EINVAL;
- error_reply:
- /* We must disconnect after NBD_CMD_WRITE if we did not
- * read the payload.
- */
- if (nbd_co_send_reply(req, &reply, 0) < 0 || !req->complete) {
- goto out;
- }
- break;
+ }
+
+reply:
+ /* We must disconnect after NBD_CMD_WRITE if we did not
+ * read the payload.
+ */
+ if (nbd_co_send_reply(req, &reply, reply_data_len) < 0 || !req->complete) {
+ goto disconnect;
}
TRACE("Request/Reply complete");
@@ -1335,9 +1229,9 @@ done:
nbd_client_put(client);
return;
-out:
+disconnect:
nbd_request_put(req);
- client_close(client);
+ client_close(client, true);
nbd_client_put(client);
}
@@ -1352,8 +1246,7 @@ static void nbd_client_receive_next_request(NBDClient *client)
static coroutine_fn void nbd_co_client_start(void *opaque)
{
- NBDClientNewData *data = opaque;
- NBDClient *client = data->client;
+ NBDClient *client = opaque;
NBDExport *exp = client->exp;
if (exp) {
@@ -1362,25 +1255,28 @@ static coroutine_fn void nbd_co_client_start(void *opaque)
}
qemu_co_mutex_init(&client->send_lock);
- if (nbd_negotiate(data)) {
- client_close(client);
- goto out;
+ if (nbd_negotiate(client)) {
+ client_close(client, false);
+ return;
}
nbd_client_receive_next_request(client);
-
-out:
- g_free(data);
}
+/*
+ * Create a new client listener on the given export @exp, using the
+ * given channel @sioc. Begin servicing it in a coroutine. When the
+ * connection closes, call @close_fn with an indication of whether the
+ * client completed negotiation.
+ */
void nbd_client_new(NBDExport *exp,
QIOChannelSocket *sioc,
QCryptoTLSCreds *tlscreds,
const char *tlsaclname,
- void (*close_fn)(NBDClient *))
+ void (*close_fn)(NBDClient *, bool))
{
NBDClient *client;
- NBDClientNewData *data = g_new(NBDClientNewData, 1);
+ Coroutine *co;
client = g_malloc0(sizeof(NBDClient));
client->refcount = 1;
@@ -1394,9 +1290,8 @@ void nbd_client_new(NBDExport *exp,
object_ref(OBJECT(client->sioc));
client->ioc = QIO_CHANNEL(sioc);
object_ref(OBJECT(client->ioc));
- client->close = close_fn;
+ client->close_fn = close_fn;
- data->client = client;
- data->co = qemu_coroutine_create(nbd_co_client_start, data);
- qemu_coroutine_enter(data->co);
+ co = qemu_coroutine_create(nbd_co_client_start, client);
+ qemu_coroutine_enter(co);
}