{-# LANGUAGE NoMonomorphismRestriction, NoImplicitPrelude, ConstraintKinds,
      TypeFamilies, RebindableSyntax, DeriveFunctor, DeriveFoldable #-}
module Knots.Morphism where

import Knots.Prelude hiding (Rational)

import Control.DeepSeq
import qualified Data.Map as Map
import qualified Data.Set as Set

import Knots.Free
import Knots.Graded
import Knots.Util

type Morphism b r = Graded (Lin b b r)

oo :: (Default b, Ord b, RingEq r) => Morphism b r -> Morphism b r -> Morphism b r
Graded g x `oo` Graded h y = Graded
    { grade      = g + h
    , components = let x_shifted = Map.mapKeys (subtract h) x
                   in  Map.intersectionWith o x_shifted y
                       `Map.union` fmap (zero `o`) (y Map.\\ x_shifted)
                       `Map.union` fmap (`o` zero) (x_shifted Map.\\ y)
    }

data B = B1 | Bx
    deriving (Eq,Show,Read,Ord)

instance NFData B where
    rnf x = x `seq` ()

degreeB :: B -> Int
degreeB Bx = -1
degreeB B1 = 1

degree :: [B] -> Int
degree = sum . map degreeB

isB1 :: B -> Bool
isB1 B1 = True
isB1 _  = False

basis :: Int -> [[B]]
basis 0         = [[]]
basis n | n > 0 = let bs = basis (n-1) in map (B1 :) bs ++ map (Bx :) bs
basis _         = error "basis: negative argument"

b1, bx  ::  [B]
[b1, bx] = basis 1

b11, b1x, bx1, bxx  ::  [B]
[b11, b1x, bx1, bxx] = basis 2

b111, b11x, b1x1, b1xx, bx11, bx1x, bxx1, bxxx  ::  [B]
[b111, b11x, b1x1, b1xx, bx11, bx1x, bxx1, bxxx] = basis 3

mult, multReducedLeft, multReducedRight,
    comult, comultReducedLeft, comultReducedRight,
    perm, permReduced,
    idA, idB1
    :: RingEq r => Morphism [B] r

mult = graded (-1)
    [ (-2, lin [bxx] [] 0)
    , (0,  lin [b1x,bx1] [bx] ( (1 .# bx) .# bx1 + (1 .# bx) .# b1x ))
    , (2,  lin [b11]     [b1] ( (1 .# b1)           .# b11 ))
    ]

multReducedLeft = graded (-1)
    [ (0,   lin [b1x]    []   0)
    , (2,   lin [b11]    [b1] ( (1 .# b1) .# b11 ) )
    ]

multReducedRight = graded (-1)
    [ (0,   lin [bx1]    []   0 )
    , (2,   lin [b11]    [b1] ( (1 .# b1) .# b11 ) )
    ]

comult = graded (-1)
    [ (-1, lin [bx] [bxx]     ( (1 .# bxx) .# bx ))
    , (1,  lin [b1] [b1x,bx1] ( (1 .# bx1) .# b1 +
                                (1 .# b1x) .# b1 ))
    , (3,  lin []   [b11]     0)
    ]

comultReducedLeft = graded (-1)
    [ (1,  lin [b1] [b1x] ( (1 .# b1x) .# b1 )) 
    , (3,  lin []   [b11] 0)
    ]

comultReducedRight = graded (-1)
    [ (1,  lin [b1] [bx1] ( (1 .# bx1) .# b1 )) 
    , (3,  lin []   [b11] 0)
    ]

perm = graded 0
        [ (-2, lin [bxx]
                   [bxx]
                   ( (1 .# bxx)   .# bxx ))
        , (0,  lin [b1x,bx1]
                   [b1x,bx1]
                   ( (1 .# b1x)   .# bx1 +
                     (1 .# bx1)   .# b1x))
        , (2,  lin [b11]
                   [b11]
                   ( (1 .# b11)   .# b11 ))
        ]

permReduced = graded 0
        [ (0,  lin [b1x]
                   [bx1]
                   ( (1 .# bx1)   .# b1x ))
        , (2,  lin [b11]
                   [b11]
                   ( (1 .# b11)   .# b11 ))
        ]

idA = graded 0
        [ (-1, lin [bx] [bx] ( (1 .# bx) .# bx ))
        , (1,  lin [b1] [b1] ( (1 .# b1) .# b1 ))
        ]

idB1 = graded 0
        [ (1, lin [b1] [b1] ( (1 .# b1) .# b1 )) ]

backPermute :: RingEq r => Morphism [B] r -> Morphism [B] r -> Int -> Morphism [B] r
backPermute phi psi k
    | k == 0    = phi
    | k >= 1    = (backPermute phi psi (k-1) * phi)   `oo`   (idA ^ (k-1) * psi)
    | otherwise = error "backPermute: negative argument"

forwardPermute :: RingEq r => Morphism [B] r -> Morphism [B] r -> Int -> Morphism [B] r
forwardPermute phi psi k
    | k == 0    = phi
    | k >= 1    = (phi * forwardPermute phi psi (k-1))   `oo`   (psi * idA ^ (k-1))
    | otherwise = error "forwardPermute: negative argument"


type Complex' r = Map Int (Morphism (Set Int,[B]) r)

convert :: (Ord i, Ord b, Ord c, AbelianGroup r) => Lin b c (Map i r) -> Map i (Lin b c r)
convert (Lin from to cs) =
    Map.fromListWith (+)
    [ (deg, Lin from to ((z .# j) .# i))
        | (i,x) <- monomials cs, (j,y) <- monomials x, (deg,z) <- Map.toList y ]

toMatrix :: (Ord a, Ord b, AbelianGroup r) => Lin a b r -> ((Int,Int), [[r]])
toMatrix (Lin from to f) =
    let the_matrix     = do i <- toList from
                            return $ do j <- toList to
                                        return $ (f `coeff` i) `coeff` j
    in ((Set.size to, Set.size from), the_matrix)

toMatrices :: (AbelianGroup r) => Complex' r -> Map (Int,Int) ((Int,Int), [[r]])
toMatrices = fmap toMatrix . convertMap1 . fmap components

isNullComplex' :: (AbEq r) => Complex' r -> Bool
isNullComplex'   = all isNullMorphism

isNullMorphism :: (AbEq r) => Morphism b r -> Bool
isNullMorphism  = all isNullMatrix . components

homology' :: Graded (Int,Int,Int) -> Graded (Int,Int,Int) -> Graded Int
homology' d1 d2 = Graded (grade d2) $
    Map.mergeWithKey (\_ (_,_,rk1) (m,_,rk2) -> Just (m - rk2 - rk1))
                     (fmap (\(_,n,rk) -> n - rk)) -- cokernel dimension
                     (fmap (\(m,_,rk) -> m - rk)) -- kernel dimension
                     (components (shiftGraded d1))
                     (components d2)

homology :: Map Int (Graded (Int,Int,Int)) -> Map Int (Graded Int)
homology x = Map.mergeWithKey (\_ d1 d2 -> Just (homology' d1 d2))
                              (map2 (\(_,n,rk) -> n - rk) . fmap shiftGraded) -- cokernel dimension
                              (map2 (\(m,_,rk) -> m - rk)) -- kernel dimension
                              (Map.mapKeys (1 +) x)
                              x

computeDims :: (Field r, Eq r, NFData r) => Complex' r -> Map Int (Graded (Int,Int,Int))
computeDims = map2 (\f -> (Set.size (dom f), Set.size (cod f), rank f))

shiftGraded :: Graded a -> Graded a
shiftGraded (Graded gr x) = Graded gr (Map.mapKeys (gr +) x)