Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions invokeai/app/invocations/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,12 @@ class BoardField(BaseModel):
board_id: str = Field(description="The id of the board")


class StylePresetField(BaseModel):
"""A style preset primitive field"""

style_preset_id: str = Field(description="The id of the style preset")


class DenoiseMaskField(BaseModel):
"""An inpaint mask field"""

Expand Down
57 changes: 57 additions & 0 deletions invokeai/app/invocations/prompt_template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.fields import InputField, OutputField, StylePresetField, UIComponent
from invokeai.app.services.shared.invocation_context import InvocationContext


@invocation_output("prompt_template_output")
class PromptTemplateOutput(BaseInvocationOutput):
"""Output for the Prompt Template node"""

positive_prompt: str = OutputField(description="The positive prompt with the template applied")
negative_prompt: str = OutputField(description="The negative prompt with the template applied")


@invocation(
"prompt_template",
title="Prompt Template",
tags=["prompt", "template", "style", "preset"],
category="prompt",
version="1.0.0",
)
class PromptTemplateInvocation(BaseInvocation):
"""Applies a Style Preset template to positive and negative prompts.

Select a Style Preset and provide positive/negative prompts. The node replaces
{prompt} placeholders in the template with your input prompts.
"""

style_preset: StylePresetField = InputField(
description="The Style Preset to use as a template",
)
positive_prompt: str = InputField(
default="",
description="The positive prompt to insert into the template's {prompt} placeholder",
ui_component=UIComponent.Textarea,
)
negative_prompt: str = InputField(
default="",
description="The negative prompt to insert into the template's {prompt} placeholder",
ui_component=UIComponent.Textarea,
)

def invoke(self, context: InvocationContext) -> PromptTemplateOutput:
# Fetch the style preset from the database
style_preset = context._services.style_preset_records.get(self.style_preset.style_preset_id)

# Get the template prompts
positive_template = style_preset.preset_data.positive_prompt
negative_template = style_preset.preset_data.negative_prompt

# Replace {prompt} placeholder with the input prompts
rendered_positive = positive_template.replace("{prompt}", self.positive_prompt)
rendered_negative = negative_template.replace("{prompt}", self.negative_prompt)

return PromptTemplateOutput(
positive_prompt=rendered_positive,
negative_prompt=rendered_negative,
)
4 changes: 3 additions & 1 deletion invokeai/frontend/web/public/locales/en.json
Original file line number Diff line number Diff line change
Expand Up @@ -2651,7 +2651,9 @@
"useForTemplate": "Use For Prompt Template",
"viewList": "View Template List",
"viewModeTooltip": "This is how your prompt will look with your currently selected template. To edit your prompt, click anywhere in the text box.",
"togglePromptPreviews": "Toggle Prompt Previews"
"togglePromptPreviews": "Toggle Prompt Previews",
"selectPreset": "Select Style Preset",
"noMatchingPresets": "No matching presets"
},

"ui": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ import {
isStringFieldInputTemplate,
isStringGeneratorFieldInputInstance,
isStringGeneratorFieldInputTemplate,
isStylePresetFieldInputInstance,
isStylePresetFieldInputTemplate,
} from 'features/nodes/types/field';
import type { NodeFieldElement } from 'features/nodes/types/workflow';
import { memo } from 'react';
Expand All @@ -67,6 +69,7 @@ import ColorFieldInputComponent from './inputs/ColorFieldInputComponent';
import EnumFieldInputComponent from './inputs/EnumFieldInputComponent';
import ImageFieldInputComponent from './inputs/ImageFieldInputComponent';
import SchedulerFieldInputComponent from './inputs/SchedulerFieldInputComponent';
import StylePresetFieldInputComponent from './inputs/StylePresetFieldInputComponent';

