{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}

module BLAKE3.IO
  ( -- * Hashing
    init
  , update
  , finalize
  , Hasher
  , allocRetHasher
  , Digest
  , allocRetDigest
  -- ** Memory
  , Raw.HasherInternal
  , copyHasher
  , withHasherInternal
  -- * Keyed hashing
  , Key
  , key
  , allocRetKey
  , initKeyed
  -- * Key derivation
  , Context
  , context
  , initDerive
  -- * Constants
  , Raw.HASHER_ALIGNMENT
  , Raw.HASHER_SIZE
  , Raw.KEY_LEN
  , Raw.BLOCK_SIZE
  , Raw.DEFAULT_DIGEST_LEN
  )
  where

import Control.Monad (guard)
import qualified Data.ByteArray as BA
import qualified Data.ByteArray.Encoding as BA
import Data.Foldable
import qualified Data.Memory.PtrMethods as BA
import Data.Proxy
import Data.String
import Data.Word
import Foreign.Ptr
import Foreign.Storable
import GHC.TypeLits
import Prelude hiding (init)

import qualified BLAKE3.Raw as Raw

--------------------------------------------------------------------------------

-- | Immutable BLAKE3 hashing state.
--
-- Obtain with 'BLAKE3.hasher' or 'BLAKE3.hasherKeyed'.
newtype Hasher = Hasher BA.ScrubbedBytes
  -- deriving newtype (BA.ByteArrayAccess)

-- | Allocate 'Hasher'.
--
-- The 'Hasher' is wiped and freed as soon as it becomes unused.
allocRetHasher
  :: forall a
  .  (Ptr Raw.HasherInternal -> IO a)  -- ^ Initialize 'Raw.HASHER_SIZE' bytes.
  -> IO (a, Hasher)
allocRetHasher :: (Ptr HasherInternal -> IO a) -> IO (a, Hasher)
allocRetHasher g :: Ptr HasherInternal -> IO a
g = do
  let size :: Int
size = Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Proxy HASHER_SIZE -> Integer
forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (Proxy HASHER_SIZE
forall k (t :: k). Proxy t
Proxy @Raw.HASHER_SIZE))
  (a :: a
a, bs :: ScrubbedBytes
bs) <- Int -> (Ptr HasherInternal -> IO a) -> IO (a, ScrubbedBytes)
forall ba p a. ByteArray ba => Int -> (Ptr p -> IO a) -> IO (a, ba)
BA.allocRet Int
size Ptr HasherInternal -> IO a
g
  (a, Hasher) -> IO (a, Hasher)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a
a, ScrubbedBytes -> Hasher
Hasher ScrubbedBytes
bs)

-- | Mutate the given 'Hasher'.
withHasherInternal
  :: Hasher
  -> (Ptr Raw.HasherInternal -> IO a) -- ^ Read or write.
  -> IO a
withHasherInternal :: Hasher -> (Ptr HasherInternal -> IO a) -> IO a
withHasherInternal (Hasher x :: ScrubbedBytes
x) = ScrubbedBytes -> (Ptr HasherInternal -> IO a) -> IO a
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
BA.withByteArray ScrubbedBytes
x

-- | Copy an inmutable 'Hasher'.
copyHasher :: Hasher -> IO Hasher -- ^
copyHasher :: Hasher -> IO Hasher
copyHasher (Hasher x :: ScrubbedBytes
x) = (ScrubbedBytes -> Hasher) -> IO ScrubbedBytes -> IO Hasher
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ScrubbedBytes -> Hasher
Hasher (IO ScrubbedBytes -> IO Hasher) -> IO ScrubbedBytes -> IO Hasher
forall a b. (a -> b) -> a -> b
$ ScrubbedBytes -> (Ptr Any -> IO ()) -> IO ScrubbedBytes
forall bs1 bs2 p.
(ByteArrayAccess bs1, ByteArray bs2) =>
bs1 -> (Ptr p -> IO ()) -> IO bs2
BA.copy ScrubbedBytes
x (IO () -> Ptr Any -> IO ()
forall a b. a -> b -> a
const (() -> IO ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()))

--------------------------------------------------------------------------------

