{-# LANGUAGE TypeFamilies               #-}
{-# LANGUAGE FlexibleContexts           #-}
{-# LANGUAGE ConstraintKinds            #-}
module Raaz.Hash.Sha.Util
       ( shaImplementation
       , length64Write
       , length128Write
       , Compressor
       ) where
import Data.Monoid                  ( (<>)      )
import Data.Word
import Foreign.Ptr                  ( Ptr       )
import Foreign.Storable             ( Storable  )
import Raaz.Core
import Raaz.Core.Transfer
import Raaz.Hash.Internal
type IsSha h    = (Primitive h, Storable h, Memory (HashMemory h))
type ShaMonad h = MT (HashMemory h)
type ShaWrite h = WriteM (ShaMonad h)
type LengthWrite h = BITS Word64 -> ShaWrite h
length64Write :: LengthWrite h
length64Write (BITS w) = write $ bigEndian w
length128Write :: LengthWrite h
length128Write w = writeStorable (0 :: Word64) <> length64Write w
type Compressor h = Pointer 
                  -> Int    
                  -> Ptr h  
                  -> IO ()
type ShaBufferAction bufSize h = Pointer       
                               -> bufSize      
                               -> ShaMonad h ()
liftCompressor          :: IsSha h => Compressor h -> ShaBufferAction (BLOCKS h) h
liftCompressor comp ptr = onSubMemory hashCell . withCellPointer . comp ptr . fromEnum
shaBlocks :: Primitive h
          => ShaBufferAction (BLOCKS h) h 
          -> ShaBufferAction (BLOCKS h) h
shaBlocks comp ptr nblocks =
  comp ptr nblocks >> updateLength nblocks
shaFinal :: (Primitive h, Storable h)
         => ShaBufferAction (BLOCKS h) h   
         -> LengthWrite h                  
         -> ShaBufferAction (BYTES Int) h
shaFinal comp lenW ptr msgLen = do
  updateLength msgLen
  totalBits <- extractLength
  let pad      = shaPad undefined msgLen $ lenW totalBits
      blocks   = atMost $ bytesToWrite pad
      in unsafeWrite pad ptr >> comp ptr blocks
shaPad :: IsSha h
       => h
       -> BYTES Int 
       -> ShaWrite h
       -> ShaWrite h
shaPad h msgLen = glueWrites 0 boundary hdr
  where skipMessage = skipWrite msgLen
        oneBit      = writeStorable (0x80 :: Word8)
        hdr         = skipMessage <> oneBit
        boundary    = blocksOf 1 h
shaImplementation :: IsSha h
                  => String                   
                  -> String                   
                  -> Compressor  h
                  -> LengthWrite h
                  -> HashI h (HashMemory h)
shaImplementation nam des comp lenW
  = HashI { hashIName               = nam
          , hashIDescription        = des
          , compress                = shaBlocks shaComp
          , compressFinal           = shaFinal  shaComp lenW
          , compressStartAlignment  = wordAlignment
          }
  where shaComp = liftCompressor comp
{-# INLINE shaImplementation #-}