# Exit with various exit codes
import sys
# Path manipulation
import os
# Manipulate style and content of logs
import logging
from rich.logging import RichHandler
# Use a config file
import configparser
# Verify IP address family
import ipaddress
# Resolve host names
import dns.resolver
# Validate if string is fqdn
import validators
# Build XML structure
import lxml.etree
import lxml.builder
# Correctly generate plurals, singular nouns etc.
import inflect
# Restart firewalld systemd service unit
import dbus
# Find physical network interface via 'find' command
import subprocess
# Diff new and existing firewalld direct rules XML structure
import difflib


# Exit codes
# 1 : Config file invalid, it has no sections
# 2 : Config file invalid, sections must define at least CONST.CFG_MANDATORY
# 3 : No physical network device found at "/sys/class/net"
# 4 : Linux find command exited non-zero trying to find a physical network device at "/sys/class/net"
# 5 : Unable to open firewalld direct rules file for reading
# 6 : Kernel sysfs export for network devices at "/sys/class/net" doesn't exist
# 7 : An option that must have a non-null value is either unset or null
# 8 : Exception while adding a chain XML element to firewalld direct rules
# 9 : Unable to open firewalld direct rules file for updating
# 10: Unable to restart systemd firewalld.service unit


class CONST(object):
    __slots__ = ()
    LOG_FORMAT = "%(message)s"
    # How to find a config file
    CFG_THIS_FILE_DIRNAME = os.path.dirname(__file__)
    CFG_DEFAULT_FILENAME = "config.ini"
    CFG_DEFAULT_ABS_PATH = os.path.join(CFG_THIS_FILE_DIRNAME, CFG_DEFAULT_FILENAME)
    # Values you don't have to set, these are their internal defaults. You may optionally add a key 'is_global' equal
    # to either True or False. By default if left off it'll be assumed False. Script will treat values where
    # 'is_global' equals True as not being overridable in a '[section]'. It's a setting that only makes sense in a
    # global context for the entire script. An option where 'empty_ok' equals True can safely be unset or set to
    # an empty string. An example config.ini file may give a sane config example value here, removing that value
    # still results in a valid file.
    CFG_KNOWN_DEFAULTS = [
        {"key": "target", "value": "ACCEPT", "is_global": False, "empty_ok": False},
        {"key": "addr", "value": "", "is_global": False, "empty_ok": True},
        {"key": "ports", "value": "80, 443", "is_global": False, "empty_ok": True},
        {"key": "proto", "value": "tcp", "is_global": False, "empty_ok": True},
        {"key": "state", "value": "NEW", "is_global": False, "empty_ok": True},
        {"key": "do_ipv6", "value": "false", "is_global": False, "empty_ok": False},
        {"key": "firewalld_direct_abs", "value": "/etc/firewalld/direct.xml", "is_global": True, "empty_ok": False},
        {"key": "restart_firewalld_after_change", "value": "true", "is_global": True, "empty_ok": False}
    ]
    # In all sections other than 'default' the following settings are known and accepted. We ignore other settings.
    # Per CFG_KNOWN_DEFAULTS above most '[DEFAULT]' options are accepted by virtue of being defaults and overridable.
    # The only exception are options where "is_global" equals True, they can't be overridden in '[sections]'; any
    # attempt at doing it anyway will be ignored. The main purpose of this list is to name settings that do not have
    # a default value but can - if set - influence how a '[section]' behaves. Repeating a '[DEFAULT]' here does not
    # make sense. We use 'is_mandatory' to determine if we have to raise errors on missing settings. Here
    # 'is_mandatory' means the setting must be given in a '[section]'. It may be empty.
    CFG_KNOWN_SECTION = [
        # {"key": "an_option", "is_mandatory": True},
        # {"key": "another_one", "is_mandatory": False}
    ]
    CFG_MANDATORY = [section_cfg["key"] for section_cfg in CFG_KNOWN_SECTION if section_cfg["is_mandatory"]]


logging.basicConfig(
    # Default for all modules is NOTSET so log everything
    level="NOTSET",
    format=CONST.LOG_FORMAT,
    datefmt="[%X]",
    handlers=[RichHandler(
        show_time=False if any([systemd_env_var in os.environ for systemd_env_var in [
            "SYSTEMD_EXEC_PID",
            "INVOCATION_ID"]]) else True,
        rich_tracebacks=True,
        show_path=False,
        show_level=False
    )]
)
log = logging.getLogger("rich")
# Our own code logs with this level
log.setLevel(os.environ.get("UFS_LOGLEVEL") if "UFS_LOGLEVEL" in [k for k, v in os.environ.items()] else logging.INFO)

