Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/rest-server/src/config/v2/protocol.js
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,11 @@ const protocolSchema = {
},
minItems: 1,
},
jobType: {
type: 'string',
enum: ['inference', 'training', 'others'],
default: 'others',
},
parameters: {
type: 'object',
additionalProperties: true,
Expand Down
28 changes: 24 additions & 4 deletions src/rest-server/src/controllers/v2/job.js
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,24 @@ const list = asyncHandler(async (req, res) => {
if ('tagsNotContain' in req.query) {
tagsNotContainFilter.name = req.query.tagsNotContain.split(',');
}
if ('jobType' in req.query) {
// validate jobType values
const validJobTypes = ['inference', 'training', 'others'];
const requestedTypes = req.query.jobType.split(',');
const invalidTypes = requestedTypes.filter(type => !validJobTypes.includes(type));
if (invalidTypes.length > 0) {
throw createError(
'Bad Request',
'InvalidParametersError',
`Invalid job type(s): ${invalidTypes.join(', ')}`
);
}
if (Array.isArray(tagsContainFilter.name)) {
tagsContainFilter.name.push(...requestedTypes);
} else {
tagsContainFilter.name = requestedTypes;
}
}
if ('keyword' in req.query) {
// match text in username, jobname, or vc
filters[Op.or] = [
Expand Down Expand Up @@ -199,6 +217,7 @@ const update = asyncHandler(async (req, res) => {
const jobName = res.locals.protocol.name;
const userName = req.user.username;
const frameworkName = `${userName}~${jobName}`;
const jobType = res.locals.protocol.jobType || 'others';

// check duplicate job
try {
Expand All @@ -216,6 +235,7 @@ const update = asyncHandler(async (req, res) => {
}
}
await job.put(frameworkName, res.locals.protocol, req.body);
await job.addTag(frameworkName, jobType);
res.status(status('Accepted')).json({
status: status('Accepted'),
message: `Update job ${jobName} for user ${userName} successfully.`,
Expand Down Expand Up @@ -368,10 +388,10 @@ const getLogs = asyncHandler(async (req, res) => {
throw error.code === 'NoTaskLogError'
? error
: createError(
'Internal Server Error',
'UnknownError',
'Failed to get log list',
);
'Internal Server Error',
'UnknownError',
'Failed to get log list',
);
}
});

Expand Down
30 changes: 29 additions & 1 deletion src/rest-server/src/middlewares/v2/protocol.js
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,34 @@ const protocolValidate = (protocolYAML) => {
}
}
}

// check jobType
if ('jobType' in protocolObj) {
if (protocolObj.jobType === 'inference') {
// check parameters for inference job
if (!('parameters' in protocolObj)) {
throw createError(
'Bad Request',
'InvalidProtocolError',
`The following parameters must be specified for inference job:
INTERNAL_SERVER_IP=$PAI_HOST_IP_taskrole_0
INTERNAL_SERVER_PORT=$PAI_PORT_LIST_taskrole_0_http
API_KEY=[any string]`,
);
}
const requiredParams = ['INTERNAL_SERVER_IP', 'INTERNAL_SERVER_PORT', 'API_KEY'];
for (const param of requiredParams) {
if (!(param in protocolObj.parameters)) {
throw createError(
'Bad Request',
'InvalidProtocolError',
`Parameter ${param} must be specified for inference job.`,
);
}
}
}
}

for (const taskRole of Object.keys(protocolObj.taskRoles)) {
for (const field of prerequisiteFields) {
if (
Expand Down Expand Up @@ -194,7 +222,7 @@ const protocolRender = (protocolObj) => {
],
$output:
protocolObj.prerequisites.output[
protocolObj.taskRoles[taskRole].output
protocolObj.taskRoles[taskRole].output
],
$data:
protocolObj.prerequisites.data[protocolObj.taskRoles[taskRole].data],
Expand Down
56 changes: 33 additions & 23 deletions src/rest-server/src/models/v2/job/k8s.js
Original file line number Diff line number Diff line change
Expand Up @@ -1000,33 +1000,43 @@ const list = async (
Object.keys(tagsContainFilter).length !== 0 ||
Object.keys(tagsNotContainFilter).length !== 0
) {
filters.name = {};
// tagsContain
// Build name filters by querying Tag table directly to avoid using QueryGenerator
const nameFilter = {};

// tagsContain -> include frameworks whose name appears in Tag rows matched by tagsContainFilter
if (Object.keys(tagsContainFilter).length !== 0) {
const queryContainFrameworkName = databaseModel.sequelize.dialect.QueryGenerator.selectQuery(
'tags',
{
attributes: ['frameworkName'],
where: tagsContainFilter,
},
);
filters.name[Sequelize.Op.in] = Sequelize.literal(`
(${queryContainFrameworkName.slice(0, -1)})
`);
const containRows = await databaseModel.Tag.findAll({
attributes: ['frameworkName'],
where: tagsContainFilter,
raw: true,
});
const containNames = [...new Set(containRows.map((r) => r.frameworkName))];
// if no tags match, result is empty
if (containNames.length === 0) {
if (withTotalCount) {
return { totalCount: 0, data: [] };
} else {
return [];
}
}
nameFilter[Sequelize.Op.in] = containNames;
}
// tagsNotContain

// tagsNotContain -> exclude frameworks whose name appears in Tag rows matched by tagsNotContainFilter
if (Object.keys(tagsNotContainFilter).length !== 0) {
const queryNotContainFrameworkName = databaseModel.sequelize.dialect.QueryGenerator.selectQuery(
'tags',
{
attributes: ['frameworkName'],
where: tagsNotContainFilter,
},
);
filters.name[Sequelize.Op.notIn] = Sequelize.literal(`
(${queryNotContainFrameworkName.slice(0, -1)})
`);
const notContainRows = await databaseModel.Tag.findAll({
attributes: ['frameworkName'],
where: tagsNotContainFilter,
raw: true,
});
const notContainNames = [...new Set(notContainRows.map((r) => r.frameworkName))];
if (notContainNames.length > 0) {
nameFilter[Sequelize.Op.notIn] = notContainNames;
}
}

// merge with any existing name filter
filters.name = Object.assign({}, filters.name || {}, nameFilter);
}

frameworks = await databaseModel.Framework.findAll({
Expand Down