From dad1ed2e1fe306ee2c8e8cf82b6ac6fd1bfd5cda Mon Sep 17 00:00:00 2001
From: fr33domlover <fr33domlover@rel4tion.org>
Date: Fri, 29 Jul 2016 22:57:52 +0000
Subject: [PATCH] SQL: IN (1, 2, 3) instead of invalid ANY('[1, 2, 3]')

I thought SQL arrays were common and PersistList corresponded to SQL
array values. But that isn't the case. PersistList seems to be
serialized as a JSON list, and `filterClause` uses IN, not ANY. So I'm
doing the same thing here and using IN.

Note that I'm building the list myself using Text concatenation, not
using `filterClause`, because the latter takes a filter on an existing
`PersistEntity` while my filters often apply to temporary tables.
---
 src/Database/Persist/Local/Sql.hs           |  5 ++++
 src/Database/Persist/Sql/Graph/Connects.hs  | 30 +++++++++++++--------
 src/Database/Persist/Sql/Graph/Cyclic.hs    | 13 +++++----
 src/Database/Persist/Sql/Graph/Path.hs      | 23 +++++++++++-----
 src/Database/Persist/Sql/Graph/Reachable.hs | 11 +++++---
 5 files changed, 55 insertions(+), 27 deletions(-)

diff --git a/src/Database/Persist/Local/Sql.hs b/src/Database/Persist/Local/Sql.hs
index bbadf23..a816557 100644
--- a/src/Database/Persist/Local/Sql.hs
+++ b/src/Database/Persist/Local/Sql.hs
@@ -23,6 +23,7 @@ module Database.Persist.Local.Sql
     , destFieldFromProxy
     , sourceFieldFromProxy
     , (?:)
+    , (?++)
     , FollowDirection (..)
     )
 where
@@ -153,3 +154,7 @@ rawSqlWithGraph dir root parent child sub vals = do
 (?:) :: Maybe a -> [a] -> [a]
 (?:) = maybe id (:)
 infixr 5 ?:
+
+(?++) :: Maybe [a] -> [a] -> [a]
+(?++) = maybe id (++)
+infixr 5 ?++
diff --git a/src/Database/Persist/Sql/Graph/Connects.hs b/src/Database/Persist/Sql/Graph/Connects.hs
index f8e1465..8c8bb95 100644
--- a/src/Database/Persist/Sql/Graph/Connects.hs
+++ b/src/Database/Persist/Sql/Graph/Connects.hs
@@ -50,7 +50,7 @@ import Database.Persist
 import Database.Persist.Sql
 import Database.Persist.Sql.Util
 
-import qualified Data.Text as T (null, intercalate)
+import qualified Data.Text as T (empty, singleton, null, intercalate)
 
 import Database.Persist.Local.Class.PersistEntityGraph
 import Database.Persist.Local.Class.PersistQueryForest
@@ -152,12 +152,15 @@ xmconnectsm' follow filter msource mdest mlen proxy = do
                 , entityDB tNode ^* fieldDB (entityId tNode), ", "
                 , "ARRAY[", entityDB tNode ^* fieldDB (entityId tNode), "], "
                 , "FALSE"
+                , " FROM ", dbname $ entityDB tNode
                 , case msource of
-                    Nothing -> " FROM " <> dbname (entityDB tNode)
-                    Just _ -> mconcat
-                        [ " FROM ", dbname $ entityDB tNode
-                        , " WHERE ", entityDB tNode ^* fieldDB (entityId tNode)
-                        , " = ANY(?)"
+                    Nothing -> T.empty
+                    Just l -> mconcat
+                        [ " WHERE ", entityDB tNode ^* fieldDB (entityId tNode)
+                        , " IN ("
+                        , T.intercalate ", " $
+                            replicate (length l) (T.singleton '?')
+                        , ")"
                         ]
             , " UNION ALL "
             , case follow of
@@ -173,16 +176,21 @@ xmconnectsm' follow filter msource mdest mlen proxy = do
             , " ) SELECT 1 WHERE EXISTS ( SELECT ", temp ^* tpath
             , " FROM ", dbname temp
             , case mdest of
-                Nothing -> ""
-                Just _ -> " WHERE " <> temp ^* tid <> " = ANY(?)"
+                Nothing -> T.empty
+                Just l -> mconcat
+                    [ " WHERE ", temp ^* tid, " IN ("
+                    , T.intercalate ", " $
+                        replicate (length l) (T.singleton '?')
+                    , ")"
+                    ]
             , case mlen of
-                Nothing -> ""
+                Nothing -> T.empty
                 Just _ -> " AND array_length(" <> temp ^* tpath <> ", 1) <= ?"
             , " )"
             ]
         toP = fmap toPersistValue
-        toPL = fmap $ PersistList . map toPersistValue
-        vals = toPL msource ?: fvals ++ toPL mdest ?: toP mlen ?: []
+        toPL = fmap $ map toPersistValue
+        vals = toPL msource ?++ fvals ++ toPL mdest ?++ toP mlen ?: []
     rawSql sql vals
 
 connects
diff --git a/src/Database/Persist/Sql/Graph/Cyclic.hs b/src/Database/Persist/Sql/Graph/Cyclic.hs
index e8ae3ff..d901202 100644
--- a/src/Database/Persist/Sql/Graph/Cyclic.hs
+++ b/src/Database/Persist/Sql/Graph/Cyclic.hs
@@ -42,7 +42,7 @@ import Database.Persist
 import Database.Persist.Sql
 import Database.Persist.Sql.Util
 
