from typing import List

from clients.nsx_utils import NsxUtils
from clients.vc_utils import VcUtils

from common import constants as const
from common.logger import setup_logger

from model.common import HostTnModel, TncModel, TnProcessingInfo, OpStatus
from model.config import ClusterEntry

from .process_tn import TnProcessingUnit

logger = setup_logger()


class TncProcessingUnit:
    def __init__(self, tnc: TncModel, cluster_entry: ClusterEntry, nsx_utils: NsxUtils,
                 vc_utils: VcUtils, hp_prof_path: str):
        self.tnc = tnc
        self.cluster_entry = cluster_entry
        self.nsx_utils = nsx_utils
        self.vc_utils = vc_utils
        self.hp_prof_path = hp_prof_path

    def get_sorted_tn_list_by_mm(self, failed_list: List[TnProcessingInfo]):
        tn_list: List[HostTnModel] = self.nsx_utils.get_tn_list_for_tnc(tnc=self.tnc)

        if not tn_list:
            return []

        maintenance_mode_tns: List[HostTnModel] = []
        normal_tns: List[HostTnModel] = []
        moid_set = set([tn.host_mo_id for tn in tn_list])
        moid2_host_map = self.vc_utils.get_hosts_by_moids(moid_set)
        for tn in tn_list:
            host = moid2_host_map.get(tn.host_mo_id)
            if host:
                if host.runtime.inMaintenanceMode:
                    maintenance_mode_tns.append(tn)
                else:
                    normal_tns.append(tn)
            else:
                err_msg = "TN %s is ignored because its host is not found by %s" % (
                    tn.id, tn.host_mo_id)
                failed_list.append(TnProcessingInfo(status=OpStatus.FAILED, error_message=err_msg))

        return maintenance_mode_tns + normal_tns

    def apply_new_tnp(self):
        self.nsx_utils.apply_tnp_to_tnc(tnc_id=self.tnc.id,
                                        tnp_id_to_apply=self.cluster_entry.tnp_id_to_apply)

    def check_if_new_tnp_can_be_applied_for_tnc(self):
        tn_list: List[HostTnModel] = self.nsx_utils.get_tn_list_for_tnc(tnc=self.tnc)

        for tn in tn_list:
            for switch in tn.host_switch_list:
                if switch.mode == const.STANDARD_MODE:
                    err_msg = ("Cannot apply new TNP %s on TNC %s because TN %s still has STANDARD "
                               "mode host switch %s.") % (
                               self.cluster_entry.tnp_id_to_apply, self.tnc.id, tn.id, switch.name)
                    logger.error(err_msg)
                    raise Exception(err_msg)

    def process_tnc(self):
        logger.info("Initiating processing of cluster %s, TNC %s",
                    self.cluster_entry.vcenter_cluster_name, self.tnc.id)

        # Very important to get the ordered TN list where in hosts that are already in MM are
        # at the start of list.
        failed_list: List[TnProcessingInfo] = []
        ordered_tn_list: List[HostTnModel] = self.get_sorted_tn_list_by_mm(failed_list)

        self.nsx_utils.check_host_details_mismatch_for_tnc(tnc=self.tnc, tn_list=ordered_tn_list,
                                                           vc_utils=self.vc_utils)

        for tn in ordered_tn_list:
            tn_process_unit = TnProcessingUnit(
                tn_id=tn.id, nsx_utils=self.nsx_utils, vc_utils=self.vc_utils,
                enter_mm_timeout_mins=self.cluster_entry.enter_mm_timeout_minutes,
                hp_prof_path=self.hp_prof_path if \
                    self.cluster_entry.reset_high_performance_params else None)
            tn_status_info: TnProcessingInfo = tn_process_unit.run()
            if tn_status_info.status != OpStatus.SUCCESS:
                failed_list.append(tn_status_info)  # Add to the failed list
                if tn_status_info.pause_after_failure:  # In case we do not want next host to start
                    break

        if len(failed_list) > 0:
            error_message = "\n".join(ent.error_message for ent in failed_list if ent.error_message)
            raise Exception("Details : \n%s" % (error_message))

        self.check_if_new_tnp_can_be_applied_for_tnc()
        self.apply_new_tnp()

        self.nsx_utils.log_tnc_and_tn_details(tnc_id=self.tnc.id)
        logger.info("Successfully processed cluster %s TNC %s.",
                    self.cluster_entry.vcenter_cluster_name, self.tnc.id)

    def run(self):
        self.process_tnc()
