# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0

from __future__ import annotations

import asyncio
import configparser
import collections
from collections.abc import Callable
import enum
import functools
import getpass
import os
import pathlib
import platform
import random
import re
import socket
import ssl as ssl_module
import stat
import struct
import sys
import typing
import urllib.parse
import warnings
import inspect

from . import compat
from . import exceptions
from . import protocol


class SSLMode(enum.IntEnum):
    disable = 0
    allow = 1
    prefer = 2
    require = 3
    verify_ca = 4
    verify_full = 5

    @classmethod
    def parse(cls, sslmode):
        if isinstance(sslmode, cls):
            return sslmode
        return getattr(cls, sslmode.replace('-', '_'))


class SSLNegotiation(compat.StrEnum):
    postgres = "postgres"
    direct = "direct"


_ConnectionParameters = collections.namedtuple(
    'ConnectionParameters',
    [
        'user',
        'password',
        'database',
        'ssl',
        'sslmode',
        'ssl_negotiation',
        'server_settings',
        'target_session_attrs',
        'krbsrvname',
        'gsslib',
    ])


_ClientConfiguration = collections.namedtuple(
    'ConnectionConfiguration',
    [
        'command_timeout',
        'statement_cache_size',
        'max_cached_statement_lifetime',
        'max_cacheable_statement_size',
    ])


_system = platform.uname().system


if _system == 'Windows':
    PGPASSFILE = 'pgpass.conf'
else:
    PGPASSFILE = '.pgpass'


PG_SERVICEFILE = '.pg_service.conf'


def _read_password_file(passfile: pathlib.Path) \
        -> typing.List[typing.Tuple[str, ...]]:

    passtab = []

    try:
        if not passfile.exists():
            return []

        if not passfile.is_file():
            warnings.warn(
                'password file {!r} is not a plain file'.format(passfile))

            return []

        if _system != 'Windows':
            if passfile.stat().st_mode & (stat.S_IRWXG | stat.S_IRWXO):
                warnings.warn(
                    'password file {!r} has group or world access; '
                    'permissions should be u=rw (0600) or less'.format(
                        passfile))

                return []

        with passfile.open('rt') as f:
            for line in f:
                line = line.strip()
                if not line or line.startswith('#'):
                    # Skip empty lines and comments.
                    continue
                # Backslash escapes both itself and the colon,
                # which is a record separator.
                line = line.replace(R'\\', '\n')
                passtab.append(tuple(
                    p.replace('\n', R'\\')
                    for p in re.split(r'(?<!\\):', line, maxsplit=4)
                ))
    except IOError:
        pass

    return passtab


def _read_password_from_pgpass(
        *, passfile: typing.Optional[pathlib.Path],
        hosts: typing.List[str],
        ports: typing.List[int],
        database: str,
        user: str):
    """Parse the pgpass file and return the matching password.

    :return:
        Password string, if found, ``None`` otherwise.
    """

    passtab = _read_password_file(passfile)
    if not passtab:
        return None

    for host, port in zip(hosts, ports):
        if host.startswith('/'):
            # Unix sockets get normalized into 'localhost'
            host = 'localhost'

        for phost, pport, pdatabase, puser, ppassword in passtab:
            if phost != '*' and phost != host:
                continue
            if pport != '*' and pport != str(port):
                continue
            if pdatabase != '*' and pdatabase != database:
                continue
            if puser != '*' and puser != user:
                continue

            # Found a match.
            return ppassword

    return None


def _validate_port_spec(hosts, port):
    if isinstance(port, list) and len(port) > 1:
        # If there is a list of ports, its length must
        # match that of the host list.
        if len(port) != len(hosts):
            raise exceptions.ClientConfigurationError(
                'could not match {} port numbers to {} hosts'.format(
                    len(port), len(hosts)))
    elif isinstance(port, list) and len(port) == 1:
        port = [port[0] for _ in range(len(hosts))]
    else:
        port = [port for _ in range(len(hosts))]

    return port


