-- Alfred-Margaret: Fast Aho-Corasick string searching
-- Copyright 2022 Channable
--
-- Licensed under the 3-clause BSD license, see the LICENSE file in the
-- repository root.
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Data.TypedByteArray
    ( Data.TypedByteArray.replicate
    , MutableTypedByteArray
    , Prim
    , TypedByteArray
    , fromList
    , toList
    , generate
    , newTypedByteArray
    , unsafeFreezeTypedByteArray
    , unsafeIndex
    , writeTypedByteArray
    , null
    , length
    , foldr
    ) where

import Prelude hiding (foldr, length, null)

import Control.DeepSeq (NFData (rnf))
import Control.Monad.Primitive (PrimMonad (PrimState))
import Control.Monad.ST (runST)
import Data.Primitive (ByteArray (ByteArray), MutableByteArray, Prim, byteArrayFromList,
                       indexByteArray, newByteArray, sizeOf, unsafeFreezeByteArray, writeByteArray)

import qualified Data.Primitive as Primitive


-- | Thin wrapper around 'ByteArray' that makes signatures and indexing nicer to read.
newtype TypedByteArray a = TypedByteArray ByteArray
    deriving (Int -> TypedByteArray a -> ShowS
[TypedByteArray a] -> ShowS
TypedByteArray a -> String
(Int -> TypedByteArray a -> ShowS)
-> (TypedByteArray a -> String)
-> ([TypedByteArray a] -> ShowS)
-> Show (TypedByteArray a)
forall a. Int -> TypedByteArray a -> ShowS
forall a. [TypedByteArray a] -> ShowS
forall a. TypedByteArray a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall a. Int -> TypedByteArray a -> ShowS
showsPrec :: Int -> TypedByteArray a -> ShowS
$cshow :: forall a. TypedByteArray a -> String
show :: TypedByteArray a -> String
$cshowList :: forall a. [TypedByteArray a] -> ShowS
showList :: [TypedByteArray a] -> ShowS
Show, TypedByteArray a -> TypedByteArray a -> Bool
(TypedByteArray a -> TypedByteArray a -> Bool)
-> (TypedByteArray a -> TypedByteArray a -> Bool)
-> Eq (TypedByteArray a)
forall a. TypedByteArray a -> TypedByteArray a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall a. TypedByteArray a -> TypedByteArray a -> Bool
== :: TypedByteArray a -> TypedByteArray a -> Bool
$c/= :: forall a. TypedByteArray a -> TypedByteArray a -> Bool
/= :: TypedByteArray a -> TypedByteArray a -> Bool
Eq)

-- | Thin wrapper around 'MutableByteArray s' that makes signatures and indexing nicer to read.
newtype MutableTypedByteArray a s = MutableTypedByteArray (MutableByteArray s)

instance NFData (TypedByteArray a) where
    rnf :: TypedByteArray a -> ()
rnf (TypedByteArray (ByteArray !ByteArray#
_)) = ()

{-# INLINE newTypedByteArray #-}
newTypedByteArray :: forall a m. (Prim a, PrimMonad m) => Int -> m (MutableTypedByteArray a (PrimState m))
newTypedByteArray :: forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
Int -> m (MutableTypedByteArray a (PrimState m))
newTypedByteArray = (MutableByteArray (PrimState m)
 -> MutableTypedByteArray a (PrimState m))
-> m (MutableByteArray (PrimState m))
-> m (MutableTypedByteArray a (PrimState m))
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap MutableByteArray (PrimState m)
-> MutableTypedByteArray a (PrimState m)
forall a s. MutableByteArray s -> MutableTypedByteArray a s
MutableTypedByteArray (m (MutableByteArray (PrimState m))
 -> m (MutableTypedByteArray a (PrimState m)))
-> (Int -> m (MutableByteArray (PrimState m)))
-> Int
-> m (MutableTypedByteArray a (PrimState m))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> m (MutableByteArray (PrimState m))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
newByteArray (Int -> m (MutableByteArray (PrimState m)))
-> (Int -> Int) -> Int -> m (MutableByteArray (PrimState m))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> Int -> Int
forall a. Num a => a -> a -> a
* a -> Int
forall a. Prim a => a -> Int
sizeOf (a
forall a. HasCallStack => a
undefined :: a))