p = inflect.engine()


# Use this version of class ConfigParser to log.debug contents of our config file. When parsing sections other than
# 'default' we don't want to reprint defaults over and over again. This custom class achieves that.
class ConfigParser(
        configparser.ConfigParser):
    """Can get options() without defaults

    Taken from https://stackoverflow.com/a/12600066.
    """

    def options(self, section, no_defaults=False, **kwargs):
        if no_defaults:
            try:
                return list(self._sections[section].keys())
            except KeyError:
                raise configparser.NoSectionError(section)
        else:
            return super().options(section)


ini_defaults = []
internal_defaults = {default["key"]: default["value"] for default in CONST.CFG_KNOWN_DEFAULTS}
internal_globals = [default["key"] for default in CONST.CFG_KNOWN_DEFAULTS if default["is_global"]]
internal_empty_ok = [default["key"] for default in CONST.CFG_KNOWN_DEFAULTS if default["empty_ok"]]
config = ConfigParser(defaults=internal_defaults,
                      converters={'list': lambda x: [i.strip() for i in x.split(',') if len(x) > 0]})
config.read(CONST.CFG_DEFAULT_ABS_PATH)
exit_code_desc = {
    5: "Unable to open firewalld direct rules file for reading",
    9: "Unable to open firewalld direct rules file for updating"
}


def print_section_header(
        header: str) -> str:
    return f"Loading config section '[{header}]' ..."


def validate_default_section(
        config_obj: configparser.ConfigParser()) -> None:
    log.debug(f"Loading config from file '{CONST.CFG_DEFAULT_ABS_PATH}' ...")
    if not config_obj.sections():
        log.debug(f"No config sections found in '{CONST.CFG_DEFAULT_ABS_PATH}'. Exiting 1 ...")
        sys.exit(1)
    if config.defaults():
        log.debug(f"Symbol legend:\n"
                  f"* Default from section '[{config_obj.default_section}]'\n"
                  f": Global option from '[{config_obj.default_section}]', can not be overridden in local sections\n"
                  f"~ Local option, doesn't exist in '[{config_obj.default_section}]'\n"
                  f"+ Local override of a value from '[{config_obj.default_section}]'\n"
                  f"= Local override, same value as in '[{config_obj.default_section}]'\n"
                  f"# Local attempt at overriding a global, will be ignored")
        log.debug(print_section_header(config_obj.default_section))
        for default in config_obj.defaults():
            ini_defaults.append({default: config_obj[config_obj.default_section][default]})
            if default in internal_globals:
                log.debug(f": {default} = {config_obj[config_obj.default_section][default]}")
            else:
                log.debug(f"* {default} = {config_obj[config_obj.default_section][default]}")
    else:
        log.debug(f"No defaults defined")


def config_has_valid_section(
        config_obj: configparser.ConfigParser()) -> bool:
    has_valid_section = False
    for config_obj_section in config_obj.sections():
        if set(CONST.CFG_MANDATORY).issubset(config_obj.options(config_obj_section)):
            has_valid_section = True
            break
    return has_valid_section


def is_default(
        config_key: str) -> bool:
    return any(config_key in ini_default for ini_default in ini_defaults)


def is_global(
        config_key: str) -> bool:
    return config_key in internal_globals


def is_same_as_default(
        config_kv_pair: dict) -> bool:
    return config_kv_pair in ini_defaults


def we_have_unset_options(
        config_obj: configparser.ConfigParser(),
        section_name: str) -> list:

    options_must_be_non_empty = []

    for option in config_obj.options(section_name):
        if not config_obj.get(section_name, option):
            if option not in internal_empty_ok:
                log.warning(f"In section '[{section_name}]' option '{option}' is empty, it mustn't be.")
                options_must_be_non_empty.append(option)

    return options_must_be_non_empty


