# Exit with various exit codes
import sys
# Path manipulation
import os
# Manipulate style and content of logs
import logging
from rich.logging import RichHandler
# Use a config file
import configparser
# Verify IP address family
import ipaddress
# Resolve host names
import dns.resolver
# Validate if string is fqdn
import validators
# Build XML structure
import lxml.etree
import lxml.builder


# Exit codes
# 1: Config file invalid, it has no sections
# 2: Config file invalid, sections must define at least CONST.CFG_MANDATORY


class CONST(object):
    __slots__ = ()
    LOG_FORMAT = "%(message)s"
    # How to find a config file
    CFG_THIS_FILE_DIRNAME = os.path.dirname(__file__)
    CFG_DEFAULT_FILENAME = "config.ini"
    CFG_DEFAULT_ABS_PATH = os.path.join(CFG_THIS_FILE_DIRNAME, CFG_DEFAULT_FILENAME)
    # Values you don't have to set, these are their internal defaults
    CFG_KNOWN_DEFAULTS = [
        {"key": "self_name", "value": "update-firewall-source"},
        {"key": "tmp_base_dir", "value": os.path.join(CFG_THIS_FILE_DIRNAME, "data/tmp/%(self_name)s")},
        {"key": "state_base_dir", "value": os.path.join(CFG_THIS_FILE_DIRNAME, "data/var/lib/%(self_name)s")},
        {"key": "state_files_dir", "value": "%(state_base_dir)s/state"},
        {"key": "state_file_retention", "value": "50"},
        {"key": "state_file_name_prefix", "value": "state-"},
        {"key": "state_file_name_suffix", "value": ".log"},
        {"key": "update_firewall_source_some_option", "value": "http://localhost:8000/api/query"},
        {"key": "another_option", "value": "first"}
    ]
    # In all sections other than 'default' the following settings are known and accepted. We silently ignore other
    # settings. We use 'is_mandatory' to determine if we have to raise errors on missing settings.
    CFG_KNOWN_SECTION = [
        {"key": "addr", "is_mandatory": True}
    ]
    CFG_MANDATORY = [section_cfg["key"] for section_cfg in CFG_KNOWN_SECTION if section_cfg["is_mandatory"]]


logging.basicConfig(
    # Default for all modules is NOTSET so log everything
    level="NOTSET",
    format=CONST.LOG_FORMAT,
    datefmt="[%X]",
    handlers=[RichHandler(
        rich_tracebacks=True
    )]
)
log = logging.getLogger("rich")
# Our own code logs with this level
log.setLevel(logging.DEBUG)


# Use this version of class ConfigParser to log.debug contents of our config file. When parsing sections other than
# 'default' we don't want to reprint defaults over and over again. This custom class achieves that.
class ConfigParser(
        configparser.ConfigParser):
    """Can get options() without defaults

    Taken from https://stackoverflow.com/a/12600066.
    """

    def options(self, section, no_defaults=False, **kwargs):
        if no_defaults:
            try:
                return list(self._sections[section].keys())
            except KeyError:
                raise configparser.NoSectionError(section)
        else:
            return super().options(section)


# arg_allow_list = ["77.13.129.237", "2a0b:7080:20::1:f485", "home.seneve.de", "208.87.98.188", "outlook.com",
#                   "uberspace.de"]


ini_defaults = []
internal_defaults = {default["key"]: default["value"] for default in CONST.CFG_KNOWN_DEFAULTS}
config = ConfigParser(defaults=internal_defaults,
                      converters={'list': lambda x: [i.strip() for i in x.split(',')]})
config.read(CONST.CFG_DEFAULT_ABS_PATH)


def print_section_header(
        header: str) -> str:
    return f"Loading config section '[{header}]' ..."


def validate_default_section(
        config_obj: configparser.ConfigParser()) -> None:
    log.debug(f"Loading config from file '{CONST.CFG_DEFAULT_ABS_PATH}' ...")
    if not config_obj.sections():
        log.debug(f"No config sections found in '{CONST.CFG_DEFAULT_ABS_PATH}'. Exiting 1 ...")
        sys.exit(1)
    if config.defaults():
        log.debug(f"Symbol legend:\n"
                  f"* Global default from section '[{config_obj.default_section}]'\n"
                  f"~ Local option, doesn't exist in '[{config_obj.default_section}]'\n"
                  f"+ Local override of a value from '[{config_obj.default_section}]'\n"
                  f"= Local override, same value as in '[{config_obj.default_section}]'")
        log.debug(print_section_header(config_obj.default_section))
        for default in config_obj.defaults():
            ini_defaults.append({default: config_obj[config_obj.default_section][default]})
            log.debug(f"* {default} = {config_obj[config_obj.default_section][default]}")
    else:
        log.debug(f"No defaults defined")


def config_has_valid_section(
        config_obj: configparser.ConfigParser()) -> bool:
    has_valid_section = False
    for config_obj_section in config_obj.sections():
        if set(CONST.CFG_MANDATORY).issubset(config_obj.options(config_obj_section)):
            has_valid_section = True
            break
    return has_valid_section


def is_default(
        config_key: str) -> bool:
    return any(config_key in ini_default for ini_default in ini_defaults)


def is_same_as_default(
        config_kv_pair: dict) -> bool:
    return config_kv_pair in ini_defaults


