#!/usr/bin/env python
import logging
import sys
import os
import subprocess

from spyt.dependency_utils import require_yt_client
require_yt_client()

from yt.wrapper import YtClient
from yt.wrapper.cli_helpers import ParseStructuredArgument
from yt.wrapper.http_helpers import get_user_name
from spyt.standalone import Worker, start_spark_cluster, find_spark_cluster, SparkDefaultArguments, SpytEnablers
from spyt import utils as spark_utils


def main():
    parser = spark_utils.get_default_arg_parser(description="Spark Launch")

    parser.add_argument("--worker-cores", required=True, type=int)
    parser.add_argument("--worker-memory", required=True)
    parser.add_argument("--worker-num", required=True, type=int)
    parser.add_argument("--worker-cores-overhead", required=False, type=int)
    parser.add_argument("--worker-timeout", required=False, default="10m")
    parser.add_argument("--pool", required=False)
    parser.add_argument("--tmpfs-limit", required=False, default="150G")
    parser.add_argument("--ssd-limit", required=False, default=None)
    parser.add_argument("--ssd-account", required=False, default=None)
    parser.add_argument("--master-memory-limit", required=False, default="4G")
    parser.add_argument("--history-server-memory-limit",
                        required=False, default="16G")
    parser.add_argument("--history-server-memory-overhead",
                        required=False, default="4G")
    parser.add_argument("--history-server-cpu-limit",
                        required=False, default=8, type=int)
    parser.add_argument("--operation-alias", required=False)
    parser.add_argument("--network-project", required=False)
    parser.add_argument("--params", required=False, action=ParseStructuredArgument, dest="params",
                        default=SparkDefaultArguments.get_params())
    parser.add_argument('--abort-existing', required=False,
                        action='store_true', default=False)
    parser.add_argument("--spark-cluster-version", required=False)
    parser.add_argument("--worker-log-update-interval",
                        required=False, default="10m")
    parser.add_argument("--worker-log-table-ttl", required=False, default="7d")
    parser.add_argument("--shs-location", required=False)
    parser.add_argument("--preemption_mode", required=False, default="normal")

    parser.add_argument('--enable-multi-operation-mode',
                        dest='enable_multi_operation_mode', action='store_true')
    parser.add_argument('--disable-multi-operation-mode',
                        dest='enable_multi_operation_mode', action='store_false')
    parser.set_defaults(enable_multi_operation_mode=False)

    default_enablers = SpytEnablers()
    group = parser.add_mutually_exclusive_group(required=False)
    group.add_argument('--enable-byop', dest='enable_byop',
                       action='store_true')
    group.add_argument('--disable-byop', dest='enable_byop',
                       action='store_false')
    parser.set_defaults(enable_byop=default_enablers.enable_byop)

    group = parser.add_mutually_exclusive_group(required=False)
    group.add_argument('--enable-solomon-agent', dest='enable_solomon_agent',
                       action='store_true')
    group.add_argument('--disable-solomon-agent', dest='enable_solomon_agent',
                       action='store_false')
    parser.set_defaults(enable_solomon_agent=default_enablers.enable_solomon_agent)

    group = parser.add_mutually_exclusive_group(required=False)
    group.add_argument('--enable-profiling',
                       dest='enable_profiling', action='store_true')
    group.add_argument('--disable-profiling',
                       dest='enable_profiling', action='store_false')
    parser.set_defaults(enable_profiling=default_enablers.enable_profiling)

    group = parser.add_mutually_exclusive_group(required=False)
    group.add_argument('--enable-mtn', dest='enable_mtn', action='store_true')
    group.add_argument('--disable-mtn', dest='enable_mtn',
                       action='store_false')
    parser.set_defaults(enable_mtn=default_enablers.enable_mtn)

    group = parser.add_mutually_exclusive_group(required=False)
    group.add_argument('--prefer-ipv6', dest='enable_preference_ipv6',
                       action='store_true')
    group.add_argument('--prefer-ipv4', dest='enable_preference_ipv6',
                       action='store_false')
    parser.set_defaults(enable_preference_ipv6=default_enablers.enable_preference_ipv6)

    group = parser.add_mutually_exclusive_group(required=False)
    group.add_argument('--enable-tcp-proxy', dest='enable_tcp_proxy',
                       action='store_true')
    group.add_argument('--disable-tcp-proxy', dest='enable_tcp_proxy',
                       action='store_false')
    parser.set_defaults(enable_tcp_proxy=default_enablers.enable_tcp_proxy)

    group = parser.add_mutually_exclusive_group(required=False)
    group.add_argument('--enable-advanced-event-log',
                       dest='advanced_event_log', action='store_true')
    group.add_argument('--disable-advanced-event-log',
                       dest='advanced_event_log', action='store_false')
    parser.set_defaults(advanced_event_log=False)

    group = parser.add_mutually_exclusive_group(required=False)
    group.add_argument('--enable-worker-log-transfer',
                       dest='worker_log_transfer', action='store_true')
    group.add_argument('--disable-worker-log-transfer',
                       dest='worker_log_transfer', action='store_false')
    parser.set_defaults(worker_log_transfer=False)

    group = parser.add_mutually_exclusive_group(required=False)
    group.add_argument('--enable-worker-log-json-mode',
                       dest='worker_log_json_mode', action='store_true')
    group.add_argument('--disable-worker-log-json-mode',
                       dest='worker_log_json_mode', action='store_false')
    parser.set_defaults(worker_log_json_mode=False)

    group = parser.add_mutually_exclusive_group(required=False)
    group.add_argument('--disable-dedicated-driver-operation-mode',
                       dest='dedicated_operation_mode', action='store_false')
    group.add_argument('--enable-dedicated-driver-operation-mode',
                       dest='dedicated_operation_mode', action='store_true')
    parser.set_defaults(dedicated_operation_mode=False)

    subgroup = parser.add_argument_group(
        '(experimental) run driver in dedicated operation')
    subgroup.add_argument("--driver-cores", required=False,
                          type=int, help="same as worker-cores by default")
    subgroup.add_argument("--driver-memory", required=False,
                          type=int, help="same as worker-memory by default")
    subgroup.add_argument("--driver-num", required=False,
                          type=int, help="Number of driver workers")
    subgroup.add_argument("--driver-cores-overhead", required=False,
                          type=int, help="same as worker-cores-overhead by default")
    subgroup.add_argument("--driver-timeout", required=False,
                          help="same as worker-timeout by default")

    subgroup = parser.add_argument_group("(experimental) autoscaler")
    subgroup.add_argument("--autoscaler-period", required=False,
                          type=str,
                          help="""
                            Start autoscaler process with provided period between autoscaling actions.
                            Period format is '<number> <time_unit>', for example '1s', '5 seconds', '100millis' etc.
                          """.strip())
    subgroup.add_argument("--autoscaler-metrics-port", required=False,
                          type=int, help="expose autoscaler metrics on provided port")
    subgroup.add_argument("--autoscaler-sliding-window", required=False,
                          type=int, help="size of autoscaler actions sliding window (in number of action) to downscale")
    subgroup.add_argument("--autoscaler-max-free-workers", required=False,
                          type=int, help="autoscaler maximum number of free workers")
    subgroup.add_argument("--autoscaler-slot-increment-step", required=False,
                          type=int, help="autoscaler worker slots increment step")

    subgroup = parser.add_argument_group("(experimental) Livy server")
    subgroup.add_argument('--enable-livy', action='store_true', default=False)
    subgroup.add_argument('--livy-driver-cores', required=False, default=1, type=int)
    subgroup.add_argument('--livy-driver-memory', required=False, default="768mb")
    subgroup.add_argument('--livy-max-sessions', required=False, default=2, type=int)

    args, unknown_args = spark_utils.parse_args(parser)

    yt_client = YtClient(proxy=args.proxy, token=spark_utils.default_token())

    if args.autoscaler_period and not args.enable_multi_operation_mode:
        print("Autoscaler could be enabled only with multi-operation mode")
        exit(-1)

    start_spark_cluster(worker_cores=args.worker_cores,
                        worker_memory=args.worker_memory,
                        worker_num=args.worker_num,
                        worker_cores_overhead=args.worker_cores_overhead,
                        worker_timeout=args.worker_timeout,
                        operation_alias=args.operation_alias,
                        discovery_path=args.discovery_path,
                        pool=args.pool or get_user_name(client=yt_client),
                        tmpfs_limit=args.tmpfs_limit,
                        ssd_limit=args.ssd_limit,
                        ssd_account=args.ssd_account,
                        master_memory_limit=args.master_memory_limit,
                        history_server_memory_limit=args.history_server_memory_limit,
                        history_server_memory_overhead=args.history_server_memory_overhead,
                        history_server_cpu_limit=args.history_server_cpu_limit,
                        network_project=args.network_project,
                        tvm_id=spark_utils.default_tvm_id(),
                        tvm_secret=spark_utils.default_tvm_secret(),
                        abort_existing=args.abort_existing,
                        advanced_event_log=args.advanced_event_log,
                        worker_log_transfer=args.worker_log_transfer,
                        worker_log_json_mode=args.worker_log_json_mode,
                        worker_log_update_interval=args.worker_log_update_interval,
                        worker_log_table_ttl=args.worker_log_table_ttl,
                        params=args.params,
                        shs_location=args.shs_location,
                        spark_cluster_version=args.spark_cluster_version,
                        enablers=SpytEnablers(
                            enable_byop=args.enable_byop,
                            enable_profiling=args.enable_profiling,
                            enable_mtn=args.enable_mtn,
                            enable_solomon_agent=args.enable_solomon_agent,
                            enable_preference_ipv6=args.enable_preference_ipv6,
                            enable_tcp_proxy=args.enable_tcp_proxy
                        ),
                        client=yt_client,
                        preemption_mode=args.preemption_mode,
                        enable_multi_operation_mode=args.enable_multi_operation_mode,
                        dedicated_operation_mode=args.dedicated_operation_mode,
                        driver_cores=args.driver_cores,
                        driver_memory=args.driver_memory,
                        driver_num=args.driver_num,
                        driver_cores_overhead=args.driver_cores_overhead,
                        driver_timeout=args.driver_timeout,
                        autoscaler_period=args.autoscaler_period,
                        autoscaler_metrics_port=args.autoscaler_metrics_port,
                        autoscaler_sliding_window=args.autoscaler_sliding_window,
                        autoscaler_max_free_workers=args.autoscaler_max_free_workers,
                        autoscaler_slot_increment_step=args.autoscaler_slot_increment_step,
                        enable_livy=args.enable_livy,
                        livy_driver_cores=args.livy_driver_cores,
                        livy_driver_memory=args.livy_driver_memory,
                        livy_max_sessions=args.livy_max_sessions)


if __name__ == '__main__':
    main()
