-- Copyright 2016 TensorFlow authors. -- -- Licensed under the Apache License, Version 2.0 (the "License"); -- you may not use this file except in compliance with the License. -- You may obtain a copy of the License at -- -- http://www.apache.org/licenses/LICENSE-2.0 -- -- Unless required by applicable law or agreed to in writing, software -- distributed under the License is distributed on an "AS IS" BASIS, -- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -- See the License for the specific language governing permissions and -- limitations under the License. {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TupleSections #-} module TensorFlow.BuildOp ( BuildResult(..) , buildOp , PureResult(..) , pureOp , eqLengthGuard , BuildInputs(..) , OpParams ) where import Control.Monad (liftM2, replicateM) import Control.Monad.Reader (ReaderT, runReaderT, ask) import Control.Monad.State.Strict (State, evalState, get, put) import Data.Int (Int64) import TensorFlow.Build import TensorFlow.Output import TensorFlow.Tensor import TensorFlow.Types data ResultState = ResultState !OutputIx [Int64] deriving Show type Result = ReaderT NodeName (State ResultState) -- | Class of types that can be used as op outputs. class BuildResult a where buildResult :: Result a instance (BuildResult a1, BuildResult a2) => BuildResult (a1, a2) where buildResult = (,) <$> buildResult <*> buildResult instance (BuildResult a1, BuildResult a2, BuildResult a3) => BuildResult (a1, a2, a3) where buildResult = (,,) <$> buildResult <*> buildResult <*> buildResult instance (BuildResult a1, BuildResult a2, BuildResult a3, BuildResult a4) => BuildResult (a1, a2, a3, a4) where buildResult = (,,,) <$> buildResult <*> buildResult <*> buildResult <*> buildResult instance (BuildResult a1, BuildResult a2, BuildResult a3, BuildResult a4, BuildResult a5) => BuildResult (a1, a2, a3, a4, a5) where buildResult = (,,,,) <$> buildResult <*> buildResult <*> buildResult <*> buildResult <*> buildResult instance ( BuildResult a1 , BuildResult a2 , BuildResult a3 , BuildResult a4 , BuildResult a5 , BuildResult a6 ) => BuildResult (a1, a2, a3, a4, a5, a6) where buildResult = (,,,,,) <$> buildResult <*> buildResult <*> buildResult <*> buildResult <*> buildResult <*> buildResult instance ( BuildResult a1 , BuildResult a2 , BuildResult a3 , BuildResult a4 , BuildResult a5 , BuildResult a6 , BuildResult a7 ) => BuildResult (a1, a2, a3, a4, a5, a6, a7) where buildResult = (,,,,,,) <$> buildResult <*> buildResult <*> buildResult <*> buildResult <*> buildResult <*> buildResult <*> buildResult instance ( BuildResult a1 , BuildResult a2 , BuildResult a3 , BuildResult a4 , BuildResult a5 , BuildResult a6 , BuildResult a7 , BuildResult a8 ) => BuildResult (a1, a2, a3, a4, a5, a6, a7, a8) where buildResult = (,,,,,,,) <$> buildResult <*> buildResult <*> buildResult <*> buildResult <*> buildResult <*> buildResult <*> buildResult <*> buildResult recordResult :: Result Output recordResult = do o <- ask ResultState i ns <- get put $! ResultState (i+1) ns return $! output i o instance (TensorKind v, Rendered (Tensor v)) => BuildResult (Tensor v a) where buildResult = Tensor . pure <$> recordResult instance BuildResult ControlNode where buildResult = ControlNode <$> ask instance (TensorKind v, Rendered (Tensor v), TensorTypes as) => BuildResult (TensorList v as) where buildResult = loop (tensorTypes :: TensorTypeList as) where loop :: TensorTypeList bs -> Result (TensorList v bs) loop Nil = return Nil loop (TensorTypeProxy :/ ls) = do t <- buildResult ts <- loop ls return (t :/ ts) instance BuildResult a => BuildResult [a] where buildResult = do ResultState i ns <- get case ns of [] -> error $ "Ran out of counts in buildResult. " ++ "Likely misuse of buildOp." (n : rest) -> do put $! ResultState i rest replicateM (fromIntegral n) buildResult buildOp :: BuildResult a => [Int64] -> OpDef -> Build a buildOp sizes o = do n <- addNewOp o return $ flip evalState (ResultState 0 sizes) (runReaderT buildResult n) -- | Returns true if all the integers in each tuple are identical. -- Throws an error with a descriptive message if not. eqLengthGuard :: [(String, [(String, Int)])] -> Bool eqLengthGuard = all eachOk where eachOk (_, []) = True -- The next line has (== 1) . length . nub in disguise eachOk (numberAttrName, pairs@((_, x) : zs)) = all (\z -> snd z == x) zs || error ("number_attr " ++ numberAttrName ++ " contains tensors with different length " ++ show pairs) ----------- -- | Class of types that can be used as op outputs. class PureResult a where pureResult :: ReaderT (Build OpDef) (State ResultState) a instance PureResult (Tensor Build a) where pureResult = do ResultState i ns <- get put $! ResultState (i+1) ns makeOp <- ask return $ Tensor $ do o <- makeOp -- TODO: unify with BuildResult (Tensor v) output i <$> getOrAddOp o instance (PureResult a1, PureResult a2) => PureResult (a1, a2) where pureResult = (,) <$> pureResult <*> pureResult instance (PureResult a1, PureResult a2, PureResult a3) => PureResult (a1, a2, a3) where pureResult = (,,) <$> pureResult <*> pureResult <*> pureResult instance (PureResult a1, PureResult a2, PureResult a3, PureResult a4) => PureResult (a1, a2, a3, a4) where pureResult = (,,,) <$> pureResult <*> pureResult <*> pureResult <*> pureResult instance (PureResult a1, PureResult a2, PureResult a3, PureResult a4, PureResult a5) => PureResult (a1, a2, a3, a4, a5) where pureResult = (,,,,) <$> pureResult <*> pureResult <*> pureResult <*> pureResult <*> pureResult instance ( PureResult a1 , PureResult a2 , PureResult a3 , PureResult a4 , PureResult a5 , PureResult a6 ) => PureResult (a1, a2, a3, a4, a5, a6) where pureResult = (,,,,,) <$> pureResult <*> pureResult <*> pureResult <*> pureResult <*> pureResult <*> pureResult instance ( PureResult a1 , PureResult a2 , PureResult a3 , PureResult a4 , PureResult a5 , PureResult a6 , PureResult a7 ) => PureResult (a1, a2, a3, a4, a5, a6, a7) where pureResult = (,,,,,,) <$> pureResult <*> pureResult <*> pureResult <*> pureResult <*> pureResult <*> pureResult <*> pureResult instance ( PureResult a1 , PureResult a2 , PureResult a3 , PureResult a4 , PureResult a5 , PureResult a6 , PureResult a7 , PureResult a8 ) => PureResult (a1, a2, a3, a4, a5, a6, a7, a8) where pureResult = (,,,,,,,) <$> pureResult <*> pureResult <*> pureResult <*> pureResult <*> pureResult <*> pureResult <*> pureResult <*> pureResult instance PureResult a => PureResult [a] where pureResult = do ResultState i ns <- get case ns of [] -> error $ "Ran out of counts in pureResult. " ++ "Likely misuse of pureOp with output lists." n : rest -> do put $! ResultState i rest replicateM (fromIntegral n) pureResult instance TensorTypes as => PureResult (TensorList Build as) where pureResult = loop (tensorTypes :: TensorTypeList as) where loop :: TensorTypeList bs -> ReaderT (Build OpDef) (State ResultState) (TensorList Build bs) loop Nil = return Nil loop (TensorTypeProxy :/ ls) = do t <- pureResult ts <- loop ls return (t :/ ts) pureOp :: PureResult a => [Int64] -> Build OpDef -> a pureOp sizes o = flip evalState (ResultState 0 sizes) (runReaderT pureResult o) ----- -- Class of types that can be used as arguments class BuildInputs a where buildInputs :: a -> Build [Output] instance BuildInputs a => BuildInputs [a] where buildInputs = fmap concat . mapM buildInputs instance BuildInputs (Tensor v a) where buildInputs (Tensor t) = do o <- toBuild t return [o] instance BuildInputs (ListOf (Tensor v) as) where buildInputs Nil = return [] buildInputs (t :/ ts) = liftM2 (++) (buildInputs t) (buildInputs ts) ---- -- | Parameters to build an op (for example, the node name or optional attributes). -- TODO: be more type safe. type OpParams = OpDef -> OpDef