{- -----------------------------------------------------------------------------
Copyright 2020-2021 Kevin P. Barry

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
----------------------------------------------------------------------------- -}

-- Author: Kevin P. Barry [ta0kira@gmail.com]

{-# LANGUAGE Safe #-}
{-# LANGUAGE TypeFamilies #-}

module Base.MergeTree (
  MergeTree,
  matchOnlyLeaf,
  mergeAllM,
  mergeAnyM,
  mergeLeaf,
  pairMergeTree,
  reduceMergeTree,
) where

import Data.List (intercalate)

import Base.CompilerError
import Base.Mergeable


data MergeTree a =
  MergeAny [MergeTree a] |
  MergeAll [MergeTree a] |
  MergeLeaf a
  deriving (MergeTree a -> MergeTree a -> Bool
forall a. Eq a => MergeTree a -> MergeTree a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: MergeTree a -> MergeTree a -> Bool
$c/= :: forall a. Eq a => MergeTree a -> MergeTree a -> Bool
== :: MergeTree a -> MergeTree a -> Bool
$c== :: forall a. Eq a => MergeTree a -> MergeTree a -> Bool
Eq)

mergeLeaf :: a -> MergeTree a
mergeLeaf :: forall a. a -> MergeTree a
mergeLeaf = forall a. a -> MergeTree a
MergeLeaf

instance Show a => Show (MergeTree a) where
  show :: MergeTree a -> String
show = forall b a.
([b] -> b) -> ([b] -> b) -> (a -> b) -> MergeTree a -> b
reduceMergeCommon [String] -> String
anyOp [String] -> String
allOp forall {a}. Show a => a -> String
leafOp where
    anyOp :: [String] -> String
anyOp [String]
xs = String
"mergeAny [" forall a. [a] -> [a] -> [a]
++ forall a. [a] -> [[a]] -> [a]
intercalate String
"," [String]
xs forall a. [a] -> [a] -> [a]
++ String
"]"
    allOp :: [String] -> String
allOp [String]
xs = String
"mergeAll [" forall a. [a] -> [a] -> [a]
++ forall a. [a] -> [[a]] -> [a]
intercalate String
"," [String]
xs forall a. [a] -> [a] -> [a]
++ String
"]"
    leafOp :: a -> String
leafOp a
x = String
"mergeLeaf " forall a. [a] -> [a] -> [a]
++ forall {a}. Show a => a -> String
show a
x

instance PreserveMerge (MergeTree a) where
  type T (MergeTree a) = a
  convertMerge :: forall b. Mergeable b => (T (MergeTree a) -> b) -> MergeTree a -> b
convertMerge = forall b a.
([b] -> b) -> ([b] -> b) -> (a -> b) -> MergeTree a -> b
reduceMergeCommon forall a (f :: * -> *). (Mergeable a, Foldable f) => f a -> a
mergeAny forall a (f :: * -> *). (Mergeable a, Foldable f) => f a -> a
mergeAll

reduceMergeTree :: PreserveMerge a => ([b] -> b) -> ([b] -> b) -> (T a -> b) -> a -> b
reduceMergeTree :: forall a b.
PreserveMerge a =>
([b] -> b) -> ([b] -> b) -> (T a -> b) -> a -> b
reduceMergeTree [b] -> b
anyOp [b] -> b
allOp T a -> b
leafOp = forall b a.
([b] -> b) -> ([b] -> b) -> (a -> b) -> MergeTree a -> b
reduceMergeCommon [b] -> b
anyOp [b] -> b
allOp T a -> b
leafOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. PreserveMerge a => a -> MergeTree (T a)
toMergeTree

toMergeTree :: PreserveMerge a => a -> MergeTree (T a)
toMergeTree :: forall a. PreserveMerge a => a -> MergeTree (T a)
toMergeTree = forall a b. (PreserveMerge a, Mergeable b) => (T a -> b) -> a -> b
convertMerge forall a. a -> MergeTree a
mergeLeaf

reduceMergeCommon :: ([b] -> b) -> ([b] -> b) -> (a -> b) -> MergeTree a -> b
reduceMergeCommon :: forall b a.
([b] -> b) -> ([b] -> b) -> (a -> b) -> MergeTree a -> b
reduceMergeCommon [b] -> b
anyOp [b] -> b
allOp a -> b
leafOp = MergeTree a -> b
reduce where
  reduce :: MergeTree a -> b
reduce (MergeAny [MergeTree a]
xs) = [b] -> b
anyOp forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map MergeTree a -> b
reduce [MergeTree a]
xs
  reduce (MergeAll [MergeTree a]
xs) = [b] -> b
allOp forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map MergeTree a -> b
reduce [MergeTree a]
xs
  reduce (MergeLeaf a
x) = a -> b
leafOp a
x

pairMergeTree :: (PreserveMerge a, PreserveMerge b) =>
  ([c] -> c) -> ([c] -> c) -> (T a -> T b -> c) -> a -> b -> c
pairMergeTree :: forall a b c.
(PreserveMerge a, PreserveMerge b) =>
([c] -> c) -> ([c] -> c) -> (T a -> T b -> c) -> a -> b -> c
pairMergeTree [c] -> c
anyOp [c] -> c
allOp T a -> T b -> c
leafOp a
x b
y = MergeTree (T a) -> MergeTree (T b) -> c
pair (forall a. PreserveMerge a => a -> MergeTree (T a)
toMergeTree a
x) (forall a. PreserveMerge a => a -> MergeTree (T a)
toMergeTree b
y) where
  pair :: MergeTree (T a) -> MergeTree (T b) -> c
pair (MergeLeaf T a
x2) (MergeLeaf T b
y2) = T a
x2 T a -> T b -> c
`leafOp` T b
y2
  pair x2 :: MergeTree (T a)
x2@(MergeAll [MergeTree (T a)]
xs) y2 :: MergeTree (T b)
y2@(MergeAny [MergeTree (T b)]
ys) =
    [c] -> c
anyOp forall a b. (a -> b) -> a -> b
$ [c]
leafComp forall a. [a] -> [a] -> [a]
++ [c]
leftComp forall a. [a] -> [a] -> [a]
++ [c]
rightComp where
    ([MergeTree (T a)]
xs2,[T a]
xl) = forall {a}. [MergeTree a] -> ([MergeTree a], [a])
separateLeaves [MergeTree (T a)]
xs
    ([MergeTree (T b)]
ys2,[T b]
yl) = forall {a}. [MergeTree a] -> ([MergeTree a], [a])
separateLeaves [MergeTree (T b)]
ys
    -- Non-leaves need the entire other side available.
    leftComp :: [c]
leftComp  = forall a b. (a -> b) -> [a] -> [b]
map (MergeTree (T a) -> MergeTree (T b) -> c
`pair` MergeTree (T b)
y2) [MergeTree (T a)]
xs2
    rightComp :: [c]
rightComp = forall a b. (a -> b) -> [a] -> [b]
map (MergeTree (T a)
x2 MergeTree (T a) -> MergeTree (T b) -> c
`pair`) [MergeTree (T b)]
ys2
    -- Leaves can be expanded either side first.
    leafComp :: [c]
leafComp = do
      T a
xx <- [T a]
xl
      T b
yy <- [T b]
yl
      [T a
xx T a -> T b -> c
`leafOp` T b
yy]
  -- NOTE: allOp is expanded first so that anyOp is ignored when either both
  -- sides are minBound or both sides are maxBound. This allows
  -- pairMergeTree mergeAny mergeAll (==) to be a partial order.
  pair (MergeAny [MergeTree (T a)]
xs) MergeTree (T b)
y2 = [c] -> c
allOp forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (MergeTree (T a) -> MergeTree (T b) -> c
`pair` MergeTree (T b)
y2) [MergeTree (T a)]
xs
  pair MergeTree (T a)
x2 (MergeAll [MergeTree (T b)]
ys) = [c] -> c
allOp forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (MergeTree (T a)
x2 MergeTree (T a) -> MergeTree (T b) -> c
`pair`) [MergeTree (T b)]
ys
  pair (MergeAll [MergeTree (T a)]
xs) MergeTree (T b)
y2 = [c] -> c
anyOp forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (MergeTree (T a) -> MergeTree (T b) -> c
`pair` MergeTree (T b)
y2) [MergeTree (T a)]
xs
  pair MergeTree (T a)
x2 (MergeAny [MergeTree (T b)]
ys) = [c] -> c
anyOp forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (MergeTree (T a)
x2 MergeTree (T a) -> MergeTree (T b) -> c
`pair`) [MergeTree (T b)]
ys
  separateLeaves :: [MergeTree a] -> ([MergeTree a], [a])
separateLeaves = forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr forall {a}.
MergeTree a -> ([MergeTree a], [a]) -> ([MergeTree a], [a])
split ([],[]) where
    split :: MergeTree a -> ([MergeTree a], [a]) -> ([MergeTree a], [a])
split (MergeLeaf a
x2) ([MergeTree a]
ms,[a]
ls) = ([MergeTree a]
ms,a
x2forall a. a -> [a] -> [a]
:[a]
ls)
    split MergeTree a
x2             ([MergeTree a]
ms,[a]
ls) = (MergeTree a
x2forall a. a -> [a] -> [a]
:[MergeTree a]
ms,[a]
ls)

instance Functor MergeTree where
  fmap :: forall a b. (a -> b) -> MergeTree a -> MergeTree b
fmap a -> b
f = forall b a.
([b] -> b) -> ([b] -> b) -> (a -> b) -> MergeTree a -> b
reduceMergeCommon forall a (f :: * -> *). (Mergeable a, Foldable f) => f a -> a
mergeAny forall a (f :: * -> *). (Mergeable a, Foldable f) => f a -> a
mergeAll (forall a. a -> MergeTree a
mergeLeaf forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> b
f)

instance Applicative MergeTree where
  pure :: forall a. a -> MergeTree a
pure = forall a. a -> MergeTree a
mergeLeaf
  MergeTree (a -> b)
f <*> :: forall a b. MergeTree (a -> b) -> MergeTree a -> MergeTree b
<*> MergeTree a
x = forall b a.
([b] -> b) -> ([b] -> b) -> (a -> b) -> MergeTree a -> b
reduceMergeCommon forall a (f :: * -> *). (Mergeable a, Foldable f) => f a -> a
mergeAny forall a (f :: * -> *). (Mergeable a, Foldable f) => f a -> a
mergeAll (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MergeTree a
x) MergeTree (a -> b)
f

instance Monad MergeTree where
  return :: forall a. a -> MergeTree a
return = forall (f :: * -> *) a. Applicative f => a -> f a
pure
  MergeTree a
x >>= :: forall a b. MergeTree a -> (a -> MergeTree b) -> MergeTree b
>>= a -> MergeTree b
f = forall b a.
([b] -> b) -> ([b] -> b) -> (a -> b) -> MergeTree a -> b
reduceMergeCommon forall a (f :: * -> *). (Mergeable a, Foldable f) => f a -> a
mergeAny forall a (f :: * -> *). (Mergeable a, Foldable f) => f a -> a
mergeAll a -> MergeTree b
f MergeTree a
x

instance Foldable MergeTree where
  foldr :: forall a b. (a -> b -> b) -> b -> MergeTree a -> b
foldr a -> b -> b
f b
y = forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr a -> b -> b
f b
y forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall b a.
([b] -> b) -> ([b] -> b) -> (a -> b) -> MergeTree a -> b
reduceMergeCommon forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat (forall a. a -> [a] -> [a]
:[])

instance Traversable MergeTree where
  traverse :: forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> MergeTree a -> f (MergeTree b)
traverse a -> f b
f = forall b a.
([b] -> b) -> ([b] -> b) -> (a -> b) -> MergeTree a -> b
reduceMergeCommon [f (MergeTree b)] -> f (MergeTree b)
anyOp [f (MergeTree b)] -> f (MergeTree b)
allOp a -> f (MergeTree b)
leafOp where
    anyOp :: [f (MergeTree b)] -> f (MergeTree b)
anyOp = (forall a (f :: * -> *). (Mergeable a, Foldable f) => f a -> a
mergeAny forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
(<*>) (forall (f :: * -> *) a. Applicative f => a -> f a
pure []) forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a b. (a -> b) -> [a] -> [b]
map (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (:)))
    allOp :: [f (MergeTree b)] -> f (MergeTree b)
allOp = (forall a (f :: * -> *). (Mergeable a, Foldable f) => f a -> a
mergeAll forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
(<*>) (forall (f :: * -> *) a. Applicative f => a -> f a
pure []) forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a b. (a -> b) -> [a] -> [b]
map (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (:)))
    leafOp :: a -> f (MergeTree b)
leafOp = (forall a. a -> MergeTree a
mergeLeaf forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>) forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> f b
f

instance Mergeable (MergeTree a) where
  mergeAny :: forall (f :: * -> *). Foldable f => f (MergeTree a) -> MergeTree a
mergeAny = forall {a}. [MergeTree a] -> MergeTree a
unnest forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (forall a. [a] -> [a] -> [a]
(++) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {a}. MergeTree a -> [MergeTree a]
flattenAny) [] where
    flattenAny :: MergeTree a -> [MergeTree a]
flattenAny (MergeAny [MergeTree a]
xs) = [MergeTree a]
xs
    flattenAny MergeTree a
x             = [MergeTree a
x]
    unnest :: [MergeTree a] -> MergeTree a
unnest [MergeTree a
x] = MergeTree a
x
    unnest [MergeTree a]
xs  = forall {a}. [MergeTree a] -> MergeTree a
MergeAny [MergeTree a]
xs
  mergeAll :: forall (f :: * -> *). Foldable f => f (MergeTree a) -> MergeTree a
mergeAll = forall {a}. [MergeTree a] -> MergeTree a
unnest forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (forall a. [a] -> [a] -> [a]
(++) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {a}. MergeTree a -> [MergeTree a]
flattenAll) [] where
    flattenAll :: MergeTree a -> [MergeTree a]
flattenAll (MergeAll [MergeTree a]
xs) = [MergeTree a]
xs
    flattenAll MergeTree a
x             = [MergeTree a
x]
    unnest :: [MergeTree a] -> MergeTree a
unnest [MergeTree a
x] = MergeTree a
x
    unnest [MergeTree a]
xs  = forall {a}. [MergeTree a] -> MergeTree a
MergeAll [MergeTree a]
xs

instance Bounded (MergeTree a) where
  minBound :: MergeTree a
minBound = forall a (f :: * -> *). (Mergeable a, Foldable f) => f a -> a
mergeAny forall a. Maybe a
Nothing
  maxBound :: MergeTree a
maxBound = forall a (f :: * -> *). (Mergeable a, Foldable f) => f a -> a
mergeAll forall a. Maybe a
Nothing

mergeAnyM :: (PreserveMerge a, CollectErrorsM m) => [m a] -> m a
mergeAnyM :: forall a (m :: * -> *).
(PreserveMerge a, CollectErrorsM m) =>
[m a] -> m a
mergeAnyM [m a]
xs = do
  forall (f :: * -> *) (m :: * -> *) a.
(Foldable f, CollectErrorsM m) =>
f (m a) -> m ()
collectFirstM_ [m a]
xs
  forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a (f :: * -> *). (Mergeable a, Foldable f) => f a -> a
mergeAny forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) (f :: * -> *) a.
(CollectErrorsM m, Foldable f) =>
f (m a) -> m [a]
collectAnyM [m a]
xs

mergeAllM :: (PreserveMerge a, CollectErrorsM m) => [m a] -> m a
mergeAllM :: forall a (m :: * -> *).
(PreserveMerge a, CollectErrorsM m) =>
[m a] -> m a
mergeAllM = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a (f :: * -> *). (Mergeable a, Foldable f) => f a -> a
mergeAll forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) (f :: * -> *) a.
(CollectErrorsM m, Foldable f) =>
f (m a) -> m [a]
collectAllM

matchOnlyLeaf :: (PreserveMerge a, CollectErrorsM m) => a -> m (T a)
matchOnlyLeaf :: forall a (m :: * -> *).
(PreserveMerge a, CollectErrorsM m) =>
a -> m (T a)
matchOnlyLeaf = forall a b.
PreserveMerge a =>
([b] -> b) -> ([b] -> b) -> (T a -> b) -> a -> b
reduceMergeTree (forall a b. a -> b -> a
const forall (m :: * -> *) a. CollectErrorsM m => m a
emptyErrorM) (forall a b. a -> b -> a
const forall (m :: * -> *) a. CollectErrorsM m => m a
emptyErrorM) forall (m :: * -> *) a. Monad m => a -> m a
return