-------------------------------------------------------------------------------
-- |
-- Module:      Crypto.PHKDF.Primitives.Subtle
-- Copyright:   (c) 2024 Auth Global
-- License:     Apache2
--
-------------------------------------------------------------------------------
module Crypto.PHKDF.Primitives.Subtle
  ( PhkdfCtx(..)
  , phkdfCtx_byteLen
  , phkdfCtx_unsafeFeed
  , PhkdfSlowCtx(..)
  , phkdfSlowCtx_lift
  , PhkdfGen(..)
  ) where

import           Prelude hiding (null)
import           Crypto.Sha256 as Sha256
import           Crypto.PHKDF.HMAC (HmacKeyLike)
import           Data.ByteString (ByteString)
import qualified Data.ByteString as B
import           Data.Foldable(foldl', null)
import           Data.Word

-- I should be using the counter inside the sha256 ctx, but this is a Proof of Concept

-- TODO: should phkdfCtx_length count bytes, or bits? Double-check how SHA256 internal counter
-- works. Decide how this should work. Then export it from Primitives module.
-- For truly bulletproof code, we probably need to be returning (Maybe Ctx), so that we don't
-- overflow SHA256's internal counter. This would be a bit of a conceptual problem with the
-- cryptohash-style interface I'm mimicking, not to mention the cryptohash implementation I
-- am depending upon.

-- note that there's an offset error w.r.t the sha256 internal counter and phkdfCtx_length, but
-- it's always 64 bytes.  As the internals of this module only care about the internal counter
-- modulo 64, this doesn't matter.  However we should probably export the SHA256 counter itself

data PhkdfCtx = PhkdfCtx
  { PhkdfCtx -> Sha256Ctx
phkdfCtx_state :: !Sha256Ctx
  , PhkdfCtx -> HmacKeyLike
phkdfCtx_hmacKeyLike :: !HmacKeyLike
  }

phkdfCtx_byteLen :: PhkdfCtx -> Word64
phkdfCtx_byteLen :: PhkdfCtx -> Word64
phkdfCtx_byteLen = Sha256Ctx -> Word64
sha256_byteCount (Sha256Ctx -> Word64)
-> (PhkdfCtx -> Sha256Ctx) -> PhkdfCtx -> Word64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PhkdfCtx -> Sha256Ctx
phkdfCtx_state

data P = P !Word64 !Sha256Ctx

phkdfCtx_unsafeFeed :: Foldable f => f ByteString -> PhkdfCtx -> PhkdfCtx
phkdfCtx_unsafeFeed :: forall (f :: * -> *).
Foldable f =>
f ByteString -> PhkdfCtx -> PhkdfCtx
phkdfCtx_unsafeFeed f ByteString
strs PhkdfCtx
ctx0 =
  if f ByteString -> Bool
forall a. f a -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null f ByteString
strs then PhkdfCtx
ctx0
  else PhkdfCtx
ctx0 {
    phkdfCtx_state = sha256_feeds strs (phkdfCtx_state ctx0)
  }

data PhkdfSlowCtx = PhkdfSlowCtx
  { PhkdfSlowCtx -> PhkdfCtx
phkdfSlowCtx_phkdfCtx :: !PhkdfCtx
  , PhkdfSlowCtx -> Word32
phkdfSlowCtx_counter :: !Word32
  , PhkdfSlowCtx -> ByteString
phkdfSlowCtx_tag :: !ByteString
  }

phkdfSlowCtx_lift :: (PhkdfCtx -> PhkdfCtx) -> PhkdfSlowCtx -> PhkdfSlowCtx
phkdfSlowCtx_lift :: (PhkdfCtx -> PhkdfCtx) -> PhkdfSlowCtx -> PhkdfSlowCtx
phkdfSlowCtx_lift PhkdfCtx -> PhkdfCtx
f PhkdfSlowCtx
ctx = PhkdfSlowCtx
ctx {
    phkdfSlowCtx_phkdfCtx = f (phkdfSlowCtx_phkdfCtx ctx)
  }

data PhkdfGen = PhkdfGen
  { PhkdfGen -> HmacKeyLike
phkdfGen_hmacKeyLike :: !HmacKeyLike
  , PhkdfGen -> ByteString
phkdfGen_extTag :: !ByteString
  , PhkdfGen -> Word32
phkdfGen_counter :: !Word32
  , PhkdfGen -> ByteString
phkdfGen_state :: !ByteString
  , PhkdfGen -> Maybe Sha256Ctx
phkdfGen_initCtx :: !(Maybe Sha256Ctx)
  }