import configparser
import datetime as d
import json
import logging
import os
import re
import sys
import time
import humanize
import requests
import inflect
from rich.logging import RichHandler
from rich.traceback import install
import typing as t
from rich.console import Console

import type_def.mvw_json_response
from type_def.mvw_json_request import MVWJSONRequest
from type_def.mvw_json_response import MVWJSONResponse

# Downloading
import os.path
import sys
from concurrent.futures import ThreadPoolExecutor
import signal
from functools import partial
from threading import Event
from typing import Iterable
from urllib.request import urlopen

# TODO set locale for datetime and others to globally stick to en_US
download_start_time = 0
download_last_update_time = 0
size_downloaded = 0

from rich.progress import (
    BarColumn,
    DownloadColumn,
    Progress,
    TaskID,
    TextColumn,
    TimeRemainingColumn,
    TransferSpeedColumn,
)

progress = Progress(
    TextColumn("[bold blue]{task.fields[filename]}", justify="right"),
    BarColumn(bar_width=None),
    "[progress.percentage]{task.percentage:>3.1f}%",
    "•",
    DownloadColumn(),
    "•",
    TransferSpeedColumn(),
    "•",
    TimeRemainingColumn(),
)
# Downloading

# Without width
console = Console(width=180)
p = inflect.engine()

# We use Python 3.5+ type hints; we're working with JSON objects; we're following a 2016 suggestion from
# Python's "typing" GitHub issue tracker on how to create a "JSONType" hint since such a thing does not yet
# officially exist: https://github.com/python/typing/issues/182#issuecomment-186684288
# JSONType = t.Union[str, int, float, bool, None, t.Dict[str, t.Any], t.List[t.Any]]
JSONType = t.Union[str, int, float, bool, None, t.Dict[str, t.Any], t.List[t.Any]]


# Exit codes
# 1: Config file invalid, it has no sections
# 2: Config file invalid, sections must define at least CONST.CFG_MANDATORY
# 3: No search results to download


class CONST(object):
    __slots__ = ()
    LOG_FORMAT = "%(message)s"
    CFG_THIS_FILE_DIRNAME = os.path.dirname(__file__)
    CFG_DEFAULT_FILENAME = "config.ini"
    CFG_DEFAULT_ABS_PATH = os.path.join(CFG_THIS_FILE_DIRNAME, CFG_DEFAULT_FILENAME)
    CFG_KNOWN_DEFAULTS = [
        {"key": "self_name", "value": "mvw-dl"},
        {"key": "tmp_base_dir", "value": os.path.join(CFG_THIS_FILE_DIRNAME, "data/tmp/%(self_name)s")},
        {"key": "state_base_dir", "value": os.path.join(CFG_THIS_FILE_DIRNAME, "data/var/lib/%(self_name)s")},
        {"key": "state_files_dir", "value": "%(state_base_dir)s/state"},
        {"key": "state_file_retention", "value": "50"},
        {"key": "state_file_name_prefix", "value": "state-"},
        {"key": "state_file_name_suffix", "value": ".log"},
        {"key": "mvw_endpoint", "value": "http://localhost:8000/api/query"},
        {"key": "title_dedup_winner", "value": "first"},
        {"key": "dl_progress_update_interval", "value": "10"},
        {"key": "dl_threads", "value": "2"},
        {"key": "dl_filename_pattern", "value": "&(channel)s - &(publish_date)s - &(topic)s - &(title)s"},
        {"key": "dl_filename_spaces_to_underscores", "value": "yes"},
        {"key": "dl_filename_all_lowercase", "value": "yes"}
    ]
    CFG_KNOWN_SECTION = [
        {"key": "min_duration", "is_mandatory": False},
        {"key": "max_duration", "is_mandatory": False},
        {"key": "title_not_regex", "is_mandatory": False},
        {"key": "query", "is_mandatory": True},
        {"key": "dl_dir", "is_mandatory": True}
    ]
    CFG_MANDATORY = [section_cfg["key"] for section_cfg in CFG_KNOWN_SECTION if section_cfg["is_mandatory"]]


CONST = CONST()
logging.basicConfig(
    # Default for all modules in NOTSET so log everything
    level="NOTSET",
    format=CONST.LOG_FORMAT,
    datefmt="[%X]",
    handlers=[RichHandler(
        show_time=False if "SYSTEMD_EXEC_PID" in os.environ else True,
        rich_tracebacks=True
    )]
)
log = logging.getLogger("rich")
# Our own code logs with this level
log.setLevel(logging.DEBUG)
# connectionpool logs with WARNING, we don't need its verbosity
log_connectionpool = logging.getLogger("urllib3.connectionpool")
log_connectionpool.setLevel(logging.WARNING)
install(show_locals=True)


