diff --git a/dbos/workflow.go b/dbos/workflow.go index bdfa648..ea5a359 100644 --- a/dbos/workflow.go +++ b/dbos/workflow.go @@ -82,9 +82,9 @@ type workflowOutcome[R any] struct { // The type parameter R represents the expected return type of the workflow. // Handles can be used to wait for workflow completion, check status, and retrieve results. type WorkflowHandle[R any] interface { - GetResult() (R, error) // Wait for workflow completion and return the result - GetStatus() (WorkflowStatus, error) // Get current workflow status without waiting - GetWorkflowID() string // Get the unique workflow identifier + GetResult(opts ...GetResultOption) (R, error) // Wait for workflow completion and return the result + GetStatus() (WorkflowStatus, error) // Get current workflow status without waiting + GetWorkflowID() string // Get the unique workflow identifier } type baseWorkflowHandle struct { @@ -92,6 +92,22 @@ type baseWorkflowHandle struct { dbosContext DBOSContext } +// GetResultOption is a functional option for configuring GetResult behavior. +type GetResultOption func(*getResultOptions) + +// getResultOptions holds the configuration for GetResult execution. +type getResultOptions struct { + timeout time.Duration +} + +// WithHandleTimeout sets a timeout for the GetResult operation. +// If the timeout is reached before the workflow completes, GetResult will return a timeout error. +func WithHandleTimeout(timeout time.Duration) GetResultOption { + return func(opts *getResultOptions) { + opts.timeout = timeout + } +} + // GetStatus returns the current status of the workflow from the database // If the DBOSContext is running in client mode, do not load input and outputs func (h *baseWorkflowHandle) GetStatus() (WorkflowStatus, error) { @@ -162,12 +178,33 @@ type workflowHandle[R any] struct { outcomeChan chan workflowOutcome[R] } -func (h *workflowHandle[R]) GetResult() (R, error) { - outcome, ok := <-h.outcomeChan // Blocking read - if !ok { - // Return an error if the channel was closed. In normal operations this would happen if GetResul() is called twice on a handler. The first call should get the buffered result, the second call find zero values (channel is empty and closed). - return *new(R), errors.New("workflow result channel is already closed. Did you call GetResult() twice on the same workflow handle?") +func (h *workflowHandle[R]) GetResult(opts ...GetResultOption) (R, error) { + options := &getResultOptions{} + for _, opt := range opts { + opt(options) } + + var timeoutChan <-chan time.Time + if options.timeout > 0 { + timeoutChan = time.After(options.timeout) + } + + select { + case outcome, ok := <-h.outcomeChan: + if !ok { + // Return error if channel closed (happens when GetResult() called twice) + return *new(R), errors.New("workflow result channel is already closed. Did you call GetResult() twice on the same workflow handle?") + } + return h.processOutcome(outcome) + case <-h.dbosContext.Done(): + return *new(R), context.Cause(h.dbosContext) + case <-timeoutChan: + return *new(R), fmt.Errorf("workflow result timeout after %v: %w", options.timeout, context.DeadlineExceeded) + } +} + +// processOutcome handles the common logic for processing workflow outcomes +func (h *workflowHandle[R]) processOutcome(outcome workflowOutcome[R]) (R, error) { // If we are calling GetResult inside a workflow, record the result as a step result workflowState, ok := h.dbosContext.Value(workflowStateKey).(*workflowState) isWithinWorkflow := ok && workflowState != nil @@ -198,9 +235,22 @@ type workflowPollingHandle[R any] struct { baseWorkflowHandle } -func (h *workflowPollingHandle[R]) GetResult() (R, error) { - result, err := retryWithResult(h.dbosContext, func() (any, error) { - return h.dbosContext.(*dbosContext).systemDB.awaitWorkflowResult(h.dbosContext, h.workflowID) +func (h *workflowPollingHandle[R]) GetResult(opts ...GetResultOption) (R, error) { + options := &getResultOptions{} + for _, opt := range opts { + opt(options) + } + + // Use timeout if specified, otherwise use DBOS context directly + ctx := h.dbosContext + var cancel context.CancelFunc + if options.timeout > 0 { + ctx, cancel = WithTimeout(h.dbosContext, options.timeout) + defer cancel() + } + + result, err := retryWithResult(ctx, func() (any, error) { + return h.dbosContext.(*dbosContext).systemDB.awaitWorkflowResult(ctx, h.workflowID) }, withRetrierLogger(h.dbosContext.(*dbosContext).logger)) if result != nil { typedResult, ok := result.(R) @@ -240,8 +290,8 @@ type workflowHandleProxy[R any] struct { wrappedHandle WorkflowHandle[any] } -func (h *workflowHandleProxy[R]) GetResult() (R, error) { - result, err := h.wrappedHandle.GetResult() +func (h *workflowHandleProxy[R]) GetResult(opts ...GetResultOption) (R, error) { + result, err := h.wrappedHandle.GetResult(opts...) if err != nil { var zero R return zero, err diff --git a/dbos/workflows_test.go b/dbos/workflows_test.go index 56b8cf8..8756efd 100644 --- a/dbos/workflows_test.go +++ b/dbos/workflows_test.go @@ -34,6 +34,11 @@ func simpleWorkflowWithStep(dbosCtx DBOSContext, input string) (string, error) { }) } +func slowWorkflow(dbosCtx DBOSContext, sleepTime time.Duration) (string, error) { + Sleep(dbosCtx, sleepTime) + return "done", nil +} + func simpleStep(_ context.Context) (string, error) { return "from step", nil } @@ -4523,3 +4528,71 @@ func TestWorkflowIdentity(t *testing.T) { assert.Equal(t, []string{"reader", "writer"}, status.AuthenticatedRoles) }) } + +func TestWorkflowHandleTimeout(t *testing.T) { + dbosCtx := setupDBOS(t, true, true) + RegisterWorkflow(dbosCtx, slowWorkflow) + + t.Run("WorkflowHandleTimeout", func(t *testing.T) { + handle, err := RunWorkflow(dbosCtx, slowWorkflow, 10*time.Second) + require.NoError(t, err, "failed to start workflow") + + start := time.Now() + _, err = handle.GetResult(WithHandleTimeout(10 * time.Millisecond)) + duration := time.Since(start) + + require.Error(t, err, "expected timeout error") + assert.Contains(t, err.Error(), "workflow result timeout") + assert.True(t, duration < 100*time.Millisecond, "timeout should occur quickly") + assert.True(t, errors.Is(err, context.DeadlineExceeded), + "expected error to be detectable as context.DeadlineExceeded, got: %v", err) + }) + + t.Run("WorkflowPollingHandleTimeout", func(t *testing.T) { + // Start a workflow that will block on the first signal + originalHandle, err := RunWorkflow(dbosCtx, slowWorkflow, 10*time.Second) + require.NoError(t, err, "failed to start workflow") + + pollingHandle, err := RetrieveWorkflow[string](dbosCtx, originalHandle.GetWorkflowID()) + require.NoError(t, err, "failed to retrieve workflow") + + _, ok := pollingHandle.(*workflowPollingHandle[string]) + require.True(t, ok, "expected polling handle, got %T", pollingHandle) + + _, err = pollingHandle.GetResult(WithHandleTimeout(10 * time.Millisecond)) + + require.Error(t, err, "expected timeout error") + assert.True(t, errors.Is(err, context.DeadlineExceeded), + "expected error to be detectable as context.DeadlineExceeded, got: %v", err) + }) +} + +func TestWorkflowHandleContextCancel(t *testing.T) { + dbosCtx := setupDBOS(t, true, true) + RegisterWorkflow(dbosCtx, getEventWorkflow) + + t.Run("WorkflowHandleContextCancel", func(t *testing.T) { + getEventWorkflowStartedSignal.Clear() + handle, err := RunWorkflow(dbosCtx, getEventWorkflow, getEventWorkflowInput{ + TargetWorkflowID: "test-workflow-id", + Key: "test-key", + }) + require.NoError(t, err, "failed to start workflow") + + resultChan := make(chan error) + go func() { + _, err := handle.GetResult() + resultChan <- err + }() + + getEventWorkflowStartedSignal.Wait() + getEventWorkflowStartedSignal.Clear() + + dbosCtx.Shutdown(1 * time.Second) + + err = <-resultChan + require.Error(t, err, "expected error from cancelled context") + assert.True(t, errors.Is(err, context.Canceled), + "expected error to be detectable as context.Canceled, got: %v", err) + }) +}