from typing import List, Set

from common import constants as const
from common.logger import setup_logger
from model.common import HostSwitch, HostTnModel, TncModel

from .nsx_client import NSXClient
from .vc_utils import VcUtils

logger = setup_logger()


class NsxUtils:

    def __init__(self, nsx_client: NSXClient):
        self.nsx_client = nsx_client

    def get_version(self):
        version = self.nsx_client.get(endpoint="api/v1/node/version")

        return version.json() if version else None

    def get_cm_json_by_server(self, server: str):
        cm_list = self.nsx_client.get_list_results(
            endpoint="policy/api/v1/fabric/compute-managers?server=%s" % (server))
        return cm_list[0] if cm_list else None

    # Get host switch list from TN API response json or TNP API response json
    def get_host_switch_list(self, json_data) -> List[HostSwitch]:
        result_list: List[HostSwitch] = []

        if not json_data:
            return result_list

        for hs_data in json_data["host_switch_spec"]["host_switches"]:
            hs = HostSwitch(
                id=hs_data["host_switch_id"],
                name=hs_data["host_switch_name"],
                mode=hs_data["host_switch_mode"],
                spec=hs_data
            )
            result_list.append(hs)

        return result_list

    def get_tn_from_tn_api_resp_json(self, tn_api_resp_json) -> HostTnModel:
        if not tn_api_resp_json:
            return None

        tn_id: str = tn_api_resp_json["id"]
        dn_id: str = tn_api_resp_json["node_deployment_info"]["discovered_node_id"]
        host_mo_id: str = dn_id.split(":")[-1]

        host_switch_list: List[HostSwitch] = self.get_host_switch_list(tn_api_resp_json)

        tn = HostTnModel(id=tn_id,
                         host_mo_id=host_mo_id,
                         host_switch_list=host_switch_list,
                         tn_api_resp_json=tn_api_resp_json)

        return tn

    def get_tn(self, tn_id: str) -> HostTnModel:
        tn_resp = self.nsx_client.get("policy/api/v1/infra/sites/default/enforcement-points/"
                                      "default/host-transport-nodes/%s" % tn_id)

        return self.get_tn_from_tn_api_resp_json(tn_resp.json()) if tn_resp else None

    def get_tn_by_dn_id(self, discovered_node_id: str) -> HostTnModel:
        tn_list = self.nsx_client.get_list_results(
            "policy/api/v1/infra/sites/default/enforcement-points/default/host-transport-nodes"
            "?discovered_node_id=%s" % discovered_node_id)
        return self.get_tn_from_tn_api_resp_json(tn_list[0]) if tn_list else None

    def get_tn_state_json_by_id(self, tn_id: str):
        tn_resp = self.nsx_client.get("policy/api/v1/infra/sites/default/enforcement-points/"
                                      "default/host-transport-nodes/%s/state" % tn_id)

        return tn_resp.json() if tn_resp else None

    def get_tn_list_for_tnc(self, tnc: TncModel) -> List[HostTnModel]:
        dn_list_json = self.get_dn_list_json_by_compute_collection(tnc.parent_compute_collection_id)
        dn_id_set: Set[str] = {r.get("external_id") for r in dn_list_json if r.get("external_id")}

        result_list: List[HostTnModel] = []
        for dn_id in dn_id_set:
            tn: HostTnModel = self.get_tn_by_dn_id(dn_id)
            if tn:
                result_list.append(tn)

        return result_list

    def get_dn_list_json_by_compute_collection(self, cm_collection_id: str):
        return self.nsx_client.get_list_results("policy/api/v1/fabric/discovered-nodes",
                                                {'parent_compute_collection': cm_collection_id})

    def switch_list_has_standard_mode_switch(self, host_switch_list: List[HostSwitch]):
        for host_switch in host_switch_list:
            if host_switch.mode == const.STANDARD_MODE:
                return True

        return False

    def get_tnc_from_tnc_api_resp_json(self, tnc_api_resp_json):
        if not tnc_api_resp_json:
            return None

        tnc_id = tnc_api_resp_json["id"]
        resource_type: str = tnc_api_resp_json["resource_type"]
        cm_collection_id: str = tnc_api_resp_json["compute_collection_id"]
        cluster_id: str = None
        if cm_collection_id:
            cluster_id = cm_collection_id.split(":")[-1]
        tn_profile_path: str = tnc_api_resp_json["transport_node_profile_id"]

        tnc = TncModel(id=tnc_id,
                       resource_type=resource_type,
                       parent_compute_collection_id=cm_collection_id,
                       cluster_id=cluster_id,
                       transport_node_profile_path=tn_profile_path,
                       tnc_api_resp_json=tnc_api_resp_json)

        return tnc

    def get_tnc(self, tnc_id: str) -> TncModel:
        tnc_resp = self.nsx_client.get("policy/api/v1/infra/sites/default/enforcement-points/"
                                       "default/transport-node-collections/%s" % tnc_id)

        return self.get_tnc_from_tnc_api_resp_json(tnc_resp.json()) if tnc_resp else None

    def get_tnc_by_cm_collection_id(self, cm_collection_id: str) -> TncModel:
        tnc_list = self.nsx_client.get_list_results(
            "policy/api/v1/infra/sites/default/enforcement-points/default/"
            "transport-node-collections?compute_collection_id=%s" % cm_collection_id)
        return self.get_tnc_from_tnc_api_resp_json(tnc_list[0]) if tnc_list else None

    def get_tnc_state_json(self, tnc_id: str) -> str:
        tnc_state_resp = self.nsx_client.get("policy/api/v1/infra/sites/default/enforcement-points/"
                                             "default/transport-node-collections/%s/state" % tnc_id)

        return tnc_state_resp.json() if tnc_state_resp else None

    def get_tnp_json_from_path(self, tnp_path: str):
        tnp_resp = self.nsx_client.get(endpoint="policy/api/v1%s" % (tnp_path))
        return tnp_resp.json() if tnp_resp else None

    def construct_tnp_path_from_tnp_id(self, tnp_id: str):
        return "/infra/host-transport-node-profiles/%s" % (tnp_id)

    def create_tnp(self, tnp_path: str, host_switch_spec, name):
        tnp_json = {"resource_type": "PolicyHostTransportNodeProfile",
                    "display_name": name,
                    "host_switch_spec": host_switch_spec
                   }
        return self.nsx_client.put("policy/api/v1%s" % tnp_path, tnp_json)

    def apply_tnp_to_tnc(self, tnc_id: str, tnp_id_to_apply: str):
        logger.info("For TNC : %s, initiating apply TNP profile : %s", tnc_id, tnp_id_to_apply)
        tnc_before_update: TncModel = self.get_tnc(tnc_id)
        tnp_path_before_update: str = tnc_before_update.transport_node_profile_path

        tnp_path_to_apply: str = self.construct_tnp_path_from_tnp_id(tnp_id_to_apply)

        logger.info("For TNC : %s, current profile : %s, profile to apply : %s", tnc_id,
                    tnp_path_before_update, tnp_path_to_apply)

        tnc_api_update_body = tnc_before_update.tnc_api_resp_json
        tnc_api_update_body["transport_node_profile_id"] = tnp_path_to_apply

        self.nsx_client.put("policy/api/v1/infra/sites/default/enforcement-points/default/"
                            "transport-node-collections/%s" % tnc_id, body=tnc_api_update_body)

        tnp_path_after_update: str = self.get_tnc(tnc_id).transport_node_profile_path
        if tnp_path_after_update == tnp_path_to_apply:
            logger.info("For TNC %s successfully applied profile %s. Current profile %s",
                        tnc_id, tnp_path_to_apply, tnp_path_after_update)
        else:
            logger.error("For TNC %s, could not apply profile %s. Current profile %s",
                         tnc_id, tnp_path_to_apply, tnp_path_after_update)
            raise Exception("For TNC %s, could not apply profile %s." % (tnc_id, tnp_path_to_apply))

    def check_host_details_mismatch_for_tnc(self, tnc: TncModel, tn_list: List[HostTnModel],
                                            vc_utils: VcUtils):
        tnc_id: str = tnc.id
        tn_count: int = len(tn_list)

        cluster = vc_utils.get_cluster_details_by_id(tnc.cluster_id)
        host_count: int = len(cluster.host)

        logger.debug("TN count of TNC %s is %d, vSphere cluster %s host count is %d", tnc_id,
                     tn_count, cluster.name, host_count)

        if tn_count != host_count:
            tn_count_mismatch_message = ("TN count %d of TNC %s does not match host count %d of "
                "cluster %s") % (tn_count, tnc_id, host_count, cluster.name)
            logger.error(tn_count_mismatch_message)
            raise Exception(tn_count_mismatch_message)

        host_id_set_from_tn_list: Set[str] = {tn.host_mo_id for tn in tn_list}

        host_id_set_from_cluster: Set[str] = {host._moId for host in cluster.host}

        if host_id_set_from_tn_list != host_id_set_from_cluster:
            err_msg = ("The host moIDs retrieved from TNs of TNC %s do not match "
                       "those retrieved from vSphere cluster %s corresponding to TNC.") % (
                       tnc_id, cluster.name)
            logger.error("%s Difference: %s %s", err_msg,
                         list(host_id_set_from_tn_list - host_id_set_from_cluster),
                         list(host_id_set_from_cluster - host_id_set_from_tn_list))
            raise Exception(err_msg)

    def log_tnc_and_tn_details(self, tnc_id: str, throw_error: bool = False):
        try:
            tnc: TncModel = self.get_tnc(tnc_id=tnc_id)
            tn_list: List[HostTnModel] = self.get_tn_list_for_tnc(tnc=tnc)
            logger.debug("TNC details : \n%s", tnc.pretty_print())
            logger.debug("TN details of TNC : \n%s", HostTnModel.pretty_print_list(tn_list=tn_list))
        except Exception as e:
            logger.warning("Failed to log latest information about TNC and its TNs. Error: %s",
                           str(e))
            if throw_error:
                raise

    def get_high_performance_hostswitch_profiles(self):
        return self.nsx_client.get_list_results(
            "policy/api/v1/infra/host-switch-profiles",
            {'hostswitch_profile_type': 'PolicyHighPerformanceHostSwitchProfile',
             'include_system_owned': True})

    def get_hostswitch_profile(self, hs_prof_path):
        resp = self.nsx_client.get("policy/api/v1%s" % hs_prof_path)
        return resp.json() if resp else None

    def get_high_performance_profile_kv_pair(self, host_switch):
        hs_prof_list = host_switch.get('host_switch_profile_ids', [])
        for hs_prof in hs_prof_list:
            if hs_prof["key"] == "HighPerformanceHostSwitchProfile":
                return hs_prof, hs_prof_list
        return None, hs_prof_list

    def remove_high_performance_profile(self, host_switch, reset_only=False):
        hs_prof_list = host_switch.get('host_switch_profile_ids', [])
        for prof_kv in hs_prof_list:
            if prof_kv["key"] == "HighPerformanceHostSwitchProfile" and prof_kv.get('value'):
                remove_prof = True
                if reset_only:
                    hp_prof = self.get_hostswitch_profile(prof_kv['value'])
                    if hp_prof.get('auto_config', 0) or hp_prof.get('high_performance_configs'):
                        remove_prof = False
                if remove_prof:
                    new_prof_list = [p for p in hs_prof_list if p['key'] != prof_kv['key']]
                    host_switch['host_switch_profile_ids'] = new_prof_list
                    return True
                break
        return False

    def remove_tz_profiles_without_profile_id(self, host_switch_spec):
        host_switch_list = host_switch_spec.get('host_switches', [])
        for host_switch in host_switch_list:
            tz_endpoints = host_switch.get('transport_zone_endpoints', [])
            for tz_endpoint in tz_endpoints:
                tz_profiles = tz_endpoint.get('transport_zone_profile_ids', [])
                tz_endpoint['transport_zone_profile_ids'] = \
                    [tp for tp in tz_profiles if tp.get('profile_id')]
