{-# language FlexibleContexts, GeneralizedNewtypeDeriving, DeriveFunctor #-}
module Control.Iterative where
import Control.Exception.Common
import Numeric.LinearAlgebra.Class
import Numeric.Eps
import Control.Monad.Catch
import Data.Typeable
import Control.Monad (when)
import Control.Monad.State.Strict
import Control.Monad.Trans.Class (lift)
import qualified Control.Monad.Trans.State.Strict as MTS
import Data.Foldable (foldrM)
data ConvergenceStatus a = BufferNotReady
| Converging
| Converged a
| Diverging a a
| NotConverged
deriving (Eq, Show)
data IterationConfig a b =
IterConf { numIterationsMax :: Int
, printDebugInfo :: Bool
, iterationView :: a -> b
, printDebugIO :: b -> IO ()}
instance Show (IterationConfig a b) where
show (IterConf n qd _ _) = unwords ["Max. # of iterations:",show n,", print debug information:", show qd]
modifyUntil :: MonadState s m => (s -> Bool) -> (s -> s) -> m s
modifyUntil q f = modifyUntilM q (pure . f)
modifyUntilM :: MonadState s m => (s -> Bool) -> (s -> m s) -> m s
modifyUntilM q f = do
x <- get
y <- f x
put y
if q y then return y
else modifyUntilM q f
modifyUntil' :: MonadIO m =>
IterationConfig a b -> (a -> Bool) -> (a -> a) -> a -> m a
modifyUntil' config q f x0 = modifyUntilM' config q (pure . f) x0
modifyUntilM' :: MonadIO m =>
IterationConfig a b -> (a -> Bool) -> (a -> m a) -> a -> m a
modifyUntilM' config q f x0 = MTS.execStateT (go 0) x0 where
pf = iterationView config
go i = do
x <- get
y <- lift $ f x
when (printDebugInfo config) $ liftIO $ do
putStrLn $ unwords ["Iteration", show i, "\n"]
printDebugIO config (pf y)
put y
if q y
then return y
else go (i + 1)
untilConvergedG0 ::
(Normed v, MonadThrow m, MonadIO m, Typeable (Magnitude v), Typeable s, Show s) =>
String
-> IterationConfig s v
-> v
-> (s -> s)
-> s
-> m s
untilConvergedG0 fname config xKnown f x0 =
modifyInspectGuarded fname config norm2Diff nearZero qdiverg qfin f x0
where
qfin s = nearZero $ norm2 (xKnown ^-^ s)
untilConvergedG :: (Normed v, MonadThrow m, MonadIO m, Typeable (Magnitude v), Typeable s, Show s) =>
String
-> IterationConfig s v
-> (v -> Bool)
-> (s -> s)
-> s
-> m s
untilConvergedG fname config =
modifyInspectGuarded fname config norm2Diff nearZero qdiverg
untilConvergedGM ::
(Normed v, MonadThrow m, MonadIO m, Typeable (Magnitude v), Typeable s, Show s) =>
String
-> IterationConfig s v
-> (v -> Bool)
-> (s -> m s)
-> s
-> m s
untilConvergedGM fname config =
modifyInspectGuardedM fname config norm2Diff nearZero qdiverg
modifyInspectGuarded ::
(MonadThrow m, MonadIO m, Typeable s, Typeable a, Show s, Show a) =>
String
-> IterationConfig s v
-> ([v] -> a)
-> (a -> Bool)
-> (a -> a -> Bool)
-> (v -> Bool)
-> (s -> s)
-> s
-> m s
modifyInspectGuarded fname config sf qc qd qfin f x0 =
modifyInspectGuardedM fname config sf qc qd qfin (pure . f) x0
modifyInspectGuardedM ::
(MonadThrow m, MonadIO m, Typeable s, Show s, Typeable a, Show a) =>
String
-> IterationConfig s v
-> ([v] -> a)
-> (a -> Bool)
-> (a -> a -> Bool)
-> (v -> Bool)
-> (s -> m s)
-> s
-> m s
modifyInspectGuardedM fname config sf qconverg qdiverg qfinal f x0
| nitermax > 0 = MTS.execStateT (go 0 []) x0
| otherwise = throwM (NonNegError fname nitermax)
where
lwindow = 3
nitermax = numIterationsMax config
pf = iterationView config
checkConvergStatus y i ll
| length ll < lwindow = BufferNotReady
| qdiverg qi qt && not (qconverg qi) = Diverging qi qt
| qconverg qi || qfinal (pf y) = Converged qi
| i == nitermax - 1 = NotConverged
| otherwise = Converging
where llf = pf <$> ll
qi = sf $ init llf
qt = sf $ tail llf
go i ll = do
x <- MTS.get
y <- lift $ f x
when (printDebugInfo config) $ liftIO $ do
putStrLn $ unwords ["Iteration", show i]
printDebugIO config (pf y)
case checkConvergStatus y i ll of
BufferNotReady -> do
MTS.put y
let ll' = y : ll
go (i + 1) ll'
Converged qi -> MTS.put y
Diverging qi qt -> do
MTS.put y
throwM (DivergingE fname i qi qt)
Converging -> do
MTS.put y
let ll' = init (y : ll)
go (i + 1) ll'
NotConverged -> do
MTS.put y
throwM (NotConvergedE fname nitermax y)
onRangeSparse :: Epsilon b => (Int -> b) -> [Int] -> [(Int, b)]
onRangeSparse f ixs = foldr ins [] ixs where
ins x xr | isNz (f x) = (x, f x) : xr
| otherwise = xr
onRangeSparseM :: (Epsilon b, Foldable t, Monad m) =>
(a -> m b) -> t a -> m [(a, b)]
onRangeSparseM f ixs = unfoldZipM mf f ixs where
mf x = isNz <$> f x
unfoldZipM0 :: (Foldable t, Monad m) =>
(a -> Bool) -> (a -> b) -> t a -> m [(a, b)]
unfoldZipM0 q f = unfoldZipM (pure . q) (pure . f)
unfoldZipM :: (Foldable t, Monad m) =>
(a -> m Bool) -> (a -> m b) -> t a -> m [(a, b)]
unfoldZipM q f ixs = foldrM insf [] ixs where
insf x xr = do
qx <- q x
if qx
then do
y <- f x
pure $ (x, y) : xr
else pure xr
combx :: Functor f => (a -> b) -> (t -> f a) -> t -> f b
combx g f x = g <$> f x
relRes :: (Normed t, LinearVectorSpace t) =>
MatrixType t -> t -> t -> Magnitude t
relRes aa b x = n / d where
n = norm2 $ (aa #> x) ^-^ b
d = norm2 b
diffSqL :: Floating a => [a] -> a
diffSqL xx = (x1 - x2)**2 where [x1, x2] = [head xx, xx!!1]
relTol :: Normed v => v -> v -> Magnitude v
relTol a b = norm2 (a ^-^ b) / m where
m = 1 + min (norm2 a) (norm2 b)
qdiverg :: Ord a => a -> a -> Bool
qdiverg = (>)
norm2Diff [s1, s0] = norm2 (s1 ^-^ s0)
norm2Diff _ = 1/0
data S = S {unS1 :: Double, unS2 :: String} deriving (Eq, Show)
liftS1 f (S x i) = S (f x) i
s0 = S 1 "blah"
ic1 = IterConf 2 True unS1 print