-import qualified Data.Text as T (null, intercalate)
+import qualified Data.Text as T (singleton, null, intercalate)
 
 import Database.Persist.Local.Class.PersistEntityGraph
 import Database.Persist.Local.Class.PersistQueryForest
@@ -129,10 +129,13 @@ xcyclicn' follow filter minitials proxy = do
                     FollowForward  -> sqlStartFrom fwd
                     FollowBackward -> sqlStartFrom bwd
                     FollowBoth     -> " FROM " <> dbname (entityDB tNode)
-                Just initials -> mconcat
+                Just l -> mconcat
                     [ " FROM ", dbname $ entityDB tNode
                     , " WHERE ", entityDB tNode ^* fieldDB (entityId tNode)
-                    , " = ANY(?)"
+                    , " IN ("
+                    , T.intercalate ", " $
+                        replicate (length l) (T.singleton '?')
+                    , ")"
                     ]
             ]
 
@@ -199,8 +202,8 @@ xcyclicn' follow filter minitials proxy = do
                         , ") LIMIT 1"
                         ]
             ]
-        msval = PersistList . map toPersistValue <$> minitials
-        vals = maybe id (:) msval fvals
+        toPL = fmap $ map toPersistValue
+        vals = toPL minitials ?++ fvals
     rawSql sql vals
 
 -- $cyclic
diff --git a/src/Database/Persist/Sql/Graph/Path.hs b/src/Database/Persist/Sql/Graph/Path.hs
index 4d8e8d6..ec226d0 100644
--- a/src/Database/Persist/Sql/Graph/Path.hs
+++ b/src/Database/Persist/Sql/Graph/Path.hs
@@ -50,7 +50,7 @@ import Database.Persist
 import Database.Persist.Sql
 import Database.Persist.Sql.Util
 
-import qualified Data.Text as T (null, intercalate)
+import qualified Data.Text as T (empty, singleton, null, intercalate)
 
 import Database.Persist.Local.Class.PersistEntityGraph
 import Database.Persist.Local.Class.PersistQueryForest
@@ -154,10 +154,13 @@ xmpathm' follow filter msource mdest mlen mlim proxy = do
                 , "FALSE"
                 , case msource of
                     Nothing -> " FROM " <> dbname (entityDB tNode)
-                    Just _ -> mconcat
+                    Just l -> mconcat
                         [ " FROM ", dbname $ entityDB tNode
                         , " WHERE ", entityDB tNode ^* fieldDB (entityId tNode)
-                        , " = ANY(?)"
+                        , " IN ("
+                        , T.intercalate ", " $
+                            replicate (length l) (T.singleton '?')
+                        , ")"
                         ]
             , " UNION ALL "
             , case follow of
@@ -173,8 +176,14 @@ xmpathm' follow filter msource mdest mlen mlim proxy = do
             , " ) SELECT ", temp ^* tpath
             , " FROM ", dbname temp
             , case mdest of
-                Nothing -> ""
-                Just _ -> " WHERE " <> temp ^* tid <> " = ANY(?)"
+                Nothing -> T.empty
+                Just l -> mconcat
+                    [ " WHERE ", temp ^* tid
+                    , " IN ("
+                    , T.intercalate ", " $
+                        replicate (length l) (T.singleton '?')
+                    , ")"
+                    ]
             , case mlen of
                 Nothing -> ""
                 Just _ -> " AND array_length(" <> temp ^* tpath <> ", 1) <= ?"
@@ -184,9 +193,9 @@ xmpathm' follow filter msource mdest mlen mlim proxy = do
                 Just _ -> " LIMIT ?"
             ]
         toP = fmap toPersistValue
-        toPL = fmap $ PersistList . map toPersistValue
+        toPL = fmap $ map toPersistValue
         vals =
-            toPL msource ?: fvals ++ toPL mdest ?: toP mlen ?: toP mlim ?: []
+            toPL msource ?++ fvals ++ toPL mdest ?++ toP mlen ?: toP mlim ?: []
     rawSql sql vals
 
 path
diff --git a/src/Database/Persist/Sql/Graph/Reachable.hs b/src/Database/Persist/Sql/Graph/Reachable.hs
index ffefb12..1f728ed 100644
--- a/src/Database/Persist/Sql/Graph/Reachable.hs
+++ b/src/Database/Persist/Sql/Graph/Reachable.hs
@@ -38,7 +38,7 @@ import Database.Persist
 import Database.Persist.Sql
 import Database.Persist.Sql.Util
 
-import qualified Data.Text as T (null, intercalate)
+import qualified Data.Text as T (singleton, null, intercalate)
 
 import Database.Persist.Local.Class.PersistEntityGraph
 import Database.Persist.Local.Class.PersistQueryForest
@@ -127,7 +127,10 @@ xreachable' follow filter initials mlen proxy = do
                 , "FALSE"
                 , " FROM ", dbname $ entityDB tNode
                 , " WHERE ", entityDB tNode ^* fieldDB (entityId tNode)
-                , " = ANY(?)"
+                , " IN ("
+                , T.intercalate ", " $
+                    replicate (length initials) (T.singleton '?')
+                , ")"
             , " UNION ALL "
             , case follow of
                 FollowForward  -> sqlStep fwd bwd
@@ -147,8 +150,8 @@ xreachable' follow filter initials mlen proxy = do
                 Just _ -> " AND array_length(" <> temp ^* tpath <> ", 1) <= ?"
             ]
         toP = fmap toPersistValue
-        toPL = PersistList . map toPersistValue
-        vals = toPL initials : fvals ++ toP mlen ?: []
+        toPL = map toPersistValue
+        vals = toPL initials ++ fvals ++ toP mlen ?: []
     rawSql sql vals
 
 reachable