module Dahdit.LiftedPrim
  ( LiftedPrim (..)
  , LiftedPrimArray (..)
  , MutableLiftedPrimArray (..)
  , emptyLiftedPrimArray
  , indexLiftedPrimArray
  , writeLiftedPrimArray
  , freezeLiftedPrimArray
  , thawLiftedPrimArray
  , unsafeFreezeLiftedPrimArray
  , unsafeThawLiftedPrimArray
  , liftedPrimArrayFromListN
  , liftedPrimArrayFromList
  , generateLiftedPrimArray
  , sizeofLiftedPrimArray
  , cloneLiftedPrimArray
  )
where

import Control.Monad.Primitive (PrimMonad (..))
import Dahdit.Internal (ViaFromIntegral (..))
import Dahdit.Proxy (proxyForF)
import Data.Default (Default (..))
import Data.Foldable (for_)
import Data.Int (Int8)
import Data.Primitive.ByteArray
  ( ByteArray
  , MutableByteArray
  , cloneByteArray
  , emptyByteArray
  , freezeByteArray
  , indexByteArray
  , newByteArray
  , runByteArray
  , sizeofByteArray
  , thawByteArray
  , unsafeFreezeByteArray
  , unsafeThawByteArray
  , writeByteArray
  )
import Data.Proxy (Proxy (..))
import Data.STRef (modifySTRef', newSTRef, readSTRef)
import Data.Word (Word8)

-- | This is a stripped-down version of 'Prim' that is possible for a human to implement.
-- It's all about reading and writing structures from byte arrays.
class LiftedPrim a where
  elemSizeLifted :: Proxy a -> Int
  indexByteArrayLiftedInBytes :: ByteArray -> Int -> a
  indexByteArrayLiftedInElems :: ByteArray -> Int -> a
  indexByteArrayLiftedInElems ByteArray
arr Int
pos =
    let !sz :: Int
sz = forall a. LiftedPrim a => Proxy a -> Int
elemSizeLifted (forall {k} (t :: k). Proxy t
Proxy :: Proxy a)
    in  forall a. LiftedPrim a => ByteArray -> Int -> a
indexByteArrayLiftedInBytes ByteArray
arr (Int
pos forall a. Num a => a -> a -> a
* Int
sz)
  writeByteArrayLiftedInBytes :: PrimMonad m => a -> MutableByteArray (PrimState m) -> Int -> m ()
  writeByteArrayLiftedInElems :: PrimMonad m => a -> MutableByteArray (PrimState m) -> Int -> m ()
  writeByteArrayLiftedInElems a
val MutableByteArray (PrimState m)
arr Int
pos =
    let !sz :: Int
sz = forall a. LiftedPrim a => Proxy a -> Int
elemSizeLifted (forall {k} (t :: k). Proxy t
Proxy :: Proxy a)
    in  forall a (m :: * -> *).
(LiftedPrim a, PrimMonad m) =>
a -> MutableByteArray (PrimState m) -> Int -> m ()
writeByteArrayLiftedInBytes a
val MutableByteArray (PrimState m)
arr (Int
pos forall a. Num a => a -> a -> a
* Int
sz)

instance LiftedPrim Word8 where
  elemSizeLifted :: Proxy Word8 -> Int
elemSizeLifted Proxy Word8
_ = Int
1
  indexByteArrayLiftedInBytes :: ByteArray -> Int -> Word8
indexByteArrayLiftedInBytes = forall a. Prim a => ByteArray -> Int -> a
indexByteArray
  writeByteArrayLiftedInBytes :: forall (m :: * -> *).
PrimMonad m =>
Word8 -> MutableByteArray (PrimState m) -> Int -> m ()
writeByteArrayLiftedInBytes Word8
val MutableByteArray (PrimState m)
arr Int
pos = forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
writeByteArray MutableByteArray (PrimState m)
arr Int
pos Word8
val

instance LiftedPrim Int8 where
  elemSizeLifted :: Proxy Int8 -> Int
elemSizeLifted Proxy Int8
_ = Int
1
  indexByteArrayLiftedInBytes :: ByteArray -> Int -> Int8
indexByteArrayLiftedInBytes = forall a. Prim a => ByteArray -> Int -> a
indexByteArray
  indexByteArrayLiftedInElems :: ByteArray -> Int -> Int8
indexByteArrayLiftedInElems = forall a. Prim a => ByteArray -> Int -> a
indexByteArray
  writeByteArrayLiftedInBytes :: forall (m :: * -> *).
PrimMonad m =>
Int8 -> MutableByteArray (PrimState m) -> Int -> m ()
writeByteArrayLiftedInBytes Int8
val MutableByteArray (PrimState m)
arr Int
pos = forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
writeByteArray MutableByteArray (PrimState m)
arr Int
pos Int8
val
  writeByteArrayLiftedInElems :: forall (m :: * -> *).
PrimMonad m =>
Int8 -> MutableByteArray (PrimState m) -> Int -> m ()
writeByteArrayLiftedInElems Int8
val MutableByteArray (PrimState m)
arr Int
pos = forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
writeByteArray MutableByteArray (PrimState m)
arr Int
pos Int8
val

-- | NOTE: Relies on same byte width of both types!
instance (Integral x, LiftedPrim x, Integral y) => LiftedPrim (ViaFromIntegral x y) where
  elemSizeLifted :: Proxy (ViaFromIntegral x y) -> Int
elemSizeLifted Proxy (ViaFromIntegral x y)
_ = forall a. LiftedPrim a => Proxy a -> Int
elemSizeLifted (forall {k} (t :: k). Proxy t
Proxy :: Proxy x)
  indexByteArrayLiftedInBytes :: ByteArray -> Int -> ViaFromIntegral x y
indexByteArrayLiftedInBytes ByteArray
arr Int
pos = forall x y. y -> ViaFromIntegral x y
ViaFromIntegral (forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall a. LiftedPrim a => ByteArray -> Int -> a
indexByteArrayLiftedInBytes ByteArray
arr Int
pos :: x))
  writeByteArrayLiftedInBytes :: forall (m :: * -> *).