def validate_config_sections(
        config_obj: configparser.ConfigParser()) -> None:
    for this_section in config_obj.sections():
        log.debug(print_section_header(this_section))

        unset_options = we_have_unset_options(config_obj, this_section)
        if unset_options:
            log.error(f"""{p.plural("Option", len(unset_options))} {unset_options} """
                      f"""{p.plural("is", len(unset_options))} unset. """
                      f"""{p.singular_noun("They", len(unset_options))} """
                      f"must have a non-null value. "
                      f"""{p.plural("Default", len(unset_options))} {p.plural("is", len(unset_options))}:""")
            for unset_option in unset_options:
                log.error(f"{unset_option} = {internal_defaults[unset_option]}")
            log.error(f"Exiting 7 ...")
            sys.exit(7)

        if not set(CONST.CFG_MANDATORY).issubset(config_obj.options(this_section, no_defaults=True)):
            log.warning(f"Config section '[{this_section}]' does not have all mandatory options "
                        f"{CONST.CFG_MANDATORY} set, skipping section ...")
            config_obj.remove_section(this_section)
        else:
            for key in config_obj.options(this_section, no_defaults=True):
                kv_prefix = "~"
                remove_from_section = False
                if is_global(key):
                    kv_prefix = "#"
                    remove_from_section = True
                elif is_default(key):
                    kv_prefix = "+"
                    if is_same_as_default({key: config_obj[this_section][key]}):
                        kv_prefix = "="
                log.debug(f"{kv_prefix} {key} = {config_obj[this_section][key]}")
                if remove_from_section:
                    config_obj.remove_option(this_section, key)


def has_child_elem(elem_name: str, attr_value: str) -> bool:
    global arg_fw_rule_data
    attr_name = "ipv"

    for elem in arg_fw_rule_data.findall(elem_name):
        if elem.attrib[attr_name] == attr_value:
            log.debug(f"""XML has element '<{elem_name} {attr_name}="{attr_value}" .../>'""")
            return True
    log.debug(f"""No XML element '<{elem_name} {attr_name}="{attr_value}" .../>'""")
    return False


def add_chain_elem(elem_name: str, addr_family: str) -> bool:
    global arg_fw_rule_data

    log.debug(f"Adding new ...")
    for chain in ["FILTERS", "DOCKER-USER"]:
        try:
            lxml.etree.SubElement(arg_fw_rule_data, elem_name,
                                  ipv=f"{addr_family}",
                                  table="filter",
                                  chain=chain)
        except lxml.etree.LxmlError as le:
            log.error(f"""Failed to add XML '<{elem_name} ipv=f"{addr_family}" .../>'\n"""
                      f"Verbatim exception was:\n"
                      f"f{le}\n"
                      f"Exiting 8 ...")
            sys.exit(8)

    return True


def rules_count(
        arg_ipv: str = "ipv4",
        arg_chain: str = "FILTERS") -> int:

    arg_rules_count = len([rule for rule in arg_fw_rule_data.findall("rule") if all([
        rule.attrib["ipv"] == arg_ipv if arg_ipv else False,
        rule.attrib["chain"] == arg_chain if arg_chain else False])])

    log.debug(f"""Counted {arg_rules_count} {p.plural("rule", arg_rules_count)} matching """
              f"""{"ipv=" + arg_ipv + " " if arg_ipv else ""}"""
              f"""{"chain=" + arg_chain + " " if arg_chain else ""}""")
    return arg_rules_count


def add_rule_elem(
        address_family: str,
        prio: int,
        target: str,
        /, *,
        arg_section: str = None,
        arg_proto: str = None,
        arg_state: str = None,
        arg_ports: list = None,
        arg_address: str = None,
        arg_chain: str = "FILTERS",
        arg_in_interface: str = None) -> bool:

    global arg_fw_rule_data

    try:
        lxml.etree.SubElement(arg_fw_rule_data, "rule",
                              ipv=f"{address_family}",
                              table=f"filter",
                              chain=arg_chain,
                              priority=f"""{prio}""").text = \
            f"""{"--in-interface " + arg_in_interface + " " if arg_in_interface else ""}""" \
            f"""{"--protocol " + arg_proto + " " if arg_proto else ""}""" \
            f"""{"--match state --state " + arg_state + " " if arg_state else ""}""" \
            f"""{"--match multiport --destination-ports " + ",".join(arg_ports) + " " if arg_ports else ""}""" \
            f"""{"--source " + arg_address + " " if arg_address else ""}""" \
            f"""--jump {target}""" \
            f"""{" --match comment --comment " + chr(34) + arg_section[:256] + chr(34) if arg_section else ""}"""
    except lxml.etree.LxmlError as le:
        log.error(f"""Failed to add XML '<rule ipv=f"{address_family}" .../>'\n"""
                  f"Verbatim exception was:\n"
                  f"f{le}\n"
                  f"Exiting 8 ...")
        sys.exit(8)
    else:
        return True