def _parse_hostlist(hostlist, port, *, unquote=False):
    if ',' in hostlist:
        # A comma-separated list of host addresses.
        hostspecs = hostlist.split(',')
    else:
        hostspecs = [hostlist]

    hosts = []
    hostlist_ports = []

    if not port:
        portspec = os.environ.get('PGPORT')
        if portspec:
            if ',' in portspec:
                default_port = [int(p) for p in portspec.split(',')]
            else:
                default_port = int(portspec)
        else:
            default_port = 5432

        default_port = _validate_port_spec(hostspecs, default_port)

    else:
        port = _validate_port_spec(hostspecs, port)

    for i, hostspec in enumerate(hostspecs):
        if hostspec[0] == '/':
            # Unix socket
            addr = hostspec
            hostspec_port = ''
        elif hostspec[0] == '[':
            # IPv6 address
            m = re.match(r'(?:\[([^\]]+)\])(?::([0-9]+))?', hostspec)
            if m:
                addr = m.group(1)
                hostspec_port = m.group(2)
            else:
                raise exceptions.ClientConfigurationError(
                    'invalid IPv6 address in the connection URI: {!r}'.format(
                        hostspec
                    )
                )
        else:
            # IPv4 address
            addr, _, hostspec_port = hostspec.partition(':')

        if unquote:
            addr = urllib.parse.unquote(addr)

        hosts.append(addr)
        if not port:
            if hostspec_port:
                if unquote:
                    hostspec_port = urllib.parse.unquote(hostspec_port)
                hostlist_ports.append(int(hostspec_port))
            else:
                hostlist_ports.append(default_port[i])

    if not port:
        port = hostlist_ports

    return hosts, port


def _parse_tls_version(tls_version):
    if tls_version.startswith('SSL'):
        raise exceptions.ClientConfigurationError(
            f"Unsupported TLS version: {tls_version}"
        )
    try:
        return ssl_module.TLSVersion[tls_version.replace('.', '_')]
    except KeyError:
        raise exceptions.ClientConfigurationError(
            f"No such TLS version: {tls_version}"
        )


def _dot_postgresql_path(filename) -> typing.Optional[pathlib.Path]:
    try:
        homedir = pathlib.Path.home()
    except (RuntimeError, KeyError):
        return None

    return (homedir / '.postgresql' / filename).resolve()


