from sage.structure.sage_object import SageObject
from sage.combinat.root_system.cartan_type import CartanType
from sage.data_structures.blas_dict import add as blas_dict_add
from sage.data_structures.blas_dict import axpy

class Monomial(SageObject):
    """
    A monomial of a `q`-character.
    """
    def __init__(self, d):
        self._data = dict(d)

    def __hash__(self):
        return hash(tuple(sorted(self._data.items())))

    def __eq__(self, other):
        return type(self) == type(other) and self._data == other._data

    def __ne__(self, other):
        return not (self == other)

    def __mul__(self, other):
        if len(self._data) < len(other._data):
            self, other = other, self
        ret = Monomial(self._data)
        for k in other._data:
            if k in ret._data:
                ret._data[k] += other._data[k]
                if ret._data[k] == 0:
                    del ret._data[k]
            else:
                ret._data[k] = other._data[k]
        return ret

    def _repr_(self):
        if not self._data:
            return '1'
        def exp_repr(e):
            if e == 1:
                return ''
            return '^%s' % e
        return '*'.join('Y{}[{}]'.format(i, qpow) + exp_repr(self._data[i,qpow])
                        for i,qpow in sorted(self._data))

    def _latex_(self):
        if not self._data:
            return '1'
        def exp_repr(e):
            if e == 1:
                return ''
            return '^{%s}' % e
        def q_exp(e):
            if e == 0:
                return '1'
            if e == 1:
                return 'q'
            return 'q^{{{}}}'.format(e)
        return ' '.join('Y_{{{},{}}}'.format(i, q_exp(qpow)) + exp_repr(self._data[i,qpow])
                        for i,qpow in sorted(self._data))

    def Y_iq_powers(self, i):
        """
        Return a ``dict`` whose keys are the powers of `q` for the variables
        `Y_{i,q^k}` occuring in ``self`` and the values are the exponents.
        """
        return {k[1]: self._data[k] for k in self._data if k[0] == i}

    def mult_Y_iq(self, i, k, e):
        """
        Multiply (and mutate) ``self`` by `Y_{i,q^k}^e`.
        """
        if (i, k) in self._data:
            if self._data[i,k] == -e:
                del self._data[i,k]
            else:
                self._data[i,k] += e
        else:
            self._data[i,k] = e

    def is_dominant(self):
        return all(k >= 0 for k in self._data.values())

class qCharacter(SageObject):
    """
    A `q`-character.
    """
    def __init__(self, d):
        self._poly_dict = dict(d)

    def _repr_(self):
        def coeff(c):
            if c == 1:
                return ''
            return repr(c) + '*'
        return " + ".join(coeff(self._poly_dict[m]) + repr(m) if m._data
                          else repr(self._poly_dict[m])
                          for m in self._poly_dict)

    def _latex_(self):
        def coeff(c):
            if c == 1:
                return ''
            return latex(c)
        return " + ".join(coeff(self._poly_dict[m]) + latex(m) if m._data
                          else latex(self._poly_dict[m])
                          for m in self._poly_dict)

    def __eq__(self, other):
        return type(self) == type(other) and self._poly_dict == other._poly_dict

    def __ne__(self, other):
        return not (self == other)

    def __len__(self):
        """
        Return the number of monomials of ``self``.
        """
        return sum(self._poly_dict.values())

    def __getitem__(self, m):
        if m not in self._poly_dict:
            return 0
        else:
            return self._poly_dict[m]

    def __add__(self, other):
        return qCharacter(blas_dict_add(self._poly_dict, other._poly_dict))

    def __sub__(self, other):
        return qCharacter(axpy(-1, other._poly_dict, self._poly_dict))

    def __mul__(self, other):
        ret = qCharacter({})
        for m in self._poly_dict:
            for mp in other._poly_dict:
                mpp = m * mp
                if mpp in ret._poly_dict:
                    ret._poly_dict[mpp] += self._poly_dict[m] * other._poly_dict[mp]
                else:
                    ret._poly_dict[mpp] = self._poly_dict[m] * other._poly_dict[mp]
        return ret

    @cached_method
    def dominant_monomials(self):
        return tuple([m for m in self._poly_dict if m.is_dominant()])

