module Language.Syntactic.Functional.Sharing
    ( 
      InjDict (..)
    , CodeMotionInterface (..)
    , defaultInterface
    , defaultInterfaceDecor
      
    , codeMotion
    ) where
import Control.Monad.State
import Data.Maybe (isNothing)
import Data.Set (Set)
import qualified Data.Set as Set
import Data.Typeable
import Data.Constraint (Dict (..))
import Language.Syntactic
import Language.Syntactic.Functional
data InjDict sym a b = InjDict
    { injVariable :: Name -> sym (Full a)
        
    , injLambda   :: Name -> sym (b :-> Full (a -> b))
        
    , injLet      :: sym (a :-> (a -> b) :-> Full b)
        
    }
data CodeMotionInterface sym = Interface
    { mkInjDict   :: forall a b . ASTF sym a -> ASTF sym b -> Maybe (InjDict sym a b)
        
        
        
        
        
    , castExprCM  :: forall a b . ASTF sym a -> ASTF sym b -> Maybe (ASTF sym b)
        
        
        
        
    , hoistOver   :: forall c. ASTF sym c -> Bool
        
    }
defaultInterface :: forall binding sym symT
    .  ( binding :<: sym
       , Let     :<: sym
       , symT ~ Typed sym
       )
    => (forall a .   Typeable a => Name -> binding (Full a))
         
    -> (forall a b . Typeable a => Name -> binding (b :-> Full (a -> b)))
         
    -> (forall a b . ASTF symT a -> ASTF symT b -> Bool)
         
         
    -> (forall a . ASTF symT a -> Bool)
         
    -> CodeMotionInterface symT
defaultInterface var lam sharable hoistOver = Interface {..}
  where
    mkInjDict :: ASTF symT a -> ASTF symT b -> Maybe (InjDict symT a b)
    mkInjDict a b | not (sharable a b) = Nothing
    mkInjDict a b =
        simpleMatch
          (\(Typed _) _ -> simpleMatch
            (\(Typed _) _ ->
              let injVariable = Typed . inj . var
                  injLambda   = Typed . inj . lam
                  injLet      = Typed $ inj (Let "")
              in  Just InjDict {..}
            ) b
          ) a
    castExprCM = castExpr
defaultInterfaceDecor :: forall binding sym symI info
    .  ( binding :<: sym
       , Let     :<: sym
       , symI ~ (sym :&: info)
       )
    => (forall a b . info a -> info b -> Maybe (Dict (a ~ b)))
         
    -> (forall a b . info a -> info b -> info (a -> b))
         
         
    -> (forall a . info a -> Name -> binding (Full a))
         
    -> (forall a b . info a -> info b -> Name -> binding (b :-> Full (a -> b)))
         
    -> (forall a b . ASTF symI a -> ASTF symI b -> Bool)
         
         
    -> (forall a . ASTF symI a -> Bool)
         
    -> CodeMotionInterface symI
defaultInterfaceDecor teq mkFunInfo var lam sharable hoistOver = Interface {..}
  where
    mkInjDict :: ASTF symI a -> ASTF symI b -> Maybe (InjDict symI a b)
    mkInjDict a b | not (sharable a b) = Nothing
    mkInjDict a b =
        simpleMatch
          (\(_ :&: aInfo) _ -> simpleMatch
            (\(_ :&: bInfo) _ ->
              let injVariable v = inj (var aInfo v) :&: aInfo
                  injLambda   v = inj (lam aInfo bInfo v) :&: mkFunInfo aInfo bInfo
                  injLet        = inj (Let "") :&: bInfo
              in  Just InjDict {..}
            ) b
          ) a
    castExprCM :: ASTF symI a -> ASTF symI b -> Maybe (ASTF symI b)
    castExprCM a b =
        simpleMatch
          (\(_ :&: aInfo) _ -> simpleMatch
            (\(_ :&: bInfo) _ -> case teq aInfo bInfo of
              Just Dict -> Just a
              _ -> Nothing
            ) b
          ) a
substitute :: forall sym a b
    .  (Equality sym, BindingDomain sym)
    => CodeMotionInterface sym
    -> ASTF sym a  
    -> ASTF sym a  
    -> ASTF sym b  
    -> ASTF sym b