class ConfigParser(
        configparser.ConfigParser):
    """Can get options() without defaults

    Taken from https://stackoverflow.com/a/12600066.
    """

    def options(self, section, no_defaults=False, **kwargs):
        if no_defaults:
            try:
                return list(self._sections[section].keys())
            except KeyError:
                raise configparser.NoSectionError(section)
        else:
            return super().options(section)


ini_defaults = []
internal_defaults = {default["key"]: default["value"] for default in CONST.CFG_KNOWN_DEFAULTS}
config = ConfigParser(defaults=internal_defaults)
config.read(CONST.CFG_DEFAULT_FILENAME)


def print_section_header(
        header: str) -> str:
    return f"Loading config section '[{header}]' ..."


def validate_default_section(
        config_obj: configparser.ConfigParser()) -> None:
    log.debug(f"Loading config from file '{CONST.CFG_DEFAULT_ABS_PATH}' ...")
    if not config_obj.sections():
        log.error(f"No config sections found in '{CONST.CFG_DEFAULT_ABS_PATH}'. Exiting 1 ...")
        sys.exit(1)
    if config.defaults():
        log.debug(f"Symbol legend:\n"
                  f"* Global default from section '[{config_obj.default_section}]'\n"
                  f"~ Local option, doesn't exist in '[{config_obj.default_section}]'\n"
                  f"+ Local override of a value from '[{config_obj.default_section}]'\n"
                  f"= Local override, same value as in '[{config_obj.default_section}]'")
        log.debug(print_section_header(config_obj.default_section))
        for default in config_obj.defaults():
            ini_defaults.append({default: config_obj[config_obj.default_section][default]})
            log.debug(f"* {default} = {config_obj[config_obj.default_section][default]}")
    else:
        log.debug(f"No defaults defined")


def config_has_valid_section(
        config_obj: configparser.ConfigParser()) -> bool:
    has_valid_section = False
    for config_obj_section in config_obj.sections():
        if set(CONST.CFG_MANDATORY).issubset(config_obj.options(config_obj_section)):
            has_valid_section = True
            break
    return has_valid_section


def is_default(
        config_key: str) -> bool:
    return any(config_key in ini_default for ini_default in ini_defaults)


def is_same_as_default(
        config_kv_pair: dict) -> bool:
    return config_kv_pair in ini_defaults


def validate_config_sections(
        config_obj: configparser.ConfigParser()) -> None:
    for this_section in config_obj.sections():
        log.debug(print_section_header(this_section))
        if not set(CONST.CFG_MANDATORY).issubset(config_obj.options(this_section, no_defaults=True)):
            log.warning(f"Config section '[{this_section}]' does not have all mandatory options "
                        f"{CONST.CFG_MANDATORY} set, skipping section ...")
            config_obj.remove_section(this_section)
        else:
            for key in config_obj.options(this_section, no_defaults=True):
                kv_prefix = "~"
                if is_default(key):
                    kv_prefix = "+"
                    if is_same_as_default({key: config_obj[this_section][key]}):
                        kv_prefix = "="
                log.debug(f"{kv_prefix} {key} = {config_obj[this_section][key]}")


def query_string_from_file(
        filename: str) -> str:
    with open(filename, "r") as jsonfile:
        query_string = jsonfile.read()
        return query_string


def get_query_payload(
        section_name: str,
        config_obj: configparser.ConfigParser()) -> MVWJSONRequest:
    log.debug(f"Generating HTTP POST JSON payload ...")
    query = config_obj.get(section_name, "query")
    if query[0] == "@":
        query = query.split("@", 1)[1]
        query = query_string_from_file(query)
    got_query_payload = MVWJSONRequest(**json.loads(query))
    return got_query_payload


def get_json_response(
        section_name: str,
        config_obj: configparser.ConfigParser(),
        payload: MVWJSONRequest) -> MVWJSONResponse:
    log.debug(f"Downloading JSON list of Mediathek files that match search criteria")
    serialized_payload = payload.json()
    url = config_obj.get(section_name, "mvw_endpoint")
    req_header = {"Content-Type": "text/plain"}
    s = requests.Session()
    req = requests.Request("POST", url, data=serialized_payload, headers=req_header)
    prepped = req.prepare()
    newline = "\n"
    log.debug(f"Request method: {req.method}\n"
              f"URL: {req.url}\n"
              f"""{newline.join(f"Header '{header}': '{value}'" for header, value in list(req.headers.items()))}\n"""
              f"Payload: {payload}")
    with s.send(prepped) as s:
        got_json_response = MVWJSONResponse(**json.loads(s.content))
        return got_json_response


