{-# LANGUAGE CPP #-} {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE RebindableSyntax #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE ViewPatterns #-} {-# OPTIONS_GHC -fno-warn-orphans #-} #if __GLASGOW_HASKELL__ >= 806 {-# LANGUAGE UndecidableInstances #-} #endif -- | -- Module : Data.Array.Accelerate.Data.Semigroup -- Copyright : [2018] Trevor L. McDonell -- License : BSD3 -- -- Maintainer : Trevor L. McDonell -- Stability : experimental -- Portability : non-portable (GHC extensions) -- -- Semigroup instances for Accelerate -- -- @since 1.2.0.0 -- module Data.Array.Accelerate.Data.Semigroup ( Semigroup(..), Min(..), Max(..), ) where import Data.Array.Accelerate.Array.Sugar import Data.Array.Accelerate.Classes.Bounded import Data.Array.Accelerate.Classes.Eq import Data.Array.Accelerate.Classes.Num import Data.Array.Accelerate.Classes.Ord import Data.Array.Accelerate.Lift import Data.Array.Accelerate.Product import Data.Array.Accelerate.Smart import Data.Array.Accelerate.Type import Data.Function import Data.Monoid ( Monoid(..) ) import Data.Semigroup import Prelude ( undefined ) import qualified Prelude as P type instance EltRepr (Min a) = ((), EltRepr a) instance Elt a => Elt (Min a) where eltType _ = TypeRpair TypeRunit (eltType (undefined::a)) toElt ((),x) = Min (toElt x) fromElt (Min x) = ((), fromElt x) instance Elt a => IsProduct Elt (Min a) where type ProdRepr (Min a) = ((), a) toProd _ ((),a) = Min a fromProd _ (Min a) = ((),a) prod _ _ = ProdRsnoc ProdRunit instance (Lift Exp a, Elt (Plain a)) => Lift Exp (Min a) where type Plain (Min a) = Min (Plain a) lift (Min a) = Exp $ Tuple $ NilTup `SnocTup` lift a instance Elt a => Unlift Exp (Min (Exp a)) where unlift t = Min . Exp $ ZeroTupIdx `Prj` t instance Bounded a => P.Bounded (Exp (Min a)) where minBound = lift $ Min (minBound :: Exp a) maxBound = lift $ Min (maxBound :: Exp a) instance Num a => P.Num (Exp (Min a)) where (+) = lift2 ((+) :: Min (Exp a) -> Min (Exp a) -> Min (Exp a)) (-) = lift2 ((-) :: Min (Exp a) -> Min (Exp a) -> Min (Exp a)) (*) = lift2 ((*) :: Min (Exp a) -> Min (Exp a) -> Min (Exp a)) negate = lift1 (negate :: Min (Exp a) -> Min (Exp a)) signum = lift1 (signum :: Min (Exp a) -> Min (Exp a)) abs = lift1 (signum :: Min (Exp a) -> Min (Exp a)) fromInteger x = lift (P.fromInteger x :: Min (Exp a)) instance Eq a => Eq (Min a) where (==) = lift2 ((==) `on` getMin) (/=) = lift2 ((/=) `on` getMin) instance Ord a => Ord (Min a) where (<) = lift2 ((<) `on` getMin) (>) = lift2 ((>) `on` getMin) (<=) = lift2 ((<=) `on` getMin) (>=) = lift2 ((>=) `on` getMin) min x y = lift . Min $ lift2 (min `on` getMin) x y max x y = lift . Min $ lift2 (max `on` getMin) x y instance Ord a => Semigroup (Exp (Min a)) where x <> y = lift . Min $ lift2 (min `on` getMin) x y stimes = stimesIdempotent instance (Ord a, Bounded a) => Monoid (Exp (Min a)) where mempty = maxBound mappend = (<>) type instance EltRepr (Max a) = ((), EltRepr a) instance Elt a => Elt (Max a) where eltType _ = TypeRpair TypeRunit (eltType (undefined::a)) toElt ((),x) = Max (toElt x) fromElt (Max x) = ((), fromElt x) instance Elt a => IsProduct Elt (Max a) where type ProdRepr (Max a) = ((), a) toProd _ ((),a) = Max a fromProd _ (Max a) = ((),a) prod _ _ = ProdRsnoc ProdRunit instance (Lift Exp a, Elt (Plain a)) => Lift Exp (Max a) where type Plain (Max a) = Max (Plain a) lift (Max a) = Exp $ Tuple $ NilTup `SnocTup` lift a instance Elt a => Unlift Exp (Max (Exp a)) where unlift t = Max . Exp $ ZeroTupIdx `Prj` t instance Bounded a => P.Bounded (Exp (Max a)) where minBound = lift $ Max (minBound :: Exp a) maxBound = lift $ Max (maxBound :: Exp a) instance Num a => P.Num (Exp (Max a)) where (+) = lift2 ((+) :: Max (Exp a) -> Max (Exp a) -> Max (Exp a)) (-) = lift2 ((-) :: Max (Exp a) -> Max (Exp a) -> Max (Exp a)) (*) = lift2 ((*) :: Max (Exp a) -> Max (Exp a) -> Max (Exp a)) negate = lift1 (negate :: Max (Exp a) -> Max (Exp a)) signum = lift1 (signum :: Max (Exp a) -> Max (Exp a)) abs = lift1 (signum :: Max (Exp a) -> Max (Exp a)) fromInteger x = lift (P.fromInteger x :: Max (Exp a)) instance Eq a => Eq (Max a) where (==) = lift2 ((==) `on` getMax) (/=) = lift2 ((/=) `on` getMax) instance Ord a => Ord (Max a) where (<) = lift2 ((<) `on` getMax) (>) = lift2 ((>) `on` getMax) (<=) = lift2 ((<=) `on` getMax) (>=) = lift2 ((>=) `on` getMax) min x y = lift . Max $ lift2 (min `on` getMax) x y max x y = lift . Max $ lift2 (max `on` getMax) x y instance Ord a => Semigroup (Exp (Max a)) where x <> y = lift . Max $ lift2 (max `on` getMax) x y stimes = stimesIdempotent instance (Ord a, Bounded a) => Monoid (Exp (Max a)) where mempty = minBound mappend = (<>) -- Instances for unit and tuples -- ----------------------------- instance Semigroup (Exp ()) where _ <> _ = constant () sconcat _ = constant () stimes _ _ = constant () instance (Elt a, Elt b, Semigroup (Exp a), Semigroup (Exp b)) => Semigroup (Exp (a,b)) where (<>) = lift2 ((<>) :: (Exp a, Exp b) -> (Exp a, Exp b) -> (Exp a, Exp b)) stimes n (unlift -> (a,b) :: (Exp a, Exp b)) = lift (stimes n a, stimes n b) instance (Elt a, Elt b, Elt c, Semigroup (Exp a), Semigroup (Exp b), Semigroup (Exp c)) => Semigroup (Exp (a,b,c)) where (<>) = lift2 ((<>) :: (Exp a, Exp b, Exp c) -> (Exp a, Exp b, Exp c) -> (Exp a, Exp b, Exp c)) stimes n (unlift -> (a,b,c) :: (Exp a, Exp b, Exp c)) = lift (stimes n a, stimes n b, stimes n c) instance (Elt a, Elt b, Elt c, Elt d, Semigroup (Exp a), Semigroup (Exp b), Semigroup (Exp c), Semigroup (Exp d)) => Semigroup (Exp (a,b,c,d)) where (<>) = lift2 ((<>) :: (Exp a, Exp b, Exp c, Exp d) -> (Exp a, Exp b, Exp c, Exp d) -> (Exp a, Exp b, Exp c, Exp d)) stimes n (unlift -> (a,b,c,d) :: (Exp a, Exp b, Exp c, Exp d)) = lift (stimes n a, stimes n b, stimes n c, stimes n d) instance (Elt a, Elt b, Elt c, Elt d, Elt e, Semigroup (Exp a), Semigroup (Exp b), Semigroup (Exp c), Semigroup (Exp d), Semigroup (Exp e)) => Semigroup (Exp (a,b,c,d,e)) where (<>) = lift2 ((<>) :: (Exp a, Exp b, Exp c, Exp d, Exp e) -> (Exp a, Exp b, Exp c, Exp d, Exp e) -> (Exp a, Exp b, Exp c, Exp d, Exp e)) stimes n (unlift -> (a,b,c,d,e) :: (Exp a, Exp b, Exp c, Exp d, Exp e)) = lift (stimes n a, stimes n b, stimes n c, stimes n d, stimes n e)