#!/usr/bin/env python3

import os
import shutil
import qarnot
from airflow import AirflowException
import subprocess
import configparser
import time
import signal
from custom_operators.qarnot_utils import qarnot_general


# Handler must have a signal number and the interrupted stack frame as args
def handler_timeout(signum, frame):
    raise AirflowException("SSH timeout while checking worker connections."
                           + "Check ssh config")


def check_workers_connection(trust_ssh_host, task, port, host, logger,
                             job_timeout, init_conn_timeout, path_ssh_key,
                             prepare_ssh_func, ssh_timeout):
    """Check through ssh that all workers are connected to the master
    If the checking time is longer than job_timeout, it raises an airflow \
    exception. Otherwise airflow process won't stop.
    """
    logger.debug("Automatically trust ssh host: " + str(trust_ssh_host))
    logger.debug("Check workers connection")
    logger.debug("Connecting to port " + str(port))

    # Set the timer for ssh connection, to raise Airflow exception
    # if ssh connection never works
    signal.signal(signal.SIGALRM, handler_timeout)
    logger.debug("Timer starts :" + str(init_conn_timeout) + "sec")
    signal.alarm(init_conn_timeout)

    workers_not_connected = True
    while workers_not_connected:
        try:
            if trust_ssh_host:
                # ssh_timeout prevents from being stuck because of an
                # ssh prompt (only happens if public key is incorrect)
                # if no ssh timeout, the signal timer wouldn't work
                msg = ['timeout', ssh_timeout, 'ssh',
                       '-o', 'StrictHostKeyChecking=no',
                       '-o', 'CheckHostIP=no',
                       '-p', str(port), 'root@' + host, 'test', '-f',
                       '/tmp/cluster_ready']
                # add the private key path if needed
                if path_ssh_key is not None:
                    qarnot_general.insert_in_list(msg, 3, '-i', path_ssh_key)

                # Add prepare_ssh_function if specified
                if prepare_ssh_func is not None:
                    prepare_ssh_func(logger)

                subprocess.check_call(msg)
            else:
                logger.debug("automatically_trust_ssh_host is set to False, "
                             + "this is not supported by Airflow. Start again "
                             + "with automatically_trust_ssh_host set to True")
            workers_not_connected = False
            logger.debug("Workers are all connected to the master")
            # Cancel the timer
            logger.debug("Timer stops")
            signal.alarm(0)
        except subprocess.CalledProcessError:
            logger.debug("SSH command not successful, retry in 5 seconds...")
            task.wait(5)
            continue


def create_cluster_config_file(task, job, path_config_cluster, logger):
    """Create a config file to easily connect to the cluster in other modules.
    """
    logger.debug("Creating cluster config")
    config_not_updated = True
    while config_not_updated:
        try:
            port = (task.status.running_instances_info.
                    per_running_instance_info[0].active_forward[0]
                    .forwarder_port)
            host = (task.status.running_instances_info
                    .per_running_instance_info[0].active_forward[0]
                    .forwarder_host)
            task_uuid = task.uuid
            job_uuid = job.uuid

            # remove a previous file in case it existed
            # ignore the error if the file doesn't exist
            logger.debug("Remove previous config cluster file if exists.")
            shutil.rmtree(path_config_cluster, ignore_errors=True)
            # creates two folders
            logger.debug("Create folder for config cluster file")
            path_tmp_config = os.path.dirname(path_config_cluster)
            os.makedirs(path_tmp_config, exist_ok=True)

            # Write config file
            logger.debug("Write config cluster file")
            config = configparser.ConfigParser()
            config['cluster_connection'] = {'port': port, 'host': host,
                                            'task_uuid': task_uuid,
                                            'job_uuid': job_uuid}
            with open(path_config_cluster, 'w') as configfile:
                config.write(configfile)
            config_not_updated = False
            logger.debug("Cluster config created successfully")
            return(port, host, task_uuid, job_uuid)

        except IndexError:
            logger.debug("Forward port of master is not available yet, "
                         + "retry in 5 seconds...")
            task.wait(5)
            continue


