From dc4b5d7adc34cb9aae9090bacb91f0b83f29befc Mon Sep 17 00:00:00 2001 From: Peter Turi Date: Fri, 27 Mar 2026 16:59:22 +0100 Subject: [PATCH 1/2] feat: subs-sync credits support --- app/config/billing.go | 10 + openmeter/billing/gatheringinvoice.go | 4 + openmeter/billing/invoiceline.go | 1 + openmeter/billing/stdinvoiceline.go | 9 + .../service/persistedstate/entity.go | 85 ++++++++ .../service/persistedstate/loader.go | 14 +- .../service/persistedstate/state.go | 2 +- .../subscriptionsync/service/reconcile.go | 4 +- .../service/reconciler/patch.go | 9 + .../service/reconciler/patchcreate.go | 39 +++- .../service/reconciler/patchdelete.go | 31 ++- .../service/reconciler/patchextend.go | 54 ++++- .../service/reconciler/patchhelpers.go | 30 +++ .../service/reconciler/patchshrink.go | 53 ++++- .../service/reconciler/prorating.go | 0 .../service/reconciler/reconciler.go | 193 +++++++----------- .../subscriptionsync/service/service.go | 16 +- .../worker/subscriptionsync/service/sync.go | 4 +- .../service/targetstate/targetstate.go | 4 +- 19 files changed, 407 insertions(+), 155 deletions(-) create mode 100644 openmeter/billing/worker/subscriptionsync/service/persistedstate/entity.go create mode 100644 openmeter/billing/worker/subscriptionsync/service/reconciler/prorating.go diff --git a/app/config/billing.go b/app/config/billing.go index 008e5feaf8..92099634b3 100644 --- a/app/config/billing.go +++ b/app/config/billing.go @@ -14,6 +14,7 @@ type BillingConfiguration struct { MaxParallelQuantitySnapshots int Worker BillingWorkerConfiguration FeatureSwitches BillingFeatureSwitchesConfiguration + Charges BillingChargesConfiguration } func (c BillingConfiguration) Validate() error { @@ -33,6 +34,14 @@ func (c BillingConfiguration) Validate() error { return errors.Join(errs...) } +type BillingChargesConfiguration struct { + Enabled bool +} + +func (c BillingChargesConfiguration) Validate() error { + return nil +} + type BillingFeatureSwitchesConfiguration struct { NamespaceLockdown []string } @@ -50,4 +59,5 @@ func ConfigureBilling(v *viper.Viper, flags *pflag.FlagSet) { _ = v.BindPFlag("billing.advancementStrategy", flags.Lookup("billing-advancement-strategy")) v.SetDefault("billing.advancementStrategy", billing.ForegroundAdvancementStrategy) v.SetDefault("billing.maxParallelQuantitySnapshots", 4) + v.SetDefault("billing.charges.enabled", false) } diff --git a/openmeter/billing/gatheringinvoice.go b/openmeter/billing/gatheringinvoice.go index 00fdb9ab7d..7774584f63 100644 --- a/openmeter/billing/gatheringinvoice.go +++ b/openmeter/billing/gatheringinvoice.go @@ -650,6 +650,10 @@ func (g GatheringLine) AsInvoiceLine() InvoiceLine { } } +func (g GatheringLine) AsLineOrHierarchy() (LineOrHierarchy, error) { + return NewLineOrHierarchy(g), nil +} + func (g GatheringLine) Equal(other GatheringLine) bool { return g.GatheringLineBase.Equal(other.GatheringLineBase) } diff --git a/openmeter/billing/invoiceline.go b/openmeter/billing/invoiceline.go index c2f3fc8a66..f3b850ee9c 100644 --- a/openmeter/billing/invoiceline.go +++ b/openmeter/billing/invoiceline.go @@ -141,6 +141,7 @@ type GenericInvoiceLineReader interface { Validate() error AsInvoiceLine() InvoiceLine + AsLineOrHierarchy() (LineOrHierarchy, error) GetRateCardDiscounts() Discounts GetSubscriptionReference() *SubscriptionReference GetSplitLineGroupID() *string diff --git a/openmeter/billing/stdinvoiceline.go b/openmeter/billing/stdinvoiceline.go index e95c2a6045..1d6d14334f 100644 --- a/openmeter/billing/stdinvoiceline.go +++ b/openmeter/billing/stdinvoiceline.go @@ -296,6 +296,15 @@ func (i StandardLine) AsInvoiceLine() InvoiceLine { } } +func (i StandardLine) AsLineOrHierarchy() (LineOrHierarchy, error) { + cloned, err := i.Clone() + if err != nil { + return LineOrHierarchy{}, err + } + + return NewLineOrHierarchy(cloned), nil +} + func (i StandardLine) GetQuantity() *alpacadecimal.Decimal { if i.UsageBased == nil { return nil diff --git a/openmeter/billing/worker/subscriptionsync/service/persistedstate/entity.go b/openmeter/billing/worker/subscriptionsync/service/persistedstate/entity.go new file mode 100644 index 0000000000..099b10b5f5 --- /dev/null +++ b/openmeter/billing/worker/subscriptionsync/service/persistedstate/entity.go @@ -0,0 +1,85 @@ +package persistedstate + +import ( + "errors" + + "github.com/openmeterio/openmeter/openmeter/billing" + "github.com/openmeterio/openmeter/openmeter/billing/charges" + "github.com/openmeterio/openmeter/pkg/timeutil" +) + +type EntityType string + +const ( + EntityTypeLineOrHierarchy EntityType = "line_or_hierarchy" + EntityTypeCharge EntityType = "charge" +) + +var ErrEntityTypeMismatch = errors.New("entity type mismatch") + +type Entity interface { + IsFlatFee() bool + + GetServicePeriod() timeutil.ClosedPeriod + GetChildUniqueReferenceID() *string + GetType() EntityType + AsLineOrHierarchy() (billing.LineOrHierarchy, error) + AsCharge() (charges.Charge, error) +} + +// Implementations + +// Line +var _ Entity = (*LineEntity)(nil) + +type LineEntity struct { + line billing.GenericInvoiceLine +} + +func (e LineEntity) GetType() EntityType { + return EntityTypeLineOrHierarchy +} + +func (e LineEntity) AsLineOrHierarchy() (billing.LineOrHierarchy, error) { + return e.line.AsLineOrHierarchy() +} + +// Hierarchy + +var _ Entity = (*HierarchyEntity)(nil) + +type HierarchyEntity struct { + hierarchy *billing.SplitLineHierarchy +} + +func (e HierarchyEntity) GetType() EntityType { + return EntityTypeLineOrHierarchy +} + +func (e HierarchyEntity) AsLineOrHierarchy() (billing.LineOrHierarchy, error) { + return billing.NewLineOrHierarchy(e.hierarchy), nil +} + +// Charge + +var _ Entity = (*ChargeEntity)(nil) + +type ChargeEntity struct { + charge charges.Charge +} + +func (e ChargeEntity) GetType() EntityType { + return EntityTypeCharge +} + +func (e ChargeEntity) GetServicePeriod() timeutil.ClosedPeriod { + return e.charge.GetServicePeriod() +} + +func (e ChargeEntity) GetChildUniqueReferenceID() *string { + return e.charge.GetChildUniqueReferenceID() +} + +func (e ChargeEntity) AsLineOrHierarchy() (billing.LineOrHierarchy, error) { + return billing.LineOrHierarchy{}, ErrEntityTypeMismatch +} diff --git a/openmeter/billing/worker/subscriptionsync/service/persistedstate/loader.go b/openmeter/billing/worker/subscriptionsync/service/persistedstate/loader.go index 7e9f9baf2c..9b520fd342 100644 --- a/openmeter/billing/worker/subscriptionsync/service/persistedstate/loader.go +++ b/openmeter/billing/worker/subscriptionsync/service/persistedstate/loader.go @@ -7,6 +7,7 @@ import ( "github.com/samber/lo" "github.com/openmeterio/openmeter/openmeter/billing" + "github.com/openmeterio/openmeter/openmeter/billing/charges" "github.com/openmeterio/openmeter/openmeter/customer" "github.com/openmeterio/openmeter/openmeter/streaming" "github.com/openmeterio/openmeter/openmeter/subscription" @@ -21,15 +22,26 @@ type billingService interface { type Loader struct { billingService billingService + chargesService charges.Service } -func NewLoader(billingService billingService) Loader { +func NewLoader(billingService billingService, chargesService charges.Service) Loader { return Loader{ billingService: billingService, + chargesService: chargesService, } } func (l Loader) LoadForSubscription(ctx context.Context, subs subscription.Subscription) (State, error) { + lines, err := l.loadLinesForSubscription(ctx, subs) + if err != nil { + return State{}, fmt.Errorf("loading lines for subscription: %w", err) + } + + return lines, nil +} + +func (l Loader) loadLinesForSubscription(ctx context.Context, subs subscription.Subscription) (State, error) { lines, err := l.billingService.GetLinesForSubscription(ctx, billing.GetLinesForSubscriptionInput{ Namespace: subs.Namespace, SubscriptionID: subs.ID, diff --git a/openmeter/billing/worker/subscriptionsync/service/persistedstate/state.go b/openmeter/billing/worker/subscriptionsync/service/persistedstate/state.go index 4fd729ec83..fb569308c8 100644 --- a/openmeter/billing/worker/subscriptionsync/service/persistedstate/state.go +++ b/openmeter/billing/worker/subscriptionsync/service/persistedstate/state.go @@ -8,7 +8,7 @@ import ( type State struct { Lines []billing.LineOrHierarchy - ByUniqueID map[string]billing.LineOrHierarchy + ByUniqueID map[string]Entity } func (s State) Validate() error { diff --git a/openmeter/billing/worker/subscriptionsync/service/reconcile.go b/openmeter/billing/worker/subscriptionsync/service/reconcile.go index 3cb79b5cd4..70fcf7abc3 100644 --- a/openmeter/billing/worker/subscriptionsync/service/reconcile.go +++ b/openmeter/billing/worker/subscriptionsync/service/reconcile.go @@ -4,7 +4,6 @@ import ( "context" "time" - "github.com/openmeterio/openmeter/openmeter/billing/worker/subscriptionsync/service/persistedstate" "github.com/openmeterio/openmeter/openmeter/billing/worker/subscriptionsync/service/reconciler" "github.com/openmeterio/openmeter/openmeter/billing/worker/subscriptionsync/service/targetstate" "github.com/openmeterio/openmeter/openmeter/subscription" @@ -16,8 +15,7 @@ func (s *Service) buildSyncPlan(ctx context.Context, subsView subscription.Subsc span := tracex.Start[*reconciler.Plan](ctx, s.tracer, "billing.worker.subscription.sync.buildSyncPlan") return span.Wrap(func(ctx context.Context) (*reconciler.Plan, error) { - persistedLoader := persistedstate.NewLoader(s.billingService) - persisted, err := persistedLoader.LoadForSubscription(ctx, subsView.Subscription) + persisted, err := s.persistedStateLoader.LoadForSubscription(ctx, subsView.Subscription) if err != nil { return nil, err } diff --git a/openmeter/billing/worker/subscriptionsync/service/reconciler/patch.go b/openmeter/billing/worker/subscriptionsync/service/reconciler/patch.go index 06ce225e77..96b4304b08 100644 --- a/openmeter/billing/worker/subscriptionsync/service/reconciler/patch.go +++ b/openmeter/billing/worker/subscriptionsync/service/reconciler/patch.go @@ -26,5 +26,14 @@ type GetInvoicePatchesInput struct { type Patch interface { Operation() PatchOperation UniqueReferenceID() string +} + +type InvoicePatch interface { + Patch GetInvoicePatches(input GetInvoicePatchesInput) ([]invoiceupdater.Patch, error) } + +type ChargePatch interface { + Patch + TodoPatchMechanism() +} diff --git a/openmeter/billing/worker/subscriptionsync/service/reconciler/patchcreate.go b/openmeter/billing/worker/subscriptionsync/service/reconciler/patchcreate.go index 9cc5eb5c7f..4d696914c2 100644 --- a/openmeter/billing/worker/subscriptionsync/service/reconciler/patchcreate.go +++ b/openmeter/billing/worker/subscriptionsync/service/reconciler/patchcreate.go @@ -1,26 +1,57 @@ package reconciler import ( + "errors" "fmt" "github.com/openmeterio/openmeter/openmeter/billing/worker/subscriptionsync/service/reconciler/invoiceupdater" "github.com/openmeterio/openmeter/openmeter/billing/worker/subscriptionsync/service/targetstate" ) -type CreatePatch struct { +type NewCreatePatchInput struct { UniqueID string Target targetstate.SubscriptionItemWithPeriods } -func (p CreatePatch) Operation() PatchOperation { +func (i NewCreatePatchInput) Validate() error { + var errs []error + if i.UniqueID == "" { + errs = append(errs, errors.New("unique id is required")) + } + + if err := i.Target.Validate(); err != nil { + errs = append(errs, fmt.Errorf("target: %w", err)) + } + + return errors.Join(errs...) +} + +func (s *Service) NewCreatePatch(input NewCreatePatchInput) (Patch, error) { + if err := input.Validate(); err != nil { + return nil, fmt.Errorf("new create patch: %w", err) + } + + // TODO: use the service's field to decide if it should create a line or charge + return LineCreatePatch{ + UniqueID: input.UniqueID, + Target: input.Target, + }, nil +} + +type LineCreatePatch struct { + UniqueID string + Target targetstate.SubscriptionItemWithPeriods +} + +func (p LineCreatePatch) Operation() PatchOperation { return PatchOperationCreate } -func (p CreatePatch) UniqueReferenceID() string { +func (p LineCreatePatch) UniqueReferenceID() string { return p.UniqueID } -func (p CreatePatch) GetInvoicePatches(input GetInvoicePatchesInput) ([]invoiceupdater.Patch, error) { +func (p LineCreatePatch) GetInvoicePatches(input GetInvoicePatchesInput) ([]invoiceupdater.Patch, error) { line, err := p.Target.GetExpectedLine(input.Subscription, input.Currency) if err != nil { return nil, fmt.Errorf("generating line from subscription item [%s]: %w", p.Target.SubscriptionItem.ID, err) diff --git a/openmeter/billing/worker/subscriptionsync/service/reconciler/patchdelete.go b/openmeter/billing/worker/subscriptionsync/service/reconciler/patchdelete.go index b49eb78a73..9385cce23b 100644 --- a/openmeter/billing/worker/subscriptionsync/service/reconciler/patchdelete.go +++ b/openmeter/billing/worker/subscriptionsync/service/reconciler/patchdelete.go @@ -1,23 +1,46 @@ package reconciler import ( + "errors" + "github.com/openmeterio/openmeter/openmeter/billing" + "github.com/openmeterio/openmeter/openmeter/billing/worker/subscriptionsync/service/persistedstate" "github.com/openmeterio/openmeter/openmeter/billing/worker/subscriptionsync/service/reconciler/invoiceupdater" + "github.com/samber/lo" ) type DeletePatch struct { UniqueID string + Existing persistedstate.Entity +} + +func (s *Service) NewDeletePatch(existing persistedstate.Entity) (Patch, error) { + if existing == nil { + return nil, errors.New("new delete patch: existing entity is required") + } + + return newFromEntity(newFromEntityInput{ + Entity: existing, + NewInvoicePatch: func(lineOrHierarchy billing.LineOrHierarchy) (Patch, error) { + return LineDeletePatch{ + Existing: lineOrHierarchy, + }, nil + }, + }) +} + +type LineDeletePatch struct { Existing billing.LineOrHierarchy } -func (p DeletePatch) Operation() PatchOperation { +func (p LineDeletePatch) Operation() PatchOperation { return PatchOperationDelete } -func (p DeletePatch) UniqueReferenceID() string { - return p.UniqueID +func (p LineDeletePatch) UniqueReferenceID() string { + return lo.FromPtr(p.Existing.ChildUniqueReferenceID()) } -func (p DeletePatch) GetInvoicePatches(input GetInvoicePatchesInput) ([]invoiceupdater.Patch, error) { +func (p LineDeletePatch) GetInvoicePatches(input GetInvoicePatchesInput) ([]invoiceupdater.Patch, error) { return invoiceupdater.GetDeletePatchesForLine(p.Existing) } diff --git a/openmeter/billing/worker/subscriptionsync/service/reconciler/patchextend.go b/openmeter/billing/worker/subscriptionsync/service/reconciler/patchextend.go index 5c1eb553ef..a9b9196d47 100644 --- a/openmeter/billing/worker/subscriptionsync/service/reconciler/patchextend.go +++ b/openmeter/billing/worker/subscriptionsync/service/reconciler/patchextend.go @@ -1,6 +1,7 @@ package reconciler import ( + "errors" "fmt" "slices" @@ -9,23 +10,60 @@ import ( "github.com/openmeterio/openmeter/openmeter/billing/worker/subscriptionsync/service/reconciler/invoiceupdater" "github.com/openmeterio/openmeter/openmeter/billing/worker/subscriptionsync/service/targetstate" "github.com/openmeterio/openmeter/pkg/timeutil" + "github.com/samber/lo" ) -type ExtendUsageBasedPatch struct { - UniqueID string +type NewLineExtendUsageBasedPatchInput struct { + Existing persistedstate.Entity + Target targetstate.SubscriptionItemWithPeriods +} + +func (i NewLineExtendUsageBasedPatchInput) Validate() error { + var errs []error + + if i.Existing == nil { + errs = append(errs, errors.New("existing is required")) + } + + if err := i.Target.Validate(); err != nil { + errs = append(errs, fmt.Errorf("target: %w", err)) + } + + return errors.Join(errs...) +} + +func (s *Service) NewLineExtendUsageBasedPatch(input NewLineExtendUsageBasedPatchInput) (Patch, error) { + if err := input.Validate(); err != nil { + return nil, fmt.Errorf("new line extend usage based patch: %w", err) + } + + return newFromEntity(newFromEntityInput{ + Entity: input.Existing, + NewInvoicePatch: func(lineOrHierarchy billing.LineOrHierarchy) (Patch, error) { + return LineExtendUsageBasedPatch{ + Existing: lineOrHierarchy, + Target: input.Target, + }, nil + }, + }) +} + +var _ InvoicePatch = (*LineExtendUsageBasedPatch)(nil) + +type LineExtendUsageBasedPatch struct { Existing billing.LineOrHierarchy Target targetstate.SubscriptionItemWithPeriods } -func (p ExtendUsageBasedPatch) Operation() PatchOperation { +func (p LineExtendUsageBasedPatch) Operation() PatchOperation { return PatchOperationExtend } -func (p ExtendUsageBasedPatch) UniqueReferenceID() string { - return p.UniqueID +func (p LineExtendUsageBasedPatch) UniqueReferenceID() string { + return lo.FromPtr(p.Existing.ChildUniqueReferenceID()) } -func (p ExtendUsageBasedPatch) GetInvoicePatches(input GetInvoicePatchesInput) ([]invoiceupdater.Patch, error) { +func (p LineExtendUsageBasedPatch) GetInvoicePatches(input GetInvoicePatchesInput) ([]invoiceupdater.Patch, error) { expectedLine, err := p.Target.GetExpectedLineOrErr(input.Subscription, input.Currency) if err != nil { return nil, err @@ -51,7 +89,7 @@ func (p ExtendUsageBasedPatch) GetInvoicePatches(input GetInvoicePatchesInput) ( } } -func (p ExtendUsageBasedPatch) getInvoicePatchesForLine(existingLine billing.GenericInvoiceLine, expectedLine billing.GatheringLine, invoices persistedstate.Invoices) ([]invoiceupdater.Patch, error) { +func (p LineExtendUsageBasedPatch) getInvoicePatchesForLine(existingLine billing.GenericInvoiceLine, expectedLine billing.GatheringLine, invoices persistedstate.Invoices) ([]invoiceupdater.Patch, error) { if shouldSkipLinePatch(existingLine, expectedLine) { return nil, nil } @@ -67,7 +105,7 @@ func (p ExtendUsageBasedPatch) getInvoicePatchesForLine(existingLine billing.Gen return getPatchesForUpdateUsageBasedLine(existingLine, expectedLine, invoices) } -func (p ExtendUsageBasedPatch) getInvoicePatchesForHierarchy(existingHierarchy *billing.SplitLineHierarchy, expectedLine billing.GatheringLine, invoices persistedstate.Invoices) ([]invoiceupdater.Patch, error) { +func (p LineExtendUsageBasedPatch) getInvoicePatchesForHierarchy(existingHierarchy *billing.SplitLineHierarchy, expectedLine billing.GatheringLine, invoices persistedstate.Invoices) ([]invoiceupdater.Patch, error) { if shouldSkipHierarchyPatch(existingHierarchy, expectedLine) { return nil, nil } diff --git a/openmeter/billing/worker/subscriptionsync/service/reconciler/patchhelpers.go b/openmeter/billing/worker/subscriptionsync/service/reconciler/patchhelpers.go index a10a9ba312..68d7dc4289 100644 --- a/openmeter/billing/worker/subscriptionsync/service/reconciler/patchhelpers.go +++ b/openmeter/billing/worker/subscriptionsync/service/reconciler/patchhelpers.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/openmeterio/openmeter/openmeter/billing" + "github.com/openmeterio/openmeter/openmeter/billing/charges" "github.com/openmeterio/openmeter/openmeter/billing/worker/subscriptionsync/service/persistedstate" "github.com/openmeterio/openmeter/openmeter/billing/worker/subscriptionsync/service/reconciler/invoiceupdater" "github.com/openmeterio/openmeter/openmeter/streaming" @@ -93,3 +94,32 @@ func getPatchesForUpdateUsageBasedLine(existingLine billing.GenericInvoiceLine, invoiceupdater.NewUpdateLinePatch(targetLine), }, nil } + +type newFromEntityInput struct { + Entity persistedstate.Entity + NewInvoicePatch func(billing.LineOrHierarchy) (Patch, error) + NewChargePatch func(charges.Charge) (Patch, error) +} + +func newFromEntity(input newFromEntityInput) (Patch, error) { + switch input.Entity.GetType() { + case persistedstate.EntityTypeLineOrHierarchy: + lineOrHierarchy, err := input.Entity.AsLineOrHierarchy() + if err != nil { + return nil, fmt.Errorf("getting line or hierarchy: %w", err) + } + if input.NewInvoicePatch == nil { + return nil, fmt.Errorf("invoice patching is not supported") + } + + return input.NewInvoicePatch(lineOrHierarchy) + case persistedstate.EntityTypeCharge: + if input.NewChargePatch == nil { + return nil, fmt.Errorf("charge patching is not supported") + } + + return nil, fmt.Errorf("charge patching is not supported") + default: + return nil, fmt.Errorf("unsupported entity type: %s", input.Entity.GetType()) + } +} diff --git a/openmeter/billing/worker/subscriptionsync/service/reconciler/patchshrink.go b/openmeter/billing/worker/subscriptionsync/service/reconciler/patchshrink.go index 6a4719ecfb..5b515c4e5f 100644 --- a/openmeter/billing/worker/subscriptionsync/service/reconciler/patchshrink.go +++ b/openmeter/billing/worker/subscriptionsync/service/reconciler/patchshrink.go @@ -1,6 +1,7 @@ package reconciler import ( + "errors" "fmt" "github.com/openmeterio/openmeter/openmeter/billing" @@ -9,23 +10,59 @@ import ( "github.com/openmeterio/openmeter/openmeter/billing/worker/subscriptionsync/service/targetstate" "github.com/openmeterio/openmeter/openmeter/streaming" "github.com/openmeterio/openmeter/pkg/timeutil" + "github.com/samber/lo" ) -type ShrinkUsageBasedPatch struct { - UniqueID string +type NewLineShrinkUsageBasedPatchInput struct { + Existing persistedstate.Entity + Target targetstate.SubscriptionItemWithPeriods +} + +func (i NewLineShrinkUsageBasedPatchInput) Validate() error { + var errs []error + if i.Existing == nil { + errs = append(errs, errors.New("existing is required")) + } + + if err := i.Target.Validate(); err != nil { + errs = append(errs, fmt.Errorf("target: %w", err)) + } + + return errors.Join(errs...) +} + +func (s *Service) NewLineShrinkUsageBasedPatch(input NewLineShrinkUsageBasedPatchInput) (Patch, error) { + if err := input.Validate(); err != nil { + return nil, fmt.Errorf("new line shrink usage based patch: %w", err) + } + + return newFromEntity(newFromEntityInput{ + Entity: input.Existing, + NewInvoicePatch: func(lineOrHierarchy billing.LineOrHierarchy) (Patch, error) { + return LineShrinkUsageBasedPatch{ + Existing: lineOrHierarchy, + Target: input.Target, + }, nil + }, + }) +} + +var _ InvoicePatch = (*LineShrinkUsageBasedPatch)(nil) + +type LineShrinkUsageBasedPatch struct { Existing billing.LineOrHierarchy Target targetstate.SubscriptionItemWithPeriods } -func (p ShrinkUsageBasedPatch) Operation() PatchOperation { +func (p LineShrinkUsageBasedPatch) Operation() PatchOperation { return PatchOperationShrink } -func (p ShrinkUsageBasedPatch) UniqueReferenceID() string { - return p.UniqueID +func (p LineShrinkUsageBasedPatch) UniqueReferenceID() string { + return lo.FromPtr(p.Existing.ChildUniqueReferenceID()) } -func (p ShrinkUsageBasedPatch) GetInvoicePatches(input GetInvoicePatchesInput) ([]invoiceupdater.Patch, error) { +func (p LineShrinkUsageBasedPatch) GetInvoicePatches(input GetInvoicePatchesInput) ([]invoiceupdater.Patch, error) { expectedLine, err := p.Target.GetExpectedLineOrErr(input.Subscription, input.Currency) if err != nil { return nil, err @@ -51,7 +88,7 @@ func (p ShrinkUsageBasedPatch) GetInvoicePatches(input GetInvoicePatchesInput) ( } } -func (p ShrinkUsageBasedPatch) getInvoicePatchesForLine(existingLine billing.GenericInvoiceLine, expectedLine billing.GatheringLine, invoices persistedstate.Invoices) ([]invoiceupdater.Patch, error) { +func (p LineShrinkUsageBasedPatch) getInvoicePatchesForLine(existingLine billing.GenericInvoiceLine, expectedLine billing.GatheringLine, invoices persistedstate.Invoices) ([]invoiceupdater.Patch, error) { if shouldSkipLinePatch(existingLine, expectedLine) { return nil, nil } @@ -67,7 +104,7 @@ func (p ShrinkUsageBasedPatch) getInvoicePatchesForLine(existingLine billing.Gen return getPatchesForUpdateUsageBasedLine(existingLine, expectedLine, invoices) } -func (p ShrinkUsageBasedPatch) getInvoicePatchesForHierarchy(existingHierarchy *billing.SplitLineHierarchy, expectedLine billing.GatheringLine, invoices persistedstate.Invoices) ([]invoiceupdater.Patch, error) { +func (p LineShrinkUsageBasedPatch) getInvoicePatchesForHierarchy(existingHierarchy *billing.SplitLineHierarchy, expectedLine billing.GatheringLine, invoices persistedstate.Invoices) ([]invoiceupdater.Patch, error) { if shouldSkipHierarchyPatch(existingHierarchy, expectedLine) { return nil, nil } diff --git a/openmeter/billing/worker/subscriptionsync/service/reconciler/prorating.go b/openmeter/billing/worker/subscriptionsync/service/reconciler/prorating.go new file mode 100644 index 0000000000..e69de29bb2 diff --git a/openmeter/billing/worker/subscriptionsync/service/reconciler/reconciler.go b/openmeter/billing/worker/subscriptionsync/service/reconciler/reconciler.go index 5fdcb0e0f3..09ee84c405 100644 --- a/openmeter/billing/worker/subscriptionsync/service/reconciler/reconciler.go +++ b/openmeter/billing/worker/subscriptionsync/service/reconciler/reconciler.go @@ -8,7 +8,6 @@ import ( "slices" "time" - "github.com/alpacahq/alpacadecimal" "github.com/samber/lo" "github.com/openmeterio/openmeter/openmeter/billing" @@ -113,133 +112,59 @@ type diffItemResult struct { func (s *Service) diffItem( target *targetstate.SubscriptionItemWithPeriods, expectedLine *billing.GatheringLine, // TODO[later]: let's merge this with target as they are the same thing's different calculation stages - existing *billing.LineOrHierarchy, -) (diffItemResult, error) { + existing persistedstate.Entity, +) (Patch, error) { switch { case target == nil && existing == nil: - return diffItemResult{}, nil + return nil, nil case target == nil && existing != nil: - uniqueID := lo.FromPtr(existing.ChildUniqueReferenceID()) - - return diffItemResult{ - Patch: DeletePatch{ - UniqueID: uniqueID, - Existing: *existing, - }, - Changed: true, - }, nil + return s.NewDeletePatch(existing) case target != nil && existing == nil && expectedLine != nil: - return diffItemResult{ - Patch: CreatePatch{ - UniqueID: target.UniqueID, - Target: *target, - }, - Changed: true, - }, nil + return s.NewCreatePatch(NewCreatePatchInput{ + UniqueID: target.UniqueID, + Target: *target, + }) case target != nil && existing == nil && expectedLine == nil: // If the target is not nil, but the expected line is nil, we should not create a patch (most probably // because the line is ignored or empty service period) - return diffItemResult{}, nil + return nil, nil case target != nil && existing != nil && expectedLine == nil: - return diffItemResult{ - Patch: DeletePatch{ - UniqueID: target.UniqueID, - Existing: *existing, - }, - Changed: true, - }, nil + return s.NewDeletePatch(existing) } - existingPeriod := existing.ServicePeriod() + existingPeriod := existing.GetServicePeriod() targetPeriod := expectedLine.ServicePeriod - if decision, err := semanticProrateDecision(*existing, *expectedLine); err != nil { - return diffItemResult{}, err + if decision, err := semanticProrateDecision(existing, *expectedLine); err != nil { + return nil, err } else if decision.ShouldProrate { // Flat fee lines do not produce usage-based shrink/extend patches. Any period // change for a flat fee line is reconciled through ProratePatch so that the // service period and per-unit amount are updated together. - return diffItemResult{ - Patch: ProratePatch{ - UniqueID: target.UniqueID, - Existing: *existing, - Target: *target, - OriginalPeriod: existingPeriod, - TargetPeriod: targetPeriod, - OriginalAmount: decision.OriginalAmount, - TargetAmount: decision.TargetAmount, - }, - Changed: true, + return ProratePatch{ + UniqueID: target.UniqueID, + Existing: existing, + Target: *target, + OriginalPeriod: existingPeriod, + TargetPeriod: targetPeriod, + OriginalAmount: decision.OriginalAmount, + TargetAmount: decision.TargetAmount, }, nil } switch { case targetPeriod.To.Before(existingPeriod.To): - return diffItemResult{ - Patch: ShrinkUsageBasedPatch{ - UniqueID: target.UniqueID, - Existing: *existing, - Target: *target, - }, - Changed: true, - }, nil + return s.NewLineShrinkUsageBasedPatch(NewLineShrinkUsageBasedPatchInput{ + Existing: existing, + Target: *target, + }) case targetPeriod.To.After(existingPeriod.To): - return diffItemResult{ - Patch: ExtendUsageBasedPatch{ - UniqueID: target.UniqueID, - Existing: *existing, - Target: *target, - }, - Changed: true, - }, nil - default: - return diffItemResult{}, nil - } -} - -type ProrateDecision struct { - ShouldProrate bool - OriginalAmount alpacadecimal.Decimal - TargetAmount alpacadecimal.Decimal -} - -func semanticProrateDecision(existing billing.LineOrHierarchy, expectedLine billing.GatheringLine) (ProrateDecision, error) { - if !invoiceupdater.IsFlatFee(expectedLine) { - return ProrateDecision{}, nil - } - - // expectedLine is materialized through targetstate.LineFromSubscriptionRateCard, which - // applies the existing subscription-sync proration rules when deriving the flat-fee amount. - targetAmount, err := invoiceupdater.GetFlatFeePerUnitAmount(expectedLine) - if err != nil { - return ProrateDecision{}, fmt.Errorf("getting expected flat fee amount: %w", err) - } - - switch existing.Type() { - case billing.LineOrHierarchyTypeLine: - existingLine, err := existing.AsGenericLine() - if err != nil { - return ProrateDecision{}, fmt.Errorf("getting existing line: %w", err) - } - - if !invoiceupdater.IsFlatFee(existingLine) { - return ProrateDecision{}, nil - } - - existingAmount, err := invoiceupdater.GetFlatFeePerUnitAmount(existingLine) - if err != nil { - return ProrateDecision{}, fmt.Errorf("getting existing flat fee amount: %w", err) - } - - return ProrateDecision{ - ShouldProrate: !existingAmount.Equal(targetAmount) || !existingLine.GetServicePeriod().Equal(expectedLine.ServicePeriod), - OriginalAmount: existingAmount, - TargetAmount: targetAmount, - }, nil - case billing.LineOrHierarchyTypeHierarchy: - return ProrateDecision{}, errors.New("flat fee lines cannot be reconciled against a split line hierarchy") + return s.NewLineExtendUsageBasedPatch(NewLineExtendUsageBasedPatchInput{ + Existing: existing, + Target: *target, + }) default: - return ProrateDecision{}, fmt.Errorf("unsupported line or hierarchy type: %s", existing.Type()) + return nil, nil } } @@ -276,12 +201,12 @@ func (s *Service) Plan(ctx context.Context, input PlanInput) (*Plan, error) { return nil, fmt.Errorf("existing line[%s] not found in the existing lines", id) } - diff, err := s.diffItem(nil, nil, &line) + patch, err := s.diffItem(nil, nil, line) if err != nil { return nil, fmt.Errorf("diffing deleted line[%s]: %w", id, err) } - if diff.Changed { - patches = append(patches, diff.Patch) + if patch != nil { + patches = append(patches, patch) } } @@ -294,33 +219,58 @@ func (s *Service) Plan(ctx context.Context, input PlanInput) (*Plan, error) { existingLine, ok := persisted.ByUniqueID[id] if !ok { - diff, err := s.diffItem(&targetLine, expectedLine, nil) + patch, err := s.diffItem(&targetLine, expectedLine, nil) if err != nil { return nil, fmt.Errorf("diffing new line[%s]: %w", id, err) } - if diff.Changed { - patches = append(patches, diff.Patch) + if patch != nil { + patches = append(patches, patch) } continue } - diff, err := s.diffItem(&targetLine, expectedLine, &existingLine) + patch, err := s.diffItem(&targetLine, expectedLine, existingLine) if err != nil { return nil, fmt.Errorf("diffing existing line[%s]: %w", id, err) } - if diff.Changed { - patches = append(patches, diff.Patch) + if patch != nil { + patches = append(patches, patch) } } + filteredPatches := lo.Filter(patches, func(p Patch, _ int) bool { + return p != nil + }) + + if err := s.validatePatches(filteredPatches); err != nil { + return nil, fmt.Errorf("validating patches: %w", err) + } + return &Plan{ - Patches: lo.Filter(patches, func(p Patch, _ int) bool { - return p != nil - }), + Patches: filteredPatches, SubscriptionMaxGenerationTimeLimit: input.Target.MaxGenerationTimeLimit, }, nil } +func (s *Service) validatePatches(patches []Patch) error { + for _, patch := range patches { + if patch == nil { + return fmt.Errorf("patch is nil") + } + + // Let's mandate that all patches are implementing either InvoicePatch or ChargePatch. + _, isInvoicePatch := patch.(InvoicePatch) + _, isChargePatch := patch.(ChargePatch) + + // TODO: let's decide later if we want to support mixed invoice and charge patches. + + if !isInvoicePatch && !isChargePatch { + return fmt.Errorf("patch is not an invoice or charge patch: %T", patch) + } + } + return nil +} + func (s *Service) Apply(ctx context.Context, input ApplyInput) error { if err := input.Validate(); err != nil { return fmt.Errorf("validating input: %w", err) @@ -330,10 +280,17 @@ func (s *Service) Apply(ctx context.Context, input ApplyInput) error { return nil } + // TODO: Let's validate that patches are either invoice or charge patches. + invoicePatches := make([]invoiceupdater.Patch, 0, len(input.Plan.Patches)) for _, patch := range input.Plan.Patches { - newInvoicePatches, err := patch.GetInvoicePatches(GetInvoicePatchesInput{ + invoicePatch, ok := patch.(InvoicePatch) + if !ok { + continue + } + + newInvoicePatches, err := invoicePatch.GetInvoicePatches(GetInvoicePatchesInput{ Subscription: input.Subscription, Currency: input.Currency, Invoices: input.Invoices, diff --git a/openmeter/billing/worker/subscriptionsync/service/service.go b/openmeter/billing/worker/subscriptionsync/service/service.go index 29a8835d3f..0f6f86b039 100644 --- a/openmeter/billing/worker/subscriptionsync/service/service.go +++ b/openmeter/billing/worker/subscriptionsync/service/service.go @@ -8,7 +8,9 @@ import ( "go.opentelemetry.io/otel/trace" "github.com/openmeterio/openmeter/openmeter/billing" + "github.com/openmeterio/openmeter/openmeter/billing/charges" "github.com/openmeterio/openmeter/openmeter/billing/worker/subscriptionsync" + "github.com/openmeterio/openmeter/openmeter/billing/worker/subscriptionsync/service/persistedstate" "github.com/openmeterio/openmeter/openmeter/billing/worker/subscriptionsync/service/reconciler" "github.com/openmeterio/openmeter/openmeter/subscription" "github.com/openmeterio/openmeter/pkg/framework/transaction" @@ -23,9 +25,13 @@ type Config struct { BillingService billing.Service SubscriptionService subscription.Service SubscriptionSyncAdapter subscriptionsync.Adapter - FeatureFlags FeatureFlags - Logger *slog.Logger - Tracer trace.Tracer + // ChargesService is optional for now, signaling that charges are not enabled. + // going forward, we will have more settings to configure the scope that should be deferred to + // the charges service. + ChargesService charges.Service + FeatureFlags FeatureFlags + Logger *slog.Logger + Tracer trace.Tracer } func (c Config) Validate() error { @@ -62,6 +68,8 @@ type Service struct { featureFlags FeatureFlags logger *slog.Logger tracer trace.Tracer + chargesService charges.Service + persistedStateLoader persistedstate.Loader } func New(config Config) (*Service, error) { @@ -75,9 +83,11 @@ func New(config Config) (*Service, error) { if err != nil { return nil, err } + return &Service{ billingService: config.BillingService, reconciler: reconcilerSvc, + persistedStateLoader: persistedstate.NewLoader(config.BillingService, config.ChargesService), subscriptionSyncAdapter: config.SubscriptionSyncAdapter, featureFlags: config.FeatureFlags, subscriptionService: config.SubscriptionService, diff --git a/openmeter/billing/worker/subscriptionsync/service/sync.go b/openmeter/billing/worker/subscriptionsync/service/sync.go index ab405f2650..710013a335 100644 --- a/openmeter/billing/worker/subscriptionsync/service/sync.go +++ b/openmeter/billing/worker/subscriptionsync/service/sync.go @@ -12,7 +12,6 @@ import ( "github.com/openmeterio/openmeter/openmeter/billing" "github.com/openmeterio/openmeter/openmeter/billing/worker/subscriptionsync" - "github.com/openmeterio/openmeter/openmeter/billing/worker/subscriptionsync/service/persistedstate" "github.com/openmeterio/openmeter/openmeter/billing/worker/subscriptionsync/service/reconciler" "github.com/openmeterio/openmeter/openmeter/customer" "github.com/openmeterio/openmeter/openmeter/subscription" @@ -160,8 +159,7 @@ func (s *Service) SynchronizeSubscription(ctx context.Context, subs subscription Namespace: subs.Subscription.Namespace, ID: subs.Subscription.CustomerId, }, func(ctx context.Context) error { - persistedLoader := persistedstate.NewLoader(s.billingService) - persistedInvoices, err := persistedLoader.LoadInvoicesForCustomer(ctx, customerID) + persistedInvoices, err := s.persistedStateLoader.LoadInvoicesForCustomer(ctx, customerID) if err != nil { return err } diff --git a/openmeter/billing/worker/subscriptionsync/service/targetstate/targetstate.go b/openmeter/billing/worker/subscriptionsync/service/targetstate/targetstate.go index 1d23f39bdd..60fb83592a 100644 --- a/openmeter/billing/worker/subscriptionsync/service/targetstate/targetstate.go +++ b/openmeter/billing/worker/subscriptionsync/service/targetstate/targetstate.go @@ -191,7 +191,7 @@ func (b Builder) correctPeriodStartForUpcomingLines(ctx context.Context, subscri continue } - if existingCurrentLine, ok := persisted.ByUniqueID[line.UniqueID]; ok { + if existingCurrentLine, ok := persisted.ByUniqueID.GetAsLineOrHierarchy(line.UniqueID); ok { syncIgnore, err := b.lineOrHierarchyHasAnnotation(existingCurrentLine, billing.AnnotationSubscriptionSyncIgnore) if err != nil { return nil, fmt.Errorf("checking if line has subscription sync ignore annotation: %w", err) @@ -233,7 +233,7 @@ func (b Builder) correctPeriodStartForUpcomingLines(ctx context.Context, subscri continue } - previousServicePeriod := existingPreviousLine.ServicePeriod() + previousServicePeriod := existingPreviousLine.GetServicePeriod() // The iterator output is already normalized to meter resolution, but this // continuity correction reuses a boundary from persisted state. Historical // rows can carry sub-second precision that the meter engine cannot query, so From 4cc05d5568fbf1f4f05f443fa3e847594a9f1239 Mon Sep 17 00:00:00 2001 From: Peter Turi Date: Fri, 27 Mar 2026 17:20:01 +0100 Subject: [PATCH 2/2] fix: stuff --- .../service/reconciler/prorating.go | 123 ++++++++++++++++ .../service/reconciler/reconciler.go | 7 +- .../service/targetstate/phaseiterator.go | 21 --- .../service/targetstate/targetstate.go | 123 ++-------------- .../service/targetstate/targetstateitem.go | 138 ++++++++++++++++++ 5 files changed, 279 insertions(+), 133 deletions(-) create mode 100644 openmeter/billing/worker/subscriptionsync/service/targetstate/targetstateitem.go diff --git a/openmeter/billing/worker/subscriptionsync/service/reconciler/prorating.go b/openmeter/billing/worker/subscriptionsync/service/reconciler/prorating.go index e69de29bb2..4645773012 100644 --- a/openmeter/billing/worker/subscriptionsync/service/reconciler/prorating.go +++ b/openmeter/billing/worker/subscriptionsync/service/reconciler/prorating.go @@ -0,0 +1,123 @@ +package reconciler + +import ( + "errors" + "fmt" + + "github.com/openmeterio/openmeter/openmeter/billing" + "github.com/openmeterio/openmeter/openmeter/billing/charges" + "github.com/openmeterio/openmeter/openmeter/billing/charges/meta" + "github.com/openmeterio/openmeter/openmeter/billing/worker/subscriptionsync/service/persistedstate" + "github.com/openmeterio/openmeter/openmeter/billing/worker/subscriptionsync/service/reconciler/invoiceupdater" + "github.com/openmeterio/openmeter/openmeter/billing/worker/subscriptionsync/service/targetstate" +) + +func semanticProrateDecision(existing persistedstate.Entity, targetState targetstate.SubscriptionItemWithPeriods) (bool, error) { + if !existing.IsFlatFee() { + return false, nil + } + + // For the prorate decision we need to understand the underlying entity type, as the proration logic is different: + // - for a line phaseiterator yields the prorated amount based on the service period and the per-unit amount + // - for a charge the charge service will handle the proration logic + switch existing.GetType() { + case persistedstate.EntityTypeLineOrHierarchy: + lineOrHierarchy, err := existing.AsLineOrHierarchy() + if err != nil { + return false, fmt.Errorf("getting line or hierarchy: %w", err) + } + return semanticLineProrateDecision(lineOrHierarchy, targetState) + case persistedstate.EntityTypeCharge: + charge, err := existing.AsCharge() + if err != nil { + return false, fmt.Errorf("getting charge: %w", err) + } + + return semanticChargeProrateDecision(charge, targetState) + default: + return false, fmt.Errorf("unsupported entity type: %s", existing.GetType()) + } +} + +func semanticLineProrateDecision(existing billing.LineOrHierarchy, targetState targetstate.SubscriptionItemWithPeriods) (bool, error) { + gatheringLine, err := targetState.GetExpectedLineOrErr(input.Subscription, input.Currency) + if err != nil { + return false, fmt.Errorf("getting expected line: %w", err) + } + + // expectedLine is materialized through targetstate.LineFromSubscriptionRateCard, which + // applies the existing subscription-sync proration rules when deriving the flat-fee amount. + targetAmount, err := invoiceupdater.GetFlatFeePerUnitAmount(expectedLine) + if err != nil { + return false, fmt.Errorf("getting expected flat fee amount: %w", err) + } + + switch existing.Type() { + case billing.LineOrHierarchyTypeLine: + existingLine, err := existing.AsGenericLine() + if err != nil { + return false, fmt.Errorf("getting existing line: %w", err) + } + + if !invoiceupdater.IsFlatFee(existingLine) { + return false, nil + } + + existingAmount, err := invoiceupdater.GetFlatFeePerUnitAmount(existingLine) + if err != nil { + return false, fmt.Errorf("getting existing flat fee amount: %w", err) + } + + return !existingAmount.Equal(targetAmount) || !existingLine.GetServicePeriod().Equal(expectedLine.ServicePeriod), nil + case billing.LineOrHierarchyTypeHierarchy: + return false, errors.New("flat fee lines cannot be reconciled against a split line hierarchy") + default: + return false, fmt.Errorf("unsupported line or hierarchy type: %s", existing.Type()) + } +} + +func semanticChargeProrateDecision(existing charges.Charge, targetState targetstate.SubscriptionItemWithPeriods) (bool, error) { + if existing.Type() != meta.ChargeTypeFlatFee { + return false, nil + } + + existingFlatFee, err := existing.AsFlatFeeCharge() + if err != nil { + return false, fmt.Errorf("getting existing flat fee charge: %w", err) + } + + // Do not prorate if pro-rating is not enabled. + if !existingFlatFee.Intent.ProRating.Enabled { + return false, nil + } + + price := targetState.SubscriptionItem.RateCard.AsMeta().Price + if price == nil { + return false, fmt.Errorf("price is nil") + } + + priceFlat, err := price.AsFlat() + if err != nil { + return false, fmt.Errorf("getting price flat: %w", err) + } + + // Proration is required if: + // - the service period or full service period has changed + // - the amount before proration has changed + // + // As proration is calculated as lengthOfServicePeriod / lengthOfFullServicePeriod * amountBeforeProration, + + if !existingFlatFee.Intent.ServicePeriod.Equal(targetState.ServicePeriod.ToClosedPeriod()) { + return true, nil + } + + if !existingFlatFee.Intent.FullServicePeriod.Equal(targetState.FullServicePeriod.ToClosedPeriod()) { + return true, nil + } + + if !existingFlatFee.Intent.AmountBeforeProration.Equal(priceFlat.Amount) { + return true, nil + } + + return false, nil +} diff --git a/openmeter/billing/worker/subscriptionsync/service/reconciler/reconciler.go b/openmeter/billing/worker/subscriptionsync/service/reconciler/reconciler.go index 09ee84c405..b771ee7c9c 100644 --- a/openmeter/billing/worker/subscriptionsync/service/reconciler/reconciler.go +++ b/openmeter/billing/worker/subscriptionsync/service/reconciler/reconciler.go @@ -111,7 +111,6 @@ type diffItemResult struct { func (s *Service) diffItem( target *targetstate.SubscriptionItemWithPeriods, - expectedLine *billing.GatheringLine, // TODO[later]: let's merge this with target as they are the same thing's different calculation stages existing persistedstate.Entity, ) (Patch, error) { switch { @@ -119,7 +118,7 @@ func (s *Service) diffItem( return nil, nil case target == nil && existing != nil: return s.NewDeletePatch(existing) - case target != nil && existing == nil && expectedLine != nil: + case target != nil && existing == nil && target.IsBillable(): return s.NewCreatePatch(NewCreatePatchInput{ UniqueID: target.UniqueID, Target: *target, @@ -135,9 +134,9 @@ func (s *Service) diffItem( existingPeriod := existing.GetServicePeriod() targetPeriod := expectedLine.ServicePeriod - if decision, err := semanticProrateDecision(existing, *expectedLine); err != nil { + if shouldProrateDecision, err := semanticProrateDecision(existing, *expectedLine); err != nil { return nil, err - } else if decision.ShouldProrate { + } else if shouldProrateDecision { // Flat fee lines do not produce usage-based shrink/extend patches. Any period // change for a flat fee line is reconciled through ProratePatch so that the // service period and per-unit amount are updated together. diff --git a/openmeter/billing/worker/subscriptionsync/service/targetstate/phaseiterator.go b/openmeter/billing/worker/subscriptionsync/service/targetstate/phaseiterator.go index 1ee13392b4..27341c1dd3 100644 --- a/openmeter/billing/worker/subscriptionsync/service/targetstate/phaseiterator.go +++ b/openmeter/billing/worker/subscriptionsync/service/targetstate/phaseiterator.go @@ -2,7 +2,6 @@ package targetstate import ( "context" - "errors" "fmt" "log/slog" "runtime/debug" @@ -18,7 +17,6 @@ import ( "github.com/openmeterio/openmeter/openmeter/productcatalog" "github.com/openmeterio/openmeter/openmeter/streaming" "github.com/openmeterio/openmeter/openmeter/subscription" - "github.com/openmeterio/openmeter/pkg/currencyx" "github.com/openmeterio/openmeter/pkg/framework/tracex" "github.com/openmeterio/openmeter/pkg/models" "github.com/openmeterio/openmeter/pkg/slicesx" @@ -97,25 +95,6 @@ func (r SubscriptionItemWithPeriods) GetInvoiceAt() time.Time { return lo.Latest(r.ServicePeriod.End, r.BillingPeriod.End) } -func (r SubscriptionItemWithPeriods) GetExpectedLine(view subscription.Subscription, currency currencyx.Calculator) (*billing.GatheringLine, error) { - return lineFromSubscriptionRateCard(view, r, currency) -} - -var ErrExpectedLineIsEmpty = errors.New("expected line is empty") - -func (r SubscriptionItemWithPeriods) GetExpectedLineOrErr(view subscription.Subscription, currency currencyx.Calculator) (billing.GatheringLine, error) { - line, err := r.GetExpectedLine(view, currency) - if err != nil { - return billing.GatheringLine{}, err - } - - if line == nil { - return billing.GatheringLine{}, fmt.Errorf("%w [child_unique_id: %s]", ErrExpectedLineIsEmpty, r.UniqueID) - } - - return *line, nil -} - func NewPhaseIterator(logger *slog.Logger, tracer trace.Tracer, subs subscription.SubscriptionView, phaseKey string) (*PhaseIterator, error) { phase, ok := subs.GetPhaseByKey(phaseKey) if !ok { diff --git a/openmeter/billing/worker/subscriptionsync/service/targetstate/targetstate.go b/openmeter/billing/worker/subscriptionsync/service/targetstate/targetstate.go index 60fb83592a..0447a27034 100644 --- a/openmeter/billing/worker/subscriptionsync/service/targetstate/targetstate.go +++ b/openmeter/billing/worker/subscriptionsync/service/targetstate/targetstate.go @@ -14,17 +14,15 @@ import ( "github.com/openmeterio/openmeter/openmeter/billing" "github.com/openmeterio/openmeter/openmeter/billing/worker/subscriptionsync/service/persistedstate" - "github.com/openmeterio/openmeter/openmeter/productcatalog" "github.com/openmeterio/openmeter/openmeter/streaming" "github.com/openmeterio/openmeter/openmeter/subscription" - "github.com/openmeterio/openmeter/pkg/currencyx" "github.com/openmeterio/openmeter/pkg/framework/tracex" "github.com/openmeterio/openmeter/pkg/models" "github.com/openmeterio/openmeter/pkg/timeutil" ) type State struct { - Items []SubscriptionItemWithPeriods + Items []StateItem MaxGenerationTimeLimit time.Time } @@ -103,8 +101,19 @@ func (b Builder) Build(ctx context.Context, input BuildInput) (State, error) { return State{}, fmt.Errorf("correcting period start for upcoming lines: %w", err) } + currencyCalculator, err := subs.Subscription.Currency.Calculator() + if err != nil { + return State{}, fmt.Errorf("getting currency calculator: %w", err) + } + return State{ - Items: inScopeLines, + Items: lo.Map(inScopeLines, func(item SubscriptionItemWithPeriods, _ int) StateItem { + return StateItem{ + SubscriptionItemWithPeriods: item, + Subscription: subs.Subscription, + CurrencyCalculator: currencyCalculator, + } + }), MaxGenerationTimeLimit: upcomingLinesResult.SubscriptionMaxGenerationTimeLimit, }, nil }) @@ -191,7 +200,7 @@ func (b Builder) correctPeriodStartForUpcomingLines(ctx context.Context, subscri continue } - if existingCurrentLine, ok := persisted.ByUniqueID.GetAsLineOrHierarchy(line.UniqueID); ok { + if existingCurrentLine, ok := persisted.ByUniqueID[line.UniqueID]; ok { syncIgnore, err := b.lineOrHierarchyHasAnnotation(existingCurrentLine, billing.AnnotationSubscriptionSyncIgnore) if err != nil { return nil, fmt.Errorf("checking if line has subscription sync ignore annotation: %w", err) @@ -233,7 +242,7 @@ func (b Builder) correctPeriodStartForUpcomingLines(ctx context.Context, subscri continue } - previousServicePeriod := existingPreviousLine.GetServicePeriod() + previousServicePeriod := existingPreviousLine.ServicePeriod() // The iterator output is already normalized to meter resolution, but this // continuity correction reuses a boundary from persisted state. Historical // rows can carry sub-second precision that the meter engine cannot query, so @@ -300,105 +309,3 @@ func (b Builder) hierarchyHasAnnotation(hierarchy *billing.SplitLineHierarchy, a return false, nil } - -// TODO: make a member of the SubscriptionItemWithPeriods type (for now it's kept here for easier review) -func lineFromSubscriptionRateCard(subs subscription.Subscription, item SubscriptionItemWithPeriods, currency currencyx.Calculator) (*billing.GatheringLine, error) { - line := billing.GatheringLine{ - GatheringLineBase: billing.GatheringLineBase{ - ManagedResource: models.NewManagedResource(models.ManagedResourceInput{ - Namespace: subs.Namespace, - Name: item.Spec.RateCard.AsMeta().Name, - Description: item.Spec.RateCard.AsMeta().Description, - }), - ManagedBy: billing.SubscriptionManagedLine, - Currency: subs.Currency, - ChildUniqueReferenceID: &item.UniqueID, - TaxConfig: item.Spec.RateCard.AsMeta().TaxConfig, - ServicePeriod: item.ServicePeriod.ToClosedPeriod(), - InvoiceAt: item.GetInvoiceAt(), - RateCardDiscounts: discountsToBillingDiscounts(item.Spec.RateCard.AsMeta().Discounts), - Subscription: &billing.SubscriptionReference{ - SubscriptionID: subs.ID, - PhaseID: item.PhaseID, - ItemID: item.SubscriptionItem.ID, - BillingPeriod: timeutil.ClosedPeriod{ - From: item.BillingPeriod.Start, - To: item.BillingPeriod.End, - }, - }, - }, - } - - if price := item.SubscriptionItem.RateCard.AsMeta().Price; price != nil && price.GetPaymentTerm() == productcatalog.InArrearsPaymentTerm { - if item.FullServicePeriod.Duration() == time.Duration(0) { - return nil, nil - } - } - - switch item.SubscriptionItem.RateCard.AsMeta().Price.Type() { - case productcatalog.FlatPriceType: - price, err := item.SubscriptionItem.RateCard.AsMeta().Price.AsFlat() - if err != nil { - return nil, fmt.Errorf("converting price to flat: %w", err) - } - - perUnitAmount := currency.RoundToPrecision(price.Amount) - if !item.ServicePeriod.IsEmpty() && shouldProrate(item, subs) { - perUnitAmount = currency.RoundToPrecision(price.Amount.Mul(item.PeriodPercentage())) - } - - if perUnitAmount.IsZero() { - return nil, nil - } - - line.Price = lo.FromPtr(productcatalog.NewPriceFrom(productcatalog.FlatPrice{ - Amount: perUnitAmount, - PaymentTerm: price.PaymentTerm, - })) - line.FeatureKey = lo.FromPtr(item.SubscriptionItem.RateCard.AsMeta().FeatureKey) - default: - if item.SubscriptionItem.RateCard.AsMeta().Price == nil { - return nil, fmt.Errorf("price must be defined for usage based price") - } - - line.Price = lo.FromPtr(item.SubscriptionItem.RateCard.AsMeta().Price) - line.FeatureKey = lo.FromPtr(item.SubscriptionItem.RateCard.AsMeta().FeatureKey) - } - - return &line, nil -} - -func discountsToBillingDiscounts(discounts productcatalog.Discounts) billing.Discounts { - out := billing.Discounts{} - - if discounts.Usage != nil { - out.Usage = &billing.UsageDiscount{UsageDiscount: *discounts.Usage} - } - - if discounts.Percentage != nil { - out.Percentage = &billing.PercentageDiscount{PercentageDiscount: *discounts.Percentage} - } - - return out -} - -func shouldProrate(item SubscriptionItemWithPeriods, subs subscription.Subscription) bool { - if !subs.ProRatingConfig.Enabled { - return false - } - - if item.Spec.RateCard.AsMeta().Price.Type() != productcatalog.FlatPriceType { - return false - } - - if subs.ActiveTo != nil && !subs.ActiveTo.After(item.ServicePeriod.End) { - return false - } - - switch subs.ProRatingConfig.Mode { - case productcatalog.ProRatingModeProratePrices: - return true - default: - return false - } -} diff --git a/openmeter/billing/worker/subscriptionsync/service/targetstate/targetstateitem.go b/openmeter/billing/worker/subscriptionsync/service/targetstate/targetstateitem.go new file mode 100644 index 0000000000..2ae92d3a6d --- /dev/null +++ b/openmeter/billing/worker/subscriptionsync/service/targetstate/targetstateitem.go @@ -0,0 +1,138 @@ +package targetstate + +import ( + "errors" + "fmt" + "time" + + "github.com/openmeterio/openmeter/openmeter/billing" + "github.com/openmeterio/openmeter/openmeter/productcatalog" + "github.com/openmeterio/openmeter/openmeter/subscription" + "github.com/openmeterio/openmeter/pkg/currencyx" + "github.com/openmeterio/openmeter/pkg/models" + "github.com/openmeterio/openmeter/pkg/timeutil" + "github.com/samber/lo" +) + +type StateItem struct { + SubscriptionItemWithPeriods + + Subscription subscription.Subscription + CurrencyCalculator currencyx.Calculator +} + +var ErrExpectedLineIsEmpty = errors.New("expected line is empty") + +func (r StateItem) GetExpectedLineOrErr() (billing.GatheringLine, error) { + line, err := r.GetExpectedLine() + if err != nil { + return billing.GatheringLine{}, err + } + + if line == nil { + return billing.GatheringLine{}, fmt.Errorf("%w [child_unique_id: %s]", ErrExpectedLineIsEmpty, r.UniqueID) + } + + return *line, nil +} + +func (r StateItem) GetExpectedLine() (*billing.GatheringLine, error) { + line := billing.GatheringLine{ + GatheringLineBase: billing.GatheringLineBase{ + ManagedResource: models.NewManagedResource(models.ManagedResourceInput{ + Namespace: r.Subscription.Namespace, + Name: r.Spec.RateCard.AsMeta().Name, + Description: r.Spec.RateCard.AsMeta().Description, + }), + ManagedBy: billing.SubscriptionManagedLine, + Currency: r.Subscription.Currency, + ChildUniqueReferenceID: &r.UniqueID, + TaxConfig: r.Spec.RateCard.AsMeta().TaxConfig, + ServicePeriod: r.ServicePeriod.ToClosedPeriod(), + InvoiceAt: r.GetInvoiceAt(), + RateCardDiscounts: discountsToBillingDiscounts(r.Spec.RateCard.AsMeta().Discounts), + Subscription: &billing.SubscriptionReference{ + SubscriptionID: r.Subscription.ID, + PhaseID: r.PhaseID, + ItemID: r.SubscriptionItem.ID, + BillingPeriod: timeutil.ClosedPeriod{ + From: r.BillingPeriod.Start, + To: r.BillingPeriod.End, + }, + }, + }, + } + + if price := r.Spec.RateCard.AsMeta().Price; price != nil && price.GetPaymentTerm() == productcatalog.InArrearsPaymentTerm { + if r.FullServicePeriod.Duration() == time.Duration(0) { + return nil, nil + } + } + + switch r.Spec.RateCard.AsMeta().Price.Type() { + case productcatalog.FlatPriceType: + price, err := r.Spec.RateCard.AsMeta().Price.AsFlat() + if err != nil { + return nil, fmt.Errorf("converting price to flat: %w", err) + } + + perUnitAmount := r.CurrencyCalculator.RoundToPrecision(price.Amount) + if !r.ServicePeriod.IsEmpty() && shouldProrate(r, r.Subscription) { + perUnitAmount = r.CurrencyCalculator.RoundToPrecision(price.Amount.Mul(r.PeriodPercentage())) + } + + if perUnitAmount.IsZero() { + return nil, nil + } + + line.Price = lo.FromPtr(productcatalog.NewPriceFrom(productcatalog.FlatPrice{ + Amount: perUnitAmount, + PaymentTerm: price.PaymentTerm, + })) + line.FeatureKey = lo.FromPtr(r.Spec.RateCard.AsMeta().FeatureKey) + default: + if r.Spec.RateCard.AsMeta().Price == nil { + return nil, fmt.Errorf("price must be defined for usage based price") + } + + line.Price = lo.FromPtr(r.Spec.RateCard.AsMeta().Price) + line.FeatureKey = lo.FromPtr(r.Spec.RateCard.AsMeta().FeatureKey) + } + + return &line, nil +} + +func discountsToBillingDiscounts(discounts productcatalog.Discounts) billing.Discounts { + out := billing.Discounts{} + + if discounts.Usage != nil { + out.Usage = &billing.UsageDiscount{UsageDiscount: *discounts.Usage} + } + + if discounts.Percentage != nil { + out.Percentage = &billing.PercentageDiscount{PercentageDiscount: *discounts.Percentage} + } + + return out +} + +func shouldProrate(item StateItem, subs subscription.Subscription) bool { + if !subs.ProRatingConfig.Enabled { + return false + } + + if item.Spec.RateCard.AsMeta().Price.Type() != productcatalog.FlatPriceType { + return false + } + + if subs.ActiveTo != nil && !subs.ActiveTo.After(item.ServicePeriod.End) { + return false + } + + switch subs.ProRatingConfig.Mode { + case productcatalog.ProRatingModeProratePrices: + return true + default: + return false + } +}