----------------------------------------------------------------------------
-- |
-- Module      :  Foreign.Ptr.Builder
-- Copyright   :  (c) Sergey Vinokurov 2022
-- License     :  Apache-2.0 (see LICENSE)
-- Maintainer  :  serg.foo@gmail.com
----------------------------------------------------------------------------

{-# LANGUAGE MagicHash        #-}
{-# LANGUAGE UnboxedTuples    #-}
{-# LANGUAGE UnliftedNewtypes #-}

module Foreign.Ptr.Builder
  ( Builder
  , withByteArrayLen
  , withPtrLen
  , storable
  , prim
  , Int#
  , ByteArray#

  , BuilderCache
  , coerceBuilderCache
  , withBuilderCache
  ) where

import Data.Primitive.Types as Prim
import Emacs.Module.Assert
import Foreign
import Foreign.Storable as Storable
import GHC.Exts
import GHC.IO

type Writer = Addr# -> Int# -> IO ()

data Builder a = Builder Int# Writer

instance Show (Builder a) where
  showsPrec :: Int -> Builder a -> ShowS
showsPrec Int
n (Builder Int#
k Writer
_) = Bool -> ShowS -> ShowS
showParen (Int
n forall a. Ord a => a -> a -> Bool
>= Int
10) (String -> ShowS
showString String
"Builder " forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Show a => Int -> a -> ShowS
showsPrec Int
11 (Int# -> Int
I# Int#
k))

instance Semigroup (Builder a) where
  {-# INLINE (<>) #-}
  Builder Int#
n Writer
f <> :: Builder a -> Builder a -> Builder a
<> Builder Int#
m Writer
g =
    forall {k} (a :: k). Int# -> Writer -> Builder a
Builder (Int#
n Int# -> Int# -> Int#
+# Int#
m) (\Addr#
ptr Int#
off -> Writer
f Addr#
ptr Int#
off forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> Writer
g Addr#
ptr (Int#
off Int# -> Int# -> Int#
+# Int#
n))

instance Monoid (Builder a) where
  {-# INLINE mempty #-}
  mempty :: Builder a
mempty = forall {k} (a :: k). Int# -> Writer -> Builder a
Builder Int#
0# (\Addr#
_ Int#
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ())

isPowerOfTwo :: Int# -> Bool
isPowerOfTwo :: Int# -> Bool
isPowerOfTwo Int#
x = Int# -> Bool
isTrue# (Word# -> Word# -> Word#
and# Word#
x' Word#
y' Word# -> Word# -> Int#
`eqWord#` Word#
0##)
  where
    x' :: Word#
x' = Int# -> Word#
int2Word# Int#
x
    y' :: Word#
y' = Int# -> Word#
int2Word# (Int#
x Int# -> Int# -> Int#
-# Int#
1#)

{-# INLINE withByteArrayLen #-}
withByteArrayLen
  :: forall a b. (WithCallStack, Storable a)
  => BuilderCache a
  -> Builder a
  -> (Int# -> ByteArray# -> IO b)
  -> IO b
withByteArrayLen :: forall a b.
(WithCallStack, Storable a) =>
BuilderCache a -> Builder a -> (Int# -> ByteArray# -> IO b) -> IO b
withByteArrayLen (BuilderCache MutableByteArray# RealWorld
cache#) (Builder Int#
size Writer
f) Int# -> ByteArray# -> IO b
action =
  forall a. Bool -> String -> a -> a
emacsAssert (Int# -> Bool
isPowerOfTwo Int#
align) String
"Alignment should be a power of two" forall a b. (a -> b) -> a -> b
$
  forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO forall a b. (a -> b) -> a -> b
$ \State# RealWorld
s0 ->
    case forall d. MutableByteArray# d -> State# d -> (# State# d, Int# #)
getSizeofMutableByteArray# MutableByteArray# RealWorld
cache# State# RealWorld
s0 of
      (# State# RealWorld
s1, Int#
cacheSize #) ->
        let !(# State# RealWorld
sLast1, ByteArray#
barr# #) =
              if Int# -> Bool
isTrue# (Int#
cacheSize Int# -> Int# -> Int#
>=# Int#
requiredSize)
              then
                case forall a. IO a -> State# RealWorld -> (# State# RealWorld, a #)
unIO (Writer
f (forall d. MutableByteArray# d -> Addr#
mutableByteArrayContents# MutableByteArray# RealWorld
cache#) Int#
0#) State# RealWorld
s1 of
                  (# State# RealWorld
s2, () #) ->
                    forall d.
MutableByteArray# d -> State# d -> (# State# d, ByteArray# #)
unsafeFreezeByteArray# MutableByteArray# RealWorld
cache# State# RealWorld
s2
              else
                case forall d.
Int# -> Int# -> State# d -> (# State# d, MutableByteArray# d #)
newAlignedPinnedByteArray# Int#
requiredSize Int#
align State# RealWorld
s1 of
                  (# State# RealWorld
s2, MutableByteArray# RealWorld
mbarr# #) ->
                    case forall a. IO a -> State# RealWorld -> (# State# RealWorld, a #)
unIO (Writer
f (forall d. MutableByteArray# d -> Addr#
mutableByteArrayContents# MutableByteArray# RealWorld
mbarr#) Int#
0#) State# RealWorld
s2 of
                      (# State# RealWorld
s3, () #) ->
                        forall d.
MutableByteArray# d -> State# d -> (# State# d, ByteArray# #)
unsafeFreezeByteArray# MutableByteArray# RealWorld
mbarr# State# RealWorld
s3
        in
          -- keepAlive# barr# sLast1 (unIO (action size barr#))
          -- Touch is measurably faster but unsound if the action diverges.
          case forall a. IO a -> State# RealWorld -> (# State# RealWorld, a #)
unIO (Int# -> ByteArray# -> IO b
action Int#
size ByteArray#
barr#) State# RealWorld
sLast1 of
            (# State# RealWorld
sLast2, b
res #) ->
              case touch# :: forall a. a -> State# RealWorld -> State# RealWorld
touch# ByteArray#
barr# State# RealWorld
sLast2 of
                State# RealWorld
sLast3 -> (# State# RealWorld
sLast3, b
res #)
  where
    !requiredSize :: Int#
requiredSize  = Int#
size Int# -> Int# -> Int#
*# Int#
elemSize
    !(I# Int#
elemSize) = forall a. Storable a => a -> Int
Storable.sizeOf    (forall a. HasCallStack => a
undefined :: a)
    !(I# Int#
align)    = forall a. Storable a => a -> Int
Storable.alignment (forall a. HasCallStack => a
undefined :: a)

{-# INLINE withPtrLen #-}
withPtrLen
  :: forall a b. (WithCallStack, Storable a)
  => BuilderCache a -> Builder a -> (Int -> Ptr a -> IO b) -> IO b
withPtrLen :: forall a b.
(WithCallStack, Storable a) =>
BuilderCache a -> Builder a -> (Int -> Ptr a -> IO b) -> IO b
withPtrLen BuilderCache a
cache Builder a
b Int -> Ptr a -> IO b
action =
  forall a b.
(WithCallStack, Storable a) =>
BuilderCache a -> Builder a -> (Int# -> ByteArray# -> IO b) -> IO b
withByteArrayLen BuilderCache a
cache Builder a
b forall a b. (a -> b) -> a -> b
$ \Int#
size ByteArray#
barr ->
    Int -> Ptr a -> IO b
action (Int# -> Int
I# Int#
size) (forall a. Addr# -> Ptr a
Ptr (ByteArray# -> Addr#
byteArrayContents# ByteArray#
barr))


{-# INLINE storable #-}
storable :: Storable a => a -> Builder a
storable :: forall a. Storable a => a -> Builder a
storable a
x = forall {k} (a :: k). Int# -> Writer -> Builder a
Builder Int#
1# forall a b. (a -> b) -> a -> b
$ \Addr#
addr Int#
off -> forall a. Storable a => Ptr a -> Int -> a -> IO ()
pokeElemOff (forall a. Addr# -> Ptr a
Ptr Addr#
addr) (Int# -> Int
I# Int#
off) a
x

{-# INLINE prim #-}
prim :: Prim a => a -> Builder a
prim :: forall a. Prim a => a -> Builder a
prim a
x = forall {k} (a :: k). Int# -> Writer -> Builder a
Builder Int#
1# forall a b. (a -> b) -> a -> b
$ \Addr#
addr Int#
off ->
  forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO forall a b. (a -> b) -> a -> b
$ \State# RealWorld
s ->
    case forall a s. Prim a => Addr# -> Int# -> a -> State# s -> State# s
Prim.writeOffAddr# Addr#
addr Int#
off a
x State# RealWorld
s of
      State# RealWorld
s' -> (# State# RealWorld
s', () #)

newtype BuilderCache a = BuilderCache { forall {k} (a :: k). BuilderCache a -> MutableByteArray# RealWorld
_unBuilderCache :: MutableByteArray# RealWorld }

coerceBuilderCache :: BuilderCache a -> BuilderCache b
coerceBuilderCache :: forall {k} {k} (a :: k) (b :: k). BuilderCache a -> BuilderCache b
coerceBuilderCache = coerce :: forall a b. Coercible a b => a -> b
coerce

withBuilderCache :: forall a b. Storable a => Int -> (BuilderCache a -> IO b) -> IO b
withBuilderCache :: forall a b. Storable a => Int -> (BuilderCache a -> IO b) -> IO b
withBuilderCache (I# Int#
size) BuilderCache a -> IO b
f = do
  forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO forall a b. (a -> b) -> a -> b
$ \State# RealWorld
s0 ->
    case forall d.
Int# -> Int# -> State# d -> (# State# d, MutableByteArray# d #)
newAlignedPinnedByteArray# (Int#
size Int# -> Int# -> Int#
*# Int#
elemSize) Int#
align State# RealWorld
s0 of
      (# State# RealWorld
s1, MutableByteArray# RealWorld
mbarr #) ->
        keepAlive# :: forall a b. a -> State# RealWorld -> (State# RealWorld -> b) -> b
keepAlive# MutableByteArray# RealWorld
mbarr State# RealWorld
s1 (forall a. IO a -> State# RealWorld -> (# State# RealWorld, a #)
unIO (BuilderCache a -> IO b
f (forall {k} (a :: k). MutableByteArray# RealWorld -> BuilderCache a
BuilderCache MutableByteArray# RealWorld
mbarr)))
  where
    !(I# Int#
elemSize) = forall a. Storable a => a -> Int
Storable.sizeOf    (forall a. HasCallStack => a
undefined :: a)
    !(I# Int#
align)    = forall a. Storable a => a -> Int
Storable.alignment (forall a. HasCallStack => a
undefined :: a)