{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE UndecidableInstances #-} module Numeric.LAPACK.Matrix.Mosaic.Private where import qualified Numeric.LAPACK.Matrix.Private as Matrix import qualified Numeric.LAPACK.Matrix.Layout.Private as Layout import qualified Numeric.LAPACK.Matrix.Extent.Private as Extent import Numeric.LAPACK.Matrix.Layout.Private (Order(RowMajor,ColumnMajor), flipOrder, uploFromOrder) import Numeric.LAPACK.Matrix.Modifier (Conjugation(NonConjugated)) import Numeric.LAPACK.Matrix.Private (Full) import Numeric.LAPACK.Scalar (zero) import Numeric.LAPACK.Shape.Private (Unchecked(Unchecked)) import Numeric.LAPACK.Private (pointerSeq, copyBlock, copyCondConjugateToTemp, pokeCInt, fill, withAutoWorkspaceInfo, withInfo, errorCodeMsg) import qualified Numeric.LAPACK.FFI.Generic as LapackGen import qualified Numeric.Netlib.Utility as Call import qualified Numeric.Netlib.Class as Class import qualified Data.Array.Comfort.Storable.Unchecked as Array import qualified Data.Array.Comfort.Shape as Shape import Data.Array.Comfort.Storable.Unchecked (Array(Array)) import Data.Array.Comfort.Shape ((::+)((::+))) import Foreign.Marshal.Alloc (alloca) import Foreign.Marshal.Array (advancePtr) import Foreign.C.Types (CInt) import Foreign.ForeignPtr (ForeignPtr, withForeignPtr) import Foreign.Ptr (Ptr) import Foreign.Storable (Storable) import Control.Monad.Trans.Cont (ContT(ContT), evalContT) import Control.Monad.IO.Class (liftIO) import Control.Applicative (pure, (<*>)) import Data.Foldable (forM_) type Mosaic pack mirror uplo sh = Array (Layout.Mosaic pack mirror uplo sh) type MosaicPacked mirror uplo sh = Mosaic Layout.Packed mirror uplo sh type MosaicUnpacked mirror uplo sh = Mosaic Layout.Unpacked mirror uplo sh type MosaicLower mirror sh = MosaicPacked mirror Shape.Lower sh type MosaicUpper mirror sh = MosaicPacked mirror Shape.Upper sh diagonalPointers :: (Storable a) => Order -> Int -> Ptr a -> [Ptr a] diagonalPointers order n aPtr = take n $ scanl advancePtr aPtr $ case order of RowMajor -> iterate pred n ColumnMajor -> iterate succ 2 diagonalPointerPairs :: (Storable a, Storable b) => Order -> Int -> Ptr a -> Ptr b -> [(Ptr a, Ptr b)] diagonalPointerPairs order n aPtr bPtr = zip (pointerSeq 1 aPtr) $ diagonalPointers order n bPtr columnMajorPointers :: (Storable a) => Int -> Ptr a -> Ptr a -> [(Int, ((Ptr a, Ptr a), Ptr a))] columnMajorPointers n fullPtr packedPtr = let ds = iterate succ 1 in take n $ zip ds $ zip (zip (pointerSeq 1 fullPtr) (pointerSeq n fullPtr)) (scanl advancePtr packedPtr ds) rowMajorPointers :: (Storable a) => Int -> Ptr a -> Ptr a -> [(Int, (Ptr a, Ptr a))] rowMajorPointers n fullPtr packedPtr = let ds = iterate pred n in take n $ zip ds $ zip (pointerSeq (n+1) fullPtr) (scanl advancePtr packedPtr ds) forPointers :: [(Int, a)] -> (Ptr CInt -> a -> IO ()) -> IO () forPointers xs act = alloca $ \nPtr -> forM_ xs $ \(d,ptrs) -> do pokeCInt nPtr d act nPtr ptrs copyTriangleToTemp :: Class.Floating a => Conjugation -> Order -> Int -> ForeignPtr a -> ContT r IO (Ptr a) copyTriangleToTemp conj order = copyCondConjugateToTemp $ case order of RowMajor -> conj ColumnMajor -> NonConjugated unpackToTemp :: Storable a => (Int -> Ptr a -> Ptr a -> IO ()) -> Int -> ForeignPtr a -> ContT r IO (Ptr a) unpackToTemp f n a = do apPtr <- ContT $ withForeignPtr a aPtr <- Call.allocaArray (n*n) liftIO $ f n apPtr aPtr return aPtr unpack :: Class.Floating a => Order -> Int -> Ptr a -> Ptr a -> IO () unpack order n packedPtr fullPtr = evalContT $ do uploPtr <- Call.char $ uploFromOrder order nPtr <- Call.cint n ldaPtr <- Call.leadingDim n liftIO $ withInfo errorCodeMsg "tpttr" $ LapackGen.tpttr uploPtr nPtr packedPtr fullPtr ldaPtr pack :: Class.Floating a => Order -> Int -> Ptr a -> Ptr a -> IO () pack order n = packRect order n n packRect :: Class.Floating a => Order -> Int -> Int -> Ptr a -> Ptr a -> IO () packRect order n ld fullPtr packedPtr = evalContT $ do uploPtr <- Call.char $ uploFromOrder order nPtr <- Call.cint n ldaPtr <- Call.leadingDim ld liftIO $ withInfo errorCodeMsg "trttp" $ LapackGen.trttp uploPtr nPtr fullPtr ldaPtr packedPtr unpackZero, _unpackZero :: Class.Floating a => Order -> Int -> Ptr a -> Ptr a -> IO () _unpackZero order n packedPtr fullPtr = do fill zero (n*n) fullPtr unpack order n packedPtr fullPtr unpackZero order n packedPtr fullPtr = do fillTriangle zero (flipOrder order) n fullPtr unpack order n packedPtr fullPtr fillTriangle :: Class.Floating a => a -> Order -> Int -> Ptr a -> IO () fillTriangle z order n aPtr = evalContT $ do uploPtr <- Call.char $ uploFromOrder order nPtr <- Call.cint n zPtr <- Call.number z liftIO $ LapackGen.laset uploPtr nPtr nPtr zPtr zPtr aPtr nPtr uncheck :: Mosaic pack mirror uplo sh a -> Mosaic pack mirror uplo (Unchecked sh) a uncheck = Array.mapShape $ \(Layout.Mosaic packing mirror uplo order sh) -> Layout.Mosaic packing mirror uplo order (Unchecked sh) recheck :: Mosaic pack mirror uplo (Unchecked sh) a -> Mosaic pack mirror uplo sh a recheck = Array.mapShape $ \(Layout.Mosaic packing mirror uplo order (Unchecked sh)) -> Layout.Mosaic packing mirror uplo order sh stack :: (Shape.C height, Eq height, Shape.C width, Eq width, Class.Floating a) => MosaicUpper mirror height a -> Matrix.General height width a -> MosaicUpper mirror width a -> MosaicUpper mirror (height::+width) a stack (Array sha a) (Array (Layout.Full order extent) b) (Array shc c) = let name = show $ Layout.mosaicMirror sha (height,width) = Extent.dimensions extent in Array.unsafeCreate (Layout.Mosaic Layout.Packed (Layout.mosaicMirror sha) Layout.Upper order (height ::+ width)) $ \xPtr -> do Call.assert (name++".stack: height shapes mismatch") $ height == Layout.mosaicSize sha Call.assert (name++".stack: width shapes mismatch") $ width == Layout.mosaicSize shc let m = Shape.size height let n = Shape.size width withForeignPtr a $ \aPtr -> copyTriangleA copyBlock order m n aPtr xPtr withForeignPtr b $ \bPtr -> copyRectangle copyBlock order m n bPtr xPtr withForeignPtr c $ \cPtr -> copyTriangleC copyBlock order m n cPtr xPtr takeTopRight :: (Shape.C height, Shape.C width, Class.Floating a) => MosaicUpper mirror (height::+width) a -> Matrix.General height width a takeTopRight (Array (Layout.Mosaic _packed _mirror _upper order (height::+width)) x) = Array.unsafeCreate (Layout.general order height width) $ \bPtr -> do let m = Shape.size height let n = Shape.size width withForeignPtr x $ copyRectangle (flip . copyBlock) order m n bPtr takeTopLeft :: (Shape.C height, Shape.C width, Class.Floating a) => MosaicUpper mirror (height::+width) a -> MosaicUpper mirror height a takeTopLeft (Array (Layout.Mosaic packing mirror upper order (height::+width)) x) = Array.unsafeCreate (Layout.Mosaic packing mirror upper order height) $ \aPtr -> do let m = Shape.size height let n = Shape.size width withForeignPtr x $ copyTriangleA (flip . copyBlock) order m n aPtr takeBottomRight :: (Shape.C height, Shape.C width, Class.Floating a) => MosaicUpper mirror (height::+width) a -> MosaicUpper mirror width a takeBottomRight (Array (Layout.Mosaic packing mirror upper order (height::+width)) x) = Array.unsafeCreate (Layout.Mosaic packing mirror upper order width) $ \cPtr -> do let m = Shape.size height let n = Shape.size width withForeignPtr x $ copyTriangleC (flip . copyBlock) order m n cPtr {-# INLINE copyTriangleA #-} copyTriangleA :: (Class.Floating a) => (Int -> Ptr a -> Ptr a -> IO ()) -> Order -> Int -> Int -> Ptr a -> Ptr a -> IO () copyTriangleA copy order m n aPtr xPtr = case order of ColumnMajor -> copy (Shape.triangleSize m) aPtr xPtr RowMajor -> forM_ (zip (iterate pred m) $ zip (diagonalPointers order m aPtr) (diagonalPointers order (m+n) xPtr)) $ \(k,(aiPtr,xiPtr)) -> copy k aiPtr xiPtr {-# INLINE copyTriangleC #-} copyTriangleC :: (Class.Floating a) => (Int -> Ptr a -> Ptr a -> IO ()) -> Order -> Int -> Int -> Ptr a -> Ptr a -> IO () copyTriangleC copy order m n cPtr xPtr = case order of RowMajor -> let triSize = Shape.triangleSize n in copy triSize cPtr (advancePtr xPtr $ Shape.triangleSize (m+n) - triSize) ColumnMajor -> forM_ (zip (iterate succ 0) $ zip (diagonalPointers order n cPtr) (drop m $ diagonalPointers order (m+n) xPtr)) $ \(k,(aiPtr,xiPtr)) -> copy (k+1) (advancePtr aiPtr (-k)) (advancePtr xiPtr (-k)) {-# INLINE copyRectangle #-} copyRectangle :: (Class.Floating a) => (Int -> Ptr a -> Ptr a -> IO ()) -> Order -> Int -> Int -> Ptr a -> Ptr a -> IO () copyRectangle copy order m n bPtr xPtr = case order of RowMajor -> forM_ (take m $ zip (iterate pred m) $ zip (pointerSeq n bPtr) (diagonalPointers order (m+n) xPtr)) $ \(k,(biPtr,xiPtr)) -> copy n biPtr (advancePtr xiPtr k) ColumnMajor -> forM_ (take n $ zip (iterate succ m) $ zip (pointerSeq m bPtr) (drop m $ diagonalPointers order (m+n) xPtr)) $ \(k,(biPtr,xiPtr)) -> copy m biPtr (advancePtr xiPtr (-k)) type Triangular uplo sh = Array (Layout.Triangular uplo sh) type Lower sh = Triangular Shape.Lower sh type Upper sh = Triangular Shape.Upper sh newtype MultiplyRight sh a b uplo = MultiplyRight {getMultiplyRight :: Triangular uplo sh a -> b} newtype Map pack mirror sh0 sh1 a uplo = Map { getMap :: Mosaic pack mirror uplo sh0 a -> Mosaic pack mirror uplo sh1 a } fromBanded :: (Class.Floating a) => Int -> Order -> Int -> ForeignPtr a -> Int -> Ptr a -> IO () fromBanded k order n a bSize bPtr = withForeignPtr a $ \aPtr -> do fill zero bSize bPtr let lda = k+1 let pointers = zip [0..] $ zip (pointerSeq lda aPtr) $ diagonalPointers order n bPtr case order of ColumnMajor -> forM_ pointers $ \(i,(xPtr,yPtr)) -> let j = min i k in copyBlock (j+1) (advancePtr xPtr (k-j)) (advancePtr yPtr (-j)) RowMajor -> forM_ pointers $ \(i,(xPtr,yPtr)) -> copyBlock (min lda (n-i)) xPtr yPtr {- Naming is inconsistent to Triangular.takeUpper, because here Hermitian is the input and in Triangular.takeUpper, Triangular is the output. -} takeUpper :: MosaicUpper mirror sh a -> Upper sh a takeUpper = Array.mapShape (\(Layout.Mosaic packing _mirror upper order sh) -> Layout.Mosaic packing Layout.NoMirror upper order sh) fromUpper :: (Layout.Mirror mirror) => Upper sh a -> MosaicUpper mirror sh a fromUpper = Array.mapShape (\(Layout.Mosaic packing Layout.NoMirror upper order sh) -> Layout.Mosaic packing Layout.autoMirror upper order sh) fromLowerPart :: (Extent.Measure meas, Extent.C horiz, Shape.C height, Shape.C width, Class.Floating a) => (Order -> Int -> Ptr a -> IO ()) -> Layout.MirrorSingleton mirror -> Full meas Extent.Small horiz height width a -> MosaicLower mirror height a fromLowerPart fillDiag mirror (Array (Layout.Full order extent) a) = let (height,width) = Extent.dimensions extent m = Shape.size height n = Shape.size width k = case order of RowMajor -> n; ColumnMajor -> m in Array.unsafeCreate (Layout.Mosaic Layout.Packed mirror Layout.Lower order height) $ \lPtr -> withForeignPtr a $ \aPtr -> do let dstOrder = flipOrder order packRect dstOrder m k aPtr lPtr fillDiag dstOrder m lPtr leaveDiagonal :: Order -> Int -> Ptr a -> IO () leaveDiagonal _order _m _ptr = return () data Labelled r label a = Labelled label (ContT r IO a) label :: label -> a -> Labelled r label a label lab a = Labelled lab (pure a) noLabel :: a -> Labelled r () a noLabel a = Labelled () (pure a) instance Functor (Labelled r label) where fmap f (Labelled lab a) = Labelled lab $ fmap f a runUnlabelled :: Labelled r () (IO ()) -> ContT r IO () runUnlabelled (Labelled () m) = liftIO =<< m runLabelledLinear :: String -> Labelled r String (Ptr CInt -> IO ()) -> ContT r IO () runLabelledLinear msg (Labelled name m) = liftIO . withInfo msg name =<< m runLabelledWorkspace :: (Class.Floating a) => String -> Labelled r String (Ptr a -> Ptr CInt -> Ptr CInt -> IO ()) -> ContT r IO () runLabelledWorkspace msg (Labelled name m) = liftIO . withAutoWorkspaceInfo msg name =<< m data Labelled2 r label a b = Labelled2 (Labelled r label a) (Labelled r label b) instance Functor (Labelled2 r label a) where fmap f (Labelled2 a b) = Labelled2 a (fmap f b) infixl 9 $*, $** ($*) :: Labelled2 r label (a -> f) (a -> g) -> a -> Labelled2 r label f g Labelled2 f g $* a = Labelled2 (fmap ($a) f) (fmap ($a) g) ($**) :: Labelled2 r label (a -> f) (a -> Ptr CInt -> g) -> (a,Int) -> Labelled2 r label f g Labelled2 f (Labelled lab g) $** (a,n) = Labelled2 (fmap ($a) f) (Labelled lab $ fmap ($a) g <*> Call.leadingDim n) runPacking :: Layout.PackingSingleton pack -> Labelled2 r label func func -> Labelled r label func runPacking pck (Labelled2 lp lu) = case pck of Layout.Packed -> lp Layout.Unpacked -> lu withPacking :: Layout.PackingSingleton pack -> Labelled2 r () (IO ()) (IO ()) -> ContT r IO () withPacking pck = runUnlabelled . runPacking pck withPackingLinear :: (func ~ (Ptr CInt -> IO ())) => String -> Layout.PackingSingleton pack -> Labelled2 r String func func -> ContT r IO () withPackingLinear msg pck = runLabelledLinear msg . runPacking pck data TriArg a = TriArg (Ptr a) Int triArg :: Ptr a -> Int -> TriArg a triArg = TriArg applyFuncPair :: (m ~ Labelled (FuncCont f) (FuncLabel f), FunctionPair f) => m (FuncPacked f) -> m (FuncUnpacked f) -> f applyFuncPair f g = apply (Labelled2 f g) class FunctionPair f where type FuncCont f type FuncLabel f type FuncPacked f type FuncUnpacked f apply :: Labelled2 (FuncCont f) (FuncLabel f) (FuncPacked f) (FuncUnpacked f) -> f type family LabelResult a type instance LabelResult (Labelled r label a) = a instance FunctionPair (Labelled2 r label a b) where type FuncCont (Labelled2 r label a b) = r type FuncLabel (Labelled2 r label a b) = label type FuncPacked (Labelled2 r label a b) = a type FuncUnpacked (Labelled2 r label a b) = b apply = id instance (FunctionArg a, FunctionPair f) => FunctionPair (a -> f) where type FuncCont (a -> f) = FuncCont f type FuncLabel (a -> f) = FuncLabel f type FuncPacked (a -> f) = FuncArgPacked a f type FuncUnpacked (a -> f) = FuncArgUnpacked a f apply = applyArg class FunctionArg a where type FuncArgPacked a f type FuncArgUnpacked a f applyArg :: (FunctionPair f) => Labelled2 (FuncCont f) (FuncLabel f) (FuncArgPacked a f) (FuncArgUnpacked a f) -> a -> f instance FunctionArg (Ptr a) where type FuncArgPacked (Ptr a) f = Ptr a -> FuncPacked f type FuncArgUnpacked (Ptr a) f = Ptr a -> FuncUnpacked f applyArg fg a = apply (fg$*a) instance FunctionArg (TriArg a) where type FuncArgPacked (TriArg a) f = Ptr a -> FuncPacked f type FuncArgUnpacked (TriArg a) f = Ptr a -> Ptr CInt -> FuncUnpacked f applyArg fg (TriArg a n) = apply (fg$**(a,n))