# Copyright (c) Facebook, Inc. and its affiliates. 
#   
# This source code is licensed under the MIT license found in the 
# LICENSE file in the root directory of this source tree.
import torch
import bitsandbytes.functional as F

from copy import deepcopy
from itertools import chain
from collections import defaultdict, abc as container_abcs

class MockArgs(object):
    def __init__(self, initial_data):
        for key in initial_data:
            setattr(self, key, initial_data[key])


class GlobalOptimManager(object):
    _instance = None

    def __init__(self):
        raise RuntimeError('Call get_instance() instead')

    def initialize(self):
        self.pid2config = {}
        self.index2config = {}
        self.optimizer = None
        self.uses_config_override = False

    @classmethod
    def get_instance(cls):
        if cls._instance is None:
            cls._instance = cls.__new__(cls)
            cls._instance.initialize()
        return cls._instance

    def register_parameters(self, params):
        param_groups = list(params)
        if not isinstance(param_groups[0], dict):
            param_groups = [{'params': param_groups}]

        for group_index, group in enumerate(param_groups):
            for p_index, p in enumerate(group['params']):
                if id(p) in self.pid2config:
                    self.index2config[(group_index, p_index)] = self.pid2config[id(p)]

    def override_config(self, parameters, key=None, value=None, key_value_dict=None):
        '''
        Overrides initial optimizer config for specific parameters.

        The key-values of the optimizer config for the input parameters are overidden
        This can be both, optimizer parameters like "betas", or "lr" or it can be
        8-bit specific paramters like "optim_bits", "percentile_clipping".

        Parameters
        ----------
        parameters : torch.Tensor or list(torch.Tensors)
            The input parameters.
        key : str
            The hyperparamter to override.
        value : object
            The value for the hyperparamters.
        key_value_dict : dict
            A dictionary with multiple key-values to override.
        '''
        self.uses_config_override = True
        if isinstance(parameters, torch.nn.Parameter):
            parameters = [parameters]
        if isinstance(parameters, torch.Tensor):
            parameters = [parameters]
        if key is not None and value is not None:
            assert key_value_dict is None
            key_value_dict = {key: value}

        if key_value_dict is not None:
            for p in parameters:
                if id(p) in self.pid2config:self.pid2config[id(p)].update(key_value_dict)
                else: self.pid2config[id(p)] = key_value_dict


