----------------------------------------------------------------------------
-- |
-- Module      :  Math.LinearRecursive.Monad
-- Copyright   :  (c) Bin Jin 2011
-- License     :  BSD3
--
-- Maintainer  :  bjin1990+haskell@gmail.com
-- Stability   :  experimental
-- Portability :  portable
--
-- A monad to calculate linear recursive sequence efficiently. Matrix
-- multiplication and fast exponentiation algorithm are used to speed
-- up calculating the number with particular index in the sequence. This
-- library also provides a monadic DSL to describe the sequence.
--
-- As an example, here is the fibonacci sequence
-- 
-- >fib = do
-- >    [f0, f1] <- newVariables [1, 1]
-- >    f0 <:- f0 <+> f1
-- >    return f1
-- 
-- >>> map (runLinearRecursive fib) [0..10]
-- [1,1,2,3,5,8,13,21,34,55,89]
-- 
----------------------------------------------------------------------------

module Math.LinearRecursive.Monad
  (
  -- * Vector
  -- ** vector types

   VectorLike(..)
  , LinearCombination
  , Variable
  -- ** vector operators and constant
  , (<+>)
  , (<->)
  , (<*)
  , (*>)
  , zeroVector
  -- * Polynomial
  , Polynomial
  , P.x
  -- * Monad
  , LinearRecursive
  , newVariable
  , newVariables
  , (<:-)
  , (<+-)
  , runLinearRecursive
  , simulateLinearRecursive
  -- * Utility
  , getConstant
  , getPartialSum
  , getStep
  , getPowerOf
  , getPolynomial
  , getPartialSumWith
 ) where

import Control.Monad (zipWithM_)
import Control.Applicative ((<$>))

import Data.IntMap (IntMap)
import qualified Data.IntMap as IntMap

import Math.LinearRecursive.Internal.Vector
import Math.LinearRecursive.Internal.Matrix
import Math.LinearRecursive.Internal.Polynomial hiding (fromList, toList, x)
import qualified Math.LinearRecursive.Internal.Polynomial as P

-- | A vector represents linear combination of several variables.
type LinearCombination = Vector

-- | An unit vector represents dependence on a particular variable.
type Variable = Vector1

data LRVariable a = LRV { initialValue :: a, dependency :: LinearCombination a }

dmap :: Num a => (LinearCombination a -> LinearCombination a) -> LRVariable a -> LRVariable a
dmap f (LRV val dep) = LRV val (f dep)

type LRVariables a = IntMap (LRVariable a)

-- | The monad to specify the calculation of next number of a linear recursive sequence.
--
-- All linear recursive sequences can be generated by iteration, the next number can
-- be represented by linear combination of some previous numbers. This can be regarded
-- as linear transformation between states, and it's actually multiply a transform matrix.
--
-- In order to formalize and simply this procedure, this monad use mutable-like variables to 
-- denote the states, and mutable-like assignment to denote the transform matrix.
--
-- To evaluate this sequence, the monad will be simulated step by step, after each step, all
-- variable will be updated. Besides, if the monad returns a 'LinearCombination', a number
-- will be generated each step. (well, actual calculation uses fast exponentiation algorithm 
-- to speed up this calculation)
--
data LinearRecursive a b = LR { unLR :: Int -> (b, Int, LRVariables a -> LRVariables a) }

-- unLR prevDeclaredVars = (return value, newDeclaredVars, changes to variables)

instance Num a => Functor (LinearRecursive a) where
    fmap f m = m >>= return . f

instance Num a => Monad (LinearRecursive a) where
    return a = LR (const (a, 0, id))
    a >>= b = LR $ \v -> let (ra, nva, ma) = unLR a v
                             (rb, nvb, mb) = unLR (b ra) (v + nva)
                         in
                             (rb, nva + nvb, mb . ma)

-- | Declare a new variable, with its initial value (the value before step 0).
--
-- >test = do
-- >    v <- newVariable 1
-- >    v <:- v <+> v
-- >    return v
--
-- >>> map (runLinearRecursive test) [0..10]
-- [1,2,4,8,16,32,64,128,256,512,1024]

newVariable :: (Eq a, Num a) => a -> LinearRecursive a (Variable a)
newVariable val0 = LR $ \v -> (vector1 v, 1, IntMap.insert v variable)
  where
    variable = LRV { initialValue = val0, dependency = zeroVector }

-- | Declare a new sequence, with their initial value.
--
-- After each step, each variable except the first one will be assigned with
-- the value of its predecessor variable before this turn. 
--
-- It's not encouraged to assign any value to the variables other 
-- than the first one.
--
-- >test = do
-- >     [v1, v2, v3] <- newVariables [1,2,3]
-- >     v1 <:- v3
-- >     return v3
--
-- >>> map (runLinearRecursive test) [0..10]
-- [3,2,1,3,2,1,3,2,1,3,2]

newVariables :: (Eq a, Num a) => [a] -> LinearRecursive a [Variable a]
newVariables vals = do
    ret <- mapM newVariable vals
    zipWithM_ (<:-) (tail ret) ret
    return ret

-- | return a constent number. Use one extra variable.
--
-- >>> map (runLinearRecursive (getConstant 3)) [0..10]
-- [3,3,3,3,3,3,3,3,3,3,3]
getConstant :: (Eq a, Num a) => a -> LinearRecursive a (LinearCombination a)
getConstant val = do
    one <- newVariable 1
    one <:- one
    return (toVector one *> val)

-- | return sum of a linear combination in steps before current one. Use one extra variable.
--
-- >>> map (runLinearRecursive (getConstant 3 >>= getPartialSum)) [0..10]
-- [0,3,6,9,12,15,18,21,24,27,30]
getPartialSum :: (Eq a, Num a) => LinearCombination a -> LinearRecursive a (LinearCombination a)
getPartialSum val = do
    s <- newVariable 0
    s <:- s <+> val
    return (toVector s)

