forked from K-tang-mkv/ai_factory
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathupload_model.py
More file actions
133 lines (110 loc) · 4.04 KB
/
upload_model.py
File metadata and controls
133 lines (110 loc) · 4.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
"""A script that pushes a model from disk to the subnet for evaluation.
Usage:
python scripts/upload_model.py --load_model_dir <path to model> --hf_repo_id my-username/my-project --competition_id competitionID --wallet.name coldkey --wallet.hotkey hotkey
Prerequisites:
1. HF_ACCESS_TOKEN is set in the environment or .env file.
2. load_model_dir points to a directory containing a previously trained model, with relevant Hugging Face files (e.g. config.json).
3. Your miner is registered
"""
import asyncio
import os
import argparse
import constants
from taoverse.metagraph import utils as metagraph_utils
from taoverse.model.storage.chain.chain_model_metadata_store import (
ChainModelMetadataStore,
)
from taoverse.model.storage.hugging_face.hugging_face_model_store import (
HuggingFaceModelStore,
)
from taoverse.utilities import utils as taoverse_utils
from taoverse.utilities.enum_action import IntEnumAction
import factory as fact
import bittensor as bt
from competitions.data import CompetitionId
from dotenv import load_dotenv
load_dotenv() # take environment variables from .env.
os.environ["TOKENIZERS_PARALLELISM"] = "true"
def get_config():
# Initialize an argument parser
parser = argparse.ArgumentParser()
parser.add_argument(
"--hf_repo_id",
type=str,
help="The hugging face repo id, which should include the org or user and repo name. E.g. jdoe/pretraining",
)
parser.add_argument(
"--load_model_dir",
type=str,
default=None,
help="If provided, loads a previously trained HF model from the specified directory",
)
parser.add_argument(
"--netuid",
type=int,
default=constants.SUBNET_UID,
help="The subnet UID.",
)
parser.add_argument(
"--competition_id",
type=CompetitionId,
action=IntEnumAction,
help="competition to mine for (use --list-competitions to get all competitions)",
)
parser.add_argument(
"--list_competitions", action="store_true", help="Print out all competitions"
)
parser.add_argument(
"--update_repo_visibility",
action="store_true",
help="If true, the repo will be made public after uploading.",
)
# Include wallet and logging arguments from bittensor
bt.wallet.add_args(parser)
bt.subtensor.add_args(parser)
bt.logging.add_args(parser)
# Parse the arguments and create a configuration namespace
config = bt.config(parser)
return config
async def main(config: bt.config):
# Create bittensor objects.
bt.logging(config=config)
taoverse_utils.logging.reinitialize()
taoverse_utils.configure_logging(config)
wallet = bt.wallet(config=config)
subtensor = bt.subtensor(config=config)
metagraph = subtensor.metagraph(config.netuid)
chain_metadata_store = ChainModelMetadataStore(
subtensor=subtensor,
subnet_uid=config.netuid,
wallet=wallet,
)
# Make sure we're registered and have a HuggingFace token.
metagraph_utils.assert_registered(wallet, metagraph)
HuggingFaceModelStore.assert_access_token_exists()
# Get current model parameters
model_constraints = constants.MODEL_CONSTRAINTS_BY_COMPETITION_ID.get(
config.competition_id, None
)
if model_constraints is None:
raise RuntimeError(
f"Could not find current competition for id: {config.competition_id}"
)
# Load the model from disk and push it to the chain and Hugging Face.
model = fact.mining.load_local_model(config.load_model_dir, model_constraints.kwargs)
await fact.mining.push(
model,
config.hf_repo_id,
wallet,
config.competition_id,
metadata_store=chain_metadata_store,
update_repo_visibility=config.update_repo_visibility,
)
if __name__ == "__main__":
# Parse and print configuration
config = get_config()
if config.list_competitions:
print(constants.COMPETITION_SCHEDULE_BY_BLOCK)
else:
print(config)
asyncio.run(main(config))