Skip to content

Commit c0f1283

Browse files
committed
Implement more of the server-side sync3
1 parent f7b8cee commit c0f1283

File tree

4 files changed

+133
-111
lines changed

4 files changed

+133
-111
lines changed

share-api/src/Share/Web/UCM/SyncV2/Impl.hs

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,34 +7,29 @@ import Codec.Serialise qualified as CBOR
77
import Conduit qualified as C
88
import Control.Concurrent.STM qualified as STM
99
import Control.Concurrent.STM.TBMQueue qualified as STM
10-
import Control.Monad.Except (ExceptT (ExceptT), withExceptT)
10+
import Control.Monad.Except (withExceptT)
1111
import Control.Monad.Trans.Except (runExceptT)
1212
import Data.Binary.Builder qualified as Builder
1313
import Data.Set qualified as Set
14-
import Data.Text.Encoding qualified as Text
1514
import Data.Vector qualified as Vector
1615
import Ki.Unlifted qualified as Ki
1716
import Servant
1817
import Servant.Conduit (ConduitToSourceIO (..))
1918
import Servant.Types.SourceT (SourceT (..))
2019
import Servant.Types.SourceT qualified as SourceT
21-
import Share.Codebase qualified as Codebase
22-
import Share.IDs (ProjectBranchShortHand (..), ProjectReleaseShortHand (..), ProjectShortHand (..), UserHandle, UserId)
20+
import Share.IDs (UserId)
2321
import Share.IDs qualified as IDs
2422
import Share.Postgres qualified as PG
2523
import Share.Postgres.Causal.Queries qualified as CausalQ
2624
import Share.Postgres.Cursors qualified as Cursor
27-
import Share.Postgres.Queries qualified as PGQ
28-
import Share.Postgres.Users.Queries qualified as UserQ
2925
import Share.Prelude
30-
import Share.Project (Project (..))
31-
import Share.User (User (..))
3226
import Share.Utils.Logging qualified as Logging
3327
import Share.Utils.Unison (hash32ToCausalHash)
3428
import Share.Web.App
35-
import Share.Web.Authorization qualified as AuthZ
3629
import Share.Web.Errors
3730
import Share.Web.UCM.Sync.HashJWT qualified as HashJWT
31+
import Share.Web.UCM.SyncCommon.Impl
32+
import Share.Web.UCM.SyncCommon.Types
3833
import Share.Web.UCM.SyncV2.Queries qualified as SSQ
3934
import Share.Web.UCM.SyncV2.Types (IsCausalSpine (..), IsLibRoot (..))
4035
import U.Codebase.Sqlite.Orphans ()
@@ -58,17 +53,6 @@ server mayUserId =
5853
causalDependenciesStream = causalDependenciesStreamImpl mayUserId
5954
}
6055

61-
parseBranchRef :: SyncV2.BranchRef -> Either Text (Either ProjectReleaseShortHand ProjectBranchShortHand)
62-
parseBranchRef (SyncV2.BranchRef branchRef) =
63-
case parseRelease <|> parseBranch of
64-
Just a -> Right a
65-
Nothing -> Left $ "Invalid repo info: " <> branchRef
66-
where
67-
parseBranch :: Maybe (Either ProjectReleaseShortHand ProjectBranchShortHand)
68-
parseBranch = fmap Right . eitherToMaybe $ IDs.fromText @ProjectBranchShortHand branchRef
69-
parseRelease :: Maybe (Either ProjectReleaseShortHand ProjectBranchShortHand)
70-
parseRelease = fmap Left . eitherToMaybe $ IDs.fromText @ProjectReleaseShortHand branchRef
71-
7256
downloadEntitiesStreamImpl :: Maybe UserId -> SyncV2.DownloadEntitiesRequest -> WebApp (SourceIO (SyncV2.CBORStream SyncV2.DownloadEntitiesChunk))
7357
downloadEntitiesStreamImpl mayCallerUserId (SyncV2.DownloadEntitiesRequest {causalHash = causalHashJWT, branchRef, knownHashes}) = do
7458
either emitErr id <$> runExceptT do