def _parse_connect_dsn_and_args(*, dsn, host, port, user,
                                password, passfile, database, ssl,
                                service, servicefile,
                                direct_tls, server_settings,
                                target_session_attrs, krbsrvname, gsslib):
    # `auth_hosts` is the version of host information for the purposes
    # of reading the pgpass file.
    auth_hosts = None
    sslcert = sslkey = sslrootcert = sslcrl = sslpassword = None
    ssl_min_protocol_version = ssl_max_protocol_version = None
    sslnegotiation = None

    if dsn:
        parsed = urllib.parse.urlparse(dsn)

        query = None
        if parsed.query:
            query = urllib.parse.parse_qs(parsed.query, strict_parsing=True)
            for key, val in query.items():
                if isinstance(val, list):
                    query[key] = val[-1]

            if 'service' in query:
                val = query.pop('service')
                if not service and val:
                    service = val

        connection_service_file = servicefile

        if connection_service_file is None:
            connection_service_file = os.getenv('PGSERVICEFILE')

        if connection_service_file is None:
            homedir = compat.get_pg_home_directory()
            if homedir:
                connection_service_file = homedir / PG_SERVICEFILE
            else:
                connection_service_file = None
        else:
            connection_service_file = pathlib.Path(connection_service_file)

        if parsed.scheme not in {'postgresql', 'postgres'}:
            raise exceptions.ClientConfigurationError(
                'invalid DSN: scheme is expected to be either '
                '"postgresql" or "postgres", got {!r}'.format(parsed.scheme))

        if parsed.netloc:
            if '@' in parsed.netloc:
                dsn_auth, _, dsn_hostspec = parsed.netloc.partition('@')
            else:
                dsn_hostspec = parsed.netloc
                dsn_auth = ''
        else:
            dsn_auth = dsn_hostspec = ''

        if dsn_auth:
            dsn_user, _, dsn_password = dsn_auth.partition(':')
        else:
            dsn_user = dsn_password = ''

        if not host and dsn_hostspec:
            host, port = _parse_hostlist(dsn_hostspec, port, unquote=True)

        if parsed.path and database is None:
            dsn_database = parsed.path
            if dsn_database.startswith('/'):
                dsn_database = dsn_database[1:]
            database = urllib.parse.unquote(dsn_database)

        if user is None and dsn_user:
            user = urllib.parse.unquote(dsn_user)

        if password is None and dsn_password:
            password = urllib.parse.unquote(dsn_password)

        if query:

            if 'port' in query:
                val = query.pop('port')
                if not port and val:
                    port = [int(p) for p in val.split(',')]

            if 'host' in query:
                val = query.pop('host')
                if not host and val:
                    host, port = _parse_hostlist(val, port)

            if 'dbname' in query:
                val = query.pop('dbname')
                if database is None:
                    database = val

            if 'database' in query:
                val = query.pop('database')
                if database is None:
                    database = val

            if 'user' in query:
                val = query.pop('user')
                if user is None:
                    user = val

            if 'password' in query:
                val = query.pop('password')
                if password is None:
                    password = val

            if 'passfile' in query:
                val = query.pop('passfile')
                if passfile is None:
                    passfile = val

            if 'sslmode' in query:
                val = query.pop('sslmode')
                if ssl is None:
                    ssl = val

            if 'sslcert' in query:
                sslcert = query.pop('sslcert')

            if 'sslkey' in query:
                sslkey = query.pop('sslkey')

            if 'sslrootcert' in query:
                sslrootcert = query.pop('sslrootcert')

            if 'sslnegotiation' in query:
                sslnegotiation = query.pop('sslnegotiation')

            if 'sslcrl' in query:
                sslcrl = query.pop('sslcrl')

            if 'sslpassword' in query:
                sslpassword = query.pop('sslpassword')

            if 'ssl_min_protocol_version' in query:
                ssl_min_protocol_version = query.pop(
                    'ssl_min_protocol_version'
                )

            if 'ssl_max_protocol_version' in query:
                ssl_max_protocol_version = query.pop(
                    'ssl_max_protocol_version'
                )

            if 'target_session_attrs' in query:
                dsn_target_session_attrs = query.pop(
                    'target_session_attrs'
                )
                if target_session_attrs is None:
                    target_session_attrs = dsn_target_session_attrs

            if 'krbsrvname' in query:
                val = query.pop('krbsrvname')
                if krbsrvname is None:
                    krbsrvname = val

            if 'gsslib' in query:
                val = query.pop('gsslib')
                if gsslib is None:
                    gsslib = val

            if 'service' in query:
                val = query.pop('service')
                if service is None:
                    service = val

            if query:
                if server_settings is None:
                    server_settings = query
                else:
                    server_settings = {**query, **server_settings}

        if connection_service_file is not None and service is not None:
            pg_service = configparser.ConfigParser()
            pg_service.read(connection_service_file)
            if service in pg_service.sections():
                service_params = pg_service[service]
                if 'port' in service_params:
                    val = service_params.pop('port')
                    if not port and val:
                        port = [int(p) for p in val.split(',')]

                if 'host' in service_params:
                    val = service_params.pop('host')
                    if not host and val:
                        host, port = _parse_hostlist(val, port)

                if 'dbname' in service_params:
                    val = service_params.pop('dbname')
                    if database is None:
                        database = val

                if 'database' in service_params:
                    val = service_params.pop('database')
                    if database is None:
                        database = val

                if 'user' in service_params:
                    val = service_params.pop('user')
                    if user is None:
                        user = val

                if 'password' in service_params:
                    val = service_params.pop('password')
                    if password is None:
                        password = val

                if 'passfile' in service_params:
                    val = service_params.pop('passfile')
                    if passfile is None:
                        passfile = val

                if 'sslmode' in service_params:
                    val = service_params.pop('sslmode')
                    if ssl is None:
                        ssl = val

                if 'sslcert' in service_params:
                    val = service_params.pop('sslcert')
                    if sslcert is None:
                        sslcert = val

                if 'sslkey' in service_params:
                    val = service_params.pop('sslkey')
                    if sslkey is None:
                        sslkey = val

                if 'sslrootcert' in service_params:
                    val = service_params.pop('sslrootcert')
                    if sslrootcert is None:
                        sslrootcert = val

                if 'sslnegotiation' in service_params:
                    val = service_params.pop('sslnegotiation')
                    if sslnegotiation is None:
                        sslnegotiation = val

                if 'sslcrl' in service_params:
                    val = service_params.pop('sslcrl')
                    if sslcrl is None:
                        sslcrl = val

                if 'sslpassword' in service_params:
                    val = service_params.pop('sslpassword')
                    if sslpassword is None:
                        sslpassword = val

                if 'ssl_min_protocol_version' in service_params:
                    val = service_params.pop(
                        'ssl_min_protocol_version'
                    )
                    if ssl_min_protocol_version is None:
                        ssl_min_protocol_version = val

                if 'ssl_max_protocol_version' in service_params:
                    val = service_params.pop(
                        'ssl_max_protocol_version'
                    )
                    if ssl_max_protocol_version is None:
                        ssl_max_protocol_version = val

                if 'target_session_attrs' in service_params:
                    dsn_target_session_attrs = service_params.pop(
                        'target_session_attrs'
                    )
                    if target_session_attrs is None:
                        target_session_attrs = dsn_target_session_attrs

                if 'krbsrvname' in service_params:
                    val = service_params.pop('krbsrvname')
                    if krbsrvname is None:
                        krbsrvname = val

                if 'gsslib' in service_params:
                    val = service_params.pop('gsslib')
                    if gsslib is None:
                        gsslib = val
    if not service:
        service = os.environ.get('PGSERVICE')
    if not host:
        hostspec = os.environ.get('PGHOST')
        if hostspec:
            host, port = _parse_hostlist(hostspec, port)

    if not host:
        auth_hosts = ['localhost']

        if _system == 'Windows':
            host = ['localhost']
        else:
            host = ['/run/postgresql', '/var/run/postgresql',
                    '/tmp', '/private/tmp', 'localhost']

    if not isinstance(host, (list, tuple)):
        host = [host]

    if auth_hosts is None:
        auth_hosts = host

    if not port:
        portspec = os.environ.get('PGPORT')
        if portspec:
            if ',' in portspec:
                port = [int(p) for p in portspec.split(',')]
            else:
                port = int(portspec)
        else:
            port = 5432

    elif isinstance(port, (list, tuple)):
        port = [int(p) for p in port]

    else:
        port = int(port)

    port = _validate_port_spec(host, port)

    if user is None:
        user = os.getenv('PGUSER')
        if not user:
            user = getpass.getuser()

    if password is None:
        password = os.getenv('PGPASSWORD')

    if database is None:
        database = os.getenv('PGDATABASE')

    if database is None:
        database = user

    if user is None:
        raise exceptions.ClientConfigurationError(
            'could not determine user name to connect with')

    if database is None:
        raise exceptions.ClientConfigurationError(
            'could not determine database name to connect to')

    if password is None:
        if passfile is None:
            passfile = os.getenv('PGPASSFILE')

        if passfile is None:
            homedir = compat.get_pg_home_directory()
            if homedir:
                passfile = homedir / PGPASSFILE
            else:
                passfile = None
        else:
            passfile = pathlib.Path(passfile)

        if passfile is not None:
            password = _read_password_from_pgpass(
                hosts=auth_hosts, ports=port,
                database=database, user=user,
                passfile=passfile)

    addrs = []
    have_tcp_addrs = False
    for h, p in zip(host, port):
        if h.startswith('/'):
            # UNIX socket name
            if '.s.PGSQL.' not in h:
                h = os.path.join(h, '.s.PGSQL.{}'.format(p))
            addrs.append(h)
        else:
            # TCP host/port
            addrs.append((h, p))
            have_tcp_addrs = True

    if not addrs:
        raise exceptions.InternalClientError(
            'could not determine the database address to connect to')

    if ssl is None:
        ssl = os.getenv('PGSSLMODE')

    if ssl is None and have_tcp_addrs:
        ssl = 'prefer'

    if direct_tls is not None:
        sslneg = (
            SSLNegotiation.direct if direct_tls else SSLNegotiation.postgres
        )
    else:
        if sslnegotiation is None:
            sslnegotiation = os.environ.get("PGSSLNEGOTIATION")

        if sslnegotiation is not None:
            try:
                sslneg = SSLNegotiation(sslnegotiation)
            except ValueError:
                modes = ', '.join(
                    m.name.replace('_', '-')
                    for m in SSLNegotiation
                )
                raise exceptions.ClientConfigurationError(
                    f'`sslnegotiation` parameter must be one of: {modes}'
                ) from None
        else:
            sslneg = SSLNegotiation.postgres

    if isinstance(ssl, (str, SSLMode)):
        try:
            sslmode = SSLMode.parse(ssl)
        except AttributeError:
            modes = ', '.join(m.name.replace('_', '-') for m in SSLMode)
            raise exceptions.ClientConfigurationError(
                '`sslmode` parameter must be one of: {}'.format(modes)
            ) from None

        # docs at https://www.postgresql.org/docs/10/static/libpq-connect.html
        if sslmode < SSLMode.allow:
            ssl = False
        else:
            ssl = ssl_module.SSLContext(ssl_module.PROTOCOL_TLS_CLIENT)
            ssl.check_hostname = sslmode >= SSLMode.verify_full
            if sslmode < SSLMode.require:
                ssl.verify_mode = ssl_module.CERT_NONE
            else:
                if sslrootcert is None:
                    sslrootcert = os.getenv('PGSSLROOTCERT')
                if sslrootcert:
                    ssl.load_verify_locations(cafile=sslrootcert)
                    ssl.verify_mode = ssl_module.CERT_REQUIRED
                else:
                    try:
                        sslrootcert = _dot_postgresql_path('root.crt')
                        if sslrootcert is not None:
                            ssl.load_verify_locations(cafile=sslrootcert)
                        else:
                            raise exceptions.ClientConfigurationError(
                                'cannot determine location of user '
                                'PostgreSQL configuration directory'
                            )
                    except (
                        exceptions.ClientConfigurationError,
                        FileNotFoundError,
                        NotADirectoryError,
                    ):
                        if sslmode > SSLMode.require:
                            if sslrootcert is None:
                                sslrootcert = '~/.postgresql/root.crt'
                                detail = (
                                    'Could not determine location of user '
                                    'home directory (HOME is either unset, '
                                    'inaccessible, or does not point to a '
                                    'valid directory)'
                                )
                            else:
                                detail = None
                            raise exceptions.ClientConfigurationError(
                                f'root certificate file "{sslrootcert}" does '
                                f'not exist or cannot be accessed',
                                hint='Provide the certificate file directly '
                                     f'or make sure "{sslrootcert}" '
                                     'exists and is readable.',
                                detail=detail,
                            )
                        elif sslmode == SSLMode.require:
                            ssl.verify_mode = ssl_module.CERT_NONE
                        else:
                            assert False, 'unreachable'
                    else:
                        ssl.verify_mode = ssl_module.CERT_REQUIRED

                if sslcrl is None:
                    sslcrl = os.getenv('PGSSLCRL')
                if sslcrl:
                    ssl.load_verify_locations(cafile=sslcrl)
                    ssl.verify_flags |= ssl_module.VERIFY_CRL_CHECK_CHAIN
                else:
                    sslcrl = _dot_postgresql_path('root.crl')
                    if sslcrl is not None:
                        try:
                            ssl.load_verify_locations(cafile=sslcrl)
                        except (
                            FileNotFoundError,
                            NotADirectoryError,
                        ):
                            pass
                        else:
                            ssl.verify_flags |= \
                                ssl_module.VERIFY_CRL_CHECK_CHAIN

            if sslkey is None:
                sslkey = os.getenv('PGSSLKEY')
            if not sslkey:
                sslkey = _dot_postgresql_path('postgresql.key')
                if sslkey is not None and not sslkey.exists():
                    sslkey = None
            if not sslpassword:
                sslpassword = ''
            if sslcert is None:
                sslcert = os.getenv('PGSSLCERT')
            if sslcert:
                ssl.load_cert_chain(
                    sslcert, keyfile=sslkey, password=lambda: sslpassword
                )
            else:
                sslcert = _dot_postgresql_path('postgresql.crt')
                if sslcert is not None:
                    try:
                        ssl.load_cert_chain(
                            sslcert,
                            keyfile=sslkey,
                            password=lambda: sslpassword
                        )
                    except (FileNotFoundError, NotADirectoryError):
                        pass

            # OpenSSL 1.1.1 keylog file, copied from create_default_context()
            if hasattr(ssl, 'keylog_filename'):
                keylogfile = os.environ.get('SSLKEYLOGFILE')
                if keylogfile and not sys.flags.ignore_environment:
                    ssl.keylog_filename = keylogfile

            if ssl_min_protocol_version is None:
                ssl_min_protocol_version = os.getenv('PGSSLMINPROTOCOLVERSION')
            if ssl_min_protocol_version:
                ssl.minimum_version = _parse_tls_version(
                    ssl_min_protocol_version
                )
            else:
                ssl.minimum_version = _parse_tls_version('TLSv1.2')

            if ssl_max_protocol_version is None:
                ssl_max_protocol_version = os.getenv('PGSSLMAXPROTOCOLVERSION')
            if ssl_max_protocol_version:
                ssl.maximum_version = _parse_tls_version(
                    ssl_max_protocol_version
                )

    elif ssl is True:
        ssl = ssl_module.create_default_context()
        sslmode = SSLMode.verify_full
    else:
        sslmode = SSLMode.disable

    if server_settings is not None and (
            not isinstance(server_settings, dict) or
            not all(isinstance(k, str) for k in server_settings) or
            not all(isinstance(v, str) for v in server_settings.values())):
        raise exceptions.ClientConfigurationError(
            'server_settings is expected to be None or '
            'a Dict[str, str]')

    if target_session_attrs is None:
        target_session_attrs = os.getenv(
            "PGTARGETSESSIONATTRS", SessionAttribute.any
        )
    try:
        target_session_attrs = SessionAttribute(target_session_attrs)
    except ValueError:
        raise exceptions.ClientConfigurationError(
            "target_session_attrs is expected to be one of "
            "{!r}"
            ", got {!r}".format(
                SessionAttribute.__members__.values, target_session_attrs
            )
        ) from None

    if krbsrvname is None:
        krbsrvname = os.getenv('PGKRBSRVNAME')

    if gsslib is None:
        gsslib = os.getenv('PGGSSLIB')
        if gsslib is None:
            gsslib = 'sspi' if _system == 'Windows' else 'gssapi'
    if gsslib not in {'gssapi', 'sspi'}:
        raise exceptions.ClientConfigurationError(
            "gsslib parameter must be either 'gssapi' or 'sspi'"
            ", got {!r}".format(gsslib))

    params = _ConnectionParameters(
        user=user, password=password, database=database, ssl=ssl,
        sslmode=sslmode, ssl_negotiation=sslneg,
        server_settings=server_settings,
        target_session_attrs=target_session_attrs,
        krbsrvname=krbsrvname, gsslib=gsslib)

    return addrs, params


