{-# language FlexibleContexts, GeneralizedNewtypeDeriving, DeriveFunctor #-}
-----------------------------------------------------------------------------
-- |
-- Module      :  Control.Iterative
-- Copyright   :  (c) Marco Zocca 2017
-- License     :  GPL-style (see the file LICENSE)
--
-- Maintainer  :  zocca marco gmail
-- Stability   :  experimental
-- Portability :  portable
--
-- Combinators and helper functions for iterative algorithms, with support for monitoring and exceptions.
--
-----------------------------------------------------------------------------
module Control.Iterative where

import Control.Exception.Common
import Numeric.LinearAlgebra.Class
import Numeric.Eps

import Control.Monad.Catch
import Data.Typeable

import Control.Monad (when)
-- import Control.Monad.Trans.Reader
import Control.Monad.State.Strict
-- import Control.Monad.Trans.Writer.CPS
import Control.Monad.Trans.Class (lift)
import qualified Control.Monad.Trans.State.Strict  as MTS -- (runStateT)
import Data.Foldable (foldrM)

import Data.VectorSpace


data ConvergenceStatus a = BufferNotReady
                         | Converging
                         | Converged a
                         | Diverging a a
                         | NotConverged
                           deriving (Eq, Show)

     
data IterationConfig a b =
  IterConf { numIterationsMax :: Int,
             printDebugInfo :: Bool,
             iterationView :: a -> b, 
             printDebugIO :: b -> IO ()}
instance Show (IterationConfig a b) where
  show (IterConf n qd _ _) = unwords ["Max. # of iterations:",show n,", print debug information:", show qd]
  


-- * Control primitives for bounded iteration with convergence check

-- -- | transform state until a condition is met
modifyUntil :: MonadState s m => (s -> Bool) -> (s -> s) -> m s
modifyUntil q f = modifyUntilM q (pure . f)
       
modifyUntilM :: MonadState s m => (s -> Bool) -> (s -> m s) -> m s
modifyUntilM q f = do
  x <- get
  y <- f x
  put y
  if q y then return y
         else modifyUntilM q f   







-- | `untilConvergedG0` is a special case of `untilConvergedG` that assesses convergence based on the L2 distance to a known solution `xKnown`
untilConvergedG0 ::
  (Normed v, MonadThrow m, MonadIO m, Typeable (Magnitude v), Typeable s, Show s) =>
     String
     -> IterationConfig s v
     -> v                    -- ^ Known value
     -> (s -> s)
     -> s
     -> m s
untilConvergedG0 fname config xKnown f x0 = 
  modifyInspectGuarded fname config norm2Diff nearZero qdiverg qfin f x0
   where
    qfin s = nearZero $ norm2 (xKnown ^-^ s)
  


-- | This function makes some default choices on the `modifyInspectGuarded` machinery: convergence is assessed using the squared L2 distance between consecutive states, and divergence is detected when this function is increasing between pairs of measurements.
untilConvergedG :: (Normed v, MonadThrow m, MonadIO m, Typeable (Magnitude v), Typeable s, Show s) =>
        String
     -> IterationConfig s v
     -> (v -> Bool)
     -> (s -> s)               
     -> s 
     -> m s
untilConvergedG fname config =
  modifyInspectGuarded fname config norm2Diff nearZero qdiverg


-- | ", monadic version
untilConvergedGM ::
  (Normed v, MonadThrow m, MonadIO m, Typeable (Magnitude v), Typeable s, Show s) =>
     String
     -> IterationConfig s v
     -> (v -> Bool)
     -> (s -> m s)          
     -> s
     -> m s
untilConvergedGM fname config =
  modifyInspectGuardedM fname config norm2Diff nearZero qdiverg





-- | `modifyInspectGuarded` is a high-order abstraction of a numerical iterative process. It accumulates a rolling window of 3 states and compares a summary `q` of the latest 2 with that of the previous two in order to assess divergence (e.g. if `q latest2 > q prev2` then it). The process ends when either we hit an iteration budget or relative convergence is verified. The function then assesses the final state with a predicate `qfinal` (e.g. against a known solution; if this is not known, the user can just supply `const True`)
modifyInspectGuarded ::
  (MonadThrow m, MonadIO m, Typeable s, Typeable a, Show s, Show a) =>
        String              -- ^ Calling function name
     -> IterationConfig s v -- ^ Configuration
     -> ([v] -> a)          -- ^ State summary array projection
     -> (a -> Bool)         -- ^ Convergence criterion
     -> (a -> a -> Bool)    -- ^ Divergence criterion
     -> (v -> Bool)         -- ^ Final state acceptance criterion
     -> (s -> s)            -- ^ State evolution
     -> s                   -- ^ Initial state
     -> m s                 -- ^ Final state
modifyInspectGuarded fname config sf qc qd qfin f x0 =
  modifyInspectGuardedM fname config sf qc qd qfin (pure . f) x0

  


-- | ", monadic version
modifyInspectGuardedM ::
  (MonadThrow m, MonadIO m, Typeable s, Show s, Typeable a, Show a) =>
     String
     -> IterationConfig s v
     -> ([v] -> a)
     -> (a -> Bool)
     -> (a -> a -> Bool)
     -> (v -> Bool)
     -> (s -> m s)
     -> s
     -> m s
modifyInspectGuardedM fname config sf qconverg qdiverg qfinal f x0 
  | nitermax > 0 = MTS.execStateT (go 0 []) x0
  | otherwise = throwM (NonNegError fname nitermax)
  where
    lwindow = 3
    nitermax = numIterationsMax config
    pf = iterationView config
    checkConvergStatus y i ll
      | length ll < lwindow = BufferNotReady
      | qdiverg qi qt && not (qconverg qi) = Diverging qi qt        
      | qconverg qi || qfinal (pf y) = Converged qi
      | i == nitermax - 1 = NotConverged         
      | otherwise = Converging
      where llf = pf <$> ll
            qi = sf $ init llf         -- summary of latest 2 states
            qt = sf $ tail llf         -- "       "  previous 2 states
    go i ll = do
      x <- MTS.get
      y <- lift $ f x
      when (printDebugInfo config) $ liftIO $ do
        putStrLn $ unwords ["Iteration", show i]
        printDebugIO config (pf y) 
      case checkConvergStatus y i ll of
        BufferNotReady -> do  
          MTS.put y
          let ll' = y : ll    -- cons current state to buffer
          go (i + 1) ll'
        Converged qi -> MTS.put y
        Diverging qi qt -> do
          MTS.put y
          throwM (DivergingE fname i qi qt)
        Converging -> do
          MTS.put y
          let ll' = init (y : ll) -- rolling state window
          go (i + 1) ll'
        NotConverged -> do
          MTS.put y
          throwM (NotConvergedE fname nitermax y)
             


-- | Some useful combinators


-- | Apply a function over a range of integer indices, zip the result with it and filter out the almost-zero entries
onRangeSparse :: Epsilon b => (Int -> b) -> [Int] -> [(Int, b)]
onRangeSparse f ixs = foldr ins [] ixs where
  ins x xr | isNz (f x) = (x, f x) : xr
           | otherwise = xr

-- | ", monadic version
onRangeSparseM :: (Epsilon b, Foldable t, Monad m) =>
     (a -> m b) -> t a -> m [(a, b)]
onRangeSparseM f ixs = unfoldZipM mf f ixs where
  mf x = isNz <$> f x
  


unfoldZipM0 :: (Foldable t, Monad m) =>
     (a -> Bool) -> (a -> b) -> t a -> m [(a, b)]
unfoldZipM0 q f = unfoldZipM (pure . q) (pure . f)


unfoldZipM :: (Foldable t, Monad m) =>
     (a -> m Bool) -> (a -> m b) -> t a -> m [(a, b)]
unfoldZipM q f ixs = foldrM insf [] ixs where
  insf x xr = do
    qx <- q x
    if qx
    then do
      y <- f x
      pure $ (x, y) : xr
    else pure xr

-- | A combinator I don't know how to call
combx :: Functor f => (a -> b) -> (t -> f a) -> t -> f b
combx g f x = g <$> f x 


          
    



-- | Helpers


-- meanl :: (Foldable t, Fractional a) => t a -> a
-- meanl xx = 1/fromIntegral (length xx) * sum xx

-- norm2l :: (Foldable t, Functor t, Floating a) => t a -> a
-- norm2l xx = sqrt $ sum (fmap (**2) xx)

-- | Squared difference of a 2-element list.
-- | NB: unsafe !
diffSqL :: Floating a => [a] -> a
diffSqL xx = (x1 - x2)**2 where [x1, x2] = [head xx, xx!!1]


-- | Relative tolerance :
-- relTol a b := ||a - b|| / (1 + min (||norm2 a||, ||norm2 b||))
relTol :: Normed v => v -> v -> Magnitude v
relTol a b = norm2 (a ^-^ b) / m where
  m = 1 + min (norm2 a) (norm2 b)


qdiverg :: Ord a => a -> a -> Bool
qdiverg = (>)

norm2Diff [s1, s0] = norm2 (s1 ^-^ s0)
norm2Diff _ = 1/0






-- test data

data S = S {unS1 :: Double, unS2 :: String} deriving (Eq, Show)
liftS1 f (S x i) = S (f x) i
s0 = S 1 "blah"
ic1 = IterConf 2 True unS1 print





-- playground

-- instance MonadThrow m => MonadThrow (WriterT w m) where
--   throwM = lift . throwM

-- -- | iter0 also accepts a configuration, e.g. for optional printing of debug info
-- -- iter0 :: MonadIO m =>
-- --      Int -> (s -> m s) -> (s -> String) -> IterationConfig s -> s -> m s
-- iter0 nmax f sf config x0 = flip runReaderT config $ MTS.execStateT (go (0 :: Int)) x0
--  where
--   go i = do
--     x <- get
--     c <- lift $ asks printDebugInfo  -- neat
--     y <- lift . lift $ f x           -- not neat
--     when c $ liftIO $ putStrLn $ sf y 
--     put y
--     unless (i >= nmax) (go $ i + 1)
 

-- -- | iter1 prints output at every iteration until the loop terminates OR is interrupted by an exception, whichever happens first
-- -- iter1 :: (MonadThrow m, MonadIO m, Typeable t, Show t) =>
-- --      (t -> m t) -> (t -> String) -> (t -> Bool) -> (t -> Bool) -> t -> m t
-- iter1 f wf qe qx x0 = execStateT (go 0) x0 where
--  go i = do
--    x <- get
--    y <- lift $ f x
--    _ <- liftIO $ wf y
--    when (qx y) $ throwM (NotConvergedE "bla" (i+1)  y)
--    put y
--    unless (qe y) $ go (i + 1) 


-- -- | iter2 concatenates output with WriterT but does NOT `tell` any output if an exception is raised before the end of the loop
-- iter2 :: (MonadThrow m, Monoid w, Typeable t, Show t) => (t -> m t)
--      -> (t -> w) -> (t -> Bool) -> (t -> Bool) -> t -> m (t, w)
-- iter2 f wf qe qx x0 = runWriterT $ execStateT (go 0) x0 where
--  go i = do
--    x <- get
--    y <- lift . lift $ f x
--    lift $ tell $ wf y
--    when (qx y) $ throwM (NotConvergedE "bla" (i+1)  y)
--    put y
--    unless (qe y) $ go (i + 1) 


-- -- test :: IO (Int, [String])
-- test :: IO Int
-- -- test :: IO ()
-- test = do
--   (yt, w ) <- iter2 f wf qe qexc x0
--   putStrLn w
--   return yt
--   -- iter1 f wf qe qexc x0
--   where
--     f = pure . (+ 1)
--     wf v = unwords ["state =", show v]
--     qe = (== 5)
--     qexc = (== 3)
--     x0 = 0 :: Int