{-# language ScopedTypeVariables #-}
{-# language PatternGuards #-}

module OBDD.Operation 

( (&&), (||), not, and, or
, ite, bool, implies, equiv, xor
, unary, binary
, instantiate
, exists, exists_many
, fold, foldM
, full_fold, full_foldM
)

where

import OBDD.Data
import OBDD.Make

import qualified Data.List ( sortBy)
import Data.Function (on)

import Data.Set ( Set )
import qualified Data.Set as S

-- import Data.List ( foldl' )
-- don't use, see below

import Prelude hiding ( (&&), (||), and, or, not, bool )
import qualified Prelude

infixr 3 &&

( && ) :: Ord v => OBDD v -> OBDD v -> OBDD v
( && ) = symmetric_binary ( Prelude.&& )

infixr 2 ||

( || ) :: Ord v => OBDD v -> OBDD v -> OBDD v
( || ) = symmetric_binary ( Prelude.|| )

bool :: Ord v => OBDD v -> OBDD v -> OBDD v -> OBDD v
bool f t p = (f && not p) || (t && p)

ite :: Ord v => OBDD v -> OBDD v -> OBDD v -> OBDD v
ite i t e = bool e t i

equiv :: Ord v => OBDD v -> OBDD v -> OBDD v
equiv = symmetric_binary (==)

xor :: Ord v => OBDD v -> OBDD v -> OBDD v
xor   = symmetric_binary (/=)

implies :: Ord v => OBDD v -> OBDD v -> OBDD v
implies = binary ( <= )

and :: Ord v => [ OBDD v ] -> OBDD v
and = fold_by_size (constant True) (&&)
-- and = foldr ( && ) ( constant True ) 
-- writing foldl or fold' here 
-- makes performance MUCH worse!

or :: Ord v => [ OBDD v ] -> OBDD v
or = fold_by_size (constant False) (||)
-- or = foldr ( || ) ( constant False ) 

fold_by_size base reduce fs =
    let handle fs = case fs of
            [] -> base
            [f1 ] -> f1
            f1 : f2 : rest -> 
                let f = reduce f1 f2
                in  handle $ insert f rest
        insert f gs = case gs of
            g : hs | size f > size g -> g : insert f hs
            _ -> f : gs
    in  handle $ Data.List.sortBy (compare `on` size) fs


-- | FIXME this is a silly implementation. Negation should be done
-- by switching values in Leaves (?)
not :: Ord v => OBDD v -> OBDD v 
not = unary ( Prelude.not )

unary :: Ord v 
      => ( Bool -> Bool )
      -> OBDD v -> OBDD v
unary op x = make $ do
    let handle x = cached ( top x, top x ) $ case access x of
                Leaf p -> register $ Leaf $ op p
                Branch v l r -> do
                        l' <- handle l
                        r' <- handle r
                        register $ Branch v l' r'
    handle x


data Symmetricity = Asymmetric | Symmetric deriving Show

binary :: Ord v
      => ( Bool -> Bool -> Bool )
      -> OBDD v -> OBDD v -> OBDD v
binary = binary_ Asymmetric
      
symmetric_binary :: Ord v
      => ( Bool -> Bool -> Bool )
      -> OBDD v -> OBDD v -> OBDD v
symmetric_binary = binary_ Symmetric


binary_ :: Ord v
      => Symmetricity
      -> ( Bool -> Bool -> Bool )
      -> OBDD v -> OBDD v -> OBDD v

binary_ sym op x y = make $ do
    let -- register = checked_register -- for testing
        -- handle x y | Symmetric <- sym, top x > top y = handle y x
        handle x y = cached (top x, top y) $ case ( access x , access y ) of
                    ( Leaf p , Leaf q ) -> register $ Leaf $ op p q
                    ( ax, ay ) -> case comp ax ay of
                        LT -> do
                            let Branch v l r = ay
                            l' <- handle x l
                            r' <- handle x r
                            register $ Branch v l' r'
                        GT -> do
                            let Branch v l r = ax
                            l' <- handle l y
                            r' <- handle r y
                            register $ Branch v l' r'
                        EQ -> do
                            let Branch v1 l1 r1 = ax
                                Branch v2 l2 r2 = ay
                                v = if v1 == v2 then v1 else error "OBDD.Operation.handle"
                            l' <- handle l1 l2
                            r' <- handle r1 r2
                            register $ Branch v l' r'
    handle x y

-- | remove variables existentially
-- TODO: needs better implementation
exists_many :: Ord v 
            => Set v
            -> OBDD v
            -> OBDD v
exists_many vars x =
    foldr exists x $ S.toList vars 

-- | remove variable existentially
-- TODO: needs better implementation
exists :: Ord v
      => v
      -> OBDD v -> OBDD v
exists var x = 
    instantiate var False x || instantiate var True x

-- | replace variable by value
instantiate :: Ord v => 
            v -> Bool 
            -> OBDD v
            -> OBDD v
instantiate var val x = make $ do
    let handle x = cached ( top x, top x ) $ case access x of
            Leaf p -> register $ Leaf p
            Branch v l r -> 
                if v == var then do
                     let t = if val then r else l
                     handle t
                else  do
                     l' <- handle l
                     r' <- handle r
                     register $ Branch v l' r' 
    handle x

comp x y = case (x , y) of
    ( Leaf {} , Leaf {} ) -> EQ
    ( Branch {} , Leaf {} ) -> GT
    ( Leaf {} , Branch {} ) ->  LT
    ( Branch v1 _ _ , Branch v2 _ _ ) -> compare v1 v2