-- | Output from BLAKE3 algorithm, of @len@ bytes.
--
-- The default digest length for BLAKE3 is 'Raw.DEFAULT_DIGEST_LEN'.
newtype Digest (len :: Nat) = Digest BA.ScrubbedBytes
  deriving newtype ( Eq -- ^ Constant time.
                   , Digest len -> Int
Digest len -> Ptr p -> IO ()
Digest len -> (Ptr p -> IO a) -> IO a
(Digest len -> Int)
-> (forall p a. Digest len -> (Ptr p -> IO a) -> IO a)
-> (forall p. Digest len -> Ptr p -> IO ())
-> ByteArrayAccess (Digest len)
forall p. Digest len -> Ptr p -> IO ()
forall ba.
(ba -> Int)
-> (forall p a. ba -> (Ptr p -> IO a) -> IO a)
-> (forall p. ba -> Ptr p -> IO ())
-> ByteArrayAccess ba
forall p a. Digest len -> (Ptr p -> IO a) -> IO a
forall (len :: Nat). Digest len -> Int
forall (len :: Nat) p. Digest len -> Ptr p -> IO ()
forall (len :: Nat) p a. Digest len -> (Ptr p -> IO a) -> IO a
copyByteArrayToPtr :: Digest len -> Ptr p -> IO ()
$ccopyByteArrayToPtr :: forall (len :: Nat) p. Digest len -> Ptr p -> IO ()
withByteArray :: Digest len -> (Ptr p -> IO a) -> IO a
$cwithByteArray :: forall (len :: Nat) p a. Digest len -> (Ptr p -> IO a) -> IO a
length :: Digest len -> Int
$clength :: forall (len :: Nat). Digest len -> Int
BA.ByteArrayAccess)

-- | Base 16 (hexadecimal).
instance Show (Digest len) where
  show :: Digest len -> String
show (Digest x :: ScrubbedBytes
x) = ScrubbedBytes -> String
forall x. ByteArrayAccess x => x -> String
showBase16 ScrubbedBytes
x

-- | Allocate a 'Digest'.
--
-- The 'Digest' is wiped and freed as soon as it becomes unused.
allocRetDigest
  :: forall len a
  .  KnownNat len
  => (Ptr Word8 -> IO a)  -- ^ Initialize @len@ bytes.
  -> IO (a, Digest len)
allocRetDigest :: (Ptr Word8 -> IO a) -> IO (a, Digest len)
allocRetDigest g :: Ptr Word8 -> IO a
g = do
  let size :: Int
size = Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Proxy len -> Integer
forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (Proxy len
forall k (t :: k). Proxy t
Proxy @len))
  (a :: a
a, bs :: ScrubbedBytes
bs) <- Int -> (Ptr Word8 -> IO a) -> IO (a, ScrubbedBytes)
forall ba p a. ByteArray ba => Int -> (Ptr p -> IO a) -> IO (a, ba)
BA.allocRet Int
size Ptr Word8 -> IO a
g
  (a, Digest len) -> IO (a, Digest len)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a
a, ScrubbedBytes -> Digest len
forall (len :: Nat). ScrubbedBytes -> Digest len
Digest ScrubbedBytes
bs)

--------------------------------------------------------------------------------

-- | Key used for keyed hashing mode.
--
-- Obtain with 'BLAKE3.key'.
--
-- See 'BLAKE3.hashKeyed'.
newtype Key = Key BA.ScrubbedBytes
  deriving newtype ( Eq -- ^ Constant time.
                   , Key -> Int
Key -> Ptr p -> IO ()
Key -> (Ptr p -> IO a) -> IO a
(Key -> Int)
-> (forall p a. Key -> (Ptr p -> IO a) -> IO a)
-> (forall p. Key -> Ptr p -> IO ())
-> ByteArrayAccess Key
forall p. Key -> Ptr p -> IO ()
forall ba.
(ba -> Int)
-> (forall p a. ba -> (Ptr p -> IO a) -> IO a)
-> (forall p. ba -> Ptr p -> IO ())
-> ByteArrayAccess ba
forall p a. Key -> (Ptr p -> IO a) -> IO a
copyByteArrayToPtr :: Key -> Ptr p -> IO ()
$ccopyByteArrayToPtr :: forall p. Key -> Ptr p -> IO ()
withByteArray :: Key -> (Ptr p -> IO a) -> IO a
$cwithByteArray :: forall p a. Key -> (Ptr p -> IO a) -> IO a
length :: Key -> Int
$clength :: Key -> Int
BA.ByteArrayAccess)

