module Basement.Alg.Foreign.PrimArray
( findIndexElem
, revFindIndexElem
, findIndexPredicate
, revFindIndexPredicate
, foldl
, foldr
, foldl1
, all
, any
, filter
, primIndex
, inplaceSortBy
) where
import GHC.Types
import GHC.Prim
import Basement.Compat.Base
import Basement.Numerical.Additive
import Basement.Numerical.Multiplicative
import Basement.Types.OffsetSize
import Basement.PrimType
import Basement.Monad
import Basement.Alg.Foreign.Prim
findIndexElem :: PrimType ty => ty -> Immutable -> Offset ty -> Offset ty -> Offset ty
findIndexElem ty ba startIndex endIndex = loop startIndex
where
loop !i
| i < endIndex && t /= ty = loop (i+1)
| otherwise = i
where t = primIndex ba i
revFindIndexElem :: PrimType ty => ty -> Immutable -> Offset ty -> Offset ty -> Offset ty
revFindIndexElem ty ba startIndex endIndex
| endIndex > startIndex = loop (endIndex `offsetMinusE` 1)
| otherwise = endIndex
where
loop !i
| t == ty = i
| i > startIndex = loop (i `offsetMinusE` 1)
| otherwise = endIndex
where t = primIndex ba i
findIndexPredicate :: PrimType ty => (ty -> Bool) -> Immutable -> Offset ty -> Offset ty -> Offset ty
findIndexPredicate predicate ba !startIndex !endIndex = loop startIndex
where
loop !i
| i < endIndex && not found = loop (i+1)
| otherwise = i
where found = predicate (primIndex ba i)
revFindIndexPredicate :: PrimType ty => (ty -> Bool) -> Immutable -> Offset ty -> Offset ty -> Offset ty
revFindIndexPredicate predicate ba startIndex endIndex
| endIndex > startIndex = loop (endIndex `offsetMinusE` 1)
| otherwise = endIndex
where
loop !i
| found = i
| i > startIndex = loop (i `offsetMinusE` 1)
| otherwise = endIndex
where found = predicate (primIndex ba i)
foldl :: PrimType ty => (a -> ty -> a) -> a -> Immutable -> Offset ty -> Offset ty -> a
foldl f !initialAcc ba !startIndex !endIndex = loop startIndex initialAcc
where
loop !i !acc
| i == endIndex = acc
| otherwise = loop (i+1) (f acc (primIndex ba i))
foldr :: PrimType ty => (ty -> a -> a) -> a -> Immutable -> Offset ty -> Offset ty -> a
foldr f !initialAcc ba startIndex endIndex = loop startIndex
where
loop !i
| i == endIndex = initialAcc
| otherwise = primIndex ba i `f` loop (i+1)
foldl1 :: PrimType ty => (ty -> ty -> ty) -> Immutable -> Offset ty -> Offset ty -> ty
foldl1 f ba startIndex endIndex = loop (startIndex+1) (primIndex ba startIndex)
where
loop !i !acc
| i == endIndex = acc
| otherwise = loop (i+1) (f acc (primIndex ba i))
filter :: (PrimMonad prim, PrimType ty)
=> (ty -> Bool) -> MutableByteArray# (PrimState prim) -> Immutable -> Offset ty -> Offset ty -> prim (CountOf ty)
filter predicate dst src start end = loop azero start
where
loop !d !s
| s == end = pure (offsetAsSize d)
| predicate v = primMbaWrite dst d v >> loop (d+Offset 1) (s+Offset 1)
| otherwise = loop d (s+Offset 1)
where
v = primIndex src s
all :: PrimType ty => (ty -> Bool) -> Immutable -> Offset ty -> Offset ty -> Bool
all predicate ba start end = loop start
where
loop !i
| i == end = True
| predicate (primIndex ba i) = loop (i+1)
| otherwise = False
any :: PrimType ty => (ty -> Bool) -> Immutable -> Offset ty -> Offset ty -> Bool
any predicate ba start end = loop start
where
loop !i
| i == end = False
| predicate (primIndex ba i) = True
| otherwise = loop (i+1)
inplaceSortBy :: (PrimType ty, PrimMonad prim)
=> (ty -> ty -> Ordering)
-> Mutable (PrimState prim)
-> Offset ty
-> Offset ty
-> prim ()
inplaceSortBy ford ma start end = qsort start (end `offsetSub` 1)
where
qsort lo hi
| lo >= hi = pure ()
| otherwise = do
p <- partition lo hi
qsort lo (pred p)
qsort (p+1) hi
pivotStrategy (Offset low) hi@(Offset high) = do
let mid = Offset $ (low + high) `div` 2
pivot <- primRead ma mid
primRead ma hi >>= primWrite ma mid
primWrite ma hi pivot
pure pivot
partition lo hi = do
pivot <- pivotStrategy lo hi
let go i j = do
let fw k = do ak <- primRead ma k
if ford ak pivot == LT
then fw (k+1)
else pure (k, ak)
(i, ai) <- fw i
let bw k | k==i = pure (i, ai)
| otherwise = do ak <- primRead ma k
if ford ak pivot /= LT
then bw (pred k)
else pure (k, ak)
(j, aj) <- bw j
if i < j
then do
primWrite ma i aj
primWrite ma j ai
go (i+1) (pred j)
else do
primWrite ma hi ai
primWrite ma i pivot
pure i
go lo hi