class Optimizer8bit(torch.optim.Optimizer):

    def __init__(self, params, defaults, optim_bits=32):
        super(Optimizer8bit, self).__init__(params, defaults)
        self.checked_if_on_gpu = False
        self.name2qmap = {}

        self.mng = GlobalOptimManager.get_instance()
        self.non_castable_tensor_keys = set(
                ['qmap1', 'qmap2',
                 'max1', 'max2',
                 'new_max1', 'new_max2',
                 'state1', 'state2',
                 'gnorm_vec', 'absmax1', 'absmax2',
                 'unorm_vec'])

        if optim_bits == 8: self.fill_qmap()

    def fill_qmap(self):
        self.name2qmap['dynamic'] = F.create_dynamic_map(signed=True)
        self.name2qmap['udynamic'] = F.create_dynamic_map(signed=False)

    def __setstate__(self, state):
        super(Optimizer8bit, self).__setstate__(state)


    def load_state_dict(self, state_dict):
        r"""Loads the optimizer state.

        Args:
            state_dict (dict): optimizer state. Should be an object returned
                from a call to :meth:`state_dict`.
        """
        # deepcopy, to be consistent with module API
        state_dict = deepcopy(state_dict)
        # Validate the state_dict
        groups = self.param_groups
        saved_groups = state_dict['param_groups']

        if len(groups) != len(saved_groups):
            raise ValueError("loaded state dict has a different number of "
                             "parameter groups")
        param_lens = (len(g['params']) for g in groups)
        saved_lens = (len(g['params']) for g in saved_groups)
        if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
            raise ValueError("loaded state dict contains a parameter group "
                             "that doesn't match the size of optimizer's group")

        # Update the state
        id_map = {old_id: p for old_id, p in
                  zip(chain.from_iterable((g['params'] for g in saved_groups)),
                      chain.from_iterable((g['params'] for g in groups)))}

        def cast(param, value):
            r"""Make a deep copy of value, casting all tensors to device of param."""
            if isinstance(value, torch.Tensor):
                # Floating-point types are a bit special here. They are the only ones
                # that are assumed to always match the type of params.
                if param.is_floating_point() and value.dtype != torch.uint8:
                    value = value.to(param.dtype)
                return value
            elif isinstance(value, dict):
                for k, v in value.items():
                    if k in self.non_castable_tensor_keys:
                        value[k] = v.to(param.device)
                    else:
                        value[k] = cast(param, v)

                return value
            elif isinstance(value, container_abcs.Iterable):
                return type(value)(cast(param, v) for v in value)
            else:
                return value

        # Copy state assigned to params (and cast tensors to appropriate types).
        # State that is not assigned to params is copied as is (needed for
        # backward compatibility).
        state = defaultdict(dict)
        for k, v in state_dict['state'].items():
            if k in id_map:
                param = id_map[k]
                state[param] = cast(param, v)
            else:
                state[k] = v

        # Update parameter groups, setting their 'params' value
        def update_group(group, new_group):
            new_group['params'] = group['params']
            return new_group
        param_groups = [
            update_group(g, ng) for g, ng in zip(groups, saved_groups)]
        self.__setstate__({'state': state, 'param_groups': param_groups})

    def to_gpu(self):
        self.checked_if_on_gpu = True
        for gindex, group in enumerate(self.param_groups):
            for pindex, p in enumerate(group['params']):
                if p in self.state:
                    values = self.state[p]
                    for k, v in values.items():
                        if isinstance(v, torch.Tensor):
                            self.state[p][k] = v.to(p.device)

    @torch.no_grad()
    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        overflows = []

        if not self.checked_if_on_gpu: self.to_gpu() # needed for fairseq pure fp16 training
        for gindex, group in enumerate(self.param_groups):
            for pindex, p in enumerate(group['params']):
                if p.grad is None:
                    continue
                state = self.state[p]
                if len(state) == 0:
                    self.init_state(group, p, gindex, pindex)

                self.update_step(group, p, gindex, pindex)

        return loss

    def get_config(self, gindex, pindex, group):
        config = {}
        config['betas'] = group['betas']
        config['eps'] = group['eps']
        config['weight_decay'] = group['weight_decay']
        config['lr'] = group['lr']
        config['optim_bits'] = self.args.optim_bits
        config['min_8bit_size'] = self.args.min_8bit_size
        config['percentile_clipping'] = self.args.percentile_clipping
        config['block_wise'] = self.args.block_wise
        config['max_unorm'] = self.args.max_unorm
        config['skip_zeros'] = self.args.skip_zeros

        if (gindex, pindex) in self.mng.index2config:
            config.update(self.mng.index2config[(gindex, pindex)])
        return config

    def init_state(self, group, p, gindex, pindex):
        raise NotImplementedError(f'init_state method needs to be overidden')

    def update_step(self, group, p, gindex, pindex):
        raise NotImplementedError(f'The update_step method needs to be overidden')