def get_phy_nics() -> list:
    phy_nics = []
    linux_sysfs_nics_abs = "/sys/class/net"
    find_phy_nics = ["find", linux_sysfs_nics_abs, "-mindepth", "1", "-maxdepth", "1", "-not", "-lname", "*virtual*"]
    # find_phy_nics = ["find", linux_sysfs_nics_abs, "-mindepth", "1", "-maxdepth", "1", "-lname", "*virtual*"]

    if os.path.isdir(linux_sysfs_nics_abs):
        try:
            phy_nics_find = subprocess.run(find_phy_nics,
                                           stdout=subprocess.PIPE,
                                           stderr=subprocess.STDOUT,
                                           check=True,
                                           encoding="UTF-8")
        except subprocess.CalledProcessError as cpe:
            log.error(f"Failed to find physical network device in {linux_sysfs_nics_abs!r}.\n"
                      f"Command was:\n"
                      f"{cpe.cmd}\n"
                      f"Verbatim command output was:\n"
                      f"{cpe.output.rstrip()}\n"
                      f"Exiting 4 ...")
            sys.exit(4)
        else:
            if not phy_nics_find.stdout:
                log.error(f"No physical network device found at {linux_sysfs_nics_abs!r}.\n"
                          f"Command was:\n"
                          f"{phy_nics_find.args}\n"
                          f"Exiting 3 ...")
                sys.exit(3)
            for line in phy_nics_find.stdout.rstrip().split("\n"):
                log.debug(f"Found physical network device {(phy_nic := os.path.basename(line))!r}")
                phy_nics.append(phy_nic)
    else:
        log.error(f"Path {linux_sysfs_nics_abs!r} does not exist. This might not be a Linux-y operating system. "
                  f"Without that location we'll not be able to separate physical network interfaces from virtual ones. "
                  f"Exiting 6 ...")
        sys.exit(6)

    log.debug(f"List of identified physical network interfaces: {phy_nics}")
    return phy_nics


def add_fw_rule_to_xml(
        config_obj: configparser.ConfigParser(),
        section_name: str,
        target: str,
        ports: list,
        proto: str) -> bool:
    global arg_fw_rule_data
    global arg_allow_sources
    addr = arg_allow_sources

    rules_already_added = {"ipv4": rules_count(arg_ipv="ipv4") + 1, "ipv6": rules_count(arg_ipv="ipv6") + 1}
    log.debug(f"Current rules count: {rules_already_added}")

    for address_family in ["ipv4", "ipv6"]:
        if len(addr[address_family]):
            if not has_child_elem("chain", address_family):
                add_chain_elem("chain", address_family)
            for address in addr[address_family]:
                add_rule_elem(
                    address_family,
                    rules_already_added[address_family],
                    target,
                    arg_section=section_name,
                    arg_proto=proto,
                    arg_state=config_obj.get(section_name, "state"),
                    arg_ports=ports,
                    arg_address=address)
                rules_already_added[address_family] += 1
        if not len(addr["ipv4"]) and not len(addr["ipv6"]):
            if address_family == "ipv4" or (address_family == "ipv6"
                                            and
                                            config_obj.getboolean(section_name, "do_ipv6")):
                if not has_child_elem("chain", address_family):
                    add_chain_elem("chain", address_family)
                add_rule_elem(
                    address_family,
                    rules_already_added[address_family],
                    target,
                    arg_section=section_name,
                    arg_proto=proto,
                    arg_state=config_obj.get(section_name, "state"),
                    arg_ports=ports)
                rules_already_added[address_family] += 1

    return True


def resolve_domain(domain: str) -> list[str]:
    log.debug(f"Resolving DNS A and AAAA records for '{domain}' ...")
    try:
        a_records = dns.resolver.resolve(domain, rdtype=dns.rdatatype.A)
    except dns.resolver.NoAnswer:
        log.debug(f"DNS didn't return an A record for '{domain}', ignoring ...")
        a_records = []
    try:
        aaaa_records = dns.resolver.resolve(domain, rdtype=dns.rdatatype.AAAA)
    except dns.resolver.NoAnswer:
        log.debug(f"DNS didn't return a AAAA record for '{domain}', ignoring ...")
        aaaa_records = []

    dns_records = []
    [dns_records.append(dns_record.address) for dns_record in a_records if a_records]
    [dns_records.append(dns_record.address) for dns_record in aaaa_records if aaaa_records]
    log.info(f"""For {domain!r} found {p.plural("record", len(dns_records))}: {dns_records}""")
    return dns_records


