import configparser
from custom_operators.qarnot_utils import start_cluster, \
    submit_app, fetch_output, stop_cluster, sync_workers
from airflow.models.baseoperator import BaseOperator
from airflow import AirflowException


def fetch_config(path_config_cluster):
    # fetch in config file
    config = configparser.ConfigParser()
    try:
        config.read(path_config_cluster)
    except KeyError:
        print("An issue occured with the file tmp_config/cluster.conf. "
              + " Check: path and content.")
        raise AirflowException

    port = config['cluster_connection']['port']
    host = config['cluster_connection']['host']
    task_uuid = config['cluster_connection']['task_uuid']
    job_uuid = config['cluster_connection']['job_uuid']

    return(port, host, task_uuid, job_uuid)


class QarnotStartCluster(BaseOperator):

    def __init__(
            self,
            nb_workers: int,
            path_resources: str,
            ssh_key: str,
            path_config_cluster: str,
            path_outputs: str,
            path_config_qarnot: str,
            path_local_logging_config: str,
            job_timeout: str,
            name_input_bucket: str = "airflow-spark-in",
            name_output_bucket: str = "airflow-spark-out",
            task_name_qarnot: str = "Hello World - Airflow Spark",
            path_remote_log_sh_config: str = "/opt/log.sh",
            automatically_trust_ssh_host: bool = False,
            spark_master_webui_port: int = 6060,
            spark_master_log: str = "/job/log_master",
            spark_worker_log: str = "/job/log_workers",
            log_shell_config: str = "/opt/log.sh",
            init_conn_timeout: int = 180,
            rm_file_path: str = None,
            path_ssh_key: str = None,
            prepare_ssh_func=None,
            delete_task: bool = False,
            ssh_timeout: str = "5s",
            *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.nb_workers = nb_workers
        self.path_resources = path_resources
        self.ssh_key = ssh_key
        self.path_config_cluster = path_config_cluster
        self.name_input_bucket = name_input_bucket
        self.name_output_bucket = name_output_bucket
        self.task_name_qarnot = task_name_qarnot
        self.path_outputs = path_outputs
        self.automatically_trust_ssh_host = automatically_trust_ssh_host
        self.path_config_qarnot = path_config_qarnot
        self.path_local_logging_config = path_local_logging_config
        self.path_remote_log_sh_config = path_remote_log_sh_config
        self.job_timeout = job_timeout
        self.spark_master_webui_port = spark_master_webui_port
        self.spark_master_log = spark_master_log
        self.spark_worker_log = spark_worker_log
        self.log_shell_config = log_shell_config
        self.init_conn_timeout = init_conn_timeout
        self.rm_file_path = rm_file_path
        self.path_ssh_key = path_ssh_key
        self.prepare_ssh_func = prepare_ssh_func
        self.delete_task = delete_task
        self.ssh_timeout = ssh_timeout

    def execute(self, context):
        message = "Executing the StartCluster operator."
        print(message)
        start_cluster.start_cluster(
                nb_workers=self.nb_workers,
                path_resources=self.path_resources,
                path_config_cluster=self.path_config_cluster,
                ssh_key=self.ssh_key,
                path_config_qarnot=self.path_config_qarnot,
                name_input_bucket=self.name_input_bucket,
                name_output_bucket=self.name_output_bucket,
                task_name_qarnot=self.task_name_qarnot,
                path_outputs=self.path_outputs,
                automatically_trust_ssh_host=self.automatically_trust_ssh_host,
                path_local_logging_config=self.path_local_logging_config,
                path_remote_log_sh_config=self.path_remote_log_sh_config,
                job_timeout=self.job_timeout,
                spark_master_webui_port=self.spark_master_webui_port,
                spark_master_log=self.spark_master_log,
                spark_worker_log=self.spark_worker_log,
                log_shell_config=self.log_shell_config,
                init_conn_timeout=self.init_conn_timeout,
                rm_file_path=self.rm_file_path,
                path_ssh_key=self.path_ssh_key,
                prepare_ssh_func=self.prepare_ssh_func,
                delete_task=self.delete_task,
                ssh_timeout=self.ssh_timeout)
        return message


class QarnotFetchOutput(BaseOperator):

    def __init__(
            self,
            path_config_cluster: str,
            path_outputs: str,
            path_config_qarnot: str,
            path_local_logging_config: str,
            download_locally: bool = False,
            snapshot_wait_sec: int = 30,
            delete_task: bool = False,
            *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.path_config_cluster = path_config_cluster
        self.path_outputs = path_outputs
        self.path_config_qarnot = path_config_qarnot
        self.path_local_logging_config = path_local_logging_config
        self.download_locally = download_locally
        self.snapshot_wait_sec = snapshot_wait_sec
        self.delete_task = delete_task

    def execute(self, context):
        message = "Executing the FetchOutput operator."
        print(message)

        (_, _, task_uuid, job_uuid) = fetch_config(self.path_config_cluster)
        fetch_output.fetch_output(
                task_uuid=task_uuid,
                job_uuid=job_uuid,
                path_outputs=self.path_outputs,
                path_config_qarnot=self.path_config_qarnot,
                download_locally=self.download_locally,
                snapshot_wait_sec=self.snapshot_wait_sec,
                path_local_logging_config=self.path_local_logging_config,
                delete_task=self.delete_task)
        return message


class QarnotSyncWorkers(BaseOperator):

    def __init__(
            self,
            path_config_cluster: str,
            path_config_qarnot: str,
            path_local_logging_config: str,
            snapshot_wait_sec: int = 30,
            delete_task: bool = False,
            *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.path_config_cluster = path_config_cluster
        self.snapshot_wait_sec = snapshot_wait_sec
        self.path_config_qarnot = path_config_qarnot
        self.path_local_logging_config = path_local_logging_config
        self.delete_task = delete_task

    def execute(self, context):
        message = "Executing the FetchOutput operator."
        print(message)

        (_, _, task_uuid, job_uuid) = fetch_config(self.path_config_cluster)
        sync_workers.sync_workers(
                task_uuid=task_uuid,
                job_uuid=job_uuid,
                path_config_qarnot=self.path_config_qarnot,
                snapshot_wait_sec=self.snapshot_wait_sec,
                path_local_logging_config=self.path_local_logging_config,
                delete_task=self.delete_task)
        return message


class QarnotStopCluster(BaseOperator):

    def __init__(
            self,
            path_config_cluster: str,
            path_config_qarnot: str,
            path_local_logging_config: str,
            delete_task: bool = False,
            *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.path_config_cluster = path_config_cluster
        self.path_config_qarnot = path_config_qarnot
        self.path_local_logging_config = path_local_logging_config
        self.delete_task = delete_task

    def execute(self, context):
        message = "Executing the StopCluster operator."
        print(message)

        (_, _, task_uuid, job_uuid) = fetch_config(self.path_config_cluster)
        stop_cluster.stop_cluster(
                    task_uuid=task_uuid,
                    job_uuid=job_uuid,
                    path_config_qarnot=self.path_config_qarnot,
                    path_local_logging_config=self.path_local_logging_config,
                    path_config_cluster=self.path_config_cluster,
                    delete_task=self.delete_task)
        return message


class QarnotSubmitApp(BaseOperator):

    def __init__(
            self,
            path_app: str,
            path_config_cluster: str,
            path_config_qarnot: str,
            path_local_logging_config: str,
            arguments=None,
            spark_driver_memory: str = "2G",
            spark_executor_memory: str = "14G",
            path_remote_logging_config: str = "/opt/python_logging.conf",
            automatically_trust_ssh_host: bool = False,
            path_remote_log_sh_config: str = "/opt/log.sh",
            path_remote_submit_script: str = "/opt/launch-spark-app.sh",
            path_ssh_key: str = None,
            prepare_ssh_func=None,
            delete_task: bool = False,
            *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.path_app = path_app
        self.path_config_cluster = path_config_cluster
        self.spark_driver_memory = spark_driver_memory
        self.spark_executor_memory = spark_executor_memory
        self.automatically_trust_ssh_host = automatically_trust_ssh_host
        self.path_config_qarnot = path_config_qarnot
        self.path_local_logging_config = path_local_logging_config
        self.path_remote_logging_config = path_remote_logging_config
        self.path_remote_log_sh_config = path_remote_log_sh_config
        self.path_remote_submit_script = path_remote_submit_script
        self.path_ssh_key = path_ssh_key
        self.prepare_ssh_func = prepare_ssh_func
        self.delete_task = delete_task
        self.arguments = arguments

    def execute(self, context):
        message = "Executing the SubmitApp operator."
        print(message)
        (port, host, task_uuid, job_uuid) = fetch_config(
                                                self.path_config_cluster)
        submit_app.submit_app(
            port=port,
            host=host,
            task_uuid=task_uuid,
            job_uuid=job_uuid,
            path_remote_submit_script=self.path_remote_submit_script,
            path_app=self.path_app,
            path_config_qarnot=self.path_config_qarnot,
            spark_driver_memory=self.spark_driver_memory,
            spark_executor_memory=self.spark_executor_memory,
            automatically_trust_ssh_host=self.automatically_trust_ssh_host,
            path_local_logging_config=self.path_local_logging_config,
            path_remote_logging_config=self.path_remote_logging_config,
            path_remote_log_sh_config=self.path_remote_log_sh_config,
            path_ssh_key=self.path_ssh_key,
            prepare_ssh_func=self.prepare_ssh_func,
            delete_task=self.delete_task,
            arguments=self.arguments)
        return message