{-# LANGUAGE DataKinds #-} {-# LANGUAGE DefaultSignatures #-} {-# LANGUAGE EmptyCase #-} {-# LANGUAGE ExistentialQuantification #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE InstanceSigs #-} {-# LANGUAGE NoStarIsType #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DeriveAnyClass #-} {-# LANGUAGE CPP #-} #if MIN_VERSION_base(4,14,0) {-# LANGUAGE StandaloneKindSignatures #-} #endif ----------------------------------------------------------------------------- {-| Module : Math.Tensor.Safe.TH Description : Type families and singletons for generalized types. Copyright : (c) Nils Alex, 2020 License : MIT Maintainer : nils.alex@fau.de Type families and singletons for generalized types. For documentation see re-exports in "Math.Tensor.Safe". -} ----------------------------------------------------------------------------- module Math.Tensor.Safe.TH where import Data.Singletons.Prelude import Data.Singletons.Prelude.Enum import Data.Singletons.Prelude.List.NonEmpty hiding (sLength) import Data.Singletons.Prelude.Ord import Data.Singletons.TH import Data.Singletons.TypeLits import Data.List.NonEmpty (NonEmpty((:|)),sort,sortBy,(<|)) import GHC.Generics (Generic,Generic1) import Control.DeepSeq (NFData,NFData1) $(singletons [d| data N where Z :: N S :: N -> N fromNat :: Nat -> N fromNat n = case n == 0 of True -> Z False -> S $ fromNat (pred n) deriving instance Generic N deriving instance NFData N deriving instance Show N deriving instance Eq N instance Ord N where Z <= _ = True (S _) <= Z = False (S n) <= (S m) = n <= m instance Num N where Z + n = n (S n) + m = S $ n + m n - Z = n Z - S _ = error "cannot subtract (S n) from Z!" (S n) - (S m) = n - m negate Z = Z negate _ = error "cannot negate (S n)!" Z * _ = Z (S n) * m = m + n * m abs n = n signum n = n fromInteger n | n == 0 = Z | otherwise = S $ fromInteger (n-1) data VSpace a b = VSpace { vId :: a, vDim :: b } deriving (Show, Ord, Eq, Generic, NFData, Generic1, NFData1) data Ix a = ICon a | ICov a deriving (Show, Ord, Eq, Generic, NFData, Generic1, NFData1) ixCompare :: Ord a => Ix a -> Ix a -> Ordering ixCompare (ICon a) (ICon b) = compare a b ixCompare (ICon a) (ICov b) = case compare a b of LT -> LT EQ -> LT GT -> GT ixCompare (ICov a) (ICon b) = case compare a b of LT -> LT EQ -> GT GT -> GT ixCompare (ICov a) (ICov b) = compare a b data IList a = ConCov (NonEmpty a) (NonEmpty a) | Cov (NonEmpty a) | Con (NonEmpty a) deriving (Show, Ord, Eq, Generic, NFData, Generic1, NFData1) type GRank s n = [(VSpace s n, IList s)] type Rank = GRank Symbol Nat isAscending :: Ord a => [a] -> Bool isAscending [] = True isAscending [_] = True isAscending (x:y:xs) = x < y && isAscending (y:xs) isAscendingNE :: Ord a => NonEmpty a -> Bool isAscendingNE (x :| xs) = isAscending (x:xs) isAscendingI :: Ord a => IList a -> Bool isAscendingI (ConCov x y) = isAscendingNE x && isAscendingNE y isAscendingI (Con x) = isAscendingNE x isAscendingI (Cov y) = isAscendingNE y isLengthNE :: NonEmpty a -> Nat -> Bool isLengthNE (_ :| []) l = l == 1 isLengthNE (_ :| (x:xs)) l = isLengthNE (x :| xs) (pred l) lengthNE :: NonEmpty a -> N lengthNE (_ :| []) = S Z lengthNE (_ :| (x:xs)) = S Z + lengthNE (x :| xs) lengthIL :: IList a -> N lengthIL (ConCov xs ys) = lengthNE xs + lengthNE ys lengthIL (Con xs) = lengthNE xs lengthIL (Cov ys) = lengthNE ys lengthR :: GRank s n -> N lengthR [] = Z lengthR ((_,x):xs) = lengthIL x + lengthR xs sane :: (Ord a, Ord b) => [(VSpace a b, IList a)] -> Bool sane [] = True sane [(_, is)] = isAscendingI is sane ((v, is):(v', is'):xs) = v < v' && isAscendingI is && sane ((v',is'):xs) headR :: Ord s => GRank s n -> (VSpace s n, Ix s) headR ((v, l):_) = (v, case l of ConCov (a :| _) (b :| _) -> case compare a b of LT -> ICon a EQ -> ICon a GT -> ICov b Con (a :| _) -> ICon a Cov (a :| _) -> ICov a) headR [] = error "headR of empty list" tailR :: Ord s => GRank s n -> GRank s n tailR ((v, l):ls) = let l' = case l of ConCov (a :| []) (b :| []) -> case compare a b of LT -> Just $ Cov (b :| []) EQ -> Just $ Cov (b :| []) GT -> Just $ Con (a :| []) ConCov (a :| (a':as)) (b :| []) -> case compare a b of LT -> Just $ ConCov (a' :| as) (b :| []) EQ -> Just $ ConCov (a' :| as) (b :| []) GT -> Just $ Con (a :| (a':as)) ConCov (a :| []) (b :| (b':bs)) -> case compare a b of LT -> Just $ Cov (b :| (b':bs)) EQ -> Just $ Cov (b :| (b':bs)) GT -> Just $ ConCov (a :| []) (b' :| bs) ConCov (a :| (a':as)) (b :| (b':bs)) -> case compare a b of LT -> Just $ ConCov (a' :| as) (b :| (b':bs)) EQ -> Just $ ConCov (a' :| as) (b :| (b':bs)) GT -> Just $ ConCov (a :| (a':as)) (b' :| bs) Con (_ :| []) -> Nothing Con (_ :| (a':as)) -> Just $ Con (a' :| as) Cov (_ :| []) -> Nothing Cov (_ :| (a':as)) -> Just $ Cov (a' :| as) in case l' of Just l'' -> (v, l''):ls Nothing -> ls tailR [] = error "tailR of empty list" mergeR :: (Ord s, Ord n) => GRank s n -> GRank s n -> Maybe (GRank s n) mergeR [] ys = Just ys mergeR xs [] = Just xs mergeR ((xv,xl):xs) ((yv,yl):ys) = case compare xv yv of LT -> ((xv,xl) :) <$> mergeR xs ((yv,yl):ys) EQ -> do xl' <- mergeIL xl yl xs' <- mergeR xs ys Just $ (xv, xl') : xs' GT -> ((yv,yl) :) <$> mergeR ((xv,xl):xs) ys mergeIL :: Ord a => IList a -> IList a -> Maybe (IList a) mergeIL (ConCov xs ys) (ConCov xs' ys') = do xs'' <- mergeNE xs xs' ys'' <- mergeNE ys ys' Just $ ConCov xs'' ys'' mergeIL (ConCov xs ys) (Con xs') = do xs'' <- mergeNE xs xs' Just $ ConCov xs'' ys mergeIL (ConCov xs ys) (Cov ys') = do ys'' <- mergeNE ys ys' Just $ ConCov xs ys'' mergeIL (Con xs) (ConCov xs' ys) = do xs'' <- mergeNE xs xs' Just $ ConCov xs'' ys mergeIL (Con xs) (Con xs') = Con <$> mergeNE xs xs' mergeIL (Con xs) (Cov ys) = Just $ ConCov xs ys mergeIL (Cov ys) (ConCov xs ys') = do ys'' <- mergeNE ys ys' Just $ ConCov xs ys'' mergeIL (Cov ys) (Con xs) = Just $ ConCov xs ys mergeIL (Cov ys) (Cov ys') = Cov <$> mergeNE ys ys' merge :: Ord a => [a] -> [a] -> Maybe [a] merge [] ys = Just ys merge xs [] = Just xs merge (x:xs) (y:ys) = case compare x y of LT -> (x :) <$> merge xs (y:ys) EQ -> Nothing GT -> (y :) <$> merge (x:xs) ys mergeNE :: Ord a => NonEmpty a -> NonEmpty a -> Maybe (NonEmpty a) mergeNE (x :| xs) (y :| ys) = case compare x y of LT -> (x :|) <$> merge xs (y:ys) EQ -> Nothing GT -> (y :|) <$> merge (x:xs) ys contractR :: Ord s => GRank s n -> GRank s n contractR [] = [] contractR ((v, is):xs) = case contractI is of Nothing -> contractR xs Just is' -> (v, is') : contractR xs prepICon :: a -> IList a -> IList a prepICon a (ConCov (x:|xs) y) = ConCov (a:|(x:xs)) y prepICon a (Con (x:|xs)) = Con (a:|(x:xs)) prepICon a (Cov y) = ConCov (a:|[]) y prepICov :: a -> IList a -> IList a prepICov a (ConCov x (y:|ys)) = ConCov x (a:|(y:ys)) prepICov a (Con x) = ConCov x (a:|[]) prepICov a (Cov (y:|ys)) = Cov (a:|(y:ys)) contractI :: Ord a => IList a -> Maybe (IList a) contractI (ConCov (x:|xs) (y:|ys)) = case compare x y of EQ -> case xs of [] -> case ys of [] -> Nothing (y':ys') -> Just $ Cov (y' :| ys') (x':xs') -> case ys of [] -> Just $ Con (x' :| xs') (y':ys') -> contractI $ ConCov (x':|xs') (y':|ys') LT -> case xs of [] -> Just $ ConCov (x:|xs) (y:|ys) (x':xs') -> case contractI $ ConCov (x':|xs') (y:|ys) of Nothing -> Just $ Con (x:|[]) Just i -> Just $ prepICon x i GT -> case ys of [] -> Just $ ConCov (x:|xs) (y:|ys) (y':ys') -> case contractI $ ConCov (x:|xs) (y':|ys') of Nothing -> Just $ Cov (y:|[]) Just i -> Just $ prepICov y i contractI (Con x) = Just $ Con x contractI (Cov x) = Just $ Cov x subsetNE :: Ord a => NonEmpty a -> NonEmpty a -> Bool subsetNE (x :| []) ys = x `elemNE` ys subsetNE (x :| (x':xs)) ys = (x `elemNE` ys) && ((x' :| xs) `subsetNE` ys) elemNE :: Ord a => a -> NonEmpty a -> Bool elemNE a (x :| []) = a == x elemNE a (x :| (x':xs)) = case compare a x of LT -> False EQ -> True GT -> elemNE a (x' :| xs) canTransposeCon :: (Ord s, Ord n) => VSpace s n -> s -> s -> GRank s n -> Bool canTransposeCon _ _ _ [] = False canTransposeCon v a b ((v',il):r) = case compare v v' of LT -> False GT -> canTransposeCon v a b r EQ -> case il of Cov _ -> canTransposeCon v a b r Con cs -> case elemNE a cs of True -> case elemNE b cs of True -> True False -> False False -> case elemNE b cs of True -> False False -> canTransposeCon v a b r ConCov cs _ -> case elemNE a cs of True -> case elemNE b cs of True -> True False -> False False -> case elemNE b cs of True -> False False -> canTransposeCon v a b r canTransposeCov :: (Ord s, Ord n) => VSpace s n -> s -> s -> GRank s n -> Bool canTransposeCov _ _ _ [] = False canTransposeCov v a b ((v',il):r) = case compare v v' of LT -> False GT -> canTransposeCov v a b r EQ -> case il of Con _ -> canTransposeCov v a b r Cov cs -> case elemNE a cs of True -> case elemNE b cs of True -> True False -> False False -> case elemNE b cs of True -> False False -> canTransposeCov v a b r ConCov _ cs -> case elemNE a cs of True -> case elemNE b cs of True -> True False -> False False -> case elemNE b cs of True -> False False -> canTransposeCov v a b r canTranspose :: (Ord s, Ord n) => VSpace s n -> Ix s -> Ix s -> GRank s n -> Bool canTranspose v (ICon a) (ICon b) r = case compare a b of LT -> canTransposeCon v a b r EQ -> True GT -> canTransposeCov v b a r canTranspose v (ICov a) (ICov b) r = case compare a b of LT -> canTransposeCov v a b r EQ -> True GT -> canTransposeCov v b a r canTranspose _ (ICov _) (ICon _) _ = False canTranspose _ (ICon _) (ICov _) _ = False removeUntil :: Ord s => Ix s -> GRank s n -> GRank s n removeUntil i r = go i r where go i' r' | snd (headR r') == i' = tailR r' | otherwise = go i $ tailR r' data TransRule a = TransCon (NonEmpty a) (NonEmpty a) | TransCov (NonEmpty a) (NonEmpty a) deriving (Show, Eq, Generic, NFData, Generic1, NFData1) saneTransRule :: Ord a => TransRule a -> Bool saneTransRule tl = case tl of TransCon sources targets -> isAscendingNE sources && sort targets == sources TransCov sources targets -> isAscendingNE sources && sort targets == sources canTransposeMult :: (Ord s, Ord n) => VSpace s n -> TransRule s -> GRank s n -> Bool canTransposeMult vs tl r = case transpositions vs tl r of Just _ -> True Nothing -> False transpositions :: (Ord s, Ord n) => VSpace s n -> TransRule s -> GRank s n -> Maybe [(N,N)] transpositions _ _ [] = Nothing transpositions vs tl ((vs',il):r) = case saneTransRule tl of False -> Nothing True -> case compare vs vs' of LT -> Nothing GT -> transpositions vs tl r EQ -> case il of Con xs -> case tl of TransCon sources targets -> transpositions' sources targets (fmap Just xs) TransCov _ _ -> Nothing Cov xs -> case tl of TransCov sources targets -> transpositions' sources targets (fmap Just xs) TransCon _ _ -> Nothing ConCov xsCon xsCov -> case tl of TransCon sources targets -> transpositions' sources targets (xsCon `zipCon` xsCov) TransCov sources targets -> transpositions' sources targets (xsCon `zipCov` xsCov) zipCon :: Ord a => NonEmpty a -> NonEmpty a -> NonEmpty (Maybe a) zipCon (x :| xs) (y :| ys) = case ICon x `ixCompare` ICov y of LT -> case xs of [] -> Just x :| fmap (const Nothing) (y:ys) (x':xs') -> Just x <| zipCon (x' :| xs') (y :| ys) GT -> case ys of [] -> Nothing :| fmap Just (x : xs) (y':ys') -> Nothing <| zipCon (x :| xs) (y' :| ys') zipCov :: Ord a => NonEmpty a -> NonEmpty a -> NonEmpty (Maybe a) zipCov (x :| xs) (y :| ys) = case ICon x `ixCompare` ICov y of LT -> case xs of [] -> Nothing :| fmap Just (y:ys) (x':xs') -> Nothing <| zipCov (x' :| xs') (y :| ys) GT -> case ys of [] -> Just y :| fmap (const Nothing) (x : xs) (y':ys') -> Just y <| zipCov (x :| xs) (y' :| ys') transpositions' :: Eq a => NonEmpty a -> NonEmpty a -> NonEmpty (Maybe a) -> Maybe [(N,N)] transpositions' sources targets xs = do ss <- mapM (`find` xs') sources ts <- mapM (`find` xs') targets zip' ss ts where xs' = go' Z xs go' :: N -> NonEmpty a -> NonEmpty (N,a) go' n (y :| []) = (n,y) :| [] go' n (y :| (y':ys')) = (n,y) <| go' (S n) (y' :| ys') find :: Eq a => a -> NonEmpty (N, Maybe a) -> Maybe N find _ ((_,Nothing) :| []) = Nothing find a ((_,Nothing) :| (y':ys')) = find a (y' :| ys') find a ((n,Just y) :| ys) = case a == y of True -> Just n False -> case ys of [] -> Nothing y':ys' -> find a (y' :| ys') zip' :: NonEmpty a -> NonEmpty b -> Maybe [(a,b)] zip' (a :| []) (b :| []) = Just [(a,b)] zip' (_ :| (_:_)) (_ :| []) = Nothing zip' (_ :| []) (_ :| (_:_)) = Nothing zip' (y:|(y':ys')) (z:|(z':zs')) = ((y,z):) <$> zip' (y':|ys') (z':|zs') type RelabelRule s = NonEmpty (s,s) saneRelabelRule :: Ord a => NonEmpty (a,a) -> Bool saneRelabelRule xs = isAscendingNE xs && isAscendingNE xs' where xs' = sort $ fmap (\(a,b) -> (b,a)) xs relabelNE :: Ord a => NonEmpty (a,a) -> NonEmpty a -> Maybe (NonEmpty (a,a)) relabelNE = go where go :: Ord a => NonEmpty (a,a) -> NonEmpty a -> Maybe (NonEmpty (a,a)) go ((source,target) :| ms) (x :| xs) = case source `compare` x of LT -> case ms of [] -> Just $ (\a -> (a,a)) <$> (x :| xs) m':ms' -> go (m' :| ms') (x :| xs) EQ -> case ms of [] -> Just $ ((target,source) :|) $ fmap (\a -> (a,a)) xs m':ms' -> case xs of [] -> Just $ (target,source) :| [] x':xs' -> ((target,source) <|) <$> go (m' :| ms') (x' :| xs') GT -> case xs of [] -> Just $ (x,x) :| [] x':xs' -> ((x,x) <|) <$> go ((source,target) :| ms) (x' :| xs') relabelR :: (Ord s, Ord n) => VSpace s n -> RelabelRule s -> GRank s n -> Maybe (GRank s n) relabelR _ _ [] = Nothing relabelR vs rls ((vs',il):r) = case vs `compare` vs' of LT -> Nothing EQ -> (\il' -> (vs',il'):r) <$> relabelIL rls il GT -> ((vs',il) :) <$> relabelR vs rls r relabelIL :: Ord a => NonEmpty (a,a) -> IList a -> Maybe (IList a) relabelIL rl is = case relabelIL' rl is of Nothing -> Nothing Just (Con is') -> Just $ Con $ fst <$> is' Just (Cov is') -> Just $ Cov $ fst <$> is' Just (ConCov is' js') -> Just $ ConCov (fst <$> is') (fst <$> js') relabelIL' :: Ord a => NonEmpty (a,a) -> IList a -> Maybe (IList (a,a)) relabelIL' rl (Con is) = do is' <- Con . sort <$> relabelNE rl is case isAscendingI is' of True -> return is' False -> Nothing relabelIL' rl (Cov is) = do is' <- Cov . sort <$> relabelNE rl is case isAscendingI is' of True -> return is' False -> Nothing relabelIL' rl (ConCov is js) = do is' <- sort <$> relabelNE rl is js' <- sort <$> relabelNE rl js let l' = ConCov is' js' case isAscendingI l' of True -> return l' False -> Nothing relabelTranspositions :: Ord a => NonEmpty (a,a) -> IList a -> Maybe [(N,N)] relabelTranspositions rl is = case relabelIL' rl is of Nothing -> Nothing Just (Con is') -> Just $ relabelTranspositions' is' Just (Cov is') -> Just $ relabelTranspositions' is' Just (ConCov is' js') -> Just $ relabelTranspositions' $ is' `zipConCov` js' zipConCov :: Ord a => NonEmpty a -> NonEmpty a -> NonEmpty a zipConCov = go where go :: Ord a => NonEmpty a -> NonEmpty a -> NonEmpty a go (i :| is) (j :| js) = case i `compare` j of LT -> case is of [] -> i <| (j:|js) i':is' -> i <| go (i' :| is') (j :| js) EQ -> case is of [] -> i <| (j:|js) i':is' -> i <| go (i' :| is') (j :| js) GT -> case js of [] -> j <| (i:|is) j':js' -> j <| go (i :| is) (j' :| js') relabelTranspositions' :: Ord a => NonEmpty (a,a) -> [(N,N)] relabelTranspositions' is = go'' is''' where is' = go Z is is'' = sortBy (\a b -> snd a `compare` snd b) is' is''' = go' Z is'' --is'''' = sort is''' go :: N -> NonEmpty (a,b) -> NonEmpty (N,b) go n ((_,y) :| []) = (n,y) :| [] go n ((_,y) :| (j:js)) = (n,y) <| go (S n) (j :| js) go' :: N -> NonEmpty (a,b) -> NonEmpty (a,N) go' n ((x,_) :| []) = (x,n) :| [] go' n ((x,_) :| (j:js)) = (x,n) <| go' (S n) (j :| js) go'' :: NonEmpty (a,a) -> [(a,a)] go'' ((x1,x2) :| []) = [(x2,x1)] go'' ((x1,x2) :| (y:ys)) = (x2,x1) : go'' (y :| ys) |]) toInt :: N -> Int toInt Z = 0 toInt (S n) = 1 + toInt n