def no_downloads_needed() -> None:
    log.info(f"No search results to download, exiting 3 ...")
    sys.exit(3)


def remove_result(
        json_obj: MVWJSONResponse,
        result_obj: type_def.mvw_json_response.Show) -> MVWJSONResponse:
    json_obj.result.results.remove(result_obj)
    json_obj.result.queryInfo.resultCount -= 1
    if json_obj.result.queryInfo.resultCount:
        return json_obj
    else:
        no_downloads_needed()


def log_result_count(result_count: int, pre_filter: bool = True) -> None:
    if pre_filter:
        log.debug(f"""Search result contains {result_count} {p.plural("show", result_count)} going in""")
    else:
        log.debug(f"""Search result now contains {result_count} {p.plural("show", result_count)}""")


def filter_json_by_duration(
        section_name: str,
        config_obj: configparser.ConfigParser(),
        json_obj: MVWJSONResponse) -> MVWJSONResponse:
    min_duration = config_obj.getint(section_name, "min_duration")
    max_duration = config_obj.getint(section_name, "max_duration")
    log_result_count(json_obj.result.queryInfo.resultCount)
    if min_duration >= 0:
        log.debug(f"Filtering '[{section_name}]' JSON for minimum length of {min_duration} "
                  f"""{p.plural("second", min_duration)} ...""")
        for result in json_obj.result.results.copy():
            if not result.duration >= min_duration:
                remove_result(json_obj, result)
    if max_duration >= 0:
        log.debug(f"Filtering '[{section_name}]' JSON for maximum length of {max_duration} "
                  f"""{p.plural("second", max_duration)} ...""")
        for result in json_obj.result.results.copy():
            if not result.duration <= max_duration:
                remove_result(json_obj, result)
    log_result_count(json_obj.result.queryInfo.resultCount, False)
    return json_obj


def filter_json_by_title_regex(
        section_name: str,
        config_obj: configparser.ConfigParser(),
        json_obj: MVWJSONResponse) -> MVWJSONResponse:
    title_not_regex = re.compile(config_obj.get(section_name, "title_not_regex"), re.IGNORECASE)
    log_result_count(json_obj.result.queryInfo.resultCount)
    log.debug(f"Filtering '[{section_name}]' JSON by title regular expression")
    for result in json_obj.result.results.copy():
        if title_not_regex.search(result.title):
            remove_result(json_obj, result)
    log_result_count(json_obj.result.queryInfo.resultCount, False)
    return json_obj


def dedup_json_titles(
        section_name: str,
        config_obj: configparser.ConfigParser(),
        json_obj: MVWJSONResponse) -> MVWJSONResponse:
    title_dedup_winner = config_obj.get(section_name, "title_dedup_winner")
    titles_list = {}
    log_result_count(json_obj.result.queryInfo.resultCount)
    for result in json_obj.result.results.copy():
        if result.title not in titles_list:
            titles_list[result.title] = {}
        if result.id not in titles_list[result.title]:
            titles_list[result.title][result.id] = result.timestamp
    for result in titles_list.copy():
        if title_dedup_winner == "first":
            dd_winner = min(titles_list[result], key=str)
        else:
            dd_winner = max(titles_list[result], key=str)
        titles_list[result] = dd_winner
    for result in json_obj.result.results.copy():
        if result.title in titles_list:
            if result.id != titles_list[result.title]:
                log.debug(f"""Deduplicating '[{section_name}]' result "{result.title}" ...""")
                remove_result(json_obj, result)
    log_result_count(json_obj.result.queryInfo.resultCount, False)
    return json_obj


done_event = Event()


def handle_sigint(signum, frame):
    done_event.set()


signal.signal(signal.SIGINT, handle_sigint)


def get_safe_filename(
        dirty_filename: str) -> str:
    """https://stackoverflow.com/a/71199182"""

    clean_filename = re.sub(r"[/\\?%*:|\"<>\x7F\x00-\x1F]", "-", dirty_filename)
    return clean_filename


def log_successful_download(
        show: type_def.mvw_json_response.Show) -> None:
    pass


