module Spark.Core.Internal.FunctionsInternals(
DynColPackable,
StaticColPackable2,
NameTuple(..),
TupleEquivalence(..),
asCol,
asCol',
pack1,
pack,
pack',
struct',
struct,
checkOrigin,
projectColFunction,
projectColFunction',
projectColFunction2',
colOpNoBroadcast
) where
import Control.Arrow
import Data.Aeson(toJSON)
import qualified Data.Vector as V
import qualified Data.Map.Strict as M
import qualified Data.List.NonEmpty as N
import qualified Data.Text as T
import Formatting
import Spark.Core.Internal.ColumnStructures
import Spark.Core.Internal.ColumnFunctions
import Spark.Core.Internal.DatasetFunctions
import Spark.Core.Internal.DatasetStructures
import Spark.Core.Internal.Utilities
import Spark.Core.Internal.TypesFunctions
import Spark.Core.Internal.LocalDataFunctions
import Spark.Core.Internal.TypesStructures
import Spark.Core.Internal.Projections
import Spark.Core.Internal.OpStructures
import Spark.Core.Internal.TypesGenerics(SQLTypeable, buildType)
import Spark.Core.StructuresInternal
import Spark.Core.Try
class DynColPackable a where
_packAsColumn :: a -> DynColumn
class StaticColPackable2 ref a b | a -> ref where
_staticPackAsColumn2 :: a -> Column ref b
data NameTuple to = NameTuple [String]
class TupleEquivalence to tup | to -> tup where
tupleFieldNames :: NameTuple to
asCol :: Dataset a -> Column a a
asCol ds =
iEmptyCol ds (unsafeCastType $ nodeType ds) (FieldPath V.empty)
asCol' :: DataFrame -> DynColumn
asCol' = ((iUntypedColData . asCol) <$>)
pack1 :: Column ref a -> Dataset a
pack1 = _pack1
pack' :: (DynColPackable a) => a -> DataFrame
pack' z = pack1 <$> _packAsColumn z
pack :: forall ref a b. (StaticColPackable2 ref a b) => a -> Dataset b
pack z =
let c = _staticPackAsColumn2 z :: ColumnData ref b
in pack1 c
struct' :: [DynColumn] -> DynColumn
struct' cols = do
l <- sequence cols
let fields = (colFieldName &&& id) <$> l
_buildStruct fields
struct :: forall ref a b. (StaticColPackable2 ref a b) => a -> Column ref b
struct = _staticPackAsColumn2
checkOrigin :: [DynColumn] -> Try [UntypedColumnData]
checkOrigin x = _checkOrigin =<< sequence x
projectColFunction :: forall x y.
(HasCallStack, SQLTypeable y, SQLTypeable x) =>
(forall ref. Column ref x -> Column ref y) -> LocalData x -> LocalData y
projectColFunction f o =
let o' = untypedLocalData o
sqltx = buildType :: SQLType x
sqlty = buildType :: SQLType y
f' :: UntypedColumnData -> Try UntypedColumnData
f' x = dropColType . f <$> castTypeCol sqltx x
o2 = projectColFunctionUntyped (f' =<<) o'
o3 = castType sqlty =<< o2
in forceRight o3
projectColFunctionUntyped ::
(DynColumn -> DynColumn) -> UntypedLocalData -> LocalFrame
projectColFunctionUntyped f obs = do
let dt = unSQLType (nodeType obs)
let no = NodeDistributedLit dt V.empty
let ds = emptyDataset no (SQLType dt)
let c = asCol ds
colRes <- f (pure (dropColType c))
let dtOut = unSQLType $ colType colRes
co <- _replaceObservables M.empty (colOp colRes)
let op = NodeStructuredTransform co
return $ emptyLocalData op (SQLType dtOut)
`parents` [untyped obs]
projectColFunction' ::
(DynColumn -> DynColumn) ->
LocalFrame -> LocalFrame
projectColFunction' f obs = projectColFunctionUntyped f =<< obs
projectColFunction2' ::
(DynColumn -> DynColumn -> DynColumn) ->
LocalFrame ->
LocalFrame ->
LocalFrame
projectColFunction2' f o1' o2' = do
let f2 :: DynColumn -> DynColumn
f2 dc = f (dc /- "_1") (dc /- "_2")
o1 <- o1'
o2 <- o2'
let o = iPackTupleObs $ o1 N.:| [o2]
projectColFunctionUntyped f2 o
colOpNoBroadcast :: GeneralizedColOp -> Try ColOp
colOpNoBroadcast = _replaceObservables M.empty
_checkOrigin :: [UntypedColumnData] -> Try [UntypedColumnData]
_checkOrigin [] = pure []
_checkOrigin l =
case _columnOrigin l of
[_] -> pure l
l' -> tryError $ sformat ("Too many distinct origins: "%sh) l'
instance forall x. (DynColPackable x) => DynColPackable [x] where
_packAsColumn = struct' . (_packAsColumn <$>)
instance DynColPackable DynColumn where
_packAsColumn = id
instance forall ref a. DynColPackable (Column ref a) where
_packAsColumn = pure . iUntypedColData
instance forall z1 z2. (DynColPackable z1, DynColPackable z2) => DynColPackable (z1, z2) where
_packAsColumn (c1, c2) = struct' [_packAsColumn c1, _packAsColumn c2]
instance forall ref a. StaticColPackable2 ref (Column ref a) a where
_staticPackAsColumn2 = id
instance forall a1 a2. TupleEquivalence (a1, a2) (a1, a2) where
tupleFieldNames = NameTuple ["_1", "_2"]
instance forall ref b a1 a2 z1 z2. (
TupleEquivalence b (a1, a2),
StaticColPackable2 ref z1 a1,
StaticColPackable2 ref z2 a2) =>
StaticColPackable2 ref (z1, z2) b where
_staticPackAsColumn2 (c1, c2) =
let
x1 = iUntypedColData (_staticPackAsColumn2 c1 :: Column ref a1)
x2 = iUntypedColData (_staticPackAsColumn2 c2 :: Column ref a2)
names = tupleFieldNames :: NameTuple b
in _unsafeBuildStruct [x1, x2] names
instance forall ref b a1 a2 a3 z1 z2 z3. (
TupleEquivalence b (a1, a2, a3),
StaticColPackable2 ref z1 a1,
StaticColPackable2 ref z2 a2,
StaticColPackable2 ref z3 a3) =>
StaticColPackable2 ref (z1, z2, z3) b where
_staticPackAsColumn2 (c1, c2, c3) =
let
x1 = iUntypedColData (_staticPackAsColumn2 c1 :: Column ref a1)
x2 = iUntypedColData (_staticPackAsColumn2 c2 :: Column ref a2)
x3 = iUntypedColData (_staticPackAsColumn2 c3 :: Column ref a3)
names = tupleFieldNames :: NameTuple b
in _unsafeBuildStruct [x1, x2, x3] names
_unsafeBuildStruct :: [UntypedColumnData] -> NameTuple x -> Column ref x
_unsafeBuildStruct cols (NameTuple names) =
if length cols /= length names
then failure $ sformat ("The number of columns and names differs:"%sh%" and "%sh) cols names
else
let fnames = unsafeFieldName . T.pack <$> names
uc = _buildStruct (fnames `zip` cols)
z = forceRight uc
in z { _cOp = _cOp z }
_buildTuple :: [UntypedColumnData] -> Try UntypedColumnData
_buildTuple l = _buildStruct (zip names l) where
names = (:[]) . unsafeFieldName . ("_" <> ) . show' $ [0..(length l)]
_buildStruct :: [(FieldName, UntypedColumnData)] -> Try UntypedColumnData
_buildStruct cols = do
let fields = GenColStruct $ (uncurry GeneralizedTransField . (fst &&& colOp . snd)) <$> V.fromList cols
st <- structTypeFromFields $ (fst &&& unSQLType . colType . snd) <$> cols
let name = structName st
case _columnOrigin (snd <$> cols) of
[ds] ->
pure ColumnData {
_cOrigin = ds,
_cType = StrictType (Struct st),
_cOp = fields,
_cReferingPath = Just $ unsafeFieldName name
}
l -> tryError $ sformat ("_buildStruct: Too many distinct origins: "%sh) l
_columnOrigin :: [UntypedColumnData] -> [UntypedDataset]
_columnOrigin l =
let
groups = myGroupBy' (nodeId . colOrigin) l
in (colOrigin . head . snd) <$> groups
_pack1 :: (HasCallStack) => Column ref a -> Dataset a
_pack1 ucd =
let gco = colOp ucd
ulds = _collectObs gco
in case ulds of
[] -> let co = forceRight $ colOpNoBroadcast gco in
_packCol1 ucd co
(h : t) -> forceRight $ _packCol1WithObs ucd (h N.:| t)
_packCol1WithObs :: Column ref a -> N.NonEmpty UntypedLocalData -> Try (Dataset a)
_packCol1WithObs c ulds = do
let packedObs = iPackTupleObs ulds
let st = structTypeTuple (unSQLType . nodeType <$> ulds)
let names = V.toList $ structFieldName <$> structFields st
let paths = FieldPath . V.fromList . (unsafeFieldName "_2" : ) . (:[]) <$> names
let m = M.fromList ((nodeId <$> N.toList ulds) `zip` paths)
let joined = broadcastPair (colOrigin c) packedObs
co <- _replaceObservables m (colOp c)
let no = NodeStructuredTransform co
let f = emptyDataset no (colType c) `parents` [untyped joined]
return f
_replaceObservables :: M.Map NodeId FieldPath -> GeneralizedColOp -> Try ColOp
_replaceObservables m (GenColExtraction fp) | M.null m = pure $ ColExtraction fp
_replaceObservables _ (GenColExtraction (FieldPath v)) =
pure (ColExtraction (FieldPath v')) where
v' = V.cons (unsafeFieldName "_1") v
_replaceObservables _ (GenColLit dt c) = pure (ColLit dt (toJSON c))
_replaceObservables m (GenColFunction n v) =
ColFunction n <$> sequence (_replaceObservables m <$> v)
_replaceObservables m (GenColStruct v) = ColStruct <$> sequence (_replaceField m <$> v)
_replaceObservables m (BroadcastColOp uld) =
case M.lookup (nodeId uld) m of
Just p -> pure $ ColExtraction p
Nothing -> tryError $ "_replaceObservables: error: missing key " <> show' uld <> " in " <> show' m
_replaceField :: M.Map NodeId FieldPath -> GeneralizedTransField -> Try TransformField
_replaceField m (GeneralizedTransField n v) = TransformField n <$> _replaceObservables m v
_packCol1 :: Column ref a -> ColOp -> Dataset a
_packCol1 c (ColExtraction (FieldPath v)) | V.null v =
forceRight $ castType (colType c) (colOrigin c)
_packCol1 c op =
emptyDataset (NodeStructuredTransform op) (colType c)
`parents` [untyped (colOrigin c)]
_collectObs :: GeneralizedColOp -> [UntypedLocalData]
_collectObs (GenColFunction _ v) = concat (_collectObs <$> V.toList v)
_collectObs (BroadcastColOp uld) = [uld]
_collectObs (GenColStruct v) = concat (_collectObs . gtfValue <$> V.toList v)
_collectObs _ = []