class Optimizer2State(Optimizer8bit):
    def __init__(self, optimizer_name, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
            weight_decay=0.0, optim_bits=32, args=None,
            min_8bit_size=4096, percentile_clipping=100, block_wise=True, max_unorm=0.0,
            skip_zeros=False):
        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if not 0.0 <= eps:
            raise ValueError("Invalid epsilon value: {}".format(eps))
        if isinstance(betas, str):
            betas = eval(betas)
            print(betas, 'parsed')
        for i in range(len(betas)):
            if not 0.0 <= betas[i] < 1.0:
                raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}")
        if not 0.0 <= weight_decay:
            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
        defaults = dict(lr=lr, betas=betas, eps=eps,
                        weight_decay=weight_decay)
        super(Optimizer2State, self).__init__(params, defaults, optim_bits)

        if args is None:
            args = {}
            args['optim_bits'] = optim_bits
            args['percentile_clipping'] = 100
            args['min_8bit_size'] = min_8bit_size
            args['percentile_clipping'] = percentile_clipping
            args['block_wise'] = block_wise
            args['max_unorm'] = max_unorm
            args['skip_zeros'] = skip_zeros

            self.args = MockArgs(args)
        else:
            self.args = args

        self.optimizer_name = optimizer_name

    @torch.no_grad()
    def init_state(self, group, p, gindex, pindex):
        config = self.get_config(gindex, pindex, group)

        if config['optim_bits'] == 32:
            dtype = torch.float32
        elif config['optim_bits'] == 8:
            dtype = torch.uint8
        else: raise NotImplementedError(f'Amount of optimizer bits not supported: {config["optim_bits"]}')

        if p.numel() < config['min_8bit_size']: dtype = torch.float32

        state = self.state[p]
        state['step'] = 0

        if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096):
            state['state1'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.float32, device=p.device)
            state['state2'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.float32, device=p.device)
        elif dtype == torch.uint8:
            if state['step'] == 0:
                if 'dynamic' not in self.name2qmap: self.fill_qmap()
                self.name2qmap['dynamic'] = self.name2qmap['dynamic'].to(p.device)
                self.name2qmap['udynamic'] = self.name2qmap['udynamic'].to(p.device)

            state['state1'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.uint8, device=p.device)
            state['qmap1'] = self.name2qmap['dynamic']

            state['state2'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.uint8, device=p.device)
            state['qmap2'] = self.name2qmap['udynamic']

            if config['block_wise']:
                n = p.numel()
                blocks = n//2048
                blocks += 1 if n % 2048 > 0 else 0

                state['absmax1'] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
                state['absmax2'] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
            else:
                state['max1'] = torch.zeros((1,), dtype=torch.float32, device=p.device)
                state['new_max1'] = torch.zeros((1,), dtype=torch.float32, device=p.device)
                state['max2'] = torch.zeros((1,), dtype=torch.float32, device=p.device)
                state['new_max2'] = torch.zeros((1,), dtype=torch.float32, device=p.device)

        if config['percentile_clipping'] < 100:
            state['gnorm_vec'] = torch.zeros((100,), device=p.device)

        if config['max_unorm'] > 0.0:
            state['unorm_vec'] = torch.zeros((1,), device=p.device)

    @torch.no_grad()
    def update_step(self, group, p, gindex, pindex):
        state = self.state[p]
        grad = p.grad

        config = self.get_config(gindex, pindex, group)

        state['step'] += 1
        step = state['step']

        if config['percentile_clipping'] < 100:
            current_gnorm, clip_value, gnorm_scale = F.percentile_clipping(grad, state['gnorm_vec'], step, config['percentile_clipping'])
        else:
            gnorm_scale = 1.0

        if state['state1'].dtype == torch.float:
            F.optimizer_update_32bit(self.optimizer_name, grad, p, state['state1'], config['betas'][0], config['eps'], step, config['lr'],
                    state['state2'], config['betas'][1], config['weight_decay'], gnorm_scale,
                    state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm'], skip_zeros=config['skip_zeros'])

        elif state['state1'].dtype == torch.uint8 and not config['block_wise']:
            F.optimizer_update_8bit(self.optimizer_name, grad, p, state['state1'], state['state2'], config['betas'][0], config['betas'][1],
                          config['eps'],  step, config['lr'],
                          state['qmap1'], state['qmap2'], state['max1'], state['max2'], state['new_max1'], state['new_max2'],
                          config['weight_decay'], gnorm_scale=gnorm_scale,
                          unorm_vec=state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm'])

            # swap maxes
            state['max1'], state['new_max1'] = state['new_max1'], state['max1']
            state['max2'], state['new_max2'] = state['new_max2'], state['max2']
        elif state['state1'].dtype == torch.uint8 and config['block_wise']:
            F.optimizer_update_8bit_blockwise(self.optimizer_name, grad, p, state['state1'], state['state2'], config['betas'][0], config['betas'][1],
                          config['eps'],  step, config['lr'],
                          state['qmap1'], state['qmap2'], state['absmax1'], state['absmax2'],
                          config['weight_decay'], gnorm_scale=gnorm_scale, skip_zeros=config['skip_zeros'])


class Optimizer1State(Optimizer8bit):
    def __init__(self, optimizer_name, params, lr=1e-3, betas=(0.9, 0.0), eps=1e-8,
            weight_decay=0.0, optim_bits=32, args=None,
            min_8bit_size=4096, percentile_clipping=100, block_wise=True, max_unorm=0.0,
            skip_zeros=False):
        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if not 0.0 <= eps:
            raise ValueError("Invalid epsilon value: {}".format(eps))
        for i in range(len(betas)):
            if not 0.0 <= betas[i] < 1.0:
                raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}")
        if not 0.0 <= weight_decay:
            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
        defaults = dict(lr=lr, betas=betas, eps=eps,
                        weight_decay=weight_decay)
        super(Optimizer1State, self).__init__(params, defaults, optim_bits)

        if args is None:
            args = {}
            args['optim_bits'] = optim_bits
            args['percentile_clipping'] = 100
            args['min_8bit_size'] = min_8bit_size
            args['percentile_clipping'] = percentile_clipping
            args['block_wise'] = block_wise
            args['max_unorm'] = max_unorm
            args['skip_zeros'] = skip_zeros

            self.args = MockArgs(args)
        else:
            self.args = args

        self.optimizer_name = optimizer_name

    @torch.no_grad()
    def init_state(self, group, p, gindex, pindex):
        config = self.get_config(gindex, pindex, group)

        if config['optim_bits'] == 32:
            dtype = torch.float32
        elif config['optim_bits'] == 8:
            dtype = torch.uint8
        else: raise NotImplementedError(f'Amount of optimizer bits not supported: {config["optim_bits"]}')

        if p.numel() < config['min_8bit_size']: dtype = torch.float32

        state = self.state[p]
        state['step'] = 0

        if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096):
            state['state1'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.float32, device=p.device)
        elif dtype == torch.uint8:
            if state['step'] == 0:
                if 'dynamic' not in self.name2qmap: self.fill_qmap()
                self.name2qmap['dynamic'] = self.name2qmap['dynamic'].to(p.device)

            state['state1'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.uint8, device=p.device)
            state['qmap1'] = self.name2qmap['dynamic']

            if config['block_wise']:
                n = p.numel()
                blocks = n//2048
                blocks += 1 if n % 2048 > 0 else 0

                state['absmax1'] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
            else:
                state['max1'] = torch.zeros((1,), dtype=torch.float32, device=p.device)
                state['new_max1'] = torch.zeros((1,), dtype=torch.float32, device=p.device)

        if config['percentile_clipping'] < 100:
            state['gnorm_vec'] = torch.zeros((100,), device=p.device)

        if config['max_unorm'] > 0.0:
            state['unorm_vec'] = torch.zeros((1,), device=p.device)


    @torch.no_grad()
    def update_step(self, group, p, gindex, pindex):
        state = self.state[p]
        grad = p.grad

        config = self.get_config(gindex, pindex, group)

        state['step'] += 1
        step = state['step']

        if config['percentile_clipping'] < 100:
            current_gnorm, clip_value, gnorm_scale = F.percentile_clipping(grad, state['gnorm_vec'], step, config['percentile_clipping'])
        else:
            gnorm_scale = 1.0

        if state['state1'].dtype == torch.float:
            F.optimizer_update_32bit(self.optimizer_name, grad, p, state['state1'], config['betas'][0], config['eps'], step, config['lr'],
                    None, 0.0, config['weight_decay'], gnorm_scale,
                    state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm'],
                    skip_zeros=config['skip_zeros'])

        elif state['state1'].dtype == torch.uint8 and not config['block_wise']:
            F.optimizer_update_8bit(self.optimizer_name, grad, p, state['state1'], None, config['betas'][0], config['betas'][1],
                    config['eps'], step, config['lr'], state['qmap1'], None, state['max1'], None, state['new_max1'], None,
                    config['weight_decay'], gnorm_scale,
                    state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm'])

            state['max1'], state['new_max1'] = state['new_max1'], state['max1']
        elif state['state1'].dtype == torch.uint8 and config['block_wise']:
            F.optimizer_update_8bit_blockwise(self.optimizer_name, grad, p, state['state1'], None, config['betas'][0], config['betas'][1],
                          config['eps'],  step, config['lr'],
                          state['qmap1'], None, state['absmax1'], None,
                          config['weight_decay'], gnorm_scale=gnorm_scale, skip_zeros=config['skip_zeros'])