diff --git a/src/Vervis/Application.hs b/src/Vervis/Application.hs index 4db69a4..458b8ce 100644 --- a/src/Vervis/Application.hs +++ b/src/Vervis/Application.hs @@ -54,6 +54,7 @@ import Yesod.Default.Main (LogFunc) import Yesod.Mail.Send (runMailer) import qualified Data.Text as T (unpack) +import qualified Data.HashMap.Strict as M (empty) import Control.Concurrent.Local (forkCheck) @@ -121,6 +122,8 @@ makeFoundation appSettings = do newTVarIO =<< (,,) <$> generateActorKey <*> generateActorKey <*> pure True + appInstanceMutex <- newTVarIO M.empty + appActivities <- newTVarIO mempty -- We need a log function to create a connection pool. We need a connection diff --git a/src/Vervis/Foundation.hs b/src/Vervis/Foundation.hs index 558ce07..04f58d2 100644 --- a/src/Vervis/Foundation.hs +++ b/src/Vervis/Foundation.hs @@ -17,12 +17,16 @@ module Vervis.Foundation where import Prelude (init, last) +import Control.Concurrent.MVar (MVar, newEmptyMVar) +import Control.Concurrent.STM.TVar import Control.Monad.Logger.CallStack (logWarn) +import Control.Monad.STM (atomically) import Control.Monad.Trans.Except import Control.Monad.Trans.Maybe import Crypto.Error (CryptoFailable (..)) import Crypto.PubKey.Ed25519 (PublicKey, publicKey, signature, verify) import Data.Either (isRight) +import Data.HashMap.Strict (HashMap) import Data.Maybe (fromJust) import Data.PEM (pemContent) import Data.Text.Encoding (decodeUtf8') @@ -37,6 +41,7 @@ import Network.URI (URI, uriAuthority, uriFragment, uriRegName, parseURI) import Text.Shakespeare.Text (textFile) import Text.Hamlet (hamletFile) --import Text.Jasmine (minifym) +import UnliftIO.MVar (withMVar) import Yesod.Auth.Account import Yesod.Auth.Account.Message (AccountMsg (MsgUsernameExists)) import Yesod.Auth.Message (AuthMessage (IdentifierNotFound)) @@ -45,6 +50,7 @@ import Yesod.Default.Util (addStaticContentExternal) import qualified Data.ByteString.Char8 as BC (unpack) import qualified Data.ByteString.Lazy as BL (ByteString) +import qualified Data.HashMap.Strict as M (lookup, insert) import qualified Yesod.Core.Unsafe as Unsafe --import qualified Data.CaseInsensitive as CI import Data.Text as T (pack, intercalate, concat) @@ -85,6 +91,7 @@ data App = App , appMailQueue :: Maybe (Chan (MailRecipe App)) , appSvgFont :: PreparedFont Double , appActorKeys :: TVar (ActorKey, ActorKey, Bool) + , appInstanceMutex :: TVar (HashMap Text (MVar ())) , appCapSignKey :: ActorKey , appHashidEncode :: Int64 -> Text , appHashidDecode :: Text -> Maybe Int64 @@ -564,6 +571,29 @@ unsafeHandler = Unsafe.fakeHandlerGetLogger appLogger -- https://github.com/yesodweb/yesod/wiki/Serve-static-files-from-a-separate-domain -- https://github.com/yesodweb/yesod/wiki/i18n-messages-in-the-scaffolding +-- TODO this is copied from stm-2.5, remove when we upgrade LTS +stateTVar :: TVar s -> (s -> (a, s)) -> STM a +stateTVar var f = do + s <- readTVar var + let (a, s') = f s -- since we destructure this, we are strict in f + writeTVar var s' + return a + +withHostLock :: Text -> Handler a -> Handler a +withHostLock host action = do + tvar <- getsYesod appInstanceMutex + mvar <- liftIO $ do + existing <- M.lookup host <$> readTVarIO tvar + case existing of + Just v -> return v + Nothing -> do + v <- newEmptyMVar + atomically $ stateTVar tvar $ \ m -> + case M.lookup host m of + Just v' -> (v', m) + Nothing -> (v, M.insert host v m) + withMVar mvar $ const action + sumUpTo :: Int -> AppDB Int -> AppDB Int -> AppDB Bool sumUpTo limit action1 action2 = do n <- action1 @@ -751,21 +781,23 @@ keyListedByActorShared manager iid vkid host luKey luActor = do else Just $ Just rsid for_ mresult $ \ mrsid -> do luInbox <- actorInbox <$> ExceptT (keyListedByActor manager host luKey luActor) - ExceptT $ runDB $ case mrsid of - Nothing -> do - rsid <- insert $ RemoteSharer luActor iid luInbox - insert_ $ VerifKeySharedUsage vkid rsid - return $ Right () - Just rsid -> runExceptT $ do - case m of - RoomModeNoLimit -> return () - RoomModeLimit limit -> do - if reject - then do - room <- lift $ actorRoom limit rsid - unless room $ throwE "Actor key storage limit reached" - else lift $ makeActorRoomForUsage limit rsid - lift $ insert_ $ VerifKeySharedUsage vkid rsid + ExceptT $ runDB $ do + vkExists <- isJust <$> get vkid + case mrsid of + Nothing -> do + rsid <- insert $ RemoteSharer luActor iid luInbox + when vkExists $ insert_ $ VerifKeySharedUsage vkid rsid + return $ Right () + Just rsid -> runExceptT $ when vkExists $ do + case m of + RoomModeNoLimit -> return () + RoomModeLimit limit -> do + if reject + then do + room <- lift $ actorRoom limit rsid + unless room $ throwE "Actor key storage limit reached" + else lift $ makeActorRoomForUsage limit rsid + lift $ insert_ $ VerifKeySharedUsage vkid rsid data VerifKeyDetail = VerifKeyDetail { vkdKeyId :: LocalURI @@ -815,7 +847,7 @@ instance YesodHttpSig App where Just u -> return u manager <- getsYesod appHttpManager let iid = verifKeyInstance vk - keyListedByActorShared manager iid vkid host luKey ua + withHostLock' host $ keyListedByActorShared manager iid vkid host luKey ua return (ua, True) return ( Right (verifKeyInstance vk, vkid) @@ -837,7 +869,7 @@ instance YesodHttpSig App where if verify' (vkdKey vkd) && stillValid (vkdExpires vkd) then case inboxOrVkid of - Left uinb -> ExceptT $ runDB $ addVerifKey host uinb vkd + Left uinb -> ExceptT $ withHostLock host $ runDB $ addVerifKey host uinb vkd Right _ids -> return () else case inboxOrVkid of Left _uinb -> @@ -846,7 +878,7 @@ instance YesodHttpSig App where else errTime Right (iid, vkid) -> do let ua = vkdActorId vkd - listed = keyListedByActorShared manager iid vkid host luKey ua + listed = withHostLock' host $ keyListedByActorShared manager iid vkid host luKey ua (newKey, newExp) <- if vkdShared vkd then fetchKnownSharedKey manager listed sigAlgo host ua luKey @@ -956,6 +988,7 @@ instance YesodHttpSig App where lift $ insert_ $ VerifKey luKey iid mexpires key (Just rsid) updateVerifKey vkid vkd = update vkid [VerifKeyExpires =. vkdExpires vkd, VerifKeyPublic =. vkdKey vkd] + withHostLock' h = ExceptT . withHostLock h . runExceptT instance YesodBreadcrumbs App where breadcrumb route = return $ case route of diff --git a/src/Vervis/Import/NoFoundation.hs b/src/Vervis/Import/NoFoundation.hs index f8b13f8..6810810 100644 --- a/src/Vervis/Import/NoFoundation.hs +++ b/src/Vervis/Import/NoFoundation.hs @@ -15,7 +15,7 @@ module Vervis.Import.NoFoundation ( module Import ) where -import ClassyPrelude.Conduit as Import hiding (delete, deleteBy) +import ClassyPrelude.Conduit as Import hiding (delete, deleteBy, readTVarIO, newEmptyMVar, atomically) import Data.Default as Import (Default (..)) import Database.Persist.Sql as Import ( SqlBackend , SqlPersistT