-- Copyright (c) 2016-present, Facebook, Inc.
-- All rights reserved.
--
-- This source code is licensed under the BSD-style license found in the
-- LICENSE file in the root directory of this source tree.


{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE NoRebindableSyntax #-}

module Duckling.Types.Stash where

import qualified Data.IntMap.Strict as IntMap
import Data.IntMap.Strict (IntMap)
import qualified Data.HashSet as HashSet
import Data.HashSet (HashSet)
import Data.Maybe
import Prelude

import Duckling.Types

newtype Stash = Stash { Stash -> IntMap (HashSet Node)
getSet :: IntMap (HashSet Node) }

filter :: (Node -> Bool) -> Stash -> Stash
filter :: (Node -> Bool) -> Stash -> Stash
filter Node -> Bool
p Stash{IntMap (HashSet Node)
getSet :: IntMap (HashSet Node)
getSet :: Stash -> IntMap (HashSet Node)
..} = IntMap (HashSet Node) -> Stash
Stash ((HashSet Node -> HashSet Node)
-> IntMap (HashSet Node) -> IntMap (HashSet Node)
forall a b. (a -> b) -> IntMap a -> IntMap b
IntMap.map ((Node -> Bool) -> HashSet Node -> HashSet Node
forall a. (a -> Bool) -> HashSet a -> HashSet a
HashSet.filter Node -> Bool
p) IntMap (HashSet Node)
getSet)

toPosOrderedList:: Stash -> [Node]
toPosOrderedList :: Stash -> [Node]
toPosOrderedList Stash{IntMap (HashSet Node)
getSet :: IntMap (HashSet Node)
getSet :: Stash -> IntMap (HashSet Node)
..} = (HashSet Node -> [Node]) -> [HashSet Node] -> [Node]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap HashSet Node -> [Node]
forall a. HashSet a -> [a]
HashSet.toList ([HashSet Node] -> [Node]) -> [HashSet Node] -> [Node]
forall a b. (a -> b) -> a -> b
$ IntMap (HashSet Node) -> [HashSet Node]
forall a. IntMap a -> [a]
IntMap.elems IntMap (HashSet Node)
getSet

toPosOrderedListFrom :: Stash -> Int -> [Node]
toPosOrderedListFrom :: Stash -> Int -> [Node]
toPosOrderedListFrom Stash{IntMap (HashSet Node)
getSet :: IntMap (HashSet Node)
getSet :: Stash -> IntMap (HashSet Node)
..} Int
pos =
  (HashSet Node -> [Node]) -> [HashSet Node] -> [Node]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap HashSet Node -> [Node]
forall a. HashSet a -> [a]
HashSet.toList ([HashSet Node] -> [Node]) -> [HashSet Node] -> [Node]
forall a b. (a -> b) -> a -> b
$ Maybe (HashSet Node) -> [HashSet Node]
forall a. Maybe a -> [a]
maybeToList Maybe (HashSet Node)
equal [HashSet Node] -> [HashSet Node] -> [HashSet Node]
forall a. [a] -> [a] -> [a]
++ IntMap (HashSet Node) -> [HashSet Node]
forall a. IntMap a -> [a]
IntMap.elems IntMap (HashSet Node)
bigger
  where
  (IntMap (HashSet Node)
_smaller, Maybe (HashSet Node)
equal, IntMap (HashSet Node)
bigger) = Int
-> IntMap (HashSet Node)
-> (IntMap (HashSet Node), Maybe (HashSet Node),
    IntMap (HashSet Node))
forall a. Int -> IntMap a -> (IntMap a, Maybe a, IntMap a)
IntMap.splitLookup Int
pos IntMap (HashSet Node)
getSet
  -- this is where we take advantage of the order

empty :: Stash
empty :: Stash
empty = IntMap (HashSet Node) -> Stash
Stash IntMap (HashSet Node)
forall a. IntMap a
IntMap.empty

fromList :: [Node] -> Stash
fromList :: [Node] -> Stash
fromList [Node]
ns = IntMap (HashSet Node) -> Stash
Stash ((HashSet Node -> HashSet Node -> HashSet Node)
-> [(Int, HashSet Node)] -> IntMap (HashSet Node)
forall a. (a -> a -> a) -> [(Int, a)] -> IntMap a
IntMap.fromListWith HashSet Node -> HashSet Node -> HashSet Node
forall a. (Eq a, Hashable a) => HashSet a -> HashSet a -> HashSet a
HashSet.union ([(Int, HashSet Node)] -> IntMap (HashSet Node))
-> [(Int, HashSet Node)] -> IntMap (HashSet Node)
forall a b. (a -> b) -> a -> b
$ (Node -> (Int, HashSet Node)) -> [Node] -> [(Int, HashSet Node)]
forall a b. (a -> b) -> [a] -> [b]
map Node -> (Int, HashSet Node)
mkKV [Node]
ns)
  where
  mkKV :: Node -> (Int, HashSet Node)
mkKV n :: Node
n@Node{nodeRange :: Node -> Range
nodeRange = Range Int
start Int
_} = (Int
start, Node -> HashSet Node
forall a. Hashable a => a -> HashSet a
HashSet.singleton Node
n)

union :: Stash -> Stash -> Stash
union :: Stash -> Stash -> Stash
union (Stash IntMap (HashSet Node)
set1) (Stash IntMap (HashSet Node)
set2) =
  IntMap (HashSet Node) -> Stash
Stash ((HashSet Node -> HashSet Node -> HashSet Node)
-> IntMap (HashSet Node)
-> IntMap (HashSet Node)
-> IntMap (HashSet Node)
forall a. (a -> a -> a) -> IntMap a -> IntMap a -> IntMap a
IntMap.unionWith HashSet Node -> HashSet Node -> HashSet Node
forall a. (Eq a, Hashable a) => HashSet a -> HashSet a -> HashSet a
HashSet.union IntMap (HashSet Node)
set1 IntMap (HashSet Node)
set2)

-- Checks if two stashes have equal amount of Nodes on each position.
-- Used to detect a fixpoint, because the Stashes are only growing.
--
-- Not proud of this, but the algorithm shouldn't use it as the termination
-- condition, it should know when it stopped adding tokens
sizeEqual :: Stash -> Stash -> Bool
sizeEqual :: Stash -> Stash -> Bool
sizeEqual (Stash IntMap (HashSet Node)
set1) (Stash IntMap (HashSet Node)
set2) =
  [(Int, HashSet Node)] -> [(Int, HashSet Node)] -> Bool
forall a a a. Eq a => [(a, HashSet a)] -> [(a, HashSet a)] -> Bool
go (IntMap (HashSet Node) -> [(Int, HashSet Node)]
forall a. IntMap a -> [(Int, a)]
IntMap.toAscList IntMap (HashSet Node)
set1) (IntMap (HashSet Node) -> [(Int, HashSet Node)]
forall a. IntMap a -> [(Int, a)]
IntMap.toAscList IntMap (HashSet Node)
set2)
  where
  go :: [(a, HashSet a)] -> [(a, HashSet a)] -> Bool
go [] [] = Bool
True
  go [] ((a, HashSet a)
_:[(a, HashSet a)]
_) = Bool
False
  go ((a, HashSet a)
_:[(a, HashSet a)]
_) [] = Bool
False
  go ((a
k1, HashSet a
h1):[(a, HashSet a)]
rest1) ((a
k2, HashSet a
h2):[(a, HashSet a)]
rest2) =
    a
k1 a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
k2 Bool -> Bool -> Bool
&& HashSet a -> Int
forall a. HashSet a -> Int
HashSet.size HashSet a
h1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== HashSet a -> Int
forall a. HashSet a -> Int
HashSet.size HashSet a
h2 Bool -> Bool -> Bool
&& [(a, HashSet a)] -> [(a, HashSet a)] -> Bool
go [(a, HashSet a)]
rest1 [(a, HashSet a)]
rest2

null :: Stash -> Bool
null :: Stash -> Bool
null (Stash IntMap (HashSet Node)
set) = IntMap (HashSet Node) -> Bool
forall a. IntMap a -> Bool
IntMap.null IntMap (HashSet Node)
set