From e11014c0443ea687ad65a14b9124aa366da7984a Mon Sep 17 00:00:00 2001
From: David Härdeman <david@hardeman.nu>
Date: Sat, 20 Jun 2020 12:53:25 +0200
Subject: Introduce helper for checking if a task is dead

---
 announce.c | 16 ++++++++++++----
 cfgdir.c   | 13 +++----------
 idle.c     | 13 +++++++++----
 igmp.c     |  6 +++++-
 main.c     |  5 +----
 main.h     | 15 +++++++++++++++
 proxy.c    | 11 +++++++++++
 rcon.c     | 29 ++++++++++++-----------------
 server.c   | 12 +++---------
 uring.c    |  2 ++
 utils.c    |  9 +++++++--
 11 files changed, 80 insertions(+), 51 deletions(-)

diff --git a/announce.c b/announce.c
index b90d983..ecb48af 100644
--- a/announce.c
+++ b/announce.c
@@ -27,10 +27,20 @@ mcast_free(struct uring_task *task)
 static void
 mcast_sent(struct cfg *cfg, struct uring_task *task, int res)
 {
+	struct server *server;
+
 	if (res < 0)
 		error("failure %i\n", res);
 	else
 		debug(DBG_ANN, "result %i\n", res);
+
+	if (!task || !task->tbuf) {
+		error("task or task->tbuf not set\n");
+		return;
+	}
+
+	server = container_of(task->tbuf, struct server, mcast_buf);
+	uring_task_put(cfg, &server->task);
 }
 
 static void
@@ -52,6 +62,7 @@ mcast_send(struct cfg *cfg, struct announce *aev, struct server *server)
 
 	server->mcast_buf.len = len;
 	uring_task_set_buf(&aev->mcast_task, &server->mcast_buf);
+	uring_task_get(cfg, &server->task);
 	uring_tbuf_sendmsg(cfg, &aev->mcast_task, mcast_sent);
 }
 
