import atexit
import ssl
import time
import traceback
import urllib

from pyVim.connect import Disconnect, SmartConnect
from pyVmomi import vim

from common import constants as const
from common.logger import setup_logger

from .vc_client import VCClient

logger = setup_logger()


class VcUtils:

    def __init__(self, vc_client: VCClient, verify_ssl=True):
        self.vc_client: VCClient = vc_client
        self.si = None
        self.content = None
        self.verify_ssl = verify_ssl

    def connect_to_vc(self, max_retry=const.VC_CONNECT_MAX_RETRIES,
                      retry_delay=const.VC_CONNECT_RETRY_DELAY):
        if self.si:
            logger.info("vCenter connection is already established to %s:%s",
                        self.vc_client.fqdn, self.vc_client.port)
            return

        err_msg = None
        for i in range(max_retry + 1):
            try:
                context = ssl.create_default_context()
                if not self.verify_ssl:
                    context.check_hostname = False
                    context.verify_mode = ssl.CERT_NONE
                logger.debug("Connecting to VC %s:%s", self.vc_client.fqdn, self.vc_client.port)
                self.si = SmartConnect(host=self.vc_client.fqdn,
                                       port=int(self.vc_client.port),
                                       user=self.vc_client.username,
                                       pwd=self.vc_client.password,
                                       sslContext=context)
                break
            except vim.fault.InvalidLogin as ex:
                logger.error("Failed connecting to VC at %s:%s: %s", self.vc_client.fqdn,
                             self.vc_client.port, str(ex))
                raise ex
            except urllib.error.URLError as ex2:
                logger.error("Failed connecting to VC at %s:%s: %s", self.vc_client.fqdn,
                             self.vc_client.port, str(ex2))
                raise ex2
            except Exception as e:
                self.si = None
                err_msg = "Failed connecting to VC at %s:%s: %s" % (
                    self.vc_client.fqdn, self.vc_client.port, str(e))
                if i < max_retry:
                    logger.warning("%s. Will retry in %s seconds; %s retries left",
                                   err_msg, retry_delay, max_retry - i)
                    logger.warning(traceback.format_exc())
                    time.sleep(retry_delay)
                else:
                    logger.error("%s. All re-tries exhausted.", err_msg)
                    logger.error(traceback.format_exc())
                    break

        if self.si:
            atexit.register(Disconnect, self.si)

        if err_msg:
            raise Exception(err_msg)
        else:
            logger.info("Connected to VC %s:%s", self.vc_client.fqdn, self.vc_client.port)

    def get_content(self, use_cache=True):
        if self.content and use_cache:
            return self.content
        self.content = self.si.RetrieveContent()
        return self.content

    def destroy_container_throw_no_error(self, container):
        try:
            if container:
                container.Destroy()

        except Exception as e:
            logger.warning("Error when destroying vCenter container view : %s", str(e))
            pass

    def get_cluster_details_by_id(self, mo_id: str):
        try:
            content = self.get_content()
            container = content.viewManager.CreateContainerView(content.rootFolder,
                                                                [vim.ClusterComputeResource],
                                                                True)
            clusters_list = container.view

            for cluster in clusters_list:
                if cluster._moId == mo_id:
                    return cluster

            return None

        except Exception as e:
            logger.error("Failed to get cluster details for mo id: %s, error : %s", mo_id, str(e))
            raise

        finally:
            self.destroy_container_throw_no_error(container=container)

    def get_cluster_details_by_name(self, cluster_name: str):
        try:
            content = self.get_content()
            container = content.viewManager.CreateContainerView(content.rootFolder,
                                                                [vim.ClusterComputeResource],
                                                                True)
            clusters_list = container.view

            for cluster in clusters_list:
                if cluster.name == cluster_name:
                    return cluster

            return None

        except Exception as e:
            logger.error("Failed to get cluster details for name: %s, error : %s",
                         cluster_name, str(e))
            raise

        finally:
            self.destroy_container_throw_no_error(container=container)

    def get_host_by_moid(self, mo_id: str):
        try:
            content = self.get_content()
            container = content.viewManager.CreateContainerView(content.rootFolder,
                                                                [vim.HostSystem],
                                                                True)
            hosts_list = container.view

            for host in hosts_list:
                if host._moId == mo_id:
                    return host

            return None

        except Exception as e:
            logger.error("Failed to get host details for mo id: %s, error : %s", mo_id, str(e))
            raise

        finally:
            self.destroy_container_throw_no_error(container=container)

    def get_hosts_by_moids(self, moid_set):
        try:
            content = self.get_content()
            container = content.viewManager.CreateContainerView(content.rootFolder,
                                                                [vim.HostSystem],
                                                                True)
            hosts_list = container.view
            moid2_host_map = {}
            for host in hosts_list:
                if host._moId in moid_set:
                    moid2_host_map[host._moId] = host
                    moid_set.discard(host._moId)
                    if not moid_set:
                        break
            return moid2_host_map
        except Exception as e:
            logger.error("Failed to get hosts from VC, error : %s", str(e))
            raise
        finally:
            self.destroy_container_throw_no_error(container=container)

    def enable_vmotion_between_ens_modes_key(self):
        content = self.get_content()

        try:
            vmotion_setting_before_update = None
            kv_list = content.setting.QueryView(const.VC_SETTING_VMOTION_BETWEEN_ENS_MODES)
            if kv_list:
                vmotion_setting_before_update = kv_list[0].value

            if vmotion_setting_before_update == "true":
                logger.info("VC setting to allow vMotion between ENS modes is already set to true")
                return

        except vim.fault.InvalidName:
            logger.info("VC setting to allow vMotion between ENS modes is not present; will set it")

        content.setting.UpdateValues(
            [vim.option.OptionValue(key=const.VC_SETTING_VMOTION_BETWEEN_ENS_MODES, value="true")])

        vmotion_setting_after_update = None
        kv_list = content.setting.QueryView(const.VC_SETTING_VMOTION_BETWEEN_ENS_MODES)
        if kv_list:
            vmotion_setting_after_update = kv_list[0].value

        if vmotion_setting_after_update != "true":
            logger.error("Could not update VC setting : %s to true. Current setting : %s",
                         const.VC_SETTING_VMOTION_BETWEEN_ENS_MODES, vmotion_setting_after_update)
            raise Exception("Could not update VC setting : %s to true. Current setting : %s" % (
                const.VC_SETTING_VMOTION_BETWEEN_ENS_MODES, vmotion_setting_after_update))

        logger.info("The VC setting %s is successfully set to %s",
                    const.VC_SETTING_VMOTION_BETWEEN_ENS_MODES, vmotion_setting_after_update)

    def wait_for_task_to_complete(self, task, timeout_seconds: int, mo_id: str, task_name: str):
        logger.info("Starting polling on %s task for host: %s.", task_name, mo_id)

        elapsed_time = 0
        log_interval = 60  # log every 60 seconds when polling for task status
        while elapsed_time < timeout_seconds:
            if task.info.state in [vim.TaskInfo.State.success, vim.TaskInfo.State.error]:
                break

            if elapsed_time % log_interval == 0:
                logger.info("%s task is still running for host %s with status %s. Time elapsed "
                            "in seconds : %s. Total timeout in seconds %s", task_name, mo_id,
                            task.info.state, elapsed_time, timeout_seconds)
            time.sleep(15)  # Poll every 15 seconds
            elapsed_time += 15  # Increment elapsed time

        if task.info.state == vim.TaskInfo.State.success:
            logger.info("%s task succeeded for host %s", task_name, mo_id)
            return None
        if task.info.state == vim.TaskInfo.State.error:
            logger.error("%s task failed for host %s, task status %s, Error: %s", task_name,
                         mo_id, task.info.state, task.info.error)
            err_msg = "%s task failed for host: %s, task status: %s, Message: %s" % (
                task_name, mo_id, task.info.state, task.info.error.msg)
        elif task.info.state in [vim.TaskInfo.State.queued, vim.TaskInfo.State.running]:
            err_msg = ("%s task still running for host: %s, task state: %s. Task did not finish "
                       "with in the timeout specified in config file. So marking the operation "
                       "as error and moving on." % (task_name, mo_id, task.info.state))
            logger.error(err_msg)
        else:
            err_msg = "%s task for host: %s failed with unexpected error." % (task_name, mo_id)
            logger.error(err_msg)

        return err_msg

    def enter_host_into_mm(self, host, timeout_seconds: int):
        mo_id = host._moId

        logger.info("Entering host %s %s into maintenance mode.", host.name, mo_id)

        if host.runtime.inMaintenanceMode:
            logger.info("Host %s %s is already in maintenance mode.", host.name, mo_id)
            return

        task = host.EnterMaintenanceMode(timeout=timeout_seconds)

        err_msg = self.wait_for_task_to_complete(task=task, timeout_seconds=timeout_seconds,
                                                 mo_id=mo_id, task_name="Enter MM")
        if err_msg:
            msg = "Enter MM operation failed for host %s: %s" % (host.name, err_msg)
            logger.error(msg)
            raise Exception(msg)

        if host.runtime.inMaintenanceMode:
            logger.info("Enter MM successful for host %s.", host.name)
        else:
            msg = "Enter MM operation failed for host %s with unknown error" % host.name
            logger.error(msg)
            raise Exception(msg)

    def exit_host_from_mm(self, host):
        mo_id = host._moId

        logger.info("Exiting host %s %s from maintenance mode.", host.name, mo_id)

        if host.runtime.inMaintenanceMode is False:
            logger.info("Host %s %s is already out of maintenance mode.", host.name, mo_id)
            return

        timeout_seconds = const.EXIT_MM_TIMEOUT_IN_MINUTES * 60

        task = host.ExitMaintenanceMode(timeout=timeout_seconds)

        err_msg = self.wait_for_task_to_complete(task=task, timeout_seconds=timeout_seconds,
                                                 mo_id=mo_id, task_name="Exit MM")
        if err_msg:
            logger.info("Exit MM operation failed for host %s: %s", host.name, err_msg)
            raise Exception("Exit MM operation failed for host %s: %s" % (host.name, err_msg))

        if host.runtime.inMaintenanceMode is False:
            logger.info("Exit MM successful for host %s.", host.name)
        else:
            msg = "Exit MM operation failed for host %s with unknown error" % host.name
            logger.error(msg)
            raise Exception(msg)

    def is_host_in_mm(self, mo_id):
        host = self.get_host_by_moid(mo_id)
        return host.runtime.inMaintenanceMode
