{-# LANGUAGE OverloadedStrings #-}

module Data.HashTree.Internal (
    Settings(..)
  , defaultSettings
  , MerkleHashTrees(..)
  , digest
  , info
  , currentHead
  , empty
  , fromList
  , toHashTree
  , add
  , InclusionProof(..)
  , defaultInclusionProof
  , generateInclusionProof
  , verifyInclusionProof
  , ConsistencyProof(..)
  , defaultConsistencyProof
  , TreeSize
  , Index
  , generateConsistencyProof
  , verifyConsistencyProof
  ) where

import Crypto.Hash (Digest, SHA256, HashAlgorithm, hash)
import Data.Bits (testBit, finiteBitSize, countLeadingZeros, (.&.), unsafeShiftR)
import Data.ByteArray (ByteArrayAccess)
import qualified Data.ByteArray as BA
import Data.ByteString (ByteString)
import qualified Data.ByteString as BS
import Data.ByteString.Char8 ()
import Data.List (foldl')
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as Map
import Data.IntMap.Strict (IntMap)
import qualified Data.IntMap.Strict as IntMap

-- $setup
-- >>> :set -XOverloadedStrings

----------------------------------------------------------------

-- | Settings for Merkle Hash Trees.
--   The first parameter is input data type.
--   The second one is digest data type.
--
-- To create this, use 'defaultSettings':
--
-- > defaultSettings { hash0 = ..., hash1 = ..., hash2 = ... }
data Settings inp ha = Settings {
    -- | A hash value for non input element.
    hash0 :: Digest ha
    -- | A hash function for one input element to calculate the leaf digest.
  , hash1 :: inp -> Digest ha
    -- | A hash function for two input elements to calculate the internal digest.
  , hash2 :: Digest ha -> Digest ha -> Digest ha
  }

sha256 :: ByteString -> Digest SHA256
sha256 = hash

-- | A default Settings with 'ByteString' and 'SHA256'.
--   This can be used for CT(Certificate Transparency) defined in RFC 6962.
defaultSettings :: Settings ByteString SHA256
defaultSettings = Settings {
    hash0 = sha256 ""
  , hash1 = \x -> sha256 (BS.singleton 0x00 `BS.append` x)
  , hash2 = \x y -> sha256 $ BS.concat [BS.singleton 0x01, BA.convert x, BA.convert y]
  }

----------------------------------------------------------------

-- | The position of the target element from 0.
type Index = Int

-- | The size of hash tree.
type TreeSize = Int

-- | The data type for Merkle Hash Trees.
--   The first parameter is input data type.
--   The second one is digest data type.
data MerkleHashTrees inp ha = MerkleHashTrees {
    settings  :: !(Settings inp ha)
    -- | Getting the log size
  , size      :: !TreeSize
    -- index is size of HashTree
    -- 0 for Empty
    -- 1 for Leaf 0 0
    -- 'size' for the last HashTree
  , hashtrees :: !(IntMap (HashTree inp ha)) -- the Int key is TreeSize
  , indices   :: !(Map (Digest ha) Index)
  }

-- | Getting the Merkle Tree Hash.
digest :: TreeSize -> MerkleHashTrees inp ha -> Maybe (Digest ha)
digest tsiz mht = case IntMap.lookup tsiz (hashtrees mht) of
    Nothing -> Nothing
    Just ht -> Just $ value ht

currentHead :: MerkleHashTrees inp ha -> Maybe (HashTree inp ha)
currentHead (MerkleHashTrees _ tsiz htdb _) = IntMap.lookup tsiz htdb

-- | Getting the root information of the Merkle Hash Tree.
--   A pair of the current size and the current Merle Tree Hash is returned.
info :: MerkleHashTrees inp ha -> (TreeSize, Digest ha)
info mht = (siz, h)
  where
    siz = size mht
    Just h = digest siz mht

----------------------------------------------------------------

data HashTree inp ha =
    Empty !(Digest ha)
  | Leaf  !(Digest ha) !Index inp
  | Node  !(Digest ha) !Index !Index !(HashTree inp ha) !(HashTree inp ha)
  deriving (Eq, Show)

-- | Creating an empty 'MerkleHashTrees'.
--
-- >>> info $ empty defaultSettings
-- (0,e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855)
empty :: Settings inp ha -> MerkleHashTrees inp ha
empty set = MerkleHashTrees {
    settings  = set
  , size      = 0
  , hashtrees = IntMap.insert 0 (Empty (hash0 set)) IntMap.empty
  , indices   = Map.empty
  }

value :: HashTree inp ha -> Digest ha
value (Empty ha)         = ha
value (Leaf  ha _ _)     = ha
value (Node  ha _ _ _ _) = ha

----------------------------------------------------------------

idxl :: HashTree inp ha -> Index
idxl (Leaf _ i _)     = i
idxl (Node _ i _ _ _) = i
idxl _                = error "idxl"

idxr :: HashTree inp ha -> Index
idxr (Leaf _ i _)     = i
idxr (Node _ _ i _ _) = i
idxr (Empty _)        = error "idxr"

----------------------------------------------------------------

-- | Creating a Merkle Hash Tree from a list of elements. O(n log n)
--
-- >>> info $ fromList defaultSettings ["0","1","2"]
-- (3,725d5230db68f557470dc35f1d8865813acd7ebb07ad152774141decbae71327)
fromList :: (ByteArrayAccess inp, HashAlgorithm ha)
         => Settings inp ha -> [inp] -> MerkleHashTrees inp ha
fromList set xs = foldl' (flip add) (empty set) xs

-- | Adding (appending) an element. O(log n)
--
-- >>> info $ add "1" $ empty defaultSettings
-- (1,2215e8ac4e2b871c2a48189e79738c956c081e23ac2f2415bf77da199dfd920c)
add :: (ByteArrayAccess inp, HashAlgorithm ha)
     => inp -> MerkleHashTrees inp ha -> MerkleHashTrees inp ha
add a mht@(MerkleHashTrees set tsiz htdb idb) =
    case Map.lookup hx idb of
        Just _  -> mht
        Nothing -> case IntMap.lookup tsiz htdb of
            Just ht -> let ht' = newht ht
                           htdb' = IntMap.insert tsiz' ht' htdb
                       in MerkleHashTrees set tsiz' htdb' idb'
            Nothing -> mht -- never reach
  where
    tsiz' = tsiz + 1
    hx = hash1 set a
    idb' = Map.insert hx tsiz idb

    newht ht = ins ht
      where
        ix = tsiz
        x = Leaf hx ix a

        hash2' = hash2 set
        ins (Empty _)           = x
        ins l@(Leaf hl il _ )   = Node (hash2' hl hx) il ix l x
        ins t@(Node h il ir l r)
          | isPowerOf2 sz = Node (hash2' h hx) il ix t x
          | otherwise     = let r' = ins r
                                h' = hash2' (value l) (value r')
                            in Node h' il ix l r'
          where
            sz = ir - il + 1

----------------------------------------------------------------

-- | A simple algorithm to create a binary balanced tree. O(n log n)
--   This is just for testing.
toHashTree :: (ByteArrayAccess inp, HashAlgorithm ha)
           => Settings inp ha -> [inp] -> HashTree inp ha
toHashTree set [] = Empty $ hash0 set -- not used
toHashTree set xs = ht
  where
    toLeaf = uncurry (leaf set)
    leaves = map toLeaf $ zip xs [0..]
    ht = buildup set leaves

leaf :: (ByteArrayAccess inp, HashAlgorithm ha)
     => Settings inp ha -> inp -> Index -> HashTree inp ha
leaf set x i = Leaf (hash1 set x) i x

link :: (ByteArrayAccess inp, HashAlgorithm ha)
     => Settings inp ha -> HashTree inp ha -> HashTree inp ha -> HashTree inp ha
link set l r = Node h (idxl l) (idxr r) l r
  where
    h = hash2 set (value l) (value r)

buildup :: (ByteArrayAccess inp, HashAlgorithm ha)
         => Settings inp ha -> [HashTree inp ha] -> HashTree inp ha
buildup _   [ht] = ht
buildup set hts  = buildup set (pairing set hts)

pairing :: (ByteArrayAccess inp, HashAlgorithm ha)
        => Settings inp ha -> [HashTree inp ha] -> [HashTree inp ha]
pairing set (t:u:vs) = link set t u : pairing set vs
pairing _       hts  = hts

----------------------------------------------------------------

-- | The type for inclusion proof (aka audit proof).
data InclusionProof ha = InclusionProof {
    -- | The index for the target.
    leafIndex :: !Index
    -- | The hash tree size.
  , treeSize  :: !TreeSize
    -- | A list of digest for inclusion.
  , inclusion :: ![Digest ha]
  } deriving (Eq, Show)

-- | The default value for 'InclusionProof' just to create a new value.
defaultInclusionProof :: InclusionProof ha
defaultInclusionProof = InclusionProof {
    leafIndex = 0
  , treeSize  = 1
  , inclusion = []
  }

-- | Generating 'InclusionProof' for the target at the server side.
generateInclusionProof :: Digest ha -- ^ The target hash (leaf digest)
                       -> TreeSize  -- ^ The tree size
                       -> MerkleHashTrees inp ha
                       -> Maybe (InclusionProof ha)
generateInclusionProof h tsiz (MerkleHashTrees _ _ htdb idb) = do
    ht <- IntMap.lookup tsiz htdb
    i <- Map.lookup h idb
    if i < tsiz then do
        let digests = reverse $ path i ht
        Just $ InclusionProof i tsiz digests
      else
        Nothing
  where
    path m (Node _ _ _ l r)
      | m <= idxr l = value r : path m l
      | otherwise   = value l : path m r
    path _ _ = []

-- | Verifying 'InclusionProof' at the client side.
--
-- >>> let target = "3"
-- >>> let mht = fromList defaultSettings ["0","1","2",target,"4","5","6"]
-- >>> let treeSize = 5
-- >>> let leafDigest = hash1 defaultSettings target
-- >>> let Just proof = generateInclusionProof leafDigest treeSize mht
-- >>> let Just rootDigest = digest treeSize mht
-- >>> verifyInclusionProof defaultSettings leafDigest rootDigest proof
-- True
verifyInclusionProof :: (ByteArrayAccess inp, HashAlgorithm ha)
                     => Settings inp ha
                     -> Digest ha         -- ^ The target hash (leaf digest)
                     -> Digest ha         -- ^ Merkle Tree Hash (root digest) for the tree size
                     -> InclusionProof ha -- ^ InclusionProof of the target
                     -> Bool
verifyInclusionProof set leafDigest rootDigest (InclusionProof idx tsiz pps)
  | idx >= tsiz = False
  | otherwise   = verify (idx,tsiz - 1) leafDigest pps
  where
    verify (_,sn) r []             = sn == 0 && r == rootDigest
    verify (_,0)  _ _              = False
    verify fsn@(fn,sn) r (p:ps)
      | fn `testBit` 0 || fn == sn = let r' = hash2 set p r
                                         fsn' = shiftR1 $ untilSet fsn
                                     in verify fsn' r' ps
      | otherwise                  = let r' = hash2 set r p
                                         fsn' = shiftR1 fsn
                                     in verify fsn' r' ps

----------------------------------------------------------------

-- | The type for consistency proof.
data ConsistencyProof ha = ConsistencyProof {
    -- | The first hash tree size.
    firstTreeSize :: !TreeSize
    -- | The second hash tree size.
  , secondTreeSize :: !TreeSize
    -- | A list of digest for consistency.
  , consistency :: ![Digest ha]
  } deriving (Eq, Show)

-- | The default value for 'ConsistencyProof' just to create a new value.
defaultConsistencyProof :: ConsistencyProof ha
defaultConsistencyProof = ConsistencyProof {
    firstTreeSize = 1
  , secondTreeSize = 2
  , consistency = []
  }

-- | Generating 'ConsistencyProof' for the target at the server side.
generateConsistencyProof :: TreeSize -> TreeSize -> MerkleHashTrees inp ha -> Maybe (ConsistencyProof ha)
generateConsistencyProof m n (MerkleHashTrees _ _ htdb _)
  | m < 0 || n < 0 = Nothing
  | m > n          = Nothing
  | m == 0         = do
      htn <- IntMap.lookup n htdb
      return $ ConsistencyProof m n [value htn]
  | otherwise = do
      htm <- IntMap.lookup m htdb
      htn <- IntMap.lookup n htdb
      let digests = prove htm htn True
      return $ ConsistencyProof m n digests
  where
    prove htm htn flag
      | idxl htm == idxl htn && idxr htm == idxr htn
                   = if flag then [] else [value htm]
    prove htm@(Leaf _ _ _) (Node _ _ _ ln rn) flag
                   = prove htm ln flag ++ [value rn]
    prove htm@(Node _ midxl midxr lm rm) (Node _ nidxl nidxr ln rn) flag
      | sizm <= k  = prove htm ln flag ++ [value rn]
      | otherwise  = prove rm rn False ++ [value lm]
      where
        sizm = midxr - midxl + 1
        sizn = nidxr - nidxl + 1
        k = maxPowerOf2 (sizn - 1) -- e.g. if 8, take 4.
    prove _ _ _    = error "generateConsistencyProof:prove"

-- | Verifying 'ConsistencyProof' at the client side.
--
-- >>> let mht0 = fromList defaultSettings ["0","1","2","3"]
-- >>> let (m, digestM) = info mht0
-- >>> let mht1 = add "6" $ add "5" $ add "4" mht0
-- >>> let (n, digestN) = info mht1
-- >>> let Just proof = generateConsistencyProof m n mht1
-- >>> verifyConsistencyProof defaultSettings digestM digestN proof
-- True
verifyConsistencyProof :: (ByteArrayAccess inp, HashAlgorithm ha)
                       => Settings inp ha
                       -> Digest ha -- start
                       -> Digest ha -- end
                       -> ConsistencyProof ha
                       -> Bool
verifyConsistencyProof set firstHash secondHash (ConsistencyProof first second path)
  | first == 0      = case path of
      [c] -> secondHash == c
      _   -> False
  | first == second = null path && firstHash == secondHash
  | otherwise       = case path' of
      []   -> False
      c:cs -> verify (untilNotSet (first - 1, second - 1)) c c cs -- fixme:cs
  where
    path'
      | isPowerOf2 first = firstHash : path
      | otherwise        = path
    verify _     fr sr [] = fr == firstHash && sr == secondHash
    verify (_,0) _ _ _    = error "verifyConsistencyProof:verify"
    verify fsn@(fn,sn) fr sr (c:cs)
      | fn `testBit` 0 || fn == sn = let fr' = hash2 set c fr
                                         sr' = hash2 set c sr
                                         fsn'
                                          | not (fn `testBit` 0) = untilSet fsn
                                          | otherwise           = fsn
                                         fsn'' = shiftR1 fsn'
                                     in verify fsn'' fr' sr' cs
      | otherwise = let sr' = hash2 set sr c
                        fsn' = shiftR1 fsn
                    in verify fsn' fr sr' cs

----------------------------------------------------------------

width :: Int -> Int
width x = finiteBitSize x - countLeadingZeros x

isPowerOf2 :: Int -> Bool
isPowerOf2 n = (n .&. (n - 1)) == 0

maxPowerOf2 :: Int -> Int
maxPowerOf2 n = 2 ^ (width n - 1)

shiftR1 :: (Int,Int) -> (Int,Int)
shiftR1 (x,y) = (x `unsafeShiftR` 1, y `unsafeShiftR` 1)

untilNotSet :: (Int,Int) -> (Int,Int)
untilNotSet fsn@(fn,_)
  | fn `testBit` 0 = untilNotSet $ shiftR1 fsn
  | otherwise      = fsn

untilSet :: (Int,Int) -> (Int,Int)
untilSet fsn@(fn,_)
  | fn == 0        = fsn
  | fn `testBit` 0 = fsn
  | otherwise      = untilSet $ shiftR1 fsn