Skip to content

Commit b530809

Browse files
Feature: Add Tag System for user made Workflows
1 parent c6a9847 commit b530809

File tree

8 files changed

+254
-14
lines changed

8 files changed

+254
-14
lines changed

invokeai/app/api/routers/workflows.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,15 @@ async def get_workflow_thumbnail(
223223
raise HTTPException(status_code=404)
224224

225225

226+
@workflows_router.get("/tags", operation_id="get_all_tags")
227+
async def get_all_tags(
228+
categories: Optional[list[WorkflowCategory]] = Query(default=None, description="The categories to include"),
229+
) -> list[str]:
230+
"""Gets all unique tags from workflows"""
231+
232+
return ApiDependencies.invoker.services.workflow_records.get_all_tags(categories=categories)
233+
234+
226235
@workflows_router.get("/counts_by_tag", operation_id="get_counts_by_tag")
227236
async def get_counts_by_tag(
228237
tags: list[str] = Query(description="The tags to get counts for"),

invokeai/app/services/workflow_records/workflow_records_base.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,3 +74,11 @@ def counts_by_tag(
7474
def update_opened_at(self, workflow_id: str) -> None:
7575
"""Open a workflow."""
7676
pass
77+
78+
@abstractmethod
79+
def get_all_tags(
80+
self,
81+
categories: Optional[list[WorkflowCategory]] = None,
82+
) -> list[str]:
83+
"""Gets all unique tags from workflows."""
84+
pass

invokeai/app/services/workflow_records/workflow_records_sqlite.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,48 @@ def update_opened_at(self, workflow_id: str) -> None:
332332
(workflow_id,),
333333
)
334334

335+
def get_all_tags(
336+
self,
337+
categories: Optional[list[WorkflowCategory]] = None,
338+
) -> list[str]:
339+
with self._db.transaction() as cursor:
340+
conditions: list[str] = []
341+
params: list[str] = []
342+
343+
# Only get workflows that have tags
344+
conditions.append("tags IS NOT NULL AND tags != ''")
345+
346+
if categories:
347+
assert all(c in WorkflowCategory for c in categories)
348+
placeholders = ", ".join("?" for _ in categories)
349+
conditions.append(f"category IN ({placeholders})")
350+
params.extend([category.value for category in categories])
351+
352+
stmt = """--sql
353+
SELECT DISTINCT tags
354+
FROM workflow_library
355+
"""
356+
357+
if conditions:
358+
stmt += " WHERE " + " AND ".join(conditions)
359+
360+
cursor.execute(stmt, params)
361+
rows = cursor.fetchall()
362+
363+
# Parse comma-separated tags and collect unique tags
364+
all_tags: set[str] = set()
365+
366+
for row in rows:
367+
tags_value = row[0]
368+
if tags_value and isinstance(tags_value, str):
369+
# Tags are stored as comma-separated string
370+
for tag in tags_value.split(","):
371+
tag_stripped = tag.strip()
372+
if tag_stripped:
373+
all_tags.add(tag_stripped)
374+
375+
return sorted(all_tags)
376+
335377
def _sync_default_workflows(self) -> None:
336378
"""Syncs default workflows to the database. Internal use only."""
337379

invokeai/frontend/web/src/features/nodes/components/sidePanel/workflow/WorkflowLibrary/WorkflowLibrarySideNav.tsx

Lines changed: 131 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ import { memo, useCallback, useMemo } from 'react';
3131
import { useTranslation } from 'react-i18next';
3232
import { PiArrowCounterClockwiseBold, PiStarFill } from 'react-icons/pi';
3333
import { useDispatch } from 'react-redux';
34-
import { useGetCountsByTagQuery } from 'services/api/endpoints/workflows';
34+
import { useGetAllTagsQuery, useGetCountsByTagQuery } from 'services/api/endpoints/workflows';
3535

3636
export const WorkflowLibrarySideNav = () => {
3737
const { t } = useTranslation();
@@ -40,11 +40,11 @@ export const WorkflowLibrarySideNav = () => {
4040
<Flex h="full" minH={0} overflow="hidden" flexDir="column" w={64} gap={0}>
4141
<Flex flexDir="column" w="full" pb={2} gap={2}>
4242
<WorkflowLibraryViewButton view="recent">{t('workflows.recentlyOpened')}</WorkflowLibraryViewButton>
43-
<WorkflowLibraryViewButton view="yours">{t('workflows.yourWorkflows')}</WorkflowLibraryViewButton>
43+
<YourWorkflowsButton />
4444
</Flex>
4545
<Flex h="full" minH={0} overflow="hidden" flexDir="column">
4646
<BrowseWorkflowsButton />
47-
<DefaultsViewCheckboxesCollapsible />
47+
<TagCheckboxesCollapsible />
4848
</Flex>
4949
<Spacer />
5050
<NewWorkflowButton />
@@ -53,6 +53,40 @@ export const WorkflowLibrarySideNav = () => {
5353
);
5454
};
5555

56+
const YourWorkflowsButton = memo(() => {
57+
const { t } = useTranslation();
58+
const view = useAppSelector(selectWorkflowLibraryView);
59+
const dispatch = useAppDispatch();
60+
const selectedTags = useAppSelector(selectWorkflowLibrarySelectedTags);
61+
const resetTags = useCallback(() => {
62+
dispatch(workflowLibraryTagsReset());
63+
}, [dispatch]);
64+
65+
if (view === 'yours' && selectedTags.length > 0) {
66+
return (
67+
<ButtonGroup>
68+
<WorkflowLibraryViewButton view="yours" w="auto">
69+
{t('workflows.yourWorkflows')}
70+
</WorkflowLibraryViewButton>
71+
<Tooltip label={t('workflows.deselectAll')}>
72+
<IconButton
73+
onClick={resetTags}
74+
size="md"
75+
aria-label={t('workflows.deselectAll')}
76+
icon={<PiArrowCounterClockwiseBold size={12} />}
77+
variant="ghost"
78+
bg="base.700"
79+
color="base.50"
80+
/>
81+
</Tooltip>
82+
</ButtonGroup>
83+
);
84+
}
85+
86+
return <WorkflowLibraryViewButton view="yours">{t('workflows.yourWorkflows')}</WorkflowLibraryViewButton>;
87+
});
88+
YourWorkflowsButton.displayName = 'YourWorkflowsButton';
89+
5690
const BrowseWorkflowsButton = memo(() => {
5791
const { t } = useTranslation();
5892
const view = useAppSelector(selectWorkflowLibraryView);
@@ -89,31 +123,114 @@ BrowseWorkflowsButton.displayName = 'BrowseWorkflowsButton';
89123

90124
const overlayscrollbarsOptions = getOverlayScrollbarsParams({ visibility: 'visible' }).options;
91125

92-
const DefaultsViewCheckboxesCollapsible = memo(() => {
126+
const TagCheckboxesCollapsible = memo(() => {
93127
const view = useAppSelector(selectWorkflowLibraryView);
94128

95129
return (
96-
<Collapse in={view === 'defaults'}>
130+
<Collapse in={view === 'defaults' || view === 'yours'}>
97131
<Flex flexDir="column" gap={2} pl={4} py={2} overflow="hidden" h="100%" minH={0}>
98132
<OverlayScrollbarsComponent style={overlayScrollbarsStyles} options={overlayscrollbarsOptions}>
99133
<Flex flexDir="column" gap={2} overflow="auto">
100-
{WORKFLOW_LIBRARY_TAG_CATEGORIES.map((tagCategory) => (
101-
<TagCategory key={tagCategory.categoryTKey} tagCategory={tagCategory} />
102-
))}
134+
{view === 'yours' ? <DynamicTagsList /> : <StaticTagCategories />}
103135
</Flex>
104136
</OverlayScrollbarsComponent>
105137
</Flex>
106138
</Collapse>
107139
);
108140
});
109-
DefaultsViewCheckboxesCollapsible.displayName = 'DefaultsViewCheckboxes';
141+
TagCheckboxesCollapsible.displayName = 'TagCheckboxesCollapsible';
110142

111-
const tagCountQueryArg = {
112-
tags: WORKFLOW_LIBRARY_TAGS.map((tag) => tag.label),
113-
categories: ['default'],
114-
} satisfies Parameters<typeof useGetCountsByTagQuery>[0];
143+
const StaticTagCategories = memo(() => {
144+
return (
145+
<>
146+
{WORKFLOW_LIBRARY_TAG_CATEGORIES.map((tagCategory) => (
147+
<TagCategory key={tagCategory.categoryTKey} tagCategory={tagCategory} />
148+
))}
149+
</>
150+
);
151+
});
152+
StaticTagCategories.displayName = 'StaticTagCategories';
153+
154+
const DynamicTagsList = memo(() => {
155+
const { t } = useTranslation();
156+
const { data: tags, isLoading } = useGetAllTagsQuery({ categories: ['user'] });
157+
158+
if (isLoading) {
159+
return <Text color="base.400">{t('common.loading')}</Text>;
160+
}
161+
162+
if (!tags || tags.length === 0) {
163+
return null;
164+
}
165+
166+
return (
167+
<Flex flexDir="column" gap={2}>
168+
{tags.map((tag) => (
169+
<DynamicTagCheckbox key={tag} tag={tag} />
170+
))}
171+
</Flex>
172+
);
173+
});
174+
DynamicTagsList.displayName = 'DynamicTagsList';
175+
176+
const DynamicTagCheckbox = memo(({ tag }: { tag: string }) => {
177+
const dispatch = useAppDispatch();
178+
const selectedTags = useAppSelector(selectWorkflowLibrarySelectedTags);
179+
const isChecked = selectedTags.includes(tag);
180+
const count = useDynamicTagCount(tag);
181+
182+
const onChange = useCallback(() => {
183+
dispatch(workflowLibraryTagToggled(tag));
184+
}, [dispatch, tag]);
185+
186+
if (count === 0) {
187+
return null;
188+
}
189+
190+
return (
191+
<Flex alignItems="center" gap={2}>
192+
<Checkbox isChecked={isChecked} onChange={onChange} flexShrink={0} />
193+
<Text>{`${tag} (${count})`}</Text>
194+
</Flex>
195+
);
196+
});
197+
DynamicTagCheckbox.displayName = 'DynamicTagCheckbox';
198+
199+
const useDynamicTagCount = (tag: string) => {
200+
const queryArg = useMemo(
201+
() => ({
202+
tags: [tag],
203+
categories: ['user'] as ('user' | 'default')[],
204+
}),
205+
[tag]
206+
);
207+
208+
const queryOptions = useMemo(
209+
() => ({
210+
selectFromResult: ({ data }: { data?: Record<string, number> }) => ({
211+
count: data?.[tag] ?? 0,
212+
}),
213+
}),
214+
[tag]
215+
);
216+
217+
const { count } = useGetCountsByTagQuery(queryArg, queryOptions);
218+
return count;
219+
};
220+
221+
const useTagCountQueryArg = () => {
222+
const view = useAppSelector(selectWorkflowLibraryView);
223+
return useMemo(
224+
() => ({
225+
tags: WORKFLOW_LIBRARY_TAGS.map((tag) => tag.label),
226+
categories: view === 'yours' ? ['user'] : ['default'],
227+
}),
228+
[view]
229+
) satisfies Parameters<typeof useGetCountsByTagQuery>[0];
230+
};
115231

116232
const useCountForIndividualTag = (tag: string) => {
233+
const tagCountQueryArg = useTagCountQueryArg();
117234
const queryOptions = useMemo(
118235
() =>
119236
({
@@ -130,6 +247,7 @@ const useCountForIndividualTag = (tag: string) => {
130247
};
131248

132249
const useCountForTagCategory = (tagCategory: WorkflowTagCategory) => {
250+
const tagCountQueryArg = useTagCountQueryArg();
133251
const queryOptions = useMemo(
134252
() =>
135253
({

invokeai/frontend/web/src/features/nodes/components/sidePanel/workflow/WorkflowLibrary/WorkflowList.tsx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ const useInfiniteQueryAry = () => {
6060
direction,
6161
categories: getCategories(view),
6262
query: debouncedSearchTerm,
63-
tags: view === 'defaults' ? selectedTags : [],
63+
tags: view === 'defaults' || view === 'yours' ? selectedTags : [],
6464
has_been_opened: getHasBeenOpened(view),
6565
} satisfies Parameters<typeof useListWorkflowsInfiniteInfiniteQuery>[0];
6666
}, [orderBy, direction, view, debouncedSearchTerm, selectedTags]);

invokeai/frontend/web/src/services/api/endpoints/workflows.ts

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ export const workflowsApi = api.injectEndpoints({
3030
// Because this may change the order of the list, we need to invalidate the whole list
3131
{ type: 'Workflow', id: LIST_TAG },
3232
{ type: 'Workflow', id: workflow_id },
33+
'WorkflowTags',
3334
'WorkflowTagCounts',
3435
'WorkflowCategoryCounts',
3536
],
@@ -46,6 +47,7 @@ export const workflowsApi = api.injectEndpoints({
4647
invalidatesTags: [
4748
// Because this may change the order of the list, we need to invalidate the whole list
4849
{ type: 'Workflow', id: LIST_TAG },
50+
'WorkflowTags',
4951
'WorkflowTagCounts',
5052
'WorkflowCategoryCounts',
5153
],
@@ -61,10 +63,17 @@ export const workflowsApi = api.injectEndpoints({
6163
}),
6264
invalidatesTags: (response, error, workflow) => [
6365
{ type: 'Workflow', id: workflow.id },
66+
'WorkflowTags',
6467
'WorkflowTagCounts',
6568
'WorkflowCategoryCounts',
6669
],
6770
}),
71+
getAllTags: build.query<string[], { categories?: ('user' | 'default')[] } | void>({
72+
query: (params) => ({
73+
url: `${buildWorkflowsUrl('tags')}${params ? `?${queryString.stringify(params, { arrayFormat: 'none' })}` : ''}`,
74+
}),
75+
providesTags: ['WorkflowTags'],
76+
}),
6877
getCountsByTag: build.query<
6978
paths['/api/v1/workflows/counts_by_tag']['get']['responses']['200']['content']['application/json'],
7079
NonNullable<paths['/api/v1/workflows/counts_by_tag']['get']['parameters']['query']>
@@ -153,6 +162,7 @@ export const workflowsApi = api.injectEndpoints({
153162

154163
export const {
155164
useUpdateOpenedAtMutation,
165+
useGetAllTagsQuery,
156166
useGetCountsByTagQuery,
157167
useGetCountsByCategoryQuery,
158168
useLazyGetWorkflowQuery,

invokeai/frontend/web/src/services/api/index.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ const tagTypes = [
4747
'LoRAModel',
4848
'SDXLRefinerModel',
4949
'Workflow',
50+
'WorkflowTags',
5051
'WorkflowTagCounts',
5152
'WorkflowCategoryCounts',
5253
'StylePreset',

invokeai/frontend/web/src/services/api/schema.ts

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1688,6 +1688,26 @@ export type paths = {
16881688
patch?: never;
16891689
trace?: never;
16901690
};
1691+
"/api/v1/workflows/tags": {
1692+
parameters: {
1693+
query?: never;
1694+
header?: never;
1695+
path?: never;
1696+
cookie?: never;
1697+
};
1698+
/**
1699+
* Get All Tags
1700+
* @description Gets all unique tags from workflows
1701+
*/
1702+
get: operations["get_all_tags"];
1703+
put?: never;
1704+
post?: never;
1705+
delete?: never;
1706+
options?: never;
1707+
head?: never;
1708+
patch?: never;
1709+
trace?: never;
1710+
};
16911711
"/api/v1/workflows/counts_by_tag": {
16921712
parameters: {
16931713
query?: never;
@@ -28145,6 +28165,38 @@ export interface operations {
2814528165
};
2814628166
};
2814728167
};
28168+
get_all_tags: {
28169+
parameters: {
28170+
query?: {
28171+
/** @description The categories to include */
28172+
categories?: components["schemas"]["WorkflowCategory"][] | null;
28173+
};
28174+
header?: never;
28175+
path?: never;
28176+
cookie?: never;
28177+
};
28178+
requestBody?: never;
28179+
responses: {
28180+
/** @description Successful Response */
28181+
200: {
28182+
headers: {
28183+
[name: string]: unknown;
28184+
};
28185+
content: {
28186+
"application/json": string[];
28187+
};
28188+
};
28189+
/** @description Validation Error */
28190+
422: {
28191+
headers: {
28192+
[name: string]: unknown;
28193+
};
28194+
content: {
28195+
"application/json": components["schemas"]["HTTPValidationError"];
28196+
};
28197+
};
28198+
};
28199+
};
2814828200
get_counts_by_tag: {
2814928201
parameters: {
2815028202
query: {

0 commit comments

Comments
 (0)