def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
                             database, command_timeout,
                             statement_cache_size,
                             max_cached_statement_lifetime,
                             max_cacheable_statement_size,
                             ssl, direct_tls, server_settings,
                             target_session_attrs, krbsrvname, gsslib,
                             service, servicefile):
    local_vars = locals()
    for var_name in {'max_cacheable_statement_size',
                     'max_cached_statement_lifetime',
                     'statement_cache_size'}:
        var_val = local_vars[var_name]
        if var_val is None or isinstance(var_val, bool) or var_val < 0:
            raise ValueError(
                '{} is expected to be greater '
                'or equal to 0, got {!r}'.format(var_name, var_val))

    if command_timeout is not None:
        try:
            if isinstance(command_timeout, bool):
                raise ValueError
            command_timeout = float(command_timeout)
            if command_timeout <= 0:
                raise ValueError
        except ValueError:
            raise ValueError(
                'invalid command_timeout value: '
                'expected greater than 0 float (got {!r})'.format(
                    command_timeout)) from None

    addrs, params = _parse_connect_dsn_and_args(
        dsn=dsn, host=host, port=port, user=user,
        password=password, passfile=passfile, ssl=ssl,
        direct_tls=direct_tls, database=database,
        server_settings=server_settings,
        target_session_attrs=target_session_attrs,
        krbsrvname=krbsrvname, gsslib=gsslib,
        service=service, servicefile=servicefile)

    config = _ClientConfiguration(
        command_timeout=command_timeout,
        statement_cache_size=statement_cache_size,
        max_cached_statement_lifetime=max_cached_statement_lifetime,
        max_cacheable_statement_size=max_cacheable_statement_size,)

    return addrs, params, config