def start_cluster(
        nb_workers,
        path_resources,
        path_config_cluster,
        ssh_key,
        path_config_qarnot,
        name_input_bucket,
        name_output_bucket,
        task_name_qarnot,
        path_outputs,
        automatically_trust_ssh_host,
        path_local_logging_config,
        path_remote_log_sh_config,
        job_timeout,
        spark_master_webui_port,
        spark_master_log,
        spark_worker_log,
        log_shell_config,
        init_conn_timeout,
        rm_file_path,
        path_ssh_key,
        prepare_ssh_func,
        delete_task,
        ssh_timeout):

    """This function:
     - creates a task
     - creates an input bucket, synchronizes it with a local directory and \
     attaches it to the task
     - creates an output bucket and attaches it to the task
     - launches the task
     - creates a local file with the following information:
        - port to communicate with for ssh
        - host (forward01.qarnot.net)
        - task uuid
    - checks that all workers are connected
    - manages the logs
    """

    try:
        # Create the output folder, as the logger will write in it
        os.makedirs(path_outputs, exist_ok=True)

        # Configure the logger
        logger = qarnot_general.configure_logger(path_local_logging_config)
        logger.debug("Logger configured successfully")

        # Remove known_host if arg
        if rm_file_path is not None:
            logger.warning("Remove file from rm_file_path")
            try:
                os.remove(rm_file_path)
            except (FileNotFoundError, TypeError):
                logger.warning("The path rm_file_path is incorrect")

        # Start the connection to Qarnot API
        logger.debug("Connecting to Qarnot API")
        conn = qarnot_general.qarnot_connection(path_config_qarnot, logger)

        # Submit the job
        logger.debug("Submitting the job")
        logger.debug("Job timeout = " + job_timeout)
        job = qarnot.connection.Connection.create_job(conn, "Airflow job")
        job._uuid = None
        job.max_wall_time = job_timeout
        job.submit()

        # Create and sync input bucket
        logger.debug("Creating the input bucket")
        input_bucket = conn.retrieve_or_create_bucket(name_input_bucket)
        logger.debug("Syncing the input bucket with local folder "
                     + path_resources)
        input_bucket.sync_directory(path_resources)

        # Create task
        # number instances = nb workers + 1 master
        logger.debug("Creating the Qarnot task")
        nb_instances = nb_workers + 1

        task = conn.create_task(task_name_qarnot, 'spark-cluster',
                                nb_instances, job=job)

        task.resources = [input_bucket]

        # Create and sync output bucket
        logger.debug("Creating the output bucket")
        output_bucket = conn.retrieve_or_create_bucket(name_output_bucket)
        task.results = output_bucket

        # Set task constants
        task.constants['DOCKER_SSH'] = ssh_key
        task.constants['SPARK_MASTER_WEBUI_PORT'] = spark_master_webui_port
        task.constants['SPARK_MASTER_LOG'] = spark_master_log
        task.constants['SPARK_WORKER_LOG'] = spark_worker_log
        task.constants['LOG_SHELL_CONFIG'] = log_shell_config

        task.snapshot(5)

        # Submit the task to the Api, that will launch it on the cluster
        try:
            logger.debug("Submitting the task")
            task.submit()
        except ValueError:
            # give time to the race condition to "resolve itself"
            time.sleep(5)
            # force the snapshots now that it should work
            task.snapshot(5)

        # Wait for creation of the object task.status
        task.wait(5)

        # Wait for the task to be 'FullyExecuting'
        last_state = ''
        while last_state not in ['Cancelled', 'Failure', 'FullyExecuting']:
            if task.state != last_state:
                last_state = task.state
                logger.debug("Status of the task: {}".format(last_state))
            task.wait(2)

        cluster_not_ready = True
        while (task.state not in ['Cancelled', 'Failure', 'Success']
               and cluster_not_ready):

            # Create a config file to easily connect to the cluster
            port, host, _, _ = create_cluster_config_file(
                                                task, job,
                                                path_config_cluster,
                                                logger)

            # Check that all workers are connected to the master
            check_workers_connection(automatically_trust_ssh_host, task, port,
                                     host, logger, job_timeout,
                                     init_conn_timeout, path_ssh_key,
                                     prepare_ssh_func, ssh_timeout)
            # Cluster is available
            logger.debug("Cluster is ready")
            cluster_not_ready = False

        # Display errors on failure and raise airflow exception
        qarnot_general.check_failure(task, job, delete_task, logger)

        # raise an airflow exception if the task was cancelled
        qarnot_general.check_cancellation(task, job, delete_task, logger)

    except Exception:
        # Check if task is already defined
        try:
            task
            qarnot_general.handle_error(logger, delete_task, task, job)
        except NameError:
            qarnot_general.handle_error(logger, delete_task)
