import time
import traceback

from clients.nsx_utils import NsxUtils
from clients.vc_utils import VcUtils

from common import constants as const
from common.logger import setup_logger

from model.common import HostTnModel, TnProcessingInfo, OpStatus

logger = setup_logger()


class TnProcessingUnit:
    def __init__(self, tn_id: str, nsx_utils: NsxUtils, vc_utils: VcUtils,
                 enter_mm_timeout_mins: str, hp_prof_path: str):
        self.tn_id = tn_id
        self.nsx_utils = nsx_utils
        self.vc_utils = vc_utils
        self.enter_mm_timeout_mins = enter_mm_timeout_mins
        self.hp_prof_path = hp_prof_path

    def change_switch_mode_from_standard_to_ens_interrupt_for_tn(
            self, tn_id: str, tn_state_before_update: str, tn_processing_info: TnProcessingInfo):
        logger.info("Changing TN %s STANDARD mode host switches to ENS_INTERRUPT.", tn_id)
        tn_before_update: HostTnModel = self.nsx_utils.get_tn(tn_id)
        logger.debug("TN %s before switch mode update: \n%s", tn_id,
                     tn_before_update.pretty_print())

        tn_api_update_body = tn_before_update.tn_api_resp_json
        tn_spec_changed = False
        for host_switch in tn_api_update_body['host_switch_spec'].get("host_switches", []):
            if host_switch["host_switch_mode"] == const.STANDARD_MODE:
                host_switch["host_switch_mode"] = const.ENS_INTERRUPT_MODE
                tn_spec_changed =True
                self.nsx_utils.remove_high_performance_profile(host_switch)
        if not tn_spec_changed:
            logger.info("TN %s has no switch mode to update", tn_id)
            return

        self.nsx_utils.remove_tz_profiles_without_profile_id(tn_api_update_body['host_switch_spec'])
        logger.info("Updating TN %s to change switch mode to ENS_INTERRUPT...", tn_id)
        self.nsx_utils.nsx_client.put("policy/api/v1/infra/sites/default/enforcement-points/default"
                                      "/host-transport-nodes/%s" % tn_id, body=tn_api_update_body)

        self.tn_state_check_after_update(tn_state_before_update, tn_processing_info)

    def reset_high_performance_params_for_tn(
            self, tn_id: str, tn_state_before_update: str, tn_processing_info: TnProcessingInfo):
        logger.info("Resetting high-performance params for TN %s.", tn_id)
        tn_before_update: HostTnModel = self.nsx_utils.get_tn(tn_id)
        logger.debug("TN %s before switch mode update: \n%s", tn_id,
                     tn_before_update.pretty_print())

        tn_api_update_body = tn_before_update.tn_api_resp_json
        tn_spec_changed = False
        for host_switch in tn_api_update_body['host_switch_spec'].get("host_switches", []):
            if host_switch["host_switch_mode"] == const.STANDARD_MODE:
                hp_prof_kv, hs_prof_list = self.nsx_utils.get_high_performance_profile_kv_pair(
                    host_switch)
                if hp_prof_kv:
                    if hp_prof_kv['value'] == self.hp_prof_path:
                        continue
                    hp_prof_kv['value'] = self.hp_prof_path
                else:
                    hs_prof_list.append({"key": "HighPerformanceHostSwitchProfile",
                                         "value": self.hp_prof_path
                                        })
                    host_switch['host_switch_profile_ids'] = hs_prof_list
                tn_spec_changed =True
        if not tn_spec_changed:
            logger.info("TN %s already reset high-performance params", tn_id)
            return tn_state_before_update

        self.nsx_utils.remove_tz_profiles_without_profile_id(tn_api_update_body['host_switch_spec'])
        logger.info("Updating TN %s to reset high-performance params...", tn_id)
        self.nsx_utils.nsx_client.put("policy/api/v1/infra/sites/default/enforcement-points/default"
                                      "/host-transport-nodes/%s" % tn_id, body=tn_api_update_body)
        # do not raise exception if the TN state is not success after update because resetting
        # high-performance params is not critical
        tn_state = self.tn_state_check_after_update(tn_state_before_update, tn_processing_info,
                                                    raise_ex=False)
        return tn_state

    def tn_state_check_after_update(self, tn_state_before_update: str,
                                    tn_processing_info: TnProcessingInfo, raise_ex=True):
        elapsed_time = 0
        tn_state_after_update = 'unknown'
        if tn_state_before_update == "success":
            time.sleep(15)  # sleep for extar 15 seconds to allow the state to change
        while elapsed_time < const.NSX_SWITCH_MODE_CHANGE_STATE_CHECK_TIMEOUT_IN_SECONDS:
            time.sleep(15)  # Poll every 15 seconds
            tn_state_after_update = self.nsx_utils.get_tn_state_json_by_id(self.tn_id)["state"]
            if tn_state_after_update == "success":
                break
            logger.debug("TN %s state after switch mode update is %s. Sleep for 15 seconds before "
                         "polling state again.", self.tn_id, tn_state_after_update)
            elapsed_time += 15  # Increment elapsed time

        if tn_state_after_update != "success":
            logger.info("TN %s final state after switch mode update is %s", self.tn_id,
                        tn_state_after_update)
            tn_processing_info.status = OpStatus.FAILED
            tn_processing_info.error_message = "TN state after ENS update is " + \
                tn_state_after_update
            if tn_state_before_update == "success" and raise_ex:
                tn_processing_info.pause_after_failure = True
                raise Exception("TN %s state changed to %s after switch mode is updated, so stop "
                                "execution here without starting operation on the next host." % (
                                self.tn_id, tn_state_after_update))
            else:
                logger.info("Still continue to next host because TN %s state before switch mode "
                            "update is %s", self.tn_id, tn_state_before_update)
        else:
            tn_processing_info.status = OpStatus.SUCCESS
            tn_processing_info.error_message = None
        return tn_state_after_update

    def process_tn(self) -> TnProcessingInfo:
        tn_processing_info = TnProcessingInfo(status=OpStatus.SUCCESS,
                                              error_message=None)
        try:
            logger.info("Initiating processing of TN : %s", self.tn_id)
            tn_before_update: HostTnModel = self.nsx_utils.get_tn(tn_id=self.tn_id)
            host_mo_id = tn_before_update.host_mo_id
            host = self.vc_utils.get_host_by_moid(mo_id=host_mo_id)

            if not self.nsx_utils.switch_list_has_standard_mode_switch(
                    tn_before_update.host_switch_list):
                logger.info("Skip updating TN %s that has no STANDARD mode switch", host.name)
            else:
                enter_mm_timeout_seconds = self.enter_mm_timeout_mins * 60
                self.vc_utils.enter_host_into_mm(host, timeout_seconds=enter_mm_timeout_seconds)

                tn_state_before_update = self.nsx_utils.get_tn_state_json_by_id(self.tn_id)["state"]
                logger.info("TN %s state before switch mode update is %s", self.tn_id,
                            tn_state_before_update)

                if not self.vc_utils.is_host_in_mm(mo_id=host_mo_id):
                    raise Exception("Skip updating switch mode in host %s for it is not in MM." \
                                    % host.name)

                if self.hp_prof_path:
                    # reset high-performance params for the TN before changing the switch mode to
                    # avoid concurrent changes in data-path kernel modules
                    tn_state_before_update = self.reset_high_performance_params_for_tn(
                        self.tn_id, tn_state_before_update, tn_processing_info)
                # change the switch mode from STANDARD to ENS_INTERRUPT
                self.change_switch_mode_from_standard_to_ens_interrupt_for_tn(
                    self.tn_id, tn_state_before_update, tn_processing_info)

            self.vc_utils.exit_host_from_mm(host)

            logger.info("Processed TN : %s", self.tn_id)

        except Exception as e:
            err_msg = "Error while enabling UENS on TN: %s. %s" % (self.tn_id, str(e))
            logger.error(err_msg)
            logger.error(traceback.format_exc())
            tn_processing_info.status = OpStatus.FAILED
            tn_processing_info.error_message = err_msg

        return tn_processing_info

    def run(self) -> TnProcessingInfo:
        return self.process_tn()
