Skip to content

Commit eeb2c0f

Browse files
committed
Merge branch 'add-replicate-own-deployment' into inline_probabilities
2 parents 2d32d48 + ddb3848 commit eeb2c0f

File tree

2 files changed

+22
-8
lines changed

2 files changed

+22
-8
lines changed

docs/docs/models/replicate.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ To run a [🤗 Transformers](./hf.html) model on Replicate, you need to:
1010

1111
1. Export the environment variable `REPLICATE_API_TOKEN` with the credential to use to authenticate the request.
1212

13-
2. Set the `transport=` argument to your model to `replicate:ORG/MODEL`, matching the name with which the model was uploaded.
13+
2. Set the `endpoint=` argument to your model to `replicate:ORG/MODEL`, matching the name with which the model was uploaded. If you want to use models from your organization's deployments, set the `endpoint=` argument to your deployment to `replicate:deployment/ORG/MODEL`.
1414

1515
3. Set the `tokenizer=` argument to your model to a huggingface transformers name from which correct configuration for the tokenizer in use can be downloaded.
1616

src/lmql/models/lmtp/lmtp_replicate_client.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ def __init__(self, model_identifier, session, endpoint, **kwargs):
2121
else: # FIXME: Allow API key to be passed in kwargs?
2222
raise Exception('Please define REPLICATE_API_TOKEN as an environment variable to use Replicate models')
2323

24+
self.model_validated = False
25+
self.use_deployment_endpoint = False
26+
2427
endpoint = endpoint.removeprefix('replicate:')
2528
if len(endpoint) == 0:
2629
endpoint = model_identifier
@@ -31,13 +34,19 @@ def __init__(self, model_identifier, session, endpoint, **kwargs):
3134
self.model_identifier = endpoint
3235
self.model_version = None
3336
elif len(endpoint_pieces) == 3:
34-
# passed a name/version pair
35-
self.model_identifier = '/'.join(endpoint_pieces[:2])
36-
self.model_version = endpoint_pieces[-1]
37+
# check if it is a deployment endpoint
38+
if endpoint_pieces[0] == 'deployment':
39+
self.model_identifier = '/'.join(endpoint_pieces[1:3])
40+
self.model_version = None
41+
self.use_deployment_endpoint = True
42+
self.model_validated = True
43+
else:
44+
# passed a name/version pair
45+
self.model_identifier = '/'.join(endpoint_pieces[:2])
46+
self.model_version = endpoint_pieces[-1]
3747
else:
38-
raise Exception('Unknown endpoint descriptor for replicate; should be owner/model or owner/model/version')
48+
raise Exception('Unknown endpoint descriptor for replicate; should be owner/model, owner/model/version' or 'deployment/owner/model')
3949

40-
self.model_validated = False
4150
self.session = session
4251
self.stream_id = 0
4352
self.handler = None
@@ -81,8 +90,13 @@ async def submit_batch(self, batch):
8190
if self.model_version is None or not self.model_validated:
8291
await self.check_model()
8392
# FIXME: Maybe store id to use for later cancel calls?
84-
body = {"input": {"ops_batch_json": json.dumps(batch)}, "stream": True, "version": self.model_version}
85-
async with self.session.post('https://api.replicate.com/v1/predictions',
93+
if self.use_deployment_endpoint:
94+
body = {"input": {"ops_batch_json": json.dumps(batch)}, "stream": True}
95+
endpoint = f'https://api.replicate.com/v1/deployments/{self.model_identifier}/predictions'
96+
else:
97+
body = {"input": {"ops_batch_json": json.dumps(batch)}, "stream": True, "version": self.model_version}
98+
endpoint = 'https://api.replicate.com/v1/predictions'
99+
async with self.session.post(endpoint,
86100
headers={
87101
'Authorization': f'Token {self.api_key}',
88102
'Content-Type': 'application/json'

0 commit comments

Comments
 (0)