-- | return the current step number. Use two extra variables.
--
-- >>> map (runLinearRecursive getStep) [0..10]
-- [0,1,2,3,4,5,6,7,8,9,10]
getStep :: (Eq a, Num a) => LinearRecursive a (LinearCombination a)
getStep = getConstant 1 >>= getPartialSum

-- | @getPowerOf a@ return power of @a@ with order equal to current step number. 
-- Use one extra variable.
--
-- >>> map (runLinearRecursive (getPowerOf 3)) [0..10]
-- [1,3,9,27,81,243,729,2187,6561,19683,59049]
getPowerOf :: (Eq a, Num a) => a -> LinearRecursive a (LinearCombination a)
getPowerOf a = do
    prod <- newVariable 1
    prod <:- prod *> a
    return (toVector prod)

-- | given n polynomials, the i-th (0 indexed) polynomial's degree is i and with first 
-- coeff equal to one. find the linear combination for x^i for each i in [0, n)
inverseTrans :: (Eq a, Num a) => [Polynomial a] -> Matrix a
inverseTrans polys = inverseMatrixDiag1 ma
  where
    n = length polys
    ma = matrix [[vcomponent (unPoly polyi) j | j <- [0..n-1]] | polyi <- polys]

getPartialSumWith :: (Eq a, Num a, VectorLike v) => Polynomial a -> v a -> LinearRecursive a (LinearCombination a)
getPartialSumWith poly v
    | n < 0     = return zeroVector
    | otherwise = do
        basisValue <- go (toVector v) 0
        let vars = map (foldl (<+>) zeroVector . zipWith (*>) basisValue) trans
        return $ foldl (<+>) zeroVector [ powi *> coeffi
                                        | (i, powi) <- zip [0..] vars
                                        , let coeffi = vcomponent vec i
                                        ]
  where
    n = degree poly
    basisPoly = scanl (*) 1 [P.x + fromIntegral i | i <- [1..n]]

    go prev pos | pos > n   = return []
                | otherwise = do
                    next <- (*> fromIntegral (pos `max` 1)) . (<+> prev) <$> getPartialSum prev
                    (next:) <$> go next (pos + 1)

    trans = unMatrix (inverseTrans basisPoly)
    vec = unPoly poly

-- | @getPolynomial poly@ evaluate polynomial @poly@ with variable @x@ replaced by current step number. 
-- Use @n@ extra variables, where @n@ is the degree of @poly@
--
-- >>> map (runLinearRecursive (getPolynomial ((x+1)^2))) [0..10]
-- [1,4,9,16,25,36,49,64,81,100,121]
getPolynomial :: (Eq a, Num a) => Polynomial a -> LinearRecursive a (LinearCombination a)
getPolynomial poly = newVariable 1 >>= getPartialSumWith poly


-- | Variable accumulated assignment. @v \<+- a@ replace variable @v@ with @v \<+\> a@.
--
-- Be aware that @v@ will be zero before any assignment.
(<+-) :: (Eq a, Num a, VectorLike v) => Variable a -> v a -> LinearRecursive a ()
(<+-) var dep = LR (const ((), 0, IntMap.adjust (dmap (<+>toVector dep)) (unVector1 var)))

-- | Variable assignment. @v \<:- a@ replace variable @v@ with @a@ after this step. 
-- If there are multiple assignments to one variable, only the last one counts.
(<:-) :: (Eq a, Num a, VectorLike v) => Variable a -> v a -> LinearRecursive a ()
(<:-) var dep = LR (const ((), 0, IntMap.adjust (dmap (const (toVector dep))) (unVector1 var)))

infix 1 <:-,<+-

buildMatrix :: (Eq a, Num a) => LRVariables a -> (Matrix a, Matrix a)
buildMatrix mapping = (matrix trans, matrix $ map (: []) initValues)
  where
    initValues = map initialValue (IntMap.elems mapping)
    rawDep = map (unVector'.dependency) (IntMap.elems mapping)
    varCount = length initValues
    trans = map (\m -> [IntMap.findWithDefault 0 i m | i <- [0..varCount-1]]) rawDep

-- | /O(v^3 * log n)/, where /v/ is the number of variables, and /n/ is steps to simulate.
--
-- @runLinearRecursive m n@ simulate the monad by @n@ steps, and return the actual value denoted
-- by returned 'LinearCombination'.
--
-- n must be non-negative.
runLinearRecursive :: (Eq a, Num a, Integral b, VectorLike v) => LinearRecursive a (v a) -> b -> a
runLinearRecursive _ steps | steps < 0 = error "runLinearRecursive: steps must be non-negative"
runLinearRecursive m steps = sum [head (res !! i) * ai | (i, ai) <- IntMap.assocs (unVector' target)]
  where
    (target, _, g) = unLR m 0 
    dep = g IntMap.empty
    (trans, initCol) = buildMatrix dep

    res = unMatrix' (trans^steps * initCol)

-- | /O(v^2 * n)/. similar to @runLinearRecursive@, but return an infinite list instead of a particular index.
simulateLinearRecursive :: (Eq a, Num a, VectorLike v) => LinearRecursive a (v a) -> [a]
simulateLinearRecursive m = map (\res -> sum [head (res !! i) * ai | (i, ai) <- IntMap.assocs (unVector' target)]) cols
  where
    (target, _, g) = unLR m 0
    dep = g IntMap.empty
    (trans, initCol) = buildMatrix dep

    cols = map unMatrix' $ scanl (flip (*)) initCol (repeat trans)