from typing import List

from clients.nsx_utils import NsxUtils
from clients.vc_utils import VcUtils

from common import constants as const
from common.logger import setup_logger

from model.common import HostSwitch, TncModel
from model.config import Config, ClusterEntry

from .process_tnc import TncProcessingUnit

logger = setup_logger()


class UensEnablement:
    def __init__(self, config: Config, cm_id: str, nsx_utils: NsxUtils, vc_utils: VcUtils):
        self.config = config
        self.cm_id = cm_id
        self.nsx_utils = nsx_utils
        self.vc_utils = vc_utils
        self.hp_prof_map = {}
        # the high-performance profile that will reset the params to default for ENS in a DVS
        self.hp_prof_path = None

    def get_high_performance_hostswitch_profiles(self):
        hp_profiles = self.nsx_utils.get_high_performance_hostswitch_profiles()
        for hp_prof in hp_profiles:
            self.hp_prof_map[hp_prof['path']] = hp_prof
            if hp_prof.get('auto_config', 0) == 0 and not hp_prof.get('high_performance_configs'):
                if hp_prof['_system_owned'] or not self.hp_prof_path:
                    self.hp_prof_path = hp_prof['path']
        if self.hp_prof_path:
            logger.info("HighPerformanceHostSwitchProfile %s is choen",
                        self.hp_prof_map[self.hp_prof_path]['display_name'])
        else:
            r_cls = [c for c in self.config.cluster_entry_list if c.reset_high_performance_params]
            if r_cls:
                raise Exception("No PolicyHighPerformanceHostSwitchProfile with auto_config=0 and"
                                " empty high_performance_configs is found to reset the high "
                                "performance params for ENS in cluster "
                                "%s." % r_cls[0].vcenter_cluster_name)

    def run_host_switch_modify_checks(self, current_tnp_json, new_tnp_json):
        old_switch_list: List[HostSwitch] = self.nsx_utils.get_host_switch_list(current_tnp_json)
        if not old_switch_list:
            raise Exception("No UENS enablement op is needed because the current TNP has no host "
                            "switch")
        if not self.nsx_utils.switch_list_has_standard_mode_switch(old_switch_list):
            raise Exception("No UENS enablement op is needed because the current TNP does not "
                            "have STANDARD switch mode")
        new_switch_list: List[HostSwitch] = self.nsx_utils.get_host_switch_list(new_tnp_json)
        new_switch_map = {switch.name : switch for switch in new_switch_list}
        miss_switch_names = []
        error_msg = ''
        for old_switch in old_switch_list:
            new_switch = new_switch_map.pop(old_switch.name, None)
            if new_switch is None:
                miss_switch_names.append(old_switch.name)
            elif old_switch.mode == const.STANDARD_MODE:
                if new_switch.mode != const.ENS_INTERRUPT_MODE:
                    error_msg += ("Host switch %s in new TNP must have mode ENS_INTERRUPT instead "
                                  "of %s. " % (old_switch.name, new_switch.mode))
            elif old_switch.mode != new_switch.mode:
                error_msg += ("Host switch %s in new TNP must have mode %s instead "
                              "of %s. " % (old_switch.name, old_switch.mode, new_switch.mode))
        if miss_switch_names:
            error_msg += ("New TNP must include host switch(es) %s. " % miss_switch_names)
        if new_switch_map:
            error_msg += ("New TNP must not include new host switch(es) %s." % \
                 list(new_switch_map.keys()))
        if error_msg:
            raise Exception(error_msg)

    def change_host_switch_spec_for_ens(self, host_switch_spec):
        changed = False
        for host_switch in host_switch_spec.get("host_switches", []):
            if host_switch["host_switch_mode"] == const.STANDARD_MODE:
                changed = True
                host_switch["host_switch_mode"] = const.ENS_INTERRUPT_MODE
                # do not have any high-performance profile in the new TNP
                self.nsx_utils.remove_high_performance_profile(host_switch)
        return changed

    def run_tnp_checks(self, cluster_entry: ClusterEntry, current_tnc: TncModel):
        current_tnp_path = current_tnc.transport_node_profile_path
        new_tnp_path = self.nsx_utils.construct_tnp_path_from_tnp_id(cluster_entry.tnp_id_to_apply)
        if current_tnp_path == new_tnp_path:
            raise Exception("No UENS enablement op is needed because the current TNP %s and "
                            "TNP-to-apply are the same" % current_tnp_path)

        current_tnp_json = self.nsx_utils.get_tnp_json_from_path(current_tnp_path)
        if not current_tnp_json:
            raise Exception("No TNP of path %s is found" % current_tnp_path)
        new_tnp_json = self.nsx_utils.get_tnp_json_from_path(new_tnp_path)
        if new_tnp_json:
            self.run_host_switch_modify_checks(current_tnp_json, new_tnp_json)
            changed = False
            for host_switch in new_tnp_json['host_switch_spec'].get("host_switches", []):
                if self.nsx_utils.remove_high_performance_profile(host_switch, reset_only=True):
                    changed = True
            if changed:
                self.nsx_utils.remove_tz_profiles_without_profile_id(
                    new_tnp_json['host_switch_spec'])
                logger.info("Removing HighPerformanceHostSwitchProfile from TNP %s ...",
                            new_tnp_json['display_name'])
                new_tnp_json = self.nsx_utils.nsx_client.put(
                    "policy/api/v1%s" % new_tnp_path, body=new_tnp_json)
                self.hp_prof_map[new_tnp_path] = new_tnp_json
        else:
            # create a TNP by changing the current TNP
            changed = self.change_host_switch_spec_for_ens(current_tnp_json['host_switch_spec'])
            if changed:
                self.nsx_utils.remove_tz_profiles_without_profile_id(
                    current_tnp_json['host_switch_spec'])
                self.nsx_utils.create_tnp(new_tnp_path, current_tnp_json['host_switch_spec'],
                                          cluster_entry.tnp_id_to_apply)
                logger.info("Created TNP %s from TNP %s attached to cluster name %s",
                            cluster_entry.tnp_id_to_apply, current_tnp_json['display_name'],
                            cluster_entry.vcenter_cluster_name)
            else:
                raise Exception("No UENS enablement op is needed because the current TNP does not"
                                " have STANDARD switch mode")

    def run_cluster_entry_checks(self, cluster_entry: ClusterEntry) -> TncModel:
        cluster_name = cluster_entry.vcenter_cluster_name
        logger.info("Starting checks for input vCenter cluster : %s", cluster_name)

        cluster = self.vc_utils.get_cluster_details_by_name(cluster_name=cluster_name)
        if not cluster:
            raise Exception("Could NOT find cluster in vCenter inventory")
        logger.info("For input cluster name %s, found a cluster in vCenter inventory with moID %s",
                    cluster_name, cluster._moId)

        drs_cfg = cluster.configurationEx.drsConfig
        if drs_cfg.enabled:
            drs_mode = drs_cfg.defaultVmBehavior if drs_cfg.defaultVmBehavior else ''
            if drs_mode != 'fullyAutomated':
                raise Exception("Cluster DRS mode is NOT fullyAutomated.")
        else:
            raise Exception("DRS is NOT enabled on cluster. Please enable DRS and set mode to "
                            "fullyAutomated")

        cm_collection_id = "%s:%s" % (self.cm_id, cluster._moId)
        tnc: TncModel = self.nsx_utils.get_tnc_by_cm_collection_id(cm_collection_id)
        if not tnc:
            raise Exception("Could NOT find a NSX TNC of id %s." % cm_collection_id)
        logger.info("For input vCenter cluster %s, found a NSX TNC %s", cluster_name, tnc.id)
        if tnc.resource_type != "HostTransportNodeCollection":
            raise Exception("NSX TNC %s is not a host transport node collection." % tnc.id)
        if tnc.cluster_id != cluster._moId:
            raise Exception("The cluster moID %s does not match cluster ID %s in the TNC." % (
                            cluster._moId, tnc.cluster_id))

        tnc_state_json = self.nsx_utils.get_tnc_state_json(tnc_id=tnc.id)
        tnc_state: str = tnc_state_json["state"]
        logger.info("For TNC : %s, current state is : %s", tnc.id, tnc_state)
        if tnc_state not in const.TNC_STATE_ALLOWED_STATUS_VALUES:
            raise Exception("UENS enablement is not supported in current TNC state %s. "
                            "It is only supported when the TNC state is in %s" % (
                            tnc_state, ", ".join(const.TNC_STATE_ALLOWED_STATUS_VALUES)))

        self.run_tnp_checks(cluster_entry=cluster_entry, current_tnc=tnc)

        logger.info("Checks successfully completed for input vCenter cluster : %s", cluster_name)
        self.nsx_utils.log_tnc_and_tn_details(tnc_id=tnc.id)
        return tnc

    def run(self):
        self.get_high_performance_hostswitch_profiles()
        self.vc_utils.enable_vmotion_between_ens_modes_key()

        for cluster_entry in self.config.cluster_entry_list:
            try:
                tnc: TncModel = self.run_cluster_entry_checks(cluster_entry=cluster_entry)
                tnc_process_unit = TncProcessingUnit(tnc=tnc,
                                                     cluster_entry=cluster_entry,
                                                     nsx_utils=self.nsx_utils,
                                                     vc_utils=self.vc_utils,
                                                     hp_prof_path=self.hp_prof_path)
                tnc_process_unit.run()
            except Exception as ex:
                logger.error("Cluster %s failed: %s", cluster_entry.vcenter_cluster_name, str(ex))
