Recursive SQL queries, still need to test before use

This commit is contained in:
fr33domlover 2016-06-12 22:37:52 +00:00
parent 55945e30f9
commit 76a627385c
4 changed files with 834 additions and 0 deletions

View file

@ -0,0 +1,255 @@
{- This file is part of Vervis.
-
- Written in 2016 by fr33domlover <fr33domlover@riseup.net>.
-
- Copying is an act of love. Please copy, reuse and share.
-
- The author(s) have dedicated all copyright and related and neighboring
- rights to this software to the public domain worldwide. This software is
- distributed without any warranty.
-
- You should have received a copy of the CC0 Public Domain Dedication along
- with this software. If not, see
- <http://creativecommons.org/publicdomain/zero/1.0/>.
-}
{- The code is based on PersistQuery. Actually, most of the difference is
- slightly different names and 3 additional function parameters.
-}
-- | Recursive queries are performed by taking the output of a recursion step,
-- possibly modifying it, and using the result as the input for the next
-- recursion step.
--
-- This module currently provides a single way to perform that recursive step:
-- Match between the /id/ column and some other column which has the same type.
--
-- For example, suppose we have a `Message` type with a `messageParent` field.
-- For given messages `a` and `b`, if `messageParent b == Just a` then `b` is a
-- reply to `a`. Therefore, all the replies to a given message point to it
-- using the `messageParent` field. And there can be replies to replies and so
-- on, creating a tree of messages.
--
-- > Message
-- > author PersonId
-- > content Text
-- > parent MessageId Maybe
--
-- If we start with a single message and follow the `messageParent` values
-- recursively, we'll be able to get a list (or a tree) of the __ancestors__ of
-- the message. Our message /a/ may be a reply to some other message /b/, and
-- /b/ may be a reply to message /c/ and so on. Eventually, if there are no
-- cycles and it's really a tree structure, we'll reach the root message, which
-- has no parent.
--
-- But there's another way to recurse. What if we wanted to find the replies
-- for a given message? And the replies of the replies, and so on? In other
-- words, the __decendants__ of a given message. Suppose we start with a
-- message /a/. We get a list of the replies of /a/, i.e. message whose parent
-- is `Just a`. Then we find the replies of those messages, i.e. the replies of
-- the replies of /a/. And so on, recursively, until we can't find more replies
-- and then we stop.
--
-- Therefore we can perform the recursion in one of two directions:
--
-- - __Outwards__, i.e. follow from a message to its parents. More generally,
-- given a persistent entity type `Foobar`, follow recursively using a
-- specific field of it, whose type is `FoobarId`. It's called "outwards"
-- because it's like following out-edges of a graph node, i.e. arrows
-- pointing from a node towards other nodes.
-- - __Inwards__, i.e. find the children (i.e. replies) of a message, and then
-- their children, and so on. More generally, given a persistent entity type
-- `Foobar`, find other values referring to it using a specific field, whose
-- type is `FoobarId`, and recursive find such values for the results we get
-- and so on. It's called "inwards" because it's like following in-edges of a
-- graph node, i.e. arrows pointing from other nodes towards that node.
--
-- The 'RecursionDirection' type is used for specifying the direction.
--
-- When you follow all the children of an entity recursively, or all of its
-- parents, we call the result you get the __transitive closure__ of the
-- specific field you used. You can further specify the direction, i.e.
-- __outward transitive closure__ or __inward transitive closure__. For
-- examples, if you follow a message's parents recursively as in the example
-- above, you get an outward transitive closure on the /parent/ field.
--
-- Note that the definition used here is __not__ the same as the mathematical
-- definition. When you perform a recursive query without filters, you get not
-- only the ancestors (or the decendants) of an entity, but also the root
-- entity itself. In other words, even though a message is not a reply of
-- itself, you'll still get it in the query result. If you want to get just the
-- ancestors (or decendants), i.e. the actual transitive closer of the "is
-- reply of" relation in the mathematical sense, use a filter to omit the root
-- message based on the ID, i.e. @[MessageParent /= msgid]@.
--
-- Therefore, when the term "transitive closure" is used below, it means not
-- just the ancestors (or decendants), but also the origin entity too.
module Database.Persist.Local.Class.PersistQueryRecursive
( RecursionDirection (..)
, PersistQueryRecursive (..)
, selectRecursivelySource
, selectRecursivelyKeys
, selectRecursivelyList
, selectRecursivelyKeysList
)
where
import Prelude
import Control.Monad.IO.Class
import Control.Monad.Reader (MonadReader)
import Control.Monad.Trans.Reader (ReaderT)
import Control.Monad.Trans.Resource (MonadResource, release)
import Data.Acquire (Acquire, allocateAcquire, with)
import Database.Persist.Class
import Database.Persist.Types
import qualified Data.Conduit as C
import qualified Data.Conduit.List as CL
data RecursionDirection
= RecOut
| RecIn
deriving (Eq, Show)
-- | Backends supporting recursive conditional operations.
class PersistQuery backend => PersistQueryRecursive backend where
-- | Update individual fields on any record in the transitive closure and
-- matching the given criterion.
updateRecursivelyWhere
:: (MonadIO m, PersistEntity val, backend ~ PersistEntityBackend val)
=> RecursionDirection
-> EntityField val (Maybe (Key val))
-> Key val
-> [Filter val]
-> [Update val]
-> ReaderT backend m ()
-- | Delete all records in the transitive closure which match the given
-- criterion.
deleteRecursivelyWhere
:: (MonadIO m, PersistEntity val, backend ~ PersistEntityBackend val)
=> RecursionDirection
-> EntityField val (Maybe (Key val))
-> Key val
-> [Filter val]
-> ReaderT backend m ()
-- | Get all records in the transitive closure, which match the given
-- criterion, in the specified order. Returns also the identifiers.
selectRecursivelySourceRes
:: ( PersistEntity val
, PersistEntityBackend val ~ backend
, MonadIO m1
, MonadIO m2
)
=> RecursionDirection
-> EntityField val (Maybe (Key val))
-> Key val
-> [Filter val]
-> [SelectOpt val]
-> ReaderT backend m1 (Acquire (C.Source m2 (Entity val)))
-- | Get the 'Key's of all records in the transitive closure, which match
-- the given criterion.
selectRecursivelyKeysRes
:: ( PersistEntity val
, PersistEntityBackend val ~ backend
, MonadIO m1
, MonadIO m2
)
=> RecursionDirection
-> EntityField val (Maybe (Key val))
-> Key val
-> [Filter val]
-> [SelectOpt val]
-> ReaderT backend m1 (Acquire (C.Source m2 (Key val)))
-- | The total number of records in the transitive closure which fulfill
-- the given criterion.
countRecursively
:: (MonadIO m, PersistEntity val, backend ~ PersistEntityBackend val)
=> RecursionDirection
-> EntityField val (Maybe (Key val))
-> Key val
-> [Filter val]
-> ReaderT backend m Int
-- | Get all records in the transitive closure, which match the given
-- criterion, in the specified order. Returns also the identifiers.
selectRecursivelySource
:: ( PersistQueryRecursive backend
, MonadResource m
, PersistEntity val
, PersistEntityBackend val ~ backend
, MonadReader env m
, HasPersistBackend env backend
)
=> RecursionDirection
-> EntityField val (Maybe (Key val))
-> Key val
-> [Filter val]
-> [SelectOpt val]
-> C.Source m (Entity val)
selectRecursivelySource dir field root filts opts = do
srcRes <-
liftPersist $ selectRecursivelySourceRes dir field root filts opts
(releaseKey, src) <- allocateAcquire srcRes
src
release releaseKey
-- | Get the 'Key's of all records in the transitive closure, which match the
-- given criterion.
selectRecursivelyKeys
:: ( PersistQueryRecursive backend
, MonadResource m
, PersistEntity val
, backend ~ PersistEntityBackend val
, MonadReader env m
, HasPersistBackend env backend
)
=> RecursionDirection
-> EntityField val (Maybe (Key val))
-> Key val
-> [Filter val]
-> [SelectOpt val]
-> C.Source m (Key val)
selectRecursivelyKeys dir field root filts opts = do
srcRes <- liftPersist $ selectRecursivelyKeysRes dir field root filts opts
(releaseKey, src) <- allocateAcquire srcRes
src
release releaseKey
-- | Call 'selectRecursivelySource' but return the result as a list.
selectRecursivelyList
:: ( PersistQueryRecursive backend
, MonadIO m
, PersistEntity val
, PersistEntityBackend val ~ backend
)
=> RecursionDirection
-> EntityField val (Maybe (Key val))
-> Key val
-> [Filter val]
-> [SelectOpt val]
-> ReaderT backend m [Entity val]
selectRecursivelyList dir field root filts opts = do
srcRes <- selectRecursivelySourceRes dir field root filts opts
liftIO $ with srcRes (C.$$ CL.consume)
-- | Call 'selectRecursivelyKeys' but return the result as a list.
selectRecursivelyKeysList
:: ( PersistQueryRecursive backend
, MonadIO m
, PersistEntity val
, PersistEntityBackend val ~ backend
)
=> RecursionDirection
-> EntityField val (Maybe (Key val))
-> Key val
-> [Filter val]
-> [SelectOpt val]
-> ReaderT backend m [Key val]
selectRecursivelyKeysList dir field root filts opts = do
srcRes <- selectRecursivelyKeysRes dir field root filts opts
liftIO $ with srcRes (C.$$ CL.consume)

