/*#############################################################################
#                                                                             #
# IPFire.org - A linux based firewall                                         #
# Copyright (C) 2023 IPFire Network Development Team                          #
#                                                                             #
# This program is free software: you can redistribute it and/or modify        #
# it under the terms of the GNU General Public License as published by        #
# the Free Software Foundation, either version 3 of the License, or           #
# (at your option) any later version.                                         #
#                                                                             #
# This program is distributed in the hope that it will be useful,             #
# but WITHOUT ANY WARRANTY; without even the implied warranty of              #
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the               #
# GNU General Public License for more details.                                #
#                                                                             #
# You should have received a copy of the GNU General Public License           #
# along with this program.  If not, see <http://www.gnu.org/licenses/>.       #
#                                                                             #
#############################################################################*/

#include <errno.h>
#include <fcntl.h>
#include <getopt.h>
#include <limits.h>
#include <stdlib.h>
#include <unistd.h>

#include <systemd/sd-bus.h>
#include <systemd/sd-daemon.h>
#include <systemd/sd-device.h>
#include <systemd/sd-event.h>
#include <systemd/sd-netlink.h>

#include "bus.h"
#include "config.h"
#include "daemon.h"
#include "devmon.h"
#include "link.h"
#include "links.h"
#include "logging.h"
#include "ports.h"
#include "stats-collector.h"
#include "string.h"
#include "zone.h"
#include "zones.h"

// Increase the receive buffer to 128 MiB
#define RCVBUF_SIZE						128 * 1024 * 1024

struct nw_daemon {
	int nrefs;

	// Configuration
	nw_configd* configd;
	nw_config* config;

	// Event Loop
	sd_event* loop;

	// Netlink Connection
	sd_netlink* rtnl;

	// DBus Connection
	sd_bus* bus;

	// udev Connection
	sd_device_monitor* devmon;

	// Links
	nw_links* links;

	// Zones
	nw_zones* zones;

	// Ports
	nw_ports* ports;

	// Stats Collector
	sd_event_source* stats_collector_event;
};

static int __nw_daemon_terminate(sd_event_source* source, const struct signalfd_siginfo* si,
		void* data) {
	DEBUG("Received signal to terminate...\n");

	return sd_event_exit(sd_event_source_get_event(source), 0);
}

static int __nw_daemon_reload(sd_event_source* source, const struct signalfd_siginfo* si,
		void* data) {
	nw_daemon* daemon = (nw_daemon*)daemon;

	DEBUG("Received signal to reload...\n");

	// Reload the daemon
	nw_daemon_reload(daemon);

	return 0;
}

/*
	Configuration
*/

static int nw_daemon_config_open(nw_daemon* daemon, const char* path) {
	int r;

	// Open the configuration directory
	r = nw_configd_create(&daemon->configd, path);
	if (r < 0)
		return r;

	return 0;
}

static int nw_daemon_parse_argv(nw_daemon* daemon, int argc, char* argv[]) {
	enum {
		ARG_CONFIG,
	};
	int r;

	static const struct option options[] = {
		{ "config", required_argument, NULL, ARG_CONFIG },
		{ NULL },
	};
	int c;

	for (;;) {
		c = getopt_long(argc, argv, "", options, NULL);
		if (c < 0)
			break;

		switch (c) {
			case ARG_CONFIG:
				r = nw_daemon_config_open(daemon, optarg);
				if (r < 0)
					return r;
				break;

			// Abort on any unrecognised option
			default:
				return -EINVAL;
		}
	}

	return 0;
}

static int nw_daemon_setup_loop(nw_daemon* daemon) {
	int r;

	// Fetch a reference to the default event loop
	r = sd_event_default(&daemon->loop);
	if (r < 0) {
		ERROR("Could not setup event loop: %s\n", strerror(-r));
		return 1;
	}

	// Enable the watchdog
	r = sd_event_set_watchdog(daemon->loop, 1);
	if (r < 0) {
		ERROR("Could not activate watchdog: %s\n", strerror(-r));
		return 1;
	}

	// Listen for SIGTERM
	r = sd_event_add_signal(daemon->loop, NULL, SIGTERM|SD_EVENT_SIGNAL_PROCMASK,
		__nw_daemon_terminate, daemon);
	if (r < 0) {
		ERROR("Could not register handling SIGTERM: %s\n", strerror(-r));
		return 1;
	}

	// Listen for SIGINT
	r = sd_event_add_signal(daemon->loop, NULL, SIGINT|SD_EVENT_SIGNAL_PROCMASK,
		__nw_daemon_terminate, daemon);
	if (r < 0) {
		ERROR("Could not register handling SIGINT: %s\n", strerror(-r));
		return 1;
	}

	// Listen for SIGHUP
	r = sd_event_add_signal(daemon->loop, NULL, SIGHUP|SD_EVENT_SIGNAL_PROCMASK,
		__nw_daemon_reload, daemon);
	if (r < 0) {
		ERROR("Could not register handling SIGHUP: %s\n", strerror(-r));
		return 1;
	}

	return 0;
}