class TLSUpgradeProto(asyncio.Protocol):
    def __init__(
        self,
        loop: asyncio.AbstractEventLoop,
        host: str,
        port: int,
        ssl_context: ssl_module.SSLContext,
        ssl_is_advisory: bool,
    ) -> None:
        self.on_data = _create_future(loop)
        self.host = host
        self.port = port
        self.ssl_context = ssl_context
        self.ssl_is_advisory = ssl_is_advisory

    def data_received(self, data: bytes) -> None:
        if data == b'S':
            self.on_data.set_result(True)
        elif (self.ssl_is_advisory and
                self.ssl_context.verify_mode == ssl_module.CERT_NONE and
                data == b'N'):
            # ssl_is_advisory will imply that ssl.verify_mode == CERT_NONE,
            # since the only way to get ssl_is_advisory is from
            # sslmode=prefer. But be extra sure to disallow insecure
            # connections when the ssl context asks for real security.
            self.on_data.set_result(False)
        else:
            self.on_data.set_exception(
                ConnectionError(
                    'PostgreSQL server at "{host}:{port}" '
                    'rejected SSL upgrade'.format(
                        host=self.host, port=self.port)))

    def connection_lost(self, exc: typing.Optional[Exception]) -> None:
        if not self.on_data.done():
            if exc is None:
                exc = ConnectionError('unexpected connection_lost() call')
            self.on_data.set_exception(exc)


