#  This file is part of TALER
#  (C) 2016 INRIA
#
#  TALER is free software; you can redistribute it and/or modify it under the
#  terms of the GNU Affero General Public License as published by the Free Software
#  Foundation; either version 3, or (at your option) any later version.
#
#  TALER is distributed in the hope that it will be useful, but WITHOUT ANY
#  WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
#  A PARTICULAR PURPOSE.  See the GNU General Public License for more details.
#
#  You should have received a copy of the GNU General Public License along with
#  TALER; see the file COPYING.  If not, see <http://www.gnu.org/licenses/>
#
#  @author Florian Dold

"""
Parse GNUnet-style configurations in pure Python
"""

import logging
import collections
import os
import weakref
import sys
import re
from typing import Callable, Any

LOGGER = logging.getLogger(__name__)

__all__ = ["TalerConfig"]

TALER_DATADIR = None
try:
    # not clear if this is a good idea ...
    from talerpaths import TALER_DATADIR as t
    TALER_DATADIR = t
except ImportError:
    pass

class ConfigurationError(Exception):
    pass

class ExpansionSyntaxError(Exception):
    pass


def expand(var: str, getter: Callable[[str], str]) -> str:
    """
    Do shell-style parameter expansion.
    Supported syntax:
    - ${X}
    - ${X:-Y}
    - $X
    """
    pos = 0
    result = ""
    while pos != -1:
        start = var.find("$", pos)
        if start == -1:
            break
        if var[start:].startswith("${"):
            balance = 1
            end = start + 2
            while balance > 0 and end < len(var):
                balance += {"{": 1, "}": -1}.get(var[end], 0)
                end += 1
            if balance != 0:
                raise ExpansionSyntaxError("unbalanced parentheses")
            piece = var[start+2:end-1]
            if piece.find(":-") > 0:
                varname, alt = piece.split(":-", 1)
                replace = getter(varname)
                if replace is None:
                    replace = expand(alt, getter)
            else:
                varname = piece
                replace = getter(varname)
                if replace is None:
                    replace = var[start:end]
        else:
            end = start + 2
            while end < len(var) and var[start+1:end+1].isalnum():
                end += 1
            varname = var[start+1:end]
            replace = getter(varname)
            if replace is None:
                replace = var[start:end]
        result = result + replace
        pos = end


    return result + var[pos:]


class Entry:
    def __init__(self, config, section: str, option: str, **kwargs) -> None:
        self.value = kwargs.get("value")
        self.filename = kwargs.get("filename")
        self.lineno = kwargs.get("lineno")
        self.section = section
        self.option = option
        self.config = weakref.ref(config)

    def __repr__(self) -> str:
        return "<Entry section=%s, option=%s, value=%s>" \
               % (self.section, self.option, repr(self.value),)

    def __str__(self) -> Any:
        return self.value

    def value_string(self, default=None, required=False, warn=False) -> str:
        if required and self.value is None:
            raise ConfigurationError("Missing required option '%s' in section '%s'" \
                                     % (self.option.upper(), self.section.upper()))
        if self.value is None:
            if warn:
                if default is not None:
                    LOGGER.warning("Configuration is missing option '%s' in section '%s',\
                                   falling back to '%s'", self.option, self.section, default)
                else:
                    LOGGER.warning("Configuration ** is missing option '%s' in section '%s'",
                                   self.option.upper(), self.section.upper())
            return default
        return self.value

    def value_int(self, default=None, required=False, warn=False) -> int:
        value = self.value_string(default, warn, required)
        if value is None:
            return None
        try:
            return int(value)
        except ValueError:
            raise ConfigurationError("Expected number for option '%s' in section '%s'" \
                                     % (self.option.upper(), self.section.upper()))

    def _getsubst(self, key: str) -> Any:
        value = self.config()["paths"][key].value
        if value is not None:
            return value
        value = os.environ.get(key)
        if value is not None:
            return value
        return None

    def value_filename(self, default=None, required=False, warn=False) -> str:
        value = self.value_string(default, required, warn)
        if value is None:
            return None
        return expand(value, self._getsubst)

    def location(self) -> str:
        if self.filename is None or self.lineno is None:
            return "<unknown>"
        return "%s:%s" % (self.filename, self.lineno)


class OptionDict(collections.defaultdict):
    def __init__(self, config, section_name: str) -> None:
        self.config = weakref.ref(config)
        self.section_name = section_name
        super().__init__()
    def __missing__(self, key: str) -> Entry:
        entry = Entry(self.config(), self.section_name, key)
        self[key] = entry
        return entry
    def __getitem__(self, chunk: str) -> Entry:
        return super().__getitem__(chunk.lower())
    def __setitem__(self, chunk: str, value: Entry) -> None:
        super().__setitem__(chunk.lower(), value)


class SectionDict(collections.defaultdict):
    def __missing__(self, key):
        value = OptionDict(self, key)
        self[key] = value
        return value
    def __getitem__(self, chunk: str) -> OptionDict:
        return super().__getitem__(chunk.lower())
    def __setitem__(self, chunk: str, value: OptionDict) -> None:
        super().__setitem__(chunk.lower(), value)


