55module Share.Web.UCM.SyncV3.Impl (server ) where
66
77import Control.Lens hiding ((.=) )
8+ import Control.Monad.Cont (ContT (.. ), MonadCont (.. ))
89import Control.Monad.Except (runExceptT )
9- import Control.Monad.Trans.Except (ExceptT )
1010import Data.Set qualified as Set
1111import Data.Set.Lens (setOf )
1212import Data.Vector (Vector )
@@ -18,17 +18,21 @@ import Share.IDs (UserId)
1818import Share.Postgres qualified as PG
1919import Share.Prelude
2020import Share.Web.App
21+ import Share.Web.Authentication.Types qualified as AuthN
2122import Share.Web.Authorization qualified as AuthZ
23+ import Share.Web.UCM.Sync.HashJWT qualified as HashJWT
24+ import Share.Web.UCM.SyncCommon.Impl (codebaseForBranchRef )
25+ import Share.Web.UCM.SyncCommon.Types
2226import Share.Web.UCM.SyncV3.API qualified as SyncV3
2327import Share.Web.UCM.SyncV3.Queries qualified as Q
2428import U.Codebase.Sqlite.Orphans ()
29+ import Unison.Debug qualified as Debug
2530import Unison.Hash32 (Hash32 )
26- import Unison.Share.API.Hash (HashJWT )
31+ import Unison.Share.API.Hash (HashJWT , HashJWTClaims ( .. ) )
2732import Unison.SyncV3.Types
2833import Unison.Util.Websockets (Queues (.. ), withQueues )
2934import UnliftIO qualified
3035import UnliftIO.STM
31- import Unison.Debug qualified as Debug
3236
3337-- Amount of entities to buffer from the network into the send/recv queues.
3438sendBufferSize :: Natural
@@ -48,69 +52,90 @@ server mayUserId =
4852 { downloadEntities = downloadEntitiesImpl mayUserId
4953 }
5054
55+ type SyncM = ContT (Either SyncError () ) WebApp
56+
5157downloadEntitiesImpl :: Maybe UserId -> WS. Connection -> WebApp ()
52- downloadEntitiesImpl _mayCallerUserId conn = do
58+ downloadEntitiesImpl mayCallerUserId conn = do
5359 Debug. debugLogM Debug. Temp " Got connection"
5460 -- Auth is currently done via HashJWTs
5561 _authZReceipt <- AuthZ. checkDownloadFromUserCodebase
56- doSyncEmitter shareEmitter conn
62+ doSyncEmitter mayCallerUserId conn
5763
5864-- | Given a helper which understands how to wire things into its backend, This
5965-- implements the sync emitter logic which is independent of the backend.
6066doSyncEmitter ::
61- forall m .
62- (MonadUnliftIO m ) =>
63- ( (SyncState HashTag Hash32 ) ->
64- Queues (MsgOrError SyncError (FromEmitterMessage Hash32 Text )) (MsgOrError SyncError (FromReceiverMessage HashJWT Hash32 )) ->
65- m (Maybe SyncError )
66- ) ->
67+ Maybe UserId ->
6768 WS. Connection ->
68- m ()
69- doSyncEmitter emitterImpl conn = do
69+ WebApp ()
70+ doSyncEmitter mayCallerUserId conn = do
7071 withQueues @ (MsgOrError SyncError (FromEmitterMessage Hash32 Text )) @ (MsgOrError SyncError (FromReceiverMessage HashJWT Hash32 ))
7172 recvBufferSize
7273 sendBufferSize
7374 conn
74- \ (q@ Queues {receive}) -> handleErr q $ do
75- Debug. debugLogM Debug. Temp " Got queues"
76- let recvM :: ExceptT SyncError m (FromReceiverMessage HashJWT Hash32 )
77- recvM = do
78- result <- liftIO $ atomically receive
79- Debug. debugM Debug. Temp " Received: " result
80- case result of
81- Msg msg -> pure msg
82- Err err -> throwError err
83-
84- Debug. debugLogM Debug. Temp " Waiting for init message"
85- initMsg <- recvM
86- Debug. debugM Debug. Temp " Got init: " initMsg
87- syncState <- case initMsg of
88- ReceiverInitStream initMsg -> lift $ initialize initMsg
89- other -> throwError $ InitializationError (" Expected ReceiverInitStream message, got: " <> tShow other)
90- lift (emitterImpl syncState q) >>= \ case
91- Nothing -> pure ()
92- Just err -> throwError err
75+ \ (q@ Queues {receive}) -> do
76+ handleErr q $ do
77+ withErrorCont \ onErr -> do
78+ Debug. debugLogM Debug. Temp " Got queues"
79+ let recvM :: SyncM (FromReceiverMessage HashJWT Hash32 )
80+ recvM = do
81+ result <- liftIO $ atomically receive
82+ Debug. debugM Debug. Temp " Received: " result
83+ case result of
84+ Msg msg -> pure msg
85+ Err err -> onErr err
86+
87+ Debug. debugLogM Debug. Temp " Waiting for init message"
88+ initMsg <- recvM
89+ Debug. debugM Debug. Temp " Got init: " initMsg
90+ syncState <- case initMsg of
91+ ReceiverInitStream initMsg -> initialize onErr mayCallerUserId initMsg
92+ other -> onErr $ InitializationError (" Expected ReceiverInitStream message, got: " <> tShow other)
93+ lift (shareEmitter syncState q)
94+ >>= maybe (pure () ) (onErr)
9395 where
96+ -- Given a continuation-based action, run it in the base monad, capturing any early exits
97+ withErrorCont ::
98+ ((forall x . SyncError -> SyncM x ) -> SyncM () ) ->
99+ WebApp (Either SyncError () )
100+ withErrorCont action = do
101+ flip runContT pure $ callCC \ cc -> do
102+ Right <$> action (fmap absurd . cc . Left )
103+ -- If we get an error, send it to the client then shut down.
104+ handleErr ::
105+ Queues (MsgOrError err a ) o ->
106+ WebApp (Either err () ) ->
107+ WebApp ()
94108 handleErr (Queues {send, shutdown}) action = do
95- runExceptT action >>= \ case
109+ action >>= \ case
96110 Left err -> do
97111 atomically $ do
98112 send (Err err)
99113 liftIO $ shutdown
100114 Right r -> pure r
101115
102- initialize :: InitMsg ah -> m (SyncState sh hash )
103- initialize InitMsg {initMsgClientVersion, initMsgBranchRef, initMsgRootCausal, initMsgRequestedDepth} = do
104- let initialCausalHash = hashjwtHash initMsgRootCausal
116+ initialize :: (forall x . SyncError -> SyncM x ) -> (Maybe UserId ) -> InitMsg HashJWT -> SyncM (SyncState sh Hash32 )
117+ initialize onErr caller InitMsg {initMsgRootCausal, initMsgBranchRef} = do
118+ HashJWTClaims {hash = initialCausalHash} <-
119+ lift (HashJWT. verifyHashJWT caller initMsgRootCausal) >>= \ case
120+ Right ch -> pure ch
121+ Left err -> onErr $ HashJWTVerificationError (AuthN. authErrMsg err)
105122 validRequestsVar <- newTVarIO Set. empty
106- requestedEntitiesVar <- newTVarIO (Set. singleton initialCausalHash)
123+ requestedEntitiesVar <- newTVarIO (Set. singleton ( CausalEntity , initialCausalHash) )
107124 entitiesAlreadySentVar <- newTVarIO Set. empty
108- pure $ SyncState
109- { codebase = PG. codebaseEnv,
110- validRequestsVar,
111- requestedEntitiesVar,
112- entitiesAlreadySentVar
113- }
125+ (lift . runExceptT $ codebaseForBranchRef initMsgBranchRef) >>= \ case
126+ Left err -> case err of
127+ CodebaseLoadingErrorNoReadPermission {} -> onErr $ NoReadPermission initMsgBranchRef
128+ CodebaseLoadingErrorProjectNotFound {} -> onErr $ ProjectNotFound initMsgBranchRef
129+ CodebaseLoadingErrorUserNotFound {} -> onErr $ UserNotFound initMsgBranchRef
130+ CodebaseLoadingErrorInvalidBranchRef msg _ -> onErr $ InvalidBranchRef msg initMsgBranchRef
131+ Right codebase ->
132+ pure $
133+ SyncState
134+ { codebase,
135+ validRequestsVar,
136+ requestedEntitiesVar,
137+ entitiesAlreadySentVar
138+ }
114139
115140data SyncState sh hash = SyncState
116141 { codebase :: CodebaseEnv ,
@@ -129,27 +154,26 @@ shareEmitter ::
129154 (SyncState HashTag Hash32 ) ->
130155 Queues (MsgOrError SyncError (FromEmitterMessage Hash32 Text )) (MsgOrError SyncError (FromReceiverMessage HashJWT Hash32 )) ->
131156 WebApp (Maybe SyncError )
132- shareEmitter SyncState {requestedEntitiesVar, entitiesAlreadySentVar, validRequestsVar, codebase} (Queues {send, receive, shutdown}) = Ki. scoped $ \ scope -> do
133- errVar <- newEmptyTMVarIO
134- let onErr :: SyncError -> STM ()
135- onErr e = do
136- UnliftIO. putTMVar errVar e
137- Ki. fork scope $ sendWorker onErr
138- Ki. fork scope $ receiveWorker onErr
139- r <- atomically $ (Ki. awaitAll scope $> Nothing ) <|> (Just <$> UnliftIO. takeTMVar errVar)
140- liftIO $ shutdown
141- pure r
157+ shareEmitter SyncState {requestedEntitiesVar, entitiesAlreadySentVar, validRequestsVar, codebase} (Queues {send, receive}) = do
158+ Ki. scoped $ \ scope -> do
159+ errVar <- newEmptyTMVarIO
160+ let onErrSTM :: SyncError -> STM ()
161+ onErrSTM e = do
162+ UnliftIO. putTMVar errVar e
163+ Ki. fork scope $ sendWorker onErrSTM
164+ Ki. fork scope $ receiveWorker onErrSTM
165+ atomically ((Ki. awaitAll scope $> Nothing ) <|> (Just <$> UnliftIO. takeTMVar errVar))
142166 where
143167 sendWorker :: (SyncError -> STM () ) -> WebApp ()
144- sendWorker onErr = forever $ do
168+ sendWorker onErrSTM = forever $ do
145169 reqs <- atomically $ do
146170 reqs <- readTVar requestedEntitiesVar
147171 validRequests <- readTVar validRequestsVar
148172 let forbiddenRequests = Set. difference reqs validRequests
149173 validRequests <-
150174 if not (Set. null forbiddenRequests)
151175 then do
152- onErr (ForbiddenEntityRequest forbiddenRequests)
176+ onErrSTM (ForbiddenEntityRequest forbiddenRequests)
153177 pure $ Set. difference validRequests forbiddenRequests
154178 else do
155179 pure validRequests
@@ -179,11 +203,11 @@ shareEmitter SyncState {requestedEntitiesVar, entitiesAlreadySentVar, validReque
179203 send $ Msg (EmitterEntityMsg entity)
180204
181205 receiveWorker :: (SyncError -> STM () ) -> WebApp ()
182- receiveWorker onErr = forever $ do
206+ receiveWorker onErrSTM = forever $ do
183207 atomically $ do
184208 receive >>= \ case
185- Err err -> onErr err
186- Msg (ReceiverInitStream {}) -> onErr (InitializationError " Received duplicate ReceiverInitStream message" )
209+ Err err -> onErrSTM err
210+ Msg (ReceiverInitStream {}) -> onErrSTM (InitializationError " Received duplicate ReceiverInitStream message" )
187211 Msg (ReceiverEntityRequest (EntityRequestMsg {hashes})) -> do
188212 modifyTVar' requestedEntitiesVar (\ s -> Set. union s (Set. fromList hashes))
189213
0 commit comments