{-| Copyright : (C) 2015-2016, University of Twente, 2016 , Myrtle Software Ltd License : BSD2 (see the file LICENSE) Maintainer : Christiaan Baaij Reductions of primitives Currently, it contains reductions for: * Clash.Sized.Vector.map * Clash.Sized.Vector.zipWith * Clash.Sized.Vector.traverse# * Clash.Sized.Vector.foldr * Clash.Sized.Vector.fold * Clash.Sized.Vector.dfold * Clash.Sized.Vector.(++) * Clash.Sized.Vector.head * Clash.Sized.Vector.tail * Clash.Sized.Vector.unconcatBitVector# * Clash.Sized.Vector.replicate * Clash.Sized.Vector.imap * Clash.Sized.Vector.dtfold * Clash.Sized.RTree.tfold Partially handles: * Clash.Sized.Vector.unconcat * Clash.Sized.Vector.transpose -} {-# LANGUAGE CPP #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE ViewPatterns #-} module Clash.Normalize.PrimitiveReductions where import qualified Control.Lens as Lens import qualified Data.HashMap.Lazy as HashMap import qualified Data.Maybe as Maybe import Unbound.Generics.LocallyNameless (bind, embed, rec, rebind) import Clash.Core.DataCon (DataCon, dataConInstArgTys) import Clash.Core.Literal (Literal (..)) import Clash.Core.Name import Clash.Core.Pretty (showDoc) import Clash.Core.Term (Term (..), Pat (..)) import Clash.Core.Type (LitTy (..), Type (..), TypeView (..), coreView, mkFunTy, mkTyConApp, splitFunForallTy, tyView, undefinedTy) import Clash.Core.TyCon (TyConName, tyConDataCons) import Clash.Core.TysPrim (integerPrimTy, typeNatKind) import Clash.Core.Util (appendToVec, extractElems, extractTElems, idToVar, mkApps, mkRTree, mkVec, termType) import Clash.Core.Var (Var (..)) import Clash.Normalize.Types import Clash.Rewrite.Types import Clash.Rewrite.Util import Clash.Util -- | Replace an application of the @Clash.Sized.Vector.zipWith@ primitive on -- vectors of a known length @n@, by the fully unrolled recursive "definition" -- of @Clash.Sized.Vector.zipWith@ reduceZipWith :: Integer -- ^ Length of the vector(s) -> Type -- ^ Type of the lhs of the function -> Type -- ^ Type of the rhs of the function -> Type -- ^ Type of the result of the function -> Term -- ^ The zipWith'd functions -> Term -- ^ The 1st vector argument -> Term -- ^ The 2nd vector argument -> NormalizeSession Term reduceZipWith n lhsElTy rhsElTy resElTy fun lhsArg rhsArg = do tcm <- Lens.view tcCache ty <- termType tcm lhsArg go tcm ty where go tcm (coreView tcm -> Just ty') = go tcm ty' go tcm (tyView -> TyConApp vecTcNm _) | (Just vecTc) <- HashMap.lookup (nameOcc vecTcNm) tcm , [nilCon,consCon] <- tyConDataCons vecTc = let (varsL,elemsL) = second concat . unzip $ extractElems consCon lhsElTy 'L' n lhsArg (varsR,elemsR) = second concat . unzip $ extractElems consCon rhsElTy 'R' n rhsArg funApps = zipWith (\l r -> mkApps fun [Left l,Left r]) varsL varsR lbody = mkVec nilCon consCon resElTy n funApps lb = Letrec (bind (rec (init elemsL ++ init elemsR)) lbody) in changed lb go _ ty = error $ $(curLoc) ++ "reduceZipWith: argument does not have a vector type: " ++ showDoc ty -- | Replace an application of the @Clash.Sized.Vector.map@ primitive on vectors -- of a known length @n@, by the fully unrolled recursive "definition" of -- @Clash.Sized.Vector.map@ reduceMap :: Integer -- ^ Length of the vector -> Type -- ^ Argument type of the function -> Type -- ^ Result type of the function -> Term -- ^ The map'd function -> Term -- ^ The map'd over vector -> NormalizeSession Term reduceMap n argElTy resElTy fun arg = do tcm <- Lens.view tcCache ty <- termType tcm arg go tcm ty where go tcm (coreView tcm -> Just ty') = go tcm ty' go tcm (tyView -> TyConApp vecTcNm _) | (Just vecTc) <- HashMap.lookup (nameOcc vecTcNm) tcm , [nilCon,consCon] <- tyConDataCons vecTc = let (vars,elems) = second concat . unzip $ extractElems consCon argElTy 'A' n arg funApps = map (fun `App`) vars lbody = mkVec nilCon consCon resElTy n funApps lb = Letrec (bind (rec (init elems)) lbody) in changed lb go _ ty = error $ $(curLoc) ++ "reduceMap: argument does not have a vector type: " ++ showDoc ty -- | Replace an application of the @Clash.Sized.Vector.imap@ primitive on vectors -- of a known length @n@, by the fully unrolled recursive "definition" of -- @Clash.Sized.Vector.imap@ reduceImap :: Integer -- ^ Length of the vector -> Type -- ^ Argument type of the function -> Type -- ^ Result type of the function -> Term -- ^ The imap'd function -> Term -- ^ The imap'd over vector -> NormalizeSession Term reduceImap n argElTy resElTy fun arg = do tcm <- Lens.view tcCache ty <- termType tcm arg go tcm ty where go tcm (coreView tcm -> Just ty') = go tcm ty' go tcm (tyView -> TyConApp vecTcNm _) | (Just vecTc) <- HashMap.lookup (nameOcc vecTcNm) tcm , [nilCon,consCon] <- tyConDataCons vecTc = do let (vars,elems) = second concat . unzip $ extractElems consCon argElTy 'I' n arg (Right idxTy:_,_) <- splitFunForallTy <$> termType tcm fun let (TyConApp idxTcNm _) = tyView idxTy nTv = string2InternalName "n" -- fromInteger# :: KnownNat n => Integer -> Index n idxFromIntegerTy = ForAllTy (bind (TyVar nTv (embed typeNatKind)) (foldr mkFunTy (mkTyConApp idxTcNm [VarTy typeNatKind nTv]) [integerPrimTy,integerPrimTy])) idxFromInteger = Prim "Clash.Sized.Internal.Index.fromInteger#" idxFromIntegerTy idxs = map (App (App (TyApp idxFromInteger (LitTy (NumTy n))) (Literal (IntegerLiteral (toInteger n)))) . Literal . IntegerLiteral . toInteger) [0..(n-1)] funApps = zipWith (\i v -> App (App fun i) v) idxs vars lbody = mkVec nilCon consCon resElTy n funApps lb = Letrec (bind (rec (init elems)) lbody) changed lb go _ ty = error $ $(curLoc) ++ "reduceImap: argument does not have a vector type: " ++ showDoc ty -- | Replace an application of the @Clash.Sized.Vector.traverse#@ primitive on -- vectors of a known length @n@, by the fully unrolled recursive "definition" -- of @Clash.Sized.Vector.traverse#@ reduceTraverse :: Integer -- ^ Length of the vector -> Type -- ^ Element type of the argument vector -> Type -- ^ The type of the applicative -> Type -- ^ Element type of the result vector -> Term -- ^ The @Applicative@ dictionary -> Term -- ^ The function to traverse with -> Term -- ^ The argument vector -> NormalizeSession Term reduceTraverse n aTy fTy bTy dict fun arg = do tcm <- Lens.view tcCache (TyConApp apDictTcNm _) <- tyView <$> termType tcm dict ty <- termType tcm arg go tcm apDictTcNm ty where go tcm apDictTcNm (coreView tcm -> Just ty') = go tcm apDictTcNm ty' go tcm apDictTcNm (tyView -> TyConApp vecTcNm _) | (Just vecTc) <- HashMap.lookup (nameOcc vecTcNm) tcm , [nilCon,consCon] <- tyConDataCons vecTc = let (Just apDictTc) = HashMap.lookup (nameOcc apDictTcNm) tcm [apDictCon] = tyConDataCons apDictTc (Just apDictIdTys) = dataConInstArgTys apDictCon [fTy] apDictIds = zipWith Id (map string2InternalName ["functorDict" ,"pure" ,"ap" ,"apConstL" ,"apConstR"]) (map embed apDictIdTys) (TyConApp funcDictTcNm _) = tyView (head apDictIdTys) (Just funcDictTc) = HashMap.lookup (nameOcc funcDictTcNm) tcm [funcDictCon] = tyConDataCons funcDictTc (Just funcDictIdTys) = dataConInstArgTys funcDictCon [fTy] funcDicIds = zipWith Id (map string2InternalName ["fmap","fmapConst"]) (map embed funcDictIdTys) apPat = DataPat (embed apDictCon) (rebind [] apDictIds) fnPat = DataPat (embed funcDictCon) (rebind [] funcDicIds) -- Extract the 'pure' function from the Applicative dictionary pureTy = apDictIdTys!!1 pureTm = Case dict pureTy [bind apPat (Var pureTy (string2InternalName "pure"))] -- Extract the '<*>' function from the Applicative dictionary apTy = apDictIdTys!!2 apTm = Case dict apTy [bind apPat (Var apTy (string2InternalName "ap"))] -- Extract the Functor dictionary from the Applicative dictionary funcTy = (head apDictIdTys) funcTm = Case dict funcTy [bind apPat (Var funcTy (string2InternalName "functorDict"))] -- Extract the 'fmap' function from the Functor dictionary fmapTy = (head funcDictIdTys) fmapTm = Case (Var funcTy (string2InternalName "functorDict")) fmapTy [bind fnPat (Var fmapTy (string2InternalName "fmap"))] (vars,elems) = second concat . unzip $ extractElems consCon aTy 'T' n arg funApps = map (fun `App`) vars lbody = mkTravVec vecTcNm nilCon consCon (idToVar (apDictIds!!1)) (idToVar (apDictIds!!2)) (idToVar (funcDicIds!!0)) bTy n funApps lb = Letrec (bind (rec ([((apDictIds!!0),embed funcTm) ,((apDictIds!!1),embed pureTm) ,((apDictIds!!2),embed apTm) ,((funcDicIds!!0),embed fmapTm) ] ++ init elems)) lbody) in changed lb go _ _ ty = error $ $(curLoc) ++ "reduceTraverse: argument does not have a vector type: " ++ showDoc ty -- | Create the traversable vector -- -- e.g. for a length '2' input vector, we get -- -- > (:>) <$> x0 <*> ((:>) <$> x1 <*> pure Nil) mkTravVec :: TyConName -- ^ Vec tcon -> DataCon -- ^ Nil con -> DataCon -- ^ Cons con -> Term -- ^ 'pure' term -> Term -- ^ '<*>' term -> Term -- ^ 'fmap' term -> Type -- ^ 'b' ty -> Integer -- ^ Length of the vector -> [Term] -- ^ Elements of the vector -> Term mkTravVec vecTc nilCon consCon pureTm apTm fmapTm bTy = go where go :: Integer -> [Term] -> Term go _ [] = mkApps pureTm [Right (mkTyConApp vecTc [LitTy (NumTy 0),bTy]) ,Left (mkApps (Data nilCon) [Right (LitTy (NumTy 0)) ,Right bTy ,Left (Prim "_CO_" nilCoTy)])] go n (x:xs) = mkApps apTm [Right (mkTyConApp vecTc [LitTy (NumTy (n-1)),bTy]) ,Right (mkTyConApp vecTc [LitTy (NumTy n),bTy]) ,Left (mkApps fmapTm [Right bTy ,Right (mkFunTy (mkTyConApp vecTc [LitTy (NumTy (n-1)),bTy]) (mkTyConApp vecTc [LitTy (NumTy n),bTy])) ,Left (mkApps (Data consCon) [Right (LitTy (NumTy n)) ,Right bTy ,Right (LitTy (NumTy (n-1))) ,Left (Prim "_CO_" (consCoTy n)) ]) ,Left x]) ,Left (go (n-1) xs)] nilCoTy = head (Maybe.fromJust (dataConInstArgTys nilCon [(LitTy (NumTy 0)) ,bTy])) consCoTy n = head (Maybe.fromJust (dataConInstArgTys consCon [(LitTy (NumTy n)) ,bTy ,(LitTy (NumTy (n-1)))])) -- | Replace an application of the @Clash.Sized.Vector.foldr@ primitive on -- vectors of a known length @n@, by the fully unrolled recursive "definition" -- of @Clash.Sized.Vector.foldr@ reduceFoldr :: Integer -- ^ Length of the vector -> Type -- ^ Element type of the argument vector -> Term -- ^ The function to fold with -> Term -- ^ The starting value -> Term -- ^ The argument vector -> NormalizeSession Term reduceFoldr n aTy fun start arg = do tcm <- Lens.view tcCache ty <- termType tcm arg go tcm ty where go tcm (coreView tcm -> Just ty') = go tcm ty' go tcm (tyView -> TyConApp vecTcNm _) | (Just vecTc) <- HashMap.lookup (nameOcc vecTcNm) tcm , [_,consCon] <- tyConDataCons vecTc = let (vars,elems) = second concat . unzip $ extractElems consCon aTy 'G' n arg lbody = foldr (\l r -> mkApps fun [Left l,Left r]) start vars lb = Letrec (bind (rec (init elems)) lbody) in changed lb go _ ty = error $ $(curLoc) ++ "reduceFoldr: argument does not have a vector type: " ++ showDoc ty -- | Replace an application of the @Clash.Sized.Vector.fold@ primitive on -- vectors of a known length @n@, by the fully unrolled recursive "definition" -- of @Clash.Sized.Vector.fold@ reduceFold :: Integer -- ^ Length of the vector -> Type -- ^ Element type of the argument vector -> Term -- ^ The function to fold with -> Term -- ^ The argument vector -> NormalizeSession Term reduceFold n aTy fun arg = do tcm <- Lens.view tcCache ty <- termType tcm arg go tcm ty where go tcm (coreView tcm -> Just ty') = go tcm ty' go tcm (tyView -> TyConApp vecTcNm _) | (Just vecTc) <- HashMap.lookup (nameOcc vecTcNm) tcm , [_,consCon] <- tyConDataCons vecTc = let (vars,elems) = second concat . unzip $ extractElems consCon aTy 'F' n arg lbody = foldV vars lb = Letrec (bind (rec (init elems)) lbody) in changed lb go _ ty = error $ $(curLoc) ++ "reduceFold: argument does not have a vector type: " ++ showDoc ty foldV [a] = a foldV as = let (l,r) = splitAt (length as `div` 2) as lF = foldV l rF = foldV r in mkApps fun [Left lF, Left rF] -- | Replace an application of the @Clash.Sized.Vector.dfold@ primitive on -- vectors of a known length @n@, by the fully unrolled recursive "definition" -- of @Clash.Sized.Vector.dfold@ reduceDFold :: Integer -- ^ Length of the vector -> Type -- ^ Element type of the argument vector -> Term -- ^ Function to fold with -> Term -- ^ Starting value -> Term -- ^ The vector to fold -> NormalizeSession Term reduceDFold n aTy fun start arg = do tcm <- Lens.view tcCache ty <- termType tcm arg go tcm ty where go tcm (coreView tcm -> Just ty') = go tcm ty' go tcm (tyView -> TyConApp vecTcNm _) | (Just vecTc) <- HashMap.lookup (nameOcc vecTcNm) tcm , [_,consCon] <- tyConDataCons vecTc = do let (vars,elems) = second concat . unzip $ extractElems consCon aTy 'D' n arg (_ltv:Right snTy:_,_) <- splitFunForallTy <$> termType tcm fun let (TyConApp snatTcNm _) = tyView snTy (Just snatTc) = HashMap.lookup (nameOcc snatTcNm) tcm [snatDc] = tyConDataCons snatTc lbody = doFold (buildSNat snatDc) (n-1) vars lb = Letrec (bind (rec (init elems)) lbody) changed lb go _ ty = error $ $(curLoc) ++ "reduceDFold: argument does not have a vector type: " ++ showDoc ty doFold _ _ [] = start doFold snDc k (x:xs) = mkApps fun [Right (LitTy (NumTy k)) ,Left (snDc k) ,Left x ,Left (doFold snDc (k-1) xs) ] -- | Replace an application of the @Clash.Sized.Vector.head@ primitive on -- vectors of a known length @n@, by a projection of the first element of a -- vector. reduceHead :: Integer -- ^ Length of the vector -> Type -- ^ Element type of the vector -> Term -- ^ The argument vector -> NormalizeSession Term reduceHead n aTy vArg = do tcm <- Lens.view tcCache ty <- termType tcm vArg go tcm ty where go tcm (coreView tcm -> Just ty') = go tcm ty' go tcm (tyView -> TyConApp vecTcNm _) | (Just vecTc) <- HashMap.lookup (nameOcc vecTcNm) tcm , [_,consCon] <- tyConDataCons vecTc = let (vars,elems) = second concat . unzip $ extractElems consCon aTy 'H' n vArg lb = Letrec (bind (rec [head elems]) (head vars)) in changed lb go _ ty = error $ $(curLoc) ++ "reduceHead: argument does not have a vector type: " ++ showDoc ty -- | Replace an application of the @Clash.Sized.Vector.tail@ primitive on -- vectors of a known length @n@, by a projection of the tail of a -- vector. reduceTail :: Integer -- ^ Length of the vector -> Type -- ^ Element type of the vector -> Term -- ^ The argument vector -> NormalizeSession Term reduceTail n aTy vArg = do tcm <- Lens.view tcCache ty <- termType tcm vArg go tcm ty where go tcm (coreView tcm -> Just ty') = go tcm ty' go tcm (tyView -> TyConApp vecTcNm _) | (Just vecTc) <- HashMap.lookup (nameOcc vecTcNm) tcm , [_,consCon] <- tyConDataCons vecTc = let (_,elems) = second concat . unzip $ extractElems consCon aTy 'L' n vArg b@(tB,_) = elems !! 1 lb = Letrec (bind (rec [b]) (idToVar tB)) in changed lb go _ ty = error $ $(curLoc) ++ "reduceTail: argument does not have a vector type: " ++ showDoc ty -- | Replace an application of the @Clash.Sized.Vector.last@ primitive on -- vectors of a known length @n@, by a projection of the last element of a -- vector. reduceLast :: Integer -- ^ Length of the vector -> Type -- ^ Element type of the vector -> Term -- ^ The argument vector -> NormalizeSession Term reduceLast n aTy vArg = do tcm <- Lens.view tcCache ty <- termType tcm vArg go tcm ty where go tcm (coreView tcm -> Just ty') = go tcm ty' go tcm (tyView -> TyConApp vecTcNm _) | (Just vecTc) <- HashMap.lookup (nameOcc vecTcNm) tcm , [_,consCon] <- tyConDataCons vecTc = let (_,elems) = unzip $ extractElems consCon aTy 'L' n vArg (tB,_) = head (last elems) in case n of 0 -> changed (mkApps (Prim "Clash.Transformations.undefined" undefinedTy) [Right aTy]) _ -> changed (Letrec (bind (rec (init (concat elems))) (idToVar tB))) go _ ty = error $ $(curLoc) ++ "reduceLast: argument does not have a vector type: " ++ showDoc ty -- | Replace an application of the @Clash.Sized.Vector.init@ primitive on -- vectors of a known length @n@, by a projection of the init of a -- vector. reduceInit :: Integer -- ^ Length of the vector -> Type -- ^ Element type of the vector -> Term -- ^ The argument vector -> NormalizeSession Term reduceInit n aTy vArg = do tcm <- Lens.view tcCache ty <- termType tcm vArg go tcm ty where go tcm (coreView tcm -> Just ty') = go tcm ty' go tcm (tyView -> TyConApp vecTcNm _) | (Just vecTc) <- HashMap.lookup (nameOcc vecTcNm) tcm , [nilCon,consCon] <- tyConDataCons vecTc = let (_,elems) = unzip $ extractElems consCon aTy 'L' n vArg in case n of 0 -> changed (mkApps (Prim "Clash.Transformations.undefined" undefinedTy) [Right aTy]) 1 -> changed (mkVec nilCon consCon aTy 0 []) _ -> let el = init elems iv = mkVec nilCon consCon aTy (n-1) (map (idToVar . fst . head) el) lb = rec (init (concat el)) in changed (Letrec (bind lb iv)) go _ ty = error $ $(curLoc) ++ "reduceInit: argument does not have a vector type: " ++ showDoc ty -- | Replace an application of the @Clash.Sized.Vector.(++)@ primitive on -- vectors of a known length @n@, by the fully unrolled recursive "definition" -- of @Clash.Sized.Vector.(++)@ reduceAppend :: Integer -- ^ Length of the LHS arg -> Integer -- ^ Lenght of the RHS arg -> Type -- ^ Element type of the vectors -> Term -- ^ The LHS argument -> Term -- ^ The RHS argument -> NormalizeSession Term reduceAppend n m aTy lArg rArg = do tcm <- Lens.view tcCache ty <- termType tcm lArg go tcm ty where go tcm (coreView tcm -> Just ty') = go tcm ty' go tcm (tyView -> TyConApp vecTcNm _) | (Just vecTc) <- HashMap.lookup (nameOcc vecTcNm) tcm , [_,consCon] <- tyConDataCons vecTc = let (vars,elems) = second concat . unzip $ extractElems consCon aTy 'C' n lArg lbody = appendToVec consCon aTy rArg (n+m) vars lb = Letrec (bind (rec (init elems)) lbody) in changed lb go _ ty = error $ $(curLoc) ++ "reduceAppend: argument does not have a vector type: " ++ showDoc ty -- | Replace an application of the @Clash.Sized.Vector.unconcat@ primitive on -- vectors of a known length @n@, by the fully unrolled recursive "definition" -- of @Clash.Sized.Vector.unconcat@ reduceUnconcat :: Integer -- ^ Length of the result vector -> Integer -- ^ Length of the elements of the result vector -> Type -- ^ Element type -> Term -- ^ Argument vector -> NormalizeSession Term reduceUnconcat n 0 aTy arg = do tcm <- Lens.view tcCache ty <- termType tcm arg go tcm ty where go tcm (coreView tcm -> Just ty') = go tcm ty' go tcm (tyView -> TyConApp vecTcNm _) | (Just vecTc) <- HashMap.lookup (nameOcc vecTcNm) tcm , [nilCon,consCon] <- tyConDataCons vecTc = let nilVec = mkVec nilCon consCon aTy 0 [] innerVecTy = mkTyConApp vecTcNm [LitTy (NumTy 0), aTy] retVec = mkVec nilCon consCon innerVecTy n (replicate (fromInteger n) nilVec) in changed retVec go _ ty = error $ $(curLoc) ++ "reduceUnconcat: argument does not have a vector type: " ++ showDoc ty reduceUnconcat _ _ _ _ = error $ $(curLoc) ++ "reduceUnconcat: unimplemented" -- | Replace an application of the @Clash.Sized.Vector.transpose@ primitive on -- vectors of a known length @n@, by the fully unrolled recursive "definition" -- of @Clash.Sized.Vector.transpose@ reduceTranspose :: Integer -- ^ Length of the result vector -> Integer -- ^ Length of the elements of the result vector -> Type -- ^ Element type -> Term -- ^ Argument vector -> NormalizeSession Term reduceTranspose n 0 aTy arg = do tcm <- Lens.view tcCache ty <- termType tcm arg go tcm ty where go tcm (coreView tcm -> Just ty') = go tcm ty' go tcm (tyView -> TyConApp vecTcNm _) | (Just vecTc) <- HashMap.lookup (nameOcc vecTcNm) tcm , [nilCon,consCon] <- tyConDataCons vecTc = let nilVec = mkVec nilCon consCon aTy 0 [] innerVecTy = mkTyConApp vecTcNm [LitTy (NumTy 0), aTy] retVec = mkVec nilCon consCon innerVecTy n (replicate (fromInteger n) nilVec) in changed retVec go _ ty = error $ $(curLoc) ++ "reduceTranspose: argument does not have a vector type: " ++ showDoc ty reduceTranspose _ _ _ _ = error $ $(curLoc) ++ "reduceTranspose: unimplemented" reduceReplicate :: Integer -> Type -> Type -> Term -> NormalizeSession Term reduceReplicate n aTy eTy arg = do tcm <- Lens.view tcCache go tcm eTy where go tcm (coreView tcm -> Just ty') = go tcm ty' go tcm (tyView -> TyConApp vecTcNm _) | (Just vecTc) <- HashMap.lookup (nameOcc vecTcNm) tcm , [nilCon,consCon] <- tyConDataCons vecTc = let retVec = mkVec nilCon consCon aTy n (replicate (fromInteger n) arg) in changed retVec go _ ty = error $ $(curLoc) ++ "reduceReplicate: argument does not have a vector type: " ++ showDoc ty -- | Replace an application of the @Clash.Sized.Vector.dtfold@ primitive on -- vectors of a known length @n@, by the fully unrolled recursive "definition" -- of @Clash.Sized.Vector.dtfold@ reduceDTFold :: Integer -- ^ Length of the vector -> Type -- ^ Element type of the argument vector -> Term -- ^ Function to convert elements with -> Term -- ^ Function to combine branches with -> Term -- ^ The vector to fold -> NormalizeSession Term reduceDTFold n aTy lrFun brFun arg = do tcm <- Lens.view tcCache ty <- termType tcm arg go tcm ty where go tcm (coreView tcm -> Just ty') = go tcm ty' go tcm (tyView -> TyConApp vecTcNm _) | (Just vecTc) <- HashMap.lookup (nameOcc vecTcNm) tcm , [_,consCon] <- tyConDataCons vecTc = do let (vars,elems) = second concat . unzip $ extractElems consCon aTy 'T' (2^n) arg (_ltv:Right snTy:_,_) <- splitFunForallTy <$> termType tcm brFun let (TyConApp snatTcNm _) = tyView snTy (Just snatTc) = HashMap.lookup (nameOcc snatTcNm) tcm [snatDc] = tyConDataCons snatTc lbody = doFold (buildSNat snatDc) (n-1) vars lb = Letrec (bind (rec (init elems)) lbody) changed lb go _ ty = error $ $(curLoc) ++ "reduceDTFold: argument does not have a vector type: " ++ showDoc ty doFold :: (Integer -> Term) -> Integer -> [Term] -> Term doFold _ _ [x] = mkApps lrFun [Left x] doFold snDc k xs = let (xsL,xsR) = splitAt (2^k) xs k' = k-1 eL = doFold snDc k' xsL eR = doFold snDc k' xsR in mkApps brFun [Right (LitTy (NumTy k)) ,Left (snDc k) ,Left eL ,Left eR ] -- | Replace an application of the @Clash.Sized.RTree.tdfold@ primitive on -- trees of a known depth @n@, by the fully unrolled recursive "definition" -- of @Clash.Sized.RTree.tdfold@ reduceTFold :: Integer -- ^ Depth of the tree -> Type -- ^ Element type of the argument tree -> Term -- ^ Function to convert elements with -> Term -- ^ Function to combine branches with -> Term -- ^ The tree to fold -> NormalizeSession Term reduceTFold n aTy lrFun brFun arg = do tcm <- Lens.view tcCache ty <- termType tcm arg go tcm ty where go tcm (coreView tcm -> Just ty') = go tcm ty' go tcm (tyView -> TyConApp treeTcNm _) | (Just treeTc) <- HashMap.lookup (nameOcc treeTcNm) tcm , [lrCon,brCon] <- tyConDataCons treeTc = do let (vars,elems) = extractTElems lrCon brCon aTy 'T' n arg (_ltv:Right snTy:_,_) <- splitFunForallTy <$> termType tcm brFun let (TyConApp snatTcNm _) = tyView snTy (Just snatTc) = HashMap.lookup (nameOcc snatTcNm) tcm [snatDc] = tyConDataCons snatTc lbody = doFold (buildSNat snatDc) (n-1) vars lb = Letrec (bind (rec elems) lbody) changed lb go _ ty = error $ $(curLoc) ++ "reduceTFold: argument does not have a tree type: " ++ showDoc ty doFold _ _ [x] = mkApps lrFun [Left x] doFold snDc k xs = let (xsL,xsR) = splitAt (length xs `div` 2) xs k' = k-1 eL = doFold snDc k' xsL eR = doFold snDc k' xsR in mkApps brFun [Right (LitTy (NumTy k)) ,Left (snDc k) ,Left eL ,Left eR ] reduceTReplicate :: Integer -- ^ Depth of the tree -> Type -- ^ Element type -> Type -- ^ Result type -> Term -- ^ Element -> NormalizeSession Term reduceTReplicate n aTy eTy arg = do tcm <- Lens.view tcCache go tcm eTy where go tcm (coreView tcm -> Just ty') = go tcm ty' go tcm (tyView -> TyConApp treeTcNm _) | (Just treeTc) <- HashMap.lookup (nameOcc treeTcNm) tcm , [lrCon,brCon] <- tyConDataCons treeTc = let retVec = mkRTree lrCon brCon aTy n (replicate (2^n) arg) in changed retVec go _ ty = error $ $(curLoc) ++ "reduceTReplicate: argument does not have a vector type: " ++ showDoc ty buildSNat :: DataCon -> Integer -> Term buildSNat snatDc i = mkApps (Data snatDc) [Right (LitTy (NumTy i)) #if MIN_VERSION_ghc(8,2,0) ,Left (Literal (NaturalLiteral (toInteger i))) #else ,Left (Literal (IntegerLiteral (toInteger i))) #endif ]