module Numeric.LAPACK.Permutation.Private where

import qualified Numeric.LAPACK.Matrix.Shape.Private as MatrixShape
import qualified Numeric.LAPACK.Matrix.Extent.Private as Extent
import qualified Numeric.LAPACK.Split as Split
import Numeric.LAPACK.Matrix.Shape.Private (Order(RowMajor, ColumnMajor))
import Numeric.LAPACK.Matrix.Private
         (Full, Square, ZeroInt, Inversion(NonInverted, Inverted))
import Numeric.LAPACK.Vector (Vector)
import Numeric.LAPACK.Format (Format(format))
import Numeric.LAPACK.Scalar (zero, one)
import Numeric.LAPACK.Private (fill, pointerSeq, copyBlock, copyToTemp)

import qualified Numeric.LAPACK.FFI.Generic as LapackGen
import qualified Numeric.Netlib.Utility as Call
import qualified Numeric.Netlib.Class as Class

import qualified Data.Array.Comfort.Storable.Unchecked as Array
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable.Unchecked (Array(Array))

import qualified Text.PrettyPrint.Boxes as TextBox

import qualified Foreign.Marshal.Array.Guarded as ForeignArray
import Foreign.Marshal.Array (advancePtr, copyArray)
import Foreign.C.Types (CInt)
import Foreign.ForeignPtr (withForeignPtr)
import Foreign.Ptr (Ptr)
import Foreign.Storable (Storable, poke, peek, pokeElemOff, peekElemOff)

import Control.Monad.Trans.Cont (ContT(ContT), evalContT)
import Control.Monad.IO.Class (liftIO)
import Control.Monad (when, forM_)
import Control.Applicative ((<$>))