static int nw_daemon_load_config(nw_daemon* daemon) {
	int r;

	// If no configuration path has been opened yet, we will open something
	if (!daemon->configd) {
		r = nw_daemon_config_open(daemon, CONFIG_DIR);
		if (r < 0)
			return r;
	}

	// Open the configuration file
	return nw_configd_open_config(&daemon->config, daemon->configd, "settings");
}

static int nw_start_device_monitor(nw_daemon* daemon) {
	int r;

	const char* subsystems[] = {
		"net",
		"ieee80211",
		"rfkill",
		NULL,
	};

	// Create a new connection to monitor any devices
	r = sd_device_monitor_new(&daemon->devmon);
	if (r < 0) {
		ERROR("Could not inititalize the device monitor: %m\n");
		return 1;
	}

	// Increase the receive buffer
	r = sd_device_monitor_set_receive_buffer_size(daemon->devmon, RCVBUF_SIZE);
	if (r < 0) {
		ERROR("Could not increase buffer size for the device monitor: %m\n");
		return 1;
	}

	// Filter for events for all relevant subsystems
	for (const char** subsystem = subsystems; *subsystem; subsystem++) {
		r = sd_device_monitor_filter_add_match_subsystem_devtype(
			daemon->devmon, *subsystem, NULL);
		if (r < 1) {
			ERROR("Could not add device monitor for the %s subsystem: %m\n", *subsystem);
			return 1;
		}
	}

	// Attach the device monitor to the event loop
	r = sd_device_monitor_attach_event(daemon->devmon, daemon->loop);
	if (r < 0) {
		ERROR("Could not attach the device monitor to the event loop: %m\n");
		return 1;
	}

	// Start processing events...
	r = sd_device_monitor_start(daemon->devmon, nw_devmon_handle_uevent, daemon);
	if (r < 0) {
		ERROR("Could not start the device monitor: %m\n");
		return 1;
	}

	return 0;
}

static int nw_daemon_connect_rtnl(nw_daemon* daemon, int fd) {
	int r;

	// Connect to Netlink
	r = sd_netlink_open(&daemon->rtnl);
	if (r < 0) {
		ERROR("Could not connect to the kernel's netlink interface: %m\n");
		return 1;
	}

	// Increase the receive buffer
	r = sd_netlink_increase_rxbuf(daemon->rtnl, RCVBUF_SIZE);
	if (r < 0) {
		ERROR("Could not increase receive buffer for the netlink socket: %m\n");
		return 1;
	}

	// Connect it to the event loop
	r = sd_netlink_attach_event(daemon->rtnl, daemon->loop, 0);
	if (r < 0) {
		ERROR("Could not connect the netlink socket to the event loop: %m\n");
		return 1;
	}

	// Register callback for new interfaces
	r = sd_netlink_add_match(daemon->rtnl, NULL, RTM_NEWLINK, nw_link_process, NULL,
			daemon, "networkd-RTM_NEWLINK");
	if (r < 0) {
		ERROR("Could not register RTM_NEWLINK: %m\n");
		return 1;
	}

	// Register callback for deleted interfaces
	r = sd_netlink_add_match(daemon->rtnl, NULL, RTM_DELLINK, nw_link_process, NULL,
			daemon, "networkd-RTM_DELLINK");
	if (r < 0) {
		ERROR("Could not register RTM_DELLINK: %m\n");
		return 1;
	}

	return 0;
}

static int nw_daemon_enumerate_links(nw_daemon* daemon) {
	int r;

	// Create a new links container
	r = nw_links_create(&daemon->links, daemon);
	if (r)
		return r;

	return nw_links_enumerate(daemon->links);
}

static int nw_daemon_enumerate_ports(nw_daemon* daemon) {
	int r;

	// Create a new ports container
	r = nw_ports_create(&daemon->ports, daemon);
	if (r)
		return r;

	return nw_ports_enumerate(daemon->ports);
}

