module Numeric.AD.Internal.Sparse 
    ( Index(..)
    , emptyIndex
    , addToIndex
    , indices
    , Sparse(..)
    , apply
    , vars
    , d, d', ds
    , skeleton
    , spartial
    , partial
    , vgrad
    , vgrad'
    , vgrads
    , Grad(..)
    , Grads(..)
    ) where
import Prelude hiding (lookup)
import Control.Applicative
import Numeric.AD.Internal.Classes
import Numeric.AD.Internal.Stream
import Numeric.AD.Internal.Types
import Data.Data
import Data.Typeable ()
import qualified Data.IntMap as IntMap 
import Data.IntMap (IntMap, mapWithKey, unionWith, findWithDefault, toAscList, singleton, insertWith, lookup)
import Data.Traversable
import Language.Haskell.TH
newtype Index = Index (IntMap Int)
emptyIndex :: Index
emptyIndex = Index IntMap.empty
addToIndex :: Int -> Index -> Index
addToIndex k (Index m) = Index (insertWith (+) k 1 m)
indices :: Index -> [Int]
indices (Index as) = uncurry (flip replicate) `concatMap` toAscList as
data Sparse a = Sparse a (IntMap (Sparse a)) deriving (Show, Data, Typeable)
dropMap :: Int -> IntMap a -> IntMap a
dropMap n = snd . IntMap.split (n  1) 
times :: Num a => Sparse a -> Int -> Sparse a -> Sparse a
times (Sparse a as) n (Sparse b bs) = Sparse (a * b) $
    unionWith (<+>) 
        (fmap (^* b) (dropMap n as))
        (fmap (a *^) (dropMap n bs))
vars :: (Traversable f, Num a) => f a -> f (AD Sparse a)
vars = snd . mapAccumL var 0 
    where
        var !n a = (n + 1, AD $ Sparse a $ singleton n $ lift 1)
apply :: (Traversable f, Num a) => (f (AD Sparse a) -> b) -> f a -> b
apply f = f . vars 
skeleton :: Traversable f => f a -> f Int
skeleton = snd . mapAccumL (\ !n _ -> (n + 1, n)) 0
d :: (Traversable f, Num a) => f b -> AD Sparse a -> f a
d fs (AD (Sparse _ da)) = snd $ mapAccumL (\ !n _ -> (n + 1, maybe 0 primal $ lookup n da)) 0 fs
d' :: (Traversable f, Num a) => f a -> AD Sparse a -> (a, f a)
d' fs (AD (Sparse a da)) = (a , snd $ mapAccumL (\ !n _ -> (n + 1, maybe 0 primal $ lookup n da)) 0 fs)
ds :: (Traversable f, Num a) => f b -> AD Sparse a -> Stream f a
ds fs (AD as@(Sparse a _)) = a :< (go emptyIndex <$> fns)
    where
        fns = skeleton fs
        
        go ix i = partial (indices ix') as :< (go ix' <$> fns)
            where ix' = addToIndex i ix
partial :: Num a => [Int] -> Sparse a -> a
partial []     (Sparse a _)  = a
partial (n:ns) (Sparse _ da) = partial ns $ findWithDefault (lift 0) n da
spartial :: Num a => [Int] -> Sparse a -> Maybe a
spartial [] (Sparse a _) = Just a
spartial (n:ns) (Sparse _ da) = do
    a' <- lookup n da
    spartial ns a'
instance Primal Sparse where
    primal (Sparse a _) = a
instance Lifted Sparse => Mode Sparse where
    lift a = Sparse a (IntMap.empty)
    Sparse a as <+> Sparse b bs = Sparse (a + b) $ unionWith (<+>) as bs
    Sparse a as ^* b = Sparse (a * b) $ fmap (^* b) as
    a *^ Sparse b bs = Sparse (a * b) $ fmap (a *^) bs
    Sparse a as ^/ b = Sparse (a / b) $ fmap (^/ b) as
instance Lifted Sparse => Jacobian Sparse where
    type D Sparse = Sparse
    unary f dadb (Sparse pb bs) = Sparse (f pb) $ mapWithKey (times dadb) bs
    lift1 f df b@(Sparse pb bs) = Sparse (f pb) $ mapWithKey (times (df b)) bs
    lift1_ f df b@(Sparse pb bs) = a where
        a = Sparse (f pb) $ mapWithKey (times (df a b)) bs
    binary f dadb dadc (Sparse pb db) (Sparse pc dc) = Sparse (f pb pc) $ 
        unionWith (<+>) 
            (mapWithKey (times dadb) db)
            (mapWithKey (times dadc) dc)
    lift2 f df b@(Sparse pb db) c@(Sparse pc dc) = Sparse (f pb pc) da where
        (dadb, dadc) = df b c
        da = unionWith (<+>) 
            (mapWithKey (times dadb) db)
            (mapWithKey (times dadc) dc)
        
    lift2_ f df b@(Sparse pb db) c@(Sparse pc dc) = a where
        (dadb, dadc) = df a b c
        a = Sparse (f pb pc) da
        da = unionWith (<+>) 
            (mapWithKey (times dadb) db)
            (mapWithKey (times dadc) dc)
deriveLifted id $ conT ''Sparse
class Num a => Grad i o o' a | i -> a o o', o -> a i o', o' -> a i o where
    pack :: i -> [AD Sparse a] -> AD Sparse a
    unpack :: ([a] -> [a]) -> o
    unpack' :: ([a] -> (a, [a])) -> o'
instance Num a => Grad (AD Sparse a) [a] (a, [a]) a where
    pack i _ = i
    unpack f = f []
    unpack' f = f []
instance Grad i o o' a => Grad (AD Sparse a -> i) (a -> o) (a -> o') a where
    pack f (a:as) = pack (f a) as
    pack _ [] = error "Grad.pack: logic error"
    unpack f a = unpack (f . (a:))
    unpack' f a = unpack' (f . (a:))
vgrad :: Grad i o o' a => i -> o
vgrad i = unpack (unsafeGrad (pack i))
    where
        unsafeGrad f as = d as $ apply f as
vgrad' :: Grad i o o' a => i -> o'
vgrad' i = unpack' (unsafeGrad' (pack i))
    where
        unsafeGrad' f as = d' as $ apply f as
class Num a => Grads i o a | i -> a o, o -> a i where
    packs :: i -> [AD Sparse a] -> AD Sparse a
    unpacks :: ([a] -> Stream [] a) -> o
instance Num a => Grads (AD Sparse a) (Stream [] a) a where
    packs i _ = i
    unpacks f = f []
instance Grads i o a => Grads (AD Sparse a -> i) (a -> o) a where
    packs f (a:as) = packs (f a) as
    packs _ [] = error "Grad.pack: logic error"
    unpacks f a = unpacks (f . (a:))
vgrads :: Grads i o a => i -> o
vgrads i = unpacks (unsafeGrads (packs i))
    where
        unsafeGrads f as = ds as $ apply f as