-- Alfred-Margaret: Fast Aho-Corasick string searching
-- Copyright 2022 Channable
--
-- Licensed under the 3-clause BSD license, see the LICENSE file in the
-- repository root.
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Data.TypedByteArray
    ( Data.TypedByteArray.replicate
    , MutableTypedByteArray
    , Prim
    , TypedByteArray
    , fromList
    , generate
    , newTypedByteArray
    , unsafeFreezeTypedByteArray
    , unsafeIndex
    , writeTypedByteArray
    ) where

import Control.DeepSeq (NFData (rnf))
import Control.Monad.Primitive (PrimMonad (PrimState))
import Control.Monad.ST (runST)
import Data.Primitive (ByteArray (ByteArray), MutableByteArray, Prim, byteArrayFromList,
                       indexByteArray, newByteArray, sizeOf, unsafeFreezeByteArray, writeByteArray)

-- | Thin wrapper around 'ByteArray' that makes signatures and indexing nicer to read.
newtype TypedByteArray a = TypedByteArray ByteArray
    deriving Int -> TypedByteArray a -> ShowS
[TypedByteArray a] -> ShowS
TypedByteArray a -> String
(Int -> TypedByteArray a -> ShowS)
-> (TypedByteArray a -> String)
-> ([TypedByteArray a] -> ShowS)
-> Show (TypedByteArray a)
forall a. Int -> TypedByteArray a -> ShowS
forall a. [TypedByteArray a] -> ShowS
forall a. TypedByteArray a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [TypedByteArray a] -> ShowS
$cshowList :: forall a. [TypedByteArray a] -> ShowS
show :: TypedByteArray a -> String
$cshow :: forall a. TypedByteArray a -> String
showsPrec :: Int -> TypedByteArray a -> ShowS
$cshowsPrec :: forall a. Int -> TypedByteArray a -> ShowS
Show

-- | Thin wrapper around 'MutableByteArray s' that makes signatures and indexing nicer to read.
newtype MutableTypedByteArray a s = MutableTypedByteArray (MutableByteArray s)

instance NFData (TypedByteArray a) where
    rnf :: TypedByteArray a -> ()
rnf (TypedByteArray (ByteArray !ByteArray#
_)) = ()

{-# INLINE newTypedByteArray #-}
newTypedByteArray :: forall a m. (Prim a, PrimMonad m) => Int -> m (MutableTypedByteArray a (PrimState m))
newTypedByteArray :: Int -> m (MutableTypedByteArray a (PrimState m))
newTypedByteArray = (MutableByteArray (PrimState m)
 -> MutableTypedByteArray a (PrimState m))
-> m (MutableByteArray (PrimState m))
-> m (MutableTypedByteArray a (PrimState m))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap MutableByteArray (PrimState m)
-> MutableTypedByteArray a (PrimState m)
forall a s. MutableByteArray s -> MutableTypedByteArray a s
MutableTypedByteArray (m (MutableByteArray (PrimState m))
 -> m (MutableTypedByteArray a (PrimState m)))
-> (Int -> m (MutableByteArray (PrimState m)))
-> Int
-> m (MutableTypedByteArray a (PrimState m))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> m (MutableByteArray (PrimState m))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
newByteArray (Int -> m (MutableByteArray (PrimState m)))
-> (Int -> Int) -> Int -> m (MutableByteArray (PrimState m))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> Int -> Int
forall a. Num a => a -> a -> a
* a -> Int
forall a. Prim a => a -> Int
sizeOf (a
forall a. HasCallStack => a
undefined :: a))

