{-# LANGUAGE TypeOperators #-}
{- |
Do not import this module. It is only for demonstration purposes.
-}
module Numeric.LAPACK.Example.EconomicAllocation where

import qualified Numeric.LAPACK.Matrix.Square as Square
import qualified Numeric.LAPACK.Matrix.Array as ArrMatrix
import qualified Numeric.LAPACK.Matrix as Matrix
import qualified Numeric.LAPACK.Vector as Vector
import Numeric.LAPACK.Matrix.Square ((|=|))
import Numeric.LAPACK.Matrix (ShapeInt, shapeInt, (#-#), (#*|), (#\|), (\\#))
import Numeric.LAPACK.Vector ((|-|))
import Numeric.LAPACK.Format ((##))

import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Shape ((:+:)((:+:)))


type ZeroInt2 = ShapeInt:+:ShapeInt
type Vector sh = Vector.Vector sh Double
type Matrix height width = Matrix.General height width Double
type SquareMatrix size = Square.Square size Double


balances0 :: Vector ZeroInt2
balances0 =
   Vector.fromList (shapeInt 2 :+: shapeInt 2)
      [100000, 90000, -50000, -120000]

expenses0 :: Matrix ShapeInt ZeroInt2
expenses0 =
   Matrix.fromList (shapeInt 2) (shapeInt 2 :+: shapeInt 2) $
   [16000,  4000,  8000, 12000,
    10000, 30000, 40000, 20000]

normalize ::
   (Eq height, Shape.C height, Shape.C width) =>
   Matrix height width -> Matrix height width
normalize x = Matrix.rowSums x \\# x

normalizeSplit ::
   (Shape.C sh0, Shape.C sh1, Eq sh1) =>
   Matrix sh1 (sh0:+:sh1) -> (Matrix sh0 sh1, SquareMatrix sh1)
normalizeSplit expenses =
   let a = Matrix.transpose $ normalize expenses
   in (Matrix.takeTop a, Square.fromGeneral $ Matrix.takeBottom a)


completeIdSquare ::
   (Shape.C sh0, Eq sh0, Shape.C sh1, Eq sh1) =>
   Matrix sh1 (sh0:+:sh1) -> SquareMatrix (sh0:+:sh1)
completeIdSquare x =
   let (p,k) = normalizeSplit x
   in (Square.identityFromHeight p, p)
      |=|
      (Matrix.zero $ ArrMatrix.shape $ Matrix.transpose p, k)

iterated ::
   (Shape.C sh0, Eq sh0, Shape.C sh1, Eq sh1) =>
   Matrix sh1 (sh0:+:sh1) -> Vector (sh0:+:sh1) -> Vector (sh0:+:sh1)
iterated expenses =
   -- 'Stream.head' would be total
   head . dropWhile ((>=1e-5) . Vector.normInf . Vector.takeRight) .
   iterate (completeIdSquare expenses #*|)


compensated ::
   (Shape.C sh0, Eq sh0, Shape.C sh1, Eq sh1) =>
   Matrix sh1 (sh0:+:sh1) -> Vector (sh0:+:sh1) -> Vector sh0
compensated expenses balances =
   let (p,k) = normalizeSplit expenses
       x = Vector.takeLeft balances
       y = Vector.takeRight balances
   in x |-| p #*| (k #-# Square.identityFrom k) #\| y


main :: IO ()
main = do
   iterated expenses0 balances0 ## "%10.2f"
   compensated expenses0 balances0 ## "%10.2f"