{-# LANGUAGE NoMonomorphismRestriction, NoImplicitPrelude, ConstraintKinds,
      TypeFamilies, RebindableSyntax, DeriveFunctor, DeriveFoldable,
      FlexibleInstances, ScopedTypeVariables, GeneralizedNewtypeDeriving #-}
module Knots.Free where

import Knots.Prelude hiding (Rational)

import Control.DeepSeq
import Data.Ratio
import Data.List (intercalate)
import qualified Data.Map as Map
import qualified Data.Set as Set

import Knots.Util

type Basis b = (Ord b)

-- | Free vector space with basis @b@ and coefficient ring @r@.
newtype Free b r = F { unF :: Map b r }
    deriving (Eq,Read,Functor,Foldable,NFData)

liftF :: (Map b r -> Map b' r') -> Free b r -> Free b' r'
liftF f = F . f . unF

instance (Show b, Show r) => Show (Free b r) where
    showsPrec p x =
        case monomials x of
             []   -> showString "0"
             mons -> showParen (p > 5)
                     . showString
                     . intercalate " + "
                     $ map show_mon mons
        where show_mon (q,r) = showsPrec 6 r
                               . showString ".#"
                               .  showsPrec 7 q
                               $ ""

-- | Monomial.
(.#) :: r -> b -> Free b r
r .# b = F $ Map.singleton b r

mapBasis :: (Basis b, Basis c) => (b -> c) -> Free b r -> Free c r
mapBasis f = F . Map.mapKeys f . unF

instance (Basis b) => Default (Free b r) where
    def = F $ Map.empty

-- | Point-wise addition
instance (AbelianGroup r, Basis b) => AbelianGroup (Free b r) where
    {-# SPECIALISE instance (Basis b) => AbelianGroup (Free b Rational) #-}
    zero        = F Map.empty
    negate      = fmap negate
    {-# INLINE negate #-}
    F x + F y   = F (Map.unionWith (+) x y)
    {-# INLINE (+) #-}

-- | A sort of Kronecker (tensor) product
instance (Basis b, Monoid b, Ring r, Eq r) => Ring (Free b r) where
    fromInteger n = plus [ (mempty, fromInteger n) ]
    x * y = mapBasis (uncurry mappend) (tensor x y)

-- | Encoding of a linear map from `Free b r' to `Free c r'. The choice of
-- encoding is limited to such linear maps `f' where `f(x)' vanishes for all
-- but finitely many basis elements `x' in `b'.
--
-- Speaking matrix-wise, the intended interpretation is column-major format; use
-- `dual' to convert from row-major format.
data Lin b c r = Lin
    { dom    :: Set b
    , cod    :: Set c
    , matrix :: Free b (Free c r)
    } deriving (Eq,Show,Read,Functor)

columns = map'i unF . matrix

onMatrix :: (Free b (Free c r) -> Free b (Free c r)) -> Lin b c r -> Lin b c r
onMatrix f (Lin from to mx) = Lin from to (f mx)

instance Foldable (Lin b c) where
    foldr f z = Map.foldr (flip (Map.foldr f)) z . fmap unF . unF . matrix

instance (Basis b, Basis c, AbelianGroup r) => AbelianGroup (Lin b c r) where
    zero    = Lin mempty mempty zero
    negate  = fmap negate
    Lin from1 to1 mx1 + Lin from2 to2 mx2 =
        Lin (Set.union from1 from2) (Set.union to1 to2) (mx1 + mx2)

-- | (Partial) ring instance for certain endomorphisms, implementing tensor
-- product.
--
-- BEWARE: This is the tensor product; it is *not* composition of morphisms.
instance (Ring r, Eq r, Monoid b, Basis b) => Ring (Lin b b r) where
    fromInteger n = Lin
        { dom = Set.singleton mempty
        , cod = Set.singleton mempty
        , matrix = (fromInteger n .# mempty) .# mempty
        }
    Lin from1 to1 mx1 * Lin from2 to2 mx2
        = Lin { dom = prod from1 from2
              , cod = prod to1 to2
              , matrix = mapBasis (uncurry mappend) (mx1 `tensor` mx2)
              }

instance (NFData b, NFData c, NFData r) => NFData (Lin b c r) where
    rnf (Lin from to mx) = from `deepseq` to `deepseq` mx `deepseq` ()

plus :: (Basis b, AbEq r) => [(b, r)] -> Free b r
plus = F . Map.filter (/= zero) . Map.fromListWith (+)

coeff :: (Basis b, AbelianGroup r) => Free b r -> b -> r
coeff (F v) b = Map.findWithDefault zero b v
{-# INLINE coeff #-}

monomials :: Free b r -> [(b, r)]
monomials = Map.toList . unF

lin :: (Basis b, Basis c, AbelianGroup r, Eq r) => [b] -> [c] -> Free b (Free c r) -> Lin b c r
lin from to mx = Lin (Set.fromList from) (Set.fromList to) mx

transpose :: (Basis b, Basis c, AbEq r) => Free c (Free b r) -> Free b (Free c r)
transpose mx = plus [ (a, r.#b) | (b,row) <- monomials mx
                                , (a,r)   <- monomials row ]

-- | Transposition.
dual :: (Basis a, Basis b, AbEq r) => Lin a b r -> Lin b a r
dual (Lin from to mx) = Lin to from (transpose mx)

-- | Composition of linear maps.
o :: (Basis a, Basis b, Basis c, RingEq r) => Lin b c r -> Lin a b r -> Lin a c r
g `o` f = Lin { dom = dom f
              , cod = cod g
              , matrix = fmap (apply . matrix $ g) (matrix f)
              }

-- | Matrix-vector-multiplication.
apply :: (Basis a, Basis b, RingEq r) => Free a (Free b r) -> Free a r -> Free b r
apply mx x = plus [ (b,r*rb) | (a,col) <- monomials mx
                             , (b,rb)  <- monomials col
                             , Just r <-  [ Map.lookup a (unF x) ]
                             ]

-- | Outer tensor product. Works also for linear maps, but for endomorphisms,
-- @*@ is preferred.
tensor :: (Basis b, Basis c, Ring r, Eq r) => Free b r -> Free c r -> Free (b,c) r
tensor x y = plus [ ((b,c), rb*rc) | (b,rb) <- monomials x, (c,rc) <- monomials y ]

-- | Checks whether the linear map maps everything to `def'.
isNullMatrix :: (AbEq r) => Lin a b r -> Bool
isNullMatrix    = all isNullVector . matrix
{-# INLINE isNullMatrix #-}

isNullVector :: (AbEq r) => Free b r -> Bool
isNullVector = all (== zero)
{-# INLINE isNullVector #-}

join_free :: (Basis b, Basis c, AbelianGroup r) => Free b (Free c r) -> Free (b,c) r
join_free f = sum [ r .# (a,b) | (a,row) <- monomials f, (b,r) <- monomials row ]

unions :: (Basis a, Foldable f) => f (Set a) -> Set a
unions = foldl Set.union Set.empty

mapToPairs :: (Basis k, Basis a) => Map k (Set a) -> Set (k,a)
mapToPairs = unions . Map.mapWithKey (\c s -> (\x -> (c,x)) `Set.map` s)

join_codomains :: (Basis b, Basis b', Basis c, Basis c', AbEq r) => Lin b c (Lin b' c' r) -> Set (c,c')
join_codomains = join_domains . dual . fmap dual

join_domains :: (Basis b, Basis b', Basis c, Basis c', AbEq r)  => Lin b c (Lin b' c' r) -> Set (b,b')
join_domains =
        mapToPairs
        . fmap (\col -> unions (dom `Map.map` col))
        . map'i unF
        . matrix

join_lin :: (Basis a, Basis b, Basis c, Basis d, AbEq r) => Lin a c (Lin b d r) -> Lin (a,b) (c,d) r
join_lin f = Lin
    { dom = join_domains f
    , cod = join_codomains f
    , matrix  = sum [ (rx .# (v,x)) .# (u,w)
                    | (u,ru) <- monomials (matrix f)
                    , (v,rv) <- monomials ru
                    , (w,rw) <- monomials (matrix rv)
                    , (x,rx) <- monomials rw
                    ]
    }

-- | Reduce the entries of a matrix to weak head normal form.
seqMatrix :: Free c (Free b r) -> x -> x
seqMatrix = flip $ foldl' (foldl' (flip seq))

-- | Apply `seqMatrix' to the argument before applying the function to it.
($!!!) :: (Free c (Free b r) -> x) -> Free c (Free b r) -> x
f $!!! mx = mx `seqMatrix` f mx

infixr 0 $!!!

-- | Applies elementary column transformations until column echelon form is
-- achieved.
gauss :: forall b c r. (Basis b, Basis c, Field r, Eq r, NFData b, NFData c, NFData r) => Lin b c r -> Lin b c r
gauss (Lin from to mx) = Lin from to (liftF (Map.filter (not . isNullVector)) $ go (toList to) (toList from) mx) where
    go :: [c] -> [b] -> Free b (Free c r) -> Free b (Free c r)
    go []     _      a = a
    go _      []     a = a
    go (i:is) (j:js) a =
        case filter (>= j) . findIndices ((/= 0) . (`coeff` i)) $ a of
             []      -> go is (j:js) $!! a -- Going to the next row
             (k:ks)  -> -- Use column k as pivot column. We want to have
                        -- zeroes below a # (k,j).
                        let -- Make a ``1'' at a # (k,j) :
                            b = adjust (fmap (/ a # (k,i))) k a
                            pivot_col = b `coeff` k
                            -- Make a ``0'' at b # (k_,j) :
                            col_transformation c k_ = adjust (\r -> liftF (Map.filter (/= 0)) $ r - (r `coeff` i) .* pivot_col) k_ c
                        in  go is js $!! exchange j k $!! foldl' col_transformation b ks

rank :: (Basis b, Basis c, Field r, Eq r, NFData b, NFData c, NFData r) => Lin b c r -> Int
rank = Map.size . unF . matrix . gauss
        -- Here, `Map.size' counts rows in the row echelon form of the matrix.
        -- Null rows have already been filtered by `gauss'.

adjust :: (Basis b) => (r -> r) -> b -> Free b r -> Free b r
adjust f k = liftF $ Map.adjust f k
{-# INLINE adjust #-}

exchange :: (Basis b) => b -> b -> Free b r -> Free b r
exchange i j = F . exchange' . unF
    where exchange' x = case (Map.lookup i x, Map.lookup j x) of
            (Just x_i, Just x_j)    -> Map.fromList [ (i,x_j), (j,x_i) ] `Map.union` x
                                        --   works because Map.union is
                                        --   left-biased
            (Just x_i, Nothing)     -> Map.singleton j x_i `Map.union` Map.delete i x
            (Nothing, Just x_j)     -> Map.singleton i x_j `Map.union` Map.delete j x
            (Nothing, Nothing)      -> x

findIndices :: (r -> Bool) -> Free b r -> [b]
findIndices f = Map.keys . Map.filter f . unF

(#) :: (Basis b, Basis c, AbelianGroup r) => Free c (Free b r) -> (c,b) -> r
a # (i,j) = (a `coeff` i) `coeff` j
{-# INLINE (#) #-}