@@ -71,10 +82,7 @@ announce_cb(struct cfg *cfg, struct uring_task *task, int res)
 {
 	struct announce *aev = container_of(task, struct announce, task);
 
-	if (task->dead) {
-		debug(DBG_ANN, "task is dead\n");
-		return;
-	}
+	assert_task_alive(DBG_ANN, task);
 
 	if (res != sizeof(aev->value))
 		perrordie("timerfd_read");
diff --git a/cfgdir.c b/cfgdir.c
index 364f58e..fc1633b 100644
--- a/cfgdir.c
+++ b/cfgdir.c
@@ -341,8 +341,7 @@ scfg_read_cb(struct cfg *cfg, struct uring_task *task, int res)
 {
 	struct server *scfg = container_of(task, struct server, task);
 
-	if (task->dead)
-		return;
+	assert_task_alive(DBG_CFG, task);
 
 	if (res <= 0) {
 		error("error reading config file for %s: %s\n",
@@ -361,10 +360,7 @@ scfg_open_cb(struct cfg *cfg, struct uring_task *task, int res)
 {
 	struct server *scfg = container_of(task, struct server, task);
 
-	if (task->dead) {
-		debug(DBG_CFG, "task dead\n");
-		return;
-	}
+	assert_task_alive(DBG_CFG, task);
 
 	if (res < 0) {
 		error("open(%s) failed: %s\n", scfg->name, strerror(-res));
@@ -464,10 +460,7 @@ inotify_cb(struct cfg *cfg, struct uring_task *task, int res)
 	char *ptr;
 	struct server *scfg;
 
-	if (task->dead) {
-		debug(DBG_CFG, "task is dead\n");
-		return;
-	}
+	assert_task_alive(DBG_CFG, task);
 
 	if (res <= 0)
 		perrordie("inotify_read (%i)", res);
diff --git a/idle.c b/idle.c
index 5f7ed17..3be8974 100644
--- a/idle.c
+++ b/idle.c
@@ -121,6 +121,8 @@ idle_check_handshake_complete(struct cfg *cfg, struct uring_task *task, int res)
 	int32_t mclen;
 	int r;
 
+	assert_task_alive_or(DBG_IDLE, task, return -EINTR);
+
 	remain = task->tbuf->len;
 	pos = task->tbuf->buf;
 
@@ -195,6 +197,8 @@ idle_check_handshake_reply(struct cfg *cfg, struct uring_task *task, int res)
 	int player_count;
 	int r;
 
+	assert_task_alive(DBG_IDLE, task);
+
 	debug(DBG_IDLE, "res: %i\n", res);
 	if (res < 0)
 		goto out;
@@ -277,6 +281,8 @@ idle_check_handshake_sent(struct cfg *cfg, struct uring_task *task, int res)
 {
 	struct idle *idle = container_of(task, struct idle, idlecheck);
 
+	assert_task_alive(DBG_IDLE, task);
+
 	debug(DBG_IDLE, "sent %i bytes\n", res);
 	if (res < 0) {
 		uring_task_close_fd(cfg, task);
@@ -298,6 +304,8 @@ idle_check_connected_cb(struct cfg *cfg, struct connection *conn, bool connected
 	uint16_t port;
 	char hostname[INET6_ADDRSTRLEN];
 
+	assert_task_alive(DBG_IDLE, &idle->idlecheck);
+
 	if (!connected) {
 		debug(DBG_IDLE,
 		      "idle check connection to remote server (%s) failed\n",
@@ -337,10 +345,7 @@ idle_cb(struct cfg *cfg, struct uring_task *task, int res)
 {
 	struct idle *idle = container_of(task, struct idle, task);
 
-	if (task->dead) {
-		debug(DBG_IDLE, "task is dead\n");
-		return;
-	}
+	assert_task_alive(DBG_IDLE, task);
 
 	if (res != sizeof(idle->value)) {
 		error("timerfd_read returned %i\n", res);
diff --git a/igmp.c b/igmp.c
index 26fe56f..36f63e2 100644
--- a/igmp.c
+++ b/igmp.c
@@ -391,8 +391,12 @@ igmp_read_cb(struct cfg *cfg, struct uring_task *task, int res)
 
 	debug(DBG_IGMP, "task %p, igmp %p, res %i\n", task, igmp, res);
 
-	if (res < 0 || task->dead)
+	assert_task_alive(DBG_IGMP, task);
+
+	if (res < 0) {
+		error("res: %i\n", res);
 		return;
+	}
 
 	task->tbuf->len = res;
 
diff --git a/main.c b/main.c
index 29ef40c..749d1e8 100644
--- a/main.c
+++ b/main.c
@@ -576,10 +576,7 @@ signalfd_read(struct cfg *cfg, struct uring_task *task, int res)
 	struct signalfd_ev *sev = container_of(task, struct signalfd_ev, task);
 	struct server *server, *stmp;
 
-	if (task->dead) {
-		debug(DBG_SIG, "task dead\n");
-		return;
-	}
+	assert_task_alive(DBG_SIG, task);
 
 	if (res != sizeof(sev->buf))
 		die("error in signalfd (%i)", res);
diff --git a/main.h b/main.h
index 39c2440..7d14118 100644
--- a/main.h
+++ b/main.h
@@ -96,6 +96,21 @@ struct uring_task {
 	void *priv;
 };
 
+#define assert_task_alive_or(lvl, t, cmd) 	\
+do {						\
+	if (!(t)) {				\
+		error("invalid task\n");	\
+		cmd;				\
+	}					\
+						\
+	if ((t)->dead) {			\
+		debug((lvl), "task dead\n");	\
+		cmd;				\
+	}					\
+} while(0)
+
+#define assert_task_alive(lvl, t) assert_task_alive_or((lvl), (t), return)
+
 struct cfg {
 	uid_t uid;
 	gid_t gid;
diff --git a/proxy.c b/proxy.c
index bcc2b6d..7fb6c34 100644
--- a/proxy.c
+++ b/proxy.c
@@ -122,6 +122,8 @@ proxy_client_data_out(struct cfg *cfg, struct uring_task *task, int res)
 {
 	struct server_proxy *proxy = container_of(task, struct server_proxy, clienttask);
 
+	assert_task_alive(DBG_PROXY, task);
+
 	if (res <= 0) {
 		debug(DBG_PROXY, "%s: result was %i\n", proxy->scfg->name, res);
 		uring_task_close_fd(cfg, task);
@@ -139,6 +141,8 @@ proxy_client_data_in(struct cfg *cfg, struct uring_task *task, int res)
 {
 	struct server_proxy *proxy = container_of(task, struct server_proxy, clienttask);
 
+	assert_task_alive(DBG_PROXY, task);
+
 	if (res <= 0) {
 		debug(DBG_PROXY, "%s: result was %i\n", proxy->scfg->name, res);
 		uring_task_close_fd(cfg, task);
@@ -157,6 +161,8 @@ proxy_server_data_out(struct cfg *cfg, struct uring_task *task, int res)
 {
 	struct server_proxy *proxy = container_of(task, struct server_proxy, servertask);
 
+	assert_task_alive(DBG_PROXY, task);
+
 	if (res <= 0) {
 		debug(DBG_PROXY, "%s: result was %i\n", proxy->scfg->name, res);
 		uring_task_close_fd(cfg, task);
@@ -174,6 +180,8 @@ proxy_server_data_in(struct cfg *cfg, struct uring_task *task, int res)
 {
 	struct server_proxy *proxy = container_of(task, struct server_proxy, servertask);
 
+	assert_task_alive(DBG_PROXY, task);
+
 	if (res <= 0) {
 		debug(DBG_PROXY, "%s: result was %i\n", proxy->scfg->name, res);
 		uring_task_close_fd(cfg, task);
@@ -190,6 +198,9 @@ proxy_connected_cb(struct cfg *cfg, struct connection *conn, bool connected)
 {
 	struct server_proxy *proxy = container_of(conn, struct server_proxy, server_conn);
 
+	assert_task_alive(DBG_PROXY, &proxy->clienttask);
+	assert_task_alive(DBG_PROXY, &proxy->servertask);
+
 	if (!connected) {
 		error("%s: proxy connection to remote server failed\n",
 		      proxy->scfg->name);
diff --git a/rcon.c b/rcon.c
index 33fcdb7..e7c37ce 100644
--- a/rcon.c
+++ b/rcon.c
@@ -8,6 +8,7 @@
 #include <arpa/inet.h>
 #include <stdint.h>
 #include <inttypes.h>
+#include <errno.h>
 
 #include "main.h"
 #include "uring.h"
@@ -123,7 +124,8 @@ enum rcon_packet_type {
 };
 
 static void
-create_packet(struct cfg *cfg, struct rcon *rcon, int32_t reqid, enum rcon_packet_type type, const char *msg)
+create_packet(struct cfg *cfg, struct rcon *rcon, int32_t reqid,
+	      enum rcon_packet_type type, const char *msg)
 {
 	char *pos = &rcon->tbuf.buf[4];
 
@@ -150,6 +152,8 @@ packet_complete(struct cfg *cfg, struct uring_task *task, int res)
 	size_t len = task->tbuf->len;
 	int32_t plen;
 
+	assert_task_alive_or(DBG_RCON, task, return -EINTR);
+
 	if (task->tbuf->len < 14)
 		return 0;
 
@@ -212,10 +216,7 @@ rcon_stop_reply(struct cfg *cfg, struct uring_task *task, int res)
 	int32_t type;
 	char *msg;
 
-	if (task->dead) {
-		debug(DBG_RCON, "task dead\n");
-		return;
-	}
+	assert_task_alive(DBG_RCON, task);
 
 	if (res < 0) {
 		debug(DBG_RCON, "res: %i\n", res);
@@ -244,10 +245,7 @@ rcon_stop_sent(struct cfg *cfg, struct uring_task *task, int res)
 {
 	struct rcon *rcon = container_of(task, struct rcon, task);
 
-	if (task->dead) {
-		debug(DBG_RCON, "task dead\n");
-		return;
-	}
+	assert_task_alive(DBG_RCON, task);
 
 	if (res < 0) {
 		debug(DBG_RCON, "res: %i\n", res);
@@ -267,10 +265,7 @@ rcon_login_reply(struct cfg *cfg, struct uring_task *task, int res)
 	int32_t type;
 	char *msg;
 
-	if (task->dead) {
-		debug(DBG_RCON, "task dead\n");
-		return;
-	}
+	assert_task_alive(DBG_RCON, task);
 
 	if (res < 0) {
 		debug(DBG_RCON, "res: %i\n", res);
@@ -305,10 +300,7 @@ rcon_login_sent(struct cfg *cfg, struct uring_task *task, int res)
 {
 	struct rcon *rcon = container_of(task, struct rcon, task);
 
-	if (task->dead) {
-		debug(DBG_RCON, "task dead\n");
-		return;
-	}
+	assert_task_alive(DBG_RCON, task);
 
 	if (res < 0) {
 		debug(DBG_RCON, "res: %i\n", res);
@@ -325,9 +317,12 @@ rcon_connected_cb(struct cfg *cfg, struct connection *conn, bool connected)
 {
 	struct rcon *rcon = container_of(conn, struct rcon, conn);
 
+	assert_task_alive(DBG_RCON, &rcon->task);
+
 	if (!connected) {
 		error("rcon connection to remote server (%s) failed\n",
 		      rcon->server->name);
+		uring_task_put(cfg, &rcon->task);
 		return;
 	}
 
diff --git a/server.c b/server.c
index c4bbc0c..edb0551 100644
--- a/server.c
+++ b/server.c
@@ -195,10 +195,7 @@ server_local_accept(struct cfg *cfg, struct uring_task *task, int res)
 
 	debug(DBG_SRV, "task %p, res %i, scfg %s\n", task, res, scfg->name);
 
-	if (task->dead) {
-		debug(DBG_SRV, "task dead\n");
-		return;
-	}
+	assert_task_alive(DBG_SRV, task);
 
 	if (res < 0) {
 		error("result was %i\n", res);
@@ -299,11 +296,8 @@ server_exec_done(struct cfg *cfg, struct uring_task *task, int res)
 	int r;
 	siginfo_t info;
 
-	if (task->dead) {
-		/* Should we leave child processes running? */
-		debug(DBG_SRV, "task dead\n");
-		goto out;
-	}
+	/* Should we leave child processes running? */
+	assert_task_alive_or(DBG_SRV, task, goto out);
 
 	if (!(res & POLLIN)) {
 		error("unexpected result: %i\n", res);
diff --git a/uring.c b/uring.c
index e1fad53..5e1b168 100644
--- a/uring.c
+++ b/uring.c
@@ -386,6 +386,8 @@ uring_tbuf_read_until(struct cfg *cfg, struct uring_task *task,
 static int
 uring_tbuf_eof(struct cfg *cfg, struct uring_task *task, int res)
 {
+	assert_task_alive_or(DBG_UR, task, return -EINTR);
+
 	if (task->tbuf->len + 1 >= sizeof(task->tbuf->buf))
 		return -E2BIG;
 
diff --git a/utils.c b/utils.c
index 8c7b663..b07fdff 100644
--- a/utils.c
+++ b/utils.c
@@ -237,6 +237,8 @@ connect_next(struct cfg *cfg, struct uring_task *task, struct connection *conn)
         unsigned i;
 
 again:
+	assert_task_alive_or(DBG_UR, task, goto out);
+
 	i = 0;
         remote = NULL;
         list_for_each_entry(tmp, conn->addrs, list) {
@@ -250,8 +252,7 @@ again:
         if (!remote) {
 		debug(DBG_UR, "%s: no more remote addresses to attempt\n",
 		      task->name);
-		conn->callback(cfg, conn, false);
-                return;
+		goto out;
         }
 
 	conn->next_addr++;
@@ -270,6 +271,10 @@ again:
 	task->priv = conn;
         uring_task_set_fd(task, sfd);
         uring_connect(cfg, task, &conn->remote, connect_cb);
+	return;
+
+out:
+	conn->callback(cfg, conn, false);
 }
 
 void
-- 
cgit v1.2.3