type Props = {
nodeId: string;
Expand Down Expand Up @@ -206,6 +209,13 @@ export const InputFieldRenderer = memo(({ nodeId, fieldName, settings }: Props)
return <BoardFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}

if (isStylePresetFieldInputTemplate(template)) {
if (!isStylePresetFieldInputInstance(field)) {
return null;
}
return <StylePresetFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}

if (isModelIdentifierFieldInputTemplate(template)) {
if (!isModelIdentifierFieldInputInstance(field)) {
return null;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { fieldStylePresetValueChanged } from 'features/nodes/store/nodesSlice';
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
import type { StylePresetFieldInputInstance, StylePresetFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useListStylePresetsQuery } from 'services/api/endpoints/stylePresets';

import type { FieldComponentProps } from './types';

const StylePresetFieldInputComponent = (
props: FieldComponentProps<StylePresetFieldInputInstance, StylePresetFieldInputTemplate>
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const { t } = useTranslation();
const { data: stylePresets, isLoading } = useListStylePresetsQuery();

const options = useMemo<ComboboxOption[]>(() => {
const _options: ComboboxOption[] = [];
if (stylePresets) {
for (const preset of stylePresets) {
_options.push({
label: preset.name,
value: preset.id,
});
}
}
return _options;
}, [stylePresets]);

const onChange = useCallback<ComboboxOnChange>(
(v) => {
if (!v) {
return;
}

dispatch(
fieldStylePresetValueChanged({
nodeId,
fieldName: field.name,
value: { style_preset_id: v.value },
})
);
},
[dispatch, field.name, nodeId]
);

const value = useMemo(() => {
const _value = field.value;
if (!_value) {
return null;
}
return options.find((o) => o.value === _value.style_preset_id) ?? null;
}, [field.value, options]);

const noOptionsMessage = useCallback(() => t('stylePresets.noMatchingPresets'), [t]);

return (
<Combobox
className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`}
value={value}
options={options}
onChange={onChange}
placeholder={isLoading ? t('common.loading') : t('stylePresets.selectPreset')}
noOptionsMessage={noOptionsMessage}
/>
);
};

export default memo(StylePresetFieldInputComponent);
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import type {
StringFieldCollectionValue,
StringFieldValue,
StringGeneratorFieldValue,
StylePresetFieldValue,
} from 'features/nodes/types/field';
import {
zBoardFieldValue,
Expand All @@ -62,6 +63,7 @@ import {
zStringFieldCollectionValue,
zStringFieldValue,
zStringGeneratorFieldValue,
zStylePresetFieldValue,
} from 'features/nodes/types/field';
import type { AnyEdge, AnyNode } from 'features/nodes/types/invocation';
import { isInvocationNode, isNotesNode } from 'features/nodes/types/invocation';
Expand Down Expand Up @@ -438,6 +440,9 @@ const slice = createSlice({
fieldBoardValueChanged: (state, action: FieldValueAction<BoardFieldValue>) => {
fieldValueReducer(state, action, zBoardFieldValue);
},
fieldStylePresetValueChanged: (state, action: FieldValueAction<StylePresetFieldValue>) => {
fieldValueReducer(state, action, zStylePresetFieldValue);
},
fieldImageValueChanged: (state, action: FieldValueAction<ImageFieldValue>) => {
fieldValueReducer(state, action, zImageFieldValue);
},
Expand Down Expand Up @@ -588,6 +593,7 @@ export const {
fieldBoardValueChanged,
fieldBooleanValueChanged,
fieldColorValueChanged,
fieldStylePresetValueChanged,
fieldEnumModelValueChanged,
fieldImageValueChanged,
fieldImageCollectionValueChanged,
Expand Down
4 changes: 4 additions & 0 deletions invokeai/frontend/web/src/features/nodes/types/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ export const zBoardField = z.object({
});
export type BoardField = z.infer<typeof zBoardField>;

export const zStylePresetField = z.object({
style_preset_id: z.string().trim().min(1),
});

export const zColorField = z.object({
r: z.number().int().min(0).max(255),
g: z.number().int().min(0).max(255),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ export const NO_PAN_CLASS = 'nopan';
export const FIELD_COLORS: { [key: string]: string } = {
BoardField: 'purple.500',
BooleanField: 'green.500',
StylePresetField: 'purple.400',
CLIPField: 'green.500',
ColorField: 'pink.300',
ConditioningField: 'cyan.500',
Expand Down
32 changes: 32 additions & 0 deletions invokeai/frontend/web/src/features/nodes/types/field.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import {
zModelIdentifierField,
zModelType,
zSchedulerField,
zStylePresetField,
} from './common';

/**
Expand Down Expand Up @@ -169,6 +170,11 @@ const zBoardFieldType = zFieldTypeBase.extend({
originalType: zStatelessFieldType.optional(),
});

const zStylePresetFieldType = zFieldTypeBase.extend({
name: z.literal('StylePresetField'),
originalType: zStatelessFieldType.optional(),
});

const zColorFieldType = zFieldTypeBase.extend({
name: z.literal('ColorField'),
originalType: zStatelessFieldType.optional(),
Expand Down Expand Up @@ -205,6 +211,7 @@ const zStatefulFieldType = z.union([
zEnumFieldType,
zImageFieldType,
zBoardFieldType,
zStylePresetFieldType,
zModelIdentifierFieldType,
zColorFieldType,
zSchedulerFieldType,
Expand Down Expand Up @@ -607,6 +614,27 @@ export const isBoardFieldInputInstance = buildInstanceTypeGuard(zBoardFieldInput
export const isBoardFieldInputTemplate = buildTemplateTypeGuard<BoardFieldInputTemplate>('BoardField');
// #endregion

// #region StylePresetField
export const zStylePresetFieldValue = zStylePresetField.optional();
const zStylePresetFieldInputInstance = zFieldInputInstanceBase.extend({
value: zStylePresetFieldValue,
});
const zStylePresetFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zStylePresetFieldType,
originalType: zFieldType.optional(),
default: zStylePresetFieldValue,
});
const zStylePresetFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zStylePresetFieldType,
});
export type StylePresetFieldValue = z.infer<typeof zStylePresetFieldValue>;
export type StylePresetFieldInputInstance = z.infer<typeof zStylePresetFieldInputInstance>;
export type StylePresetFieldInputTemplate = z.infer<typeof zStylePresetFieldInputTemplate>;
export const isStylePresetFieldInputInstance = buildInstanceTypeGuard(zStylePresetFieldInputInstance);
export const isStylePresetFieldInputTemplate =
buildTemplateTypeGuard<StylePresetFieldInputTemplate>('StylePresetField');
// #endregion

// #region ColorField
export const zColorFieldValue = zColorField.optional();
const zColorFieldInputInstance = zFieldInputInstanceBase.extend({
Expand Down Expand Up @@ -1257,6 +1285,7 @@ export const zStatefulFieldValue = z.union([
zImageFieldValue,
zImageFieldCollectionValue,
zBoardFieldValue,
zStylePresetFieldValue,
zModelIdentifierFieldValue,
zColorFieldValue,
zSchedulerFieldValue,
Expand Down Expand Up @@ -1284,6 +1313,7 @@ const zStatefulFieldInputInstance = z.union([
zImageFieldInputInstance,
zImageFieldCollectionInputInstance,
zBoardFieldInputInstance,
zStylePresetFieldInputInstance,
zModelIdentifierFieldInputInstance,
zColorFieldInputInstance,
zSchedulerFieldInputInstance,
Expand All @@ -1310,6 +1340,7 @@ const zStatefulFieldInputTemplate = z.union([
zImageFieldInputTemplate,
zImageFieldCollectionInputTemplate,
zBoardFieldInputTemplate,
zStylePresetFieldInputTemplate,
zModelIdentifierFieldInputTemplate,
zColorFieldInputTemplate,
zSchedulerFieldInputTemplate,
Expand Down Expand Up @@ -1337,6 +1368,7 @@ const zStatefulFieldOutputTemplate = z.union([
zImageFieldOutputTemplate,
zImageFieldCollectionOutputTemplate,
zBoardFieldOutputTemplate,
zStylePresetFieldOutputTemplate,
zModelIdentifierFieldOutputTemplate,
zColorFieldOutputTemplate,
zSchedulerFieldOutputTemplate,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ const FIELD_VALUE_FALLBACK_MAP: Record<StatefulFieldType['name'], FieldValue> =
ModelIdentifierField: undefined,
SchedulerField: 'dpmpp_3m_k',
StringField: '',
StylePresetField: undefined,
FloatGeneratorField: undefined,
IntegerGeneratorField: undefined,
StringGeneratorField: undefined,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import type {
StringFieldCollectionInputTemplate,
StringFieldInputTemplate,
StringGeneratorFieldInputTemplate,
StylePresetFieldInputTemplate,
} from 'features/nodes/types/field';
import {
getFloatGeneratorArithmeticSequenceDefaults,
Expand Down Expand Up @@ -289,6 +290,20 @@ const buildBoardFieldInputTemplate: FieldInputTemplateBuilder<BoardFieldInputTem
return template;
};

const buildStylePresetFieldInputTemplate: FieldInputTemplateBuilder<StylePresetFieldInputTemplate> = ({
schemaObject,
baseField,
fieldType,
}) => {
const template: StylePresetFieldInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? undefined,
};

return template;
};

const buildImageFieldInputTemplate: FieldInputTemplateBuilder<ImageFieldInputTemplate> = ({
schemaObject,
baseField,
Expand Down Expand Up @@ -460,6 +475,7 @@ const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputTemplate
ModelIdentifierField: buildModelIdentifierFieldInputTemplate,
SchedulerField: buildSchedulerFieldInputTemplate,
StringField: buildStringFieldInputTemplate,
StylePresetField: buildStylePresetFieldInputTemplate,
FloatGeneratorField: buildFloatGeneratorFieldInputTemplate,
IntegerGeneratorField: buildIntegerGeneratorFieldInputTemplate,
StringGeneratorField: buildStringGeneratorFieldInputTemplate,
Expand Down
Loading