{-# LANGUAGE BangPatterns #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE DeriveDataTypeable #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FunctionalDependencies #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables#-} -- {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE NoImplicitPrelude #-} module Numerical.Array.Layout.Builder where import Control.Monad.Primitive ( PrimMonad, PrimState ) import qualified Data.Vector.Generic as VG import qualified Data.Vector.Generic.Mutable as VGM import Numerical.Array.Layout.Base import Numerical.Array.Layout.Dense as Dense --import Numerical.Array.Layou.Sparse as Sparse --import Numerical.Data.Vector.Pair import Control.Monad.ST (runST) import Data.Typeable import qualified Data.Foldable as F import Data.Traversable as T import Control.Applicative as A import Numerical.Data.Vector.Pair import Numerical.Array.Layout.Sparse import Data.Vector.Algorithms.Intro as IntroSort import Data.List (group) import Numerical.InternalUtils import Prelude hiding (error) data BatchInit v = BatchInit { batchInitSize :: !Int ,batchInitKV :: !(Either [v] (IntFun v)) } deriving (Typeable) materializeBatchMV :: (PrimMonad m, VGM.MVector mv a) => BatchInit a -> m (mv (PrimState m) a) materializeBatchMV (BatchInit size (Left ls )) = do newMV <- VGM.new size _ <- Prelude.mapM (\(ix ,val )-> VGM.unsafeWrite newMV ix val ) (zip [0..] $ take size ls) return newMV materializeBatchMV (BatchInit size (Right (IntFun f) )) = do newMV <- VGM.new size _ <- Prelude.mapM (\ix -> do v <- (f ix) ; VGM.unsafeWrite newMV ix v ) $ take size [0..] return newMV --- not sure if this is EVER useful newtype AnyMV mv e = AMV (forall s . mv s e ) instance (Show a)=> Show (BatchInit a) where show (BatchInit size (Left ls) ) | size > 100 = "(BatchInit " ++show size ++ "-- only showing the first 100 elements\n" ++ "(Left "++(show $ take 100 ls ) ++ "))\n" | otherwise ="(BatchInit " ++show size ++ " (Left "++(show ls ) ++ "))\n" show (BatchInit size (Right (IntFun f)) ) | size > 100 = "(BatchInit " ++show size ++ "-- only showing the first 100 elements\n" ++ "(Left "++(show $ runST (Prelude.mapM f [0..100]) ) ++ "))\n" | otherwise ="(BatchInit " ++show size ++ "(Left "++(show $ runST (Prelude.mapM f [0,1..size -1]) ) ++ "))\n" newtype IntFun a = IntFun (forall m. (PrimMonad m)=> Int -> m a ) -- This may change substantially in a future release, but for now -- acts like deriving (Typeable) instance Functor IntFun where fmap f (IntFun g) = IntFun (\x-> g x >>= (\ y -> return (f y)) ) {-# INLINE fmap #-} instance Functor BatchInit where {-# INLINE fmap #-} fmap = \f bival -> case bival of (BatchInit size (Left ls))-> BatchInit size (Left (Prelude.map f ls )) (BatchInit size (Right gfun))-> BatchInit size (Right $ fmap f gfun ) -- batchInit size should be Word rather than Int for size, but Vector is lame like that {- ChoiceT from monad lib is tempting as is one of the ListT done right Bundle from Vector 0.11 and Stream from 0.10 are both alluring too but all of them make things complicated, punt for now ALso: I may want/need to distinguish sparse vs dense builders and put them into different classes, punting that for now -} fromListBI :: [a] -> BatchInit a fromListBI ls = BatchInit (length ls) (Left ls) fromVectorBI :: VG.Vector v e => v e -> BatchInit e fromVectorBI v = BatchInit size (Right (IntFun $ \i -> if i >= size then error $ " out of bounds index on IntFun of size: " ++ show i else return $ v VG.! i )) where size = VG.length v fromMVectorBI :: (VGM.MVector mv e ) => AnyMV mv e -> BatchInit e fromMVectorBI (AMV v) = BatchInit size (Right (IntFun $ \i -> if i >= size then error $ " out of bounds index on IntFun of size: " ++ show i else v `VGM.read` i )) where size = VGM.length v class Layout form (rank::Nat) => LayoutBuilder form (rank::Nat) | form -> rank where buildFormatM :: (store~FormatStorageRep form,Buffer store Int ,Buffer store a,PrimMonad m)=> Index rank -> proxy form -> a -> Maybe (BatchInit (Index rank ,a)) ->m (form, BufferMut store (PrimState m) a ) buildFormatPure:: forall store form rank proxy m a. (LayoutBuilder form (rank::Nat) ,store~FormatStorageRep form,Buffer store Int ,Buffer store a, Monad m ) => Index rank -> proxy form -> a -> Maybe (BatchInit (Index rank ,a)) ->m (form, BufferPure store a ) buildFormatPure shape prox defaultValue builder = do res@(!_,!_)<-return $! theComputation return res where theComputation :: (form,BufferPure store a ) !theComputation = runST $ do (form,buf) <- buildFormatM shape prox defaultValue builder pureBuff <- VG.unsafeFreeze buf return (form, pureBuff) {- this is a funky api for both dense and sparse arrays general builder format. given the target shape, logical dimensions,a default value (only used for dense arrays) and the list of manifest values (mostly only used for sparse), build the format descriptor and the suitably initialized and sized values buffer this api is only meant for internal use for building new array values TODO: compare using a catenable priority heap vs just doing fast sorting. -} {- the dense instances ignore the builder structure, which does suggest that maybe there shoudl be a dense builder layout class and a sparse layout class separately -} instance LayoutBuilder (Format Direct 'Contiguous ('S 'Z) rep) ('S 'Z) where buildFormatM (size:* _) _ defaultValue _ = do buf<- VGM.replicate size defaultValue return (FormatDirectContiguous size,buf) -- really wish I didn't have to write the foldable and traversable constraints -- seems like a code smell?! instance (F.Foldable (Shape r),T.Traversable (Shape r) ,A.Applicative (Shape r)) => LayoutBuilder (Format Row 'Contiguous r rep) r where buildFormatM ix _ defaultValue _ = do buf<- VGM.replicate (F.foldl' (*) 0 ix) defaultValue return (FormatRowContiguous ix,buf) instance (F.Foldable (Shape r),T.Traversable (Shape r) ,A.Applicative (Shape r)) => LayoutBuilder (Format Column 'Contiguous r rep) r where buildFormatM ix _ defaultValue _ = do buf<- VGM.replicate (F.foldl' (*) 0 ix) defaultValue return (FormatColumnContiguous ix,buf) isStrictlyMonotonicV ::(VG.Vector v e)=> (e -> e->Ordering)-> v e -> Maybe Int isStrictlyMonotonicV cmp v = go 0 (VG.length v) where go !i !len | i+1 >= len = Nothing | (v VG.! i) `lt` (v VG.! (i+1))= go (i+1) len | otherwise = Just i lt a b = case cmp a b of LT -> True _ -> False instance (Buffer rep Int)=>LayoutBuilder (Format DirectSparse 'Contiguous ('S 'Z) rep ) ('S 'Z) where buildFormatM (size:* _) _ _ Nothing = do mvI <- VGM.new 0 vI <- VG.unsafeFreeze mvI mvV <- VGM.new 0 return $! (FormatDirectSparseContiguous size 0 vI, mvV) buildFormatM (size:* _) _ _ (Just builder)= do -- need to use let so type inference doesnt totally barf mvt@(MVPair (MVLeaf ix) (MVLeaf val)) <- materializeBatchMV $ fmap ( \((ix:*_),v)-> (ix,v)) builder -- if i swap to using this instead of ix <- mat.. ; val <- mat.. --i get CRAZY type errors -- could this be a ghc bug? --ix <- materializeBatchMV $ fmap fst builtTup --val <- materializeBatchMV $ fmap snd builtTup _<- IntroSort.sortBy (\x y -> compare (fst x) (fst y)) mvt -- this lets me sort a pair of arrays! vIx <- VG.unsafeFreeze ix optFail <- return $ isStrictlyMonotonicV compare vIx --_hoelly case optFail of Nothing -> return (FormatDirectSparseContiguous size 0 vIx, val) Just ixWrong -> error $ "DirectSparse Index duplication at index "++ show (vIx VG.! ixWrong) instance (Buffer rep Int) => LayoutBuilder (Format CompressedSparseRow 'Contiguous ('S ('S 'Z)) rep ) ('S ('S 'Z)) where buildFormatM (x:* y :* _) _ _ Nothing= do mvi <- VGM.new 0 vi <- VG.unsafeFreeze mvi mvval <- VGM.new 0 return $ (FormatContiguousCompressedSparseRow (FormatContiguousCompressedSparseInternal y x vi vi), mvval ) buildFormatM (x:* y :* _) proxyFormat _ (Just builder) = do mvtup@(MVPair (MVPair (MVLeaf mvectYs) (MVLeaf mvectXs)) (MVLeaf mvectVals))<- materializeBatchMV $ fmap (\((xix:* yix :* _),val)-> ((yix,xix),val) ) builder _ <- IntroSort.sortBy (\((y1,x1),_) ((y2,x2),_) -> basicCompareIndex proxyFormat (x1:*y1 :* Nil) (x2:*y2:* Nil) ) mvtup vectXs <- unsafeBufferFreeze mvectXs vectYs <- unsafeBufferFreeze mvectYs --- predicate check here wrt monotonicity of --- compute runlength partial sums of where ys go -- need to actually check yRunsVect <- return $ VG.replicate (y+1) (0::Int) VG.// computeStarts (computeRunLengths vectYs) 0 y --_ <- (error "computeRUnCount") vectYs yRunsMVect --yRunsVect <- unsafeBufferFreeze yRunsMVect let xyVect = (VPair (VLeaf vectXs) (VLeaf vectYs)) optFail <- return $ isStrictlyMonotonicV (\(x1,y1) (x2,y2)->basicCompareIndex proxyFormat (x1:*y1:*Nil) (x2:*y2:*Nil)) xyVect case optFail of Nothing -> return $ (FormatContiguousCompressedSparseRow (FormatContiguousCompressedSparseInternal y x vectXs yRunsVect), mvectVals ) Just i -> error $ "illegal duplication in CSR builder (x,y) coordinates " ++ show (xyVect VG.! i) ++ " and " ++ show (xyVect VG.! (i+1)) ++ "starting at position " ++ show i computeRunLengths :: (VG.Vector v e, Eq e)=> v e -> [(e,Int)] computeRunLengths = \y -> fmap (\x ->(head x,length x)) $ group $ VG.toList y {-# SPECIALIZE INLINE computeStarts :: [(Int,Int)]->Int->Int ->[(Int,Int)] #-} computeStarts:: (Enum a, Ord a, Num b )=>[(a,b)]-> a -> a -> [(a,b)] computeStarts [] start end | start <= end = fmap (\x -> (x ,0)) [start..end] | otherwise = error "bad start end arguments to computeStarts" computeStarts ls start end | start <= end = go start 0 ls | otherwise = error "bad start end arguments to computeStarts" where --go :: a ->b->[(a,b)]-> [(a,b)] go !posNext preSum [] | posNext <= end = fmap (\x -> (x,preSum)) [posNext .. end] | otherwise = error "impossible go computeStarts " go !posNext !preSum gls@((posAt,atSum):rest) | posNext < posAt= (posNext,preSum): go (succ posNext) preSum gls | posNext == posAt = (posNext,preSum) : go (succ posNext) (preSum + atSum) rest | otherwise = error "bad position in prefix stream for computeStarts go, literally unpossible " --computeStarts :: (Eq a, Num a)=> [(a,Int)]->Int -> [(a,Int)] --computeStarts [] len = map (\x -> (x ,0)) [0..len] --computeStarts ls len = go 0 0 ls -- where -- go preSum place [] | place > len = [] -- | place == len = [(place,preSum)] -- | otherwise = map