class TalerConfig:
    """
    One loaded taler configuration, including base configuration
    files and included files.
    """
    def __init__(self) -> None:
        """
        Initialize an empty configuration
        """
        self.sections = SectionDict() # just plain dict

    # defaults != config file: the first is the 'base'
    # whereas the second overrides things from the first.
    @staticmethod
    def from_file(filename=None, load_defaults=True):
        cfg = TalerConfig()
        if filename is None:
            xdg = os.environ.get("XDG_CONFIG_HOME")
            if xdg:
                filename = os.path.join(xdg, "taler.conf")
            else:
                filename = os.path.expanduser("~/.config/taler.conf")
        if load_defaults:
            cfg.load_defaults()
        cfg.load_file(filename)
        return cfg

    def value_string(self, section, option, **kwargs) -> str:
        return self.sections[section][option].value_string(
            kwargs.get("default"), kwargs.get("required"), kwargs.get("warn"))

    def value_filename(self, section, option, **kwargs) -> str:
        return self.sections[section][option].value_filename(
            kwargs.get("default"), kwargs.get("required"), kwargs.get("warn"))

    def value_int(self, section, option, **kwargs) -> int:
        return self.sections[section][option].value_int(
            kwargs.get("default"), kwargs.get("required"), kwargs.get("warn"))

    def load_defaults(self) -> None:
        base_dir = os.environ.get("TALER_BASE_CONFIG")
        if base_dir:
            self.load_dir(base_dir)
            return
        prefix = os.environ.get("TALER_PREFIX")
        if prefix:
            tmp = os.path.split(os.path.normpath(prefix))
            if re.match("lib", tmp[1]):
                prefix = tmp[0]
            self.load_dir(os.path.join(prefix, "share/taler/config.d"))
            return
        if TALER_DATADIR:
            self.load_dir(os.path.join(TALER_DATADIR, "share/taler/config.d"))
            return
        LOGGER.warning("no base directory found")

    @staticmethod
    def from_env(*args, **kwargs):
        """
        Load configuration from environment variable TALER_CONFIG_FILE
        or from default location if the variable is not set.
        """
        filename = os.environ.get("TALER_CONFIG_FILE")
        return TalerConfig.from_file(filename, *args, **kwargs)

    def load_dir(self, dirname) -> None:
        try:
            files = os.listdir(dirname)
        except FileNotFoundError:
            LOGGER.warning("can't read config directory '%s'", dirname)
            return
        for file in files:
            if not file.endswith(".conf"):
                continue
            self.load_file(os.path.join(dirname, file))

    def load_file(self, filename) -> None:
        sections = self.sections
        try:
            with open(filename, "r") as file:
                lineno = 0
                current_section = None
                for line in file:
                    lineno += 1
                    line = line.strip()
                    if line == "":
                        # empty line
                        continue
                    if line.startswith("#"):
                        # comment
                        continue
                    if line.startswith("["):
                        if not line.endswith("]"):
                            LOGGER.error("invalid section header in line %s: %s",
                                         lineno, repr(line))
                        section_name = line.strip("[]").strip().strip('"')
                        current_section = section_name
                        continue
                    if current_section is None:
                        LOGGER.error("option outside of section in line %s: %s", lineno, repr(line))
                        continue
                    pair = line.split("=", 1)
                    if len(pair) != 2:
                        LOGGER.error("invalid option in line %s: %s", lineno, repr(line))
                    key = pair[0].strip()
                    value = pair[1].strip()
                    if value.startswith('"'):
                        value = value[1:]
                        if not value.endswith('"'):
                            LOGGER.error("mismatched quotes in line %s: %s", lineno, repr(line))
                        else:
                            value = value[:-1]
                    entry = Entry(self.sections, current_section, key,
                                  value=value, filename=filename, lineno=lineno)
                    sections[current_section][key] = entry
        except FileNotFoundError:
            LOGGER.error("Configuration file (%s) not found", filename)
            sys.exit(3)


    def dump(self) -> None:
        for kv_section in self.sections.items():
            print("[%s]" % (kv_section[1].section_name,))
            for kv_option in kv_section[1].items():
                print("%s = %s # %s" % \
                      (kv_option[1].option,
                       kv_option[1].value,
                       kv_option[1].location()))

    def __getitem__(self, chunk: str) -> OptionDict:
        if isinstance(chunk, str):
            return self.sections[chunk]
        raise TypeError("index must be string")


if __name__ == "__main__":
    import argparse

    PARSER = argparse.ArgumentParser()
    PARSER.add_argument("--section", "-s", dest="section",
                        default=None, metavar="SECTION")
    PARSER.add_argument("--option", "-o", dest="option",
                        default=None, metavar="OPTION")
    PARSER.add_argument("--config", "-c", dest="config",
                        default=None, metavar="FILE")
    PARSER.add_argument("--filename", "-f", dest="expand_filename",
                        default=False, action='store_true')
    ARGS = PARSER.parse_args()

    TC = TalerConfig.from_file(ARGS.config)

    if ARGS.section is not None and ARGS.option is not None:
        if ARGS.expand_filename:
            X = TC.value_filename(ARGS.section, ARGS.option)
        else:
            X = TC.value_string(ARGS.section, ARGS.option)
        if X is not None:
            print(X)
    else:
        TC.dump()