def copy_url(
        section_name: str,
        config_obj: configparser.ConfigParser(),
        show: type_def.mvw_json_response.Show,
        video_metadata: dict,
        total_content_length: int) -> None:
    """Copy data from a url to a local file."""

    global download_start_time
    global download_last_update_time
    global size_downloaded

    update_interval = config_obj.getint(section_name, "dl_progress_update_interval")
    max_quality_url = video_metadata["url"]
    filename = max_quality_url.split("/")[-1]
    dest_dir = config_obj.get(section_name, "dl_dir")
    dest_path = os.path.join(dest_dir, filename)
    dest_path = os.path.expanduser(dest_path)
    dest_path = os.path.expandvars(dest_path)
    show_name = f"{show.topic} - {show.title}"
    publish_date = d.datetime.utcfromtimestamp(show.timestamp).strftime('%Y%m%d')

    os.makedirs(os.path.dirname(dest_path), exist_ok=True)
    with open(dest_path, "wb") as dest_file:
        log.info(f"""Downloading "{show_name}" ...""")
        log.info(f"Download location resolved to {dest_path}")
        r = requests.get(max_quality_url, stream=True)
        for chunk in r.iter_content(32768):
            size_downloaded += len(chunk)
            dest_file.write(chunk)
            if time.time() - download_last_update_time >= update_interval:
                download_last_update_time = time.time()
                dl_speed_so_far = size_downloaded / (download_last_update_time - download_start_time)
                human_dl_speed_so_far = f"{humanize.naturalsize(dl_speed_so_far, binary=True)}/s"
                percentage_done = size_downloaded / total_content_length * 100
                human_pct = "{:.1f}".format(percentage_done)
                human_size_dl = humanize.naturalsize(size_downloaded, binary=True)
                human_total_dl = humanize.naturalsize(total_content_length, binary=True)
                log.debug(f"Downloaded {human_pct}% ({human_size_dl}/{human_total_dl} at an average "
                          f"{human_dl_speed_so_far})")
            if done_event.is_set():
                log.info(f"""Download of "{show_name}" interrupted""")
                return
    log.info(f"""Download of "{show_name}" done""")
    log_successful_download(show)


def get_max_quality_url(
        show: type_def.mvw_json_response.Show) -> str:
    if show.url_video_hd:
        max_quality_url = show.url_video_hd
    elif show.url_video:
        max_quality_url = show.url_video
    else:
        max_quality_url = show.url_video_low
    return max_quality_url


def get_content_length(
        video_url: str) -> int:
    r = requests.head(video_url)
    if r.status_code == requests.codes.ok:
        return int(r.headers["content-length"])
    else:
        return 0


def download_media(
        section_name: str,
        config_obj: configparser.ConfigParser(),
        json_obj: MVWJSONResponse) -> None:

    global download_start_time
    global download_last_update_time

    dl_threads = config_obj.getint(section_name, "dl_threads")
    video_metadata = {}

    for result in json_obj.result.results.copy():
        max_quality_url = get_max_quality_url(result)
        content_length = get_content_length(max_quality_url)
        video_metadata[result.id] = {"url": max_quality_url, "content_length": content_length}
    total_content_length = 0
    for video in video_metadata:
        total_content_length += video_metadata[video]["content_length"]
    video_metadata["total_content_length"] = total_content_length
    log.info(f"""Download location is {config_obj.get(section_name, "dl_dir")}""")
    log.info(f"Limiting parallel downloads to {dl_threads} ...")
    with ThreadPoolExecutor(max_workers=dl_threads) as pool:
        download_last_update_time = time.time()
        download_start_time = download_last_update_time
        update_interval = config_obj.getint(section_name, "dl_progress_update_interval")
        log.debug(f"""Will provide updates every {update_interval} {p.plural("second", update_interval)}""")
        for result in json_obj.result.results.copy():
            pool.submit(
                copy_url,
                section_name,
                config_obj,
                result,
                video_metadata[result.id],
                video_metadata["total_content_length"])


if __name__ == '__main__':
    validate_default_section(config)
    if config_has_valid_section(config):
        validate_config_sections(config)
    else:
        log.error(f"No valid config section found. A valid config section has at least the mandatory options "
                  f"{CONST.CFG_MANDATORY} set. Exiting 2 ...")
        sys.exit(2)

    log.debug(f"Iterating over config sections ...")
    for section in config.sections():
        log.debug(f"Processing section '[{section}]' ...")
        query_payload = get_query_payload(section, config)
        json_response = get_json_response(section, config, query_payload)

        log.debug(f"Filtering results by duration where applicable ...")
        if config.has_option(section, "min_duration") or config.has_option(section, "max_duration"):
            json_response = filter_json_by_duration(section, config, json_response)

        log.debug(f"Filtering results by title regular expression where applicable ...")
        if config.has_option(section, "title_not_regex"):
            json_response = filter_json_by_title_regex(section, config, json_response)

        log.debug(f"Deduplicating results by title where needed ...")
        if config.has_option(section, "title_not_regex"):
            json_response = dedup_json_titles(section, config, json_response)

        log.debug(f"Downloading {json_response.result.queryInfo.resultCount} "
                  f"""{p.plural("show", json_response.result.queryInfo.resultCount)} ...""")
        download_media(section, config, json_response)

            # console.print_json(json_response.json())