-- | Base 16 (hexadecimal).
instance Show Key where
  show :: Key -> String
show (Key x :: ScrubbedBytes
x) = ScrubbedBytes -> String
forall x. ByteArrayAccess x => x -> String
showBase16 ScrubbedBytes
x

keyLen :: Int
keyLen :: Int
keyLen = Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Proxy KEY_LEN -> Integer
forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (Proxy KEY_LEN
forall k (t :: k). Proxy t
Proxy @Raw.KEY_LEN))

-- | Obtain a 'Key' for use in BLAKE3 keyed hashing.
--
-- See 'BLAKE3.hashKeyed'.
key
  :: BA.ByteArrayAccess bin
  => bin -- ^ Key bytes. Must have length 'Raw.KEY_LEN'.
  -> Maybe Key -- ^
key :: bin -> Maybe Key
key bin :: bin
bin | bin -> Int
forall ba. ByteArrayAccess ba => ba -> Int
BA.length bin
bin Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
keyLen = Key -> Maybe Key
forall a. a -> Maybe a
Just (ScrubbedBytes -> Key
Key (bin -> ScrubbedBytes
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert bin
bin))
        | Bool
otherwise = Maybe Key
forall a. Maybe a
Nothing

-- | Allocate a 'Key'.
--
-- The 'Key' is wiped and freed as soon as it becomes unused.
allocRetKey
  :: forall a
  . (Ptr Word8 -> IO a) -- ^ Initialize 'Raw.KEY_LEN' bytes.
  -> IO (a, Key)
allocRetKey :: (Ptr Word8 -> IO a) -> IO (a, Key)
allocRetKey g :: Ptr Word8 -> IO a
g = do
  (a :: a
a, bs :: ScrubbedBytes
bs) <- Int -> (Ptr Word8 -> IO a) -> IO (a, ScrubbedBytes)
forall ba p a. ByteArray ba => Int -> (Ptr p -> IO a) -> IO (a, ba)
BA.allocRet Int
keyLen Ptr Word8 -> IO a
g
  (a, Key) -> IO (a, Key)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a
a, ScrubbedBytes -> Key
Key ScrubbedBytes
bs)

--------------------------------------------------------------------------------

-- | Context for BLAKE3 key derivation. Obtain with 'context'.
newtype Context = Context BA.Bytes -- ^ NUL-terminated 'CString'.
  deriving newtype (Context -> Context -> Bool
(Context -> Context -> Bool)
-> (Context -> Context -> Bool) -> Eq Context
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Context -> Context -> Bool
$c/= :: Context -> Context -> Bool
== :: Context -> Context -> Bool
$c== :: Context -> Context -> Bool
Eq)

-- | Base 16 (hexadecimal).
instance Show Context where
  show :: Context -> String
show (Context x :: Bytes
x) = View Bytes -> String
forall x. ByteArrayAccess x => x -> String
showBase16 (Bytes -> Int -> View Bytes
forall bytes. ByteArrayAccess bytes => bytes -> Int -> View bytes
BA.takeView Bytes
x (Bytes -> Int
forall ba. ByteArrayAccess ba => ba -> Int
BA.length Bytes
x Int -> Int -> Int
forall a. Num a => a -> a -> a
- 1))

-- | 'fromString' is a /partial/ function that fails if the given 'String'
-- contains 'Char's outside the range @['toEnum' 1 .. 'toEnum' 255]@. 
--
-- See 'context' for more details.
instance IsString Context where
  fromString :: String -> Context
fromString s :: String
s = case (Char -> Maybe Word8) -> String -> Maybe [Word8]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Char -> Maybe Word8
charToWord8 String
s of
      Nothing -> String -> Context
forall a. HasCallStack => String -> a
error "Not a valid String for Context"
      Just w8s :: [Word8]