src/Share/Web/UCM/SyncCommon/Impl.hs

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,36 @@
1-
module Share.Web.UCM.SyncCommon.Impl (codebaseForBranchRef) where
1+
module Share.Web.UCM.SyncCommon.Impl
2+
( parseBranchRef,
3+
codebaseForBranchRef,
4+
)
5+
where
6+
7+
import Control.Monad.Except (ExceptT (ExceptT))
8+
import Servant
9+
import Share.Codebase qualified as Codebase
10+
import Share.IDs (ProjectBranchShortHand (..), ProjectReleaseShortHand (..), ProjectShortHand (..))
11+
import Share.IDs qualified as IDs
12+
import Share.Postgres qualified as PG
13+
import Share.Postgres.Queries qualified as PGQ
14+
import Share.Postgres.Users.Queries qualified as UserQ
15+
import Share.Prelude
16+
import Share.Project (Project (..))
17+
import Share.User (User (..))
18+
import Share.Web.App
19+
import Share.Web.Authorization qualified as AuthZ
20+
import Share.Web.UCM.SyncCommon.Types
21+
import U.Codebase.Sqlite.Orphans ()
22+
import Unison.SyncV2.Types qualified as SyncV2
23+
24+
parseBranchRef :: SyncV2.BranchRef -> Either Text (Either ProjectReleaseShortHand ProjectBranchShortHand)
25+
parseBranchRef (SyncV2.BranchRef branchRef) =
26+
case parseRelease <|> parseBranch of
27+
Just a -> Right a
28+
Nothing -> Left $ "Invalid repo info: " <> branchRef
29+
where
30+
parseBranch :: Maybe (Either ProjectReleaseShortHand ProjectBranchShortHand)
31+
parseBranch = fmap Right . eitherToMaybe $ IDs.fromText @ProjectBranchShortHand branchRef
32+
parseRelease :: Maybe (Either ProjectReleaseShortHand ProjectBranchShortHand)
33+
parseRelease = fmap Left . eitherToMaybe $ IDs.fromText @ProjectReleaseShortHand branchRef
234