static int nw_daemon_enumerate_zones(nw_daemon* daemon) {
	int r;

	// Create a new zones container
	r = nw_zones_create(&daemon->zones, daemon);
	if (r)
		return r;

	return nw_zones_enumerate(daemon->zones);
}

static int nw_daemon_enumerate(nw_daemon* daemon) {
	int r;

	// Links
	r = nw_daemon_enumerate_links(daemon);
	if (r)
		return r;

	// Ports
	r = nw_daemon_enumerate_ports(daemon);
	if (r)
		return r;

	// Zones
	r = nw_daemon_enumerate_zones(daemon);
	if (r)
		return r;

	return 0;
}

static int __nw_daemon_reconfigure(sd_event_source* s, void* data) {
	nw_daemon* daemon = (nw_daemon*)data;
	int r;

	DEBUG("Reconfiguring...\n");

	// Reconfigure all zones
	r = nw_zones_reconfigure(daemon->zones);
	if (r)
		return r;

	// Reconfigure all ports
	r = nw_ports_reconfigure(daemon->ports);
	if (r)
		return r;

	return 0;
}

static int nw_daemon_reconfigure(nw_daemon* daemon) {
	int r;

	r = sd_event_add_defer(daemon->loop, NULL, __nw_daemon_reconfigure, daemon);
	if (r) {
		ERROR("Could not schedule re-configuration task: %m\n");
		return r;
	}

	return 0;
}

static int nw_daemon_starts_stats_collector(nw_daemon* daemon) {
	sd_event_source* s = NULL;
	int r;

	// Register the stats collector main function
	r = sd_event_add_time_relative(daemon->loop, &s, CLOCK_MONOTONIC, 0, 0,
			nw_stats_collector, daemon);
	if (r < 0) {
		ERROR("Could not start the stats collector: %m\n");
		goto ERROR;
	}

	// Keep calling the stats collector for forever
	r = sd_event_source_set_enabled(s, SD_EVENT_ON);
	if (r < 0)
		goto ERROR;

	// Keep a reference to the event source
	daemon->stats_collector_event = sd_event_source_ref(s);

ERROR:
	if (s)
		sd_event_source_unref(s);

	return r;
}

static int nw_daemon_setup(nw_daemon* daemon) {
	int r;

	// Read the configuration
	r = nw_daemon_load_config(daemon);
	if (r)
		return r;

	// Setup the event loop
	r = nw_daemon_setup_loop(daemon);
	if (r)
		return r;

	// Connect to the kernel's netlink interface
	r = nw_daemon_connect_rtnl(daemon, 0);
	if (r)
		return r;

	// Connect to the system bus
	r = nw_bus_connect(&daemon->bus, daemon->loop, daemon);
	if (r)
		return r;

	// Connect to udev
	r = nw_start_device_monitor(daemon);
	if (r)
		return r;

	// Enumerate everything we need to know
	r = nw_daemon_enumerate(daemon);
	if (r)
		return r;

	// (Re-)configure everything
	r = nw_daemon_reconfigure(daemon);
	if (r)
		return r;

	// Start the stats collector
	r = nw_daemon_starts_stats_collector(daemon);
	if (r)
		return r;

	return 0;
}

int nw_daemon_create(nw_daemon** daemon, int argc, char* argv[]) {
	int r;

	nw_daemon* d = calloc(1, sizeof(*d));
	if (!d)
		return 1;

	// Initialize reference counter
	d->nrefs = 1;

	// Parse command line arguments
	r = nw_daemon_parse_argv(d, argc, argv);
	if (r)
		goto ERROR;

	// Setup the daemon
	r = nw_daemon_setup(d);
	if (r)
		goto ERROR;

	// Set the reference
	*daemon = d;

	return 0;

ERROR:
	nw_daemon_unref(d);

	return r;
}

static void nw_daemon_cleanup(nw_daemon* daemon) {
	if (daemon->ports)
		nw_ports_unref(daemon->ports);
	if (daemon->zones)
		nw_zones_unref(daemon->zones);
	if (daemon->links)
		nw_links_unref(daemon->links);
	if (daemon->config)
		nw_config_unref(daemon->config);
}

static void nw_daemon_free(nw_daemon* daemon) {
	// Cleanup common objects
	nw_daemon_cleanup(daemon);

	if (daemon->configd)
		nw_configd_unref(daemon->configd);
	if (daemon->stats_collector_event)
		sd_event_source_unref(daemon->stats_collector_event);
	if (daemon->bus)
		sd_bus_unref(daemon->bus);
	if (daemon->loop)
		sd_event_unref(daemon->loop);

	free(daemon);
}