w8s -> Bytes -> Context
Context (Bytes -> Context) -> Bytes -> Context
forall a b. (a -> b) -> a -> b
$! [Word8] -> Bytes
forall a. ByteArray a => [Word8] -> a
BA.pack ([Word8]
w8s [Word8] -> [Word8] -> [Word8]
forall a. Semigroup a => a -> a -> a
<> [0])
    where
      charToWord8 :: Char -> Maybe Word8
      charToWord8 :: Char -> Maybe Word8
charToWord8 c :: Char
c = do
        let i :: Int
i = Char -> Int
forall a. Enum a => a -> Int
fromEnum Char
c
        Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> 0 Bool -> Bool -> Bool
&& Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< 256)
        Word8 -> Maybe Word8
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
i)

-- | Obtain a 'Context' for BLAKE3 key derivation.
--
-- The context should be hardcoded, globally unique, and
-- application-specific.
--
-- A good format for the context string is:
--
-- @
-- [application] [commit timestamp] [purpose]
-- @
--
-- For example:
--
-- @
-- example.com 2019-12-25 16:18:03 session tokens v1
-- @
context
  :: BA.ByteArrayAccess bin
  => bin -- ^ If @bin@ contains null bytes, this function returns 'Nothing'.
  -> Maybe Context
context :: bin -> Maybe Context
context src :: bin
src
  | (Word8 -> Bool) -> bin -> Bool
forall ba. ByteArrayAccess ba => (Word8 -> Bool) -> ba -> Bool
BA.any (0 Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
==) bin
src = Maybe Context
forall a. Maybe a
Nothing
  | Bool
otherwise = Context -> Maybe Context
forall a. a -> Maybe a
Just (Context -> Maybe Context) -> Context -> Maybe Context
forall a b. (a -> b) -> a -> b
$ Bytes -> Context
Context (Bytes -> Context) -> Bytes -> Context
forall a b. (a -> b) -> a -> b
$
      let srcLen :: Int
srcLen = bin -> Int
forall ba. ByteArrayAccess ba => ba -> Int
BA.length bin
src
          dstLen :: Int
dstLen = Int
srcLen Int -> Int -> Int
forall a. Num a => a -> a -> a
+ 1
      in Int -> (Ptr Word8 -> IO ()) -> Bytes
forall a p. ByteArray a => Int -> (Ptr p -> IO ()) -> a
BA.allocAndFreeze Int
dstLen ((Ptr Word8 -> IO ()) -> Bytes) -> (Ptr Word8 -> IO ()) -> Bytes
forall a b. (a -> b) -> a -> b
$ \pdst :: Ptr Word8
pdst ->
         bin -> (Ptr Word8 -> IO ()) -> IO ()
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
BA.withByteArray bin
src ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \psrc :: Ptr Word8
psrc -> do
           Ptr Word8 -> Ptr Word8 -> Int -> IO ()
BA.memCopy Ptr Word8
pdst Ptr Word8
psrc Int
srcLen
           Ptr Word8 -> Int -> Word8 -> IO ()
forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr Word8
pdst Int
srcLen (0 :: Word8)

--------------------------------------------------------------------------------

-- | Initialize a 'Raw.HasherInternal'.
init
  :: Ptr Raw.HasherInternal -- ^ Will be mutated.
  -> IO ()
init :: Ptr HasherInternal -> IO ()
init = Ptr HasherInternal -> IO ()
Raw.init

-- | Initialize a 'Raw.HasherInternal' in keyed mode.
initKeyed
  :: Ptr Raw.HasherInternal -- ^ Will be mutated.
  -> Key
  -> IO () -- ^
initKeyed :: Ptr HasherInternal -> Key -> IO ()
initKeyed ph :: Ptr HasherInternal
ph key0 :: Key
key0 =
  Key -> (Ptr Word8 -> IO ()) -> IO ()
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
BA.withByteArray Key
key0 ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \pkey :: Ptr Word8
pkey ->
  Ptr HasherInternal -> Ptr Word8 -> IO ()
Raw.init_keyed Ptr HasherInternal
ph Ptr Word8
pkey

