{-# LANGUAGE DeriveGeneric #-} module Sync.MerkleTree.Trie where import Prelude hiding (lookup) import Control.Monad import Control.Arrow hiding (arr, loop) import Crypto.Hash import Data.Array.IArray import Data.Byteable import Data.Set(Set) import GHC.Generics import qualified Data.ByteString as BS import qualified Data.ByteString.Base16 as B16 import qualified Data.List as L import qualified Data.Bytes.Serial as SE import qualified Data.Set as S import qualified Data.Text as T import qualified Data.Text.Encoding as TE import qualified Test.HUnit as H data Hash = Hash { unHash :: !BS.ByteString } deriving (Eq, Generic) instance Show Hash where showsPrec _ x = ((T.unpack $ TE.decodeUtf8 $ B16.encode $ unHash x) ++) instance SE.Serial Hash -- Abstract Merkle Hash Trie data Trie a = Trie { t_hash :: !Hash , t_node :: !(TrieNode a) } deriving (Show, Eq) data TrieNode a = Node !(Array Int (Trie a)) | Leave !(Set a) deriving (Show, Eq) data NodeType = NodeType | LeaveType deriving (Eq, Generic) instance SE.Serial NodeType -- Location in the Merkle Hash Trie data TrieLocation = TrieLocation { tl_level :: !Int -- ^ Must be nonnegative , tl_index :: !Int -- ^ Must be between nonnegative and smaller than (degree^tl_level) } deriving (Generic) instance SE.Serial TrieLocation degree :: Int degree = 64 class HasDigest a where digest :: a -> Digest MD5 -- | Fingerprint of a Merkle-Hash-Tree node -- We asssume the Tree below a node is identical while synchronizing if its FingerPrint is data Fingerprint = Fingerprint { f_hash :: !Hash , f_nodeType :: !NodeType } deriving (Eq, Generic) instance SE.Serial Fingerprint toFingerprint :: Trie a -> Fingerprint toFingerprint (Trie h node) = Fingerprint h nodeType where nodeType = case node of Node _ -> NodeType Leave _ -> LeaveType -- | Creates a Merkle-Hash-Tree for a list of elements mkTrie :: (Ord a, HasDigest a) => Int -> [a] -> Trie a mkTrie i ls | length ls < degree = mkLeave ls | otherwise = mkNode $ fmap (mkTrie (i+1)) $ accumArray (flip (:)) [] (0,degree-1) $ map ((groupOf i) &&& id) ls mkNode :: (Array Int (Trie a)) -> Trie a mkNode arr = Trie { t_hash = combineHash $ map t_hash $ elems arr , t_node = Node arr } hashMD5 :: BS.ByteString -> Digest MD5 hashMD5 = hash combineHash :: [Hash] -> Hash combineHash = Hash . toBytes . hashMD5 . BS.concat . map unHash -- | The function @groupOf x@ eeturns a value between 0 to degree-1 for a digest with the property -- that @groupOf@ forms an approximate unviversal hash familiy. groupOf :: (HasDigest a) => Int -> a -> Int groupOf i x = fromInteger $ toInteger $ (h0 `mod` (fromInteger $ toInteger degree)) where Just (h0, _t) = BS.uncons $ toBytes $ h h :: Digest MD5 h = hash $ BS.concat [BS.pack [fromInteger $ toInteger i], toBytes $ digest x] mkLeave :: (HasDigest a, Ord a) => [a] -> Trie a mkLeave ls = Trie { t_hash = combineHash $ map (Hash . toBytes . digest) $ L.sort ls , t_node = Leave $ S.fromList ls } lookup :: (Monad m) => Trie a -> TrieLocation -> m (Trie a) lookup trie (TrieLocation { tl_level = l, tl_index = i }) | l < 0 || i < 0 || i >= degree^l = fail "illegal index pair" | l > 0, (g, i') <- i `quotRem` (degree ^ (l-1)), Node arr <- t_node trie = lookup (arr ! g) (TrieLocation { tl_level = (l - 1), tl_index = i'}) | l == 0 = return trie | otherwise = fail "index pair to deep" queryHash :: (Monad m) => Trie a -> TrieLocation -> m Fingerprint queryHash trie = liftM toFingerprint . lookup trie querySet :: (Ord a, Monad m) => Trie a -> TrieLocation -> m (Set a) querySet trie = liftM getAll . lookup trie getAll :: (Ord a) => Trie a -> Set a getAll (Trie _ node) = case node of Node arr -> S.unions $ map getAll $ elems arr Leave s -> s rootLocation :: TrieLocation rootLocation = TrieLocation { tl_level = 0 , tl_index = 0 } expand :: TrieLocation -> (Array Int (Trie a)) -> [(TrieLocation, Trie a)] expand loc arr = map go [0..(degree - 1)] where go i = ( TrieLocation { tl_level = tl_level loc + 1 , tl_index = degree * tl_index loc + i } , arr ! i ) newtype TestDigest = TestDigest { unTestDigest :: T.Text } deriving (Eq, Ord, Show) instance HasDigest TestDigest where digest = hashMD5 . TE.encodeUtf8 . unTestDigest tests :: H.Test tests = H.TestList $ [ H.TestLabel "trieLookup" $ (Nothing H.~=? (lookup t (TrieLocation { tl_level = -1, tl_index = 0 }))) , H.TestLabel "trieLookupTooDeep" $ (Nothing H.~=? (lookup t (TrieLocation { tl_level = 4, tl_index = 0 }))) ] where t = mkTrie 0 $ map (TestDigest . T.pack . show) [0..(13+2*degree*degree)]