335
codebaseForBranchRef :: SyncV2.BranchRef -> (ExceptT CodebaseLoadingError WebApp Codebase.CodebaseEnv)
436
codebaseForBranchRef branchRef = do
Lines changed: 14 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,27 @@
1-
module Share.Web.UCM.SyncCommon.Types () where
1+
{-# LANGUAGE DataKinds #-}
22

3-
newtype BranchRef = BranchRef {unBranchRef :: Text}
4-
deriving (Serialise, Eq, Show, Ord, ToJSON, FromJSON) via Text
5-
6-
instance From (ProjectAndBranch ProjectName ProjectBranchName) BranchRef where
7-
from pab = BranchRef $ from pab
3+
module Share.Web.UCM.SyncCommon.Types (CodebaseLoadingError (..)) where
84

5+
import Data.Text.Encoding qualified as Text
6+
import Servant
7+
import Share.IDs
8+
import Share.IDs qualified as IDs
9+
import Share.Prelude
10+
import Share.Utils.Logging qualified as Logging
11+
import Share.Web.Errors
12+
import Unison.SyncCommon.Types
913

1014
data CodebaseLoadingError
1115
= CodebaseLoadingErrorProjectNotFound ProjectShortHand
1216
| CodebaseLoadingErrorUserNotFound UserHandle
13-
| CodebaseLoadingErrorNoReadPermission SyncV2.BranchRef
14-
| CodebaseLoadingErrorInvalidBranchRef Text SyncV2.BranchRef
17+
| CodebaseLoadingErrorNoReadPermission BranchRef
18+
| CodebaseLoadingErrorInvalidBranchRef Text BranchRef
1519
deriving stock (Show)
1620
deriving (Logging.Loggable) via Logging.ShowLoggable Logging.UserFault CodebaseLoadingError
1721

1822
instance ToServerError CodebaseLoadingError where
1923
toServerError = \case
2024
CodebaseLoadingErrorProjectNotFound projectShortHand -> (ErrorID "codebase-loading:project-not-found", Servant.err404 {errBody = from . Text.encodeUtf8 $ "Project not found: " <> (IDs.toText projectShortHand)})
2125
CodebaseLoadingErrorUserNotFound userHandle -> (ErrorID "codebase-loading:user-not-found", Servant.err404 {errBody = from . Text.encodeUtf8 $ "User not found: " <> (IDs.toText userHandle)})
22-
CodebaseLoadingErrorNoReadPermission branchRef -> (ErrorID "codebase-loading:no-read-permission", Servant.err403 {errBody = from . Text.encodeUtf8 $ "No read permission for branch ref: " <> (SyncV2.unBranchRef branchRef)})
23-
CodebaseLoadingErrorInvalidBranchRef err branchRef -> (ErrorID "codebase-loading:invalid-branch-ref", Servant.err400 {errBody = from . Text.encodeUtf8 $ "Invalid branch ref: " <> err <> " " <> (SyncV2.unBranchRef branchRef)})
24-
25-
codebaseForBranchRef :: SyncV2.BranchRef -> (ExceptT CodebaseLoadingError WebApp Codebase.CodebaseEnv)
26-
codebaseForBranchRef branchRef = do
27-
case parseBranchRef branchRef of
28-
Left err -> throwError (CodebaseLoadingErrorInvalidBranchRef err branchRef)
29-
Right (Left (ProjectReleaseShortHand {userHandle, projectSlug})) -> do
30-
let projectShortHand = ProjectShortHand {userHandle, projectSlug}
31-
(Project {ownerUserId = projectOwnerUserId}, contributorId) <- ExceptT . PG.tryRunTransaction $ do
32-
project <- PGQ.projectByShortHand projectShortHand `whenNothingM` throwError (CodebaseLoadingErrorProjectNotFound $ projectShortHand)
33-
pure (project, Nothing)
34-
authZToken <- lift AuthZ.checkDownloadFromProjectBranchCodebase `whenLeftM` \_err -> throwError (CodebaseLoadingErrorNoReadPermission branchRef)
35-
let codebaseLoc = Codebase.codebaseLocationForProjectBranchCodebase projectOwnerUserId contributorId
36-
pure $ Codebase.codebaseEnv authZToken codebaseLoc
37-
Right (Right (ProjectBranchShortHand {userHandle, projectSlug, contributorHandle})) -> do
38-
let projectShortHand = ProjectShortHand {userHandle, projectSlug}
39-
(Project {ownerUserId = projectOwnerUserId}, contributorId) <- ExceptT . PG.tryRunTransaction $ do
40-
project <- (PGQ.projectByShortHand projectShortHand) `whenNothingM` throwError (CodebaseLoadingErrorProjectNotFound projectShortHand)
41-
mayContributorUserId <- for contributorHandle \ch -> fmap user_id $ (UserQ.userByHandle ch) `whenNothingM` throwError (CodebaseLoadingErrorUserNotFound ch)
42-
pure (project, mayContributorUserId)
43-
authZToken <- lift AuthZ.checkDownloadFromProjectBranchCodebase `whenLeftM` \_err -> throwError (CodebaseLoadingErrorNoReadPermission branchRef)
44-
let codebaseLoc = Codebase.codebaseLocationForProjectBranchCodebase projectOwnerUserId contributorId
45-
pure $ Codebase.codebaseEnv authZToken codebaseLoc
26+
CodebaseLoadingErrorNoReadPermission branchRef -> (ErrorID "codebase-loading:no-read-permission", Servant.err403 {errBody = from . Text.encodeUtf8 $ "No read permission for branch ref: " <> (unBranchRef branchRef)})
27+
CodebaseLoadingErrorInvalidBranchRef err branchRef -> (ErrorID "codebase-loading:invalid-branch-ref", Servant.err400 {errBody = from . Text.encodeUtf8 $ "Invalid branch ref: " <> err <> " " <> (unBranchRef branchRef)})

src/Share/Web/UCM/SyncV3/Impl.hs

Lines changed: 82 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
module Share.Web.UCM.SyncV3.Impl (server) where
66

77
import Control.Lens hiding ((.=))
8+
import Control.Monad.Cont (ContT (..), MonadCont (..))
89
import Control.Monad.Except (runExceptT)
9-
import Control.Monad.Trans.Except (ExceptT)
1010
import Data.Set qualified as Set
1111
import Data.Set.Lens (setOf)
1212
import Data.Vector (Vector)
@@ -18,17 +18,21 @@ import Share.IDs (UserId)
1818
import Share.Postgres qualified as PG
1919
import Share.Prelude
2020
import Share.Web.App
21+
import Share.Web.Authentication.Types qualified as AuthN
2122
import Share.Web.Authorization qualified as AuthZ
23+
import Share.Web.UCM.Sync.HashJWT qualified as HashJWT
24+
import Share.Web.UCM.SyncCommon.Impl (codebaseForBranchRef)
25+
import Share.Web.UCM.SyncCommon.Types
2226
import Share.Web.UCM.SyncV3.API qualified as SyncV3
2327
import Share.Web.UCM.SyncV3.Queries qualified as Q
2428
import U.Codebase.Sqlite.Orphans ()
29+
import Unison.Debug qualified as Debug
2530
import Unison.Hash32 (Hash32)
26-
import Unison.Share.API.Hash (HashJWT)
31+
import Unison.Share.API.Hash (HashJWT, HashJWTClaims (..))
2732
import Unison.SyncV3.Types
2833
import Unison.Util.Websockets (Queues (..), withQueues)
2934
import UnliftIO qualified
3035
import UnliftIO.STM
31-
import Unison.Debug qualified as Debug
3236

3337
-- Amount of entities to buffer from the network into the send/recv queues.
3438
sendBufferSize :: Natural
@@ -48,69 +52,90 @@ server mayUserId =
4852
{ downloadEntities = downloadEntitiesImpl mayUserId
4953
}
5054

55+
type SyncM = ContT (Either SyncError ()) WebApp
56+
5157
downloadEntitiesImpl :: Maybe UserId -> WS.Connection -> WebApp ()
52-
downloadEntitiesImpl _mayCallerUserId conn = do
58+
downloadEntitiesImpl mayCallerUserId conn = do
5359
Debug.debugLogM Debug.Temp "Got connection"
5460
-- Auth is currently done via HashJWTs
5561
_authZReceipt <- AuthZ.checkDownloadFromUserCodebase
56-
doSyncEmitter shareEmitter conn
62+
doSyncEmitter mayCallerUserId conn
5763

5864
-- | Given a helper which understands how to wire things into its backend, This
5965
-- implements the sync emitter logic which is independent of the backend.
6066
doSyncEmitter ::
61-
forall m.
62-
(MonadUnliftIO m) =>
63-
( (SyncState HashTag Hash32) ->
64-
Queues (MsgOrError SyncError (FromEmitterMessage Hash32 Text)) (MsgOrError SyncError (FromReceiverMessage HashJWT Hash32)) ->
65-
m (Maybe SyncError)
66-
) ->
67+
Maybe UserId ->
6768
WS.Connection ->
68-
m ()
69-
doSyncEmitter emitterImpl conn = do
69+
WebApp ()
70+
doSyncEmitter mayCallerUserId conn = do
7071
withQueues @(MsgOrError SyncError (FromEmitterMessage Hash32 Text)) @(MsgOrError SyncError (FromReceiverMessage HashJWT Hash32))
7172
recvBufferSize
7273
sendBufferSize
7374
conn
74-
\(q@Queues {receive}) -> handleErr q $ do
75-
Debug.debugLogM Debug.Temp "Got queues"
76-
let recvM :: ExceptT SyncError m (FromReceiverMessage HashJWT Hash32)
77-
recvM = do
78-
result <- liftIO $ atomically receive
79-
Debug.debugM Debug.Temp "Received: " result
80-
case result of
81-
Msg msg -> pure msg
82-
Err err -> throwError err
83-
84-
Debug.debugLogM Debug.Temp "Waiting for init message"
85-
initMsg <- recvM
86-
Debug.debugM Debug.Temp "Got init: " initMsg
87-
syncState <- case initMsg of
88-
ReceiverInitStream initMsg -> lift $ initialize initMsg
89-
other -> throwError $ InitializationError ("Expected ReceiverInitStream message, got: " <> tShow other)
90-
lift (emitterImpl syncState q) >>= \case
91-
Nothing -> pure ()
92-
Just err -> throwError err
75+
\(q@Queues {receive}) -> do
76+
handleErr q $ do
77+
withErrorCont \onErr -> do
78+
Debug.debugLogM Debug.Temp "Got queues"
79+
let recvM :: SyncM (FromReceiverMessage HashJWT Hash32)
80+
recvM = do
81+
result <- liftIO $ atomically receive
82+
Debug.debugM Debug.Temp "Received: " result
83+
case result of
84+
Msg msg -> pure msg
85+
Err err -> onErr err
86+
87+
Debug.debugLogM Debug.Temp "Waiting for init message"
88+
initMsg <- recvM
89+
Debug.debugM Debug.Temp "Got init: " initMsg
90+
syncState <- case initMsg of
91+
ReceiverInitStream initMsg -> initialize onErr mayCallerUserId initMsg
92+
other -> onErr $ InitializationError ("Expected ReceiverInitStream message, got: " <> tShow other)
93+
lift (shareEmitter syncState q)
94+
>>= maybe (pure ()) (onErr)
9395
where
96+
-- Given a continuation-based action, run it in the base monad, capturing any early exits
97+
withErrorCont ::
98+
((forall x. SyncError -> SyncM x) -> SyncM ()) ->
99+
WebApp (Either SyncError ())
100+
withErrorCont action = do
101+
flip runContT pure $ callCC \cc -> do
102+
Right <$> action (fmap absurd . cc . Left)
103+
-- If we get an error, send it to the client then shut down.
104+
handleErr ::
105+
Queues (MsgOrError err a) o ->
106+
WebApp (Either err ()) ->
107+
WebApp ()
94108
handleErr (Queues {send, shutdown}) action = do
95-
runExceptT action >>= \case
109+
action >>= \case
96110
Left err -> do
97111
atomically $ do
98112
send (Err err)
99113
liftIO $ shutdown
100114
Right r -> pure r
101115

102-
initialize :: InitMsg ah -> m (SyncState sh hash)
103-
initialize InitMsg{initMsgClientVersion, initMsgBranchRef, initMsgRootCausal, initMsgRequestedDepth} = do
104-
let initialCausalHash = hashjwtHash initMsgRootCausal
116+
initialize :: (forall x. SyncError -> SyncM x) -> (Maybe UserId) -> InitMsg HashJWT -> SyncM (SyncState sh Hash32)
117+
initialize onErr caller InitMsg {initMsgRootCausal, initMsgBranchRef} = do
118+
HashJWTClaims {hash = initialCausalHash} <-
119+
lift (HashJWT.verifyHashJWT caller initMsgRootCausal) >>= \case
120+
Right ch -> pure ch
121+
Left err -> onErr $ HashJWTVerificationError (AuthN.authErrMsg err)
105122
validRequestsVar <- newTVarIO Set.empty
106-
requestedEntitiesVar <- newTVarIO (Set.singleton initialCausalHash)
123+
requestedEntitiesVar <- newTVarIO (Set.singleton (CausalEntity, initialCausalHash))
107124
entitiesAlreadySentVar <- newTVarIO Set.empty
108-
pure $ SyncState
109-
{ codebase = PG.codebaseEnv,
110-
validRequestsVar,
111-
requestedEntitiesVar,
112-
entitiesAlreadySentVar
113-
}
125+
(lift . runExceptT $ codebaseForBranchRef initMsgBranchRef) >>= \case
126+
Left err -> case err of
127+
CodebaseLoadingErrorNoReadPermission {} -> onErr $ NoReadPermission initMsgBranchRef
128+
CodebaseLoadingErrorProjectNotFound {} -> onErr $ ProjectNotFound initMsgBranchRef
129+
CodebaseLoadingErrorUserNotFound {} -> onErr $ UserNotFound initMsgBranchRef
130+
CodebaseLoadingErrorInvalidBranchRef msg _ -> onErr $ InvalidBranchRef msg initMsgBranchRef
131+
Right codebase ->
132+
pure $
133+
SyncState
134+
{ codebase,
135+
validRequestsVar,
136+
requestedEntitiesVar,
137+
entitiesAlreadySentVar
138+
}
114139

115140
data SyncState sh hash = SyncState
116141
{ codebase :: CodebaseEnv,
@@ -129,27 +154,26 @@ shareEmitter ::
129154
(SyncState HashTag Hash32) ->
130155
Queues (MsgOrError SyncError (FromEmitterMessage Hash32 Text)) (MsgOrError SyncError (FromReceiverMessage HashJWT Hash32)) ->
131156
WebApp (Maybe SyncError)
132-
shareEmitter SyncState {requestedEntitiesVar, entitiesAlreadySentVar, validRequestsVar, codebase} (Queues {send, receive, shutdown}) = Ki.scoped $ \scope -> do
133-
errVar <- newEmptyTMVarIO
134-
let onErr :: SyncError -> STM ()
135-
onErr e = do
136-
UnliftIO.putTMVar errVar e
137-
Ki.fork scope $ sendWorker onErr
138-
Ki.fork scope $ receiveWorker onErr
139-
r <- atomically $ (Ki.awaitAll scope $> Nothing) <|> (Just <$> UnliftIO.takeTMVar errVar)
140-
liftIO $ shutdown
141-
pure r
157+
shareEmitter SyncState {requestedEntitiesVar, entitiesAlreadySentVar, validRequestsVar, codebase} (Queues {send, receive}) = do
158+
Ki.scoped $ \scope -> do
159+
errVar <- newEmptyTMVarIO
160+
let onErrSTM :: SyncError -> STM ()
161+
onErrSTM e = do
162+
UnliftIO.putTMVar errVar e
163+
Ki.fork scope $ sendWorker onErrSTM
164+
Ki.fork scope $ receiveWorker onErrSTM
165+
atomically ((Ki.awaitAll scope $> Nothing) <|> (Just <$> UnliftIO.takeTMVar errVar))
142166
where
143167
sendWorker :: (SyncError -> STM ()) -> WebApp ()
144-
sendWorker onErr = forever $ do
168+
sendWorker onErrSTM = forever $ do
145169
reqs <- atomically $ do
146170
reqs <- readTVar requestedEntitiesVar
147171
validRequests <- readTVar validRequestsVar
148172
let forbiddenRequests = Set.difference reqs validRequests
149173
validRequests <-
150174
if not (Set.null forbiddenRequests)
151175
then do
152-
onErr (ForbiddenEntityRequest forbiddenRequests)
176+
onErrSTM (ForbiddenEntityRequest forbiddenRequests)
153177
pure $ Set.difference validRequests forbiddenRequests
154178
else do
155179
pure validRequests
@@ -179,11 +203,11 @@ shareEmitter SyncState {requestedEntitiesVar, entitiesAlreadySentVar, validReque
179203
send $ Msg (EmitterEntityMsg entity)
180204

181205
receiveWorker :: (SyncError -> STM ()) -> WebApp ()
182-
receiveWorker onErr = forever $ do
206+
receiveWorker onErrSTM = forever $ do
183207
atomically $ do
184208
receive >>= \case
185-
Err err -> onErr err
186-
Msg (ReceiverInitStream {}) -> onErr (InitializationError "Received duplicate ReceiverInitStream message")
209+
Err err -> onErrSTM err
210+
Msg (ReceiverInitStream {}) -> onErrSTM (InitializationError "Received duplicate ReceiverInitStream message")
187211
Msg (ReceiverEntityRequest (EntityRequestMsg {hashes})) -> do
188212
modifyTVar' requestedEntitiesVar (\s -> Set.union s (Set.fromList hashes))
189213

0 commit comments

Comments
 (0)