{-# LANGUAGE PatternGuards, RankNTypes, TupleSections #-}
{- |

Description: find out derivatives are unnecessary

gradients and hessians tend to be sparse. This can be probed by
calculating these with NaNs as inputs. http://www.gpops2.com/ uses this
strategy, but perhaps there are relatively common cases (calls to
blas/lapack for example) that do not propagate NaNs correctly.

All functions provide indices that are affected (and should thus be included)

functions ending in 1 set one variable at a time to NaN, and additionally
provide a hint as to which variable to change.

functions not ending in 1 set all input variables to NaN

it seems that adding a 1 achieves the same thing as doing one more derivative.
In other words, nanPropagateG1 tells the same thing as nanPropagateH.

-}
module Ipopt.Sparsity (
    -- $setup
    -- * gradient sparsity
    nanPropagate1,
    nanPropagateG,

    -- * jacobian
    nanPropagateJ,
    -- * hessian sparsity
    nanPropagateH,
    nanPropagateG1,

    nanPropagateHF,

    -- * sparse derivatives (faked for now)
    jacobianSS,
    ) where

import Control.Applicative
import Control.Monad
import qualified Data.Vector as V
import Data.Vector ((!))
import Numeric.AD

import qualified Data.IntMap as M
import qualified Data.IntSet as S
import Data.Monoid

{- $setup
>>> let g x = x!0 + x!1 * x!1 * x!2

-}

{- | a nonzero gradient when inputs are NaN ==> no need to include
that row/column in the hessian, since it will be zero

>>> nanPropagateG1 4 g
[(1,fromList [1,2]),(2,fromList [1])]



-}
nanPropagateG1 :: Int -- ^ size of input vector
    -> (forall a. RealFloat a => V.Vector a -> a)
    -> [(Int, V.Vector Int)] -- ^ @(i,js)@
        -- shows that a NaN at index @i@ produces NaNs in gradient indexes
        -- @js@... so the hessians only need to include
        -- @i@ and @js@
nanPropagateG1 nx f | x0:_ <- dropWhile (isNaN . f) (trialV nx [0,0.5,1])
    = [ (i,j) | i <- [0 .. nx-1],
            let g = grad f (x0 V.// [(i,(0/0))])
                j = V.findIndices isNaN g,
            not (V.null j) ]

{- |

>>> nanPropagateG 4 g
fromList [0,1,2]

-}
nanPropagateG :: Int -- ^ size of input vector
    -> (forall a. RealFloat a => V.Vector a -> a)
    -> V.Vector Int
nanPropagateG nx f = V.findIndices (\x -> x /= 0)
        (grad f (V.replicate nx (0/0)))

{- |
>>> nanPropagate1 4 g
[0,1,2]

variable 3 isn't even used.
-}
nanPropagate1 :: Int -- ^ size of input vector
    -> (forall a. RealFloat a => V.Vector a -> a)
    -> [Int] -- ^ inputs that don't become NaN
nanPropagate1 nx f | x0:_ <- dropWhile (isNaN . f) (trialV nx [0,0.5,1])
    = [ i | i <- [0 .. nx-1],
            let g = f (x0 V.// [(i,(0/0))]),
            isNaN g ]


{- |
>>> nanPropagateH 4 g
fromList [(1,1),(1,2),(2,1)]

-}
nanPropagateH :: Int
    -> (forall a. RealFloat a => V.Vector a -> a)
    -> V.Vector (Int, Int) -- ^ (i,j) indexes
nanPropagateH nx f = nonzeroIxs $ hessian f (V.replicate nx (0/0::Double))

nanPropagateJ :: Int
    -> (forall a. RealFloat a => V.Vector a -> V.Vector a)
    -> V.Vector (Int,Int)
nanPropagateJ nx f = nonzeroIxs $ jacobian f (V.replicate nx (0/0::Double))

nanPropagateHF :: Int
    -> (forall a. RealFloat a => V.Vector a -> V.Vector a)
    -> V.Vector (V.Vector (Int,Int))
nanPropagateHF nx f =
        V.map nonzeroIxs $ hessianF f (V.replicate nx (0/0::Double))

{-
complement :: (Int,Int) -> V.Vector (Int,Int) -> V.Vector (Int,Int)
complement (mx,my) = M.fromList . M.toList
-}

collapse :: V.Vector (V.Vector (Int,Int)) -> V.Vector (Int,Int)
collapse = V.fromList . concatMap (\(a,bs) -> map (a,) (S.toList bs)) . M.toList .
            M.mapWithKey (\k s -> fst (S.split (k+1) s)) .
            M.fromListWith (<>) . map (\(a,b) -> (a, S.singleton b)) . concatMap V.toList . V.toList

nonzeroIxs = V.concatMap id . V.imap (\i -> V.map (i,) . V.findIndices (/= 0))

trialV :: Int -> [Double] -> [V.Vector Double]
trialV nx cand = V.fromList <$> replicateM nx cand



jacobianSS :: RealFloat a => V.Vector (Int,Int)
    -> (forall a. RealFloat a => V.Vector a -> V.Vector a)
    -> V.Vector a
    -> V.Vector a
jacobianSS ijs f x = let v = jacobian f x
    in V.map (\(i,j) -> v V.! i V.! j) ijs

{-
hessianFSS :: RealFloat a => V.Vector (Int,Int)
    -> (forall a. RealFloat a => V.Vector a -> V.Vector a)
    -> V.Vector a
    -> V.Vector a
hessianFSS = undefined
-}