PrimMonad m =>
ViaFromIntegral x y
-> MutableByteArray (PrimState m) -> Int -> m ()
writeByteArrayLiftedInBytes ViaFromIntegral x y
val MutableByteArray (PrimState m)
arr Int
pos = let !x :: x
x = forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall x y. ViaFromIntegral x y -> y
unViaFromIntegral ViaFromIntegral x y
val) :: x in forall a (m :: * -> *).
(LiftedPrim a, PrimMonad m) =>
a -> MutableByteArray (PrimState m) -> Int -> m ()
writeByteArrayLiftedInBytes x
x MutableByteArray (PrimState m)
arr Int
pos

newtype LiftedPrimArray a = LiftedPrimArray {forall a. LiftedPrimArray a -> ByteArray
unLiftedPrimArray :: ByteArray}
  deriving stock (Int -> LiftedPrimArray a -> ShowS
forall a. Int -> LiftedPrimArray a -> ShowS
forall a. [LiftedPrimArray a] -> ShowS
forall a. LiftedPrimArray a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [LiftedPrimArray a] -> ShowS
$cshowList :: forall a. [LiftedPrimArray a] -> ShowS
show :: LiftedPrimArray a -> String
$cshow :: forall a. LiftedPrimArray a -> String
showsPrec :: Int -> LiftedPrimArray a -> ShowS
$cshowsPrec :: forall a. Int -> LiftedPrimArray a -> ShowS
Show)
  deriving newtype (LiftedPrimArray a -> LiftedPrimArray a -> Bool
forall a. LiftedPrimArray a -> LiftedPrimArray a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: LiftedPrimArray a -> LiftedPrimArray a -> Bool
$c/= :: forall a. LiftedPrimArray a -> LiftedPrimArray a -> Bool
== :: LiftedPrimArray a -> LiftedPrimArray a -> Bool
$c== :: forall a. LiftedPrimArray a -> LiftedPrimArray a -> Bool
Eq, NonEmpty (LiftedPrimArray a) -> LiftedPrimArray a
LiftedPrimArray a -> LiftedPrimArray a -> LiftedPrimArray a
forall b. Integral b => b -> LiftedPrimArray a -> LiftedPrimArray a
forall a. NonEmpty (LiftedPrimArray a) -> LiftedPrimArray a
forall a.
LiftedPrimArray a -> LiftedPrimArray a -> LiftedPrimArray a
forall a.
(a -> a -> a)
-> (NonEmpty a -> a)
-> (forall b. Integral b => b -> a -> a)
-> Semigroup a
forall a b.
Integral b =>
b -> LiftedPrimArray a -> LiftedPrimArray a
stimes :: forall b. Integral b => b -> LiftedPrimArray a -> LiftedPrimArray a
$cstimes :: forall a b.
Integral b =>
b -> LiftedPrimArray a -> LiftedPrimArray a
sconcat :: NonEmpty (LiftedPrimArray a) -> LiftedPrimArray a
$csconcat :: forall a. NonEmpty (LiftedPrimArray a) -> LiftedPrimArray a
<> :: LiftedPrimArray a -> LiftedPrimArray a -> LiftedPrimArray a
$c<> :: forall a.
LiftedPrimArray a -> LiftedPrimArray a -> LiftedPrimArray a
Semigroup, LiftedPrimArray a
[LiftedPrimArray a] -> LiftedPrimArray a
LiftedPrimArray a -> LiftedPrimArray a -> LiftedPrimArray a
forall a. Semigroup (LiftedPrimArray a)
forall a. LiftedPrimArray a
forall a.
Semigroup a -> a -> (a -> a -> a) -> ([a] -> a) -> Monoid a
forall a. [LiftedPrimArray a] -> LiftedPrimArray a
forall a.
LiftedPrimArray a -> LiftedPrimArray a -> LiftedPrimArray a
mconcat :: [LiftedPrimArray a] -> LiftedPrimArray a
$cmconcat :: forall a. [LiftedPrimArray a] -> LiftedPrimArray a
mappend :: LiftedPrimArray a -> LiftedPrimArray a -> LiftedPrimArray a
$cmappend :: forall a.
LiftedPrimArray a -> LiftedPrimArray a -> LiftedPrimArray a
mempty :: LiftedPrimArray a
$cmempty :: forall a. LiftedPrimArray a
Monoid)

