{-# LANGUAGE BangPatterns, TemplateHaskell, TypeFamilies, TypeOperators, FlexibleContexts, UndecidableInstances, DeriveDataTypeable, MultiParamTypeClasses, FunctionalDependencies, FlexibleInstances #-}
{-# OPTIONS_GHC -fno-warn-name-shadowing #-}
module Numeric.AD.Internal.Sparse 
    ( Index(..)
    , emptyIndex
    , addToIndex
    , indices
    , Sparse(..)
    , apply
    , vars
    , d, d', ds
    , skeleton
    , spartial
    , partial
    , vgrad
    , vgrad'
    , vgrads
    , Grad(..)
    , Grads(..)
    ) where

import Prelude hiding (lookup)
import Control.Applicative
import Numeric.AD.Internal.Classes
import Control.Comonad.Cofree
import Numeric.AD.Internal.Types
import Data.Data
import Data.Typeable ()
import qualified Data.IntMap as IntMap 
import Data.IntMap (IntMap, mapWithKey, unionWith, findWithDefault, toAscList, singleton, insertWith, lookup)
import Data.Traversable
import Language.Haskell.TH

newtype Index = Index (IntMap Int)

emptyIndex :: Index
emptyIndex = Index IntMap.empty
{-# INLINE emptyIndex #-}

addToIndex :: Int -> Index -> Index
addToIndex k (Index m) = Index (insertWith (+) k 1 m)
{-# INLINE addToIndex #-}

indices :: Index -> [Int]
indices (Index as) = uncurry (flip replicate) `concatMap` toAscList as
{-# INLINE indices #-}

-- | We only store partials in sorted order, so the map contained in a partial
-- will only contain partials with equal or greater keys to that of the map in
-- which it was found. This should be key for efficiently computing sparse hessians.
-- there are only (n + k - 1) choose k distinct nth partial derivatives of a 
-- function with k inputs.
data Sparse a = Sparse !a (IntMap (Sparse a)) deriving (Show, Data, Typeable)

-- | drop keys below a given value
dropMap :: Int -> IntMap a -> IntMap a
dropMap n = snd . IntMap.split (n - 1) 
{-# INLINE dropMap #-}

times :: Num a => Sparse a -> Int -> Sparse a -> Sparse a
times (Sparse a as) n (Sparse b bs) = Sparse (a * b) $
    unionWith (<+>) 
        (fmap (^* b) (dropMap n as))
        (fmap (a *^) (dropMap n bs))
{-# INLINE times #-}

vars :: (Traversable f, Num a) => f a -> f (AD Sparse a)
vars = snd . mapAccumL var 0 
    where
        var !n a = (n + 1, AD $ Sparse a $ singleton n $ lift 1)
{-# INLINE vars #-}

apply :: (Traversable f, Num a) => (f (AD Sparse a) -> b) -> f a -> b
apply f = f . vars 
{-# INLINE apply #-}

skeleton :: Traversable f => f a -> f Int
skeleton = snd . mapAccumL (\ !n _ -> (n + 1, n)) 0
{-# INLINE skeleton #-}

d :: (Traversable f, Num a) => f b -> AD Sparse a -> f a
d fs (AD (Sparse _ da)) = snd $ mapAccumL (\ !n _ -> (n + 1, maybe 0 primal $ lookup n da)) 0 fs
{-# INLINE d #-}

d' :: (Traversable f, Num a) => f a -> AD Sparse a -> (a, f a)
d' fs (AD (Sparse a da)) = (a , snd $ mapAccumL (\ !n _ -> (n + 1, maybe 0 primal $ lookup n da)) 0 fs)
{-# INLINE d' #-}

ds :: (Traversable f, Num a) => f b -> AD Sparse a -> Cofree f a
ds fs (AD as@(Sparse a _)) = a :< (go emptyIndex <$> fns)
    where
        fns = skeleton fs
        -- go :: Index -> Int -> Cofree f a
        go ix i = partial (indices ix') as :< (go ix' <$> fns)
            where ix' = addToIndex i ix
{-# INLINE ds #-}

{-
vvars :: Num a => Vector a -> Vector (AD Sparse a)
vvars = Vector.imap (\n a -> AD $ Sparse a $ singleton n $ lift 1)
{-# INLINE vvars #-}

vapply :: Num a => (Vector (AD Sparse a) -> b) -> Vector a -> b
vapply f = f . vvars 
{-# INLINE vapply #-}


vd :: Num a => Int -> AD Sparse a -> Vector a
vd n (AD (Sparse _ da)) = Vector.generate n $ \i -> maybe 0 primal $ lookup i da
{-# INLINE vd #-}

vd' :: Num a => Int -> AD Sparse a -> (a, Vector a)
vd' n (AD (Sparse a da)) = (a , Vector.generate n $ \i -> maybe 0 primal $ lookup i da)
{-# INLINE vd' #-}

vds :: Num a => Int -> AD Sparse a -> Cofree Vector a
vds n (AD as@(Sparse a _)) = a :< Vector.generate n (go emptyIndex)
    where
        go ix i = partial (indices ix') as :< Vector.generate n (go ix')
            where ix' = addToIndex i ix
{-# INLINE vds #-}
-}

partial :: Num a => [Int] -> Sparse a -> a
partial []     (Sparse a _)  = a
partial (n:ns) (Sparse _ da) = partial ns $ findWithDefault (lift 0) n da
{-# INLINE partial #-}

spartial :: Num a => [Int] -> Sparse a -> Maybe a
spartial [] (Sparse a _) = Just a
spartial (n:ns) (Sparse _ da) = do
    a' <- lookup n da
    spartial ns a'
{-# INLINE spartial #-}



instance Primal Sparse where
    primal (Sparse a _) = a

instance Lifted Sparse => Mode Sparse where
    lift a = Sparse a (IntMap.empty)
    Sparse a as <+> Sparse b bs = Sparse (a + b) $ unionWith (<+>) as bs
    Sparse a as ^* b = Sparse (a * b) $ fmap (^* b) as
    a *^ Sparse b bs = Sparse (a * b) $ fmap (a *^) bs
    Sparse a as ^/ b = Sparse (a / b) $ fmap (^/ b) as

instance Lifted Sparse => Jacobian Sparse where
    type D Sparse = Sparse
    unary f dadb (Sparse pb bs) = Sparse (f pb) $ mapWithKey (times dadb) bs
    lift1 f df b@(Sparse pb bs) = Sparse (f pb) $ mapWithKey (times (df b)) bs
    lift1_ f df b@(Sparse pb bs) = a where
        a = Sparse (f pb) $ mapWithKey (times (df a b)) bs

    binary f dadb dadc (Sparse pb db) (Sparse pc dc) = Sparse (f pb pc) $ 
        unionWith (<+>) 
            (mapWithKey (times dadb) db)
            (mapWithKey (times dadc) dc)

    lift2 f df b@(Sparse pb db) c@(Sparse pc dc) = Sparse (f pb pc) da where
        (dadb, dadc) = df b c
        da = unionWith (<+>) 
            (mapWithKey (times dadb) db)
            (mapWithKey (times dadc) dc)
        
    lift2_ f df b@(Sparse pb db) c@(Sparse pc dc) = a where
        (dadb, dadc) = df a b c
        a = Sparse (f pb pc) da
        da = unionWith (<+>) 
            (mapWithKey (times dadb) db)
            (mapWithKey (times dadc) dc)

deriveLifted id $ conT ''Sparse


class Num a => Grad i o o' a | i -> a o o', o -> a i o', o' -> a i o where
    pack :: i -> [AD Sparse a] -> AD Sparse a
    unpack :: ([a] -> [a]) -> o
    unpack' :: ([a] -> (a, [a])) -> o'

instance Num a => Grad (AD Sparse a) [a] (a, [a]) a where
    pack i _ = i
    unpack f = f []
    unpack' f = f []

instance Grad i o o' a => Grad (AD Sparse a -> i) (a -> o) (a -> o') a where
    pack f (a:as) = pack (f a) as
    pack _ [] = error "Grad.pack: logic error"
    unpack f a = unpack (f . (a:))
    unpack' f a = unpack' (f . (a:))

vgrad :: Grad i o o' a => i -> o
vgrad i = unpack (unsafeGrad (pack i))
    where
        unsafeGrad f as = d as $ apply f as
{-# INLINE vgrad #-}

vgrad' :: Grad i o o' a => i -> o'
vgrad' i = unpack' (unsafeGrad' (pack i))
    where
        unsafeGrad' f as = d' as $ apply f as
{-# INLINE vgrad' #-}

class Num a => Grads i o a | i -> a o, o -> a i where
    packs :: i -> [AD Sparse a] -> AD Sparse a
    unpacks :: ([a] -> Cofree [] a) -> o

instance Num a => Grads (AD Sparse a) (Cofree [] a) a where
    packs i _ = i
    unpacks f = f []

instance Grads i o a => Grads (AD Sparse a -> i) (a -> o) a where
    packs f (a:as) = packs (f a) as
    packs _ [] = error "Grad.pack: logic error"
    unpacks f a = unpacks (f . (a:))

vgrads :: Grads i o a => i -> o
vgrads i = unpacks (unsafeGrads (packs i))
    where
        unsafeGrads f as = ds as $ apply f as
{-# INLINE vgrads #-}