nw_daemon* nw_daemon_ref(nw_daemon* daemon) {
	daemon->nrefs++;

	return daemon;
}

nw_daemon* nw_daemon_unref(nw_daemon* daemon) {
	if (--daemon->nrefs > 0)
		return daemon;

	nw_daemon_free(daemon);
	return NULL;
}

/*
	This function contains the main loop of the daemon...
*/
int nw_daemon_run(nw_daemon* daemon) {
	int r;

	// We are now ready to process any requests
	sd_notify(0, "READY=1\n" "STATUS=Processing requests...");

	// Launch the event loop
	r = sd_event_loop(daemon->loop);
	if (r < 0) {
		ERROR("Could not run the event loop: %s\n", strerror(-r));
		goto ERROR;
	}

	// Let systemd know that we are shutting down
	sd_notify(0, "STOPPING=1\n" "STATUS=Shutting down...");

	// Save the configuration
	r = nw_daemon_save(daemon);
	if (r)
		goto ERROR;

	// Cleanup everything
	nw_daemon_cleanup(daemon);

	return 0;

ERROR:
	sd_notifyf(0, "ERRNO=%i", -r);

	// Cleanup everything
	nw_daemon_cleanup(daemon);

	return 1;
}

int nw_daemon_reload(nw_daemon* daemon) {
	DEBUG("Reloading daemon...\n");

	// XXX TODO

	return 0;
}

/*
	Saves the configuration to disk
*/
int nw_daemon_save(nw_daemon* daemon) {
	int r;

	DEBUG("Saving configuration...\n");

#if 0
	// Save settings
	r = nw_config_write(daemon->config, f);
	if (r)
		return r;
#endif

	// Save ports
	r = nw_ports_save(daemon->ports);
	if (r)
		return r;

	// Save zones
	r = nw_zones_save(daemon->zones);
	if (r)
		return r;

	return 0;
}

nw_configd* nw_daemon_configd(nw_daemon* daemon, const char* path) {
	if (!daemon->configd)
		return NULL;

	if (path)
		return nw_configd_descend(daemon->configd, path);

	return nw_configd_ref(daemon->configd);
}

/*
	Bus
*/
sd_bus* nw_daemon_get_bus(nw_daemon* daemon) {
	return daemon->bus;
}

/*
	Netlink
*/
sd_netlink* nw_daemon_get_rtnl(nw_daemon* daemon) {
	return daemon->rtnl;
}

/*
	Links
*/
nw_links* nw_daemon_links(nw_daemon* daemon) {
	return nw_links_ref(daemon->links);
}

void nw_daemon_drop_link(nw_daemon* daemon, nw_link* link) {
	if (!daemon->links)
		return;

	nw_links_drop_link(daemon->links, link);
}

nw_link* nw_daemon_get_link_by_ifindex(nw_daemon* daemon, int ifindex) {
	if (!daemon->links)
		return NULL;

	return nw_links_get_by_ifindex(daemon->links, ifindex);
}

nw_link* nw_daemon_get_link_by_name(nw_daemon* daemon, const char* name) {
	if (!daemon->links)
		return NULL;

	return nw_links_get_by_name(daemon->links, name);
}

/*
	Ports
*/
nw_ports* nw_daemon_ports(nw_daemon* daemon) {
	return nw_ports_ref(daemon->ports);
}

int nw_daemon_ports_walk(nw_daemon* daemon, nw_ports_walk_callback callback, void* data) {
	if (!daemon->ports)
		return 0;

	return nw_ports_walk(daemon->ports, callback, data);
}

nw_port* nw_daemon_get_port_by_name(nw_daemon* daemon, const char* name) {
	if (!daemon->ports)
		return NULL;

	return nw_ports_get_by_name(daemon->ports, name);
}

/*
	Zones
*/

nw_zones* nw_daemon_zones(nw_daemon* daemon) {
	return nw_zones_ref(daemon->zones);
}

int nw_daemon_zones_walk(nw_daemon* daemon, nw_zones_walk_callback callback, void* data) {
	if (!daemon->zones)
		return 0;

	return nw_zones_walk(daemon->zones, callback, data);
}

nw_zone* nw_daemon_get_zone_by_name(nw_daemon* daemon, const char* name) {
	if (!daemon->zones)
		return NULL;

	return nw_zones_get_by_name(daemon->zones, name);
}
