{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE FlexibleInstances     #-}
{-# LANGUAGE MultiParamTypeClasses #-}

module HaskellWorks.Data.BalancedParens.ParensSeq
  ( ParensSeq(..)
  , mempty
  , size
  , fromWord64s
  , fromPartialWord64s
  , toPartialWord64s
  , fromBools
  , toBools
  , splitAt
  , take
  , drop
  , firstChild
  , nextSibling
  , (<|), (><), (|>)
  ) where

import Data.Coerce
import Data.Foldable
import Data.Word
import HaskellWorks.Data.BalancedParens.Internal.ParensSeq (Elem (Elem), ParensSeq (ParensSeq), ParensSeqFt)
import HaskellWorks.Data.Bits.BitWise
import HaskellWorks.Data.FingerTree                        (ViewL (..), ViewR (..), (<|), (><), (|>))
import HaskellWorks.Data.Positioning
import Prelude                                             hiding (drop, max, min, splitAt, take)

import qualified Data.List                                           as L
import qualified HaskellWorks.Data.BalancedParens.Internal.ParensSeq as PS
import qualified HaskellWorks.Data.BalancedParens.Internal.Word      as W
import qualified HaskellWorks.Data.FingerTree                        as FT

empty :: ParensSeq
empty :: ParensSeq
empty = ParensSeqFt -> ParensSeq
ParensSeq ParensSeqFt
forall v a. Measured v a => FingerTree v a
FT.empty

size :: ParensSeq -> Count
size :: ParensSeq -> Count
size (ParensSeq ParensSeqFt
parens) = Measure -> Count
PS.size (ParensSeqFt -> Measure
forall v a. Measured v a => a -> v
FT.measure ParensSeqFt
parens :: PS.Measure)

-- TODO Needs optimisation
fromWord64s :: Traversable f => f Word64 -> ParensSeq
fromWord64s :: f Count -> ParensSeq
fromWord64s = (ParensSeq -> Count -> ParensSeq)
-> ParensSeq -> f Count -> ParensSeq
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl ParensSeq -> Count -> ParensSeq
go ParensSeq
empty
  where go :: ParensSeq -> Word64 -> ParensSeq
        go :: ParensSeq -> Count -> ParensSeq
go ParensSeq
ps Count
w = ParensSeqFt -> ParensSeq
ParensSeq (ParensSeq -> ParensSeqFt
PS.parens ParensSeq
ps ParensSeqFt -> Elem ParensSeqFt -> ParensSeqFt
forall v. Snoc v => v -> Elem v -> v
|> Count -> Count -> Elem
Elem Count
w Count
64)

-- TODO Needs optimisation
fromPartialWord64s :: Traversable f => f (Word64, Count) -> ParensSeq
fromPartialWord64s :: f (Count, Count) -> ParensSeq
fromPartialWord64s = (ParensSeq -> (Count, Count) -> ParensSeq)
-> ParensSeq -> f (Count, Count) -> ParensSeq
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl ParensSeq -> (Count, Count) -> ParensSeq
go ParensSeq
empty
  where go :: ParensSeq -> (Word64, Count) -> ParensSeq
        go :: ParensSeq -> (Count, Count) -> ParensSeq
go ParensSeq
ps (Count
w, Count
n) = ParensSeqFt -> ParensSeq
ParensSeq (ParensSeq -> ParensSeqFt
PS.parens ParensSeq
ps ParensSeqFt -> Elem ParensSeqFt -> ParensSeqFt
forall v. Snoc v => v -> Elem v -> v
|> Count -> Count -> Elem
Elem Count
w Count
n)

toPartialWord64s :: ParensSeq -> [(Word64, Count)]
toPartialWord64s :: ParensSeq -> [(Count, Count)]
toPartialWord64s = (ParensSeqFt -> Maybe ((Count, Count), ParensSeqFt))
-> ParensSeqFt -> [(Count, Count)]
forall b a. (b -> Maybe (a, b)) -> b -> [a]
L.unfoldr ParensSeqFt -> Maybe ((Count, Count), ParensSeqFt)
go (ParensSeqFt -> [(Count, Count)])
-> (ParensSeq -> ParensSeqFt) -> ParensSeq -> [(Count, Count)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ParensSeq -> ParensSeqFt
coerce
  where go :: ParensSeqFt -> Maybe ((Word64, Count), ParensSeqFt)
        go :: ParensSeqFt -> Maybe ((Count, Count), ParensSeqFt)
go ParensSeqFt
ft = case ParensSeqFt -> ViewL (FingerTree Measure) Elem
forall v a.
Measured v a =>
FingerTree v a -> ViewL (FingerTree v) a
FT.viewl ParensSeqFt
ft of
          PS.Elem Count
w Count
n :< ParensSeqFt
rt -> ((Count, Count), ParensSeqFt)
-> Maybe ((Count, Count), ParensSeqFt)
forall a. a -> Maybe a
Just ((Count
w, Count -> Count
coerce Count
n), ParensSeqFt
rt)
          ViewL (FingerTree Measure) Elem
FT.EmptyL         -> Maybe ((Count, Count), ParensSeqFt)
forall a. Maybe a
Nothing

fromBools :: [Bool] -> ParensSeq
fromBools :: [Bool] -> ParensSeq
fromBools = ParensSeq -> [Bool] -> ParensSeq
go ParensSeq
empty
  where go :: ParensSeq -> [Bool] -> ParensSeq
        go :: ParensSeq -> [Bool] -> ParensSeq
go (ParensSeq ParensSeqFt
ps) (Bool
b:[Bool]
bs) = case ParensSeqFt -> ViewR (FingerTree Measure) Elem
forall v a.
Measured v a =>
FingerTree v a -> ViewR (FingerTree v) a
FT.viewr ParensSeqFt
ps of
          ViewR (FingerTree Measure) Elem
FT.EmptyR      -> ParensSeq -> [Bool] -> ParensSeq
go (ParensSeqFt -> ParensSeq
ParensSeq (Elem -> ParensSeqFt
forall v a. Measured v a => a -> FingerTree v a
FT.singleton (Count -> Count -> Elem
Elem Count
b' Count
1))) [Bool]
bs
          ParensSeqFt
lt :> Elem Count
w Count
n ->
            let newPs :: ParensSeqFt
newPs = if Count
n Count -> Count -> Bool
forall a. Ord a => a -> a -> Bool
>= Count
64
                then ParensSeqFt
ps ParensSeqFt -> Elem ParensSeqFt -> ParensSeqFt
forall v. Snoc v => v -> Elem v -> v
|> Count -> Count -> Elem
Elem Count
b' Count
1
                else ParensSeqFt
lt ParensSeqFt -> Elem ParensSeqFt -> ParensSeqFt
forall v. Snoc v => v -> Elem v -> v
|> Count -> Count -> Elem
Elem (Count
w Count -> Count -> Count
forall a. BitWise a => a -> a -> a
.|. (Count
b' Count -> Count -> Count
forall a. Shift a => a -> Count -> a
.<. Count -> Count
forall a b. (Integral a, Num b) => a -> b
fromIntegral Count
n)) (Count
n Count -> Count -> Count
forall a. Num a => a -> a -> a
+ Count
1)
            in ParensSeq -> [Bool] -> ParensSeq
go (ParensSeqFt -> ParensSeq
ParensSeq ParensSeqFt
newPs) [Bool]
bs
          where b' :: Count
b' = if Bool
b then Count
1 else Count
0 :: Word64
        go ParensSeq
ps [] = ParensSeq
ps

toBools :: ParensSeq -> [Bool]
toBools :: ParensSeq -> [Bool]
toBools ParensSeq
ps = ParensSeq -> [Bool] -> [Bool]
toBoolsDiff ParensSeq
ps []

toBoolsDiff :: ParensSeq -> [Bool] -> [Bool]
toBoolsDiff :: ParensSeq -> [Bool] -> [Bool]
toBoolsDiff ParensSeq
ps = [[Bool] -> [Bool]] -> [Bool] -> [Bool]
forall a. Monoid a => [a] -> a
mconcat (((Count, Count) -> [Bool] -> [Bool])
-> [(Count, Count)] -> [[Bool] -> [Bool]]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Count, Count) -> [Bool] -> [Bool]
go (ParensSeq -> [(Count, Count)]
toPartialWord64s ParensSeq
ps))
  where go :: (Word64, Count) -> [Bool] -> [Bool]
        go :: (Count, Count) -> [Bool] -> [Bool]
go (Count
w, Count
n) = Count -> Count -> [Bool] -> [Bool]
W.partialToBoolsDiff (Count -> Count
forall a b. (Integral a, Num b) => a -> b
fromIntegral Count
n) Count
w

drop :: Count -> ParensSeq -> ParensSeq
drop :: Count -> ParensSeq -> ParensSeq
drop Count
n ParensSeq
ps = (ParensSeq, ParensSeq) -> ParensSeq
forall a b. (a, b) -> b
snd (Count -> ParensSeq -> (ParensSeq, ParensSeq)
splitAt Count
n ParensSeq
ps)

take :: Count -> ParensSeq -> ParensSeq
take :: Count -> ParensSeq -> ParensSeq
take Count
n ParensSeq
ps = (ParensSeq, ParensSeq) -> ParensSeq
forall a b. (a, b) -> a
fst (Count -> ParensSeq -> (ParensSeq, ParensSeq)
splitAt Count
n ParensSeq
ps)

splitAt :: Count -> ParensSeq -> (ParensSeq, ParensSeq)
splitAt :: Count -> ParensSeq -> (ParensSeq, ParensSeq)
splitAt Count
n (ParensSeq ParensSeqFt
parens) = case (Measure -> Bool) -> ParensSeqFt -> (ParensSeqFt, ParensSeqFt)
forall v a.
Measured v a =>
(v -> Bool) -> FingerTree v a -> (FingerTree v a, FingerTree v a)
FT.split (Count -> Measure -> Bool
PS.atSizeBelowZero Count
n) ParensSeqFt
parens of
  (ParensSeqFt
lt, ParensSeqFt
rt) -> let
    n' :: Count
n' = Count
n Count -> Count -> Count
forall a. Num a => a -> a -> a
- Measure -> Count
PS.size (ParensSeqFt -> Measure
forall v a. Measured v a => a -> v
FT.measure ParensSeqFt
lt :: PS.Measure)
    u :: Count
u  = Count
64 Count -> Count -> Count
forall a. Num a => a -> a -> a
- Count
n'
    in case ParensSeqFt -> ViewL (FingerTree Measure) Elem
forall v a.
Measured v a =>
FingerTree v a -> ViewL (FingerTree v) a
FT.viewl ParensSeqFt
rt of
      PS.Elem Count
w Count
nw :< ParensSeqFt
rrt -> if Count
n' Count -> Count -> Bool
forall a. Ord a => a -> a -> Bool
>= Count
nw
        then (ParensSeqFt -> ParensSeq
ParensSeq  ParensSeqFt
lt                                 , ParensSeqFt -> ParensSeq
ParensSeq                                  ParensSeqFt
rrt )
        else (ParensSeqFt -> ParensSeq
ParensSeq (ParensSeqFt
lt ParensSeqFt -> Elem ParensSeqFt -> ParensSeqFt
forall v. Snoc v => v -> Elem v -> v
|> Count -> Count -> Elem
PS.Elem ((Count
w Count -> Count -> Count
forall a. Shift a => a -> Count -> a
.<. Count
u) Count -> Count -> Count
forall a. Shift a => a -> Count -> a
.>. Count
u) Count
n'), ParensSeqFt -> ParensSeq
ParensSeq (Count -> Count -> Elem
PS.Elem (Count
w Count -> Count -> Count
forall a. Shift a => a -> Count -> a
.>. Count
n') (Count
nw Count -> Count -> Count
forall a. Num a => a -> a -> a
- Count
n') Elem ParensSeqFt -> ParensSeqFt -> ParensSeqFt
forall v. Cons v => Elem v -> v -> v
<| ParensSeqFt
rrt))
      ViewL (FingerTree Measure) Elem
FT.EmptyL          -> (ParensSeqFt -> ParensSeq
ParensSeq ParensSeqFt
lt, ParensSeqFt -> ParensSeq
ParensSeq ParensSeqFt
forall v a. Measured v a => FingerTree v a
FT.empty)

firstChild :: ParensSeq -> Count -> Maybe Count
firstChild :: ParensSeq -> Count -> Maybe Count
firstChild ParensSeq
ps Count
n = case ParensSeqFt -> ViewL (FingerTree Measure) Elem
forall v a.
Measured v a =>
FingerTree v a -> ViewL (FingerTree v) a
FT.viewl ParensSeqFt
ft of
  PS.Elem Count
w Count
nw :< ParensSeqFt
rt -> if Count
nw Count -> Count -> Bool
forall a. Ord a => a -> a -> Bool
>= Count
2
    then case Count
w Count -> Count -> Count
forall a. BitWise a => a -> a -> a
.&. Count
3 of
      Count
3 -> Count -> Maybe Count
forall a. a -> Maybe a
Just (Count
n Count -> Count -> Count
forall a. Num a => a -> a -> a
+ Count
1)
      Count
_ -> Maybe Count
forall a. Maybe a
Nothing
    else if Count
nw Count -> Count -> Bool
forall a. Ord a => a -> a -> Bool
>= Count
1
      then case Count
w Count -> Count -> Count
forall a. BitWise a => a -> a -> a
.&. Count
1 of
        Count
1 -> case ParensSeqFt -> ViewL (FingerTree Measure) Elem
forall v a.
Measured v a =>
FingerTree v a -> ViewL (FingerTree v) a
FT.viewl ParensSeqFt
rt of
          PS.Elem Count
w' Count
nw' :< ParensSeqFt
_ -> if Count
nw' Count -> Count -> Bool
forall a. Ord a => a -> a -> Bool
>= Count
1
            then case Count
w' Count -> Count -> Count
forall a. BitWise a => a -> a -> a
.&. Count
1 of
              Count
1 -> Count -> Maybe Count
forall a. a -> Maybe a
Just (Count
n Count -> Count -> Count
forall a. Num a => a -> a -> a
+ Count
1)
              Count
_ -> Maybe Count
forall a. Maybe a
Nothing
            else Maybe Count
forall a. Maybe a
Nothing
          ViewL (FingerTree Measure) Elem
FT.EmptyL -> Maybe Count
forall a. Maybe a
Nothing
        Count
_ -> Maybe Count
forall a. Maybe a
Nothing
      else Maybe Count
forall a. Maybe a
Nothing
  ViewL (FingerTree Measure) Elem
FT.EmptyL -> Maybe Count
forall a. Maybe a
Nothing
  where ParensSeq ParensSeqFt
ft = Count -> ParensSeq -> ParensSeq
drop (Count
n Count -> Count -> Count
forall a. Num a => a -> a -> a
- Count
1) ParensSeq
ps

nextSibling :: ParensSeq -> Count -> Maybe Count
nextSibling :: ParensSeq -> Count -> Maybe Count
nextSibling (ParensSeq ParensSeqFt
ps) Count
n = do
  let (ParensSeqFt
lt0, ParensSeqFt
rt0) = (Measure -> Bool) -> ParensSeqFt -> (ParensSeqFt, ParensSeqFt)
PS.ftSplit (Count -> Measure -> Bool
PS.atSizeBelowZero (Count
n Count -> Count -> Count
forall a. Num a => a -> a -> a
- Count
1)) ParensSeqFt
ps
  ()
_ <- case ParensSeqFt -> ViewL (FingerTree Measure) Elem
forall v a.
Measured v a =>
FingerTree v a -> ViewL (FingerTree v) a
FT.viewl ParensSeqFt
rt0 of
    PS.Elem Count
w Count
nw :< ParensSeqFt
_ -> if Count
nw Count -> Count -> Bool
forall a. Ord a => a -> a -> Bool
>= Count
1 Bool -> Bool -> Bool
&& Count
w Count -> Count -> Count
forall a. BitWise a => a -> a -> a
.&. Count
1 Count -> Count -> Bool
forall a. Eq a => a -> a -> Bool
== Count
1 then () -> Maybe ()
forall a. a -> Maybe a
Just () else Maybe ()
forall a. Maybe a
Nothing
    ViewL (FingerTree Measure) Elem
FT.EmptyL         -> Maybe ()
forall a. Maybe a
Nothing
  let (ParensSeqFt
lt1, ParensSeqFt
rt1) = (Measure -> Bool) -> ParensSeqFt -> (ParensSeqFt, ParensSeqFt)
PS.ftSplit (Count -> Measure -> Bool
PS.atSizeBelowZero Count
1) ParensSeqFt
rt0
  let (ParensSeqFt
lt2, ParensSeqFt
rt2) = (Measure -> Bool) -> ParensSeqFt -> (ParensSeqFt, ParensSeqFt)
PS.ftSplit Measure -> Bool
PS.atMinZero  ParensSeqFt
rt1
  case ParensSeqFt -> ViewL (FingerTree Measure) Elem
forall v a.
Measured v a =>
FingerTree v a -> ViewL (FingerTree v) a
FT.viewl ParensSeqFt
rt2 of
    PS.Elem Count
w Count
nw :< ParensSeqFt
_ -> if Count
nw Count -> Count -> Bool
forall a. Ord a => a -> a -> Bool
>= Count
1 Bool -> Bool -> Bool
&& Count
w Count -> Count -> Count
forall a. BitWise a => a -> a -> a
.&. Count
1 Count -> Count -> Bool
forall a. Eq a => a -> a -> Bool
== Count
0 then () -> Maybe ()
forall a. a -> Maybe a
Just () else Maybe ()
forall a. Maybe a
Nothing
    ViewL (FingerTree Measure) Elem
FT.EmptyL         -> Maybe ()
forall a. Maybe a
Nothing
  let (ParensSeqFt
lt3, ParensSeqFt
rt3) = (Measure -> Bool) -> ParensSeqFt -> (ParensSeqFt, ParensSeqFt)
PS.ftSplit (Count -> Measure -> Bool
PS.atSizeBelowZero Count
1) ParensSeqFt
rt2
  case ParensSeqFt -> ViewL (FingerTree Measure) Elem
forall v a.
Measured v a =>
FingerTree v a -> ViewL (FingerTree v) a
FT.viewl ParensSeqFt
rt3 of
    PS.Elem Count
w Count
nw :< ParensSeqFt
_ -> if Count
nw Count -> Count -> Bool
forall a. Ord a => a -> a -> Bool
>= Count
1 Bool -> Bool -> Bool
&& Count
w Count -> Count -> Count
forall a. BitWise a => a -> a -> a
.&. Count
1 Count -> Count -> Bool
forall a. Eq a => a -> a -> Bool
== Count
1 then () -> Maybe ()
forall a. a -> Maybe a
Just () else Maybe ()
forall a. Maybe a
Nothing
    ViewL (FingerTree Measure) Elem
FT.EmptyL         -> Maybe ()
forall a. Maybe a
Nothing
  Count -> Maybe Count
forall (m :: * -> *) a. Monad m => a -> m a
return (Count -> Maybe Count) -> Count -> Maybe Count
forall a b. (a -> b) -> a -> b
$ Count
1
    Count -> Count -> Count
forall a. Num a => a -> a -> a
+ Measure -> Count
PS.size (ParensSeqFt -> Measure
forall v a. Measured v a => a -> v
FT.measure ParensSeqFt
lt0 :: PS.Measure)
    Count -> Count -> Count
forall a. Num a => a -> a -> a
+ Measure -> Count
PS.size (ParensSeqFt -> Measure
forall v a. Measured v a => a -> v
FT.measure ParensSeqFt
lt1 :: PS.Measure)
    Count -> Count -> Count
forall a. Num a => a -> a -> a
+ Measure -> Count
PS.size (ParensSeqFt -> Measure
forall v a. Measured v a => a -> v
FT.measure ParensSeqFt
lt2 :: PS.Measure)
    Count -> Count -> Count
forall a. Num a => a -> a -> a
+ Measure -> Count
PS.size (ParensSeqFt -> Measure
forall v a. Measured v a => a -> v
FT.measure ParensSeqFt
lt3 :: PS.Measure)