View file

@ -0,0 +1,239 @@
{- This file contains (slightly modified) copies of unexported functions from
- Database.Persist.Sql.Orphan.PersistQuery, which I need for my
- PersistQueryRecursive implementation. They're released under MIT.
-
- This should be a temporary situation. Either my code moves to persistent and
- the functions are reused there, or these functions become exported in
- persistent and then I can import them instead of holding copies.
-}
{-# LANGUAGE RankNTypes #-}
module Database.Persist.Local.Sql.Orphan.Common
( fieldName
, dummyFromFilts
, getFiltsValues
, updatePersistValue
, filterClause
, orderClause
)
where
import Prelude
import Data.List (inits, transpose)
import Data.Monoid ((<>))
import Data.Text (Text)
import Database.Persist
import Database.Persist.Sql
import Database.Persist.Sql.Util
import qualified Data.Text as T
fieldName
:: forall record typ.
(PersistEntity record , PersistEntityBackend record ~ SqlBackend)
=> EntityField record typ
-> DBName
fieldName f = fieldDB $ persistFieldDef f
dummyFromFilts :: [Filter v] -> Maybe v
dummyFromFilts _ = Nothing
getFiltsValues
:: forall val. (PersistEntity val, PersistEntityBackend val ~ SqlBackend)
=> SqlBackend
-> [Filter val]
-> [PersistValue]
getFiltsValues conn = snd . filterClauseHelper False False conn OrNullNo
data OrNull = OrNullYes | OrNullNo
filterClauseHelper
:: (PersistEntity val, PersistEntityBackend val ~ SqlBackend)
=> Bool -- ^ include table name?
-> Bool -- ^ include WHERE?
-> SqlBackend
-> OrNull
-> [Filter val]
-> (Text, [PersistValue])
filterClauseHelper includeTable includeWhere conn orNull filters =
( if not (T.null sql) && includeWhere
then " WHERE " <> sql
else sql
, vals
)
where
(sql, vals) = combineAND filters
combineAND = combine " AND "
combine s fs =
(T.intercalate s $ map wrapP a, mconcat b)
where
(a, b) = unzip $ map go fs
wrapP x = T.concat ["(", x, ")"]
go (BackendFilter _) = error "BackendFilter not expected"
go (FilterAnd []) = ("1=1", [])
go (FilterAnd fs) = combineAND fs
go (FilterOr []) = ("1=0", [])
go (FilterOr fs) = combine " OR " fs
go (Filter field value pfilter) =
let t = entityDef $ dummyFromFilts [Filter field value pfilter]
in case (isIdField field, entityPrimary t, allVals) of
(True, Just pdef, PersistList ys:_) ->
if length (compositeFields pdef) /= length ys
then error $ "wrong number of entries in compositeFields vs PersistList allVals=" ++ show allVals
else
case (allVals, pfilter, isCompFilter pfilter) of
([PersistList xs], Eq, _) ->
let sqlcl=T.intercalate " and " (map (\a -> connEscapeName conn (fieldDB a) <> showSqlFilter pfilter <> "? ") (compositeFields pdef))
in (wrapSql sqlcl,xs)
([PersistList xs], Ne, _) ->
let sqlcl=T.intercalate " or " (map (\a -> connEscapeName conn (fieldDB a) <> showSqlFilter pfilter <> "? ") (compositeFields pdef))
in (wrapSql sqlcl,xs)
(_, In, _) ->
let xxs = transpose (map fromPersistList allVals)
sqls=map (\(a,xs) -> connEscapeName conn (fieldDB a) <> showSqlFilter pfilter <> "(" <> T.intercalate "," (replicate (length xs) " ?") <> ") ") (zip (compositeFields pdef) xxs)
in (wrapSql (T.intercalate " and " (map wrapSql sqls)), concat xxs)
(_, NotIn, _) ->
let xxs = transpose (map fromPersistList allVals)
sqls=map (\(a,xs) -> connEscapeName conn (fieldDB a) <> showSqlFilter pfilter <> "(" <> T.intercalate "," (replicate (length xs) " ?") <> ") ") (zip (compositeFields pdef) xxs)
in (wrapSql (T.intercalate " or " (map wrapSql sqls)), concat xxs)
([PersistList xs], _, True) ->
let zs = tail (inits (compositeFields pdef))
sql1 = map (\b -> wrapSql (T.intercalate " and " (map (\(i,a) -> sql2 (i==length b) a) (zip [1..] b)))) zs
sql2 islast a = connEscapeName conn (fieldDB a) <> (if islast then showSqlFilter pfilter else showSqlFilter Eq) <> "? "
sqlcl = T.intercalate " or " sql1
in (wrapSql sqlcl, concat (tail (inits xs)))
(_, BackendSpecificFilter _, _) -> error "unhandled type BackendSpecificFilter for composite/non id primary keys"
_ -> error $ "unhandled type/filter for composite/non id primary keys pfilter=" ++ show pfilter ++ " persistList="++show allVals
(True, Just pdef, _) -> error $ "unhandled error for composite/non id primary keys pfilter=" ++ show pfilter ++ " persistList=" ++ show allVals ++ " pdef=" ++ show pdef
_ -> case (isNull, pfilter, varCount) of
(True, Eq, _) -> (name <> " IS NULL", [])
(True, Ne, _) -> (name <> " IS NOT NULL", [])
(False, Ne, _) -> (T.concat
[ "("
, name
, " IS NULL OR "
, name
, " <> "
, qmarks
, ")"
], notNullVals)
-- We use 1=2 (and below 1=1) to avoid using TRUE and FALSE, since
-- not all databases support those words directly.
(_, In, 0) -> ("1=2" <> orNullSuffix, [])
(False, In, _) -> (name <> " IN " <> qmarks <> orNullSuffix, allVals)
(True, In, _) -> (T.concat
[ "("
, name
, " IS NULL OR "
, name
, " IN "
, qmarks
, ")"
], notNullVals)
(_, NotIn, 0) -> ("1=1", [])
(False, NotIn, _) -> (T.concat
[ "("
, name
, " IS NULL OR "
, name
, " NOT IN "
, qmarks
, ")"
], notNullVals)
(True, NotIn, _) -> (T.concat
[ "("
, name
, " IS NOT NULL AND "
, name
, " NOT IN "
, qmarks
, ")"
], notNullVals)
_ -> (name <> showSqlFilter pfilter <> "?" <> orNullSuffix, allVals)
where
isCompFilter Lt = True
isCompFilter Le = True
isCompFilter Gt = True
isCompFilter Ge = True
isCompFilter _ = False
wrapSql sqlcl = "(" <> sqlcl <> ")"
fromPersistList (PersistList xs) = xs
fromPersistList other = error $ "expected PersistList but found " ++ show other
filterValueToPersistValues :: forall a. PersistField a => Either a [a] -> [PersistValue]
filterValueToPersistValues v = map toPersistValue $ either return id v
orNullSuffix =
case orNull of
OrNullYes -> mconcat [" OR ", name, " IS NULL"]
OrNullNo -> ""
isNull = any (== PersistNull) allVals
notNullVals = filter (/= PersistNull) allVals
allVals = filterValueToPersistValues value
tn = connEscapeName conn $ entityDB
$ entityDef $ dummyFromFilts [Filter field value pfilter]
name =
(if includeTable
then ((tn <> ".") <>)
else id)
$ connEscapeName conn $ fieldName field
qmarks = case value of
Left _ -> "?"
Right x ->
let x' = filter (/= PersistNull) $ map toPersistValue x
in "(" <> T.intercalate "," (map (const "?") x') <> ")"
varCount = case value of
Left _ -> 1
Right x -> length x
showSqlFilter Eq = "="
showSqlFilter Ne = "<>"
showSqlFilter Gt = ">"
showSqlFilter Lt = "<"
showSqlFilter Ge = ">="
showSqlFilter Le = "<="
showSqlFilter In = " IN "
showSqlFilter NotIn = " NOT IN "
showSqlFilter (BackendSpecificFilter s) = s
updatePersistValue :: Update v -> PersistValue
updatePersistValue (Update _ v _) = toPersistValue v
updatePersistValue _ = error "BackendUpdate not implemented"
filterClause :: (PersistEntity val, PersistEntityBackend val ~ SqlBackend)
=> Bool -- ^ include table name?
-> SqlBackend
-> [Filter val]
-> Text
filterClause b c = fst . filterClauseHelper b True c OrNullNo
orderClause :: (PersistEntity val, PersistEntityBackend val ~ SqlBackend)
=> Bool -- ^ include the table name
-> SqlBackend
-> SelectOpt val
-> Text
orderClause includeTable conn o =
case o of
Asc x -> name x
Desc x -> name x <> " DESC"
_ -> error "orderClause: expected Asc or Desc, not limit or offset"
where
dummyFromOrder :: SelectOpt a -> Maybe a
dummyFromOrder _ = Nothing
tn = connEscapeName conn $ entityDB $ entityDef $ dummyFromOrder o
name :: (PersistEntityBackend record ~ SqlBackend, PersistEntity record)
=> EntityField record typ -> Text
name x =
(if includeTable
then ((tn <> ".") <>)
else id)
$ connEscapeName conn $ fieldName x

View file

@ -0,0 +1,333 @@
{- This file is part of Vervis.
-
- Written in 2016 by fr33domlover <fr33domlover@riseup.net>.
-
- Copying is an act of love. Please copy, reuse and share.
-
- The author(s) have dedicated all copyright and related and neighboring
- rights to this software to the public domain worldwide. This software is
- distributed without any warranty.
-
- You should have received a copy of the CC0 Public Domain Dedication along
- with this software. If not, see
- <http://creativecommons.org/publicdomain/zero/1.0/>.
-}
module Database.Persist.Local.Sql.Orphan.PersistQueryRecursive
( deleteRecursivelyWhereCount
, updateRecursivelyWhereCount
)
where
import Prelude
import Control.Monad (void)
import Control.Monad.IO.Class
import Control.Monad.Trans.Reader (ReaderT, ask)
import Control.Exception (throwIO)
import Data.ByteString.Char8 (readInteger)
import Data.Conduit (($=))
import Data.Foldable (find)
import Data.Int (Int64)
import Data.Maybe (isJust)
import Data.Monoid ((<>))
import Data.Text (Text)
import Database.Persist
import Database.Persist.Sql
import Database.Persist.Sql.Util
import qualified Data.Conduit.List as CL (head, mapM)
import qualified Data.Text as T (pack, unpack, intercalate)
import Database.Persist.Local.Class.PersistQueryRecursive
import Database.Persist.Local.Sql.Orphan.Common
instance PersistQueryRecursive SqlBackend where
updateRecursivelyWhere dir field root filts upds =
void $ updateRecursivelyWhereCount dir field root filts upds
deleteRecursivelyWhere dir field root filts =
void $ deleteRecursivelyWhereCount dir field root filts
selectRecursivelySourceRes dir field root filts opts = do
conn <- ask
let (sql, vals, parse) = sqlValsParse conn
srcRes <- rawQueryRes sql vals
return $ fmap ($= CL.mapM parse) srcRes
where
sqlValsParse conn = (sql, vals, parse)
where
(temp, isRoot, cols, qcols, sqlWith) =
withRecursive dir field root conn t (flip entityColumnNames)
(limit, offset, orders) = limitOffsetOrder opts
parse xs =
case parseEntityValues t xs of
Left s -> liftIO $ throwIO $ PersistMarshalError s
Right row -> return row
t = entityDef $ dummyFromFilts filts
wher =
if null filts
then ""
else filterClause False conn filts
ord =
case map (orderClause False conn) orders of
[] -> ""
ords -> " ORDER BY " <> T.intercalate "," ords
sql =
mappend sqlWith $
connLimitOffset conn (limit, offset) (not $ null orders) $
mconcat
[ "SELECT "
, cols
, " FROM "
, connEscapeName conn temp
, wher
, ord
]
vals = getFiltsValues conn $ isRoot : filts
selectRecursivelyKeysRes dir field root filts opts = do
conn <- ask
let (sql, vals, parse) = sqlValsParse conn
srcRes <- rawQueryRes sql vals
return $ fmap ($= CL.mapM parse) srcRes
where
sqlValsParse conn = (sql, vals, parse)
where
(temp, isRoot, cols, qcols, sqlWith) =
withRecursive dir field root conn t dbIdColumns
(limit, offset, orders) = limitOffsetOrder opts
parse xs = do
keyvals <-
case entityPrimary t of
Nothing ->
case xs of
[PersistInt64 x] ->
return [PersistInt64 x]
[PersistDouble x] ->
-- oracle returns Double
return [PersistInt64 $ truncate x]
_ ->
liftIO $ throwIO $ PersistMarshalError $
"Unexpected in selectKeys False: " <>
T.pack (show xs)
Just pdef ->
let pks = map fieldHaskell $ compositeFields pdef
keyvals =
map snd $
filter
(\ (a, _) ->
let ret = isJust (find (== a) pks)
in ret
) $
zip (map fieldHaskell $ entityFields t) xs
in return keyvals
case keyFromValues keyvals of
Right k -> return k
Left _ -> error "selectKeysImpl: keyFromValues failed"
t = entityDef $ dummyFromFilts filts
wher =
if null filts
then ""
else filterClause False conn filts
ord =
case map (orderClause False conn) orders of
[] -> ""
ords -> " ORDER BY " <> T.intercalate "," ords
sql =
mappend sqlWith $
connLimitOffset conn (limit, offset) (not $ null orders) $
mconcat
[ "SELECT "
, cols
, " FROM "
, connEscapeName conn temp
, wher
, ord
]
vals = getFiltsValues conn $ isRoot : filts
countRecursively dir field root filts = do
conn <- ask
let (sql, vals) = sqlAndVals conn
withRawQuery sql vals $ do
mm <- CL.head
case mm of
Just [PersistInt64 i] -> return $ fromIntegral i
Just [PersistDouble i] ->return $ fromIntegral (truncate i :: Int64) -- gb oracle
Just [PersistByteString i] -> case readInteger i of -- gb mssql
Just (ret,"") -> return $ fromIntegral ret
xs -> error $ "invalid number i["++show i++"] xs[" ++ show xs ++ "]"
Just xs -> error $ "count:invalid sql return xs["++show xs++"] sql["++show sql++"]"
Nothing -> error $ "count:invalid sql returned nothing sql["++show sql++"]"
where
sqlAndVals conn = (sql, vals)
where
(temp, isRoot, cols, qcols, sqlWith) =
withRecursive dir field root conn t dbIdColumns
t = entityDef $ dummyFromFilts filts
wher =
if null filts
then ""
else filterClause False conn filts
sql = mconcat
[ sqlWith
, "SELECT COUNT(*) FROM "
, connEscapeName conn temp
, wher
]
vals = getFiltsValues conn $ isRoot : filts
-- | Same as 'deleteRecursivelyWhere', but returns the number of rows affected.
deleteRecursivelyWhereCount
:: (PersistEntity val, MonadIO m, PersistEntityBackend val ~ SqlBackend)
=> RecursionDirection
-> EntityField val (Maybe (Key val))
-> Key val
-> [Filter val]
-> ReaderT SqlBackend m Int64
deleteRecursivelyWhereCount dir field root filts = do
conn <- ask
let (sql, vals) = sqlAndVals conn
rawExecuteCount sql vals
where
sqlAndVals conn = (sql, vals)
where
(temp, isRoot, cols, qcols, sqlWith) =
withRecursive dir field root conn t dbIdColumns
t = entityDef $ dummyFromFilts filts
wher = mconcat
[ if null filts
then " WHERE ( "
else filterClause False conn filts <> " AND ( "
, connEscapeName conn $ fieldDB $ entityId t
, " IN (SELECT "
, connEscapeName conn $ fieldDB $ entityId t
, " FROM "
, connEscapeName conn temp
, ") ) "
]
sql = mconcat
[ sqlWith
, "DELETE FROM "
, connEscapeName conn $ entityDB t
, wher
]
vals = getFiltsValues conn $ isRoot : filts
-- | Same as 'updateRecursivelyWhere', but returns the number of rows affected.
updateRecursivelyWhereCount
:: (PersistEntity val, MonadIO m, SqlBackend ~ PersistEntityBackend val)
=> RecursionDirection
-> EntityField val (Maybe (Key val))
-> Key val
-> [Filter val]
-> [Update val]
-> ReaderT SqlBackend m Int64
updateRecursivelyWhereCount _ _ _ _ [] = return 0
updateRecursivelyWhereCount dir field root filts upds = do
conn <- ask
let (sql, vals) = sqlAndVals conn
rawExecuteCount sql vals
where
sqlAndVals conn = (sql, vals)
where
(temp, isRoot, cols, qcols, sqlWith) =
withRecursive dir field root conn t dbIdColumns
t = entityDef $ dummyFromFilts filts
go'' n Assign = n <> "=?"
go'' n Add = mconcat [n, "=", n, "+?"]
go'' n Subtract = mconcat [n, "=", n, "-?"]
go'' n Multiply = mconcat [n, "=", n, "*?"]
go'' n Divide = mconcat [n, "=", n, "/?"]
go'' _ (BackendSpecificUpdate up) =
error $ T.unpack $ "BackendSpecificUpdate " <> up <> " not supported"
go' (x, pu) = go'' (connEscapeName conn x) pu
go x = (updateField x, updateUpdate x)
updateField (Update f _ _) = fieldName f
updateField _ = error "BackendUpdate not implemented"
wher = mconcat
[ if null filts
then " WHERE ( "
else filterClause False conn filts <> " AND ( "
, connEscapeName conn $ fieldDB $ entityId t
, " IN (SELECT "
, connEscapeName conn $ fieldDB $ entityId t
, " FROM "
, connEscapeName conn temp
, ") ) "
]
sql = mconcat
[ sqlWith
, "UPDATE "
, connEscapeName conn $ entityDB t
, " SET "
, T.intercalate "," $ map (go' . go) upds
, wher
]
vals =
getFiltsValues conn [isRoot] ++
map updatePersistValue upds ++
getFiltsValues conn filts
withRecursive
:: (PersistEntity val, SqlBackend ~ PersistEntityBackend val)
=> RecursionDirection
-> EntityField val (Maybe (Key val))
-> Key val
-> SqlBackend
-> EntityDef
-> (SqlBackend -> EntityDef -> [Text])
-> (DBName, Filter val, Text, DBName -> Text, Text)
withRecursive dir field root conn t getcols =
let temp = DBName "temp_hierarchy_cte"
isRoot = persistIdField ==. root
cols = T.intercalate "," $ getcols conn t
qcols name =
T.intercalate ", " $
map ((connEscapeName conn name <>) . ("." <>)) $
getcols conn t
sql = mconcat
[ "WITH RECURSIVE "
, connEscapeName conn temp
, "("
, cols
, ") AS ( SELECT "
, cols
, " FROM "
, connEscapeName conn $ entityDB t
, filterClause False conn [isRoot]
--, " WHERE "
--, connEscapeName conn $ fieldDB $ entityId t
--, " = ?"
, " UNION SELECT "
, qcols temp
, " FROM "
, connEscapeName conn $ entityDB t
, ", "
, connEscapeName conn temp
, " WHERE "
, connEscapeName conn $ entityDB t
, "."
, connEscapeName conn $ fieldDB $ case dir of
RecOut -> persistFieldDef field
RecIn -> entityId t
, " = "
, connEscapeName conn temp
, "."
, connEscapeName conn $ fieldDB $ case dir of
RecOut -> entityId t
RecIn -> persistFieldDef field
, " ) "
]
in (temp, isRoot, cols, qcols, sql)

View file

@ -64,6 +64,9 @@ library
Database.Esqueleto.Local
Database.Persist.Class.Local
Database.Persist.Sql.Local
Database.Persist.Local.Class.PersistQueryRecursive
Database.Persist.Local.Sql.Orphan.Common
Database.Persist.Local.Sql.Orphan.PersistQueryRecursive
Development.DarcsRev
Formatting.CaseInsensitive
Network.SSH.Local
@ -219,6 +222,8 @@ library
, memory
, monad-control
, monad-logger
-- for Database.Persist.Local
, mtl
, pandoc
, pandoc-types
-- for PathPiece instance for CI, Web.PathPieces.Local
@ -227,6 +232,8 @@ library
, persistent-postgresql
, persistent-template
, process
-- for Database.Persist.Local
, resourcet
, safe
, shakespeare
, ssh