def resolve_addresses(
        config_obj: configparser.ConfigParser(),
        section_name: str,
        allow_list_mixed: list[str]) -> dict[str, list]:
    global arg_allow_sources
    allow_list_ip_only = []

    log.info(f"""Verifying {p.plural("address", len(allow_list_mixed))} {allow_list_mixed!r} ...""")
    for allow_source in allow_list_mixed:
        log.debug(f"Checking if '{allow_source}' is a domain ...")
        if validators.domain(allow_source):
            log.debug(f"'{allow_source}' is a domain.")
            [allow_list_ip_only.append(addr) for addr in resolve_domain(allow_source)]
        else:
            log.debug(f"'{allow_source}' is not a domain.")
            allow_list_ip_only.append(allow_source)

    for allow_source in allow_list_ip_only:
        try:
            ipv4_addr = str(ipaddress.IPv4Address(allow_source))
            log.info(f"Adding IPv4 address '{allow_source}' ...")
            arg_allow_sources["ipv4"].append(ipv4_addr)
        except ipaddress.AddressValueError:
            log.debug(f"Address '{allow_source}' is not a valid IPv4 address.")
            if not config_obj.getboolean(section_name, "do_ipv6"):
                log.info(f"For section '[{section_name}]' option 'do_ipv6' equals false. "
                         f"Skipping IPv6 handling of '{allow_source}' ...")
                continue
            try:
                ipv6_addr = str(ipaddress.IPv6Address(allow_source))
            except ipaddress.AddressValueError:
                log.debug(f"Address '{allow_source}' is not a valid IPv6 address either. Ignoring ...")
            else:
                log.info(f"Adding IPv6 address '{allow_source}' ...")
                arg_allow_sources["ipv6"].append(ipv6_addr)

    return arg_allow_sources


def gen_fwd_direct_scaffolding() -> lxml.builder.ElementMaker:
    data = lxml.builder.ElementMaker()
    direct_tag = data.direct
    fw_rule_data = direct_tag()
    return fw_rule_data


def ose_handler(
        os_error: OSError,
        human_text: str = None,
        exit_code: int = None) -> None:
    nl = "\n"
    log.error(f"{human_text if human_text else exit_code_desc.get(exit_code)}"
              f"{nl}Verbatim exception was:\n"
              f"{os_error}"
              f"""{nl + "Exiting " + str(exit_code) + " ..." if exit_code else ""}""")


def get_xml_str_repr() -> str:
    global arg_fw_rule_data

    fwd_direct_xml_str = lxml.etree.tostring(arg_fw_rule_data,
                                             pretty_print=True,
                                             encoding="UTF-8",
                                             xml_declaration=True).decode()

    return fwd_direct_xml_str


def write_new_fwd_direct_xml(
        config_obj: configparser.ConfigParser()) -> bool:
    global arg_fw_rule_data

    fwd_direct_xml_str = get_xml_str_repr()

    try:
        with open(config_obj.get(configparser.DEFAULTSECT, "firewalld_direct_abs"), "r+") as fwd_file_handle:
            log.info(f"Writing new firewalld direct config ...")
            log.debug(f"New content:\n"
                      f"{fwd_direct_xml_str.rstrip()}")
            fwd_file_handle.seek(0)
            fwd_file_handle.write(fwd_direct_xml_str)
            fwd_file_handle.truncate()
    except OSError as ose:
        ose_handler(os_error=ose, exit_code=9)
        sys.exit(9)
    else:
        return True


