-- |
-- Module      : Data.CompactString.Internal
-- License     : BSD-style
-- Maintainer  : twanvl@gmail.com
-- Stability   : experimental
-- Portability : untested
-- 
-- Internal functions for the CompactString type.
--
module Data.CompactString.Internal (
        CompactString(..),
        Proxy, encoding, Encoding(..),
        PairS(..), MaybeS(..), unSP,
        AccEFL, FoldEFL, ImperativeLoop, ImperativeLoop_,
        ByteString(..), memcpy, inlinePerformIO,
        withBuffer, withBufferEnd, unsafeWithBuffer, unsafeWithBufferEnd, create,
        ord, unsafeChr, returnChr,
        plusPtr, peekByteOff, pokeByteOff, peek, poke,
        failMessage, moduleError, errorEmptyList, unsafeTry, unsafeTryIO
        ) where

import Foreign.Ptr              (Ptr)
import qualified Foreign.Ptr    (plusPtr)
import Foreign.Storable         (Storable, peek, poke)
import qualified Foreign.Storable
import Foreign.ForeignPtr       (withForeignPtr)

import Data.Word                (Word8, Word32)
import Data.Char                (ord)

import Control.Monad
import Control.Exception

#if defined(__GLASGOW_HASKELL__)
import GHC.Base                 (unsafeChr)
#else
import Data.Char                (chr)
#endif

import System.IO.Unsafe

import Data.ByteString.Internal (ByteString(..), memcpy, inlinePerformIO)
import qualified Data.ByteString.Internal as B

-- -----------------------------------------------------------------------------
--
-- Useful macros, until we have bang patterns
--

#define STRICT1(f) f _a             | _a                                     `seq` False = undefined
#define STRICT2(f) f _a _b          | _a `seq` _b                            `seq` False = undefined
#define STRICT3(f) f _a _b _c       | _a `seq` _b `seq` _c                   `seq` False = undefined
#define STRICT4(f) f _a _b _c _d    | _a `seq` _b `seq` _c `seq` _d          `seq` False = undefined
#define STRICT5(f) f _a _b _c _d _e | _a `seq` _b `seq` _c `seq` _d `seq` _e `seq` False = undefined

-- -----------------------------------------------------------------------------
--
-- Utilities
--