substitute iface x y a = subst a
  where
    fv = freeVars x
    subst :: ASTF sym c -> ASTF sym c
    subst a
      | Just y' <- castExprCM iface y a, alphaEq x a = y'
      | otherwise = subst' a
    subst' :: AST sym c -> AST sym c
    subst' a@(lam :$ body)
      | Just v <- prLam lam
      , Set.member v fv = a
    subst' (s :$ a) = subst' s :$ subst a
    subst' a = a
  
  
  
  
  
count :: forall sym a b
    .  (Equality sym, BindingDomain sym)
    => ASTF sym a  
    -> ASTF sym b  
    -> Int
count a b = cnt b
  where
    fv = freeVars a
    cnt :: ASTF sym c -> Int
    cnt c
      | alphaEq a c = 1
      | otherwise   = cnt' c
    cnt' :: AST sym sig -> Int
    cnt' (lam :$ body)
      | Just v <- prLam lam
      , Set.member v fv = 0
          
          
          
          
          
          
          
          
          
          
    cnt' (s :$ c) = cnt' s + cnt c
    cnt' _        = 0
data Env sym = Env
    { inLambda :: Bool  
    , counter  :: EF (AST sym) -> Int
        
        
    , dependencies :: Set Name
        
        
    }
liftable :: BindingDomain sym => Env sym -> ASTF sym a -> Bool
liftable env a = independent && isNothing (prVar a) && heuristic
      
      
  where
    independent = Set.null $ Set.intersection (freeVars a) (dependencies env)
    heuristic   = inLambda env || (counter env (EF a) > 1)
data Chosen sym a
  where
    Chosen :: InjDict sym b a -> ASTF sym b -> Chosen sym a
choose :: forall sym a
    .  (Equality sym, BindingDomain sym)
    => CodeMotionInterface sym
    -> ASTF sym a
    -> Maybe (Chosen sym a)
choose iface a = chooseEnvSub initEnv a
  where
    initEnv = Env
        { inLambda     = False
        , counter      = \(EF b) -> count b a
        , dependencies = Set.empty
        }
    chooseEnv :: Env sym -> ASTF sym b -> Maybe (Chosen sym a)
    chooseEnv env b
        | liftable env b
        , Just id <- mkInjDict iface b a
        = Just $ Chosen id b
    chooseEnv env b
        | hoistOver iface b = chooseEnvSub env b
        | otherwise         = Nothing
    
    chooseEnvSub :: Env sym -> AST sym b -> Maybe (Chosen sym a)
    chooseEnvSub env (Sym lam :$ b)
        | Just v <- prLam lam
        = chooseEnv (env' v) b
      where
        env' v = env
            { inLambda     = True
            , dependencies = Set.insert v (dependencies env)
            }
    chooseEnvSub env (s :$ b) = chooseEnvSub env s `mplus` chooseEnv env b
    chooseEnvSub _ _ = Nothing
codeMotionM :: forall sym m a
    .  ( Equality sym
       , BindingDomain sym
       , MonadState Name m
       )
    => CodeMotionInterface sym
    -> ASTF sym a
    -> m (ASTF sym a)
codeMotionM iface a
    | Just (Chosen id b) <- choose iface a = share id b
    | otherwise = descend a
  where
    share :: InjDict sym b a -> ASTF sym b -> m (ASTF sym a)
    share id b = do
        b' <- codeMotionM iface b
        v  <- get; put (v+1)
        let x = Sym (injVariable id v)
        body <- codeMotionM iface $ substitute iface b x a
        return
            $  Sym (injLet id)
            :$ b'
            :$ (Sym (injLambda id v) :$ body)
    descend :: AST sym b -> m (AST sym b)
    descend (s :$ a) = liftM2 (:$) (descend s) (codeMotionM iface a)
    descend a        = return a
codeMotion :: forall sym m a
    .  ( Equality sym
       , BindingDomain sym
       )
    => CodeMotionInterface sym
    -> ASTF sym a
    -> ASTF sym a
codeMotion iface a = flip evalState maxVar $ codeMotionM iface a
  where
    maxVar = succ $ Set.findMax $ Set.insert 0 $ allVars a