_ProctolFactoryR = typing.TypeVar(
    "_ProctolFactoryR", bound=asyncio.protocols.Protocol
)


async def _create_ssl_connection(
    # TODO: The return type is a specific combination of subclasses of
    # asyncio.protocols.Protocol that we can't express. For now, having the
    # return type be dependent on signature of the factory is an improvement
    protocol_factory: Callable[[], _ProctolFactoryR],
    host: str,
    port: int,
    *,
    loop: asyncio.AbstractEventLoop,
    ssl_context: ssl_module.SSLContext,
    ssl_is_advisory: bool = False,
) -> typing.Tuple[asyncio.Transport, _ProctolFactoryR]:

    tr, pr = await loop.create_connection(
        lambda: TLSUpgradeProto(loop, host, port,
                                ssl_context, ssl_is_advisory),
        host, port)

    tr.write(struct.pack('!ll', 8, 80877103))  # SSLRequest message.

    try:
        do_ssl_upgrade = await pr.on_data
    except (Exception, asyncio.CancelledError):
        tr.close()
        raise

    if hasattr(loop, 'start_tls'):
        if do_ssl_upgrade:
            try:
                new_tr = await loop.start_tls(
                    tr, pr, ssl_context, server_hostname=host)
                assert new_tr is not None
            except (Exception, asyncio.CancelledError):
                tr.close()
                raise
        else:
            new_tr = tr

        pg_proto = protocol_factory()
        pg_proto.is_ssl = do_ssl_upgrade
        pg_proto.connection_made(new_tr)
        new_tr.set_protocol(pg_proto)

        return new_tr, pg_proto
    else:
        conn_factory = functools.partial(
            loop.create_connection, protocol_factory)

        if do_ssl_upgrade:
            conn_factory = functools.partial(
                conn_factory, ssl=ssl_context, server_hostname=host)

        sock = _get_socket(tr)
        sock = sock.dup()
        _set_nodelay(sock)
        tr.close()

        try:
            new_tr, pg_proto = await conn_factory(sock=sock)
            pg_proto.is_ssl = do_ssl_upgrade
            return new_tr, pg_proto
        except (Exception, asyncio.CancelledError):
            sock.close()
            raise


