{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE RankNTypes #-}
module Builder
    ( Builder, builderLength, bytes, copyBuilderToPtr, create, run, runRelaxed
    , runToBlock, unsafeCreate
    ) where

import Basement.Types.OffsetSize

import Data.ByteArray (ByteArray)

import Control.Monad.ST
import Control.Monad.ST.Unsafe

import Data.Semigroup
import Data.Word

import Foreign.Ptr (Ptr, castPtr, plusPtr)

import Block (Block)
import Marking (Classified, Leak(..), SecurityMarking(..))
import SecureBytes (SecureBytes)
import qualified Block
import qualified ByteArrayST as ST
import qualified SecureBytes

data Builder (marking :: SecurityMarking) = Builder
    { builderLength :: {-# UNPACK #-} !Int
    , copyBuilderToPtr :: forall a s. Ptr a -> ST s ()
    }

instance Semigroup (Builder marking) where
    b1 <> b2  = create (n1 + n2) $ \p ->
        copyBuilderToPtr b1 p >> copyBuilderToPtr b2 (p `plusPtr` n1)
      where
        n1 = builderLength b1
        n2 = builderLength b2

instance Monoid (Builder marking) where
    mempty = empty
    mconcat builders = create n (`loop` builders)
      where
        n = getSum $ Prelude.mconcat $ map (Sum . builderLength) builders
        loop !_ [] = return ()
        loop !p (b : bs) =
            copyBuilderToPtr b p >> loop (p `plusPtr` builderLength b) bs

instance Leak Builder

bytes :: Classified marking => SecureBytes marking -> Builder marking
bytes b = unsafeCreate (SecureBytes.length b) (SecureBytes.copyByteArrayToPtr b)

create :: Int -> (forall s. Ptr a -> ST s ()) -> Builder marking
create n f = Builder n (f . castPtr)
{-# INLINE create #-}

empty :: Builder marking
empty = Builder 0 $ \_ -> return ()

run :: Classified marking => Builder marking -> SecureBytes marking
run b = SecureBytes.unsafeCreate (builderLength b) (copyBuilderToPtr b)

runRelaxed :: ByteArray ba => Builder Pub -> ba
runRelaxed b = ST.unsafeCreate (builderLength b) (copyBuilderToPtr b)

runToBlock :: Builder Pub -> Block Word8
runToBlock b = runST $ do
    mb <- Block.newPinned (CountOf $ builderLength b)
    Block.withMutablePtr mb (copyBuilderToPtr b)
    Block.unsafeFreeze mb

unsafeCreate :: Int -> (Ptr a -> IO ()) -> Builder marking
unsafeCreate n f = create n (unsafeIOToST . f)
{-# INLINE unsafeCreate #-}