import Data.Bool.HT (if')


newtype Permutation sh = Permutation (Vector sh CInt)
   deriving (Show)

instance (Shape.C sh) => Format (Permutation sh) where
   format _fmt (Permutation perm) =
      let n = Shape.size $ Array.shape perm
      in TextBox.vcat TextBox.top $
         map (TextBox.hsep 1 TextBox.right . map TextBox.char) $
         map (\k -> (replicate (k-1) '.' ++ '1' : replicate (n-k) '.')) $
         map fromIntegral $ Array.toList perm


{-
We could use laswp if it would be available for CInt elements.
-}
{- |
The pivot array must be at most as long as @Shape.size sh@.
-}
fromPivots :: (Shape.C sh) =>
   Inversion -> sh -> Vector ZeroInt CInt -> Permutation sh
fromPivots inverted sh (Array (Shape.ZeroBased numIPiv) ipiv) =
   Permutation $
   if' (numIPiv > Shape.size sh)
      (error "Permutation.fromPivots: too many pivots") $
   Array.unsafeCreateWithSize sh $ \n permPtr ->
   withForeignPtr ipiv $ \ipivPtr -> do
      sequence_ $ take n $ zipWith poke (pointerSeq 1 permPtr) (iterate (1+) 1)
      let is =
            case inverted of
               Inverted -> tail $ iterate (subtract 1) numIPiv
               NonInverted -> iterate (1+) 0
      forM_ (take numIPiv is) $ \i ->
         swapElem permPtr i =<< peek1 ipivPtr i

swapElem :: (Storable a) => Ptr a -> Int -> Int -> IO ()
swapElem ptr i j = swap (advancePtr ptr i) (advancePtr ptr j)

swap :: (Storable a) => Ptr a -> Ptr a -> IO ()
swap ptr0 ptr1 = do
   a <- peek ptr0
   poke ptr0 =<< peek ptr1
   poke ptr1 a


toPivots :: (Shape.C sh) => Inversion -> Permutation sh -> Vector sh CInt
toPivots inverted (Permutation (Array sh perm)) =
   Array.unsafeCreateWithSize sh $ \n invPtr ->
   withForeignPtr perm $ \perm0Ptr ->
   ForeignArray.alloca n $ \permPtr -> do
      case inverted of
         Inverted -> do
            copyArray permPtr perm0Ptr n
            transposeIO n permPtr invPtr
         NonInverted -> do
            copyArray invPtr perm0Ptr n
            transposeIO n perm0Ptr permPtr
      forM_ (take n $ iterate (1+) 0) $ \i -> do
         j <- peek1 invPtr i
         k <- peek1 permPtr i
         poke1 permPtr j k
         poke1 invPtr k j


data Sign = Negative | Positive
   deriving (Eq, Show)

{-
We could also count the cycles of even number. This might be a little faster.
-}
determinant :: (Shape.C sh) => Permutation sh -> Sign
determinant =
   (\oddp -> if oddp then Negative else Positive) .
   Split.oddPermutation . Array.toList . toPivots NonInverted

numberFromSign :: (Class.Floating a) => Sign -> a
numberFromSign s =
   case s of
      Negative -> -1
      Positive -> 1


transpose :: (Shape.C sh) => Permutation sh -> Permutation sh
transpose (Permutation (Array shape perm)) =
   Permutation $
   Array.unsafeCreateWithSize shape $ \n dstPtr ->
   withForeignPtr perm $ \srcPtr ->
   transposeIO n srcPtr dstPtr

transposeIO :: Int -> Ptr CInt -> Ptr CInt -> IO ()
transposeIO n srcPtr dstPtr =
   forM_ (take n $ iterate (1+) 0) $ \i -> do
      j <- peek1 srcPtr i
      poke1 dstPtr j i


multiply :: (Shape.C sh, Eq sh) =>
   Permutation sh -> Permutation sh -> Permutation sh
multiply (Permutation (Array shape permA)) (Permutation (Array shapeB permB)) =
   if shape /= shapeB
      then error "Permutation.multiply: sizes mismatch"
      else
         Permutation $
         Array.unsafeCreateWithSize shape $ \n cPtr ->
         withForeignPtr permA $ \aPtr ->
         withForeignPtr permB $ \bPtr ->
         forM_ (take n $ iterate (1+) 0) $ \i ->
            poke1 cPtr i =<< peek1 bPtr =<< peek1 aPtr i


toMatrix :: (Shape.C sh, Class.Floating a) => Permutation sh -> Square sh a
toMatrix (Permutation (Array shape perm)) =
   Array.unsafeCreate (MatrixShape.square RowMajor shape) $ \aPtr ->
   withForeignPtr perm $ \permPtr -> do
      let n = Shape.size shape
      fill zero (n*n) aPtr
      forM_ (take n $ zip (iterate (1+) 0) (pointerSeq n aPtr)) $
         \(k,rowPtr) -> do
            i <- peek1 permPtr k
            pokeElemOff rowPtr i one


peek1 :: Ptr CInt -> Int -> IO Int
peek1 ptr i = subtract 1 . fromIntegral <$> peekElemOff ptr i

poke1 :: Ptr CInt -> Int -> Int -> IO ()
poke1 ptr i j = pokeElemOff ptr i (fromIntegral (j+1))


apply ::
   (Extent.C vert, Extent.C horiz,
    Shape.C height, Eq height, Shape.C width, Class.Floating a) =>
   Bool -> Permutation height ->
   Full vert horiz height width a ->
   Full vert horiz height width a
apply inverted
      (Permutation (Array shapeP perm))
      (Array shape@(MatrixShape.Full order extent) a) =

   Array.unsafeCreateWithSize shape $ \blockSize bPtr -> do

   let (height,width) = Extent.dimensions extent
   Call.assert "Permutation.apply: heights mismatch" (height == shapeP)
   let m = Shape.size height
   let n = Shape.size width
   evalContT $ do
      fwdPtr <- Call.bool $ not inverted
      mPtr <- Call.cint m
      nPtr <- Call.cint n
      kPtr <- copyToTemp m perm
      aPtr <- ContT $ withForeignPtr a
      liftIO $ do
         copyBlock blockSize aPtr bPtr
         when (m>0 && n>0) $
            case order of
               RowMajor -> LapackGen.lapmt fwdPtr nPtr mPtr bPtr nPtr kPtr
               ColumnMajor -> LapackGen.lapmr fwdPtr mPtr nPtr bPtr mPtr kPtr