{-# INLINE fromList #-}
fromList :: Prim a => [a] -> TypedByteArray a
fromList :: forall a. Prim a => [a] -> TypedByteArray a
fromList = ByteArray -> TypedByteArray a
forall a. ByteArray -> TypedByteArray a
TypedByteArray (ByteArray -> TypedByteArray a)
-> ([a] -> ByteArray) -> [a] -> TypedByteArray a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [a] -> ByteArray
forall a. Prim a => [a] -> ByteArray
byteArrayFromList

{-# INLINE toList #-}
toList :: Prim a => TypedByteArray a -> [a]
toList :: forall a. Prim a => TypedByteArray a -> [a]
toList = (a -> [a] -> [a]) -> [a] -> TypedByteArray a -> [a]
forall a b. Prim a => (a -> b -> b) -> b -> TypedByteArray a -> b
foldr (:) []

-- | Element index without bounds checking.
{-# INLINE unsafeIndex #-}
unsafeIndex :: Prim a => TypedByteArray a -> Int -> a
unsafeIndex :: forall a. Prim a => TypedByteArray a -> Int -> a
unsafeIndex (TypedByteArray ByteArray
arr) = ByteArray -> Int -> a
forall a. Prim a => ByteArray -> Int -> a
indexByteArray ByteArray
arr

{-# INLINE generate #-}
-- | Construct a 'TypedByteArray' of the given length by applying the function to each index in @[0..n-1]@.
generate :: Prim a => Int -> (Int -> a) -> TypedByteArray a
generate :: forall a. Prim a => Int -> (Int -> a) -> TypedByteArray a
generate !Int
n Int -> a
f = (forall s. ST s (TypedByteArray a)) -> TypedByteArray a
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (TypedByteArray a)) -> TypedByteArray a)
-> (forall s. ST s (TypedByteArray a)) -> TypedByteArray a
forall a b. (a -> b) -> a -> b
$ do
    -- Allocate enough space for n elements of type a
    MutableTypedByteArray a s
arr <- Int -> ST s (MutableTypedByteArray a (PrimState (ST s)))
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
Int -> m (MutableTypedByteArray a (PrimState m))
newTypedByteArray Int
n
    Int -> Int -> (Int -> ST s ()) -> ST s ()
forall (m :: * -> *).
Monad m =>
Int -> Int -> (Int -> m ()) -> m ()
intLoop Int
0 Int
n ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> Int
i Int -> ST s () -> ST s ()
forall a b. a -> b -> b
`seq` MutableTypedByteArray a (PrimState (ST s)) -> Int -> a -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableTypedByteArray a (PrimState m) -> Int -> a -> m ()
writeTypedByteArray MutableTypedByteArray a s
MutableTypedByteArray a (PrimState (ST s))
arr Int
i (a -> ST s ()) -> a -> ST s ()
forall a b. (a -> b) -> a -> b
$ Int -> a
f Int
i

    MutableTypedByteArray a (PrimState (ST s))
-> ST s (TypedByteArray a)
forall (m :: * -> *) a.
PrimMonad m =>
MutableTypedByteArray a (PrimState m) -> m (TypedByteArray a)
unsafeFreezeTypedByteArray MutableTypedByteArray a s
MutableTypedByteArray a (PrimState (ST s))
arr

replicate :: (Prim a, PrimMonad m) => Int -> a -> m (MutableTypedByteArray a (PrimState m))
replicate :: forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
Int -> a -> m (MutableTypedByteArray a (PrimState m))
replicate Int
n a
value = do
    MutableTypedByteArray a (PrimState m)
arr <- Int -> m (MutableTypedByteArray a (PrimState m))
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
Int -> m (MutableTypedByteArray a (PrimState m))
newTypedByteArray Int
n
    Int -> Int -> (Int -> m ()) -> m ()
forall (m :: * -> *).
Monad m =>
Int -> Int -> (Int -> m ()) -> m ()
intLoop Int
0 Int
n ((Int -> m ()) -> m ()) -> (Int -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> Int
i Int -> m () -> m ()
forall a b. a -> b -> b
`seq` MutableTypedByteArray a (PrimState m) -> Int -> a -> m ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableTypedByteArray a (PrimState m) -> Int -> a -> m ()
writeTypedByteArray MutableTypedByteArray a (PrimState m)
arr Int
i a
value
    MutableTypedByteArray a (PrimState m)
-> m (MutableTypedByteArray a (PrimState m))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure MutableTypedByteArray a (PrimState m)
arr

{-# INLINE writeTypedByteArray #-}
writeTypedByteArray :: (Prim a, PrimMonad m) => MutableTypedByteArray a (PrimState m) -> Int -> a -> m ()
writeTypedByteArray :: forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableTypedByteArray a (PrimState m) -> Int -> a -> m ()
writeTypedByteArray (MutableTypedByteArray MutableByteArray (PrimState m)
array) = MutableByteArray (PrimState m) -> Int -> a -> m ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
writeByteArray MutableByteArray (PrimState m)
array

{-# INLINE unsafeFreezeTypedByteArray #-}
unsafeFreezeTypedByteArray :: PrimMonad m => MutableTypedByteArray a (PrimState m) -> m (TypedByteArray a)
unsafeFreezeTypedByteArray :: forall (m :: * -> *) a.
PrimMonad m =>
MutableTypedByteArray a (PrimState m) -> m (TypedByteArray a)
unsafeFreezeTypedByteArray (MutableTypedByteArray MutableByteArray (PrimState m)
array) = ByteArray -> TypedByteArray a
forall a. ByteArray -> TypedByteArray a
TypedByteArray (ByteArray -> TypedByteArray a)
-> m ByteArray -> m (TypedByteArray a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MutableByteArray (PrimState m) -> m ByteArray
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
unsafeFreezeByteArray MutableByteArray (PrimState m)
array

{-# INLINE intLoop #-}
intLoop :: Monad m => Int -> Int -> (Int -> m ()) -> m ()
intLoop :: forall (m :: * -> *).
Monad m =>
Int -> Int -> (Int -> m ()) -> m ()
intLoop !Int
iStart !Int
n Int -> m ()
p = Int -> m ()
go Int
iStart
    where
        go :: Int -> m ()
go !Int
i
            | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
n = () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
            | Bool
otherwise = do
                Int -> m ()
p Int
i
                Int -> m ()
go (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)

{-# INLINE null #-}
null :: TypedByteArray a -> Bool
null :: forall a. TypedByteArray a -> Bool
null (TypedByteArray ByteArray
arr) =
  ByteArray -> Int
Primitive.sizeofByteArray ByteArray
arr Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0  -- under the assumption that elements are not size 0

{-# INLINE length #-}
length :: forall a. Prim a => TypedByteArray a -> Int
length :: forall a. Prim a => TypedByteArray a -> Int
length (TypedByteArray ByteArray
arr) =
  -- This is how foldrByteArray calculates it, so must be good
  ByteArray -> Int
Primitive.sizeofByteArray ByteArray
arr Int -> Int -> Int
forall a. Integral a => a -> a -> a
`quot` a -> Int
forall a. Prim a => a -> Int
sizeOf (a
forall a. HasCallStack => a
undefined :: a)

{-# INLINE foldr #-}
foldr :: Prim a => (a -> b -> b) -> b -> TypedByteArray a -> b
foldr :: forall a b. Prim a => (a -> b -> b) -> b -> TypedByteArray a -> b
foldr a -> b -> b
f b
a (TypedByteArray ByteArray
arr) = (a -> b -> b) -> b -> ByteArray -> b
forall a b. Prim a => (a -> b -> b) -> b -> ByteArray -> b
Primitive.foldrByteArray a -> b -> b
f b
a ByteArray
arr