{-# LANGUAGE TypeOperators #-}
{- |
Do not import this module. It is only for demonstration purposes.
-}
module Numeric.LAPACK.Example.EconomicAllocation where

import qualified Numeric.LAPACK.Matrix.Square as Square
import qualified Numeric.LAPACK.Matrix.Array as ArrMatrix
import qualified Numeric.LAPACK.Matrix as Matrix
import qualified Numeric.LAPACK.Vector as Vector
import Numeric.LAPACK.Matrix.Square ((|=|))
import Numeric.LAPACK.Matrix (ShapeInt, shapeInt, (#-#), (#*|), (#\|), (\\#))
import Numeric.LAPACK.Vector ((|-|))
import Numeric.LAPACK.Format ((##))

import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Shape ((::+)((::+)))

import qualified Data.Stream as Stream


{- $setup
>>> import Numeric.LAPACK.Example.EconomicAllocation
>>> import Test.Utility (approxVector)
>>>
>>> import qualified Numeric.LAPACK.Vector as Vector
>>> import Numeric.LAPACK.Vector ((+++))
>>>
>>> import qualified Data.Array.Comfort.Storable as Array
-}


type ZeroInt2 = ShapeInt::+ShapeInt
type Vector sh = Vector.Vector sh Double
type Matrix height width = Matrix.General height width Double
type SquareMatrix size = Square.Square size Double


balances0 :: Vector ZeroInt2
balances0 =
   Vector.fromList (shapeInt 2 ::+ shapeInt 2)
      [100000, 90000, -50000, -120000]

expenses0 :: Matrix ShapeInt ZeroInt2
expenses0 =
   Matrix.fromList (shapeInt 2) (shapeInt 2 ::+ shapeInt 2) $
   [16000,  4000,  8000, 12000,
    10000, 30000, 40000, 20000]

normalize ::
   (Eq height, Shape.C height, Shape.C width) =>
   Matrix height width -> Matrix height width
normalize x = Matrix.rowSums x \\# x

normalizeSplit ::
   (Shape.C sh0, Shape.C sh1, Eq sh1) =>
   Matrix sh1 (sh0::+sh1) -> (Matrix sh0 sh1, SquareMatrix sh1)
normalizeSplit expenses =
   let a = Matrix.transpose $ normalize expenses
   in (Matrix.takeTop a, Square.fromFull $ Matrix.takeBottom a)


completeIdSquare ::
   (Shape.C sh0, Eq sh0, Shape.C sh1, Eq sh1) =>
   Matrix sh1 (sh0::+sh1) -> SquareMatrix (sh0::+sh1)
completeIdSquare x =
   let (p,k) = normalizeSplit x
   in (Square.identityFromHeight p, p)
      |=|
      (Matrix.zero $ ArrMatrix.shape $ Matrix.transpose p, k)

iterated ::
   (Shape.C sh0, Eq sh0, Shape.C sh1, Eq sh1) =>
   Matrix sh1 (sh0::+sh1) -> Vector (sh0::+sh1) -> Vector (sh0::+sh1)
iterated expenses =
   Stream.head .
   Stream.dropWhile ((>=1e-5) . Vector.normInf . Vector.takeRight) .
   Stream.iterate (completeIdSquare expenses #*|)


compensated ::
   (Shape.C sh0, Eq sh0, Shape.C sh1, Eq sh1) =>
   Matrix sh1 (sh0::+sh1) -> Vector (sh0::+sh1) -> Vector sh0
compensated expenses balances =
   let (p,k) = normalizeSplit expenses
       x = Vector.takeLeft balances
       y = Vector.takeRight balances
   in x |-| p #*| (k #-# Square.identityFrom k) #\| y


{- |
prop> let result = iterated expenses0 balances0 in approxVector result $ compensated expenses0 balances0 +++ Vector.zero (Array.shape $ Vector.takeRight result)
-}
main :: IO ()
main = do
   iterated expenses0 balances0 ## "%10.2f"
   compensated expenses0 balances0 ## "%10.2f"