module Data.Array.Accelerate.Analysis.Match (
MatchAcc,
(:=:)(..),
matchOpenAcc, matchPreOpenAcc,
matchOpenAfun, matchPreOpenAfun,
matchOpenExp, matchPreOpenExp,
matchOpenFun, matchPreOpenFun,
matchPrimFun, matchPrimFun',
matchIdx, matchTupleType,
matchIntegralType, matchFloatingType, matchNumType, matchScalarType,
HashAcc,
hashPreOpenAcc, hashOpenAcc,
hashPreOpenExp, hashOpenExp,
hashPreOpenFun,
) where
import Prelude hiding ( exp )
import Data.Maybe
import Data.Typeable
import Data.Hashable
import System.Mem.StableName
import System.IO.Unsafe ( unsafePerformIO )
import Data.Array.Accelerate.AST
import Data.Array.Accelerate.Type
import Data.Array.Accelerate.Array.Sugar
import Data.Array.Accelerate.Array.Representation ( SliceIndex(..) )
import Data.Array.Accelerate.Tuple hiding ( Tuple )
import qualified Data.Array.Accelerate.Tuple as Tuple
data s :=: t where
REFL :: s :=: s
deriving instance Show (s :=: t)
type MatchAcc acc = forall aenv s t. acc aenv s -> acc aenv t -> Maybe (s :=: t)
matchOpenAcc :: OpenAcc aenv s -> OpenAcc aenv t -> Maybe (s :=: t)
matchOpenAcc (OpenAcc acc1) (OpenAcc acc2) =
matchPreOpenAcc matchOpenAcc hashOpenAcc acc1 acc2
matchPreOpenAcc
:: forall acc aenv s t.
MatchAcc acc
-> HashAcc acc
-> PreOpenAcc acc aenv s
-> PreOpenAcc acc aenv t
-> Maybe (s :=: t)
matchPreOpenAcc matchAcc hashAcc = match
where
matchFun :: PreOpenFun acc env aenv u -> PreOpenFun acc env aenv v -> Maybe (u :=: v)
matchFun = matchPreOpenFun matchAcc hashAcc
matchExp :: PreOpenExp acc env aenv u -> PreOpenExp acc env aenv v -> Maybe (u :=: v)
matchExp = matchPreOpenExp matchAcc hashAcc
match :: PreOpenAcc acc aenv s -> PreOpenAcc acc aenv t -> Maybe (s :=: t)
match (Alet x1 a1) (Alet x2 a2)
| Just REFL <- matchAcc x1 x2
, Just REFL <- matchAcc a1 a2
= Just REFL
match (Avar v1) (Avar v2)
= matchIdx v1 v2
match (Atuple t1) (Atuple t2)
| Just REFL <- matchAtuple matchAcc t1 t2
= gcast REFL
match (Aprj ix1 t1) (Aprj ix2 t2)
| Just REFL <- matchAcc t1 t2
, Just REFL <- matchTupleIdx ix1 ix2
= Just REFL
match (Apply f1 a1) (Apply f2 a2)
| Just REFL <- matchPreOpenAfun matchAcc f1 f2
, Just REFL <- matchAcc a1 a2
= Just REFL
match (Aforeign ff1 _ a1) (Aforeign ff2 _ a2)
| Just REFL <- matchAcc a1 a2,
unsafePerformIO $ do
sn1 <- makeStableName ff1
sn2 <- makeStableName ff2
return $! hashStableName sn1 == hashStableName sn2
= gcast REFL
match (Acond p1 t1 e1) (Acond p2 t2 e2)
| Just REFL <- matchExp p1 p2
, Just REFL <- matchAcc t1 t2
, Just REFL <- matchAcc e1 e2
= Just REFL
match (Use a1) (Use a2)
| Just REFL <- matchArrays (arrays (undefined::s)) (arrays (undefined::t)) a1 a2
= gcast REFL
match (Unit e1) (Unit e2)
| Just REFL <- matchExp e1 e2
= Just REFL
match (Reshape sh1 a1) (Reshape sh2 a2)
| Just REFL <- matchExp sh1 sh2
, Just REFL <- matchAcc a1 a2
= Just REFL
match (Generate sh1 f1) (Generate sh2 f2)
| Just REFL <- matchExp sh1 sh2
, Just REFL <- matchFun f1 f2
= Just REFL
match (Transform sh1 ix1 f1 a1) (Transform sh2 ix2 f2 a2)
| Just REFL <- matchExp sh1 sh2
, Just REFL <- matchFun ix1 ix2
, Just REFL <- matchFun f1 f2
, Just REFL <- matchAcc a1 a2
= Just REFL
match (Replicate _ ix1 a1) (Replicate _ ix2 a2)
| Just REFL <- matchExp ix1 ix2
, Just REFL <- matchAcc a1 a2
= gcast REFL
match (Slice _ a1 ix1) (Slice _ a2 ix2)
| Just REFL <- matchAcc a1 a2
, Just REFL <- matchExp ix1 ix2
= gcast REFL
match (Map f1 a1) (Map f2 a2)
| Just REFL <- matchFun f1 f2
, Just REFL <- matchAcc a1 a2
= Just REFL
match (ZipWith f1 a1 b1) (ZipWith f2 a2 b2)
| Just REFL <- matchFun f1 f2
, Just REFL <- matchAcc a1 a2
, Just REFL <- matchAcc b1 b2
= Just REFL
match (Fold f1 z1 a1) (Fold f2 z2 a2)
| Just REFL <- matchFun f1 f2
, Just REFL <- matchExp z1 z2
, Just REFL <- matchAcc a1 a2
= Just REFL
match (Fold1 f1 a1) (Fold1 f2 a2)
| Just REFL <- matchFun f1 f2
, Just REFL <- matchAcc a1 a2
= Just REFL
match (FoldSeg f1 z1 a1 s1) (FoldSeg f2 z2 a2 s2)
| Just REFL <- matchFun f1 f2
, Just REFL <- matchExp z1 z2
, Just REFL <- matchAcc a1 a2
, Just REFL <- matchAcc s1 s2
= Just REFL
match (Fold1Seg f1 a1 s1) (Fold1Seg f2 a2 s2)
| Just REFL <- matchFun f1 f2
, Just REFL <- matchAcc a1 a2
, Just REFL <- matchAcc s1 s2
= Just REFL
match (Scanl f1 z1 a1) (Scanl f2 z2 a2)
| Just REFL <- matchFun f1 f2
, Just REFL <- matchExp z1 z2
, Just REFL <- matchAcc a1 a2
= Just REFL
match (Scanl' f1 z1 a1) (Scanl' f2 z2 a2)
| Just REFL <- matchFun f1 f2
, Just REFL <- matchExp z1 z2
, Just REFL <- matchAcc a1 a2
= Just REFL
match (Scanl1 f1 a1) (Scanl1 f2 a2)
| Just REFL <- matchFun f1 f2
, Just REFL <- matchAcc a1 a2
= Just REFL
match (Scanr f1 z1 a1) (Scanr f2 z2 a2)
| Just REFL <- matchFun f1 f2
, Just REFL <- matchExp z1 z2
, Just REFL <- matchAcc a1 a2
= Just REFL
match (Scanr' f1 z1 a1) (Scanr' f2 z2 a2)
| Just REFL <- matchFun f1 f2
, Just REFL <- matchExp z1 z2
, Just REFL <- matchAcc a1 a2
= Just REFL
match (Scanr1 f1 a1) (Scanr1 f2 a2)
| Just REFL <- matchFun f1 f2
, Just REFL <- matchAcc a1 a2
= Just REFL
match (Permute f1 d1 p1 a1) (Permute f2 d2 p2 a2)
| Just REFL <- matchFun f1 f2
, Just REFL <- matchAcc d1 d2
, Just REFL <- matchFun p1 p2
, Just REFL <- matchAcc a1 a2
= Just REFL
match (Backpermute sh1 ix1 a1) (Backpermute sh2 ix2 a2)
| Just REFL <- matchExp sh1 sh2
, Just REFL <- matchFun ix1 ix2
, Just REFL <- matchAcc a1 a2
= Just REFL
match (Stencil f1 b1 (a1 :: acc aenv (Array sh1 e1)))
(Stencil f2 b2 (a2 :: acc aenv (Array sh2 e2)))
| Just REFL <- matchFun f1 f2
, Just REFL <- matchAcc a1 a2
, matchBoundary (eltType (undefined::e1)) b1 b2
= Just REFL
match (Stencil2 f1 b1 (a1 :: acc aenv (Array sh1 e1 )) b2 (a2 :: acc aenv (Array sh2 e2 )))
(Stencil2 f2 b1' (a1' :: acc aenv (Array sh1' e1')) b2' (a2':: acc aenv (Array sh2' e2')))
| Just REFL <- matchFun f1 f2
, Just REFL <- matchAcc a1 a1'
, Just REFL <- matchAcc a2 a2'
, matchBoundary (eltType (undefined::e1)) b1 b1'
, matchBoundary (eltType (undefined::e2)) b2 b2'
= Just REFL
match _ _
= Nothing
matchAtuple
:: MatchAcc acc
-> Atuple (acc aenv) s
-> Atuple (acc aenv) t
-> Maybe (s :=: t)
matchAtuple matchAcc (SnocAtup t1 a1) (SnocAtup t2 a2)
| Just REFL <- matchAtuple matchAcc t1 t2
, Just REFL <- matchAcc a1 a2
= Just REFL
matchAtuple _ NilAtup NilAtup = Just REFL
matchAtuple _ _ _ = Nothing
matchOpenAfun :: OpenAfun aenv s -> OpenAfun aenv t -> Maybe (s :=: t)
matchOpenAfun = matchPreOpenAfun matchOpenAcc
matchPreOpenAfun :: MatchAcc acc -> PreOpenAfun acc aenv s -> PreOpenAfun acc aenv t -> Maybe (s :=: t)
matchPreOpenAfun m (Alam s) (Alam t)
| Just REFL <- matchEnvTop s t
, Just REFL <- matchPreOpenAfun m s t
= Just REFL
where
matchEnvTop :: (Arrays s, Arrays t)
=> PreOpenAfun acc (aenv, s) f -> PreOpenAfun acc (aenv, t) g -> Maybe (s :=: t)
matchEnvTop _ _ = gcast REFL
matchPreOpenAfun m (Abody s) (Abody t) = m s t
matchPreOpenAfun _ _ _ = Nothing
matchBoundary :: TupleType e -> Boundary e -> Boundary e -> Bool
matchBoundary ty (Constant s) (Constant t) = matchConst ty s t
matchBoundary _ Wrap Wrap = True
matchBoundary _ Clamp Clamp = True
matchBoundary _ Mirror Mirror = True
matchBoundary _ _ _ = False
matchArrays :: ArraysR s -> ArraysR t -> s -> t -> Maybe (s :=: t)
matchArrays ArraysRunit ArraysRunit () ()
= Just REFL
matchArrays (ArraysRpair a1 b1) (ArraysRpair a2 b2) (arr1,brr1) (arr2,brr2)
| Just REFL <- matchArrays a1 a2 arr1 arr2
, Just REFL <- matchArrays b1 b2 brr1 brr2
= Just REFL
matchArrays ArraysRarray ArraysRarray (Array _ ad1) (Array _ ad2)
| unsafePerformIO $ do
sn1 <- makeStableName ad1
sn2 <- makeStableName ad2
return $! hashStableName sn1 == hashStableName sn2
= gcast REFL
matchArrays _ _ _ _
= Nothing
matchOpenExp :: OpenExp env aenv s -> OpenExp env aenv t -> Maybe (s :=: t)
matchOpenExp = matchPreOpenExp matchOpenAcc hashOpenAcc
matchPreOpenExp
:: forall acc env aenv s t.
MatchAcc acc
-> HashAcc acc
-> PreOpenExp acc env aenv s
-> PreOpenExp acc env aenv t
-> Maybe (s :=: t)
matchPreOpenExp matchAcc hashAcc = match
where
match :: forall env' aenv' s' t'. PreOpenExp acc env' aenv' s' -> PreOpenExp acc env' aenv' t' -> Maybe (s' :=: t')
match (Let x1 e1) (Let x2 e2)
| Just REFL <- match x1 x2
, Just REFL <- match e1 e2
= Just REFL
match (Var v1) (Var v2)
= matchIdx v1 v2
match (Foreign ff1 _ e1) (Foreign ff2 _ e2)
| Just REFL <- match e1 e2
, unsafePerformIO $ do
sn1 <- makeStableName ff1
sn2 <- makeStableName ff2
return $! hashStableName sn1 == hashStableName sn2
= gcast REFL
match (Const c1) (Const c2)
| Just REFL <- matchTupleType (eltType (undefined::s')) (eltType (undefined::t'))
, matchConst (eltType (undefined::s')) c1 c2
= gcast REFL
match (Tuple t1) (Tuple t2)
| Just REFL <- matchTuple matchAcc hashAcc t1 t2
= gcast REFL
match (Prj ix1 t1) (Prj ix2 t2)
| Just REFL <- match t1 t2
, Just REFL <- matchTupleIdx ix1 ix2
= Just REFL
match IndexAny IndexAny
= gcast REFL
match IndexNil IndexNil
= Just REFL
match (IndexCons sl1 a1) (IndexCons sl2 a2)
| Just REFL <- match sl1 sl2
, Just REFL <- match a1 a2
= Just REFL
match (IndexHead sl1) (IndexHead sl2)
| Just REFL <- match sl1 sl2
= Just REFL
match (IndexTail sl1) (IndexTail sl2)
| Just REFL <- match sl1 sl2
= Just REFL
match (IndexSlice sliceIndex1 ix1 sh1) (IndexSlice sliceIndex2 ix2 sh2)
| Just REFL <- match ix1 ix2
, Just REFL <- match sh1 sh2
, Just REFL <- matchSliceRestrict sliceIndex1 sliceIndex2
= gcast REFL
match (IndexFull sliceIndex1 ix1 sl1) (IndexFull sliceIndex2 ix2 sl2)
| Just REFL <- match ix1 ix2
, Just REFL <- match sl1 sl2
, Just REFL <- matchSliceExtend sliceIndex1 sliceIndex2
= gcast REFL
match (ToIndex sh1 i1) (ToIndex sh2 i2)
| Just REFL <- match sh1 sh2
, Just REFL <- match i1 i2
= Just REFL
match (FromIndex sh1 i1) (FromIndex sh2 i2)
| Just REFL <- match i1 i2
, Just REFL <- match sh1 sh2
= Just REFL
match (Cond p1 t1 e1) (Cond p2 t2 e2)
| Just REFL <- match p1 p2
, Just REFL <- match t1 t2
, Just REFL <- match e1 e2
= Just REFL
match (Iterate n1 f1 x1) (Iterate n2 f2 x2)
| Just REFL <- match n1 n2
, Just REFL <- match x1 x2
, Just REFL <- match f1 f2
= Just REFL
match (PrimConst c1) (PrimConst c2)
= matchPrimConst c1 c2
match (PrimApp f1 x1) (PrimApp f2 x2)
| Just x1' <- commutes hashAcc f1 x1
, Just x2' <- commutes hashAcc f2 x2
, Just REFL <- match x1' x2'
, Just REFL <- matchPrimFun f1 f2
= Just REFL
| Just REFL <- match x1 x2
, Just REFL <- matchPrimFun f1 f2
= Just REFL
match (Index a1 x1) (Index a2 x2)
| Just REFL <- matchAcc a1 a2
, Just REFL <- match x1 x2
= Just REFL
match (LinearIndex a1 x1) (LinearIndex a2 x2)
| Just REFL <- matchAcc a1 a2
, Just REFL <- match x1 x2
= Just REFL
match (Shape a1) (Shape a2)
| Just REFL <- matchAcc a1 a2
= Just REFL
match (ShapeSize sh1) (ShapeSize sh2)
| Just REFL <- match sh1 sh2
= Just REFL
match (Intersect sa1 sb1) (Intersect sa2 sb2)
| Just REFL <- match sa1 sa2
, Just REFL <- match sb1 sb2
= Just REFL
match _ _
= Nothing
matchOpenFun :: OpenFun env aenv s -> OpenFun env aenv t -> Maybe (s :=: t)
matchOpenFun = matchPreOpenFun matchOpenAcc hashOpenAcc
matchPreOpenFun
:: MatchAcc acc
-> HashAcc acc
-> PreOpenFun acc env aenv s
-> PreOpenFun acc env aenv t
-> Maybe (s :=: t)
matchPreOpenFun m h (Lam s) (Lam t)
| Just REFL <- matchEnvTop s t
, Just REFL <- matchPreOpenFun m h s t
= Just REFL
where
matchEnvTop :: (Elt s, Elt t) => PreOpenFun acc (env, s) aenv f -> PreOpenFun acc (env, t) aenv g -> Maybe (s :=: t)
matchEnvTop _ _ = gcast REFL
matchPreOpenFun m h (Body s) (Body t) = matchPreOpenExp m h s t
matchPreOpenFun _ _ _ _ = Nothing
matchConst :: TupleType a -> a -> a -> Bool
matchConst UnitTuple () () = True
matchConst (SingleTuple ty) a b = evalEq ty (a,b)
matchConst (PairTuple ta tb) (a1,b1) (a2,b2) = matchConst ta a1 a2 && matchConst tb b1 b2
evalEq :: ScalarType a -> (a, a) -> Bool
evalEq (NumScalarType (IntegralNumType ty)) | IntegralDict <- integralDict ty = uncurry (==)
evalEq (NumScalarType (FloatingNumType ty)) | FloatingDict <- floatingDict ty = uncurry (==)
evalEq (NonNumScalarType ty) | NonNumDict <- nonNumDict ty = uncurry (==)
matchIdx :: Idx env s -> Idx env t -> Maybe (s :=: t)
matchIdx ZeroIdx ZeroIdx = Just REFL
matchIdx (SuccIdx u) (SuccIdx v) = matchIdx u v
matchIdx _ _ = Nothing
matchTupleIdx :: TupleIdx tup s -> TupleIdx tup t -> Maybe (s :=: t)
matchTupleIdx ZeroTupIdx ZeroTupIdx = Just REFL
matchTupleIdx (SuccTupIdx s) (SuccTupIdx t) = matchTupleIdx s t
matchTupleIdx _ _ = Nothing
matchTuple
:: MatchAcc acc
-> HashAcc acc
-> Tuple.Tuple (PreOpenExp acc env aenv) s
-> Tuple.Tuple (PreOpenExp acc env aenv) t
-> Maybe (s :=: t)
matchTuple _ _ NilTup NilTup = Just REFL
matchTuple m h (SnocTup t1 e1) (SnocTup t2 e2)
| Just REFL <- matchTuple m h t1 t2
, Just REFL <- matchPreOpenExp m h e1 e2
= Just REFL
matchTuple _ _ _ _ = Nothing
matchSliceRestrict
:: SliceIndex slix s co sh
-> SliceIndex slix t co' sh
-> Maybe (s :=: t)
matchSliceRestrict SliceNil SliceNil
= Just REFL
matchSliceRestrict (SliceAll sl1) (SliceAll sl2)
| Just REFL <- matchSliceRestrict sl1 sl2
= Just REFL
matchSliceRestrict (SliceFixed sl1) (SliceFixed sl2)
| Just REFL <- matchSliceRestrict sl1 sl2
= Just REFL
matchSliceRestrict _ _
= Nothing
matchSliceExtend
:: SliceIndex slix sl co s
-> SliceIndex slix sl co' t
-> Maybe (s :=: t)
matchSliceExtend SliceNil SliceNil
= Just REFL
matchSliceExtend (SliceAll sl1) (SliceAll sl2)
| Just REFL <- matchSliceExtend sl1 sl2
= Just REFL
matchSliceExtend (SliceFixed sl1) (SliceFixed sl2)
| Just REFL <- matchSliceExtend sl1 sl2
= Just REFL
matchSliceExtend _ _
= Nothing
matchPrimConst :: (Elt s, Elt t) => PrimConst s -> PrimConst t -> Maybe (s :=: t)
matchPrimConst (PrimMinBound s) (PrimMinBound t) = matchBoundedType s t
matchPrimConst (PrimMaxBound s) (PrimMaxBound t) = matchBoundedType s t
matchPrimConst (PrimPi s) (PrimPi t) = matchFloatingType s t
matchPrimConst _ _ = Nothing
matchPrimFun :: (Elt s, Elt t) => PrimFun (a -> s) -> PrimFun (a -> t) -> Maybe (s :=: t)
matchPrimFun (PrimAdd _) (PrimAdd _) = Just REFL
matchPrimFun (PrimSub _) (PrimSub _) = Just REFL
matchPrimFun (PrimMul _) (PrimMul _) = Just REFL
matchPrimFun (PrimNeg _) (PrimNeg _) = Just REFL
matchPrimFun (PrimAbs _) (PrimAbs _) = Just REFL
matchPrimFun (PrimSig _) (PrimSig _) = Just REFL
matchPrimFun (PrimQuot _) (PrimQuot _) = Just REFL
matchPrimFun (PrimRem _) (PrimRem _) = Just REFL
matchPrimFun (PrimIDiv _) (PrimIDiv _) = Just REFL
matchPrimFun (PrimMod _) (PrimMod _) = Just REFL
matchPrimFun (PrimBAnd _) (PrimBAnd _) = Just REFL
matchPrimFun (PrimBOr _) (PrimBOr _) = Just REFL
matchPrimFun (PrimBXor _) (PrimBXor _) = Just REFL
matchPrimFun (PrimBNot _) (PrimBNot _) = Just REFL
matchPrimFun (PrimBShiftL _) (PrimBShiftL _) = Just REFL
matchPrimFun (PrimBShiftR _) (PrimBShiftR _) = Just REFL
matchPrimFun (PrimBRotateL _) (PrimBRotateL _) = Just REFL
matchPrimFun (PrimBRotateR _) (PrimBRotateR _) = Just REFL
matchPrimFun (PrimFDiv _) (PrimFDiv _) = Just REFL
matchPrimFun (PrimRecip _) (PrimRecip _) = Just REFL
matchPrimFun (PrimSin _) (PrimSin _) = Just REFL
matchPrimFun (PrimCos _) (PrimCos _) = Just REFL
matchPrimFun (PrimTan _) (PrimTan _) = Just REFL
matchPrimFun (PrimAsin _) (PrimAsin _) = Just REFL
matchPrimFun (PrimAcos _) (PrimAcos _) = Just REFL
matchPrimFun (PrimAtan _) (PrimAtan _) = Just REFL
matchPrimFun (PrimAsinh _) (PrimAsinh _) = Just REFL
matchPrimFun (PrimAcosh _) (PrimAcosh _) = Just REFL
matchPrimFun (PrimAtanh _) (PrimAtanh _) = Just REFL
matchPrimFun (PrimExpFloating _) (PrimExpFloating _) = Just REFL
matchPrimFun (PrimSqrt _) (PrimSqrt _) = Just REFL
matchPrimFun (PrimLog _) (PrimLog _) = Just REFL
matchPrimFun (PrimFPow _) (PrimFPow _) = Just REFL
matchPrimFun (PrimLogBase _) (PrimLogBase _) = Just REFL
matchPrimFun (PrimAtan2 _) (PrimAtan2 _) = Just REFL
matchPrimFun (PrimTruncate _ s) (PrimTruncate _ t) = matchIntegralType s t
matchPrimFun (PrimRound _ s) (PrimRound _ t) = matchIntegralType s t
matchPrimFun (PrimFloor _ s) (PrimFloor _ t) = matchIntegralType s t
matchPrimFun (PrimCeiling _ s) (PrimCeiling _ t) = matchIntegralType s t
matchPrimFun (PrimLt _) (PrimLt _) = Just REFL
matchPrimFun (PrimGt _) (PrimGt _) = Just REFL
matchPrimFun (PrimLtEq _) (PrimLtEq _) = Just REFL
matchPrimFun (PrimGtEq _) (PrimGtEq _) = Just REFL
matchPrimFun (PrimEq _) (PrimEq _) = Just REFL
matchPrimFun (PrimNEq _) (PrimNEq _) = Just REFL
matchPrimFun (PrimMax _) (PrimMax _) = Just REFL
matchPrimFun (PrimMin _) (PrimMin _) = Just REFL
matchPrimFun (PrimFromIntegral _ s) (PrimFromIntegral _ t) = matchNumType s t
matchPrimFun PrimLAnd PrimLAnd = Just REFL
matchPrimFun PrimLOr PrimLOr = Just REFL
matchPrimFun PrimLNot PrimLNot = Just REFL
matchPrimFun PrimOrd PrimOrd = Just REFL
matchPrimFun PrimChr PrimChr = Just REFL
matchPrimFun PrimBoolToInt PrimBoolToInt = Just REFL
matchPrimFun _ _ = Nothing
matchPrimFun' :: (Elt s, Elt t) => PrimFun (s -> a) -> PrimFun (t -> a) -> Maybe (s :=: t)
matchPrimFun' (PrimAdd _) (PrimAdd _) = Just REFL
matchPrimFun' (PrimSub _) (PrimSub _) = Just REFL
matchPrimFun' (PrimMul _) (PrimMul _) = Just REFL
matchPrimFun' (PrimNeg _) (PrimNeg _) = Just REFL
matchPrimFun' (PrimAbs _) (PrimAbs _) = Just REFL
matchPrimFun' (PrimSig _) (PrimSig _) = Just REFL
matchPrimFun' (PrimQuot _) (PrimQuot _) = Just REFL
matchPrimFun' (PrimRem _) (PrimRem _) = Just REFL
matchPrimFun' (PrimIDiv _) (PrimIDiv _) = Just REFL
matchPrimFun' (PrimMod _) (PrimMod _) = Just REFL
matchPrimFun' (PrimBAnd _) (PrimBAnd _) = Just REFL
matchPrimFun' (PrimBOr _) (PrimBOr _) = Just REFL
matchPrimFun' (PrimBXor _) (PrimBXor _) = Just REFL
matchPrimFun' (PrimBNot _) (PrimBNot _) = Just REFL
matchPrimFun' (PrimBShiftL _) (PrimBShiftL _) = Just REFL
matchPrimFun' (PrimBShiftR _) (PrimBShiftR _) = Just REFL
matchPrimFun' (PrimBRotateL _) (PrimBRotateL _) = Just REFL
matchPrimFun' (PrimBRotateR _) (PrimBRotateR _) = Just REFL
matchPrimFun' (PrimFDiv _) (PrimFDiv _) = Just REFL
matchPrimFun' (PrimRecip _) (PrimRecip _) = Just REFL
matchPrimFun' (PrimSin _) (PrimSin _) = Just REFL
matchPrimFun' (PrimCos _) (PrimCos _) = Just REFL
matchPrimFun' (PrimTan _) (PrimTan _) = Just REFL
matchPrimFun' (PrimAsin _) (PrimAsin _) = Just REFL
matchPrimFun' (PrimAcos _) (PrimAcos _) = Just REFL
matchPrimFun' (PrimAtan _) (PrimAtan _) = Just REFL
matchPrimFun' (PrimAsinh _) (PrimAsinh _) = Just REFL
matchPrimFun' (PrimAcosh _) (PrimAcosh _) = Just REFL
matchPrimFun' (PrimAtanh _) (PrimAtanh _) = Just REFL
matchPrimFun' (PrimExpFloating _) (PrimExpFloating _) = Just REFL
matchPrimFun' (PrimSqrt _) (PrimSqrt _) = Just REFL
matchPrimFun' (PrimLog _) (PrimLog _) = Just REFL
matchPrimFun' (PrimFPow _) (PrimFPow _) = Just REFL
matchPrimFun' (PrimLogBase _) (PrimLogBase _) = Just REFL
matchPrimFun' (PrimAtan2 _) (PrimAtan2 _) = Just REFL
matchPrimFun' (PrimTruncate s _) (PrimTruncate t _) = matchFloatingType s t
matchPrimFun' (PrimRound s _) (PrimRound t _) = matchFloatingType s t
matchPrimFun' (PrimFloor s _) (PrimFloor t _) = matchFloatingType s t
matchPrimFun' (PrimCeiling s _) (PrimCeiling t _) = matchFloatingType s t
matchPrimFun' (PrimMax _) (PrimMax _) = Just REFL
matchPrimFun' (PrimMin _) (PrimMin _) = Just REFL
matchPrimFun' (PrimFromIntegral s _) (PrimFromIntegral t _) = matchIntegralType s t
matchPrimFun' PrimLAnd PrimLAnd = Just REFL
matchPrimFun' PrimLOr PrimLOr = Just REFL
matchPrimFun' PrimLNot PrimLNot = Just REFL
matchPrimFun' PrimOrd PrimOrd = Just REFL
matchPrimFun' PrimChr PrimChr = Just REFL
matchPrimFun' PrimBoolToInt PrimBoolToInt = Just REFL
matchPrimFun' (PrimLt s) (PrimLt t)
| Just REFL <- matchScalarType s t
= Just REFL
matchPrimFun' (PrimGt s) (PrimGt t)
| Just REFL <- matchScalarType s t
= Just REFL
matchPrimFun' (PrimLtEq s) (PrimLtEq t)
| Just REFL <- matchScalarType s t
= Just REFL
matchPrimFun' (PrimGtEq s) (PrimGtEq t)
| Just REFL <- matchScalarType s t
= Just REFL
matchPrimFun' (PrimEq s) (PrimEq t)
| Just REFL <- matchScalarType s t
= Just REFL
matchPrimFun' (PrimNEq s) (PrimNEq t)
| Just REFL <- matchScalarType s t
= Just REFL
matchPrimFun' _ _
= Nothing
matchTupleType :: TupleType s -> TupleType t -> Maybe (s :=: t)
matchTupleType UnitTuple UnitTuple = Just REFL
matchTupleType (SingleTuple s) (SingleTuple t) = matchScalarType s t
matchTupleType (PairTuple s1 s2) (PairTuple t1 t2)
| Just REFL <- matchTupleType s1 t1
, Just REFL <- matchTupleType s2 t2
= Just REFL
matchTupleType _ _
= Nothing
matchScalarType :: ScalarType s -> ScalarType t -> Maybe (s :=: t)
matchScalarType (NumScalarType s) (NumScalarType t) = matchNumType s t
matchScalarType (NonNumScalarType s) (NonNumScalarType t) = matchNonNumType s t
matchScalarType _ _ = Nothing
matchNumType :: NumType s -> NumType t -> Maybe (s :=: t)
matchNumType (IntegralNumType s) (IntegralNumType t) = matchIntegralType s t
matchNumType (FloatingNumType s) (FloatingNumType t) = matchFloatingType s t
matchNumType _ _ = Nothing
matchBoundedType :: BoundedType s -> BoundedType t -> Maybe (s :=: t)
matchBoundedType (IntegralBoundedType s) (IntegralBoundedType t) = matchIntegralType s t
matchBoundedType (NonNumBoundedType s) (NonNumBoundedType t) = matchNonNumType s t
matchBoundedType _ _ = Nothing
matchIntegralType :: IntegralType s -> IntegralType t -> Maybe (s :=: t)
matchIntegralType (TypeInt _) (TypeInt _) = Just REFL
matchIntegralType (TypeInt8 _) (TypeInt8 _) = Just REFL
matchIntegralType (TypeInt16 _) (TypeInt16 _) = Just REFL
matchIntegralType (TypeInt32 _) (TypeInt32 _) = Just REFL
matchIntegralType (TypeInt64 _) (TypeInt64 _) = Just REFL
matchIntegralType (TypeWord _) (TypeWord _) = Just REFL
matchIntegralType (TypeWord8 _) (TypeWord8 _) = Just REFL
matchIntegralType (TypeWord16 _) (TypeWord16 _) = Just REFL
matchIntegralType (TypeWord32 _) (TypeWord32 _) = Just REFL
matchIntegralType (TypeWord64 _) (TypeWord64 _) = Just REFL
matchIntegralType (TypeCShort _) (TypeCShort _) = Just REFL
matchIntegralType (TypeCUShort _) (TypeCUShort _) = Just REFL
matchIntegralType (TypeCInt _) (TypeCInt _) = Just REFL
matchIntegralType (TypeCUInt _) (TypeCUInt _) = Just REFL
matchIntegralType (TypeCLong _) (TypeCLong _) = Just REFL
matchIntegralType (TypeCULong _) (TypeCULong _) = Just REFL
matchIntegralType (TypeCLLong _) (TypeCLLong _) = Just REFL
matchIntegralType (TypeCULLong _) (TypeCULLong _) = Just REFL
matchIntegralType _ _ = Nothing
matchFloatingType :: FloatingType s -> FloatingType t -> Maybe (s :=: t)
matchFloatingType (TypeFloat _) (TypeFloat _) = Just REFL
matchFloatingType (TypeDouble _) (TypeDouble _) = Just REFL
matchFloatingType (TypeCFloat _) (TypeCFloat _) = Just REFL
matchFloatingType (TypeCDouble _) (TypeCDouble _) = Just REFL
matchFloatingType _ _ = Nothing
matchNonNumType :: NonNumType s -> NonNumType t -> Maybe (s :=: t)
matchNonNumType (TypeBool _) (TypeBool _) = Just REFL
matchNonNumType (TypeChar _) (TypeChar _) = Just REFL
matchNonNumType (TypeCChar _) (TypeCChar _) = Just REFL
matchNonNumType (TypeCSChar _) (TypeCSChar _) = Just REFL
matchNonNumType (TypeCUChar _) (TypeCUChar _) = Just REFL
matchNonNumType _ _ = Nothing
commutes
:: forall acc env aenv a r.
HashAcc acc
-> PrimFun (a -> r)
-> PreOpenExp acc env aenv a
-> Maybe (PreOpenExp acc env aenv a)
commutes h f x = case f of
PrimAdd _ -> Just (swizzle x)
PrimMul _ -> Just (swizzle x)
PrimBAnd _ -> Just (swizzle x)
PrimBOr _ -> Just (swizzle x)
PrimBXor _ -> Just (swizzle x)
PrimEq _ -> Just (swizzle x)
PrimNEq _ -> Just (swizzle x)
PrimMax _ -> Just (swizzle x)
PrimMin _ -> Just (swizzle x)
PrimLAnd -> Just (swizzle x)
PrimLOr -> Just (swizzle x)
_ -> Nothing
where
swizzle :: PreOpenExp acc env aenv (a',a') -> PreOpenExp acc env aenv (a',a')
swizzle exp
| Tuple (NilTup `SnocTup` a `SnocTup` b) <- exp
, hashPreOpenExp h a > hashPreOpenExp h b = Tuple (NilTup `SnocTup` b `SnocTup` a)
| otherwise = exp
hashIdx :: Idx env t -> Int
hashIdx = hash . idxToInt
hashTupleIdx :: TupleIdx tup e -> Int
hashTupleIdx = hash . tupleIdxToInt
type HashAcc acc = forall aenv a. acc aenv a -> Int
hashOpenAcc :: OpenAcc aenv arrs -> Int
hashOpenAcc (OpenAcc pacc) = hashPreOpenAcc hashOpenAcc pacc
hashPreOpenAcc :: forall acc aenv arrs. HashAcc acc -> PreOpenAcc acc aenv arrs -> Int
hashPreOpenAcc hashAcc pacc =
let
hashA :: Int -> acc aenv' a -> Int
hashA salt = hashWithSalt salt . hashAcc
hashE :: Int -> PreOpenExp acc env' aenv' e -> Int
hashE salt = hashWithSalt salt . hashPreOpenExp hashAcc
hashF :: Int -> PreOpenFun acc env' aenv' f -> Int
hashF salt = hashWithSalt salt . hashPreOpenFun hashAcc
in case pacc of
Alet bnd body -> hash "Alet" `hashA` bnd `hashA` body
Avar v -> hash "Avar" `hashWithSalt` hashIdx v
Atuple t -> hash "Atuple" `hashWithSalt` hashAtuple hashAcc t
Aprj ix a -> hash "Aprj" `hashWithSalt` hashTupleIdx ix `hashA` a
Apply f a -> hash "Apply" `hashWithSalt` hashAfun hashAcc f `hashA` a
Aforeign _ f a -> hash "Aforeign" `hashWithSalt` hashAfun hashAcc f `hashA` a
Use a -> hash "Use" `hashWithSalt` hashArrays (arrays (undefined::arrs)) a
Unit e -> hash "Unit" `hashE` e
Generate e f -> hash "Generate" `hashE` e `hashF` f
Acond e a1 a2 -> hash "Acond" `hashE` e `hashA` a1 `hashA` a2
Reshape sh a -> hash "Reshape" `hashE` sh `hashA` a
Transform sh f1 f2 a -> hash "Transform" `hashE` sh `hashF` f1 `hashF` f2 `hashA` a
Replicate spec ix a -> hash "Replicate" `hashE` ix `hashA` a `hashWithSalt` show spec
Slice spec a ix -> hash "Slice" `hashE` ix `hashA` a `hashWithSalt` show spec
Map f a -> hash "Map" `hashF` f `hashA` a
ZipWith f a1 a2 -> hash "ZipWith" `hashF` f `hashA` a1 `hashA` a2
Fold f e a -> hash "Fold" `hashF` f `hashE` e `hashA` a
Fold1 f a -> hash "Fold1" `hashF` f `hashA` a
FoldSeg f e a s -> hash "FoldSeg" `hashF` f `hashE` e `hashA` a `hashA` s
Fold1Seg f a s -> hash "Fold1Seg" `hashF` f `hashA` a `hashA` s
Scanl f e a -> hash "Scanl" `hashF` f `hashE` e `hashA` a
Scanl' f e a -> hash "Scanl'" `hashF` f `hashE` e `hashA` a
Scanl1 f a -> hash "Scanl1" `hashF` f `hashA` a
Scanr f e a -> hash "Scanr" `hashF` f `hashE` e `hashA` a
Scanr' f e a -> hash "Scanr'" `hashF` f `hashE` e `hashA` a
Scanr1 f a -> hash "Scanr1" `hashF` f `hashA` a
Backpermute sh f a -> hash "Backpermute" `hashF` f `hashE` sh `hashA` a
Permute f1 a1 f2 a2 -> hash "Permute" `hashF` f1 `hashA` a1 `hashF` f2 `hashA` a2
Stencil f b a -> hash "Stencil" `hashF` f `hashA` a `hashWithSalt` hashBoundary a b
Stencil2 f b1 a1 b2 a2 -> hash "Stencil2" `hashF` f `hashA` a1 `hashA` a2 `hashWithSalt` hashBoundary a1 b1 `hashWithSalt` hashBoundary a2 b2
hashArrays :: ArraysR a -> a -> Int
hashArrays ArraysRunit () = hash ()
hashArrays (ArraysRpair r1 r2) (a1, a2) = hash ( hashArrays r1 a1, hashArrays r2 a2)
hashArrays ArraysRarray ad = unsafePerformIO $! hashStableName `fmap` makeStableName ad
hashAtuple :: HashAcc acc -> Tuple.Atuple (acc aenv) a -> Int
hashAtuple _ NilAtup = hash "NilAtup"
hashAtuple h (SnocAtup t a) = hash "SnocAtup" `hashWithSalt` hashAtuple h t `hashWithSalt` h a
hashAfun :: HashAcc acc -> PreOpenAfun acc aenv f -> Int
hashAfun h (Abody b) = hash "Abody" `hashWithSalt` h b
hashAfun h (Alam f) = hash "Alam" `hashWithSalt` hashAfun h f
hashBoundary :: forall acc aenv sh e. Elt e => acc aenv (Array sh e) -> Boundary (EltRepr e) -> Int
hashBoundary _ Wrap = hash "Wrap"
hashBoundary _ Clamp = hash "Clamp"
hashBoundary _ Mirror = hash "Mirror"
hashBoundary _ (Constant c) = hash "Constant" `hashWithSalt` show (toElt c :: e)
hashOpenExp :: OpenExp env aenv exp -> Int
hashOpenExp = hashPreOpenExp hashOpenAcc
hashPreOpenExp :: forall acc env aenv exp. HashAcc acc -> PreOpenExp acc env aenv exp -> Int
hashPreOpenExp hashAcc exp =
let
hashA :: Int -> acc aenv' a -> Int
hashA salt = hashWithSalt salt . hashAcc
hashE :: Int -> PreOpenExp acc env' aenv' e -> Int
hashE salt = hashWithSalt salt . hashPreOpenExp hashAcc
in case exp of
Let bnd body -> hash "Let" `hashE` bnd `hashE` body
Var ix -> hash "Var" `hashWithSalt` hashIdx ix
Const c -> hash "Const" `hashWithSalt` show (toElt c :: exp)
Tuple t -> hash "Tuple" `hashWithSalt` hashTuple hashAcc t
Prj i e -> hash "Prj" `hashWithSalt` hashTupleIdx i `hashE` e
IndexAny -> hash "IndexAny"
IndexNil -> hash "IndexNil"
IndexCons sl a -> hash "IndexCons" `hashE` sl `hashE` a
IndexHead sl -> hash "IndexHead" `hashE` sl
IndexTail sl -> hash "IndexTail" `hashE` sl
IndexSlice spec ix sh -> hash "IndexSlice" `hashE` ix `hashE` sh `hashWithSalt` show spec
IndexFull spec ix sl -> hash "IndexFull" `hashE` ix `hashE` sl `hashWithSalt` show spec
ToIndex sh i -> hash "ToIndex" `hashE` sh `hashE` i
FromIndex sh i -> hash "FromIndex" `hashE` sh `hashE` i
Cond c t e -> hash "Cond" `hashE` c `hashE` t `hashE` e
Iterate n f x -> hash "Iterate" `hashE` n `hashE` f `hashE` x
PrimApp f x -> hash "PrimApp" `hashWithSalt` hashPrimFun f `hashE` fromMaybe x (commutes hashAcc f x)
PrimConst c -> hash "PrimConst" `hashWithSalt` hashPrimConst c
Index a ix -> hash "Index" `hashA` a `hashE` ix
LinearIndex a ix -> hash "LinearIndex" `hashA` a `hashE` ix
Shape a -> hash "Shape" `hashA` a
ShapeSize sh -> hash "ShapeSize" `hashE` sh
Intersect sa sb -> hash "Intersect" `hashE` sa `hashE` sb
Foreign _ f e -> hash "Foreign" `hashWithSalt` hashPreOpenFun hashAcc f `hashE` e
hashPreOpenFun :: HashAcc acc -> PreOpenFun acc env aenv f -> Int
hashPreOpenFun h (Body e) = hash "Body" `hashWithSalt` hashPreOpenExp h e
hashPreOpenFun h (Lam f) = hash "Lam" `hashWithSalt` hashPreOpenFun h f
hashTuple :: HashAcc acc -> Tuple.Tuple (PreOpenExp acc env aenv) e -> Int
hashTuple _ NilTup = hash "NilTup"
hashTuple h (SnocTup t e) = hash "SnocTup" `hashWithSalt` hashTuple h t `hashWithSalt` hashPreOpenExp h e
hashPrimConst :: PrimConst c -> Int
hashPrimConst (PrimMinBound _) = hash "PrimMinBound"
hashPrimConst (PrimMaxBound _) = hash "PrimMaxBound"
hashPrimConst (PrimPi _) = hash "PrimPi"
hashPrimFun :: PrimFun f -> Int
hashPrimFun (PrimAdd _) = hash "PrimAdd"
hashPrimFun (PrimSub _) = hash "PrimSub"
hashPrimFun (PrimMul _) = hash "PrimMul"
hashPrimFun (PrimNeg _) = hash "PrimNeg"
hashPrimFun (PrimAbs _) = hash "PrimAbs"
hashPrimFun (PrimSig _) = hash "PrimSig"
hashPrimFun (PrimQuot _) = hash "PrimQuot"
hashPrimFun (PrimRem _) = hash "PrimRem"
hashPrimFun (PrimIDiv _) = hash "PrimIDiv"
hashPrimFun (PrimMod _) = hash "PrimMod"
hashPrimFun (PrimBAnd _) = hash "PrimBAnd"
hashPrimFun (PrimBOr _) = hash "PrimBOr"
hashPrimFun (PrimBXor _) = hash "PrimBXor"
hashPrimFun (PrimBNot _) = hash "PrimBNot"
hashPrimFun (PrimBShiftL _) = hash "PrimBShiftL"
hashPrimFun (PrimBShiftR _) = hash "PrimBShiftR"
hashPrimFun (PrimBRotateL _) = hash "PrimBRotateL"
hashPrimFun (PrimBRotateR _) = hash "PrimBRotateR"
hashPrimFun (PrimFDiv _) = hash "PrimFDiv"
hashPrimFun (PrimRecip _) = hash "PrimRecip"
hashPrimFun (PrimSin _) = hash "PrimSin"
hashPrimFun (PrimCos _) = hash "PrimCos"
hashPrimFun (PrimTan _) = hash "PrimTan"
hashPrimFun (PrimAsin _) = hash "PrimAsin"
hashPrimFun (PrimAcos _) = hash "PrimAcos"
hashPrimFun (PrimAtan _) = hash "PrimAtan"
hashPrimFun (PrimAsinh _) = hash "PrimAsinh"
hashPrimFun (PrimAcosh _) = hash "PrimAcosh"
hashPrimFun (PrimAtanh _) = hash "PrimAtanh"
hashPrimFun (PrimExpFloating _) = hash "PrimExpFloating"
hashPrimFun (PrimSqrt _) = hash "PrimSqrt"
hashPrimFun (PrimLog _) = hash "PrimLog"
hashPrimFun (PrimFPow _) = hash "PrimFPow"
hashPrimFun (PrimLogBase _) = hash "PrimLogBase"
hashPrimFun (PrimAtan2 _) = hash "PrimAtan2"
hashPrimFun (PrimTruncate _ _) = hash "PrimTruncate"
hashPrimFun (PrimRound _ _) = hash "PrimRound"
hashPrimFun (PrimFloor _ _) = hash "PrimFloor"
hashPrimFun (PrimCeiling _ _) = hash "PrimCeiling"
hashPrimFun (PrimLt _) = hash "PrimLt"
hashPrimFun (PrimGt _) = hash "PrimGt"
hashPrimFun (PrimLtEq _) = hash "PrimLtEq"
hashPrimFun (PrimGtEq _) = hash "PrimGtEq"
hashPrimFun (PrimEq _) = hash "PrimEq"
hashPrimFun (PrimNEq _) = hash "PrimNEq"
hashPrimFun (PrimMax _) = hash "PrimMax"
hashPrimFun (PrimMin _) = hash "PrimMin"
hashPrimFun (PrimFromIntegral _ _) = hash "PrimFromIntegral"
hashPrimFun PrimLAnd = hash "PrimLAnd"
hashPrimFun PrimLOr = hash "PrimLOr"
hashPrimFun PrimLNot = hash "PrimLNot"
hashPrimFun PrimOrd = hash "PrimOrd"
hashPrimFun PrimChr = hash "PrimChr"
hashPrimFun PrimBoolToInt = hash "PrimBoolToInt"