{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module Numeric.LAPACK.Permutation.Private where
import qualified Numeric.LAPACK.Matrix.Layout.Private as Layout
import qualified Numeric.LAPACK.Matrix.Extent.Private as Extent
import qualified Numeric.LAPACK.Output as Output
import Numeric.LAPACK.Output (Output, formatAligned)
import Numeric.LAPACK.Matrix.Layout.Private (Order(RowMajor, ColumnMajor))
import Numeric.LAPACK.Matrix.Modifier
(Transposition(NonTransposed,Transposed),
Inversion(NonInverted,Inverted))
import Numeric.LAPACK.Matrix.Private (Full, Square)
import Numeric.LAPACK.Vector (Vector)
import Numeric.LAPACK.Scalar (zero, one)
import Numeric.LAPACK.Private (copyBlock, copyToTemp)
import qualified Numeric.LAPACK.FFI.Generic as LapackGen
import qualified Numeric.Netlib.Utility as Call
import qualified Numeric.Netlib.Class as Class
import qualified Data.Array.Comfort.Storable.Mutable.Unchecked as MutArray
import qualified Data.Array.Comfort.Storable.Unchecked as Array
import qualified Data.Array.Comfort.Storable as CheckedArray
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable.Unchecked (Array(Array), (!))
import Foreign.C.Types (CInt)
import Foreign.ForeignPtr (withForeignPtr, castForeignPtr)
import Foreign.Ptr (Ptr, castPtr)
import Foreign.Storable (Storable, sizeOf, alignment, poke, peek)
import Control.Monad.Trans.Cont (ContT(ContT), evalContT)
import Control.Monad.IO.Class (liftIO)
import Control.Monad.ST (ST, runST)
import Control.Monad (when, forM_)
import Control.Applicative (liftA2, (<$>))
import qualified Data.Tuple.HT as Tuple
import Data.Functor.Identity (Identity(Identity))
import Data.Function.HT (powerAssociative)
import Data.Monoid (Monoid, mempty, mappend)
import Data.Semigroup (Semigroup, (<>))
import Prelude hiding (odd)
newtype Permutation sh = Permutation (Vector (Shape sh) (Element sh))
deriving (Eq, Show)
format :: (Shape.C sh, Output out) => Permutation sh -> out
format (Permutation perm) =
let n = Shape.size $ Array.shape perm
s0 = Output.text "."
s1 = Output.text "1"
in formatAligned $
map (\k -> map Identity $ replicate (k-1) s0 ++ s1 : replicate (n-k) s0) $
map (fromIntegral . deconsElement) $ Array.toList perm
layout :: (Shape.C sh, Class.Floating a) => Permutation sh -> [[Maybe a]]
layout (Permutation perm) =
let n = Shape.size $ Array.shape perm
z = Nothing
in map (\k -> replicate (k-1) z ++ Just one : replicate (n-k) z) $
map (fromIntegral . deconsElement) $ Array.toList perm
size :: Permutation sh -> sh
size (Permutation (Array (Shape shape) _perm)) = shape
mapSizeUnchecked :: (shA -> shB) -> Permutation shA -> Permutation shB
mapSizeUnchecked f (Permutation (Array (Shape shape) perm)) =
Permutation $ Array (Shape $ f shape) $ castForeignPtr perm
mapSize ::
(Shape.C shA, Shape.C shB) =>
(shA -> shB) -> Permutation shA -> Permutation shB
mapSize f = mapSizeUnchecked (Layout.mapChecked "Permutation.mapSize" f)
identity :: (Shape.C sh) => sh -> Permutation sh
identity shape = Permutation $ CheckedArray.sample (Shape shape) id
fromPivots ::
(Shape.C sh) =>
Inversion -> Vector (Shape sh) (Element sh) -> Permutation sh
fromPivots inverted ipiv =
fromPivotsGen id inverted (Array.shape ipiv) ipiv
fromTruncatedPivots ::
(Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
(Shape.C height, Shape.C width) =>
(diagShape ~ Layout.RectangularDiagonal meas vert horiz height width) =>
Inversion -> Vector (Shape diagShape) (Element height) -> Permutation height
fromTruncatedPivots inverted ipiv =
fromPivotsGen (\(Element i) -> Element i)
inverted (Layout.bandedHeight <$> Array.shape ipiv) ipiv
fromPivotsGen ::
(Shape.C sh, Shape.C small) =>
(Element small -> Element sh) ->
Inversion -> Shape sh -> Vector (Shape small) (Element sh) -> Permutation sh
fromPivotsGen mapIx inverted sh ipiv =
Permutation $
runST (do
perm <- initMutable sh $ \perm i -> MutArray.write perm i i
forM_ (indices inverted $ Array.shape ipiv) $ \i ->
swap perm (mapIx i) (ipiv!i)
MutArray.unsafeFreeze perm)
swap ::
(Shape.Indexed sh, Shape.Index sh ~ ix, Storable a) =>
MutArray.Array (ST s) sh a -> ix -> ix -> ST s ()
swap arr i j = do
a <- MutArray.read arr i
MutArray.write arr i =<< MutArray.read arr j
MutArray.write arr j a
indices :: (Shape.C sh) => Inversion -> Shape sh -> [Element sh]
indices inverted sh =
let numIPiv = Shape.size sh
in take numIPiv $ map Element $
case inverted of
Inverted -> iterate (subtract 1) (fromIntegral numIPiv)
NonInverted -> iterate (1+) 1
toPivots ::
(Shape.C sh) => Inversion -> Permutation sh -> Vector (Shape sh) (Element sh)
toPivots inverted (Permutation a) =
let sh = Array.shape a
in Array.reshape sh $
runST (do
(inv,perm) <-
(case inverted of Inverted -> Tuple.swap; NonInverted -> id)
<$>
liftA2 (,)
(MutArray.thaw a)
(transposeToMutable a)
forM_ (Shape.indices sh) $ \i -> do
j <- MutArray.read inv i
k <- MutArray.read perm i
MutArray.write perm j k
MutArray.write inv k j
MutArray.unsafeFreeze inv)
data Sign = Positive | Negative
deriving (Eq, Show, Enum, Bounded)
instance Semigroup Sign where
x<>y = if x==y then Positive else Negative
instance Monoid Sign where
mempty = Positive
mappend = (<>)
determinant :: (Shape.C sh) => Permutation sh -> Sign
determinant =
(\oddp -> if oddp then Negative else Positive) .
odd . map deconsElement . Array.toList . toPivots NonInverted
numberFromSign :: (Class.Floating a) => Sign -> a
numberFromSign s =
case s of
Negative -> -1
Positive -> 1
condNegate :: (Class.Floating a) => [CInt] -> a -> a
condNegate ipiv = if odd ipiv then negate else id
odd :: [CInt] -> Bool
odd = not . null . dropEven . filter id . zipWith (/=) [1..]
dropEven :: [a] -> [a]
dropEven (_:_:xs) = dropEven xs
dropEven xs = xs
transpose :: (Shape.C sh) => Permutation sh -> Permutation sh
transpose (Permutation perm) =
Permutation $ runST (MutArray.unsafeFreeze =<< transposeToMutable perm)
transposeToMutable ::
(Shape.Indexed sh, Shape.Index sh ~ ix, Storable ix) =>
Array sh ix -> ST s (MutArray.Array (ST s) sh ix)
transposeToMutable perm =
initMutable (Array.shape perm) $ \inv i -> MutArray.write inv (perm!i) i
inversionFromTransposition :: Transposition -> Inversion
inversionFromTransposition trans =
case trans of
NonTransposed -> NonInverted
Transposed -> Inverted
multiply :: (Shape.C sh, Eq sh) =>
Permutation sh -> Permutation sh -> Permutation sh
multiply a b =
if size a /= size b
then error "Permutation.multiply: sizes mismatch"
else multiplyUnchecked a b
square :: (Shape.C sh) => Permutation sh -> Permutation sh
square p = multiplyUnchecked p p
power :: (Shape.C sh) => Integer -> Permutation sh -> Permutation sh
power n p = powerAssociative multiplyUnchecked (identity $ size p) p n
multiplyUnchecked :: (Shape.C sh) =>
Permutation sh -> Permutation sh -> Permutation sh
multiplyUnchecked (Permutation a) (Permutation b) =
Permutation $ CheckedArray.sample (Array.shape a) $ \i -> b!(a!i)
takeDiagonal ::
(Shape.C sh, Class.Floating a) => Permutation sh -> Vector sh a
takeDiagonal (Permutation a) =
Array.mapShape deconsShape $
CheckedArray.sample (Array.shape a) $ \i -> if a!i == i then 1 else 0
toMatrix :: (Shape.C sh, Class.Floating a) => Permutation sh -> Square sh a
toMatrix (Permutation perm) =
let shape = Array.shape perm
in Array.reshape (Layout.square RowMajor $ deconsShape shape) $
runST (do
a <- MutArray.new (shape,shape) zero
forM_ (Shape.indices $ Array.shape perm) $ \k ->
MutArray.write a (k, perm!k) one
MutArray.unsafeFreeze a)
apply ::
(Extent.Measure meas, Extent.C vert, Extent.C horiz,
Shape.C height, Eq height, Shape.C width, Class.Floating a) =>
Inversion -> Permutation height ->
Full meas vert horiz height width a ->
Full meas vert horiz height width a
apply inverted
(Permutation (Array (Shape shapeP) perm))
(Array shape@(Layout.Full order extent) a) =
Array.unsafeCreateWithSize shape $ \blockSize bPtr -> do
let (height,width) = Extent.dimensions extent
Call.assert "Permutation.apply: heights mismatch" (height == shapeP)
let m = Shape.size height
let n = Shape.size width
evalContT $ do
fwdPtr <- Call.bool $ inverted==NonInverted
mPtr <- Call.cint m
nPtr <- Call.cint n
kPtr <- deconsElementPtr <$> copyToTemp m perm
aPtr <- ContT $ withForeignPtr a
liftIO $ do
copyBlock blockSize aPtr bPtr
when (m>0 && n>0) $
case order of
RowMajor -> LapackGen.lapmt fwdPtr nPtr mPtr bPtr nPtr kPtr
ColumnMajor -> LapackGen.lapmr fwdPtr mPtr nPtr bPtr mPtr kPtr
initMutable ::
(Shape.Indexed sh, Shape.Index sh ~ ix, Storable a) =>
sh -> (MutArray.Array (ST s) sh a -> ix -> ST s ()) ->
ST s (MutArray.Array (ST s) sh a)
initMutable sh f = do
arr <- MutArray.unsafeCreate sh (\ _ -> return ())
mapM_ (f arr) $ Shape.indices sh
return arr
newtype Shape sh = Shape {deconsShape :: sh}
deriving (Eq, Show)
newtype Element sh = Element {deconsElement :: CInt}
deriving (Eq, Show)
deconsElementPtr :: Ptr (Element sh) -> Ptr CInt
deconsElementPtr = castPtr
instance Functor Shape where
fmap f (Shape sh) = Shape $ f sh
instance (Shape.C sh) => Shape.C (Shape sh) where
size (Shape sh) = Shape.size sh
shapeInt :: (Shape.C sh) => sh -> Shape.OneBased Int
shapeInt = Shape.OneBased . Shape.size
instance (Shape.C sh) => Shape.Indexed (Shape sh) where
type Index (Shape sh) = Element sh
indices (Shape sh) = map Element $ take (Shape.size sh) [1 ..]
unifiedOffset (Shape sh) (Element k) =
Shape.unifiedOffset (shapeInt sh) (fromIntegral k)
inBounds (Shape sh) (Element k) =
Shape.inBounds (shapeInt sh) (fromIntegral k)
instance (Shape.C sh) => Shape.InvIndexed (Shape sh) where
unifiedIndexFromOffset (Shape sh) =
fmap (Element . fromIntegral) . Shape.unifiedIndexFromOffset (shapeInt sh)
instance Storable (Element sh) where
{-# INLINE sizeOf #-}
{-# INLINE alignment #-}
{-# INLINE peek #-}
{-# INLINE poke #-}
sizeOf (Element k) = sizeOf k
alignment (Element k) = alignment k
poke p (Element k) = poke (deconsElementPtr p) k
peek p = fmap Element $ peek (deconsElementPtr p)