instance Default (LiftedPrimArray a) where
  def :: LiftedPrimArray a
def = forall a. LiftedPrimArray a
emptyLiftedPrimArray

newtype MutableLiftedPrimArray m a = MutableLiftedPrimArray {forall m a. MutableLiftedPrimArray m a -> MutableByteArray m
unMutableLiftedPrimArray :: MutableByteArray m}
  deriving newtype (MutableLiftedPrimArray m a -> MutableLiftedPrimArray m a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall m a.
MutableLiftedPrimArray m a -> MutableLiftedPrimArray m a -> Bool
/= :: MutableLiftedPrimArray m a -> MutableLiftedPrimArray m a -> Bool
$c/= :: forall m a.
MutableLiftedPrimArray m a -> MutableLiftedPrimArray m a -> Bool
== :: MutableLiftedPrimArray m a -> MutableLiftedPrimArray m a -> Bool
$c== :: forall m a.
MutableLiftedPrimArray m a -> MutableLiftedPrimArray m a -> Bool
Eq)

emptyLiftedPrimArray :: LiftedPrimArray a
emptyLiftedPrimArray :: forall a. LiftedPrimArray a
emptyLiftedPrimArray = forall a. ByteArray -> LiftedPrimArray a
LiftedPrimArray ByteArray
emptyByteArray

indexLiftedPrimArray :: LiftedPrim a => LiftedPrimArray a -> Int -> a
indexLiftedPrimArray :: forall a. LiftedPrim a => LiftedPrimArray a -> Int -> a
indexLiftedPrimArray (LiftedPrimArray ByteArray
arr) = forall a. LiftedPrim a => ByteArray -> Int -> a
indexByteArrayLiftedInElems ByteArray
arr