data PairS a b = {-# UNPACK #-} !a :*: {-# UNPACK #-} !b
data MaybeS a = NothingS | JustS {-# UNPACK #-} !a
infixl 2 :*:

unSP :: PairS a b -> (a,b)
unSP (a :*: b) = (a,b)

-- -----------------------------------------------------------------------------
--
-- Type
--

-- | A String using a compact, strict representation.
--   A @CompactString a@ is encoded using encoding @a@, for example @CompactString 'UTF8'@.
newtype CompactString a = CS { unCS :: ByteString }

-- Invariants used by CompactString:
--  - All characters in the bytestring are complete and valid (see below)
--  - Characters use the shortest possible encoding

-- -----------------------------------------------------------------------------
--
-- Encoding
--

-- From Data.Proxy proposal
data Proxy a

-- | A way to encode characters into bytes
class Encoding a where
        -- | Given a character returns the length of that character,
        --   and a function to write it to a memory buffer.
        --   if the encoding can not represent the character, the io function should @fail@.
        pokeCharFun :: Proxy a -> Char -> (Int, Ptr Word8 -> IO ())
        -- | The size needed to store a character
        pokeCharLen :: Proxy a -> Char -> Int
        pokeCharLen a = fst . pokeCharFun a
        -- | Write a character and return the size used
        pokeChar :: Proxy a -> Ptr Word8 -> Char -> IO Int
        pokeChar enc p c = case pokeCharFun enc c of (l,f) -> f p >> return l
        {-# INLINE pokeChar #-}
        -- | Write a character given a pointer to its last byte, and return the size used
        pokeCharRev :: Proxy a -> Ptr Word8 -> Char -> IO Int
        pokeCharRev enc p c = case pokeCharFun enc c of (l,f) -> f (p `plusPtr` (1-l)) >> return l
        {-# INLINE pokeCharRev #-}
        
        -- | Read a character from a memory buffer, return it and its length.
        --   The buffer is guaranteed to contain a valid character.
        peekChar :: Proxy a -> Ptr Word8 -> IO (Int, Char)
        -- | Return the length of the character in a memory buffer
        peekCharLen :: Proxy a -> Ptr Word8 -> IO Int
        -- | Read a character from a memory buffer, return it and its length,
        --   given a pointer to the /last/ byte.
        --   The buffer is guaranteed to contain a valid character.
        peekCharRev :: Proxy a -> Ptr Word8 -> IO (Int, Char)
        -- | Return the length of the character in a memory buffer,
        --   given a pointer to the /last/ byte.
        peekCharLenRev :: Proxy a -> Ptr Word8 -> IO Int
        
        -- | Read a character from a memory buffer, return it and its length.
        --   The buffer is not guaranteed to contain a valid character, so that should
        --   be verified. There is also no guarantee that the length of the buffer (also given)
        --   is sufficient to contain a whole character.
        peekCharSafe :: Proxy a -> Int -> Ptr Word8 -> IO (Int, Char)
        -- | Validate the length, should be used before peekCharSafe is called.
        --   Can be used to remove the number of checks used by peekCharSafe.
        validateLength :: Proxy a -> Int -> IO ()
        validateLength _ _ = return ()
        
        -- | Copy a character from one buffer to another, return the length of the character
        copyChar :: Proxy a -> Ptr Word8 -> Ptr Word8 -> IO Int
        copyChar enc src dst = do
                (l,c) <- peekChar enc src
                pokeChar enc dst c
                return l
        -- | Copy a character from one buffer to another, where the source pointer
        --   points to the last byte of the character.
        --   return the length of the character.
        copyCharRev :: Proxy a -> Ptr Word8 -> Ptr Word8 -> IO Int
        copyCharRev enc src dst = do
                (l,c) <- peekCharRev enc src
                pokeChar enc dst c
                return l
        
        -- | Is ASCII a valid subset of the encoding?
        containsASCII :: Proxy a -> Bool
        -- | Is @(a == b) == (toBS a == toBS b)@?
        validEquality  :: Proxy a -> Bool
        validEquality _ = True
        -- | Is @(a `compare` b) == (toBS a `compare` toBS b)@?
        validOrdering  :: Proxy a -> Bool
        -- | Is @(a `isSubstringOf` b) == (toBS a `isSubstringOf` toBS b)@?
        validSubstring :: Proxy a -> Bool
        
        -- | What is the maximum number of character a string with the given number of bytes contains?
        charCount :: Proxy a -> Int -> Int
        charCount _ n = n
        -- | What is the maximum number of bytes a string with the given number of characters contains?
        byteCount :: Proxy a -> Int -> Int
        -- | What is the maximum size in bytes after transforming (using map) a string?
        newSize :: Proxy a -> Int -> Int
        newSize e = byteCount e . charCount e
        
        -----------------------------------------------------------------------------
        --
        -- Fusion
        --
        
        doUpLoop :: Proxy a -> AccEFL acc -> acc -> ImperativeLoop acc
        doUpLoop enc f acc0 src dest len = loop 0 0 acc0
          where STRICT3(loop)
                loop src_off dest_off acc
                    | src_off >= len = return (acc :*: 0 :*: dest_off)
                    | otherwise      = do
                        (l,x) <- peekChar enc (src `plusPtr` src_off)
                        case f acc x of
                          (acc' :*: NothingS) ->    loop (src_off+l)  dest_off     acc'
                          (acc' :*: JustS x') -> do l' <- pokeChar enc (dest `plusPtr` dest_off) x'
                                                    loop (src_off+l) (dest_off+l') acc'
        
        doDownLoop :: Proxy a -> AccEFL acc -> acc -> ImperativeLoop acc
        doDownLoop enc f acc0 src dest len = loop (len-1) (newSize enc len-1) acc0
          where STRICT3(loop)
                loop src_off dest_off acc
                    | src_off < 0    = return (acc :*: dest_off + 1 :*: newSize enc len - (dest_off+1))
                    | otherwise      = do
                        (l,x) <- peekCharRev enc (src `plusPtr` src_off)
                        case f acc x of
                          (acc' :*: NothingS) ->    loop (src_off-l)  dest_off     acc'
                          (acc' :*: JustS x') -> do l' <- pokeCharRev enc (dest `plusPtr` dest_off) x'
                                                    loop (src_off-l) (dest_off-l') acc'
        
        doUpLoopFold :: Proxy a -> FoldEFL acc -> acc -> ImperativeLoop_ acc
        doUpLoopFold enc f acc0 src len = loop 0 acc0
          where STRICT2(loop)
                loop src_off acc
                    | src_off >= len = return acc
                    | otherwise      = do
                        (l,x) <- peekChar enc (src `plusPtr` src_off)
                        loop (src_off + l) (f acc x)
        
        doDownLoopFold :: Proxy a -> FoldEFL acc -> acc -> ImperativeLoop_ acc
        doDownLoopFold enc f acc0 src len = loop (len-1) acc0
          where STRICT2(loop)
                loop src_off acc
                    | src_off < 0 = return acc
                    | otherwise   = do
                        (l,x) <- peekCharRev enc (src `plusPtr` src_off)
                        loop (src_off - l) (f acc x)

-- -----------------------------------------------------------------------------
--
-- Fusion types
--

-- |Type of loop functions
type AccEFL acc  = acc -> Char -> (PairS acc (MaybeS Char))
type FoldEFL acc = acc -> Char ->        acc

-- | An imperative loop transforming a string, using an accumulating parameter.
--   See Data.ByteString.Fusion
type ImperativeLoop acc =
    Ptr Word8          -- pointer to the start of the source byte array
 -> Ptr Word8          -- pointer to ther start of the destination byte array
 -> Int                -- length of the source byte array
 -> IO (PairS (PairS acc Int) Int) -- result and offset, length of dest that was filled

-- | ImperativeLoop with no output
type ImperativeLoop_ acc =
    Ptr Word8          -- pointer to the start of the source byte array
 -> Int                -- length of the source byte array
 -> IO acc             -- result

-- -----------------------------------------------------------------------------
--
-- Utilities : buffer stuff
--

-- | Perform a function given a pointer to the buffer of a CompactString
withBuffer :: CompactString a -> (Ptr Word8 -> IO b) -> IO b
withBuffer (CS (PS x s _)) f = withForeignPtr x $ \p -> f (p `plusPtr` s)
{-# INLINE withBuffer #-}

-- | Perform a function given a pointer to the last byte in the buffer of a CompactString
withBufferEnd :: CompactString a -> (Ptr Word8 -> IO b) -> IO b
withBufferEnd (CS (PS x s l)) f = withForeignPtr x $ \p -> f (p `plusPtr` (s + l - 1))
{-# INLINE withBufferEnd #-}

-- | Perform a function given a pointer to the buffer of a CompactString
unsafeWithBuffer :: CompactString a -> (Ptr Word8 -> IO b) -> b
unsafeWithBuffer cs f = inlinePerformIO $ withBuffer cs f
{-# INLINE unsafeWithBuffer #-}

-- | Perform a function given a pointer to the last byte in the buffer of a CompactString
unsafeWithBufferEnd :: CompactString a -> (Ptr Word8 -> IO b) -> b
unsafeWithBufferEnd cs f = inlinePerformIO $ withBufferEnd cs f
{-# INLINE unsafeWithBufferEnd #-}

create :: Int -> (Ptr Word8 -> IO ()) -> IO (CompactString a)
create len f = liftM CS $ B.create len f
{-# INLINE create #-}

-- -----------------------------------------------------------------------------
--
-- Utilities : characters
--

#if !defined(__GLASGOW_HASKELL__)
unsafeChr = chr
#endif

-- | Safe variant of chr, combined with return; does more checks.
--   At least GHC does not check for surrogate pairs
returnChr :: Int -> Word32 -> IO (Int, Char)
returnChr a c
 | c >= 0xD800 && c <= 0xDFFF = failMessage "decode" "Surrogate character"
 | c > 0x10FFFF               = failMessage "decode" "Character out of range"
 | otherwise                  = return (a, unsafeChr $ fromIntegral c)

-- -----------------------------------------------------------------------------
--
-- Utilities : Type safety/inference
--

-- | plusPtr that preserves the pointer type
plusPtr :: Ptr a -> Int -> Ptr a
plusPtr = Foreign.Ptr.plusPtr

peekByteOff :: Storable a => Ptr a -> Int -> IO a
peekByteOff = Foreign.Storable.peekByteOff

pokeByteOff :: Storable a => Ptr a -> Int -> a -> IO ()
pokeByteOff = Foreign.Storable.pokeByteOff

encoding :: CompactString a -> Proxy a
encoding = undefined

-- -----------------------------------------------------------------------------
--
-- Utilities : Error handling
--
-- Common up near identical calls to `error' to reduce the number
-- constant strings created when compiled:
--

-- | Fail with an error message including the module name and function
failMessage :: String -> String -> IO a
failMessage fun msg = fail ("Data.CompactString." ++ fun ++ ':':' ':msg)
{-# NOINLINE failMessage #-}

-- | Raise an errorr, with the message including the module name and function
moduleError :: String -> String -> a
moduleError fun msg = error ("Data.CompactString." ++ fun ++ ':':' ':msg)
{-# NOINLINE moduleError #-}

errorEmptyList :: String -> a
errorEmptyList fun = moduleError fun "empty CompactString"
{-# NOINLINE errorEmptyList #-}

-- | Catch exceptions from fail in the IO monad, and wrap them in another monad
unsafeTry :: MonadPlus m => IO a -> m a
unsafeTry ioa = unsafePerformIO (unsafeTryIO ioa)

-- | Catch exceptions from fail in the IO monad, and wrap them in another monad
unsafeTryIO :: MonadPlus m => IO a -> IO (m a)
unsafeTryIO ioa = handleJust userErrors (return . fail) (fmap return ioa)