-- | Initialize a 'Raw.HasherInternal' in derivation mode.
--
-- The input key material must be provided afterwards, using 'update'.
initDerive
  :: Ptr Raw.HasherInternal -- ^ Will be mutated.
  -> Context
  -> IO ()
initDerive :: Ptr HasherInternal -> Context -> IO ()
initDerive ph :: Ptr HasherInternal
ph (Context ctx :: Bytes
ctx) =
  Bytes -> (Ptr CChar -> IO ()) -> IO ()
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
BA.withByteArray Bytes
ctx ((Ptr CChar -> IO ()) -> IO ()) -> (Ptr CChar -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \pc :: Ptr CChar
pc ->
  Ptr HasherInternal -> Ptr CChar -> IO ()
Raw.init_derive_key Ptr HasherInternal
ph Ptr CChar
pc

-- | Update 'Raw.HasherInternal' state with new data.
update
  :: forall bin
  .  BA.ByteArrayAccess bin
  => Ptr Raw.HasherInternal -- ^ Will be mutated.
  -> [bin]
  -> IO () -- ^
update :: Ptr HasherInternal -> [bin] -> IO ()
update ph :: Ptr HasherInternal
ph bins :: [bin]
bins =
  [bin] -> (bin -> IO ()) -> IO ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
t a -> (a -> f b) -> f ()
for_ [bin]
bins ((bin -> IO ()) -> IO ()) -> (bin -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \bin :: bin
bin ->
  bin -> (Ptr Word8 -> IO ()) -> IO ()
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
BA.withByteArray bin
bin ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \pbin :: Ptr Word8
pbin ->
  Ptr HasherInternal -> Ptr Word8 -> CSize -> IO ()
Raw.update Ptr HasherInternal
ph Ptr Word8
pbin (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral (bin -> Int
forall ba. ByteArrayAccess ba => ba -> Int
BA.length bin
bin))

-- | Finalize 'Raw.HasherInternal' state and obtain a digest.
--
-- The 'Raw.HasherInternal' is mutated.
finalize
  :: forall len
  .  KnownNat len
  => Ptr Raw.HasherInternal -- ^ Will be mutated.
  -> IO (Digest len) -- ^
finalize :: Ptr HasherInternal -> IO (Digest len)
finalize ph :: Ptr HasherInternal
ph =
  (((), Digest len) -> Digest len)
-> IO ((), Digest len) -> IO (Digest len)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((), Digest len) -> Digest len
forall a b. (a, b) -> b
snd (IO ((), Digest len) -> IO (Digest len))
-> IO ((), Digest len) -> IO (Digest len)
forall a b. (a -> b) -> a -> b
$ (Ptr Word8 -> IO ()) -> IO ((), Digest len)
forall (len :: Nat) a.
KnownNat len =>
(Ptr Word8 -> IO a) -> IO (a, Digest len)
allocRetDigest ((Ptr Word8 -> IO ()) -> IO ((), Digest len))
-> (Ptr Word8 -> IO ()) -> IO ((), Digest len)
forall a b. (a -> b) -> a -> b
$ \pd :: Ptr Word8
pd ->
  Ptr HasherInternal -> Ptr Word8 -> CSize -> IO ()
Raw.finalize Ptr HasherInternal
ph Ptr Word8
pd (Integer -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Proxy len -> Integer
forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (Proxy len
forall k (t :: k). Proxy t
Proxy @len)))

--------------------------------------------------------------------------------

showBase16 :: BA.ByteArrayAccess x => x -> String
showBase16 :: x -> String
showBase16 = (Word8 -> Char) -> [Word8] -> String
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Int -> Char
forall a. Enum a => Int -> a
toEnum (Int -> Char) -> (Word8 -> Int) -> Word8 -> Char
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word8 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral)
           ([Word8] -> String) -> (x -> [Word8]) -> x -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteArrayAccess ScrubbedBytes => ScrubbedBytes -> [Word8]
forall a. ByteArrayAccess a => a -> [Word8]
BA.unpack @BA.ScrubbedBytes
           (ScrubbedBytes -> [Word8]) -> (x -> ScrubbedBytes) -> x -> [Word8]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Base -> x -> ScrubbedBytes
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
Base -> bin -> bout
BA.convertToBase Base
BA.Base16