Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 63 additions & 13 deletions dbos/workflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,16 +82,32 @@ type workflowOutcome[R any] struct {
// The type parameter R represents the expected return type of the workflow.
// Handles can be used to wait for workflow completion, check status, and retrieve results.
type WorkflowHandle[R any] interface {
GetResult() (R, error) // Wait for workflow completion and return the result
GetStatus() (WorkflowStatus, error) // Get current workflow status without waiting
GetWorkflowID() string // Get the unique workflow identifier
GetResult(opts ...GetResultOption) (R, error) // Wait for workflow completion and return the result
GetStatus() (WorkflowStatus, error) // Get current workflow status without waiting
GetWorkflowID() string // Get the unique workflow identifier
}

type baseWorkflowHandle struct {
workflowID string
dbosContext DBOSContext
}

// GetResultOption is a functional option for configuring GetResult behavior.
type GetResultOption func(*getResultOptions)

// getResultOptions holds the configuration for GetResult execution.
type getResultOptions struct {
timeout time.Duration
}

// WithHandleTimeout sets a timeout for the GetResult operation.
// If the timeout is reached before the workflow completes, GetResult will return a timeout error.
func WithHandleTimeout(timeout time.Duration) GetResultOption {
return func(opts *getResultOptions) {
opts.timeout = timeout
}
}

// GetStatus returns the current status of the workflow from the database
// If the DBOSContext is running in client mode, do not load input and outputs
func (h *baseWorkflowHandle) GetStatus() (WorkflowStatus, error) {
Expand Down Expand Up @@ -162,12 +178,33 @@ type workflowHandle[R any] struct {
outcomeChan chan workflowOutcome[R]
}

func (h *workflowHandle[R]) GetResult() (R, error) {
outcome, ok := <-h.outcomeChan // Blocking read
if !ok {
// Return an error if the channel was closed. In normal operations this would happen if GetResul() is called twice on a handler. The first call should get the buffered result, the second call find zero values (channel is empty and closed).
return *new(R), errors.New("workflow result channel is already closed. Did you call GetResult() twice on the same workflow handle?")
func (h *workflowHandle[R]) GetResult(opts ...GetResultOption) (R, error) {
options := &getResultOptions{}
for _, opt := range opts {
opt(options)
}

var timeoutChan <-chan time.Time
if options.timeout > 0 {
timeoutChan = time.After(options.timeout)
}

select {
case outcome, ok := <-h.outcomeChan:
if !ok {
// Return error if channel closed (happens when GetResult() called twice)
return *new(R), errors.New("workflow result channel is already closed. Did you call GetResult() twice on the same workflow handle?")
}
return h.processOutcome(outcome)
case <-h.dbosContext.Done():
return *new(R), context.Cause(h.dbosContext)
case <-timeoutChan:
return *new(R), fmt.Errorf("workflow result timeout after %v: %w", options.timeout, context.DeadlineExceeded)
}
}

// processOutcome handles the common logic for processing workflow outcomes
func (h *workflowHandle[R]) processOutcome(outcome workflowOutcome[R]) (R, error) {
// If we are calling GetResult inside a workflow, record the result as a step result
workflowState, ok := h.dbosContext.Value(workflowStateKey).(*workflowState)
isWithinWorkflow := ok && workflowState != nil
Expand Down Expand Up @@ -198,9 +235,22 @@ type workflowPollingHandle[R any] struct {
baseWorkflowHandle
}

func (h *workflowPollingHandle[R]) GetResult() (R, error) {
result, err := retryWithResult(h.dbosContext, func() (any, error) {
return h.dbosContext.(*dbosContext).systemDB.awaitWorkflowResult(h.dbosContext, h.workflowID)
func (h *workflowPollingHandle[R]) GetResult(opts ...GetResultOption) (R, error) {
options := &getResultOptions{}
for _, opt := range opts {
opt(options)
}

// Use timeout if specified, otherwise use DBOS context directly
ctx := h.dbosContext
var cancel context.CancelFunc
if options.timeout > 0 {
ctx, cancel = WithTimeout(h.dbosContext, options.timeout)
defer cancel()
}

result, err := retryWithResult(ctx, func() (any, error) {
return h.dbosContext.(*dbosContext).systemDB.awaitWorkflowResult(ctx, h.workflowID)
}, withRetrierLogger(h.dbosContext.(*dbosContext).logger))
if result != nil {
typedResult, ok := result.(R)
Expand Down Expand Up @@ -240,8 +290,8 @@ type workflowHandleProxy[R any] struct {
wrappedHandle WorkflowHandle[any]
}

func (h *workflowHandleProxy[R]) GetResult() (R, error) {
result, err := h.wrappedHandle.GetResult()
func (h *workflowHandleProxy[R]) GetResult(opts ...GetResultOption) (R, error) {
result, err := h.wrappedHandle.GetResult(opts...)
if err != nil {
var zero R
return zero, err
Expand Down
73 changes: 73 additions & 0 deletions dbos/workflows_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ func simpleWorkflowWithStep(dbosCtx DBOSContext, input string) (string, error) {
})
}

func slowWorkflow(dbosCtx DBOSContext, sleepTime time.Duration) (string, error) {
Sleep(dbosCtx, sleepTime)
return "done", nil
}

func simpleStep(_ context.Context) (string, error) {
return "from step", nil
}
Expand Down Expand Up @@ -4523,3 +4528,71 @@ func TestWorkflowIdentity(t *testing.T) {
assert.Equal(t, []string{"reader", "writer"}, status.AuthenticatedRoles)
})
}

func TestWorkflowHandleTimeout(t *testing.T) {
dbosCtx := setupDBOS(t, true, true)
RegisterWorkflow(dbosCtx, slowWorkflow)

t.Run("WorkflowHandleTimeout", func(t *testing.T) {
handle, err := RunWorkflow(dbosCtx, slowWorkflow, 10*time.Second)
require.NoError(t, err, "failed to start workflow")

start := time.Now()
_, err = handle.GetResult(WithHandleTimeout(10 * time.Millisecond))
duration := time.Since(start)

require.Error(t, err, "expected timeout error")
assert.Contains(t, err.Error(), "workflow result timeout")
assert.True(t, duration < 100*time.Millisecond, "timeout should occur quickly")
assert.True(t, errors.Is(err, context.DeadlineExceeded),
"expected error to be detectable as context.DeadlineExceeded, got: %v", err)
})

t.Run("WorkflowPollingHandleTimeout", func(t *testing.T) {
// Start a workflow that will block on the first signal
originalHandle, err := RunWorkflow(dbosCtx, slowWorkflow, 10*time.Second)
require.NoError(t, err, "failed to start workflow")

pollingHandle, err := RetrieveWorkflow[string](dbosCtx, originalHandle.GetWorkflowID())
require.NoError(t, err, "failed to retrieve workflow")

_, ok := pollingHandle.(*workflowPollingHandle[string])
require.True(t, ok, "expected polling handle, got %T", pollingHandle)

_, err = pollingHandle.GetResult(WithHandleTimeout(10 * time.Millisecond))

require.Error(t, err, "expected timeout error")
assert.True(t, errors.Is(err, context.DeadlineExceeded),
"expected error to be detectable as context.DeadlineExceeded, got: %v", err)
})
}

func TestWorkflowHandleContextCancel(t *testing.T) {
dbosCtx := setupDBOS(t, true, true)
RegisterWorkflow(dbosCtx, getEventWorkflow)

t.Run("WorkflowHandleContextCancel", func(t *testing.T) {
getEventWorkflowStartedSignal.Clear()
handle, err := RunWorkflow(dbosCtx, getEventWorkflow, getEventWorkflowInput{
TargetWorkflowID: "test-workflow-id",
Key: "test-key",
})
require.NoError(t, err, "failed to start workflow")

resultChan := make(chan error)
go func() {
_, err := handle.GetResult()
resultChan <- err
}()

getEventWorkflowStartedSignal.Wait()
getEventWorkflowStartedSignal.Clear()

dbosCtx.Shutdown(1 * time.Second)

err = <-resultChan
require.Error(t, err, "expected error from cancelled context")
assert.True(t, errors.Is(err, context.Canceled),
"expected error to be detectable as context.Canceled, got: %v", err)
})
}
Loading