-- |
-- Module      :  Math.LinearRecursive.Monad
-- Copyright   :  (c) Bin Jin 2011
-- License     :  BSD3
-- Maintainer  :  bjin1990+haskell@gmail.com
-- Stability   :  experimental
-- Portability :  portable
-- A monad to calculate linear recursive sequence efficiently. Matrix
-- multiplication and fast exponentiation algorithm are used to speed
-- up calculating the number with particular index in the sequence. This
-- library also provides a monadic DSL to describe the sequence.
-- As an example, here is the fibonacci sequence
-- >fib = do
-- >    [f0, f1] <- newVariables [1, 1]
-- >    f0 <:- f0 <+> f1
-- >    return f1
-- >>> map (runLinearRecursive fib) [0..10]
-- [1,1,2,3,5,8,13,21,34,55,89]

module Math.LinearRecursive.Monad
  -- * Vector
  -- ** vector types

  , LinearCombination
  , Variable
  -- ** vector operators and constant
  , (<+>)
  , (<->)
  , (<*)
  , (*>)
  , zeroVector
  -- * Monad
  , LinearRecursive
  , newVariable
  , newVariables
  , (<:-)
  , (<+-)
  , runLinearRecursive
  -- * Utility
  , getConstant
  , getPartialSum
  , getStep
  , getPowerOf
 ) where

import Control.Monad (zipWithM_)

import Data.IntMap (IntMap)
import qualified Data.IntMap as IntMap

import Math.LinearRecursive.Internal.Vector
import Math.LinearRecursive.Internal.Matrix

-- | A vector represents linear combination of several variables.
type LinearCombination = Vector

-- | An unit vector represents dependence on a particular variable.
type Variable = Vector1

data LRVariable a = LRV { initialValue :: a, dependency :: LinearCombination a }

dmap :: Num a => (LinearCombination a -> LinearCombination a) -> LRVariable a -> LRVariable a
dmap f (LRV val dep) = LRV val (f dep)

type LRVariables a = IntMap (LRVariable a)

-- | The monad to specify the calculation of next number of a linear recursive sequence.
-- All linear recursive sequences can be generated by iteration, the next number can
-- be represented by linear combination of some previous numbers. This can be regarded
-- as linear transformation between states, and it's actually multiply a transform matrix.
-- In order to formalize and simply this procedure, this monad use mutable-like variables to 
-- denote the states, and mutable-like assignment to denote the transform matrix.
-- To evaluate this sequence, the monad will be simulated step by step, after each step, all
-- variable will be updated. Besides, if the monad returns a 'LinearCombination', a number
-- will be generated each step. (well, actual calculation uses fast exponentiation algorithm 
-- to speed up this calculation)
data LinearRecursive a b = LR { unLR :: Int -> (b, Int, LRVariables a -> LRVariables a) }

-- unLR prevDeclaredVars = (return value, newDeclaredVars, changes to variables)

instance Num a => Monad (LinearRecursive a) where
    return a = LR (const (a, 0, id))
    a >>= b = LR $ \v -> let (ra, nva, ma) = unLR a v
                             (rb, nvb, mb) = unLR (b ra) (v + nva)
                             (rb, nva + nvb, mb . ma)

-- | Declare a new variable, with its initial value (the value before step 0).
-- >test = do
-- >    v <- newVariable 1
-- >    v <:- v <+> v
-- >    return v
-- >>> map (runLinearRecursive test) [0..10]
-- [1,2,4,8,16,32,64,128,256,512,1024]

newVariable :: Num a => a -> LinearRecursive a (Variable a)
newVariable val0 = LR $ \v -> (vector1 v, 1, IntMap.insert v variable)
    variable = LRV { initialValue = val0, dependency = zeroVector }

