|
19 | 19 | from .feature import Feature, Features, Data |
20 | 20 | from .source import Source, Sources, SubsetSources |
21 | 21 | from .model import Model |
22 | | -from .util.cli import CMD, Arg, SourcesCMD, FeaturesCMD, ModelCMD, PortCMD, \ |
23 | | - KeysCMD, ListEntrypoint, ParseSourcesAction |
| 22 | +from .df import Input, \ |
| 23 | + MemoryInputSet, \ |
| 24 | + MemoryInputSetConfig, \ |
| 25 | + StringInputSetContext |
| 26 | +from .util.cli.base import CMD, Arg |
| 27 | +from .util.cli.parser import ParseSourcesAction |
| 28 | +from .util.cli.cmd import SourcesCMD, FeaturesCMD, ModelCMD, PortCMD, \ |
| 29 | + KeysCMD, ListEntrypoint, DataFlowFacilitatorCMD |
24 | 30 |
|
25 | 31 | class Version(CMD): |
26 | 32 | ''' |
@@ -137,6 +143,112 @@ async def run(self): |
137 | 143 | repo.merge(await self.dest.repo(repo.src_url)) |
138 | 144 | await self.dest.update(repo) |
139 | 145 |
|
| 146 | +class OperationsCMD(DataFlowFacilitatorCMD, SourcesCMD): |
| 147 | + |
| 148 | + arg_sources = SourcesCMD.arg_sources.modify(required=False) |
| 149 | + arg_caching = Arg('-caching', help='Re-run operations or use last', |
| 150 | + required=False, default=False, action='store_true') |
| 151 | + arg_cacheless = Arg('-cacheless', |
| 152 | + help='Do not re-run operations if these features are missing', |
| 153 | + required=False, default=[], nargs='+') |
| 154 | + arg_update = Arg('-update', help='Update repo with sources', required=False, |
| 155 | + default=False, action='store_true') |
| 156 | + |
| 157 | +class OperationsAll(OperationsCMD): |
| 158 | + '''Operations all repos in sources''' |
| 159 | + |
| 160 | + # async def operations(self, sources, features): |
| 161 | + async def run_operations(self, sources): |
| 162 | + # Orchestrate the running of these operations |
| 163 | + async with self.dff( |
| 164 | + input_network = self.input_network( |
| 165 | + self.input_network.config(self)), |
| 166 | + operation_network = self.operation_network( |
| 167 | + self.operation_network.config(self)), |
| 168 | + lock_network = self.lock_network( |
| 169 | + self.lock_network.config(self)), |
| 170 | + rchecker = self.rchecker( |
| 171 | + self.rchecker.config(self)), |
| 172 | + opimpn = self.opimpn(self.opimpn.config(self)), |
| 173 | + orchestrator = self.orchestrator( |
| 174 | + self.orchestrator.config(self)) |
| 175 | + ) as dffctx: |
| 176 | + |
| 177 | + # Create the inputs for the ouput operations |
| 178 | + output_specs = [Input(value=value, |
| 179 | + definition=self.definitions[def_name], |
| 180 | + parents=False) \ |
| 181 | + for value, def_name in self.output_specs] |
| 182 | + |
| 183 | + # Add our inputs to the input network with the context being the |
| 184 | + # repo src_url |
| 185 | + async for repo in sources.repos(): |
| 186 | + inputs = [] |
| 187 | + for value, def_name in self.inputs: |
| 188 | + inputs.append(Input(value=value, |
| 189 | + definition=self.definitions[def_name], |
| 190 | + parents=False)) |
| 191 | + if self.repo_def: |
| 192 | + inputs.append(Input(value=repo.src_url, |
| 193 | + definition=self.definitions[self.repo_def], |
| 194 | + parents=False)) |
| 195 | + |
| 196 | + await dffctx.ictx.add( |
| 197 | + MemoryInputSet( |
| 198 | + MemoryInputSetConfig( |
| 199 | + ctx=StringInputSetContext(repo.src_url), |
| 200 | + inputs=inputs + output_specs |
| 201 | + ) |
| 202 | + ) |
| 203 | + ) |
| 204 | + |
| 205 | + async for ctx, results in dffctx.evaluate(): |
| 206 | + ctx_str = (await ctx.handle()).as_string() |
| 207 | + # TODO Make a RepoInputSetContext which would let us store the |
| 208 | + # repo instead of recalling it by the URL |
| 209 | + repo = await sources.repo(ctx_str) |
| 210 | + # Remap the output operations to their feature |
| 211 | + remap = {} |
| 212 | + for output_operation_name, sub, feature_name in self.remap: |
| 213 | + if not output_operation_name in results: |
| 214 | + self.logger.error('[%s] results do not contain %s: %s', |
| 215 | + ctx_str, |
| 216 | + output_operation_name, |
| 217 | + results) |
| 218 | + continue |
| 219 | + if not sub in results[output_operation_name]: |
| 220 | + self.logger.error('[%s] %s does not contain %s: %s', |
| 221 | + ctx_str, |
| 222 | + sub, |
| 223 | + results[output_operation_name]) |
| 224 | + continue |
| 225 | + remap[feature_name] = results[output_operation_name][sub] |
| 226 | + # Store the results |
| 227 | + repo.evaluated(remap) |
| 228 | + yield repo |
| 229 | + if self.update: |
| 230 | + await sources.update(repo) |
| 231 | + |
| 232 | + async def run(self): |
| 233 | + # async with self.sources as sources, self.features as features: |
| 234 | + async with self.sources as sources: |
| 235 | + # async for repo in self.operations(sources, features): |
| 236 | + async for repo in self.run_operations(sources): |
| 237 | + yield repo |
| 238 | + |
| 239 | +class OperationsRepo(OperationsAll, KeysCMD): |
| 240 | + '''Operations features on individual repos''' |
| 241 | + |
| 242 | + def __init__(self, *args, **kwargs): |
| 243 | + super().__init__(*args, **kwargs) |
| 244 | + self.sources = SubsetSources(*self.sources, keys=self.keys) |
| 245 | + |
| 246 | +class Operations(CMD): |
| 247 | + '''Run operations for repos''' |
| 248 | + |
| 249 | + repo = OperationsRepo |
| 250 | + _all = OperationsAll |
| 251 | + |
140 | 252 | class EvaluateCMD(FeaturesCMD, SourcesCMD): |
141 | 253 |
|
142 | 254 | arg_sources = SourcesCMD.arg_sources.modify(required=False) |
@@ -288,6 +400,7 @@ class CLI(CMD): |
288 | 400 | train = Train |
289 | 401 | accuracy = Accuracy |
290 | 402 | predict = Predict |
| 403 | + operations = Operations |
291 | 404 | evaluate = Evaluate |
292 | 405 | service = services() |
293 | 406 | applicable = Applicable |
0 commit comments