module MathFlow.Core where
import GHC.TypeLits
import Data.Singletons
import Data.Singletons.TH
import Data.Promotion.Prelude
type family IsSubSamp (f :: [Nat]) (m :: [Nat]) (n :: [Nat]) :: Bool where
  IsSubSamp (1:fs) (m:ms) (n:ns) = IsSubSamp fs ms ns
  IsSubSamp (f:fs) (m:ms) (n:ns) = ((n * f) :== m) :&& (IsSubSamp fs ms ns)
  IsSubSamp '[] '[] '[] = 'True
  IsSubSamp _ _ _ = 'False
type family IsMatMul (m :: [Nat]) (o :: [Nat]) (n :: [Nat]) :: Bool where
  IsMatMul m o n =
    Last n :== Last o :&&
    Last m :== Head (Tail (Reverse o)) :&&
    (Tail (Reverse n)) :== (Tail (Reverse m)) :&&
    (Tail (Tail (Reverse n))) :== (Tail (Tail (Reverse o)))
type family IsConcat (m :: [Nat]) (o :: [Nat]) (n :: [Nat]) :: Bool where
  IsConcat (m:mx) (o:ox) (n:nx) = (m :== o :&& m:== n :|| m + o :== n) :&& IsConcat mx ox nx
  IsConcat '[] '[] '[] = 'True
  IsConcat _ _ _ = 'False
type family IsSameProduct (m :: [Nat]) (n :: [Nat]) :: Bool where
  IsSameProduct (m:mx) (n:nx) = m :== n :&& (Product mx :== Product nx)
  IsSameProduct mx nx = Product mx :== Product nx
data Tensor (n::[Nat]) t a =
    (Num t) => TScalar t 
  | Tensor a 
  | TAdd (Tensor n t a) (Tensor n t a) 
  | TSub (Tensor n t a) (Tensor n t a) 
  | TMul (Tensor n t a) (Tensor n t a) 
  | TAbs (Tensor n t a) 
  | TSign (Tensor n t a) 
  | TRep (Tensor (Tail n) t a) 
  | TTr (Tensor (Reverse n) t a) 
  | forall o m. (SingI o,SingI m,SingI n,IsMatMul m o n ~ 'True) => TMatMul (Tensor m t a) (Tensor o t a) 
  | forall o m. (SingI o,SingI m,SingI n,IsConcat m o n ~ 'True) => TConcat (Tensor m t a) (Tensor o t a) 
  | forall m. (SingI m,IsSameProduct m n ~ 'True) => TReshape (Tensor m t a) 
  | forall o m.
    (SingI o,SingI m,
     Last n ~ Last o,
     Last m ~ Head (Tail (Reverse o)),
     (Tail (Reverse n)) ~ (Tail (Reverse m))
    ) =>
    TConv2d (Tensor m t a) (Tensor o t a) 
  | forall f m. (SingI f, SingI m,IsSubSamp f m n ~ 'True) => TMaxPool (Sing f) (Tensor m t a) 
  | TSoftMax (Tensor n t a)
  | TReLu (Tensor n t a)
  | TNorm (Tensor n t a)
  | forall f m. (SingI f,SingI m,IsSubSamp f m n ~ 'True) => TSubSamp (Sing f) (Tensor m t a) 
  | forall m t2. TApp (Tensor n t a) (Tensor m t2 a)
  | TFunc String (Tensor n t a)
  | TSym String
  | TArgT String (Tensor n t a)
  | TArgS String String
  | TArgI String Integer
  | TArgF String Float
  | TArgD String Double
  | forall f. (SingI f) => TArgSing String (Sing (f::[Nat]))
  | TLabel String (Tensor n t a) 
(<+>) :: forall n t a m t2. (Tensor n t a) -> (Tensor m t2 a) -> (Tensor n t a)
(<+>) = TApp
infixr 4 <+>
instance (Num t) => Num (Tensor n t a) where
  (+) = TAdd
  () = TSub
  (*) = TMul
  abs = TAbs
  signum = TSign
  fromInteger = TScalar . fromInteger
class Dimension a where
  dim :: a -> [Integer]
instance (SingI n) => Dimension (Tensor n t a) where
  dim t = dim $ ty t
    where
      ty :: (SingI n) => Tensor n t a -> Sing n
      ty _ = sing
instance Dimension (Sing (n::[Nat])) where
  dim t = fromSing t
toValue :: forall n t a. Sing (n::[Nat]) -> a -> Tensor n t a
toValue _ a = Tensor a
(%*) :: forall o m n t a. (SingI o,SingI m,SingI n,IsMatMul m o n ~ 'True)
     => Tensor m t a -> Tensor o t a -> Tensor n t a
(%*) a b = TMatMul a b
(<--) :: SingI n => String -> Tensor n t a  -> Tensor n t a 
(<--) = TLabel
class FromTensor a where
  fromTensor :: Tensor n t a -> a
  toString :: Tensor n t a -> String
  run :: Tensor n t a -> IO (Int,String,String)