diff --git a/src/rest-server/src/config/v2/protocol.js b/src/rest-server/src/config/v2/protocol.js index 163172f9..4f86c569 100644 --- a/src/rest-server/src/config/v2/protocol.js +++ b/src/rest-server/src/config/v2/protocol.js @@ -173,6 +173,11 @@ const protocolSchema = { }, minItems: 1, }, + jobType: { + type: 'string', + enum: ['inference', 'training', 'others'], + default: 'others', + }, parameters: { type: 'object', additionalProperties: true, diff --git a/src/rest-server/src/controllers/v2/job.js b/src/rest-server/src/controllers/v2/job.js index 6f59d3d3..339bac7c 100644 --- a/src/rest-server/src/controllers/v2/job.js +++ b/src/rest-server/src/controllers/v2/job.js @@ -86,6 +86,24 @@ const list = asyncHandler(async (req, res) => { if ('tagsNotContain' in req.query) { tagsNotContainFilter.name = req.query.tagsNotContain.split(','); } + if ('jobType' in req.query) { + // validate jobType values + const validJobTypes = ['inference', 'training', 'others']; + const requestedTypes = req.query.jobType.split(','); + const invalidTypes = requestedTypes.filter(type => !validJobTypes.includes(type)); + if (invalidTypes.length > 0) { + throw createError( + 'Bad Request', + 'InvalidParametersError', + `Invalid job type(s): ${invalidTypes.join(', ')}` + ); + } + if (Array.isArray(tagsContainFilter.name)) { + tagsContainFilter.name.push(...requestedTypes); + } else { + tagsContainFilter.name = requestedTypes; + } + } if ('keyword' in req.query) { // match text in username, jobname, or vc filters[Op.or] = [ @@ -199,6 +217,7 @@ const update = asyncHandler(async (req, res) => { const jobName = res.locals.protocol.name; const userName = req.user.username; const frameworkName = `${userName}~${jobName}`; + const jobType = res.locals.protocol.jobType || 'others'; // check duplicate job try { @@ -216,6 +235,7 @@ const update = asyncHandler(async (req, res) => { } } await job.put(frameworkName, res.locals.protocol, req.body); + await job.addTag(frameworkName, jobType); res.status(status('Accepted')).json({ status: status('Accepted'), message: `Update job ${jobName} for user ${userName} successfully.`, @@ -368,10 +388,10 @@ const getLogs = asyncHandler(async (req, res) => { throw error.code === 'NoTaskLogError' ? error : createError( - 'Internal Server Error', - 'UnknownError', - 'Failed to get log list', - ); + 'Internal Server Error', + 'UnknownError', + 'Failed to get log list', + ); } }); diff --git a/src/rest-server/src/middlewares/v2/protocol.js b/src/rest-server/src/middlewares/v2/protocol.js index 4e505be4..b989665e 100644 --- a/src/rest-server/src/middlewares/v2/protocol.js +++ b/src/rest-server/src/middlewares/v2/protocol.js @@ -120,6 +120,34 @@ const protocolValidate = (protocolYAML) => { } } } + + // check jobType + if ('jobType' in protocolObj) { + if (protocolObj.jobType === 'inference') { + // check parameters for inference job + if (!('parameters' in protocolObj)) { + throw createError( + 'Bad Request', + 'InvalidProtocolError', + `The following parameters must be specified for inference job: +INTERNAL_SERVER_IP=$PAI_HOST_IP_taskrole_0 +INTERNAL_SERVER_PORT=$PAI_PORT_LIST_taskrole_0_http +API_KEY=[any string]`, + ); + } + const requiredParams = ['INTERNAL_SERVER_IP', 'INTERNAL_SERVER_PORT', 'API_KEY']; + for (const param of requiredParams) { + if (!(param in protocolObj.parameters)) { + throw createError( + 'Bad Request', + 'InvalidProtocolError', + `Parameter ${param} must be specified for inference job.`, + ); + } + } + } + } + for (const taskRole of Object.keys(protocolObj.taskRoles)) { for (const field of prerequisiteFields) { if ( @@ -194,7 +222,7 @@ const protocolRender = (protocolObj) => { ], $output: protocolObj.prerequisites.output[ - protocolObj.taskRoles[taskRole].output + protocolObj.taskRoles[taskRole].output ], $data: protocolObj.prerequisites.data[protocolObj.taskRoles[taskRole].data], diff --git a/src/rest-server/src/models/v2/job/k8s.js b/src/rest-server/src/models/v2/job/k8s.js index 86ad75f8..5a8ba976 100644 --- a/src/rest-server/src/models/v2/job/k8s.js +++ b/src/rest-server/src/models/v2/job/k8s.js @@ -1000,33 +1000,43 @@ const list = async ( Object.keys(tagsContainFilter).length !== 0 || Object.keys(tagsNotContainFilter).length !== 0 ) { - filters.name = {}; - // tagsContain + // Build name filters by querying Tag table directly to avoid using QueryGenerator + const nameFilter = {}; + + // tagsContain -> include frameworks whose name appears in Tag rows matched by tagsContainFilter if (Object.keys(tagsContainFilter).length !== 0) { - const queryContainFrameworkName = databaseModel.sequelize.dialect.QueryGenerator.selectQuery( - 'tags', - { - attributes: ['frameworkName'], - where: tagsContainFilter, - }, - ); - filters.name[Sequelize.Op.in] = Sequelize.literal(` - (${queryContainFrameworkName.slice(0, -1)}) - `); + const containRows = await databaseModel.Tag.findAll({ + attributes: ['frameworkName'], + where: tagsContainFilter, + raw: true, + }); + const containNames = [...new Set(containRows.map((r) => r.frameworkName))]; + // if no tags match, result is empty + if (containNames.length === 0) { + if (withTotalCount) { + return { totalCount: 0, data: [] }; + } else { + return []; + } + } + nameFilter[Sequelize.Op.in] = containNames; } - // tagsNotContain + + // tagsNotContain -> exclude frameworks whose name appears in Tag rows matched by tagsNotContainFilter if (Object.keys(tagsNotContainFilter).length !== 0) { - const queryNotContainFrameworkName = databaseModel.sequelize.dialect.QueryGenerator.selectQuery( - 'tags', - { - attributes: ['frameworkName'], - where: tagsNotContainFilter, - }, - ); - filters.name[Sequelize.Op.notIn] = Sequelize.literal(` - (${queryNotContainFrameworkName.slice(0, -1)}) - `); + const notContainRows = await databaseModel.Tag.findAll({ + attributes: ['frameworkName'], + where: tagsNotContainFilter, + raw: true, + }); + const notContainNames = [...new Set(notContainRows.map((r) => r.frameworkName))]; + if (notContainNames.length > 0) { + nameFilter[Sequelize.Op.notIn] = notContainNames; + } } + + // merge with any existing name filter + filters.name = Object.assign({}, filters.name || {}, nameFilter); } frameworks = await databaseModel.Framework.findAll({