-- | Hash tables, implemented as a structure similar to @Map hash (Map key value)]@.
--
-- What this data structure can also give you is a unique value (a @(hash,Int)@ pair)
-- for each key, even during building the table: It is guaranteed to be unique 
-- in the past and future lifetime of a single hashtable (that is, one realization 
-- of the world-line), among all the keys appearing in that history.
--
-- Set operations (union, intersection) clearly break this principle; this is
-- resolved by declaring these operations to be /left-biased/, in the sense that
-- they retain the unique values of the left table (so @union t1 t2@ belongs to
-- to @t1@'s world-line, but not to @t2@'s one).
--
-- If a key is first removed then added back again, it will get a new value.
--
-- To be Haskell98 compatible (no multi-param type classes), when constructing 
-- a new hash table, we have to support the function computing (or just fetching, if
-- it is cached) the hash value. This function is then stored in the data type.
--

{-# LANGUAGE CPP #-}
module Data.Generics.Fixplate.Util.Hash.Table
  ( HashTable , Bucket(..) , Leaf(..)
  , getHashValue , unHashTable
    -- * Construction and deconstruction
  , empty , singleton
  , fromList , toList 
  , null 
  , bag
    -- * Membership
  , lookup , member 
    -- * Insertion / deletion
  , insert , insertWith
  , delete
    -- * Union
  , union , unionWith         
  , unionsWith , unionsWith'
    -- * Intersection
  , intersection, intersectionWith 
  , intersectionsWith , intersectionsWith'
    -- * Difference
  , difference , differenceWith
    -- * Unique indices
  , getUniqueIndex
  , keysWith
  , mapWithUniqueIndices
#ifdef WITH_QUICKCHECK
    -- * Tests
  , runtests_HashTable  
  , prop_insert       , prop_delete
  , prop_insertDelete , prop_deleteInsert
  , prop_insertInsert , prop_deleteDelete
  , prop_fromListToList
  , prop_intersection , prop_intersectionWith
  , prop_union        , prop_unionWith
  , prop_difference   , prop_differenceWith
  , prop_uniqueValues
#endif
  ) 
  where

--------------------------------------------------------------------------------

import Prelude hiding ( lookup , null )

import Data.List ( foldl' )

import qualified Data.Map as Map ; import Data.Map (Map) 
-- import qualified Data.Set as Set ; import Data.Set (Set) 

#ifdef WITH_QUICKCHECK
import Test.QuickCheck
import Test.QuickCheck.Modifiers
import Data.Generics.Fixplate.Misc
import Data.List ( sort , group , nubBy , nub , (\\) , foldl' , scanl )
import Control.Monad
import Control.Applicative ( (<$>) )
import Debug.Trace
#endif

--------------------------------------------------------------------------------
-- helper functions

mapInsertWith :: Ord k => (a -> v) -> (a -> v -> v) -> k -> a -> Map k v ->  Map k v
mapInsertWith f g k x = x `seq` Map.alter worker k where
  worker Nothing   =          Just $! (f x)
  worker (Just y)  = y `seq` (Just $! (g x y))

{-
mapIsSingleton :: Map k v -> Maybe (k,v)
mapIsSingleton table = if Map.size table == 1 
  then let [(k,v)] = Map.toList table in Just (k,v)
  else Nothing

mapIsSingleton_ :: Map k v -> Maybe v
mapIsSingleton_ table = if Map.size table == 1 
  then let [(_,v)] = Map.toList table in Just v
  else Nothing
-}

--------------------------------------------------------------------------------
-- buckets

data Leaf v     = Leaf   {-# UNPACK #-} !Int v                    -- the index of the key, plus a value
data Bucket k v = Bucket {-# UNPACK #-} !Int !(Map k (Leaf v))    -- the next free index, plus the elements in the bucket

fromLeaf :: Leaf v -> v
fromLeaf (Leaf _ x) = x

emptyBucket :: Bucket k v
emptyBucket = Bucket 0 (Map.empty)

bucketSingleton :: k -> v -> Bucket k v
bucketSingleton k x = Bucket 1 (Map.singleton k (Leaf 0 x))

bucketInsert :: Ord k => k -> v -> Bucket k v -> Bucket k v
bucketInsert = bucketInsertWith id const    -- not (flip const), since a -> v -> v !!!

bucketInsertWith :: Ord k => (a -> v) -> (a -> v -> v) -> k -> a -> Bucket k v -> Bucket k v
bucketInsertWith f g k x (Bucket n table) = x `seq` new where
  new = Bucket (n+1) (Map.alter worker k table)
  worker Nothing            =          Just $! (Leaf n (f x))
  worker (Just (Leaf j y))  = y `seq` (Just $! (Leaf j (g x y)))

{-
bucketIsSingleton :: Bucket k v -> Maybe (k,v)
bucketIsSingleton (Bucket _ table) = if Map.size table == 1 
  then let [(k,Leaf _ v)] = Map.toList table in Just (k,v)
  else Nothing

bucketIsSingleton_ :: Bucket k v -> Maybe v
bucketIsSingleton_ (Bucket _ table) = if Map.size table == 1 
  then let [Leaf _ v] = Map.elems table in Just v
  else Nothing
-}

--------------------------------------------------------------------------------

data HashTable hash k v = HashTable 
  { getHashValue :: k -> hash
  , unHashTable  :: Map hash (Bucket k v) 
  }

empty :: (Ord hash, Ord k) => (k -> hash) -> HashTable hash k v
empty gethash = HashTable gethash (Map.empty)

singleton :: (Ord hash, Ord k) => (k -> hash) -> k -> v -> HashTable hash k v
singleton gethash k v = HashTable gethash $ Map.singleton h (bucketSingleton k v) where
  h = gethash k

fromList :: (Ord hash, Ord k) => (k -> hash) -> [(k,v)] -> HashTable hash k v
fromList gethash = foldl' (\old (k,v) -> insert k v old) (empty gethash)

-- | Note that the returned list is ordered by hash, /not/ by keys like 'Data.Map'!
toList :: Ord k => HashTable hash k v -> [(k,v)]
toList (HashTable _ table) = 
  [ (k,v) 
  | Bucket _ sub  <- Map.elems table 
  , (k, Leaf _ v) <- Map.toList sub 
  ]

null :: (Ord hash, Ord k) => HashTable hash k v -> Bool
null t = case toList t of
  [] -> True
  _  -> False

-- | Keys together with their associated unique values
keysWith :: Ord k => (k -> hash -> Int -> a) -> HashTable hash k v -> [a]
keysWith f (HashTable _ table) = 
  [ f k hash j 
  | (hash, Bucket _ sub) <- Map.toList table 
  , (k, Leaf j _) <- Map.toList sub 
  ]

--------------------------------------------------------------------------------

lookup :: (Ord hash, Ord k) => k -> HashTable hash k v -> Maybe v
lookup key (HashTable gethash table) = 
  case Map.lookup h table of
    Just (Bucket n sub) -> case Map.lookup key sub of
      Just (Leaf _ v) -> Just v
      Nothing         -> Nothing
    Nothing  -> Nothing      
  where
    h = gethash key

-- | Look up a unique index, in the form of a @(hash,Int)@ pair, for any key.
-- If the user-supplied function is /injective/, then the result is guaranteed to be uniquely
-- associated to the given key in the past and future history of this table (but of
-- course not unique among different future histories).
--
getUniqueIndex :: (Ord hash, Ord k) => (hash -> Int -> a) -> k -> HashTable hash k v -> Maybe a
getUniqueIndex f key (HashTable gethash table) = 
  case Map.lookup h table of
    Just bucket@(Bucket _ sub) -> case Map.lookup key sub of
      Just (Leaf j _) -> Just (f h j)
      Nothing         -> Nothing       
    Nothing  -> Nothing      
  where
    h = gethash key

member :: (Ord hash, Ord k) => k -> HashTable hash k v -> Bool
member key table = case lookup key table of
  Just _  -> True
  Nothing -> False

--------------------------------------------------------------------------------

insert :: (Ord hash, Ord k) => k -> v -> HashTable hash k v -> HashTable hash k v
insert k v (HashTable gethash table) = HashTable gethash $ mapInsertWith f g h v table where
  h = gethash k
  f v     = bucketSingleton k v
  g v sub = bucketInsert    k v sub

insertWith :: (Ord hash, Ord k) => (a -> v) -> (a -> v -> v) -> k -> a -> HashTable hash k v -> HashTable hash k v
insertWith ff gg k x (HashTable gethash table) = HashTable gethash $ mapInsertWith f g h x table where
  h = gethash k
  f x     = bucketSingleton k (ff x)
  g x sub = bucketInsertWith ff gg k x sub

delete :: (Ord hash, Ord k) => k -> HashTable hash k v -> HashTable hash k v
delete k (HashTable gethash table) = HashTable gethash $ Map.alter worker h table where
  h = gethash k
  worker Nothing               = Nothing
  worker (Just (Bucket n sub)) = Just $ Bucket n (Map.delete k sub)

--------------------------------------------------------------------------------
-- union

-- | > union == unionWith const
union :: (Ord hash, Ord k) => HashTable hash k a -> HashTable hash k a -> HashTable hash k a
union = unionWith const

-- | This is unsafe in the sense that the two @getHash@ functions 
-- (supplied when the hash tables were created) must agree. The same applies for all the set operations.
--
-- It is also left-biased in the sense that the unique indices from the left hashtable are retained,
-- while the unique indices from the right hashtable are /changed/.
unionWith :: (Ord hash, Ord k) => (v -> v -> v) -> HashTable hash k v -> HashTable hash k v -> HashTable hash k v 
unionWith g (HashTable gethash table1) (HashTable _ table2) = HashTable gethash (Map.unionWith worker table1 table2)
  where
    worker (Bucket n sub1) (Bucket m sub2) = Bucket (n+m) (Map.unionWith h sub1 $ Map.map offset sub2) where
      h (Leaf i x) (Leaf _ y) = Leaf i (g x y)
      offset       (Leaf j y) = Leaf (n+j) y

-- | This is unsafe both in the above sense and also that it does not accepts the empty list (for the same reason).
-- The result belongs to the world-line of the first table.
unionsWith :: (Ord hash, Ord k) => (v -> v -> v) -> [HashTable hash k v] -> HashTable hash k v 
unionsWith g tables = case tables of
  [x]    -> x
  []     -> error "HashTable/unionsWith: empty list"
  xs     -> foldl1 (unionWith g) xs

-- | This one accepts the empty list. The empty imput creates a new world-line.
unionsWith' :: (Ord hash, Ord k) => (k -> hash) -> (v -> v -> v) -> [HashTable hash k v] -> HashTable hash k v 
unionsWith' gethash g tables = case tables of
  [x]    -> x
  []     -> empty gethash
  xs     -> foldl1 (unionWith g) xs

--------------------------------------------------------------------------------
-- intersection

-- | > intersection == intersectionWith const
intersection :: (Ord hash, Ord k) => HashTable hash k a -> HashTable hash k b -> HashTable hash k a
intersection = intersectionWith const
 
-- NOTE the `Map.union` and `Map.difference` here!!!!!
-- This is necessary so that the world-line property remains true: if there is a hash present in the left table
-- but not in the right table, then we have to put an empty bucket in the resulting table while retaining the 
-- next unique index value). Unfortunately "Data.Map" does not have a flexible enough set operation to be used here...
intersectionWith :: (Ord hash, Ord k) => (a -> b -> c) -> HashTable hash k a -> HashTable hash k b -> HashTable hash k c
intersectionWith g (HashTable gethash table1) (HashTable _ table2) = 
  HashTable gethash (Map.union a_minus_b a_cap_b) {- disjoint union -} where
    a_cap_b   = Map.intersectionWith cap_worker table1 table2
    a_minus_b = Map.map empty_worker (Map.difference table1 table2)
  
    cap_worker (Bucket n sub1) (Bucket _ sub2) = Bucket n (Map.intersectionWith h sub1 sub2) where
      h (Leaf i x) (Leaf _ y) = Leaf i (g x y)

    -- empty_worker :: Bucket k a -> Bucket k c
    empty_worker (Bucket n sub1) = Bucket n (Map.empty)

intersectionsWith :: (Ord hash, Ord k) => (v -> v -> v) -> [HashTable hash k v] -> HashTable hash k v 
intersectionsWith g tables = case tables of
  [x]    -> x
  []     -> error "HashTable/intersectionWith: empty list"
  xs     -> foldl1 (intersectionWith g) xs

intersectionsWith' :: (Ord hash, Ord k) => (k -> hash) -> (v -> v -> v) -> [HashTable hash k v] -> HashTable hash k v 
intersectionsWith' gethash g tables = case tables of
  [x]    -> x
  []     -> empty gethash
  xs     -> foldl1 (intersectionWith g) xs

--------------------------------------------------------------------------------
-- difference

difference :: (Ord hash, Ord k) => HashTable hash k a -> HashTable hash k b -> HashTable hash k a
difference = differenceWith (\_ _ -> Nothing)

differenceWith :: (Ord hash, Ord k) => (a -> b -> Maybe a) -> HashTable hash k a -> HashTable hash k b -> HashTable hash k a
differenceWith g (HashTable gethash table1) (HashTable _ table2) = HashTable gethash (Map.differenceWith worker table1 table2) 
  where
    worker (Bucket n sub1) (Bucket _ sub2) = Just (Bucket n (Map.differenceWith h sub1 sub2)) where
      h (Leaf i x) (Leaf _ y) = case g x y of
        Just z  -> Just (Leaf i z)
        Nothing -> Nothing
  
--------------------------------------------------------------------------------

-- | Creates a multi-set from a list.
bag :: (Ord hash, Ord k) => (k -> hash) -> [k] -> HashTable hash k Int
bag gethash = foldl' (\old k -> insertWith id (+) k 1 old) (empty gethash)

--------------------------------------------------------------------------------

mapWithUniqueIndices :: (Ord hash, Ord k) => (hash -> Int -> a -> b) -> HashTable hash k a -> HashTable hash k b
mapWithUniqueIndices user (HashTable gethash table) = HashTable gethash (Map.mapWithKey worker table) where
  worker hash (Bucket n sub) = Bucket n (Map.map g sub) where
    g (Leaf j x) = Leaf j (user hash j x)

--------------------------------------------------------------------------------
#ifdef WITH_QUICKCHECK
-- * tests

runtests_HashTable :: IO ()
runtests_HashTable = do
  quickCheck prop_insert
  quickCheck prop_delete
  quickCheck prop_insertDelete
  quickCheck prop_deleteInsert
  quickCheck prop_insertInsert
  quickCheck prop_deleteDelete
  quickCheck prop_fromListToList
  quickCheck prop_intersection
  quickCheck prop_intersectionWith
  quickCheck prop_union
  quickCheck prop_unionWith
  quickCheck prop_difference
  quickCheck prop_differenceWith
  replicateM_ 5 $ quickCheck prop_uniqueValues
--  quickCheck prop_
--  quickCheck prop_

-------------------------

debug x y = trace ("-- " ++ show x ++ " --") y

newtype Key = Key Int deriving (Eq,Ord,Show)

instance (Ord k, Ord hash, Show k, Show v) => Show (HashTable hash k v) where
  show t = "HashTable<< " ++ show (toList t) ++ " >>"

instance Arbitrary Key where
  arbitrary = do
    n <- choose (0, 255)
    return (Key n)

newtype Hash = Hash Int deriving (Eq,Ord,Show)

calcHash :: Key -> Hash 
calcHash (Key k) = Hash (mod k 17)

newtype Table v = Table (HashTable Hash Key v) deriving Show 

instance Arbitrary v => Arbitrary (Table v) where
  arbitrary = do
    xs <- arbitrary
    let t = fromList calcHash xs
    {- debug (length xs) $ -}
    return (Table t)

newtype NonEmptyTable v = NonEmptyTable (HashTable Hash Key v) deriving Show 

instance Arbitrary v => Arbitrary (NonEmptyTable v) where
  arbitrary = do
    NonEmpty xs <- arbitrary
    let t = fromList calcHash xs
    {- debug (length xs) $ -}
    return (NonEmptyTable t)

data Pointed v = Pointed (HashTable Hash Key v) (Key,v) deriving Show

instance Arbitrary v => Arbitrary (Pointed v) where
  arbitrary = do
    NonEmptyTable t <- arbitrary
    let list = toList t
        n = length list
    i <- choose (0,n-1)
    let kv =list!!i
    return (Pointed t kv)

sortedToList :: Ord a => HashTable Hash Key a -> [(Key,a)]
sortedToList = sort . toList

-------------------------

data Step v
  = Insert     Key v 
  | InsertWith Key v
  | Delete     Key  
  | Union      (Table v)
  | Intersect  (Table v)
  | Difference (Table v)
  deriving Show

instance Arbitrary v => Arbitrary (Step v) where
  arbitrary = do
    frequency
      [ ( 10 , do { k<-arbitrary ; v<-arbitrary ; return (Insert     k v) } )
      , (  5 , do { k<-arbitrary ; v<-arbitrary ; return (InsertWith k v) } )
      , ( 10 , do { k<-arbitrary ; return (Delete k)       } )
      , (  3 , do { t<-arbitrary ; return (Union      t  ) } )
      , (  2 , do { t<-arbitrary ; return (Difference t  ) } )
      , (  1 , do { t<-arbitrary ; return (Intersect  t  ) } )
      ]

newtype NoDeleteStep v = NoDeleteStep (Step v)

instance Arbitrary v => Arbitrary (NoDeleteStep v) where
  arbitrary = NoDeleteStep <$> do
    frequency
      [ ( 10 , do { k<-arbitrary ; v<-arbitrary ; return (Insert     k v) } )
      , (  5 , do { k<-arbitrary ; v<-arbitrary ; return (InsertWith k v) } )
      , (  3 , do { t<-arbitrary ;                return (Union      t  ) } )
      ]

step :: (v -> v -> v) -> Step v -> HashTable Hash Key v -> HashTable Hash Key v 
step f step old = case step of
  Insert     k v       -> insert          k v old
  InsertWith k v       -> insertWith id f k v old
  Delete     k         -> delete          k   old
  Union      (Table t) -> union        old t
  Intersect  (Table t) -> intersection old t
  Difference (Table t) -> difference   old t

type History v = [Step v]

runHistory :: (v -> v -> v) -> History v -> HashTable Hash Key v -> [HashTable Hash Key v]
runHistory f steps ini = scanl (flip (step f)) ini steps

data U = U Hash Int deriving (Eq,Ord,Show)

-------------------------

prop_insert :: Key -> Char -> Table Char -> Bool
prop_insert k v (Table table) = lookup k (insert k v table) == Just v

prop_delete :: Pointed Char -> Bool
prop_delete (Pointed table (k,_)) = lookup k (delete k table) == Nothing 

prop_insertInsert :: Key -> Char -> Table Char -> Bool
prop_insertInsert k v (Table table) = toList (insert k v table) == toList (insert k v (insert k v table))

prop_deleteDelete :: Pointed Char -> Bool
prop_deleteDelete (Pointed table (k,_)) = toList (delete k table) == toList (delete k (delete k table))

prop_insertDelete :: Key -> Char -> Table Char -> Bool
prop_insertDelete k v (Table table) = lookup k (delete k $ insert k v table) == Nothing

prop_deleteInsert :: Pointed Char -> Bool
prop_deleteInsert (Pointed table (k,v)) = lookup k (insert k v $ delete k table) == Just v 

prop_fromListToList :: [(Key,Char)] -> Bool
prop_fromListToList xs = sortedToList (fromList calcHash xs) == Map.toList (Map.fromList xs)

prop_intersection :: [(Key,Char)] -> [(Key,Bool)] -> Bool
prop_intersection xs ys = sortedToList (intersection t1 t2) == Map.toList (Map.intersection m1 m2) where
  t1 = fromList calcHash xs
  t2 = fromList calcHash ys
  m1 = Map.fromList xs
  m2 = Map.fromList ys

prop_intersectionWith :: [(Key,Char)] -> [(Key,String)] -> Bool
prop_intersectionWith xs ys = sortedToList (intersectionWith (:) t1 t2) == Map.toList (Map.intersectionWith (:) m1 m2) where
  t1 = fromList calcHash xs
  t2 = fromList calcHash ys
  m1 = Map.fromList xs
  m2 = Map.fromList ys

prop_union :: [(Key,Char)] -> [(Key,Char)] -> Bool
prop_union xs ys = sortedToList (union t1 t2) == Map.toList (Map.union m1 m2) where
  t1 = fromList calcHash xs
  t2 = fromList calcHash ys
  m1 = Map.fromList xs
  m2 = Map.fromList ys

prop_unionWith :: [(Key,String)] -> [(Key,String)] -> Bool
prop_unionWith xs ys = sortedToList (unionWith (++) t1 t2) == Map.toList (Map.unionWith (++) m1 m2) where
  t1 = fromList calcHash xs
  t2 = fromList calcHash ys
  m1 = Map.fromList xs
  m2 = Map.fromList ys

prop_difference :: [(Key,Char)] -> [(Key,Bool)] -> Bool
prop_difference xs ys = sortedToList (difference t1 t2) == Map.toList (Map.difference m1 m2) where
  t1 = fromList calcHash xs
  t2 = fromList calcHash ys
  m1 = Map.fromList xs
  m2 = Map.fromList ys

prop_differenceWith :: [(Key,Char)] -> [(Key,Bool)] -> Bool
prop_differenceWith xs ys = sortedToList (differenceWith f t1 t2) == Map.toList (Map.differenceWith f m1 m2) where
  t1 = fromList calcHash xs
  t2 = fromList calcHash ys
  m1 = Map.fromList xs
  m2 = Map.fromList ys
  f x b = if b then Just x else Nothing

-- we try to test whether values are really unique and really constant during a wordline
prop_uniqueValues :: History Float -> Table Float -> Bool
prop_uniqueValues history (Table initial) = areUnique && areInjective {- && ... -} where
  worldline = runHistory (\x y -> x-y) history initial :: [HashTable Hash Key Float]
  lists = ((flip map) worldline $ \table -> keysWith (\k h j -> (U h j, k)) table) :: [[(U,Key)]]

  -- at each point in time, a single value must appear only once in the table
  areUnique = and [ isUnique xs | xs <- lists ]
  isUnique uks = let us = map fst uks in sort us == sort (nub us)

  -- taking the whole wordline, it must be true that to a single unique value there is only a single key associated
  -- (the opposite is not true, since a key can be deleted then reinserted, gaining a new value)
  areInjective = and $ map test $ groupSortOn fst $ concat lists where
    test :: [(U,Key)] -> Bool
    test xs = (length (groupSortOn fst xs) == 1)     -- this is redundant, but hey, we are also testing the test :)
           && (length (groupSortOn snd xs) == 1)
 

#endif
--------------------------------------------------------------------------------