{-# LANGUAGE TypeFamilies #-}
module Numeric.LAPACK.Permutation.Private where

import qualified Numeric.LAPACK.Matrix.Shape.Private as MatrixShape
import qualified Numeric.LAPACK.Matrix.Extent.Private as Extent
import qualified Numeric.LAPACK.Shape as ExtShape
import qualified Numeric.LAPACK.Output as Output
import Numeric.LAPACK.Output (Output, formatAligned)
import Numeric.LAPACK.Matrix.Shape.Private (Order(RowMajor, ColumnMajor))
import Numeric.LAPACK.Matrix.Modifier
         (Transposition(NonTransposed,Transposed),
          Inversion(NonInverted,Inverted))
import Numeric.LAPACK.Matrix.Private (Full, Square, shapeInt)
import Numeric.LAPACK.Vector (Vector)
import Numeric.LAPACK.Scalar (zero, one)
import Numeric.LAPACK.Private (copyBlock, copyToTemp)

import qualified Numeric.LAPACK.FFI.Generic as LapackGen
import qualified Numeric.Netlib.Utility as Call
import qualified Numeric.Netlib.Class as Class

import qualified Data.Array.Comfort.Storable.Mutable.Unchecked as MutArray
import qualified Data.Array.Comfort.Storable.Unchecked as Array
import qualified Data.Array.Comfort.Storable as CheckedArray
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable.Unchecked (Array(Array), (!))

import Foreign.C.Types (CInt)
import Foreign.ForeignPtr (withForeignPtr)
import Foreign.Ptr (Ptr, castPtr)
import Foreign.Storable (Storable, sizeOf, alignment, poke, peek)

import Control.Monad.Trans.Cont (ContT(ContT), evalContT)
import Control.Monad.IO.Class (liftIO)
import Control.Monad.ST (ST, runST)
import Control.Monad (when, forM_)
import Control.Applicative (liftA2, (<$>))

import qualified Data.Tuple.HT as Tuple
import Data.Function.HT (powerAssociative)
import Data.Monoid (Monoid, mempty, mappend)
import Data.Semigroup (Semigroup, (<>))

import Prelude hiding (odd)


newtype Permutation sh = Permutation (Vector (Shape sh) (Element sh))
   deriving (Show)

format :: (Shape.C sh, Output out) => Permutation sh -> out
format (Permutation perm) =
   let n = Shape.size $ Array.shape perm
   in formatAligned $
      map (map ((:[]) . Output.text . (:""))) $
      map (\k -> (replicate (k-1) '.' ++ '1' : replicate (n-k) '.')) $
      map (fromIntegral . deconsElement) $ Array.toList perm


size :: Permutation sh -> sh
size (Permutation (Array (Shape shape) _perm)) = shape

identity :: (Shape.C sh) => sh -> Permutation sh
identity shape = Permutation $ CheckedArray.sample (Shape shape) id

fromPivots ::
   (Shape.C sh) =>
   Inversion -> Vector (Shape sh) (Element sh) -> Permutation sh
fromPivots inverted ipiv =
   fromPivotsGen inverted (Array.shape ipiv) ipiv

{-
We could use laswp if it would be available for CInt elements.
-}
fromTruncatedPivots ::
   (Shape.C sh, Shape.C sh1) =>
   Inversion ->
   Vector (ExtShape.Min sh1 (Shape sh)) (Element sh) -> Permutation sh
fromTruncatedPivots inverted ipiv =
   fromPivotsGen inverted (ExtShape.minShape1 $ Array.shape ipiv) ipiv

fromPivotsGen ::
   (Shape.C sh, Shape.Indexed small, Shape.Index small ~ Element sh) =>
   Inversion -> Shape sh -> Vector small (Element sh) -> Permutation sh
fromPivotsGen inverted sh ipiv =
   Permutation $
   runST (do
      perm <- initMutable sh $ \perm i -> MutArray.write perm i i
      forM_ (indices inverted $ Array.shape ipiv) $ \i -> swap perm i (ipiv!i)
      MutArray.unsafeFreeze perm)

swap ::
   (Shape.Indexed sh, Shape.Index sh ~ ix, Storable a) =>
   MutArray.Array (ST s) sh a -> ix -> ix -> ST s ()
swap arr i j = do
   a <- MutArray.read arr i
   MutArray.write arr i =<< MutArray.read arr j
   MutArray.write arr j a

indices ::
   (Shape.C sh, Shape.Indexed small, Shape.Index small ~ Element sh) =>
   Inversion -> small -> [Element sh]
indices inverted sh =
   let numIPiv = Shape.size sh
   in take numIPiv $ map Element $
      case inverted of
         Inverted -> iterate (subtract 1) (fromIntegral numIPiv)
         NonInverted -> iterate (1+) 1


toPivots ::
   (Shape.C sh) => Inversion -> Permutation sh -> Vector sh (Element sh)
toPivots inverted (Permutation a) =
   let sh = Array.shape a
   in Array.reshape (deconsShape sh) $
      runST (do
         (inv,perm) <-
            (case inverted of Inverted -> Tuple.swap; NonInverted -> id)
            <$>
            liftA2 (,)
               (MutArray.thaw a)
               (transposeToMutable a)
         forM_ (Shape.indices sh) $ \i -> do
            j <- MutArray.read inv i
            k <- MutArray.read perm i
            MutArray.write perm j k
            MutArray.write inv k j
         MutArray.unsafeFreeze inv)


data Sign = Positive | Negative
   deriving (Eq, Show, Enum, Bounded)

instance Semigroup Sign where
   x<>y = if x==y then Positive else Negative

instance Monoid Sign where
   mempty = Positive
   mappend = (<>)

{-
We could also count the cycles of even number. This might be a little faster.
-}
determinant :: (Shape.C sh) => Permutation sh -> Sign
determinant =
   (\oddp -> if oddp then Negative else Positive) .
   odd . map deconsElement . Array.toList . toPivots NonInverted

{- |
> numberFromSign s == (-1)^fromEnum s
-}
numberFromSign :: (Class.Floating a) => Sign -> a
numberFromSign s =
   case s of
      Negative -> -1
      Positive -> 1


condNegate :: (Class.Floating a) => [CInt] -> a -> a
condNegate ipiv = if odd ipiv then negate else id

odd :: [CInt] -> Bool
odd = not . null . dropEven . filter id . zipWith (/=) [1..]

dropEven :: [a] -> [a]
dropEven (_:_:xs) = dropEven xs
dropEven xs = xs


transpose :: (Shape.C sh) => Permutation sh -> Permutation sh
transpose (Permutation perm) =
   Permutation $ runST (MutArray.unsafeFreeze =<< transposeToMutable perm)

transposeToMutable ::
   (Shape.Indexed sh, Shape.Index sh ~ ix, Storable ix) =>
   Array sh ix -> ST s (MutArray.Array (ST s) sh ix)
transposeToMutable perm =
   initMutable (Array.shape perm) $ \inv i -> MutArray.write inv (perm!i) i

inversionFromTransposition :: Transposition -> Inversion
inversionFromTransposition trans =
   case trans of
      NonTransposed -> NonInverted
      Transposed -> Inverted


multiply :: (Shape.C sh, Eq sh) =>
   Permutation sh -> Permutation sh -> Permutation sh
multiply a b =
   if size a /= size b
      then error "Permutation.multiply: sizes mismatch"
      else multiplyUnchecked a b

square :: (Shape.C sh) => Permutation sh -> Permutation sh
square p = multiplyUnchecked p p

power :: (Shape.C sh) => Integer -> Permutation sh -> Permutation sh
power n p = powerAssociative multiplyUnchecked (identity $ size p) p n

multiplyUnchecked :: (Shape.C sh) =>
   Permutation sh -> Permutation sh -> Permutation sh
multiplyUnchecked (Permutation a) (Permutation b) =
   Permutation $ CheckedArray.sample (Array.shape a) $ \i -> b!(a!i)


takeDiagonal ::
   (Shape.C sh, Class.Floating a) => Permutation sh -> Vector sh a
takeDiagonal (Permutation a) =
   Array.mapShape deconsShape $
   CheckedArray.sample (Array.shape a) $ \i -> if a!i == i then 1 else 0


toMatrix :: (Shape.C sh, Class.Floating a) => Permutation sh -> Square sh a
toMatrix (Permutation perm) =
   let shape = Array.shape perm
   in Array.reshape (MatrixShape.square RowMajor $ deconsShape shape) $
      runST (do
         a <- MutArray.new (shape,shape) zero
         forM_ (Shape.indices $ Array.shape perm) $ \k ->
            MutArray.write a (k, perm!k) one
         MutArray.unsafeFreeze a)


apply ::
   (Extent.C vert, Extent.C horiz,
    Shape.C height, Eq height, Shape.C width, Class.Floating a) =>
   Inversion -> Permutation height ->
   Full vert horiz height width a ->
   Full vert horiz height width a
apply inverted
      (Permutation (Array (Shape shapeP) perm))
      (Array shape@(MatrixShape.Full order extent) a) =

   Array.unsafeCreateWithSize shape $ \blockSize bPtr -> do

   let (height,width) = Extent.dimensions extent
   Call.assert "Permutation.apply: heights mismatch" (height == shapeP)
   let m = Shape.size height
   let n = Shape.size width
   evalContT $ do
      fwdPtr <- Call.bool $ inverted==NonInverted
      mPtr <- Call.cint m
      nPtr <- Call.cint n
      kPtr <- deconsElementPtr <$> copyToTemp m perm
      aPtr <- ContT $ withForeignPtr a
      liftIO $ do
         copyBlock blockSize aPtr bPtr
         when (m>0 && n>0) $
            case order of
               RowMajor -> LapackGen.lapmt fwdPtr nPtr mPtr bPtr nPtr kPtr
               ColumnMajor -> LapackGen.lapmr fwdPtr mPtr nPtr bPtr mPtr kPtr


initMutable ::
   (Shape.Indexed sh, Shape.Index sh ~ ix, Storable a) =>
   sh -> (MutArray.Array (ST s) sh a -> ix -> ST s ()) ->
   ST s (MutArray.Array (ST s) sh a)
initMutable sh f = do
   arr <- MutArray.unsafeCreate sh (\ _ -> return ())
   mapM_ (f arr) $ Shape.indices sh
   return arr



-- cf. Shape.Deferred
newtype Shape sh = Shape {deconsShape :: sh}
   deriving (Eq, Show)

newtype Element sh = Element {deconsElement :: CInt}
   deriving (Eq, Show)

deconsElementPtr :: Ptr (Element sh) -> Ptr CInt
deconsElementPtr = castPtr

instance (Shape.C sh) => Shape.C (Shape sh) where
   size (Shape sh) = Shape.size sh
   uncheckedSize (Shape sh) = Shape.uncheckedSize sh

instance (Shape.C sh) => Shape.Indexed (Shape sh) where
   type Index (Shape sh) = Element sh
   indices (Shape sh) = map Element $ take (Shape.size sh) [1 ..]
   offset (Shape sh) (Element k) =
      Shape.offset (shapeInt $ Shape.size sh) (fromIntegral k - 1)
   uncheckedOffset _ (Element k) = fromIntegral k - 1
   inBounds (Shape sh) (Element k) =
      Shape.inBounds (shapeInt $ Shape.size sh) (fromIntegral k - 1)

instance (Shape.C sh) => Shape.InvIndexed (Shape sh) where
   indexFromOffset (Shape sh) k =
      Element $
         1 + fromIntegral (Shape.indexFromOffset (shapeInt $ Shape.size sh) k)
   uncheckedIndexFromOffset _sh = Element . (1+) . fromIntegral

instance Storable (Element sh) where
   {-# INLINE sizeOf #-}
   {-# INLINE alignment #-}
   {-# INLINE peek #-}
   {-# INLINE poke #-}
   sizeOf (Element k) = sizeOf k
   alignment (Element k) = alignment k
   poke p (Element k) = poke (deconsElementPtr p) k
   peek p = fmap Element $ peek (deconsElementPtr p)