writeLiftedPrimArray :: (LiftedPrim a, PrimMonad m) => a -> MutableLiftedPrimArray (PrimState m) a -> Int -> m ()
writeLiftedPrimArray :: forall a (m :: * -> *).
(LiftedPrim a, PrimMonad m) =>
a -> MutableLiftedPrimArray (PrimState m) a -> Int -> m ()
writeLiftedPrimArray a
val (MutableLiftedPrimArray MutableByteArray (PrimState m)
arr) = forall a (m :: * -> *).
(LiftedPrim a, PrimMonad m) =>
a -> MutableByteArray (PrimState m) -> Int -> m ()
writeByteArrayLiftedInElems a
val MutableByteArray (PrimState m)
arr

freezeLiftedPrimArray :: PrimMonad m => MutableLiftedPrimArray (PrimState m) a -> Int -> Int -> m (LiftedPrimArray a)
freezeLiftedPrimArray :: forall (m :: * -> *) a.
PrimMonad m =>
MutableLiftedPrimArray (PrimState m) a
-> Int -> Int -> m (LiftedPrimArray a)
freezeLiftedPrimArray (MutableLiftedPrimArray MutableByteArray (PrimState m)
arr) Int
off Int
len = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. ByteArray -> LiftedPrimArray a
LiftedPrimArray (forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> Int -> Int -> m ByteArray
freezeByteArray MutableByteArray (PrimState m)
arr Int
off Int
len)

unsafeFreezeLiftedPrimArray :: PrimMonad m => MutableLiftedPrimArray (PrimState m) a -> m (LiftedPrimArray a)
unsafeFreezeLiftedPrimArray :: forall (m :: * -> *) a.
PrimMonad m =>
MutableLiftedPrimArray (PrimState m) a -> m (LiftedPrimArray a)
unsafeFreezeLiftedPrimArray (MutableLiftedPrimArray MutableByteArray (PrimState m)
arr) = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. ByteArray -> LiftedPrimArray a
LiftedPrimArray (forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
unsafeFreezeByteArray MutableByteArray (PrimState m)
arr)

thawLiftedPrimArray :: PrimMonad m => LiftedPrimArray a -> Int -> Int -> m (MutableLiftedPrimArray (PrimState m) a)
thawLiftedPrimArray :: forall (m :: * -> *) a.
PrimMonad m =>
LiftedPrimArray a
-> Int -> Int -> m (MutableLiftedPrimArray (PrimState m) a)
thawLiftedPrimArray (LiftedPrimArray ByteArray
arr) Int
off Int
len = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall m a. MutableByteArray m -> MutableLiftedPrimArray m a
MutableLiftedPrimArray (forall (m :: * -> *).
PrimMonad m =>
ByteArray -> Int -> Int -> m (MutableByteArray (PrimState m))
thawByteArray ByteArray
arr Int
off Int
len)

unsafeThawLiftedPrimArray :: PrimMonad m => LiftedPrimArray a -> m (MutableLiftedPrimArray (PrimState m) a)
unsafeThawLiftedPrimArray :: forall (m :: * -> *) a.
PrimMonad m =>
LiftedPrimArray a -> m (MutableLiftedPrimArray (PrimState m) a)
unsafeThawLiftedPrimArray (LiftedPrimArray ByteArray
arr) = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall m a. MutableByteArray m -> MutableLiftedPrimArray m a
MutableLiftedPrimArray (forall (m :: * -> *).
PrimMonad m =>
ByteArray -> m (MutableByteArray (PrimState m))
unsafeThawByteArray ByteArray
arr)

liftedPrimArrayFromListN :: LiftedPrim a => Int -> [a] -> LiftedPrimArray a
liftedPrimArrayFromListN :: forall a. LiftedPrim a => Int -> [a] -> LiftedPrimArray a
liftedPrimArrayFromListN Int
n [a]
xs = forall a. ByteArray -> LiftedPrimArray a
LiftedPrimArray forall a b. (a -> b) -> a -> b
$ (forall s. ST s (MutableByteArray s)) -> ByteArray
runByteArray forall a b. (a -> b) -> a -> b
$ do
  let !elemSize :: Int