def restart_systemd_firewalld() -> bool:
    sysbus = dbus.SystemBus()
    systemd1 = sysbus.get_object("org.freedesktop.systemd1", "/org/freedesktop/systemd1")
    manager = dbus.Interface(systemd1, "org.freedesktop.systemd1.Manager")

    firewalld_unit = manager.LoadUnit("firewalld.service")
    firewalld_proxy = sysbus.get_object("org.freedesktop.systemd1", str(firewalld_unit))
    firewalld_active_state = firewalld_proxy.Get("org.freedesktop.systemd1.Unit",
                                                 "ActiveState",
                                                 dbus_interface="org.freedesktop.DBus.Properties")

    if firewalld_active_state == "inactive":
        log.info(f"systemd firewalld.service unit is inactive, ignoring restart instruction, leaving as-is ...")
        return False

    try:
        log.info(f"Restarting systemd firewalld.service unit ...")
        manager.TryRestartUnit('firewalld.service', 'fail')
    except dbus.exceptions.DBusException as dbe:
        log.error(f"Failed to restart systemd firewalld.service unit.\n"
                  f"Verbatim exception was:\n"
                  f"{dbe}\n"
                  f"You're going to want to check firewalld.service health.\n"
                  f"Exiting 10 ...")
        sys.exit(10)
    else:
        log.info(f"Done")
        return True


def add_firewall_shim(arg_phy_nics: list) -> None:
    global arg_fw_rule_data

    log.debug(f"Adding ip(6)tables jump target to DOCKER-USER chain ...")
    for addr_family in ["ipv4", "ipv6"]:
        for phy_nic in arg_phy_nics:
            if has_child_elem("chain", addr_family):
                add_rule_elem(
                    addr_family,
                    rules_count(addr_family, arg_chain="INPUT"),
                    "ACCEPT",
                    arg_chain="INPUT",
                    arg_in_interface="lo"
                )
                for chain in ["INPUT", "DOCKER-USER"]:
                    add_rule_elem(
                        addr_family,
                        rules_count(addr_family, arg_chain=chain),
                        "FILTERS",
                        arg_chain=chain,
                        arg_in_interface=phy_nic if chain == "DOCKER-USER" else None
                    )


def has_xml_changed(
        config_obj: configparser.ConfigParser()) -> bool:
    arg_fwd_file_abs = os.path.abspath(config_obj.get(configparser.DEFAULTSECT, "firewalld_direct_file_abs"))

    try:
        with open(arg_fwd_file_abs, "r") as fwd_file_abs_handle:
            fwd_file_abs_content = fwd_file_abs_handle.read()
            fwd_direct_xml_str = get_xml_str_repr()
            diff_result = difflib.Differ().compare(fwd_file_abs_content.splitlines(), fwd_direct_xml_str.splitlines())
            s = difflib.SequenceMatcher(isjunk=None, a=fwd_file_abs_content, b=fwd_direct_xml_str, autojunk=False)
    except OSError as ose:
        ose_handler(os_error=ose, exit_code=5)
        sys.exit(5)
    else:
        if s.ratio() < 1:
            nl = "\n"
            log.info(f"Changing firewalld rules. Diff as follows:\n"
                     f"""{nl.join(diff_result)}""")
            return True
        else:
            log.info(f"No diff in firewalld XML config, no need to write new file.")
            return False


if __name__ == "__main__":
    validate_default_section(config)
    if config_has_valid_section(config):
        validate_config_sections(config)
    else:
        log.error(f"No valid config section found. A valid config section has at least the mandatory options "
                  f"{CONST.CFG_MANDATORY} set. Exiting 2 ...")
        sys.exit(2)

    arg_fw_rule_data = gen_fwd_direct_scaffolding()

    log.debug(f"Iterating over config sections ...")
    for section in config.sections():
        log.info(f"Generating rules from section '[{section}]' ...")
        arg_fwd_addr = config.getlist(section, "addr")
        arg_allow_sources = {"ipv4": [], "ipv6": []}
        if arg_fwd_addr:
            arg_allow_sources = resolve_addresses(config, section, arg_fwd_addr)
            log.debug(arg_allow_sources)
        else:
            log.info(f"No source address given. Rules will apply to all sources.")

        add_fw_rule_to_xml(config,
                           section,
                           target=config.get(section, "target"),
                           ports=config.getlist(section, "ports"),
                           proto=config.get(section, "proto"))
    for arg_address_family in ["ipv4", "ipv6"]:
        if rules_count(arg_address_family):
            add_rule_elem(
                arg_address_family,
                0,
                "ACCEPT",
                arg_state="ESTABLISHED,RELATED")
    add_firewall_shim(get_phy_nics())

    if has_xml_changed(config):
        write_new_fwd_direct_xml(config)
        if config.getboolean(configparser.DEFAULTSECT, "restart_firewalld_after_change"):
            restart_systemd_firewalld()