{-# LANGUAGE OverloadedStrings #-}
-- | Utility functions for SD-JWT operations (low-level).
--
-- This module provides base64url encoding/decoding, salt generation,
-- and text/ByteString conversions used throughout the SD-JWT library.
--
-- == Usage
--
-- This module contains low-level utilities that are typically used internally
-- by other SD-JWT modules. Most users should use the higher-level APIs in:
--
-- * 'SDJWT.Issuer' - For issuers
-- * 'SDJWT.Holder' - For holders  
-- * 'SDJWT.Verifier' - For verifiers
--
-- These utilities may be useful for:
-- * Advanced use cases requiring custom implementations
-- * Library developers building on top of SD-JWT
-- * Testing and debugging
--
module SDJWT.Internal.Utils
  ( base64urlEncode
  , base64urlDecode
  , textToByteString
  , byteStringToText
  , hashToBytes
  , splitJSONPointer
  , unescapeJSONPointer
  , constantTimeEq
  , generateSalt  -- Internal use only, not part of public API
  , groupPathsByFirstSegment
  ) where

import qualified Data.ByteString.Base64.URL as Base64
import qualified Data.ByteString as BS
import qualified Data.Text as T
import qualified Data.Text.Encoding as TE
import qualified Crypto.Random as RNG
import qualified Crypto.Hash as Hash
import qualified Data.ByteArray as BA
import qualified Data.Map.Strict as Map
import Control.Monad.IO.Class (MonadIO, liftIO)
import SDJWT.Internal.Types (HashAlgorithm(..))

-- | Base64url encode a ByteString (without padding).
--
-- This function encodes a ByteString using base64url encoding as specified
-- in RFC 4648 Section 5. The result is URL-safe and does not include padding.
--
-- >>> base64urlEncode "Hello, World!"
-- "SGVsbG8sIFdvcmxkIQ"
base64urlEncode :: BS.ByteString -> T.Text
base64urlEncode = TE.decodeUtf8 . Base64.encodeUnpadded

-- | Base64url decode a Text (handles padding).
--
-- This function decodes a base64url-encoded Text back to a ByteString.
-- It handles both padded and unpadded input.
--
-- Returns 'Left' with an error message if decoding fails.
base64urlDecode :: T.Text -> Either T.Text BS.ByteString
base64urlDecode t =
  case Base64.decodeUnpadded (TE.encodeUtf8 t) of
    Left err -> Left $ T.pack $ show err
    Right bs -> Right bs

-- | Generate a cryptographically secure random salt.
--
-- Generates 128 bits (16 bytes) of random data as recommended by RFC 9901.
-- This salt is used when creating disclosures to ensure that digests cannot
-- be guessed or brute-forced.
--
-- The salt is generated using cryptonite's secure random number generator.
generateSalt :: MonadIO m => m BS.ByteString
generateSalt = liftIO $ RNG.getRandomBytes 16

-- | Convert Text to ByteString (UTF-8 encoding).
--
-- This is a convenience function that encodes Text as UTF-8 ByteString.
textToByteString :: T.Text -> BS.ByteString
textToByteString = TE.encodeUtf8

-- | Convert ByteString to Text (UTF-8 decoding).
--
-- This is a convenience function that decodes a UTF-8 ByteString to Text.
-- Note: This will throw an exception if the ByteString is not valid UTF-8.
-- For safe decoding, use 'Data.Text.Encoding.decodeUtf8'' instead.
byteStringToText :: BS.ByteString -> T.Text
byteStringToText = TE.decodeUtf8

-- | Hash bytes using the specified hash algorithm.
--
-- This function computes a cryptographic hash of the input ByteString
-- using the specified hash algorithm (SHA-256, SHA-384, or SHA-512).
-- Returns the hash digest as a ByteString.
hashToBytes :: HashAlgorithm -> BS.ByteString -> BS.ByteString
hashToBytes SHA256 bs = BA.convert (Hash.hash bs :: Hash.Digest Hash.SHA256)
hashToBytes SHA384 bs = BA.convert (Hash.hash bs :: Hash.Digest Hash.SHA384)
hashToBytes SHA512 bs = BA.convert (Hash.hash bs :: Hash.Digest Hash.SHA512)

-- | Split JSON Pointer path by "/", respecting escapes (RFC 6901).
--
-- This function properly handles JSON Pointer escaping:
--
-- - "~1" represents a literal forward slash "/"
-- - "~0" represents a literal tilde "~"
--
-- Examples:
--
-- - "a\/b" → ["a", "b"]
-- - "a~1b" → ["a\/b"] (escaped slash)
-- - "a~0b" → ["a~b"] (escaped tilde)
-- - "a~1\/b" → ["a\/", "b"] (escaped slash becomes "\/", then "\/" is separator)
-- 
-- Note: This function is designed for relative JSON Pointer paths (without leading "/").
-- Leading slashes are stripped, trailing slashes don't create empty segments,
-- and consecutive slashes are collapsed.
splitJSONPointer :: T.Text -> [T.Text]
splitJSONPointer path = go path [] ""
  where
    go remaining acc current
      | T.null remaining = reverse (if T.null current then acc else current : acc)
      | T.take 2 remaining == "~1" =
          -- Escaped slash (must check before checking for unescaped "/")
          go (T.drop 2 remaining) acc (current <> "/")
      | T.take 2 remaining == "~0" =
          -- Escaped tilde
          go (T.drop 2 remaining) acc (current <> "~")
      | T.head remaining == '/' =
          -- Found unescaped slash (after checking escape sequences)
          go (T.tail remaining) (if T.null current then acc else current : acc) ""
      | otherwise =
          -- Regular character
          go (T.tail remaining) acc (T.snoc current (T.head remaining))

-- | Unescape JSON Pointer segment (RFC 6901).
--
-- Converts escape sequences back to literal characters:
--
-- - "~1" → "/"
-- - "~0" → "~"
--
-- Note: Order matters - must replace ~1 before ~0 to avoid double-replacement.
unescapeJSONPointer :: T.Text -> T.Text
unescapeJSONPointer = T.replace "~1" "/" . T.replace "~0" "~"

-- | Constant-time equality comparison for ByteStrings.
--
-- This function performs a constant-time comparison to prevent timing attacks.
-- It compares two ByteStrings byte-by-byte and always takes the same amount
-- of time regardless of where the first difference occurs.
--
-- SECURITY: Use this function when comparing cryptographic values like digests,
-- hashes, or other sensitive data that could be exploited via timing attacks.
--
-- Implementation uses cryptonite's 'BA.constEq' which provides constant-time
-- comparison for ByteArray instances. ByteString is a ByteArray instance.
--
constantTimeEq :: BS.ByteString -> BS.ByteString -> Bool
constantTimeEq a b
  | BS.length a /= BS.length b = False
  | otherwise = BA.constEq a b

-- | Group paths by their first segment.
--
-- This is a common pattern for processing nested JSON Pointer paths.
-- Empty paths are grouped under an empty string key.
--
-- Example:
--   groupPathsByFirstSegment [["a", "b"], ["a", "c"], ["x"]] 
--   = Map.fromList [("a", [["b"], ["c"]]), ("x", [[]])]
groupPathsByFirstSegment :: [[T.Text]] -> Map.Map T.Text [[T.Text]]
groupPathsByFirstSegment nestedPaths =
  Map.fromListWith (++) $ map (\path -> case path of
    [] -> ("", [])
    (first:rest) -> (first, [rest])) nestedPaths

