```{-# LANGUAGE BangPatterns #-}
-- | Functions for finding /lowest common ancestors/ in binary trees in /O(1)/ time, with /O(n)/ preprocessing.
module Data.RangeMin.LCA.Binary (Index, BinTree(..), quickLCABinary, lcaBinary) where

import Data.RangeMin

import Data.RangeMin.Common.Vector.Utils

import Data.RangeMin.Common.ST
import Data.RangeMin.LCA.IndexM
import qualified Data.Vector.Unboxed as UV
import qualified Data.Vector.Primitive as PV
import qualified Data.Vector as V

-- | A generic binary tree.
data BinTree a = Tip | BinTree a (BinTree a) (BinTree a)

unfoldBin :: BinTree a -> Maybe (a, BinTree a, BinTree a)
unfoldBin (BinTree x l r) = Just (x, l, r)
unfoldBin Tip = Nothing

type Depth = Int
data Trav a = D {-# UNPACK #-} !Depth {-# UNPACK #-} !Index a

inorderBin :: BinTree a -> (BinTree (Index, a), Int)
inorderBin = execIndexM . trav where
trav Tip = return Tip
trav (BinTree x l r) = do
l' <- trav l
i <- getIndex
r' <- trav r
return (BinTree (i, x) l' r')

{-# INLINE inorderD #-}
inorderD :: (Int -> a -> b) -> BinTree a -> [b] -> [b]
inorderD f = inorderD' 0 where
inorderD' !d t xs = case unfoldBin t of
Just (x, l, r)	-> inorderD'' l (f d x:inorderD'' r xs)
Nothing		-> xs
where inorderD'' = inorderD' (d+1)

-- | Takes a binary tree and indexes it inorder, returning the number of nodes, the indexed
-- tree, and the lowest common ancestor function.
quickLCABinary :: BinTree a -> (Int, BinTree (Index, a), Index -> Index -> (Index, a))
quickLCABinary tree = case inorderBin tree of
(iTree, n) -> (n, iTree, lcaBinary n fst iTree)

-- | Similar to 'LCA.lowestCommonAncestor', but optimized for binary trees.  This method can reasonably
-- be expected to run twice as fast as 'lowestCommonAncestor'.
lcaBinary :: Int -> (a -> Index) -> BinTree a -> Index -> Index -> a
lcaBinary !n ix tree = lca
where	depths :: PV.Vector Int
ixs :: UV.Vector Int
-- 		vals :: V.Vector a
!(!depths, !ixs, !vals) = inlineRunST \$ do
!depthsM <- new n
!ixsM <- new n
!valsM <- new n
let go !i (D d j x:ts) = do
write depthsM i d
write ixsM j i
write valsM i x
go (i+1) ts
go _ [] = return ()
go 0 (inorderD (\ d x -> D d (ix x) x) tree [])
liftM3 (,,) (unsafeFreeze depthsM) (unsafeFreeze ixsM)
((unsafeFreeze :: V.MVector s a -> ST s (V.Vector a)) valsM)
rM :: Int -> Int -> Int
!rM = intRangeMin depths
{-# NOINLINE lca #-}
lca !i !j = let
iIx = ixs ! i
jIx = ixs ! j
in vals ! case compare iIx jIx of
EQ	-> iIx
LT	-> rM iIx (jIx - iIx + 1)
GT	-> rM jIx (iIx - jIx + 1)
```