elemSize = forall a. LiftedPrim a => Proxy a -> Int
elemSizeLifted (forall (f :: * -> *) a. f a -> Proxy a
proxyForF [a]
xs)
      !len :: Int
len = Int
n forall a. Num a => a -> a -> a
* Int
elemSize
  MutableByteArray s
arr <- forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
newByteArray Int
len
  STRef s Int
offRef <- forall a s. a -> ST s (STRef s a)
newSTRef Int
0
  forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
t a -> (a -> f b) -> f ()
for_ [a]
xs forall a b. (a -> b) -> a -> b
$ \a
x -> do
    Int
off <- forall s a. STRef s a -> ST s a
readSTRef STRef s Int
offRef
    forall a (m :: * -> *).
(LiftedPrim a, PrimMonad m) =>
a -> MutableByteArray (PrimState m) -> Int -> m ()
writeByteArrayLiftedInBytes a
x MutableByteArray s
arr Int
off
    forall s a. STRef s a -> (a -> a) -> ST s ()
modifySTRef' STRef s Int
offRef (Int
elemSize forall a. Num a => a -> a -> a
+)
  forall (f :: * -> *) a. Applicative f => a -> f a
pure MutableByteArray s
arr

liftedPrimArrayFromList :: LiftedPrim a => [a] -> LiftedPrimArray a
liftedPrimArrayFromList :: forall a. LiftedPrim a => [a] -> LiftedPrimArray a
liftedPrimArrayFromList [a]
xs = forall a. LiftedPrim a => Int -> [a] -> LiftedPrimArray a
liftedPrimArrayFromListN (forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
xs) [a]
xs

generateLiftedPrimArray :: LiftedPrim a => Int -> (Int -> a) -> LiftedPrimArray a
generateLiftedPrimArray :: forall a. LiftedPrim a => Int -> (Int -> a) -> LiftedPrimArray a
generateLiftedPrimArray Int
n Int -> a
f = forall a. LiftedPrim a => Int -> [a] -> LiftedPrimArray a
liftedPrimArrayFromListN Int
n (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Int -> a
f [Int
0 .. Int
n forall a. Num a => a -> a -> a
- Int
1])

sizeofLiftedPrimArray :: LiftedPrim a => LiftedPrimArray a -> Int
sizeofLiftedPrimArray :: forall a. LiftedPrim a => LiftedPrimArray a -> Int
sizeofLiftedPrimArray pa :: LiftedPrimArray a
pa@(LiftedPrimArray ByteArray
arr) =
  let !elemSize :: Int
elemSize = forall a. LiftedPrim a => Proxy a -> Int
elemSizeLifted (forall (f :: * -> *) a. f a -> Proxy a
proxyForF LiftedPrimArray a
pa)
      !arrSize :: Int
arrSize = ByteArray -> Int
sizeofByteArray ByteArray
arr
  in  forall a. Integral a => a -> a -> a
div Int
arrSize Int
elemSize

cloneLiftedPrimArray :: LiftedPrim a => LiftedPrimArray a -> Int -> Int -> LiftedPrimArray a
cloneLiftedPrimArray :: forall a.
LiftedPrim a =>
LiftedPrimArray a -> Int -> Int -> LiftedPrimArray a
cloneLiftedPrimArray pa :: LiftedPrimArray a
pa@(LiftedPrimArray ByteArray
arr) Int
off Int
len =
  let !elemSize :: Int
elemSize = forall a. LiftedPrim a => Proxy a -> Int
elemSizeLifted (forall (f :: * -> *) a. f a -> Proxy a
proxyForF LiftedPrimArray a
pa)
      !byteOff :: Int
byteOff = Int
off forall a. Num a => a -> a -> a
* Int
elemSize
      !byteLen :: Int
byteLen = Int
len forall a. Num a => a -> a -> a
* Int
elemSize
      !arr' :: ByteArray
arr' = ByteArray -> Int -> Int -> ByteArray
cloneByteArray ByteArray
arr Int
byteOff Int
byteLen
  in  forall a. ByteArray -> LiftedPrimArray a
LiftedPrimArray ByteArray
arr'