async def _connect_addr(
    *,
    addr,
    loop,
    params,
    config,
    connection_class,
    record_class
):
    assert loop is not None

    params_input = params
    if callable(params.password):
        password = params.password()
        if inspect.isawaitable(password):
            password = await password

        params = params._replace(password=password)
    args = (addr, loop, config, connection_class, record_class, params_input)

    # prepare the params (which attempt has ssl) for the 2 attempts
    if params.sslmode == SSLMode.allow:
        params_retry = params
        params = params._replace(ssl=None)
    elif params.sslmode == SSLMode.prefer:
        params_retry = params._replace(ssl=None)
    else:
        # skip retry if we don't have to
        return await __connect_addr(params, False, *args)

    # first attempt
    try:
        return await __connect_addr(params, True, *args)
    except _RetryConnectSignal:
        pass

    # second attempt
    return await __connect_addr(params_retry, False, *args)


class _RetryConnectSignal(Exception):
    pass


async def __connect_addr(
    params,
    retry,
    addr,
    loop,
    config,
    connection_class,
    record_class,
    params_input,
):
    connected = _create_future(loop)

    proto_factory = lambda: protocol.Protocol(
        addr, connected, params, record_class, loop)

    if isinstance(addr, str):
        # UNIX socket
        connector = loop.create_unix_connection(proto_factory, addr)

    elif params.ssl and params.ssl_negotiation is SSLNegotiation.direct:
        # if ssl and ssl_negotiation is `direct`, skip STARTTLS and perform
        # direct SSL connection
        connector = loop.create_connection(
            proto_factory, *addr, ssl=params.ssl
        )

    elif params.ssl:
        connector = _create_ssl_connection(
            proto_factory, *addr, loop=loop, ssl_context=params.ssl,
            ssl_is_advisory=params.sslmode == SSLMode.prefer)
    else:
        connector = loop.create_connection(proto_factory, *addr)

    tr, pr = await connector

    try:
        await connected
    except (
        exceptions.InvalidAuthorizationSpecificationError,
        exceptions.ConnectionDoesNotExistError,  # seen on Windows
    ):
        tr.close()

        # retry=True here is a redundant check because we don't want to
        # accidentally raise the internal _RetryConnectSignal to the user
        if retry and (
            params.sslmode == SSLMode.allow and not pr.is_ssl or
            params.sslmode == SSLMode.prefer and pr.is_ssl
        ):
            # Trigger retry when:
            #   1. First attempt with sslmode=allow, ssl=None failed
            #   2. First attempt with sslmode=prefer, ssl=ctx failed while the
            #      server claimed to support SSL (returning "S" for SSLRequest)
            #      (likely because pg_hba.conf rejected the connection)
            raise _RetryConnectSignal()

        else:
            # but will NOT retry if:
            #   1. First attempt with sslmode=prefer failed but the server
            #      doesn't support SSL (returning 'N' for SSLRequest), because
            #      we already tried to connect without SSL thru ssl_is_advisory
            #   2. Second attempt with sslmode=prefer, ssl=None failed
            #   3. Second attempt with sslmode=allow, ssl=ctx failed
            #   4. Any other sslmode
            raise

    except (Exception, asyncio.CancelledError):
        tr.close()
        raise

    con = connection_class(pr, tr, loop, addr, config, params_input)
    pr.set_connection(con)
    return con


class SessionAttribute(str, enum.Enum):
    any = 'any'
    primary = 'primary'
    standby = 'standby'
    prefer_standby = 'prefer-standby'
    read_write = "read-write"
    read_only = "read-only"


