{-# language FlexibleContexts, GeneralizedNewtypeDeriving, DeriveFunctor #-}
-----------------------------------------------------------------------------
-- |
-- Module      :  Control.Iterative
-- Copyright   :  (c) Marco Zocca 2017
-- License     :  GPL-style (see the file LICENSE)
--
-- Maintainer  :  zocca marco gmail
-- Stability   :  experimental
-- Portability :  portable
--
-- Combinators and helper functions for iterative algorithms, with support for monitoring and exceptions.
--
-----------------------------------------------------------------------------
module Control.Iterative where

import Control.Exception.Common
import Numeric.LinearAlgebra.Class
import Numeric.Eps

import Control.Monad.Catch
import Data.Typeable

import Control.Monad (when)
-- import Control.Monad.Trans.Reader
import Control.Monad.State.Strict
-- import Control.Monad.Trans.Writer.CPS
import Control.Monad.Trans.Class (lift)
import qualified Control.Monad.Trans.State.Strict  as MTS -- (runStateT)
import Data.Foldable (foldrM)




data ConvergenceStatus a = BufferNotReady
                         | Converging
                         | Converged a
                         | Diverging a a
                         | NotConverged
                           deriving (Eq, Show)

     
data IterationConfig a b =
  IterConf { numIterationsMax :: Int     -- ^ Max.# of iterations
           , printDebugInfo :: Bool      -- ^ Print iteration info to stdout 
           , iterationView :: a -> b     -- ^ Project state to a type `b`
           , printDebugIO :: b -> IO ()} -- ^ print function for type `b`
instance Show (IterationConfig a b) where
  show (IterConf n qd _ _) = unwords ["Max. # of iterations:",show n,", print debug information:", show qd]
  


-- * Control primitives for bounded iteration with convergence check

-- -- | transform state until a condition is met
modifyUntil :: MonadState s m => (s -> Bool) -> (s -> s) -> m s
modifyUntil q f = modifyUntilM q (pure . f)
       
modifyUntilM :: MonadState s m => (s -> Bool) -> (s -> m s) -> m s
modifyUntilM q f = do
  x <- get
  y <- f x
  put y
  if q y then return y
         else modifyUntilM q f   

-- | modifyUntil with optional iteration logging to stdout
modifyUntil' :: MonadIO m =>
   IterationConfig a b -> (a -> Bool) -> (a -> a) -> a -> m a
modifyUntil' config q f x0 = modifyUntilM' config q (pure . f) x0


modifyUntilM' :: MonadIO m =>
   IterationConfig a b -> (a -> Bool) -> (a -> m a) -> a -> m a
modifyUntilM' config q f x0 = MTS.execStateT (go 0) x0 where
  pf = iterationView config
  go i = do
   x <- get
   y <- lift $ f x
   when (printDebugInfo config) $ liftIO $ do
     putStrLn $ unwords ["Iteration", show i, "\n"]
     printDebugIO config (pf y) 
   put y
   if q y
     then return y
     else go (i + 1)




-- | `untilConvergedG0` is a special case of `untilConvergedG` that assesses convergence based on the L2 distance to a known solution `xKnown`
untilConvergedG0 ::
  (Normed v, MonadThrow m, MonadIO m, Typeable (Magnitude v), Typeable s, Show s) =>
     String
     -> IterationConfig s v
     -> v                    -- ^ Known value
     -> (s -> s)
     -> s
     -> m s
untilConvergedG0 fname config xKnown f x0 = 
  modifyInspectGuarded fname config norm2Diff nearZero qdiverg qfin f x0
   where
    qfin s = nearZero $ norm2 (xKnown ^-^ s)
  


-- | This function makes some default choices on the `modifyInspectGuarded` machinery: convergence is assessed using the squared L2 distance between consecutive states, and divergence is detected when this function is increasing between pairs of measurements.
untilConvergedG :: (Normed v, MonadThrow m, MonadIO m, Typeable (Magnitude v), Typeable s, Show s) =>
        String
     -> IterationConfig s v
     -> (v -> Bool)
     -> (s -> s)               
     -> s 
     -> m s
untilConvergedG fname config =
  modifyInspectGuarded fname config norm2Diff nearZero qdiverg


-- | ", monadic version
untilConvergedGM ::
  (Normed v, MonadThrow m, MonadIO m, Typeable (Magnitude v), Typeable s, Show s) =>
     String
     -> IterationConfig s v
     -> (v -> Bool)
     -> (s -> m s)          
     -> s
     -> m s
untilConvergedGM fname config =
  modifyInspectGuardedM fname config norm2Diff nearZero qdiverg





-- | `modifyInspectGuarded` is a high-order abstraction of a numerical iterative process. It accumulates a rolling window of 3 states and compares a summary `q` of the latest 2 with that of the previous two in order to assess divergence (e.g. if `q latest2 > q prev2` then the function throws an exception and terminates). The process ends by either hitting an iteration budget or by relative convergence, whichever happens first. After the iterations stop, the function then assesses the final state with a predicate `qfinal` (e.g. for comparing the final state with a known one; if this is not available, the user can just supply `const True`)
modifyInspectGuarded ::
  (MonadThrow m, MonadIO m, Typeable s, Typeable a, Show s, Show a) =>
        String              -- ^ Calling function name
     -> IterationConfig s v -- ^ Configuration
     -> ([v] -> a)          -- ^ State summary array projection
     -> (a -> Bool)         -- ^ Convergence criterion
     -> (a -> a -> Bool)    -- ^ Divergence criterion
     -> (v -> Bool)         -- ^ Final state acceptance criterion
     -> (s -> s)            -- ^ State evolution
     -> s                   -- ^ Initial state
     -> m s                 -- ^ Final state
modifyInspectGuarded fname config sf qc qd qfin f x0 =
  modifyInspectGuardedM fname config sf qc qd qfin (pure . f) x0

  


-- | ", monadic version
modifyInspectGuardedM ::
  (MonadThrow m, MonadIO m, Typeable s, Show s, Typeable a, Show a) =>
     String
     -> IterationConfig s v
     -> ([v] -> a)
     -> (a -> Bool)
     -> (a -> a -> Bool)
     -> (v -> Bool)
     -> (s -> m s)
     -> s
     -> m s
modifyInspectGuardedM fname config sf qconverg qdiverg qfinal f x0 
  | nitermax > 0 = MTS.execStateT (go 0 []) x0
  | otherwise = throwM (NonNegError fname nitermax)
  where
    lwindow = 3
    nitermax = numIterationsMax config
    pf = iterationView config
    checkConvergStatus y i ll
      | length ll < lwindow = BufferNotReady
      | qdiverg qi qt && not (qconverg qi) = Diverging qi qt        
      | qconverg qi || qfinal (pf y) = Converged qi
      | i == nitermax - 1 = NotConverged         
      | otherwise = Converging
      where llf = pf <$> ll
            qi = sf $ init llf         -- summary of latest 2 states
            qt = sf $ tail llf         -- "       "  previous 2 states
    go i ll = do
      x <- MTS.get
      y <- lift $ f x
      when (printDebugInfo config) $ liftIO $ do
        putStrLn $ unwords ["Iteration", show i]
        printDebugIO config (pf y) 
      case checkConvergStatus y i ll of
        BufferNotReady -> do  
          MTS.put y
          let ll' = y : ll    -- cons current state to buffer
          go (i + 1) ll'
        Converged qi -> MTS.put y
        Diverging qi qt -> do
          MTS.put y
          throwM (DivergingE fname i qi qt)
        Converging -> do
          MTS.put y
          let ll' = init (y : ll) -- rolling state window
          go (i + 1) ll'
        NotConverged -> do
          MTS.put y
          throwM (NotConvergedE fname nitermax y)
             


-- | Some useful combinators


-- | Apply a function over a range of integer indices, zip the result with it and filter out the almost-zero entries
onRangeSparse :: Epsilon b => (Int -> b) -> [Int] -> [(Int, b)]
onRangeSparse f ixs = foldr ins [] ixs where
  ins x xr | isNz (f x) = (x, f x) : xr
           | otherwise = xr

-- | ", monadic version
onRangeSparseM :: (Epsilon b, Foldable t, Monad m) =>
     (a -> m b) -> t a -> m [(a, b)]
onRangeSparseM f ixs = unfoldZipM mf f ixs where
  mf x = isNz <$> f x
  


unfoldZipM0 :: (Foldable t, Monad m) =>
     (a -> Bool) -> (a -> b) -> t a -> m [(a, b)]
unfoldZipM0 q f = unfoldZipM (pure . q) (pure . f)


unfoldZipM :: (Foldable t, Monad m) =>
     (a -> m Bool) -> (a -> m b) -> t a -> m [(a, b)]
unfoldZipM q f ixs = foldrM insf [] ixs where
  insf x xr = do
    qx <- q x
    if qx
    then do
      y <- f x
      pure $ (x, y) : xr
    else pure xr

-- | A combinator I don't know how to call
combx :: Functor f => (a -> b) -> (t -> f a) -> t -> f b
combx g f x = g <$> f x 


          
    



-- | Helpers

-- | Relative residual
relRes :: (Normed t, LinearVectorSpace t) =>
     MatrixType t -> t -> t -> Magnitude t
relRes aa b x = n / d where
  n = norm2 $ (aa #> x) ^-^ b
  d = norm2 b


-- meanl :: (Foldable t, Fractional a) => t a -> a
-- meanl xx = 1/fromIntegral (length xx) * sum xx

-- norm2l :: (Foldable t, Functor t, Floating a) => t a -> a
-- norm2l xx = sqrt $ sum (fmap (**2) xx)

-- | Squared difference of a 2-element list.
-- | NB: unsafe !
diffSqL :: Floating a => [a] -> a
diffSqL xx = (x1 - x2)**2 where [x1, x2] = [head xx, xx!!1]


-- | Relative tolerance :
-- relTol a b := ||a - b|| / (1 + min (||norm2 a||, ||norm2 b||))
relTol :: Normed v => v -> v -> Magnitude v
relTol a b = norm2 (a ^-^ b) / m where
  m = 1 + min (norm2 a) (norm2 b)


qdiverg :: Ord a => a -> a -> Bool
qdiverg = (>)

norm2Diff [s1, s0] = norm2 (s1 ^-^ s0)
norm2Diff _ = 1/0






-- test data

data S = S {unS1 :: Double, unS2 :: String} deriving (Eq, Show)
liftS1 f (S x i) = S (f x) i
s0 = S 1 "blah"
ic1 = IterConf 2 True unS1 print





-- playground

-- instance MonadThrow m => MonadThrow (WriterT w m) where
--   throwM = lift . throwM

-- -- | iter0 also accepts a configuration, e.g. for optional printing of debug info
-- -- iter0 :: MonadIO m =>
-- --      Int -> (s -> m s) -> (s -> String) -> IterationConfig s -> s -> m s
-- iter0 nmax f sf config x0 = flip runReaderT config $ MTS.execStateT (go (0 :: Int)) x0
--  where
--   go i = do
--     x <- get
--     c <- lift $ asks printDebugInfo  -- neat
--     y <- lift . lift $ f x           -- not neat
--     when c $ liftIO $ putStrLn $ sf y 
--     put y
--     unless (i >= nmax) (go $ i + 1)
 

-- -- | iter1 prints output at every iteration until the loop terminates OR is interrupted by an exception, whichever happens first
-- -- iter1 :: (MonadThrow m, MonadIO m, Typeable t, Show t) =>
-- --      (t -> m t) -> (t -> String) -> (t -> Bool) -> (t -> Bool) -> t -> m t
-- iter1 f wf qe qx x0 = execStateT (go 0) x0 where
--  go i = do
--    x <- get
--    y <- lift $ f x
--    _ <- liftIO $ wf y
--    when (qx y) $ throwM (NotConvergedE "bla" (i+1)  y)
--    put y
--    unless (qe y) $ go (i + 1) 


-- -- | iter2 concatenates output with WriterT but does NOT `tell` any output if an exception is raised before the end of the loop
-- iter2 :: (MonadThrow m, Monoid w, Typeable t, Show t) => (t -> m t)
--      -> (t -> w) -> (t -> Bool) -> (t -> Bool) -> t -> m (t, w)
-- iter2 f wf qe qx x0 = runWriterT $ execStateT (go 0) x0 where
--  go i = do
--    x <- get
--    y <- lift . lift $ f x
--    lift $ tell $ wf y
--    when (qx y) $ throwM (NotConvergedE "bla" (i+1)  y)
--    put y
--    unless (qe y) $ go (i + 1) 


-- -- test :: IO (Int, [String])
-- test :: IO Int
-- -- test :: IO ()
-- test = do
--   (yt, w ) <- iter2 f wf qe qexc x0
--   putStrLn w
--   return yt
--   -- iter1 f wf qe qexc x0
--   where
--     f = pure . (+ 1)
--     wf v = unwords ["state =", show v]
--     qe = (== 5)
--     qexc = (== 3)
--     x0 = 0 :: Int