{-# INLINE fromList #-}
fromList :: Prim a => [a] -> TypedByteArray a
fromList :: [a] -> TypedByteArray a
fromList = ByteArray -> TypedByteArray a
forall a. ByteArray -> TypedByteArray a
TypedByteArray (ByteArray -> TypedByteArray a)
-> ([a] -> ByteArray) -> [a] -> TypedByteArray a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [a] -> ByteArray
forall a. Prim a => [a] -> ByteArray
byteArrayFromList

-- | Element index without bounds checking.
{-# INLINE unsafeIndex #-}
unsafeIndex :: Prim a => TypedByteArray a -> Int -> a
unsafeIndex :: TypedByteArray a -> Int -> a
unsafeIndex (TypedByteArray ByteArray
arr) = ByteArray -> Int -> a
forall a. Prim a => ByteArray -> Int -> a
indexByteArray ByteArray
arr

{-# INLINE generate #-}
-- | Construct a 'TypedByteArray' of the given length by applying the function to each index in @[0..n-1]@.
generate :: Prim a => Int -> (Int -> a) -> TypedByteArray a
generate :: Int -> (Int -> a) -> TypedByteArray a
generate !Int
n Int -> a
f = (forall s. ST s (TypedByteArray a)) -> TypedByteArray a
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (TypedByteArray a)) -> TypedByteArray a)
-> (forall s. ST s (TypedByteArray a)) -> TypedByteArray a
forall a b. (a -> b) -> a -> b
$ do
    -- Allocate enough space for n elements of type a
    MutableTypedByteArray a s
arr <- Int -> ST s (MutableTypedByteArray a (PrimState (ST s)))
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
Int -> m (MutableTypedByteArray a (PrimState m))
newTypedByteArray Int
n
    Int -> Int -> (Int -> ST s ()) -> ST s ()
forall (m :: * -> *).
Monad m =>
Int -> Int -> (Int -> m ()) -> m ()
intLoop Int
0 Int
n ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> Int
i Int -> ST s () -> ST s ()
`seq` MutableTypedByteArray a (PrimState (ST s)) -> Int -> a -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableTypedByteArray a (PrimState m) -> Int -> a -> m ()
writeTypedByteArray MutableTypedByteArray a s
MutableTypedByteArray a (PrimState (ST s))
arr Int
i (a -> ST s ()) -> a -> ST s ()
forall a b. (a -> b) -> a -> b
$ Int -> a
f Int
i

    MutableTypedByteArray a (PrimState (ST s))
-> ST s (TypedByteArray a)
forall (m :: * -> *) a.
PrimMonad m =>
MutableTypedByteArray a (PrimState m) -> m (TypedByteArray a)
unsafeFreezeTypedByteArray MutableTypedByteArray a s
MutableTypedByteArray a (PrimState (ST s))
arr

replicate :: (Prim a, PrimMonad m) => Int -> a -> m (MutableTypedByteArray a (PrimState m))
replicate :: Int -> a -> m (MutableTypedByteArray a (PrimState m))
replicate Int
n a
value = do
    MutableTypedByteArray a (PrimState m)
arr <- Int -> m (MutableTypedByteArray a (PrimState m))
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
Int -> m (MutableTypedByteArray a (PrimState m))
newTypedByteArray Int
n
    Int -> Int -> (Int -> m ()) -> m ()
forall (m :: * -> *).
Monad m =>
Int -> Int -> (Int -> m ()) -> m ()
intLoop Int
0 Int
n ((Int -> m ()) -> m ()) -> (Int -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> Int
i Int -> m () -> m ()
`seq` MutableTypedByteArray a (PrimState m) -> Int -> a -> m ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableTypedByteArray a (PrimState m) -> Int -> a -> m ()
writeTypedByteArray MutableTypedByteArray a (PrimState m)
arr Int
i a
value
    MutableTypedByteArray a (PrimState m)
-> m (MutableTypedByteArray a (PrimState m))
forall (f :: * -> *) a. Applicative f => a -> f a
pure MutableTypedByteArray a (PrimState m)
arr

{-# INLINE writeTypedByteArray #-}
writeTypedByteArray :: (Prim a, PrimMonad m) => MutableTypedByteArray a (PrimState m) -> Int -> a -> m ()
writeTypedByteArray :: MutableTypedByteArray a (PrimState m) -> Int -> a -> m ()
writeTypedByteArray (MutableTypedByteArray MutableByteArray (PrimState m)
array) = MutableByteArray (PrimState m) -> Int -> a -> m ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
writeByteArray MutableByteArray (PrimState m)
array

{-# INLINE unsafeFreezeTypedByteArray #-}
unsafeFreezeTypedByteArray :: PrimMonad m => MutableTypedByteArray a (PrimState m) -> m (TypedByteArray a)
unsafeFreezeTypedByteArray :: MutableTypedByteArray a (PrimState m) -> m (TypedByteArray a)
unsafeFreezeTypedByteArray (MutableTypedByteArray MutableByteArray (PrimState m)
array) = ByteArray -> TypedByteArray a
forall a. ByteArray -> TypedByteArray a
TypedByteArray (ByteArray -> TypedByteArray a)
-> m ByteArray -> m (TypedByteArray a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MutableByteArray (PrimState m) -> m ByteArray
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
unsafeFreezeByteArray MutableByteArray (PrimState m)
array

{-# INLINE intLoop #-}
intLoop :: Monad m => Int -> Int -> (Int -> m ()) -> m ()
intLoop :: Int -> Int -> (Int -> m ()) -> m ()
intLoop !Int
iStart !Int
n Int -> m ()
p = Int -> m ()
go Int
iStart
    where
        go :: Int -> m ()
go !Int
i
            | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
n = () -> m ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
            | Bool
otherwise = do
                Int -> m ()
p Int
i
                Int -> m ()
go (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)