module Data.Array.Repa.Plugin.ToDDC.Detect.Type where import DDC.Core.Compounds import DDC.Core.Exp import DDC.Core.Flow import DDC.Core.Flow.Compounds import Data.Array.Repa.Plugin.FatName import Data.Array.Repa.Plugin.ToDDC.Detect.Base import Control.Monad.State.Strict import qualified DDC.Type.Sum as Sum import Data.List import qualified Kind as G import qualified TyCon as G import qualified Var as G -- Bind ----------------------------------------------------------------------- instance Detect Bind where detect b = case b of BName (FatName g d) t1 -> do collect d g t1' <- detect t1 return $ BName d t1' BAnon t -> liftM BAnon (detect t) BNone t -> liftM BNone (detect t) -- Bound ---------------------------------------------------------------------- instance Detect Bound where detect u = case u of UName n@(FatName g d) -- Primitive type constructors. | Just g' <- matchPrim "Bool_" n -> makePrim g' (NamePrimTyCon PrimTyConBool) kData | Just g' <- matchPrim "Int_" n -> makePrim g' (NamePrimTyCon PrimTyConInt) kData | Just g' <- matchPrim "Int#_" n -> makePrim g' (NamePrimTyCon PrimTyConInt) kData | Just g' <- matchPrim "Word8_" n -> makePrim g' (NamePrimTyCon (PrimTyConWord 8)) kData | Just g' <- matchPrim "Word16_" n -> makePrim g' (NamePrimTyCon (PrimTyConWord 16)) kData | Just g' <- matchPrim "Word32_" n -> makePrim g' (NamePrimTyCon (PrimTyConWord 32)) kData | Just g' <- matchPrim "Word64_" n -> makePrim g' (NamePrimTyCon (PrimTyConWord 64)) kData | Just g' <- matchPrim "Float_" n -> makePrim g' (NamePrimTyCon (PrimTyConFloat 32)) kData | Just g' <- matchPrim "Double_" n -> makePrim g' (NamePrimTyCon (PrimTyConFloat 64)) kData -- Vectors, series and selectors. | Just g' <- matchPrim "Vector_" n -- Find ghc's kind for the var -- Only if it's a data type, not a Constraint? , not $ returnsConstraintKind g' -> makePrim g' (NameTyConFlow TyConFlowVector) (kData `kFun` kData) | Just g' <- matchPrim "Series_" n -> makePrim g' (NameTyConFlow TyConFlowSeries) (kRate `kFun` kData `kFun` kData) | Just g' <- matchPrim "Sel1_" n -> makePrim g' (NameTyConFlow (TyConFlowSel 1)) (kRate `kFun` kRate `kFun` kData) -- N-tuples: (,)_ etc. Holds one more than the number of commas | Just (str, g') <- stringPrim n , '(':rest <- str , (commas,aftercommas) <- span (==',') rest , isPrefixOf ")_" aftercommas , size <- length commas + 1 -> do let k = foldr kFun kData (replicate size kData) makePrim g' (NameTyConFlow (TyConFlowTuple size)) k | otherwise -> do collect d g return $ UName d UIx ix -> return $ UIx ix UPrim (FatName g d) t -> do collect d g t' <- detect t return $ UPrim d t' matchPrim str n | Just (str', g) <- stringPrim n , isPrefixOf str str' = Just g | otherwise = Nothing stringPrim n | FatName g (NameVar str') <- n = Just (str', g) | FatName g (NameCon str') <- n = Just (str', g) | otherwise = Nothing makePrim g d t = do collect d g return $ UPrim d t returnsConstraintKind :: GhcName -> Bool returnsConstraintKind g = case g of GhcNameVar v -> G.returnsConstraintKind $ G.varType v GhcNameTyCon tc -> G.returnsConstraintKind $ G.tyConKind tc _ -> False -- TyCon ---------------------------------------------------------------------- instance Detect TyCon where detect tc = case tc of TyConSort tc' -> return $ TyConSort tc' TyConKind tc' -> return $ TyConKind tc' TyConWitness tc' -> return $ TyConWitness tc' TyConSpec tc' -> return $ TyConSpec tc' TyConBound u k -> do u' <- detect u k' <- detect k case u' of UPrim _ k2 -> return $ TyConBound u' k2 _ -> return $ TyConBound u' k' -- Type ------------------------------------------------------------------------ instance Detect Type where detect tt -- Detect rate variables being applied to Series type constructors. | TApp t1 t2 <- tt , [ TCon (TyConBound (UName (FatName _ (NameCon str))) _) , TVar (UName (FatName _ n)) , _] <- takeTApps tt , isPrefixOf "Series_" str = do setRateVar n t1' <- detect t1 t2' <- detect t2 return $ TApp t1' t2' -- Detect rate variables being applied to Sel1 type constructors. | TApp t1 t2 <- tt , [ TCon (TyConBound (UName (FatName _ (NameCon str))) _) , TVar (UName (FatName _ n1)) , TVar (UName (FatName _ n2))] <- takeTApps tt , isPrefixOf "Sel1_" str = do setRateVar n1 setRateVar n2 t1' <- detect t1 t2' <- detect t2 return $ TApp t1' t2' -- Set kind of detected rate variables to Rate. | TForall b t <- tt = do t' <- detect t b' <- detect b case b' of BName n _ -> do rateVar <- isRateVar n if rateVar then return $ TForall (BName n kRate) t' else return $ TForall b' t' _ -> error "repa-plugin.detect no match" -- Convert all kindy things to kData | TCon (TyConBound (UName n) _) <- tt , Just _ <- matchPrim "*_" n = do return $ TCon (TyConKind KiConData) | TCon (TyConBound (UName n) _) <- tt , Just _ <- matchPrim "#_" n = do return $ TCon (TyConKind KiConData) | TCon (TyConBound (UName n) _) <- tt , Just _ <- matchPrim "Constraint_" n = do return $ TCon (TyConKind KiConData) -- Boilerplate traversal. | otherwise = case tt of TVar u -> liftM TVar (detect u) TCon c -> liftM TCon (detect c) TForall b t -> liftM2 TForall (detect b) (detect t) TApp t1 t2 -> liftM2 TApp (detect t1) (detect t2) TSum ts -> do k <- detect $ Sum.kindOfSum ts tss' <- liftM (Sum.fromList k) $ mapM detect $ Sum.toList ts return $ TSum tss'