def FM_algorithm(ct, initial):
    """
    Return the `q`-character of ``m`` of the Cartan type ``ct``
    using the FM algorithm.

    INPUT:

    - ``ct`` -- a Cartan type
    - ``initial`` -- the initial monomial data as a ``dict`` whose keys
      are pairs `(i, k)` corresponding to `Y_{i,q^k}` and the value is
      the corresponding expontent

    EXAMPLES:

    We create the `q`-character whose dominant monomial
    is `Y_{1,q^2} Y_{2,q^{-1}}` in type `A_2`::

        sage: FM_algorithm(['A',2], {(1,2):1, (2,-1):1})
        Y1[0]*Y2[1]^-1*Y2[5]^-1 + Y2[-1]*Y2[5]^-1 + Y1[4]^-1*Y2[-1]*Y2[3]
         + Y1[2]*Y2[-1] + Y1[2]^-1*Y1[4]^-1*Y2[3] + Y1[0]*Y1[4]^-1*Y2[1]^-1*Y2[3]
         + Y1[0]*Y1[2]*Y2[1]^-1 + Y1[2]^-1*Y2[5]^-1

    Next, we compute the `q`-character of `Y_{2,q^{-1}}` in type `C_2`::

        sage: FM_algorithm(['C',2], {(2,-1):1})
        Y1[0]*Y1[4]^-1 + Y2[-1] + Y2[5]^-1 + Y1[0]*Y1[2]*Y2[3]^-1
         + Y1[2]^-1*Y1[4]^-1*Y2[1]

    Now `Y_{2,q^{-1}} Y_{2,q^1}`::

        sage: FM_algorithm(['C',2], {(2,-1):1, (2,1):1})
        Y1[2]*Y1[6]^-1*Y2[-1] + Y1[0]*Y1[4]^-1*Y2[1] + Y1[0]*Y1[2]*Y2[3]^-1*Y2[7]^-1
         + Y2[-1]*Y2[1] + Y1[2]*Y1[6]^-1*Y2[5]^-1 + Y1[0]*Y1[2]^2*Y1[6]^-1*Y2[3]^-1
         + Y1[2]*Y1[4]*Y2[-1]*Y2[5]^-1 + 2*Y1[0]*Y1[2]*Y1[4]^-1*Y1[6]^-1
         + Y1[2]^-1*Y1[4]^-2*Y1[6]^-1*Y2[1]*Y2[3] + Y1[4]^-1*Y1[6]^-1*Y2[3]*Y2[5]^-1
         + Y1[2]^-1*Y1[4]^-1*Y2[1]*Y2[7]^-1 + Y1[0]*Y1[2]^2*Y1[4]*Y2[3]^-1*Y2[5]^-1
         + Y1[0]*Y1[4]^-1*Y2[7]^-1 + Y1[0]*Y1[2]*Y2[5]^-1 + Y1[2]^-1*Y1[4]^-1*Y2[1]^2
         + Y1[4]^-1*Y1[6]^-1*Y2[1] + Y1[0]*Y1[2]*Y2[1]*Y2[3]^-1
         + Y1[4]^-1*Y1[6]^-1*Y2[-1]*Y2[3] + Y2[-1]*Y2[7]^-1 + Y2[5]^-1*Y2[7]^-1
         + 2*Y2[1]*Y2[5]^-1 + Y1[2]*Y1[4]*Y2[5]^-2 + Y1[0]*Y1[4]^-2*Y1[6]^-1*Y2[3]

    Finally, `Y_{2,q^{-1}} Y_{2,q^3}`, which gives the KR module `W^{2,2}`::

        sage: FM_algorithm(['C',2], {(2,-1):1, (2,3):1})
        Y1[2]^-1*Y1[4]^-1*Y2[1]*Y2[9]^-1 + Y1[4]*Y1[8]^-1*Y2[-1]
         + Y1[2]^-1*Y1[4]^-1*Y1[6]^-1*Y1[8]^-1*Y2[1]*Y2[5]
         + Y1[0]*Y1[2]*Y1[4]*Y1[8]^-1*Y2[3]^-1 + Y2[5]^-1*Y2[9]^-1
         + Y1[0]*Y1[4]^-1*Y1[6]^-1*Y1[8]^-1*Y2[5]
         + Y1[0]*Y1[2]*Y1[4]*Y1[6]*Y2[3]^-1*Y2[7]^-1 + Y1[0]*Y1[2]*Y2[3]^-1*Y2[9]^-1
         + Y2[-1]*Y2[9]^-1 + Y2[-1]*Y2[3] + Y1[0]*Y1[4]^-1*Y2[9]^-1
         + Y1[6]^-1*Y1[8]^-1*Y2[-1]*Y2[5] + Y1[4]*Y1[6]*Y2[-1]*Y2[7]^-1
         + Y1[0]*Y1[2]*Y1[6]^-1*Y1[8]^-1*Y2[3]^-1*Y2[5]
    """
    from itertools import product
    ct = CartanType(ct)
    I = ct.index_set()
    d = ct.symmetrizer()
    m = Monomial(initial)
    ret = {m: 1}
    coloring = {m: {i: 0 for i in I}}
    CM = ct.cartan_matrix()
    adjacent = {i: {j: CM[ij,ii] for ij,j in enumerate(I) if i != j and CM[ij,ii]}
                for ii,i in enumerate(I)}
    # We go through weights (and hence monomials) by depth
    next = [[m]]
    num = -1
    while next:
        cur = next.pop(0)
        num += 1
        for m in cur:
            for i in I:
                # There is nothing to check/do for i and m such that
                #   the i-color is equal to the coefficient.
                if coloring[m][i] == ret[m]:
                    #print("coloring equals coefficient {} and {}".format(m, i))
                    continue

                # Get the powers k of q for the variables Y_{i,q^k}
                powers = m.Y_iq_powers(i)

                # Check for failure of FM algorithm by verifying m is i-dominant
                assert all(k >= 0 for k in powers.values()), "FM algorithm failed at %s" % m

                ##### Perform the i-expansion #####
                #print("Performing {}-expansion of {}".format(i,m))

                # Compute the q-strings in general position to determine
                #   which sl_2 representations m corresponds to.
                di = d[i]
                adj = adjacent[i]
                top_q = []
                q_str_len = []
                while powers:
                    cur_pow = min(powers)
                    cur_len = 0
                    while cur_pow in powers:
                        if powers[cur_pow] == 1:
                            del powers[cur_pow]
                        else:
                            powers[cur_pow] -= 1
                        cur_len += 1
                        cur_pow += 2 * di
                    top_q.append(cur_pow - 2*di)
                    q_str_len.append(cur_len)
                #print("q-strings: {}, {}".format(top_q, q_str_len))

                # Build the i-string and add in the corresponding variables
                t = (ret[m] - coloring[m][i])
                for exponent in product(*[range(ell+1) for ell in q_str_len]):
                    depth = sum(exponent)
                    if depth == 0: # Nothing to do for this monomial
                        continue

                    while depth > len(next):
                        next.append([])

                    # Compute the new monomial
                    mp = Monomial(m._data)
                    for ind, e in enumerate(exponent):
                        # Multiply by prod_{k=0}^e A_{i, a q_i^{1-2k}}
                        for k in range(e):
                            a = top_q[ind] + di*(1 - 2*k)
                            mp.mult_Y_iq(i, a-di, -1)
                            mp.mult_Y_iq(i, a+di, -1)
                            for j in adj:
                                if adj[j] == -1:
                                    mp.mult_Y_iq(j, a, 1)
                                elif adj[j] == -2:
                                    mp.mult_Y_iq(j, a-1, 1)
                                    mp.mult_Y_iq(j, a+1, 1)
                                elif adj[j] == -3:
                                    mp.mult_Y_iq(j, a-2, 1)
                                    mp.mult_Y_iq(j, a, 1)
                                    mp.mult_Y_iq(j, a+2, 1)

                    # Now add it to the colored polynomial
                    if mp not in ret:
                        coloring[mp] = {i: 0 for i in I}
                        coloring[mp][i] = t
                        ret[mp] = t
                        next[depth-1].append(mp)
                    else:
                        coloring[mp][i] += t
                        if coloring[mp][i] > ret[mp]:
                            ret[mp] = coloring[mp][i]

                #print(qCharacter(ret))

    return qCharacter(ret)

def product_dominant_monomials(V1, V2):
    ret = []
    for m in V1._poly_dict:
        for mp in V2._poly_dict:
            mpp = m * mp
            if mpp.is_dominant():
                ret.extend([mpp]*(V1._poly_dict[m]*V2._poly_dict[mp]))
    return ret

def is_product_simple(V1, V2):
    count = 0
    for m in V1._poly_dict:
        for mp in V2._poly_dict:
            mpp = m * mp
            if mpp.is_dominant():
                count += 1
                if count > 1:
                    return False
    return True

