{-# LANGUAGE CPP, DeriveDataTypeable, ExistentialQuantification, 
NoImplicitPrelude, OverloadedStrings, 
TupleSections #-}
{-# OPTIONS -Wall #-}

module Language.Paraiso.Generator.PlanTrans (
  translate,
  commonLibraries
  ) where


import qualified Algebra.Additive                    as Additive
import           Data.Char
import           Data.Dynamic
import qualified Data.Graph.Inductive                as FGL
import           Data.List (sortBy)
import qualified Data.Foldable                       as F
import           Data.Maybe
import qualified Data.Set                            as Set
import           Data.Tensor.TypeLevel
import qualified Data.Tensor.TypeLevel.Axis          as Axis
import qualified Data.Text                           as T
import qualified Data.Traversable                    as F
import qualified Data.Vector                         as V
import qualified Language.Paraiso.Annotation         as Anot
import qualified Language.Paraiso.Annotation.Boundary as Boundary
import qualified Language.Paraiso.Annotation.SyncThreads as Sync
import qualified Language.Paraiso.Generator.Claris   as C
import           Language.Paraiso.Generator.ClarisTrans (typeRepDB)
import qualified Language.Paraiso.Generator.Native   as Native
import qualified Language.Paraiso.Generator.Plan     as Plan
import qualified Language.Paraiso.OM                 as OM
import qualified Language.Paraiso.OM.Arithmetic      as Arith
import qualified Language.Paraiso.OM.DynValue        as DVal
import qualified Language.Paraiso.OM.Graph           as OM
import qualified Language.Paraiso.OM.Realm           as Realm
import qualified Language.Paraiso.Optimization.Graph as Opt
import           Language.Paraiso.Name
import           Language.Paraiso.Prelude hiding (Boolean(..))
import           NumericPrelude hiding ((++))

-- the standard annotation type.
type AnAn = Anot.Annotation

-- the set of variable needed to construct various things,
-- this makes the functions a bit tedious...
data Env v g = Env (Native.Setup v g) (Plan.Plan v g AnAn)

-- translate the plan to Claris
translate :: Opt.Ready v g => Native.Setup v g -> Plan.Plan v g AnAn -> C.Program
translate setup plan = 
  C.Program 
  { C.progName = name plan,
    C.topLevel = 
      stlHeaders ++ 
      library env ++ 
      comments ++
      subHelperFuncs ++ 
      [ C.ClassDef $ C.Class (name plan) $
        map fst storageVars ++ 
        constructorDef ++ 
        accessorsForSize env ++
        accessorsForVars env ++
        subMemberFuncs ++ memberFuncs 
      ]
  }
  where
    env = Env setup plan
    comments = (:[]) $ C.Comment $ T.unlines [ 
      "",
      "lowerMargin = " ++ showT (Plan.lowerMargin plan),
      "upperMargin = " ++ showT (Plan.upperMargin plan)
      ]

    constructorDef = 
      (:[]) $
      C.MemberFunc C.Public True $
      (C.function  C.ConstructorType (name plan))
      { C.funcMemberInitializer = -- allocate memory for storage members
           concat $
           flip map storageVars $ \(memb, stRef) -> 
             case (memb, Realm.realm $ Plan.storageDynValue stRef) of
               (C.MemberVar _ var, Realm.Array)
                 -> [C.FuncCallUsr (name var) [C.FuncCallUsr (name (omFuncMemorySizeTotal env)) []]]
               _ -> []
      }

    memberFuncs = 
      V.toList $ 
      V.imap (\idx ker -> makeKernelFunc env idx ker) $ 
      Plan.kernels plan

    subHelperFuncs = concat $ V.toList $ V.map snd $ subKernelFuncs   
    subMemberFuncs = concat $ V.toList $ V.map fst $ subKernelFuncs   
    subKernelFuncs = V.map (makeSubFunc env) $ Plan.subKernels plan


    include   = C.Exclusive C.HeaderFile . C.StmtPrpr . C.PrprInclude C.Chevron
    include'' = C.Exclusive C.HeaderFile . C.StmtPrpr . C.PrprInclude C.Quotation2
    stlHeaders = case Native.language setup of
      Native.CPlusPlus -> map include ["algorithm", "cmath", "vector"]
      Native.CUDA      -> map include ["thrust/device_vector.h", "thrust/host_vector.h", "thrust/functional.h", "thrust/extrema.h", "thrust/reduce.h"] ++ 
                          map (include'' . T.pack . fst) commonLibraries

    storageVars = 
      V.toList $
      V.map storageRefToMember $
      Plan.storages plan
    storageRefToMember stRef =  
      (,stRef) $
      C.MemberVar C.Private $ 
      C.Var 
        (mkCppType env $ Plan.storageType stRef) 
        (name stRef) 

-- | Generate member functions for accessing the static variables
accessorsForVars :: Opt.Ready v g => Env v g -> [C.MemberDef]
accessorsForVars env@(Env setup plan) =
  map (C.MemberFunc C.Public True) $ do
    -- list monad
    stRef <- (V.toList $ Plan.storages plan)
    
    -- we have accessors only for static values.
    Just na <- [accessorName stRef]
    
    -- for all static values we have simple accessors that returns reference to it.
    let typeSimple = C.RefOf $ mkCppType env $ Plan.storageType stRef
    let bodySimple = [C.StmtReturn $ mkVarExpr $ nameText stRef]
    let argsSimple = []
        
    -- we also have elemental accessors, but only for Arrays.    
    let typeElemMaybe = case Plan.storageDynValue stRef of 
          DVal.DynValue Realm.Scalar _ -> []          
          DVal.DynValue Realm.Array  c -> [C.RefOf $ C.UnitType c]

    let argsElem = compose (\ax@(Axis i) -> 
                             C.Var (C.typeOf (Native.localSize setup ! ax))
                                   (mkName $ "i" ++ showT i))
        bodyElem = (:[]) $
                   C.StmtReturn $ C.ArrayAccess (mkVarExpr $ nameText stRef) $
                   productedArgs ! Axis 0
        productedArgs = compose 
          (\i -> 
            if (Axis.next i == Axis 0) 
            then (marginedArgs!i) 
            else  (marginedArgs!i) +
                  (C.FuncCallUsr (name $ omFuncMemorySize env ! i) []) *
                                (productedArgs!(Axis.next i)))
        marginedArgs = compose (\i -> 
                                  C.FuncCallUsr (name $ omFuncLowerMargin env ! i) [] +
                                  C.VarExpr (argsElem ! i))
                   
    let materials = (typeSimple, bodySimple, argsSimple) : fmap (,bodyElem,F.toList argsElem) typeElemMaybe
        
    (typ, bod, arg) <- materials
    
    return (C.function typ na){C.funcBody = bod, C.funcArgs = arg}    

accessorName :: (Plan.StorageRef v g a) -> Maybe Name    
accessorName x = case Plan.storageIdx x of
    Plan.StaticRef i     -> Just $ name $ OM.staticValues (Plan.setup $ Plan.parent x) V.! i
    Plan.ManifestRef i j -> Nothing



-- | Generate member functions that returns the sizes of the mesh
accessorsForSize :: Opt.Ready v g => Env v g -> [C.MemberDef]
accessorsForSize env =
  map (C.MemberFunc C.Public True) $
  ([omFuncLocalSizeTotal env, omFuncMemorySizeTotal env] ++) $
  concat $ 
  map F.toList $
  [omFuncLocalSize env, omFuncMemorySize env, omFuncLowerMargin env, omFuncUpperMargin env]

omFuncLocalSizeTotal :: Opt.Ready v g => Env v g -> C.Function                               
omFuncLocalSize      :: Opt.Ready v g => Env v g -> v C.Function
omFuncLocalSizeTotal = fst $ makeOmSizeFuncSet "om_size" (\(Env setup _) -> Native.localSize setup)
omFuncLocalSize      = snd $ makeOmSizeFuncSet "om_size" (\(Env setup _) -> Native.localSize setup)

omFuncMemorySizeTotal :: Opt.Ready v g => Env v g -> C.Function                               
omFuncMemorySize      :: Opt.Ready v g => Env v g -> v C.Function
omFuncMemorySizeTotal = fst $ makeOmSizeFuncSet "om_memory_size" 
                        (\(Env setup plan) -> Native.localSize setup + Plan.lowerMargin plan + Plan.upperMargin plan)
omFuncMemorySize      = snd $ makeOmSizeFuncSet "om_memory_size"
                        (\(Env setup plan) -> Native.localSize setup + Plan.lowerMargin plan + Plan.upperMargin plan)

omFuncLowerMargin     :: Opt.Ready v g => Env v g -> v C.Function
omFuncLowerMargin     = snd $ makeOmSizeFuncSet "om_lower_margin" (\(Env _ plan) -> Plan.lowerMargin plan)
omFuncUpperMargin     :: Opt.Ready v g => Env v g -> v C.Function
omFuncUpperMargin     = snd $ makeOmSizeFuncSet "om_upper_margin" (\(Env _ plan) -> Plan.upperMargin plan)



makeOmSizeFuncSet :: Opt.Ready v g => 
                        Text              -- ^ The header text
                     -> (Env v g -> v g)  -- ^ How to read the size from the environment
                     -> (Env v g -> C.Function, Env v g -> v C.Function)
makeOmSizeFuncSet header sizeVecReader = (prodFunc, elemFuncs)
  where
    -- reader monad
    prodFunc = do
      sizeVec <- sizeVecReader
      trivialFunc header $ product $ F.toList sizeVec
    elemFuncs = do
      sizeVec <- sizeVecReader
      F.sequenceA $ compose (\i -> trivialFunc (header ++ "_" ++ showT (axisIndex i)) (sizeVec ! i)) 

    trivialFunc str ret = do
      sizeVec <- sizeVecReader
      gt <- gaugeType
      return $ (C.function gt $ mkName $ str)
          { C.funcBody = 
            [ C.StmtReturn (C.toDyn ret)
            ]
          }

-- | Generate member functions for accessing 
      

gaugeType :: Opt.Ready v g => Env v g -> C.TypeRep
gaugeType env@(Env setup plan) = C.typeOf $ Native.localSize setup ! (Axis 0)


-- Make Kernel Functions
makeKernelFunc :: Opt.Ready v g => Env v g -> Int -> OM.Kernel v g AnAn -> C.MemberDef
makeKernelFunc env@(Env setup plan) kerIdx ker = C.MemberFunc C.Public False $ 
 (C.function tVoid (name ker)) 
 { C.funcBody = kernelCalls ++ storeInsts
 }
 where
   graph = OM.dataflow ker

   kernelCalls = 
     V.toList $
     V.map (\subker -> callSubKer subker $ V.map findVar $ 
                       Plan.inputIdxs subker V.++ Plan.outputIdxs subker) $
     V.filter ((== kerIdx) . Plan.kernelIdx) $
     Plan.subKernels plan
   callSubKer subker xs = 
     C.StmtExpr $
     C.FuncCallUsr (name subker) (V.toList xs)

   storeInsts = 
     map swapStmt $
     concatMap (\(idx, nd) -> case nd of
               OM.NInst (OM.Store (OM.StaticIdx statIdx))_ -> [(idx, statIdx)] 
               _ -> []) $
     FGL.labNodes graph

   swapStmt (idx, statIdx) = 
     let preIdx = head $ FGL.pre graph idx in
     case (filter ((Plan.StaticRef statIdx==) . Plan.storageIdx) $ V.toList $ Plan.storages plan,
           filter ((Plan.ManifestRef kerIdx preIdx==) . Plan.storageIdx) $ V.toList $ Plan.storages plan) of
       ([stRef],[maRef]) -> 
         C.StmtExpr $ C.Op2Infix "="
         (C.VarExpr $ C.Var (mkCppType env $ Plan.storageType stRef) (name stRef) )
         (C.VarExpr $ C.Var C.UnknownType (name maRef) )
       _ -> error $ "mismatch in storage phase: " ++ show (idx, statIdx) 
   findVar idx = 
     let 
       loadIdx = 
         listToMaybe $
         concat $
         map (\jdx -> 
               case FGL.lab graph jdx of
                 Just (OM.NInst (OM.Load (OM.StaticIdx statIdx))_)-> [Plan.StaticRef statIdx] 
                 _                                                -> []) $
         FGL.pre graph idx
       match stIdx
         | stIdx == Plan.ManifestRef kerIdx idx = True
         | Just stIdx == loadIdx                = True
         | otherwise                            = False
       stRef = V.head $ V.filter ( match . Plan.storageIdx ) $ Plan.storages plan
     in C.VarExpr $ C.Var 
        (mkCppType env $ Plan.storageType stRef) 
        (name stRef) 


-- | Create a subKernel: a member function that performs a portion of
--   actual calculations. It may also generate some helper functions
--   called from the subKernel body.
makeSubFunc :: Opt.Ready v g 
            => Env v g 
            -> Plan.SubKernelRef v g AnAn 
            -> ([C.MemberDef], [C.Statement])
makeSubFunc env@(Env setup plan) subker = 
  case Native.language setup of
    Native.CPlusPlus -> cppSolution
    Native.CUDA      -> cudaSolution
  where
    rlm = Realm.realm subker 
    
    cudaHelperName = mkName $ nameText subker ++ "_inner"
    
    cudaBodys = 
      if rlm == Realm.Scalar 
      then []
      else
        (:[]) $
        C.FuncDef $
        (C.function 
         (C.QualifiedType [C.CudaGlobal] tVoid) 
         cudaHelperName)
        { C.funcArgs = 
           makeRawSubArg env True  (Plan.labNodesIn subker) ++
           makeRawSubArg env False (Plan.labNodesOut subker),
          C.funcBody = 
          [ C.Comment $ T.unlines 
            [ "",
              "lowerMargin = " ++ showT (Plan.lowerBoundary subker),
              "upperMargin = " ++ showT (Plan.upperBoundary subker)
            ]
          ] ++ loopMaker env rlm subker
        }
    
    (gridDim, blockDim) = Native.cudaGridSize setup
    
    cudaSolution = 
      (,cudaBodys) $
      (:[]) $
      C.MemberFunc C.Public False $ 
      (C.function tVoid (name subker))
      { C.funcArgs = 
         makeSubArg env True (Plan.labNodesIn subker) ++
         makeSubArg env False (Plan.labNodesOut subker),
        C.funcBody = 
          if rlm == Realm.Scalar 
          then loopMaker env rlm subker
          else
            [ C.Comment $ T.unlines 
              [ "",
                "lowerMargin = " ++ showT (Plan.lowerBoundary subker),
                "upperMargin = " ++ showT (Plan.upperBoundary subker)
              ],
              C.RawStatement "{static bool fst = false;",
              C.StmtExpr $ C.FuncCallStd "if (fst) cudaFuncSetCacheConfig" 
                 [mkVarExpr $ nameText cudaHelperName, mkVarExpr "cudaFuncCachePreferL1"],
              C.RawStatement "fst = true;}",
              C.StmtExpr $ C.CudaFuncCallUsr cudaHelperName (C.toDyn gridDim) (C.toDyn blockDim) $
              map takeRaw $ 
              (V.toList $ Plan.labNodesIn subker) ++ (V.toList $ Plan.labNodesOut subker)
            ] 
      }

    takeRaw (idx, nd)= 
      case nd of
        OM.NValue typ _ -> 
          extractor typ $
          C.VarExpr $
          C.Var (mkCppType env typ) (nodeNameUniversal idx)
        _ -> error "NValue expected" 

    extractor typ = case Realm.realm typ of
      Realm.Array ->           
        flip C.MemberAccess (C.FuncCallStd "raw" []) 
      Realm.Scalar ->
        id

    cppSolution = 
      (,[]) $
      (:[]) $
      C.MemberFunc C.Public False $ 
      (C.function tVoid (name subker))
      { C.funcArgs = 
         makeSubArg env True (Plan.labNodesIn subker) ++
         makeSubArg env False (Plan.labNodesOut subker),
        C.funcBody = 
          if rlm == Realm.Scalar 
          then loopMaker env rlm subker
          else
            [ C.Comment $ T.unlines 
              [ "",
                "lowerMargin = " ++ showT (Plan.lowerBoundary subker),
                "upperMargin = " ++ showT (Plan.upperBoundary subker)
              ]
            ] ++ loopMaker env rlm subker
      }

    
    
-- | make a subroutine argument list.
makeSubArg :: Opt.Ready v g => Env v g -> Bool -> V.Vector (FGL.LNode (OM.Node v g AnAn)) -> [C.Var]
makeSubArg env isConst lnodes =
  let f = (if isConst then C.Const else id) . C.RefOf
  in
  map (\(idx,nd)-> case nd of
            OM.NValue typ _ -> C.Var (f $ mkCppType env typ) (nodeNameUniversal idx)
            _ -> error "NValue expected" ) $
  V.toList lnodes

-- | make a subroutine argument list, using raw pointers.
makeRawSubArg :: Opt.Ready v g => Env v g -> Bool -> V.Vector (FGL.LNode (OM.Node v g AnAn)) -> [C.Var]
makeRawSubArg env isConst lnodes =
  let f = (if isConst then C.Const else id) 
  in
  map (\(idx,nd)-> case nd of
          OM.NValue typ _ -> 
            C.Var (f $ mkCudaRawType env typ) (nodeNameUniversal idx)
          _ -> error "NValue expected" ) $
  V.toList lnodes




-- | implement the loop for each subroutine
loopMaker :: Opt.Ready v g => Env v g -> Realm.Realm -> Plan.SubKernelRef v g AnAn -> [C.Statement]
loopMaker env@(Env setup plan) realm subker = case realm of
  Realm.Array ->
    pragma ++
    [ C.StmtFor 
      (C.VarDefSub loopCounter loopBegin) 
      (C.Op2Infix "<"  (C.VarExpr loopCounter) loopEnd)
      (C.Op2Infix "+=" (C.VarExpr loopCounter) loopStride) $
      [C.VarDefSub addrCounter codecAddr] ++
      loopContent
    ]
  Realm.Scalar -> loopContent

  where
    pragma = 
      if Native.language setup == Native.CPlusPlus 
      then [C.StmtPrpr $ C.PrprPragma "omp parallel for"]
      else []
    (loopBegin, loopEnd, loopStride) = case Native.language setup of
      Native.CPlusPlus -> (intImm 0, C.toDyn (product boundarySize), intImm 1)
      Native.CUDA -> (loopBeginCuda, C.toDyn (product boundarySize), loopStrideCuda)      
    loopBeginCuda = mkVarExpr "blockIdx.x * blockDim.x + threadIdx.x"
    loopStrideCuda   = mkVarExpr "blockDim.x * gridDim.x"    
    
    loopCounter = C.Var tSizet (mkName "i")
    -- the orthotope for entire input
    memorySize   = F.toList $ Native.localSize setup + Plan.lowerMargin plan + Plan.upperMargin plan
    -- the region where we can make output.
    -- if we use open boundary, it's smaller than input
    boundarySize = F.toList $ Native.localSize setup + Plan.lowerMargin plan + Plan.upperMargin plan
     - (compose $ \ax -> case Native.boundary setup ! ax of
        Boundary.Open   -> Plan.lowerBoundary subker ! ax + Plan.upperBoundary subker ! ax
        Boundary.Cyclic -> Additive.zero)

    codecDiv = 
      [ if idx == 0 then (C.VarExpr loopCounter) else (C.VarExpr loopCounter) / (C.toDyn $ product $ take idx boundarySize) 
      | idx <- [0..length boundarySize-1]]
    codecMod = 
      [ if idx == length codecDiv-1 then x else x `mod`  (C.toDyn $ boundarySize !! idx)
      | (idx, x) <- zip [0..] codecDiv]
    codecModAdd = 
      [ x + (C.toDyn $ Plan.lowerMargin plan ! (Axis idx))
      | (idx, x) <- zip [0..] codecMod]
    codecAddr = 
      if memorySize == boundarySize 
      then C.VarExpr loopCounter
      else foldl1 (+)
           [ x * (C.toDyn $ product $ take idx  memorySize)
           | (idx, x) <- zip [0..] codecModAdd]
    codecLoadIndex cursor =
      [ let
           bnd = Native.boundary setup
           n      = C.toDyn $ memorySize !! idx 
           protector x'
             | bnd ! Axis idx == Boundary.Open = x'
             | otherwise = (x'+n) `mod` n
        in protector $ x - C.toDyn((Plan.lowerMargin plan - Plan.lowerBoundary subker - cursor) ! (Axis idx) )
      | (idx, x) <- zip [0..] codecMod]
    codecLoadSize =
      [ C.toDyn  (Native.localSize setup ! (Axis idx) )
      | (idx, _) <- zip [0..] codecMod]

    codecCursor cursor 
      | F.all (==Boundary.Open) bnd = easySum
      | otherwise = normalSum                      
        where
          bnd = Native.boundary setup
          easySum = C.VarExpr addrCounter + C.toDyn hardCodeShift
          hardCodeShift = sum $
            [ cursor ! (Axis idx) * product (take idx memorySize)
            | (idx, _) <- zip [0..] memorySize]
          normalSum = foldl1 (+)
            [ let stride = C.toDyn $ product $ take idx  memorySize
                  n      = C.toDyn $ memorySize !! idx 
                  protector x'
                    | bnd ! Axis idx == Boundary.Open = x'
                    | otherwise = (x'+n) `mod` n
              in stride * protector (x + C.toDyn (cursor ! Axis idx))
            | (idx, x) <- zip [0..] codecModAdd]  

    addrCounter = C.Var tSizet (mkName "addr_origin")

    loopContent = 
      concat $
      map buildExprs $
      filterVal $
      Set.toList allIdxSet

    buildExprs (idx, val@(DVal.DynValue r c)) = 
      addSyncFunctions idx  $
      map (\cursor -> 
            lhs cursor
            (fst $ rhsAndRequest env idx cursor)
          ) $
      Set.toList $ lhsCursors V.! idx
      where
        lhs cursor expr = 
          if Set.member idx outputIdxSet
          then C.StmtExpr $ flip (C.Op2Infix "=") expr $ case realm of
            Realm.Array ->
              (C.ArrayAccess (C.VarExpr $ C.Var (C.UnitType c) (nodeNameUniversal idx)) (C.VarExpr addrCounter)) 
            Realm.Scalar ->
              C.VarExpr $ C.Var (C.UnitType c) (nodeNameUniversal idx)
          else flip C.VarDefSub expr 
               (C.Var (C.UnitType c) $ nodeNameCursored env idx cursor)

    addSyncFunctions :: FGL.Node -> [C.Statement] -> [C.Statement] 
    addSyncFunctions idx = 
      if realm /= Realm.Array ||
         Native.language setup /= Native.CUDA 
         then id
         else foldl (.) id $ map adder anot
      where
        anot :: [Sync.Timing]
        anot =  (maybeToList $ FGL.lab graph idx) >>= (Anot.toList . OM.getA)
          
        adder :: Sync.Timing -> [C.Statement] -> [C.Statement]  
        adder Sync.Pre  = ([C.RawStatement "__syncthreads();"] ++ )
        adder Sync.Post = ( ++ [C.RawStatement "__syncthreads();"])

    -- lhsCursors :: (Opt.Ready v g) => V.Vector(Set.Set(v g))
    lhsCursors = V.generate idxSize f
      where 
        f idx
          | not (Set.member idx allIdxSet) = Set.empty
          | Set.member idx outputIdxSet    = Set.singleton Additive.zero
          | otherwise                      = lhsRequest idx

    -- lhsRequests :: FGL.Node -> (Set.Set (v g))
    lhsRequest idx =
      Set.fromList $
      map snd $
      filter ((==idx) . fst) $
      concat $ 
      [snd $ rhsAndRequest env jdx cur| 
       jdx <- Set.toList allIdxSet, 
       jdx > idx,
       cur <- Set.toList $ lhsCursors V.! jdx
       ]

    -- rhsAndRequest :: (Opt.Ready v g) => Env v g -> FGL.Node -> v g -> (C.Expr,[(Int, v g)])
    rhsAndRequest env' idx cursor = 
      let (idxInst,inst) = case preInst idx of
            found:_ -> found
            _       -> error $ "right hand side is not inst:" ++ show idx
          prepre = map fst $ sortBy (\x y -> compare (snd x) (snd y)) $ FGL.lpre graph idxInst
          isInput = Set.member idx inputIdxSet
          creatVar idx' = C.VarExpr $ C.Var C.UnknownType (nodeNameUniversal idx')
        in case inst of
      _ | isInput     -> case realm of
        Realm.Array -> (C.ArrayAccess (creatVar idx) (codecCursor cursor), [])
        Realm.Scalar -> (creatVar idx, [])
      OM.Imm dyn      -> (C.Imm dyn, [])
      OM.Arith op     -> (rhsArith env' op (map (nodeToRhs env' cursor) prepre),  
                      map (,cursor) prepre)
      OM.Shift v      -> case prepre of
        [pre1] -> (nodeToRhs env' cursor' pre1, [(pre1,cursor')]) where cursor' = cursor - v
        _      -> error $ "shift has not 1 pre!" ++ show idxInst ++  show prepre
      OM.LoadIndex ax -> (codecLoadIndex cursor !! axisIndex ax, [])
      OM.LoadSize  ax -> (codecLoadSize  !! axisIndex ax, [])
      OM.Reduce op    -> let fname = T.pack ("om_reduce_" ++ map toLower (show op)) in
        (C.FuncCallStd fname (map creatVar prepre), [])
      OM.Broadcast    -> let fname = "om_broadcast" in
        (C.FuncCallStd fname (map creatVar prepre), [])
      _               -> (C.CommentExpr ("TODO : " ++ showT inst) (C.toDyn (42::Int)), [])

    nodeToRhs env' cursor idx = C.VarExpr $ C.Var C.UnknownType $ nodeNameCursored env' idx cursor


    preVal  = filterVal  . FGL.pre graph
    preInst = filterInst . FGL.pre graph
    sucVal  = filterVal  . FGL.suc graph
    sucInst = filterInst . FGL.suc graph

    filterVal  = concat . map (\(i,(xs,ys))-> map(i,)xs) . shiwake
    filterInst = concat . map (\(i,(xs,ys))-> map(i,)ys) . shiwake
    shiwake indices = 
      map (\idx -> (idx,) $ case FGL.lab graph idx of
              Just (OM.NValue dval _) -> ([dval], [])
              Just (OM.NInst  inst _) -> ([], [inst])
              Nothing                 -> error $ "not in graph:" ++ show idx) $
      indices 

    idxSize = FGL.noNodes graph

    allIdxSet = Set.unions [inputIdxSet, outputIdxSet, calcIdxSet]
    inputIdxSet  = Set.fromList $ V.toList $ Plan.inputIdxs  subker
    outputIdxSet = Set.fromList $ V.toList $ Plan.outputIdxs subker
    calcIdxSet = Set.fromList $ V.toList $ Plan.calcIdxs subker    

    graph = Plan.dataflow subker

-- | convert a DynValue to C type representation
mkCppType :: Opt.Ready v g => Env v g -> DVal.DynValue -> C.TypeRep
mkCppType env x = case x of
  DVal.DynValue Realm.Scalar c -> C.UnitType c
  DVal.DynValue Realm.Array  c -> containerType env c          

containerType :: Env v g -> TypeRep -> C.TypeRep
containerType (Env setup _) c = case Native.language setup of
  Native.CPlusPlus -> C.TemplateType "std::vector" [C.UnitType c]
  Native.CUDA      -> C.TemplateType "thrust::thrust_vector" [C.UnitType c]


-- | convert a DynValue to raw-pointer type representation 
--   for example used within CUDA kernel
mkCudaRawType :: Opt.Ready v g => Env v g -> DVal.DynValue -> C.TypeRep
mkCudaRawType env x = case x of
  DVal.DynValue Realm.Scalar c -> C.UnitType c
  DVal.DynValue Realm.Array  c -> containerRawType env c          

containerRawType :: Env v g -> TypeRep -> C.TypeRep
containerRawType (Env setup _) c = case Native.language setup of
  Native.CPlusPlus -> C.PtrOf $ C.UnitType c
  Native.CUDA      -> C.PtrOf $ C.UnitType c



-- | a universal naming rule for a node.
nodeNameUniversal :: FGL.Node -> Name
nodeNameUniversal idx = mkName $ "a" ++ showT idx



nodeNameCursored :: Opt.Ready v g => Env v g ->  FGL.Node -> v g -> Name
nodeNameCursored env idx cursor = mkName $ "a" ++ showT idx ++ "_" ++ 
                                cursorToText env cursor

cursorToText :: Opt.Ready v g => Env v g ->  v g -> T.Text
cursorToText _ cursor = cursorT
  where
    cursorT :: T.Text
    cursorT = F.foldl1 connector $ compose (\i -> T.map sanitize $ showT (cursor ! i))
    connector a b = a ++ "_" ++ b
    sanitize c
      | isDigit c = c
      | c == '-'  = 'm'
      | c == '.'  = 'd'
      | otherwise = 'k'



-- | Utility Types
intImm :: Int -> C.Expr
intImm = C.toDyn

tInt :: C.TypeRep
tInt = C.typeOf (undefined :: Int)

tSizet :: C.TypeRep
tSizet = C.typeOf (undefined :: Int)

mkVarExpr :: Text -> C.Expr
mkVarExpr = C.VarExpr . C.Var C.UnknownType . mkName

tVoid :: C.TypeRep
tVoid = C.typeOf ()

tHostVecInt :: C.TypeRep
tHostVecInt = C.TemplateType "thrust::host_vector" [tInt]

tDeviceVecInt :: C.TypeRep
tDeviceVecInt = C.TemplateType "thrust::device_vector" [tInt]

rhsArith :: Opt.Ready v g => Env v g -> Arith.Operator -> [C.Expr] -> C.Expr
rhsArith (Env setup _) op argExpr = case (op, argExpr) of
  (Arith.Identity, [x]) ->  x
  (Arith.Add    , [x,y]) -> C.Op2Infix "+" x y
  (Arith.Sub    , [x,y]) -> C.Op2Infix "-" x y
  (Arith.Neg    , [x]) -> C.Op1Prefix "-" x 
  (Arith.Mul    , [x,y]) -> C.Op2Infix "*" x y
  (Arith.Div    , [x,y]) -> C.Op2Infix "/" x y
  (Arith.Mod    , [x,y]) -> C.Op2Infix "%" x y  
  (Arith.Inv    , [x]) -> C.Op1Prefix "1/" x 
  (Arith.Not    , [x]) -> C.Op1Prefix "!" x   
  (Arith.And    , [x,y]) -> C.Op2Infix "&&" x y  
  (Arith.Or     , [x,y]) -> C.Op2Infix "||" x y  
  (Arith.EQ     , [x,y]) -> C.Op2Infix "==" x y  
  (Arith.NE     , [x,y]) -> C.Op2Infix "!=" x y    
  (Arith.LT     , [x,y]) -> C.Op2Infix "<" x y    
  (Arith.LE     , [x,y]) -> C.Op2Infix "<=" x y    
  (Arith.GT     , [x,y]) -> C.Op2Infix ">" x y    
  (Arith.GE     , [x,y]) -> C.Op2Infix ">=" x y    
  (Arith.Select , [x,y,z]) -> C.Op3Infix "?" ":" x y z   
  (Arith.Max    , [x,y])  -> C.FuncCallStd (nmsp "std::max" "max") [x,y]
  (Arith.Min    , [x,y])  -> C.FuncCallStd (nmsp "std::min" "min") [x,y]
  (Arith.Abs    , [x])  -> C.FuncCallStd "abs" [x]
  (Arith.Sqrt   , [x])  -> C.FuncCallStd "sqrt" [x]
  (Arith.Exp    , [x])  -> C.FuncCallStd "exp" [x]
  (Arith.Log    , [x])  -> C.FuncCallStd "log" [x]
  (Arith.Sin    , [x])  -> C.FuncCallStd "sin" [x]
  (Arith.Cos    , [x])  -> C.FuncCallStd "cos" [x]
  (Arith.Tan    , [x])  -> C.FuncCallStd "tan" [x]
  (Arith.Asin   , [x])  -> C.FuncCallStd "asin" [x]
  (Arith.Acos   , [x])  -> C.FuncCallStd "acos" [x]
  (Arith.Atan   , [x])  -> C.FuncCallStd "atan" [x]
  (Arith.Atan2  , [x,y])  -> C.FuncCallStd "atan2" [x,y]
  (Arith.Cast tr, [x])  -> case typeRepDB tr of
                                Just typeRepStr -> C.FuncCallStd ("(" ++ typeRepStr ++ ")") [x]
                                _               -> x
  _ -> C.FuncCallStd (T.map toLower $ showT op) argExpr
  where
    nmsp a b = case Native.language setup of
      Native.CPlusPlus -> a
      Native.CUDA      -> b

library :: Opt.Ready v g => Env v g -> [C.Statement]
library (Env setup _) = (:[]) $ C.Exclusive C.SourceFile $ C.RawStatement $ lib
  where
    lib = case Native.language setup of
      Native.CPlusPlus -> cpuLib
      Native.CUDA      -> gpuLib
    -- use draft.cpp to generate library
    cpuLib = "\ntemplate <class T> T om_broadcast (const T& x) {\n  return x;\n}\ntemplate <class T> T om_reduce_sum (const std::vector<T> &xs) {\n  T ret = 0;\n  for (int i = 0; i < xs.size(); ++i) ret+=xs[i];\n  return ret;\n}\ntemplate <class T> T om_reduce_min (const std::vector<T> &xs) {\n  T ret = xs[0];\n  for (int i = 1; i < xs.size(); ++i) ret=std::min(ret,xs[i]);\n  return ret;\n}\ntemplate <class T> T om_reduce_max (const std::vector<T> &xs) {\n  T ret = xs[0];\n  for (int i = 1; i < xs.size(); ++i) ret=std::max(ret,xs[i]);\n  return ret;\n}\n\n"
    gpuLib = "\ntemplate <class T> __device__ __host__ T om_broadcast (const T& x) {\n  return x;\n}\ntemplate <class T> T om_reduce_sum (const thrust::thrust_vector<T> &xs) {\n  return thrust::reduce(xs.device_vector().begin(), xs.device_vector().end(), 0, thrust::plus<T>());\n}\ntemplate <class T> T om_reduce_min (const thrust::thrust_vector<T> &xs) {\n  return *(thrust::min_element(xs.device_vector().begin(), xs.device_vector().end()));\n}\ntemplate <class T> T om_reduce_max (const thrust::thrust_vector<T> &xs) {\n  return *(thrust::max_element(xs.device_vector().begin(), xs.device_vector().end()));\n}\n\n"


commonLibraries :: [(FilePath, Text)]
commonLibraries = [("om_thrust_vector.h", thrustVectorLib)]

thrustVectorLib :: Text

thrustVectorLib = "#pragma once\n\n#include <thrust/device_vector.h>\n#include <thrust/host_vector.h>\n#include <thrust/device_ptr.h>\n#include <iostream>\n\nnamespace thrust {\n\ntemplate <class T>\nclass thrust_vector {\npublic:\n  enum NewerFlag {\n    kHost,\n    kDevice,\n    kBoth\n  };\nprivate:\n  mutable thrust::host_vector<T> hv;\n  mutable thrust::device_vector<T> dv;\n  mutable NewerFlag newer_;\npublic:  \n\n  typedef typename thrust::host_vector<T>::size_type  size_type;\n  typedef T value_type;\n\n  thrust_vector () : newer_(kBoth) {}\n  thrust_vector (size_type n, const value_type &value = value_type()) :\n    hv(n, value), dv(n, value), newer_(kBoth) {}\n  thrust_vector (const thrust_vector<T> &v) :\n    hv(v.hv), dv(v.dv), newer_(kBoth) {}\n\n  const size_type size() const { return hv.size(); }\n\n  void bring_device () const {\n    if (kHost == newer_) {\n      dv = hv;\n      newer_ = kBoth;\n    }\n  }\n  void bring_host () const {\n    if (kDevice == newer_) {\n      hv = dv;\n      newer_ = kBoth;\n    }\n  }\n\n  thrust::host_vector<T> &host_vector() const {\n    bring_host();\n    return hv;\n  }\n  thrust::device_vector<T> &device_vector() const {\n    bring_device();\n    return dv;\n  }\n  \n  const T& operator[](const size_type n) const {\n    bring_host();\n    return hv[n];\n  }\n  T& operator[](const size_type n) {\n    bring_host();\n    newer_ = kHost;\n    return hv[n];\n  }\n  T* unsafe_raw_device () const {\n    return thrust::raw_pointer_cast(&*dv.begin());\n  }\n  T* unsafe_raw_host () const {\n    return &hv[0];\n  }\n  void unsafe_set_newer (NewerFlag n) {\n    newer_ = n;\n  }\n  \n  T* raw () const {\n    bring_device();\n    newer_ = kDevice;\n    return thrust::raw_pointer_cast(&*dv.begin());\n  }\n  typename thrust::device_vector<T>::iterator device_begin () {\n    bring_device();\n    return dv.begin();\n  }\n  typename thrust::device_vector<T>::iterator device_end () {\n    bring_device();\n    return dv.end();\n  }\n};\n\ntemplate<class T>\nT* raw(thrust::device_vector<T> &dv) {\n  return thrust::raw_pointer_cast(&*dv.begin());\n}\n\ntemplate<class T>\nT* raw(thrust::host_vector<T> &hv) {\n  return &hv[0];\n}\n\ntemplate<class T>\nT* raw(const thrust::thrust_vector<T> &tv) {\n  return tv.raw();\n}\n\n}\n"