-- | Declare a new sequence, with their initial value.
-- After each step, each variable except the first one will be assigned with
-- the value of its predecessor variable before this turn. 
-- It's not encouraged to assign any value to the variables other 
-- than the first one.
-- >test = do
-- >     [v1, v2, v3] <- newVariables [1,2,3]
-- >     v1 <:- v3
-- >     return v3
-- >>> map (runLinearRecursive test) [0..10]
-- [3,2,1,3,2,1,3,2,1,3,2]

newVariables :: Num a => [a] -> LinearRecursive a [Variable a]
newVariables vals = do
    ret <- mapM newVariable vals
    zipWithM_ (<:-) (tail ret) ret
    return ret

-- | return a constent number.
-- >>> map (runLinearRecursive (getConstant 3)) [0..10]
-- [3,3,3,3,3,3,3,3,3,3,3]
getConstant :: Num a => a -> LinearRecursive a (LinearCombination a)
getConstant val = do
    one <- newVariable 1
    one <:- one
    return (toVector one *> val)

-- | return sum of a linear combination in steps before current one.
-- >>> map (runLinearRecursive (getConstant 3 >>= getPartialSum)) [0..10]
-- [0,3,6,9,12,15,18,21,24,27,30]
getPartialSum :: Num a => LinearCombination a -> LinearRecursive a (LinearCombination a)
getPartialSum val = do
    s <- newVariable 0
    s <:- s <+> val
    return (toVector s)

-- | return the current step number.
-- >>> map (runLinearRecursive getStep) [0..10]
-- [0,1,2,3,4,5,6,7,8,9,10]
getStep :: Num a => LinearRecursive a (LinearCombination a)
getStep = getConstant 1 >>= getPartialSum

-- | @getPowerOf a@ return power of @a@ with order equal to current step number.
-- >>> map (runLinearRecursive (getPowerOf 3)) [0..10]
-- [1,3,9,27,81,243,729,2187,6561,19683,59049]
getPowerOf :: Num a => a -> LinearRecursive a (LinearCombination a)
getPowerOf a = do
    prod <- newVariable 1
    prod <:- prod *> a
    return (toVector prod)

-- | Variable accumulated assignment. @v \<+- a@ replace variable @v@ with @v \<+\> a@.
-- Be aware that @v@ will be zero before any assignment.
(<+-) :: (Num a, VectorLike v) => Variable a -> v a -> LinearRecursive a ()
(<+-) var dep = LR (const ((), 0, IntMap.adjust (dmap (<+>toVector dep)) (unVector1 var)))

-- | Variable assignment. @v \<:- a@ replace variable @v@ with @a@ after this step. 
-- If there are multiple assignments to one variable, only the last one counts.
(<:-) :: (Num a, VectorLike v) => Variable a -> v a -> LinearRecursive a ()
(<:-) var dep = LR (const ((), 0, IntMap.adjust (dmap (const (toVector dep))) (unVector1 var)))

infix 1 <:-,<+-

buildMatrix :: Num a => LRVariables a -> (Matrix a, Matrix a)
buildMatrix mapping = (matrix trans, matrix $ map (: []) initValues)
    initValues = map initialValue (IntMap.elems mapping)
    rawDep = map (unVector'.dependency) (IntMap.elems mapping)
    varCount = length initValues
    trans = map (\m -> [IntMap.findWithDefault 0 i m | i <- [0..varCount-1]]) rawDep

-- | /O(v^3 * log n)/, where /v/ is the number of variables, and /n/ is steps to simulate.
-- @runLinearRecursive m n@ simulate the monad by @n@ steps, and return the actual value denoted
-- by returned 'LinearCombination'.
-- n must be non-negative.
runLinearRecursive :: (Num a, Integral b, VectorLike v) => LinearRecursive a (v a) -> b -> a
runLinearRecursive m steps | steps < 0 = error "runLinearRecursive: steps must be non-negative"
runLinearRecursive m steps = sum [head (res !! i) * ai | (i, ai) <- IntMap.assocs (unVector' target)]
    (target, _, g) = unLR m 0 
    dep = g IntMap.empty
    (trans, initCol) = buildMatrix dep

    res = unMatrix' (trans^steps * initCol)