def _accept_in_hot_standby(should_be_in_hot_standby: bool):
    """
    If the server didn't report "in_hot_standby" at startup, we must determine
    the state by checking "SELECT pg_catalog.pg_is_in_recovery()".
    If the server allows a connection and states it is in recovery it must
    be a replica/standby server.
    """
    async def can_be_used(connection):
        settings = connection.get_settings()
        hot_standby_status = getattr(settings, 'in_hot_standby', None)
        if hot_standby_status is not None:
            is_in_hot_standby = hot_standby_status == 'on'
        else:
            is_in_hot_standby = await connection.fetchval(
                "SELECT pg_catalog.pg_is_in_recovery()"
            )
        return is_in_hot_standby == should_be_in_hot_standby

    return can_be_used


def _accept_read_only(should_be_read_only: bool):
    """
    Verify the server has not set default_transaction_read_only=True
    """
    async def can_be_used(connection):
        settings = connection.get_settings()
        is_readonly = getattr(settings, 'default_transaction_read_only', 'off')

        if is_readonly == "on":
            return should_be_read_only

        return await _accept_in_hot_standby(should_be_read_only)(connection)
    return can_be_used


async def _accept_any(_):
    return True


target_attrs_check = {
    SessionAttribute.any: _accept_any,
    SessionAttribute.primary: _accept_in_hot_standby(False),
    SessionAttribute.standby: _accept_in_hot_standby(True),
    SessionAttribute.prefer_standby: _accept_in_hot_standby(True),
    SessionAttribute.read_write: _accept_read_only(False),
    SessionAttribute.read_only: _accept_read_only(True),
}


async def _can_use_connection(connection, attr: SessionAttribute):
    can_use = target_attrs_check[attr]
    return await can_use(connection)


async def _connect(*, loop, connection_class, record_class, **kwargs):
    if loop is None:
        loop = asyncio.get_event_loop()

    addrs, params, config = _parse_connect_arguments(**kwargs)
    target_attr = params.target_session_attrs

    candidates = []
    chosen_connection = None
    last_error = None
    try:
        for addr in addrs:
            try:
                conn = await _connect_addr(
                    addr=addr,
                    loop=loop,
                    params=params,
                    config=config,
                    connection_class=connection_class,
                    record_class=record_class,
                )
                candidates.append(conn)
                if await _can_use_connection(conn, target_attr):
                    chosen_connection = conn
                    break
            except OSError as ex:
                last_error = ex
        else:
            if target_attr == SessionAttribute.prefer_standby and candidates:
                chosen_connection = random.choice(candidates)
    finally:

        async def _close_candidates(conns, chosen):
            await asyncio.gather(
                *(c.close() for c in conns if c is not chosen),
                return_exceptions=True
            )
        if candidates:
            asyncio.create_task(
                _close_candidates(candidates, chosen_connection))

    if chosen_connection:
        return chosen_connection

    raise last_error or exceptions.TargetServerAttributeNotMatched(
        'None of the hosts match the target attribute requirement '
        '{!r}'.format(target_attr)
    )


async def _cancel(*, loop, addr, params: _ConnectionParameters,
                  backend_pid, backend_secret):

    class CancelProto(asyncio.Protocol):

        def __init__(self):
            self.on_disconnect = _create_future(loop)
            self.is_ssl = False

        def connection_lost(self, exc):
            if not self.on_disconnect.done():
                self.on_disconnect.set_result(True)

    if isinstance(addr, str):
        tr, pr = await loop.create_unix_connection(CancelProto, addr)
    else:
        if params.ssl and params.sslmode != SSLMode.allow:
            tr, pr = await _create_ssl_connection(
                CancelProto,
                *addr,
                loop=loop,
                ssl_context=params.ssl,
                ssl_is_advisory=params.sslmode == SSLMode.prefer)
        else:
            tr, pr = await loop.create_connection(
                CancelProto, *addr)
            _set_nodelay(_get_socket(tr))

    # Pack a CancelRequest message
    msg = struct.pack('!llll', 16, 80877102, backend_pid, backend_secret)

    try:
        tr.write(msg)
        await pr.on_disconnect
    finally:
        tr.close()


def _get_socket(transport):
    sock = transport.get_extra_info('socket')
    if sock is None:
        # Shouldn't happen with any asyncio-complaint event loop.
        raise ConnectionError(
            'could not get the socket for transport {!r}'.format(transport))
    return sock


def _set_nodelay(sock):
    if not hasattr(socket, 'AF_UNIX') or sock.family != socket.AF_UNIX:
        sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)


def _create_future(loop):
    try:
        create_future = loop.create_future
    except AttributeError:
        return asyncio.Future(loop=loop)
    else:
        return create_future()
