{-# LANGUAGE BangPatterns #-}
{-# OPTIONS_GHC -fno-warn-name-shadowing #-}
{-# OPTIONS_HADDOCK not-home #-}
-----------------------------------------------------------------------------
-- |
-- Copyright   :  (c) Edward Kmett 2010-2021
-- License     :  BSD3
-- Maintainer  :  ekmett@gmail.com
-- Stability   :  experimental
-- Portability :  GHC only
--
-- common guts for Sparse.Double and Sparse mode
--
-- Handle with care.
-----------------------------------------------------------------------------
module Numeric.AD.Internal.Sparse.Common
  ( Monomial(..)
  , emptyMonomial
  , addToMonomial
  , indices
  , skeleton
  , terms
  ) where

import Data.IntMap (IntMap, toAscList, insertWith)
import qualified Data.IntMap as IntMap
import Data.Traversable

newtype Monomial = Monomial (IntMap Int)

emptyMonomial :: Monomial
emptyMonomial :: Monomial
emptyMonomial = IntMap Int -> Monomial
Monomial IntMap Int
forall a. IntMap a
IntMap.empty
{-# INLINE emptyMonomial #-}

addToMonomial :: Int -> Monomial -> Monomial
addToMonomial :: Int -> Monomial -> Monomial
addToMonomial Int
k (Monomial IntMap Int
m) = IntMap Int -> Monomial
Monomial ((Int -> Int -> Int) -> Int -> Int -> IntMap Int -> IntMap Int
forall a. (a -> a -> a) -> Int -> a -> IntMap a -> IntMap a
insertWith Int -> Int -> Int
forall a. Num a => a -> a -> a
(+) Int
k Int
1 IntMap Int
m)
{-# INLINE addToMonomial #-}

indices :: Monomial -> [Int]
indices :: Monomial -> [Int]
indices (Monomial IntMap Int
as) = (Int -> Int -> [Int]) -> (Int, Int) -> [Int]
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ((Int -> Int -> [Int]) -> Int -> Int -> [Int]
forall a b c. (a -> b -> c) -> b -> a -> c
flip Int -> Int -> [Int]
forall a. Int -> a -> [a]
replicate) ((Int, Int) -> [Int]) -> [(Int, Int)] -> [Int]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
`concatMap` IntMap Int -> [(Int, Int)]
forall a. IntMap a -> [(Int, a)]
toAscList IntMap Int
as
{-# INLINE indices #-}

skeleton :: Traversable f => f a -> f Int
skeleton :: f a -> f Int
skeleton = (Int, f Int) -> f Int
forall a b. (a, b) -> b
snd ((Int, f Int) -> f Int) -> (f a -> (Int, f Int)) -> f a -> f Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> a -> (Int, Int)) -> Int -> f a -> (Int, f Int)
forall (t :: * -> *) a b c.
Traversable t =>
(a -> b -> (a, c)) -> a -> t b -> (a, t c)
mapAccumL (\ !Int
n a
_ -> (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1, Int
n)) Int
0
{-# INLINE skeleton #-}

terms :: Monomial -> [(Integer,Monomial,Monomial)]
terms :: Monomial -> [(Integer, Monomial, Monomial)]
terms (Monomial IntMap Int
m) = [(Int, Int)] -> [(Integer, Monomial, Monomial)]
t (IntMap Int -> [(Int, Int)]
forall a. IntMap a -> [(Int, a)]
toAscList IntMap Int
m) where
  t :: [(Int, Int)] -> [(Integer, Monomial, Monomial)]
t [] = [(Integer
1,Monomial
emptyMonomial,Monomial
emptyMonomial)]
  t ((Int
k,Int
a):[(Int, Int)]
ts) = ((Integer, Int) -> [(Integer, Monomial, Monomial)])
-> [(Integer, Int)] -> [(Integer, Monomial, Monomial)]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ([(Integer, Monomial, Monomial)]
-> (Integer, Int) -> [(Integer, Monomial, Monomial)]
forall a.
Num a =>
[(a, Monomial, Monomial)] -> (a, Int) -> [(a, Monomial, Monomial)]
f ([(Int, Int)] -> [(Integer, Monomial, Monomial)]
t [(Int, Int)]
ts)) ([Integer] -> [Int] -> [(Integer, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip ([[Integer]]
bins[[Integer]] -> Int -> [Integer]
forall a. [a] -> Int -> a
!!Int
a) [Int
0..Int
a]) where
    f :: [(a, Monomial, Monomial)] -> (a, Int) -> [(a, Monomial, Monomial)]
f [(a, Monomial, Monomial)]
ps (a
b,Int
i) = ((a, Monomial, Monomial) -> (a, Monomial, Monomial))
-> [(a, Monomial, Monomial)] -> [(a, Monomial, Monomial)]
forall a b. (a -> b) -> [a] -> [b]
map (\(a
w,Monomial IntMap Int
mf,Monomial IntMap Int
mg) -> (a
wa -> a -> a
forall a. Num a => a -> a -> a
*a
b,IntMap Int -> Monomial
Monomial (Int -> Int -> IntMap Int -> IntMap Int
forall a. Int -> a -> IntMap a -> IntMap a
IntMap.insert Int
k Int
i IntMap Int
mf), IntMap Int -> Monomial
Monomial (Int -> Int -> IntMap Int -> IntMap Int
forall a. Int -> a -> IntMap a -> IntMap a
IntMap.insert Int
k (Int
aInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
i) IntMap Int
mg))) [(a, Monomial, Monomial)]
ps
  bins :: [[Integer]]
bins = ([Integer] -> [Integer]) -> [Integer] -> [[Integer]]
forall a. (a -> a) -> a -> [a]
iterate [Integer] -> [Integer]
forall a. Num a => [a] -> [a]
next [Integer
1]
  next :: [a] -> [a]
next xs :: [a]
xs@(a
_:[a]
ts) = a
1 a -> [a] -> [a]
forall a. a -> [a] -> [a]
: (a -> a -> a) -> [a] -> [a] -> [a]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith a -> a -> a
forall a. Num a => a -> a -> a
(+) [a]
xs [a]
ts [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a
1]
  next [] = [Char] -> [a]
forall a. HasCallStack => [Char] -> a
error [Char]
"impossible"