def validate_config_sections(
        config_obj: configparser.ConfigParser()) -> None:
    for this_section in config_obj.sections():
        log.debug(print_section_header(this_section))
        if not set(CONST.CFG_MANDATORY).issubset(config_obj.options(this_section, no_defaults=True)):
            log.debug(f"Config section '[{this_section}]' does not have all mandatory options "
                      f"{CONST.CFG_MANDATORY} set, skipping section ...")
            config_obj.remove_section(this_section)
        else:
            for key in config_obj.options(this_section, no_defaults=True):
                kv_prefix = "~"
                if is_default(key):
                    kv_prefix = "+"
                    if is_same_as_default({key: config_obj[this_section][key]}):
                        kv_prefix = "="
                log.debug(f"{kv_prefix} {key} = {config_obj[this_section][key]}")


def gen_fw_rule_xml(ip_addresses: dict[str, list]) -> lxml.builder.ElementMaker:
    len_ipv4_addresses = len(ip_addresses["ipv4"])
    len_ipv6_addresses = len(ip_addresses["ipv6"])
    data = lxml.builder.ElementMaker()

    direct_tag = data.direct
    chain_tag = data.chain
    rule_tag = data.rule
    fw_rule_data = direct_tag(
        chain_tag(ipv="ipv4", table="filter", chain="DOCKER-USER"),
        # rule_tag("-s 208.87.98.188 -j DROP", ipv="ipv4", table="filter", chain="DOCKER-USER", priority="0"),
        chain_tag(ipv="ipv6", table="filter", chain="DOCKER-USER"),
        # rule_tag("-s 2a0b:7080:20::1:f485 -j DROP", ipv="ipv6", table="filter", chain="DOCKER-USER", priority="0")
        *(rule_tag(f"-s {addr} -j DROP", ipv=f"ipv4", table=f"filter", chain="DOCKER-USER", priority=f"{count}")
          for count, addr in enumerate(ip_addresses["ipv4"])),
        *(rule_tag(f"-s {addr} -j DROP", ipv=f"ipv6", table=f"filter", chain="DOCKER-USER", priority=f"{count}")
          for count, addr in enumerate(ip_addresses["ipv6"])),
        rule_tag(f"-s -j DROP", ipv="ipv4", table="filter", chain="DOCKER-USER", priority=f"{len_ipv4_addresses}"),
        rule_tag(f"-s -j DROP", ipv="ipv6", table="filter", chain="DOCKER-USER", priority=f"{len_ipv6_addresses}")
    )

    # fw_rule_data_str = lxml.etree.tostring(
    #     fw_rule_data,
    #     pretty_print=True,
    #     xml_declaration=True,
    #     encoding="UTF-8").decode()
    # log.debug(f"{fw_rule_data_str}")

    return fw_rule_data


def resolve_domain(domain: str) -> list[str]:
    log.debug(f"Resolving DNS A and AAAA records for '{domain}' ...")
    try:
        a_records = dns.resolver.resolve(domain, rdtype=dns.rdatatype.A)
    except dns.resolver.NoAnswer:
        log.debug(f"DNS didn't return an A record for '{domain}', ignoring ...")
        a_records = []
    try:
        aaaa_records = dns.resolver.resolve(domain, rdtype=dns.rdatatype.AAAA)
    except dns.resolver.NoAnswer:
        log.debug(f"DNS didn't return a AAAA record for '{domain}', ignoring ...")
        aaaa_records = []

    dns_records = []
    [dns_records.append(dns_record.address) for dns_record in a_records if a_records]
    [dns_records.append(dns_record.address) for dns_record in aaaa_records if aaaa_records]
    log.debug(f"Found records: {dns_records}")
    return dns_records


def resolve_addresses(allow_list_mixed: list[str]) -> dict[str, list]:
    allow_sources = {"ipv4": [], "ipv6": []}
    allow_list_ip_only = []

    for allow_source in allow_list_mixed:
        if validators.domain(allow_source):
            log.debug(f"'{allow_source}' is a domain.")
            [allow_list_ip_only.append(addr) for addr in resolve_domain(allow_source)]
        else:
            allow_list_ip_only.append(allow_source)

    for allow_source in allow_list_ip_only:
        try:
            ipv4_addr = str(ipaddress.IPv4Address(allow_source))
            log.debug(f"Adding IPv4 address '{allow_source}' ...")
            allow_sources["ipv4"].append(ipv4_addr)
        except ipaddress.AddressValueError:
            log.debug(f"Address '{allow_source}' is not a valid IPv4 address. Trying to match against IPv6 ...")
            try:
                ipv6_addr = str(ipaddress.IPv6Address(allow_source))
                log.debug(f"Adding IPv6 address '{allow_source}' ...")
                allow_sources["ipv6"].append(ipv6_addr)
            except ipaddress.AddressValueError:
                log.warning(f"Address '{allow_source}' is not a valid IPv6 address either. Ignoring ...")

    return allow_sources


if __name__ == '__main__':
    validate_default_section(config)
    if config_has_valid_section(config):
        validate_config_sections(config)
    else:
        log.debug(f"No valid config section found. A valid config section has at least the mandatory options "
                  f"{CONST.CFG_MANDATORY} set. Exiting 2 ...")
        sys.exit(2)

    log.debug(f"Iterating over config sections ...")
    for section in config.sections():
        log.debug(f"Processing section '[{section}]' ...")
        log.debug(config.getlist(section, "addr"))

# arg_allow_sources = resolve_addresses(arg_allow_list)
# gen_fw_rule_xml(arg_allow_sources)