module Data.Array.Repa.Plugin.ToDDC.Detect
        (detectModule)
where
import Data.Array.Repa.Plugin.FatName
import Data.Array.Repa.Plugin.ToDDC.Detect.Base
import Data.Array.Repa.Plugin.ToDDC.Detect.Type  ()

import DDC.Core.Module
import DDC.Core.Collect
import DDC.Type.Env
import DDC.Core.Flow
import DDC.Core.Flow.Exp
import DDC.Core.Flow.Prim
import DDC.Core.Flow.Compounds
import DDC.Core.Transform.Annotate
import DDC.Core.Transform.Deannotate

import Control.Monad.State.Strict

import qualified Data.Map       as Map
import Data.Map                 (Map)
import qualified Data.Set       as Set
import Data.List


detectModule 
        :: Module  () FatName 
        -> (Module () Name, Map Name GhcName)

detectModule mm
 = let  (mm', state')    = runState (detect mm) $ zeroState
   in   (mm', stateNames state')


-- Module ---------------------------------------------------------------------
instance Detect (Module ()) where
 detect mm
  = do  body'   <- liftM (annotate ()) 
                $  detect     (deannotate (const Nothing) $ moduleBody mm)
        importK <- detectMap  (moduleImportKinds mm)
        importT <- detectMap  (moduleImportTypes mm)

        -- Limit the import types to free vars in body:
        let free     = freeX empty body'
            importT' = Map.filterWithKey (\k _ -> Set.member (UName k) free) importT

        return  $ ModuleCore
                { moduleName            = moduleName mm
                , moduleExportKinds     = Map.empty
                , moduleExportTypes     = Map.empty
                , moduleImportKinds     = importK
                , moduleImportTypes     = importT'
                , moduleBody            = body' }


-- Convert the FatNames of an import map
detectMap  :: Map FatName (QualName FatName, Type FatName)
           -> State DetectS (Map Name (QualName Name, Type Name))
detectMap  m
 = do   let ms = Map.toList   m
        ms'   <- mapM detect' ms
        return $ Map.fromList ms'
 where
  detect' (FatName _ k,(QualName mn (FatName _ n), t))
   = do t' <- detect t
        return (k, (QualName mn n, t'))


-- DaCon ----------------------------------------------------------------------
instance Detect DaCon where
 detect (DaCon dcn t isAlg)
  = do  dcn'    <- detect dcn
        t'      <- detect t
        return  $  DaCon dcn' t' isAlg


instance Detect DaConName where
 detect dcn
  = case dcn of
        DaConUnit       
         -> return DaConUnit

        -- Booleans
        DaConNamed (FatName g d@(NameCon v))
         | isPrefixOf "True_" v
         -> do  collect d g
                return $ DaConNamed (NameLitBool True)
        DaConNamed (FatName g d@(NameCon v))
         | isPrefixOf "False_" v
         -> do  collect d g
                return $ DaConNamed (NameLitBool False)

                                                        -- TODO This should have been a NameCon
        DaConNamed (FatName g d@(NameVar v))
         | isPrefixOf "(,)_" v
         -> do  collect d g
                return $ DaConNamed (NameDaConFlow (DaConFlowTuple 2))

        DaConNamed (FatName g d)
         -> do  collect d g
                return $ DaConNamed d


-- Exp ------------------------------------------------------------------------
instance Detect (Exp a) where
 detect xx
  | XAnnot a x          <- xx
  = liftM (XAnnot a) $ detect x

  -- Set kind of detected rate variables to Rate.
  | XLam b x          <- xx
  = do  b'      <- detect b
        x'      <- detect x
        case b' of
         BName n _
          -> do rateVar <- isRateVar n
                if rateVar 
                 then return $ XLAM (BName n kRate) x'
                 else return $ XLam b' x'

         _ -> error "repa-plugin.detect[Exp] no match"

  -- Detect vectorOfSeries
  | XApp{}                              <- xx
  , Just  (XVar u,     [xTK, xTA, _xD, xS]) 
                                        <- takeXApps xx
  , UName (FatName _ (NameVar v))       <- u
  , isPrefixOf "toVector_" v
  = do  args'   <- mapM detect [xTK, xTA, xS]
        return  $ xApps (XVar (UPrim (NameOpFlow OpFlowVectorOfSeries)
                                     (typeOpFlow OpFlowVectorOfSeries)))
                          args'

  -- Detect folds.
  | XApp{}                              <- xx
  , Just  (XVar uFold, [xTK, xTA, xTB, _xD, xF, xZ, xS])    
                                        <- takeXApps xx
  , UName (FatName _ (NameVar vFold))   <- uFold
  , isPrefixOf "fold_" vFold
  = do  args'   <- mapM detect [xTK, xTA, xTB, xF, xZ, xS]
        return  $  xApps (XVar (UPrim (NameOpFlow OpFlowFold) 
                                      (typeOpFlow OpFlowFold)))
                         args'

  -- foldIndex
  | XApp{}                              <- xx
  , Just  (XVar uFold, [xTK, xTA, xTB, _xD, xF, xZ, xS])    
                                        <- takeXApps xx
  , UName (FatName _ (NameVar vFold))   <- uFold
  , isPrefixOf "foldIndex_" vFold
  = do  args'   <- mapM detect [xTK, xTA, xTB, xF, xZ, xS]
        return  $  xApps (XVar (UPrim (NameOpFlow OpFlowFoldIndex) 
                                      (typeOpFlow OpFlowFoldIndex)))
                         args'


  -- Detect maps
  | XApp{}                              <- xx
  , Just  (XVar uMap,  [xTK, xTA, xTB, _xD1, _xD2, xF, xS ])
                                        <- takeXApps xx
  , UName (FatName _ (NameVar vMap))    <- uMap
  , isPrefixOf "map_" vMap
  = do  args'   <- mapM detect [xTK, xTA, xTB, xF, xS]
        return  $ xApps (XVar (UPrim (NameOpFlow (OpFlowMap 1))
                                     (typeOpFlow (OpFlowMap 1))))
                        args'

  -- TODO mapN
  | XApp{}                              <- xx
  , Just  (XVar uMap,  [xTK, xTA, xTB, xTC, _xD1, _xD2, _xD3, xF, xS1, xS2 ])
                                        <- takeXApps xx
  , UName (FatName _ (NameVar vMap))    <- uMap
  , isPrefixOf "map2_" vMap
  = do  args'   <- mapM detect [xTK, xTA, xTB, xTC, xF, xS1, xS2]
        return  $ xApps (XVar (UPrim (NameOpFlow (OpFlowMap 2))
                                     (typeOpFlow (OpFlowMap 2))))
                        args'

  -- Detect packs
  | XApp{}                              <- xx
  , Just  (XVar uPack,  [xTK1, xTK2, xTA, _xD1, xSel, xF])
                                        <- takeXApps xx
  , UName (FatName _ (NameVar vPack))   <- uPack
  , isPrefixOf "pack_" vPack
  = do  args'   <- mapM detect [xTK1, xTK2, xTA, xSel, xF]
        return  $ xApps (XVar (UPrim (NameOpFlow OpFlowPack)
                                     (typeOpFlow OpFlowPack)))
                        args'

  -- Detect mkSels
  | XApp{}                              <- xx
  , Just  (XVar u,    [xTK, xTA, xFlags, xWorker])
                                        <- takeXApps xx
  , UName (FatName _ (NameVar v))       <- u
  , isPrefixOf "mkSel1_" v
  = do  args'   <- mapM detect [xTK, xTA, xFlags, xWorker]
        return  $ xApps (XVar (UPrim (NameOpFlow (OpFlowMkSel 1))
                                     (typeOpFlow (OpFlowMkSel 1))))
                        args'

  -- Detect n-tuples
  | XApp{}                              <- xx
  , Just  (XVar uTuple,  args)          <- takeXApps xx
  , UName (FatName _ (NameVar vTuple))  <- uTuple

  , size                                <- length args `div` 2
  , commas                              <- replicate (size-1) ','
  , prefix                              <- "(" ++ commas ++ ")_"

  , size > 1
  , isPrefixOf prefix vTuple
  = do  args'   <- mapM detect args
        let tuple = DaConFlowTuple size
            ty    = typeDaConFlow tuple
        return  $ xApps (XCon $ mkDaConAlg (NameDaConFlow tuple) ty)
                        args'


  -- Inject type arguments for arithmetic ops.
  --   In the Core code, arithmetic operations are expressed as monomorphic
  --   dictionary methods, which we convert to polytypic DDC primops.
  | XVar (UName (FatName nG (NameVar str)))    <- xx
  , Just (nD', tArg, tPrim)  <- matchPrimArith str
  = do  collect nD' nG
        return  $ xApps (XVar (UPrim nD' tPrim)) [XType tArg]


  -- Strip boxing constructors from literal values.
  | XApp (XVar (UName (FatName _ (NameCon str1)))) x2 <- xx
  , isPrefixOf "I#_" str1
  = detect x2

  
  -- Boilerplate traversal.
  | otherwise
  = case xx of
        XAnnot a x      -> liftM (XAnnot a) (detect x)
        XVar  u         -> liftM  XVar  (detect u)
        XCon  u         -> liftM  XCon  (detect u)
        XLAM  b x       -> liftM2 XLAM  (detect b)   (detect x)
        XLam  b x       -> liftM2 XLam  (detect b)   (detect x)
        XApp  x1 x2     -> liftM2 XApp  (detect x1)  (detect x2)
        XLet  lts x     -> liftM2 XLet  (detect lts) (detect x)
        XType t         -> liftM  XType (detect t)

        XCase x alts    -> liftM2 XCase (detect x)   (mapM detect alts)
        XCast{}         -> error "repa-plugin.detect: XCast not handled"
        XWitness{}      -> error "repa-plugin.detect: XWitness not handled"


-- Match arithmetic operators.
matchPrimArith :: String -> Maybe (Name, Type Name, Type Name)
matchPrimArith str
 -- Num
 | isPrefixOf "$fNumInt_$c+_" str       
 = Just (NamePrimArith PrimArithAdd, tInt, typePrimArith PrimArithAdd)

 | isPrefixOf "$fNumInt_$c-_" str       
 = Just (NamePrimArith PrimArithSub, tInt, typePrimArith PrimArithSub)

 | isPrefixOf "$fNumInt_$c*_" str
 = Just (NamePrimArith PrimArithMul, tInt, typePrimArith PrimArithMul)

 -- Integral
 | isPrefixOf "$fIntegralInt_$cdiv_" str
 = Just (NamePrimArith PrimArithDiv, tInt, typePrimArith PrimArithDiv)

 | isPrefixOf "$fIntegralInt_$crem_" str
 = Just (NamePrimArith PrimArithRem, tInt, typePrimArith PrimArithRem)

 | isPrefixOf "$fIntegralInt_$cmod_" str
 = Just (NamePrimArith PrimArithMod, tInt, typePrimArith PrimArithMod)

 -- Eq
 | isPrefixOf "eqInt_" str
 = Just (NamePrimArith PrimArithEq,  tInt, typePrimArith PrimArithEq)

 | isPrefixOf "gtInt_" str
 = Just (NamePrimArith PrimArithGt,  tInt, typePrimArith PrimArithGt)

 | isPrefixOf "ltInt_" str
 = Just (NamePrimArith PrimArithLt,  tInt, typePrimArith PrimArithLt)

 | otherwise
 = Nothing


--- Lets ----------------------------------------------------------------------
instance Detect (Lets a) where
 detect ll
  = case ll of
        LLet b x      
         -> do  b'      <- detect b
                x'      <- detect x
                return  $ LLet b' x'

        LRec bxs        
         -> do  let (bs, xs) = unzip bxs
                bs'     <- mapM detect bs
                xs'     <- mapM detect xs
                return  $ LRec $ zip bs' xs'

        LLetRegions{}   -> error "repa-plugin.detect: LLetRegions not handled"
        LWithRegion{}   -> error "repa-plugin.detect: LWithRegions not handled"


--- Alt  ----------------------------------------------------------------------
instance Detect (Alt a) where
 detect (AAlt p x)
  = liftM2 AAlt (detect p) (detect x)

instance Detect Pat where
 detect p
  = case p of
        PDefault
         -> return PDefault

        PData dc bs
         -> liftM2 PData (detect dc) (mapM detect bs)